├── .gitignore
├── CUB.py
├── cub_demo.py
├── fig
├── a3m.png
├── cub-dir.png
├── result.png
└── title.png
├── model
└── readme.md
├── readme.md
├── run.sh
└── tools
├── attributes_process.py
├── processed_attributes.txt
└── readme.md
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore folder
2 | build
3 | log
4 |
5 | # backup files
6 | *.*~
7 |
8 | #.o and .a files
9 | *.[oa]
10 |
11 | *.DS_Store
12 | *.pyc
13 | *.pt
14 | *.h5
15 |
--------------------------------------------------------------------------------
/CUB.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | CUB-200-2011 Dataset.
4 | '''
5 | from __future__ import print_function
6 |
7 | import numpy as np
8 | import warnings
9 |
10 | from PIL import Image
11 | from keras.preprocessing import image
12 | from keras.utils.layer_utils import convert_all_kernels_in_model
13 | from keras.utils.data_utils import get_file
14 | from keras import backend as K
15 | from keras.applications.imagenet_utils import decode_predictions, preprocess_input
16 | import scipy.io as sio
17 | import os
18 | import time
19 |
20 | def load_data(data_folder, target_size=(224, 224), bounding_box=True):
21 | X_train = []
22 | X_test = []
23 | y_train = []
24 | y_test = []
25 | #data_folder = '/home/hankai/data/CUB_200_2011'
26 | images_file = data_folder+'/images.txt'
27 | label_file = data_folder+'/image_class_labels.txt'
28 | attributes_file = data_folder+'/attributes/image_attribute_labels.txt'
29 | class_attributes_file = data_folder+'/attributes/class_attribute_labels_continuous.txt'
30 | split_file = data_folder+'/train_test_split.txt'
31 | bb_file = data_folder+'/bounding_boxes.txt'
32 | attribute_name_file = data_folder+'/attributes.txt'
33 | processed_attribute_file = data_folder+'/processed_attributes.txt'
34 | # train test split
35 | split_rf = open(split_file,'r')
36 | train_test_list = []
37 | train_idx = []
38 | test_idx = []
39 | i=0
40 | for line in split_rf.readlines():
41 | strs = line.strip().split(' ')
42 | train_test_list.append(strs[1])
43 | if(strs[1]=='1'):
44 | train_idx.append(i)
45 | else:
46 | test_idx.append(i)
47 | i+=1
48 | split_rf.close()
49 | # bb
50 | bb_rf = open(bb_file,'r')
51 | bb_list = []
52 | for line in bb_rf.readlines():
53 | strs = line.strip().split(' ')
54 | bb_list.append((float(strs[1]),float(strs[2]),float(strs[1])+float(strs[3])
55 | ,float(strs[2])+float(strs[4])))
56 | bb_rf.close()
57 | # images
58 | i = 0
59 | images_rf = open(images_file,'r')
60 | for line in images_rf.readlines():
61 | strs = line.strip().split(' ')
62 | img = image.load_img(data_folder+'/images/'+strs[1])
63 | if(bounding_box):
64 | img = img.crop(bb_list[int(strs[0])-1])
65 | img = img.resize(target_size)
66 | x = image.img_to_array(img)
67 | if(train_test_list[int(strs[0])-1]=='1'):
68 | X_train.append(x)
69 | else:
70 | X_test.append(x)
71 | i += 1
72 | if(i%1000==0):
73 | print(i,' images load.')
74 | images_rf.close()
75 | # label
76 | label_rf = open(label_file,'r')
77 | for line in label_rf.readlines():
78 | strs = line.strip().split(' ')
79 | if(train_test_list[int(strs[0])-1]=='1'):
80 | y_train.append(int(strs[1])-1)
81 | else:
82 | y_test.append(int(strs[1])-1)
83 | label_rf.close()
84 | # attributes
85 | A_all = np.genfromtxt(processed_attribute_file, dtype=int, delimiter=' ')
86 | A_train = A_all[train_idx]
87 | A_test = A_all[test_idx]
88 | # class attributes
89 | C_A = np.zeros((200,312))
90 | class_attr_rf = open(class_attributes_file,'r')
91 | i = 0
92 | for line in class_attr_rf.readlines():
93 | strs = line.strip().split(' ')
94 | for j in range(len(strs)):
95 | C_A[i][j] = 0 if strs[j]=='0.0' else float(1.0/float(strs[j]))
96 | i+=1
97 | class_attr_rf.close()
98 |
99 | X_train = np.array(X_train)
100 | X_test = np.array(X_test)
101 | y_train = np.array(y_train)
102 | y_test = np.array(y_test)
103 | X_train = preprocess_input(X_train)
104 | X_test = preprocess_input(X_test)
105 | # theano or tensorflow
106 | if K.image_dim_ordering() == 'th':
107 | X_train = X_train.reshape(X_train.shape[0], 3, target_size[0], target_size[1])
108 | X_test = X_test.reshape(X_test.shape[0], 3, target_size[0], target_size[1])
109 | else:
110 | X_train = X_train.reshape(X_train.shape[0], target_size[0], target_size[1], 3)
111 | X_test = X_test.reshape(X_test.shape[0], target_size[0], target_size[1], 3)
112 | return (X_train,y_train), (X_test,y_test), (A_train,A_test,C_A)
113 |
114 |
115 | if __name__ == '__main__':
116 | (X_train,y_train), (X_test,y_test),(A_train,A_test,C_A) = load_data()
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/cub_demo.py:
--------------------------------------------------------------------------------
1 | ''' A3M for fine-grained recognition
2 | '''
3 |
4 | from __future__ import print_function
5 | import sys
6 | sys.path.append("..")
7 | sys.setrecursionlimit(10000)
8 | import numpy as np
9 | np.random.seed(2208) # for reproducibility
10 |
11 | import time
12 | from keras.preprocessing.image import ImageDataGenerator
13 | from keras.layers import Input, Dense, RepeatVector, Permute, merge
14 | from keras.layers import BatchNormalization, Lambda, Bidirectional, GRU
15 | from keras.layers import Dense, Dropout, Activation, Flatten, Reshape
16 | from keras.layers import Convolution2D, MaxPooling2D, Convolution1D
17 | from keras.layers import GlobalAveragePooling2D, GlobalAveragePooling1D
18 | from keras.models import Model
19 | from keras.optimizers import SGD
20 | from keras.utils import np_utils
21 | from keras import backend as K
22 | from keras.models import load_model
23 | #from keras.utils.visualize_util import plot
24 | from keras.applications.vgg16 import VGG16
25 | from keras.applications.resnet50 import ResNet50
26 | import scipy.misc
27 | from sklearn import preprocessing
28 | import CUB
29 |
30 | # args
31 | net = sys.argv[1]
32 | data_folder = sys.argv[2]
33 |
34 | # model config
35 | flag_test = False
36 | batch_size = 10
37 | nb_epoch = 10
38 | dropout = 0.5
39 | final_dim = 512 if net=='VGG16' else 2048
40 | emb_dim = 512
41 | shared_layer_name = 'block5_pool' if net=='VGG16' else 'activation_49'
42 | model_weight_path = './model/weights_resnet50_86.1.h5'
43 | lambdas = [0.2,0.5,1.0]
44 | attr_equal = False
45 | region_equal = False
46 |
47 | # dataset config
48 | dataset = 'CUB'
49 | nb_classes = 200
50 | nb_attributes = [10, 16, 16, 16, 5, 16, 7, 16, 12, 16, 16, 15, 4, 16, 16, 16, 16, 6, 6, 15, 5, 5, 5, 16, 16, 16, 16, 5]
51 | img_rows, img_cols = 448, 448
52 | L = 14*14
53 | lr_list = [0.001,0.003,0.001,0.001,0.001,0.001,0.001,0.0001]
54 |
55 | def init_classification(input_fea_map, dim_channel, nb_class, name=None):
56 | # conv
57 | fea_map = Convolution1D(dim_channel, 1, border_mode='same')(share_fea_map)
58 | fea_map = BatchNormalization(axis=2)(fea_map)
59 | fea_map = Activation('relu')(fea_map)
60 | # pool
61 | pool = GlobalAveragePooling1D(name=name+'_avg_pool')(fea_map)
62 | pool = BatchNormalization()(pool)
63 | pool = Activation('relu')(pool)
64 | # classification
65 | pool = Dropout(dropout)(pool)
66 | prob = Dense(nb_class)(pool)
67 | prob = Activation(activation='softmax',name=name)(prob)
68 | return prob,pool,fea_map
69 |
70 | # model define
71 | alphas = [lambdas[1]*1.0/len(nb_attributes)]*len(nb_attributes)
72 | loss_dict = {}
73 | weight_dict = {}
74 | # input and output
75 | inputs = Input(shape=(3, img_rows, img_cols))
76 | out_list = []
77 |
78 | # shared CNN
79 | model_raw = eval(net)(input_tensor=inputs, include_top=False, weights='imagenet')
80 | share_fea_map = model_raw.get_layer(shared_layer_name).output
81 | share_fea_map = Reshape((final_dim, L), name='reshape_layer')(share_fea_map)
82 | share_fea_map = Permute((2, 1))(share_fea_map)
83 |
84 | # loss-1: identity classification
85 | id_prob,id_pool,id_fea_map = init_classification(share_fea_map, emb_dim, nb_classes, name='p0')
86 | out_list.append(id_prob)
87 | loss_dict['p0'] = 'categorical_crossentropy'
88 | weight_dict['p0'] = lambdas[0]
89 |
90 | # loss-2: attribute classification
91 | attr_fea_list = []
92 | for i in range(len(nb_attributes)):
93 | name ='attr'+str(i)
94 | attr_prob,attr_pool,_ = init_classification(share_fea_map, emb_dim, nb_attributes[i], name)
95 | out_list.append(attr_prob)
96 | attr_fea_list.append(attr_pool)
97 | loss_dict[name] = 'categorical_crossentropy'
98 | weight_dict[name] = alphas[i]
99 |
100 | # attention generation
101 | region_score_map_list = []
102 | attr_score_list = []
103 | for i in range(len(nb_attributes)):
104 | attn1 = merge([id_fea_map,attr_fea_list[i]], mode='dot', dot_axes=(2,1))
105 | fea_score = merge([id_pool,attr_fea_list[i]], mode='dot', dot_axes=(1,1))
106 | region_score_map_list.append(attn1)
107 | attr_score_list.append(fea_score)
108 |
109 | # regional feature fusion
110 | region_score_map = merge(region_score_map_list, mode='ave', name='attn')
111 | region_score_map = BatchNormalization()(region_score_map)
112 | region_score_map = Activation('sigmoid', name='region_attention')(region_score_map)
113 | region_fea = merge([id_fea_map,region_score_map], mode='dot', dot_axes=(1,1))
114 | region_fea = Lambda(lambda x: x*(1.0/L))(region_fea)
115 | region_fea = BatchNormalization()(region_fea)
116 |
117 | # attribute feature fusion
118 | attr_scores = merge(attr_score_list, mode='concat')
119 | attr_scores = BatchNormalization()(attr_scores)
120 | attr_scores = Activation('sigmoid')(attr_scores)
121 | attr_fea = merge(attr_fea_list, mode='concat')
122 | attr_fea = Reshape((emb_dim, len(nb_attributes)))(attr_fea)
123 | equal_attr_fea = GlobalAveragePooling1D()(attr_fea)
124 | attr_fea = merge([attr_fea,attr_scores], mode='dot', dot_axes=(2,1))
125 | attr_fea = Lambda(lambda x: x*(1.0/len(nb_attributes)))(attr_fea)
126 | attr_fea = BatchNormalization()(attr_fea)
127 |
128 | # loss-3: final classification
129 | if(attr_equal):
130 | attr_fea = equal_attr_fea
131 | if(region_equal):
132 | region_fea = id_pool
133 | final_fea = merge([attr_fea,region_fea], mode='concat')
134 | final_fea = Activation('relu', name='final_fea')(final_fea)
135 | final_fea = Dropout(dropout)(final_fea)
136 | final_prob = Dense(nb_classes)(final_fea)
137 | final_prob = Activation(activation='softmax',name='p')(final_prob)
138 | out_list.append(final_prob)
139 | loss_dict['p'] = 'categorical_crossentropy'
140 | weight_dict['p'] = lambdas[2]
141 |
142 | model = Model(inputs, out_list)
143 | if(flag_test):
144 | model.load_weights(model_weight_path)
145 |
146 | model.summary()
147 | #plot(model, show_shapes=True, to_file='./fig/'+net+'_attention.png')
148 |
149 | # the data, shuffled and split between train and test sets
150 | (X_train, y_train),(X_test, y_test),(A_train,A_test,C_A)=eval(dataset).load_data(
151 | data_folder, target_size=(img_rows, img_cols), bounding_box=True)
152 |
153 | print(X_train[100][1][50:60,100:110])
154 | print('X_train shape:', X_train.shape)
155 | print('X_test shape:', X_test.shape)
156 |
157 | # concat Y A
158 | yA_train = np.concatenate((np.expand_dims(y_train,1), A_train), axis=1)
159 | yA_test = np.concatenate((np.expand_dims(y_test,1), A_test), axis=1)
160 | print('yA_train shape:', yA_train.shape)
161 | print('yA_test shape:', yA_test.shape)
162 |
163 | # train/test
164 | for lr in lr_list:
165 | # test
166 | if(flag_test):
167 | label_test_list = []
168 | label_test_list.append(np_utils.to_categorical(y_test, nb_classes))
169 | for i in range(len(nb_attributes)):
170 | label_test_list.append(np_utils.to_categorical(A_test[:,i], nb_attributes[i]))
171 | label_test_list.append(np_utils.to_categorical(y_test, nb_classes))
172 | scores = model.evaluate(X_test, label_test_list, verbose=0)
173 | print('\nval-loss: ',scores[:1+len(loss_dict)], '\nval-acc: ', scores[1+len(loss_dict):])
174 | break
175 | # train
176 | if(not flag_test):
177 | if(lr==0.011):
178 | for layer in model.layers:
179 | if(layer.name=='reshape_layer'):
180 | break
181 | layer.trainable=False
182 | else:
183 | for layer in model.layers:
184 | layer.trainable=True
185 | opt = SGD(lr=lr, decay=5e-4, momentum=0.9, nesterov=True)
186 | model.compile(loss=loss_dict,
187 | loss_weights=weight_dict,
188 | optimizer=opt, metrics=['accuracy'])
189 | # data augment this will do preprocessing and realtime data augmentation
190 | datagen = ImageDataGenerator(
191 | featurewise_center=False, # set input mean to 0 over the dataset
192 | samplewise_center=False, # set each sample mean to 0
193 | featurewise_std_normalization=False, # divide inputs by std of the dataset
194 | samplewise_std_normalization=False, # divide each input by its std
195 | zca_whitening=False, # apply ZCA whitening
196 | rotation_range=30, # randomly rotate images in the range (degrees, 0 to 180)
197 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
198 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
199 | zoom_range=[0.75,1.33],
200 | horizontal_flip=True, # randomly flip images
201 | vertical_flip=False) # randomly flip images
202 | # train for nb_epoch epoches
203 | for e in range(nb_epoch):
204 | time1 = time.time()
205 | print('Epoch %d/%d' % (e+1,nb_epoch))
206 | batches = 1
207 | ave_loss = np.zeros(1+2*len(loss_dict))
208 | for X_batch, yA_batch in datagen.flow(X_train, yA_train, batch_size=batch_size):
209 | y_batch = yA_batch[:,:1]
210 | attr_batch = yA_batch[:,1:]
211 | label_batch_list = []
212 | label_batch_list.append(np_utils.to_categorical(y_batch, nb_classes))
213 | for i in range(len(nb_attributes)):
214 | label_batch_list.append(np_utils.to_categorical(attr_batch[:,i], nb_attributes[i]))
215 | label_batch_list.append(np_utils.to_categorical(y_batch, nb_classes))
216 | loss = model.train_on_batch(X_batch, label_batch_list)
217 | # print
218 | ave_loss = ave_loss*(batches-1)/batches + np.array(loss)/batches
219 | show_idx = [0,len(loss_dict)+1,len(loss_dict)+2,2*len(loss_dict)]
220 | sys.stdout.write('\rtrain-loss: %.4f, train-acc: %.4f %.4f %.4f'
221 | % tuple(ave_loss[show_idx].tolist()))
222 | sys.stdout.flush()
223 | batches += 1
224 | if batches > len(X_train)/batch_size:
225 | sys.stdout.write("\r \r\n")
226 | break
227 | # test
228 | label_test_list = []
229 | label_test_list.append(np_utils.to_categorical(y_test, nb_classes))
230 | for i in range(len(nb_attributes)):
231 | label_test_list.append(np_utils.to_categorical(A_test[:,i], nb_attributes[i]))
232 | label_test_list.append(np_utils.to_categorical(y_test, nb_classes))
233 | scores = model.evaluate(X_test, label_test_list, verbose=0)
234 | print('\nval-loss: ',scores[:1+len(loss_dict)], '\nval-acc: ', scores[1+len(loss_dict):])
235 | print('Main acc: %f' %(scores[-1]))
236 | # save model
237 | model.save_weights('./model/weights_'+net+str(lr)+'.h5')
238 | print('train stage:',lr,' sgd done!')
239 |
240 |
--------------------------------------------------------------------------------
/fig/a3m.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iamhankai/attribute-aware-attention/4b4cf873d6e398f1e64891dbc34ccb8fbd891f30/fig/a3m.png
--------------------------------------------------------------------------------
/fig/cub-dir.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iamhankai/attribute-aware-attention/4b4cf873d6e398f1e64891dbc34ccb8fbd891f30/fig/cub-dir.png
--------------------------------------------------------------------------------
/fig/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iamhankai/attribute-aware-attention/4b4cf873d6e398f1e64891dbc34ccb8fbd891f30/fig/result.png
--------------------------------------------------------------------------------
/fig/title.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iamhankai/attribute-aware-attention/4b4cf873d6e398f1e64891dbc34ccb8fbd891f30/fig/title.png
--------------------------------------------------------------------------------
/model/readme.md:
--------------------------------------------------------------------------------
1 | model path
2 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | ## Attribute-Aware Attention Model
2 | Code for ACM Multimedia 2018 oral paper: Attribute-Aware Attention Model for Fine-grained Representation Learning
3 |
4 |
5 |
6 | We have presented results of **fine-grained classification**, **person re-id**, **image retrieval** tasks, including CUB-200-2011, Market-1501, CARS196 datasets in the paper. Here is the example of fine-grained classification. For detailed results, refer to the [original paper](https://dl.acm.org/citation.cfm?id=3240550) or [ArXiv](https://arxiv.org/abs/1901.00392).
7 |
8 |
9 | ### Usage
10 | Requires: Keras 1.2.1 ("image_data_format": "channels_first")
11 |
12 | Run in two steps:
13 |
14 | 1. Download CUB-200-2011 dataset [here](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) and unzip it to `$CUB`; Copy file `tools/processed_attributes.txt` to `$CUB`.
15 |
16 | - The `$CUB` dir should be like this:
17 |
18 |
19 | 2. Change `data_dir` in `run.sh` to `$CUB`, run the scprit `sh run.sh` to obtain the result.
20 |
21 | - Result on CUB dataset
22 |
23 |
24 |
25 |
26 | ### Citation
27 | Please use the following bibtex to cite our work:
28 | ```
29 | @inproceedings{han2018attribute,
30 | title={Attribute-Aware Attention Model for Fine-grained Representation Learning},
31 | author={Han, Kai and Guo, Jianyuan and Zhang, Chao and Zhu, Mingjian},
32 | booktitle={Proceedings of the 26th ACM international conference on Multimedia},
33 | pages={2040--2048},
34 | year={2018},
35 | organization={ACM}
36 | }
37 | ```
38 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | timestamp=`date +%s`
2 | datetime=`date -d @$timestamp +"%Y-%m-%d-%H:%M:%S"`
3 | #net=AlexNet
4 | #net=VGG16
5 | #net=InceptionV3
6 | net=ResNet50
7 | data_dir=/home/hankai/data/CUB_200_2011
8 | gpu_id=0
9 | THEANO_FLAGS='device=gpu'$gpu_id',floatX=float32,lib.cnmem=0.6' python cub_demo.py $net $data_dir | tee "./log/"$net"-"$datetime".log.txt"
10 |
11 |
12 |
--------------------------------------------------------------------------------
/tools/attributes_process.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | import re
3 | import numpy as np
4 |
5 | # get attribute cluster idx
6 | attribute_name_file = 'attributes.txt'
7 | f1 = open(attribute_name_file, 'rb')
8 | start_idxs = []
9 | last_attr = ''
10 | for line in f1.readlines():
11 | strs = re.split(' |::', line)
12 | if(strs[1]!=last_attr):
13 | start_idxs.append(int(strs[0]))
14 | last_attr = strs[1]
15 | start_idxs.append(int(strs[0])+1)
16 | print(start_idxs)
17 | a = np.array(start_idxs)
18 | nums = a[1:]-a[:-1]+1
19 | print(np.sum(nums))
20 | print(nums.tolist())
21 |
22 | # transform binary attribute to clustered attribute
23 | nb_attr = len(start_idxs)-1
24 | A_all = np.zeros((11788,nb_attr))
25 | image_attribute_file = 'attributes/image_attribute_labels.txt'
26 | f2 = open(image_attribute_file,'rb')
27 | for line in f2.readlines():
28 | strs = re.split(' ', line)
29 | img_id = int(strs[0])-1
30 | attr_id = int(strs[1])
31 | is_present = int(strs[2])
32 | if(is_present>0):
33 | for i in range(len(start_idxs)):
34 | if(attr_id