├── README.md ├── dataset ├── process_CUHK_data.py ├── process_ICFG_data.py └── utils │ └── read_write_data.py ├── experiment ├── CUHK-PEDES │ ├── test.sh │ └── train.sh └── ICFG-PEDES │ ├── test.sh │ └── train.sh ├── figure ├── CUHK-PEDES_result.GIF └── ICFG-PEDES_result.GIF └── src ├── data ├── dataloader.py └── dataset.py ├── loss ├── Id_loss.py └── RankingLoss.py ├── model ├── model.py └── text_feature_extract.py ├── option └── options.py ├── test.py ├── test_during_train.py ├── train.py └── utils └── read_write_data.py /README.md: -------------------------------------------------------------------------------- 1 | # Semantically Self-Aligned Network for Text-to-Image Part-aware Person Re-identification 2 | 3 | [![LICENSE](https://img.shields.io/badge/license-MIT-green)](https://github.com/taksau/GPS-Net/blob/master/LICENSE) 4 | [![Python](https://img.shields.io/badge/python-3.7-blue.svg)](https://www.python.org/) 5 | ![PyTorch](https://img.shields.io/badge/pytorch-1.5.0-%237732a8) 6 | 7 | We provide the code for reproducing result of our paper [**Semantically Self-Aligned Network for Text-to-Image Part-aware Person Re-identification**](https://arxiv.org/pdf/2107.12666.pdf). 8 | 9 | ## Getting Started 10 | #### Dataset Preparation 11 | 12 | 1. **CUHK-PEDES** 13 | 14 | Organize them in `dataset` folder as follows: 15 | 16 | 17 | ~~~ 18 | |-- dataset/ 19 | | |-- / 20 | | |-- imgs 21 | |-- cam_a 22 | |-- cam_b 23 | |-- ... 24 | | |-- reid_raw.json 25 | 26 | ~~~ 27 | 28 | Download the CUHK-PEDES dataset from [here](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description) and then run the `process_CUHK_data.py` as follow: 29 | 30 | ~~~ 31 | cd SSAN 32 | python ./dataset/process_CUHK_data.py 33 | ~~~ 34 | 35 | 2. **ICFG-PEDES** 36 | 37 | Organize them in `dataset` folder as follows: 38 | 39 | ~~~ 40 | |-- dataset/ 41 | | |-- / 42 | | |-- imgs 43 | |-- test 44 | |-- train 45 | | |-- ICFG_PEDES.json 46 | 47 | ~~~ 48 | 49 | Note that our ICFG-PEDES is collect from [MSMT17](https://github.com/pkuvmc/PTGAN) and thus we keep its storage structure in order to avoid the loss of information such as camera label, shooting time, etc. Therefore, the file `test`and `train` here are not the way ICFG-PEDES is divided. The exact division of ICFG-PEDES is determined by `ICFG-PDES.json`. The `ICFG-PDES.json` is organized like the `reid_raw.json` in [CUHK-PEDES](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description) . 50 | 51 | Please request the ICFG-PEDES database from [272521211@qq.com](mailto:272521211@qq.com) and then run the `process_ICFG_data.py` as follow: 52 | 53 | ~~~ 54 | cd SSAN 55 | python ./dataset/process_ICFG_data.py 56 | ~~~ 57 | 58 | #### Training and Testing 59 | ~~~ 60 | sh experiments/CUHK-PEDES/train.sh 61 | sh experiments/ICFG-PEDES/train.sh 62 | ~~~ 63 | #### Evaluation 64 | ~~~ 65 | sh experiments/CUHK-PEDES/test.sh 66 | sh experiments/ICFG-PEDES/test.sh 67 | ~~~ 68 | 69 | ## Results on CUHK-PEDES and ICFG-PEDES 70 | 71 | **Our Results on CUHK-PEDES dataset** 72 | 73 | 74 | 75 | **Our Results on ICFG-PEDES dataset** 76 | 77 | 78 | 79 | ## Citation 80 | 81 | If this work is helpful for your research, please cite our work: 82 | 83 | ~~~ 84 | @article{ding2021semantically, 85 | title={Semantically Self-Aligned Network for Text-to-Image Part-aware Person Re-identification}, 86 | author={Ding, Zefeng and Ding, Changxing and Shao, Zhiyin and Tao, Dacheng}, 87 | journal={arXiv preprint arXiv:2107.12666}, 88 | year={2021} 89 | } 90 | ~~~ 91 | -------------------------------------------------------------------------------- /dataset/process_CUHK_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | processes the CUHK_PEDES/reid_raw.json, output the train_data, val_data, test_data, 4 | 5 | For example: 6 | train_data.pkl contains the data in dict format 7 | 'id': id for the caption-image pair 8 | 'img_path': the image in the caption-image pair 9 | 'same_id_index': the id number in the dict of captions of other images from same id 10 | 'lstm_caption_id': the code of per caption for bi-lstm 11 | 'captions': the caption in the caption-image pair 12 | 13 | 14 | @author: zifyloo 15 | """ 16 | 17 | from utils.read_write_data import read_json, makedir, save_dict, write_txt 18 | import argparse 19 | import os 20 | import numpy as np 21 | 22 | 23 | class Word2Index(object): 24 | 25 | def __init__(self, vocab): 26 | self._vocab = {w: index + 1 for index, w in enumerate(vocab)} 27 | self.unk_id = len(vocab) + 1 28 | 29 | def __call__(self, word): 30 | if word not in self._vocab: 31 | return self.unk_id 32 | return self._vocab[word] 33 | 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser(description='Command for data pre_processing') 37 | parser.add_argument('--json_root', default='./CUHK-PEDES/reid_raw.json', type=str) 38 | parser.add_argument('--out_root', default='./CUHK-PEDES/processed_data', type=str) 39 | parser.add_argument('--min_word_count', default='2', type=int) 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | def split_json(args): 45 | """ 46 | has 40206 image in reid_raw_data 47 | has 13003 id 48 | every id has several images and every image has several caption 49 | data's structure in reid_raw_data is dict ['split', 'captions', 'file_path', 'processed_tokens', 'id'] 50 | """ 51 | reid_raw_data = read_json(args.json_root) 52 | 53 | train_json = [] 54 | test_json = [] 55 | val_json = [] 56 | for data in reid_raw_data: 57 | data_save = { 58 | 'img_path': 'imgs/'+data['file_path'], 59 | 'id': data['id'], 60 | 'tokens': data['processed_tokens'], 61 | 'captions': data['captions'] 62 | } 63 | 64 | split = data['split'].lower() 65 | if split == 'train': 66 | train_json.append(data_save) 67 | elif split == 'test': 68 | test_json.append(data_save) 69 | else: 70 | val_json.append(data_save) 71 | return train_json, test_json, val_json 72 | 73 | 74 | def build_vocabulary(train_json, args): 75 | 76 | word_count = {} 77 | for data in train_json: 78 | for caption in data['tokens']: 79 | for word in caption: 80 | word_count[word.lower()] = word_count.get(word.lower(), 0) + 1 81 | 82 | word_count_list = [[v, k] for v, k in word_count.items()] 83 | word_count_list.sort(key=lambda x: x[1], reverse=True) # from high to low 84 | 85 | good_vocab = [v for v, k in word_count.items() if k >= args.min_word_count] 86 | 87 | print('top-10 highest frequency words:') 88 | for w, n in word_count_list[0:10]: 89 | print(w, n) 90 | 91 | good_count = len(good_vocab) 92 | all_count = len(word_count_list) 93 | good_word_rate = good_count * 100.0 / all_count 94 | st = 'good words: %d, total_words: %d, good_word_rate: %f%%' % (good_count, all_count, good_word_rate) 95 | write_txt(st, os.path.join(args.out_root, 'data_message')) 96 | print(st) 97 | word2Ind = Word2Index(good_vocab) 98 | 99 | save_dict(good_vocab, os.path.join(args.out_root, 'ind2word')) 100 | return word2Ind 101 | 102 | 103 | def generate_captionid(data_json, word2Ind, data_name, args): 104 | 105 | id_save = [] 106 | lstm_caption_id_save = [] 107 | img_path_save = [] 108 | caption_save = [] 109 | same_id_index_save = [] 110 | un_idx = word2Ind.unk_id 111 | data_save_by_id = {} 112 | for data in data_json: 113 | 114 | if data['id'] in [1369, 4116, 6116]: # CR need two images for per id at least, these ids have only one image, 115 | continue 116 | if data['id'] > 6116: 117 | id_new = data['id'] - 4 118 | elif data['id'] > 4116: 119 | id_new = data['id'] - 3 120 | elif data['id'] > 1369: 121 | id_new = data['id'] - 2 122 | else: 123 | id_new = data['id'] - 1 124 | 125 | data_save_i = { 126 | 'img_path': data['img_path'], 127 | 'id': id_new, 128 | 'tokens': data['tokens'], 129 | 'captions': data['captions'] 130 | } 131 | if id_new not in data_save_by_id.keys(): 132 | data_save_by_id[id_new] = [] 133 | 134 | data_save_by_id[id_new].append(data_save_i) 135 | 136 | data_order = 0 137 | for id_new, data_save_by_id_i in data_save_by_id.items(): 138 | 139 | caption_length = 0 140 | for data_save_by_id_i_i in data_save_by_id_i: 141 | caption_length += len(data_save_by_id_i_i['captions']) 142 | 143 | data_order_i = data_order + np.arange(caption_length) 144 | data_order_i_begin = 0 145 | 146 | for data_save_by_id_i_i in data_save_by_id_i: 147 | caption_length_i = len(data_save_by_id_i_i['captions']) 148 | data_order_i_end = data_order_i_begin + caption_length_i 149 | data_order_i_select = np.delete(data_order_i, np.arange(data_order_i_begin, data_order_i_end)) 150 | data_order_i_begin = data_order_i_end 151 | 152 | for j in range(len(data_save_by_id_i_i['tokens'])): 153 | tokens_j = data_save_by_id_i_i['tokens'][j] 154 | lstm_caption_id = [] 155 | for word in tokens_j: 156 | lstm_caption_id.append(word2Ind(word)) 157 | if un_idx in lstm_caption_id: 158 | lstm_caption_id = list(filter(lambda x: x != un_idx, lstm_caption_id)) 159 | 160 | caption_j = data_save_by_id_i_i['captions'][j] 161 | 162 | id_save.append(data_save_by_id_i_i['id']) 163 | img_path_save.append(data_save_by_id_i_i['img_path']) 164 | same_id_index_save.append(data_order_i_select) 165 | 166 | lstm_caption_id_save.append(lstm_caption_id) 167 | caption_save.append(caption_j) 168 | 169 | data_order = data_order + caption_length 170 | 171 | data_save = { 172 | 'id': id_save, 173 | 'img_path': img_path_save, 174 | 'same_id_index': same_id_index_save, 175 | 176 | 'lstm_caption_id': lstm_caption_id_save, 177 | 'captions': caption_save, 178 | } 179 | 180 | img_num = len(set(img_path_save)) 181 | id_num = len(set(id_save)) 182 | caption_num = len(lstm_caption_id_save) 183 | 184 | st = '%s_img_num: %d, %s_id_num: %d, %s_caption_num: %d, ' \ 185 | % (data_name, img_num, data_name, id_num, data_name, caption_num) 186 | write_txt(st, os.path.join(args.out_root, 'data_message')) 187 | 188 | return data_save 189 | 190 | 191 | def generate_test_val_caption_id(data_json, word2Ind, data_name, args): 192 | id_save = [] 193 | lstm_caption_id_save = [] 194 | caption_save = [] 195 | img_path_save = [] 196 | img_caption_index_save = [] 197 | caption_matching_img_index_save = [] 198 | caption_label_save = [] 199 | 200 | un_idx = word2Ind.unk_id 201 | 202 | img_caption_index_i = 0 203 | caption_matching_img_index_i = 0 204 | for data in data_json: 205 | id_save.append(data['id']) 206 | img_path_save.append(data['img_path']) 207 | 208 | for j in range(len(data['tokens'])): 209 | 210 | tokens_j = data['tokens'][j] 211 | lstm_caption_id = [] 212 | for word in tokens_j: 213 | lstm_caption_id.append(word2Ind(word)) 214 | if un_idx in lstm_caption_id: 215 | lstm_caption_id = list(filter(lambda x: x != un_idx, lstm_caption_id)) 216 | 217 | caption_j = data['captions'][j] 218 | 219 | caption_matching_img_index_save.append(caption_matching_img_index_i) 220 | lstm_caption_id_save.append(lstm_caption_id) 221 | caption_save.append(caption_j) 222 | 223 | caption_label_save.append(data['id']) 224 | img_caption_index_save.append([img_caption_index_i, img_caption_index_i+len(data['captions'])-1]) 225 | img_caption_index_i += len(data['captions']) 226 | caption_matching_img_index_i += 1 227 | 228 | data_save = { 229 | 'id': id_save, 230 | 'img_path': img_path_save, 231 | 'img_caption_index': img_caption_index_save, 232 | 233 | 'caption_matching_img_index': caption_matching_img_index_save, 234 | 'caption_label': caption_label_save, 235 | 'lstm_caption_id': lstm_caption_id_save, 236 | 'captions': caption_save, 237 | } 238 | 239 | img_num = len(set(img_path_save)) 240 | id_num = len(set(id_save)) 241 | caption_num = len(lstm_caption_id_save) 242 | st = '%s_img_num: %d, %s_id_num: %d, %s_caption_num: %d, ' % ( 243 | data_name, img_num, data_name, id_num, data_name, caption_num) 244 | write_txt(st, os.path.join(args.out_root, 'data_message')) 245 | 246 | return data_save 247 | 248 | 249 | def main(args): 250 | train_json, test_json, val_json = split_json(args) 251 | 252 | word2Ind = build_vocabulary(train_json, args) 253 | 254 | train_save = generate_captionid(train_json, word2Ind, 'train', args) 255 | test_save = generate_test_val_caption_id(test_json, word2Ind, 'test', args) 256 | val_save = generate_captionid(val_json, word2Ind, 'val', args) 257 | 258 | save_dict(train_save, os.path.join(args.out_root, 'train_save')) 259 | save_dict(test_save, os.path.join(args.out_root, 'test_save')) 260 | save_dict(val_save, os.path.join(args.out_root, 'val_save')) 261 | 262 | 263 | if __name__ == '__main__': 264 | 265 | args = parse_args() 266 | 267 | makedir(args.out_root) 268 | main(args) 269 | -------------------------------------------------------------------------------- /dataset/process_ICFG_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | processes the ICFG_PEDES/ICFG_PEDES.json, output the train_data, test_data 4 | For example: 5 | train_data.pkl contains the data in dict format 6 | 'id': id for the caption-image pair 7 | 'img_path': the image in the caption-image pair 8 | 'same_id_index': the id number in the dict of captions of other images from same id 9 | 'lstm_caption_id': the code of per caption for bi-lstm 10 | 'captions': the caption in the caption-image pair 11 | @author: zifyloo 12 | """ 13 | 14 | from utils.read_write_data import read_json, makedir, save_dict, write_txt 15 | import argparse 16 | import os 17 | import numpy as np 18 | 19 | 20 | class Word2Index(object): 21 | 22 | def __init__(self, vocab): 23 | self._vocab = {w: index + 1 for index, w in enumerate(vocab)} 24 | self.unk_id = len(vocab) + 1 25 | 26 | def __call__(self, word): 27 | if word not in self._vocab: 28 | return self.unk_id 29 | return self._vocab[word] 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser(description='Command for data pre_processing') 34 | parser.add_argument('--json_root', default='./ICFG-PEDES/ICFG-PEDES.json', type=str) 35 | parser.add_argument('--out_root', default='./ICFG-PEDES/processed_data', type=str) 36 | parser.add_argument('--min_word_count', default='2', type=int) 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | def split_json(args): 42 | """ 43 | has 40206 image in reid_raw_data 44 | has 13003 id 45 | every id has several images and every image has several caption 46 | data's structure in reid_raw_data is dict ['split', 'captions', 'file_path', 'processed_tokens', 'id'] 47 | """ 48 | reid_raw_data = read_json(args.json_root) 49 | 50 | train_json = [] 51 | test_json = [] 52 | for data in reid_raw_data: 53 | data_save = { 54 | 'img_path': 'imgs/'+data['file_path'], 55 | 'id': data['id'], 56 | 'tokens': data['processed_tokens'], 57 | 'captions': data['captions'] 58 | } 59 | 60 | split = data['split'].lower() 61 | if split == 'train': 62 | train_json.append(data_save) 63 | elif split == 'test': 64 | test_json.append(data_save) 65 | return train_json, test_json 66 | 67 | 68 | def build_vocabulary(train_json, args): 69 | 70 | word_count = {} 71 | for data in train_json: 72 | for caption in data['tokens']: 73 | for word in caption: 74 | word_count[word.lower()] = word_count.get(word.lower(), 0) + 1 75 | 76 | word_count_list = [[v, k] for v, k in word_count.items()] 77 | word_count_list.sort(key=lambda x: x[1], reverse=True) # from high to low 78 | 79 | good_vocab = [v for v, k in word_count.items() if k >= args.min_word_count] 80 | 81 | print('top-10 highest frequency words:') 82 | for w, n in word_count_list[0:10]: 83 | print(w, n) 84 | 85 | good_count = len(good_vocab) 86 | all_count = len(word_count_list) 87 | good_word_rate = good_count * 100.0 / all_count 88 | st = 'good words: %d, total_words: %d, good_word_rate: %f%%' % (good_count, all_count, good_word_rate) 89 | write_txt(st, os.path.join(args.out_root, 'data_message')) 90 | print(st) 91 | word2Ind = Word2Index(good_vocab) 92 | 93 | save_dict(good_vocab, os.path.join(args.out_root, 'ind2word')) 94 | return word2Ind 95 | 96 | 97 | def generate_captionid(data_json, word2Ind, data_name, args): 98 | 99 | id_save = [] 100 | lstm_caption_id_save = [] 101 | img_path_save = [] 102 | caption_save = [] 103 | same_id_index_save = [] 104 | un_idx = word2Ind.unk_id 105 | data_save_by_id = {} 106 | for data in data_json: 107 | 108 | id_new = data['id'] 109 | 110 | data_save_i = { 111 | 'img_path': data['img_path'], 112 | 'id': id_new, 113 | 'tokens': data['tokens'], 114 | 'captions': data['captions'] 115 | } 116 | if id_new not in data_save_by_id.keys(): 117 | data_save_by_id[id_new] = [] 118 | 119 | data_save_by_id[id_new].append(data_save_i) 120 | 121 | data_order = 0 122 | for id_new, data_save_by_id_i in data_save_by_id.items(): 123 | 124 | caption_length = 0 125 | for data_save_by_id_i_i in data_save_by_id_i: 126 | caption_length += len(data_save_by_id_i_i['captions']) 127 | 128 | data_order_i = data_order + np.arange(caption_length) 129 | data_order_i_begin = 0 130 | 131 | for data_save_by_id_i_i in data_save_by_id_i: 132 | caption_length_i = len(data_save_by_id_i_i['captions']) 133 | data_order_i_end = data_order_i_begin + caption_length_i 134 | data_order_i_select = np.delete(data_order_i, np.arange(data_order_i_begin, data_order_i_end)) 135 | data_order_i_begin = data_order_i_end 136 | 137 | for j in range(len(data_save_by_id_i_i['tokens'])): 138 | tokens_j = data_save_by_id_i_i['tokens'][j] 139 | lstm_caption_id = [] 140 | for word in tokens_j: 141 | lstm_caption_id.append(word2Ind(word)) 142 | if un_idx in lstm_caption_id: 143 | lstm_caption_id = list(filter(lambda x: x != un_idx, lstm_caption_id)) 144 | 145 | caption_j = data_save_by_id_i_i['captions'][j] 146 | 147 | id_save.append(data_save_by_id_i_i['id']) 148 | img_path_save.append(data_save_by_id_i_i['img_path']) 149 | same_id_index_save.append(data_order_i_select) 150 | 151 | lstm_caption_id_save.append(lstm_caption_id) 152 | caption_save.append(caption_j) 153 | 154 | data_order = data_order + caption_length 155 | 156 | data_save = { 157 | 'id': id_save, 158 | 'img_path': img_path_save, 159 | 'same_id_index': same_id_index_save, 160 | 161 | 'lstm_caption_id': lstm_caption_id_save, 162 | 'captions': caption_save, 163 | } 164 | 165 | img_num = len(set(img_path_save)) 166 | id_num = len(set(id_save)) 167 | caption_num = len(lstm_caption_id_save) 168 | 169 | st = '%s_img_num: %d, %s_id_num: %d, %s_caption_num: %d, ' \ 170 | % (data_name, img_num, data_name, id_num, data_name, caption_num) 171 | write_txt(st, os.path.join(args.out_root, 'data_message')) 172 | 173 | return data_save 174 | 175 | 176 | def generate_test_val_caption_id(data_json, word2Ind, data_name, args): 177 | id_save = [] 178 | lstm_caption_id_save = [] 179 | caption_save = [] 180 | img_path_save = [] 181 | img_caption_index_save = [] 182 | caption_matching_img_index_save = [] 183 | caption_label_save = [] 184 | 185 | un_idx = word2Ind.unk_id 186 | 187 | img_caption_index_i = 0 188 | caption_matching_img_index_i = 0 189 | for data in data_json: 190 | id_save.append(data['id']) 191 | img_path_save.append(data['img_path']) 192 | 193 | for j in range(len(data['tokens'])): 194 | 195 | tokens_j = data['tokens'][j] 196 | lstm_caption_id = [] 197 | for word in tokens_j: 198 | lstm_caption_id.append(word2Ind(word)) 199 | if un_idx in lstm_caption_id: 200 | lstm_caption_id = list(filter(lambda x: x != un_idx, lstm_caption_id)) 201 | 202 | caption_j = data['captions'][j] 203 | 204 | caption_matching_img_index_save.append(caption_matching_img_index_i) 205 | lstm_caption_id_save.append(lstm_caption_id) 206 | caption_save.append(caption_j) 207 | 208 | caption_label_save.append(data['id']) 209 | img_caption_index_save.append([img_caption_index_i, img_caption_index_i+len(data['captions'])-1]) 210 | img_caption_index_i += len(data['captions']) 211 | caption_matching_img_index_i += 1 212 | 213 | data_save = { 214 | 'id': id_save, 215 | 'img_path': img_path_save, 216 | 'img_caption_index': img_caption_index_save, 217 | 218 | 'caption_matching_img_index': caption_matching_img_index_save, 219 | 'caption_label': caption_label_save, 220 | 'lstm_caption_id': lstm_caption_id_save, 221 | 'captions': caption_save, 222 | } 223 | 224 | img_num = len(set(img_path_save)) 225 | id_num = len(set(id_save)) 226 | caption_num = len(lstm_caption_id_save) 227 | st = '%s_img_num: %d, %s_id_num: %d, %s_caption_num: %d, ' % ( 228 | data_name, img_num, data_name, id_num, data_name, caption_num) 229 | write_txt(st, os.path.join(args.out_root, 'data_message')) 230 | 231 | return data_save 232 | 233 | 234 | def main(args): 235 | train_json, test_json = split_json(args) 236 | 237 | word2Ind = build_vocabulary(train_json, args) 238 | 239 | train_save = generate_captionid(train_json, word2Ind, 'train', args) 240 | test_save = generate_test_val_caption_id(test_json, word2Ind, 'test', args) 241 | 242 | save_dict(train_save, os.path.join(args.out_root, 'train_save')) 243 | save_dict(test_save, os.path.join(args.out_root, 'test_save')) 244 | 245 | 246 | if __name__ == '__main__': 247 | 248 | args = parse_args() 249 | 250 | makedir(args.out_root) 251 | main(args) 252 | -------------------------------------------------------------------------------- /dataset/utils/read_write_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | the tool to read or write the data. Have a good luck! 4 | 5 | Created on Thurs., Aug. 1(st), 2019 at 20:15. 6 | 7 | Updated on Thurs., Aug. 9(th), 2021 at 11:12 8 | 9 | @author: zifyloo 10 | """ 11 | 12 | import os 13 | import json 14 | import pickle 15 | 16 | 17 | def makedir(root): 18 | if not os.path.exists(root): 19 | os.makedirs(root) 20 | 21 | 22 | def write_json(data, root): 23 | with open(root, 'w') as f: 24 | json.dump(data, f) 25 | 26 | 27 | def read_json(root): 28 | with open(root, 'r') as f: 29 | data = json.load(f) 30 | 31 | return data 32 | 33 | 34 | def read_dict(root): 35 | with open(root, 'rb') as f: 36 | data = pickle.load(f) 37 | 38 | return data 39 | 40 | 41 | def save_dict(data, name): 42 | with open(name + '.pkl', 'wb') as f: 43 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 44 | 45 | 46 | def write_txt(data, name): 47 | with open(name, 'a') as f: 48 | f.write(data) 49 | f.write('\n') 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /experiment/CUHK-PEDES/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | 5 | python test.py --model_name 'SSAN' \ 6 | --GPU_id 0 \ 7 | --part 6 \ 8 | --lr 0.001 \ 9 | --dataset 'CUHK-PEDES' \ 10 | --dataroot '../dataset/CUHK-PEDES/' \ 11 | --vocab_size 5000 \ 12 | --feature_length 1024 \ 13 | --mode 'test' 14 | -------------------------------------------------------------------------------- /experiment/CUHK-PEDES/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | 5 | python train.py --model_name 'SSAN' \ 6 | --GPU_id 0 \ 7 | --part 6 \ 8 | --lr 0.001 \ 9 | --dataset 'CUHK-PEDES' \ 10 | --epoch 60 \ 11 | --dataroot '../dataset/CUHK-PEDES/' \ 12 | --class_num 11000 \ 13 | --vocab_size 5000 \ 14 | --feature_length 1024 \ 15 | --mode 'train' \ 16 | --batch_size 64 \ 17 | --cr_beta 0.1 18 | -------------------------------------------------------------------------------- /experiment/ICFG-PEDES/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | 5 | python test.py --model_name 'SSAN' \ 6 | --GPU_id 0 \ 7 | --part 6 \ 8 | --lr 0.001 \ 9 | --dataset 'ICFG-PEDES' \ 10 | --dataroot '../dataset/ICFG-PEDES/' \ 11 | --vocab_size 2500 \ 12 | --feature_length 1024 \ 13 | --mode 'test' 14 | -------------------------------------------------------------------------------- /experiment/ICFG-PEDES/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd src 4 | 5 | python train.py --model_name 'SSAN' \ 6 | --GPU_id 0 \ 7 | --part 6 \ 8 | --lr 0.001 \ 9 | --dataset 'ICFG-PEDES' \ 10 | --epoch 60 \ 11 | --dataroot '../dataset/ICFG-PEDES/' \ 12 | --class_num 3102 \ 13 | --vocab_size 2500 \ 14 | --feature_length 1024 \ 15 | --batch_size 64 \ 16 | --mode 'train' \ 17 | --cr_beta 0.1 18 | -------------------------------------------------------------------------------- /figure/CUHK-PEDES_result.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zifyloo/SSAN/567836932aaf7a54e29c39933320dded74383107/figure/CUHK-PEDES_result.GIF -------------------------------------------------------------------------------- /figure/ICFG-PEDES_result.GIF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zifyloo/SSAN/567836932aaf7a54e29c39933320dded74383107/figure/ICFG-PEDES_result.GIF -------------------------------------------------------------------------------- /src/data/dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | from torchvision import transforms 7 | from PIL import Image 8 | import torch 9 | from data.dataset import CUHKPEDEDataset, CUHKPEDE_img_dateset, CUHKPEDE_txt_dateset 10 | 11 | 12 | def get_dataloader(opt): 13 | """ 14 | tranforms the image, downloads the image with the id by data.DataLoader 15 | """ 16 | 17 | if opt.mode == 'train': 18 | transform_list = [ 19 | transforms.RandomHorizontalFlip(), 20 | transforms.Resize((384, 128), Image.BICUBIC), # interpolation 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5, 0.5, 0.5), 23 | (0.5, 0.5, 0.5))] 24 | tran = transforms.Compose(transform_list) 25 | 26 | dataset = CUHKPEDEDataset(opt, tran) 27 | 28 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, 29 | shuffle=True, drop_last=True, num_workers=3) 30 | print('{}-{} has {} pohtos'.format(opt.dataset, opt.mode, len(dataset))) 31 | 32 | return dataloader 33 | 34 | else: 35 | tran = transforms.Compose([ 36 | transforms.Resize((384, 128), Image.BICUBIC), # interpolation 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.5, 0.5, 0.5), 39 | (0.5, 0.5, 0.5))] 40 | ) 41 | 42 | img_dataset = CUHKPEDE_img_dateset(opt, tran) 43 | 44 | img_dataloader = torch.utils.data.DataLoader(img_dataset, batch_size=opt.batch_size, 45 | shuffle=False, drop_last=False, num_workers=3) 46 | 47 | txt_dataset = CUHKPEDE_txt_dateset(opt) 48 | 49 | txt_dataloader = torch.utils.data.DataLoader(txt_dataset, batch_size=opt.batch_size, 50 | shuffle=False, drop_last=False, num_workers=3) 51 | 52 | print('{}-{} has {} pohtos, {} text'.format(opt.dataset, opt.mode, len(img_dataset), len(txt_dataset))) 53 | 54 | return img_dataloader, txt_dataloader 55 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | import torch 7 | import torch.utils.data as data 8 | import numpy as np 9 | from PIL import Image 10 | import os 11 | from utils.read_write_data import read_dict 12 | import cv2 13 | import torchvision.transforms.functional as F 14 | import random 15 | 16 | 17 | def fliplr(img, dim): 18 | """ 19 | flip horizontal 20 | :param img: 21 | :return: 22 | """ 23 | inv_idx = torch.arange(img.size(dim) - 1, -1, -1).long() # N x C x H x W 24 | img_flip = img.index_select(dim, inv_idx) 25 | return img_flip 26 | 27 | 28 | class CUHKPEDEDataset(data.Dataset): 29 | def __init__(self, opt, tran): 30 | 31 | self.opt = opt 32 | self.flip_flag = (self.opt.mode == 'train') 33 | 34 | data_save = read_dict(os.path.join(opt.dataroot, 'processed_data', opt.mode + '_save.pkl')) 35 | 36 | self.img_path = [os.path.join(opt.dataroot, img_path) for img_path in data_save['img_path']] 37 | 38 | self.label = data_save['id'] 39 | 40 | self.caption_code = data_save['lstm_caption_id'] 41 | 42 | self.same_id_index = data_save['same_id_index'] 43 | 44 | self.transform = tran 45 | 46 | self.num_data = len(self.img_path) 47 | 48 | def __getitem__(self, index): 49 | """ 50 | :param index: 51 | :return: image and its label 52 | """ 53 | 54 | image = Image.open(self.img_path[index]) 55 | image = self.transform(image) 56 | label = torch.from_numpy(np.array([self.label[index]])).long() 57 | caption_code, caption_length = self.caption_mask(self.caption_code[index]) 58 | 59 | same_id_index = np.random.randint(len(self.same_id_index[index])) 60 | same_id_index = self.same_id_index[index][same_id_index] 61 | same_id_caption_code, same_id_caption_length = self.caption_mask(self.caption_code[same_id_index]) 62 | 63 | return image, label, caption_code, caption_length, same_id_caption_code, same_id_caption_length 64 | 65 | def get_data(self, index, img=True): 66 | if img: 67 | image = Image.open(self.img_path[index]) 68 | image = self.transform(image) 69 | else: 70 | image = 0 71 | 72 | label = torch.from_numpy(np.array([self.label[index]])).long() 73 | 74 | caption_code, caption_length = self.caption_mask(self.caption_code[index]) 75 | 76 | return image, label, caption_code, caption_length 77 | 78 | def caption_mask(self, caption): 79 | caption_length = len(caption) 80 | caption = torch.from_numpy(np.array(caption)).view(-1).long() 81 | 82 | if caption_length < self.opt.caption_length_max: 83 | zero_padding = torch.zeros(self.opt.caption_length_max - caption_length).long() 84 | caption = torch.cat([caption, zero_padding], 0) 85 | else: 86 | caption = caption[:self.opt.caption_length_max] 87 | caption_length = self.opt.caption_length_max 88 | 89 | return caption, caption_length 90 | 91 | def __len__(self): 92 | return self.num_data 93 | 94 | 95 | class CUHKPEDE_img_dateset(data.Dataset): 96 | def __init__(self, opt, tran): 97 | 98 | self.opt = opt 99 | 100 | data_save = read_dict(os.path.join(opt.dataroot, 'processed_data', opt.mode + '_save.pkl')) 101 | 102 | self.img_path = [os.path.join(opt.dataroot, img_path) for img_path in data_save['img_path']] 103 | 104 | self.label = data_save['id'] 105 | 106 | self.transform = tran 107 | 108 | self.num_data = len(self.img_path) 109 | 110 | def __getitem__(self, index): 111 | """ 112 | :param index: 113 | :return: image and its label 114 | """ 115 | 116 | image = Image.open(self.img_path[index]) 117 | image = self.transform(image) 118 | 119 | label = torch.from_numpy(np.array([self.label[index]])).long() 120 | 121 | return image, label 122 | 123 | def __len__(self): 124 | return self.num_data 125 | 126 | 127 | class CUHKPEDE_txt_dateset(data.Dataset): 128 | def __init__(self, opt): 129 | 130 | self.opt = opt 131 | 132 | data_save = read_dict(os.path.join(opt.dataroot, 'processed_data', opt.mode + '_save.pkl')) 133 | 134 | self.label = data_save['caption_label'] 135 | self.caption_code = data_save['lstm_caption_id'] 136 | 137 | self.num_data = len(self.caption_code) 138 | 139 | def __getitem__(self, index): 140 | """ 141 | :param index: 142 | :return: image and its label 143 | """ 144 | 145 | label = torch.from_numpy(np.array([self.label[index]])).long() 146 | 147 | caption_code, caption_length = self.caption_mask(self.caption_code[index]) 148 | return label, caption_code, caption_length 149 | 150 | def caption_mask(self, caption): 151 | caption_length = len(caption) 152 | caption = torch.from_numpy(np.array(caption)).view(-1).float() 153 | if caption_length < self.opt.caption_length_max: 154 | zero_padding = torch.zeros(self.opt.caption_length_max - caption_length) 155 | caption = torch.cat([caption, zero_padding], 0) 156 | else: 157 | caption = caption[:self.opt.caption_length_max] 158 | caption_length = self.opt.caption_length_max 159 | 160 | return caption, caption_length 161 | 162 | def __len__(self): 163 | return self.num_data 164 | 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /src/loss/Id_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parameter import Parameter 9 | from torch.nn import init 10 | 11 | 12 | def weights_init_classifier(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Linear') != -1: 15 | init.normal_(m.weight.data, std=0.001) 16 | init.constant_(m.bias.data, 0.0) 17 | 18 | 19 | class classifier(nn.Module): 20 | 21 | def __init__(self, input_dim, output_dim): 22 | super(classifier, self).__init__() 23 | 24 | self.block = nn.Linear(input_dim, output_dim) 25 | self.block.apply(weights_init_classifier) 26 | 27 | def forward(self, x): 28 | x = self.block(x) 29 | return x 30 | 31 | 32 | class Id_Loss(nn.Module): 33 | 34 | def __init__(self, opt, part, feature_length): 35 | super(Id_Loss, self).__init__() 36 | 37 | self.opt = opt 38 | self.part = part 39 | 40 | W = [] 41 | for i in range(part): 42 | W.append(classifier(feature_length, opt.class_num)) 43 | self.W = nn.Sequential(*W) 44 | 45 | def calculate_IdLoss(self, image_embedding_local, text_embedding_local, label): 46 | 47 | label = label.view(label.size(0)) 48 | 49 | criterion = nn.CrossEntropyLoss(reduction='mean') 50 | 51 | Lipt_local = 0 52 | Ltpi_local = 0 53 | 54 | for i in range(self.part): 55 | 56 | score_i2t_local_i = self.W[i](image_embedding_local[:, :, i]) 57 | score_t2i_local_i = self.W[i](text_embedding_local[:, :, i]) 58 | 59 | Lipt_local += criterion(score_i2t_local_i, label) 60 | Ltpi_local += criterion(score_t2i_local_i, label) 61 | 62 | loss = (Lipt_local + Ltpi_local) / self.part 63 | 64 | return loss 65 | 66 | def forward(self, image_embedding_local, text_embedding_local, label): 67 | 68 | loss = self.calculate_IdLoss(image_embedding_local, text_embedding_local, label) 69 | 70 | return loss 71 | 72 | -------------------------------------------------------------------------------- /src/loss/RankingLoss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | @author: zifyloo 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | import torch.nn.functional as F 11 | 12 | 13 | def calculate_similarity(image_embedding, text_embedding): 14 | image_embedding = image_embedding.view(image_embedding.size(0), -1) 15 | text_embedding = text_embedding.view(text_embedding.size(0), -1) 16 | image_embedding_norm = image_embedding / (image_embedding.norm(dim=1, keepdim=True) + 1e-8) 17 | text_embedding_norm = text_embedding / (text_embedding.norm(dim=1, keepdim=True) + 1e-8) 18 | 19 | similarity = torch.mm(image_embedding_norm, text_embedding_norm.t()) 20 | 21 | similarity_match = torch.sum(image_embedding_norm * text_embedding_norm, dim=1) 22 | 23 | return similarity, similarity_match 24 | 25 | 26 | def calculate_margin_cr(similarity_match_cr, similarity_match, auto_margin_flag, margin): 27 | if auto_margin_flag: 28 | lambda_cr = abs(similarity_match_cr.detach()) / abs(similarity_match.detach()) 29 | ones = torch.ones_like(lambda_cr) 30 | data = torch.ge(ones, lambda_cr).float() 31 | data_2 = torch.ge(lambda_cr, ones).float() 32 | lambda_cr = data * lambda_cr + data_2 33 | 34 | lambda_cr = lambda_cr.detach().cpu().numpy() 35 | margin_cr = ((lambda_cr + 1) * margin) / 2.0 36 | else: 37 | margin_cr = margin / 2.0 38 | 39 | return margin_cr 40 | 41 | 42 | class CRLoss(nn.Module): 43 | 44 | def __init__(self, opt): 45 | super(CRLoss, self).__init__() 46 | 47 | self.device = opt.device 48 | self.margin = np.array([opt.margin]).repeat(opt.batch_size) 49 | self.beta = opt.cr_beta 50 | 51 | def semi_hard_negative(self, loss, margin): 52 | negative_index = np.where(np.logical_and(loss < margin, loss > 0))[0] 53 | return np.random.choice(negative_index) if len(negative_index) > 0 else None 54 | 55 | def get_triplets(self, similarity, labels, auto_margin_flag, margin): 56 | 57 | similarity = similarity.cpu().data.numpy() 58 | 59 | labels = labels.cpu().data.numpy() 60 | triplets = [] 61 | 62 | for idx, label in enumerate(labels): # same class calculate together 63 | if margin[idx] >= 0.16 or auto_margin_flag is False: 64 | negative = np.where(labels != label)[0] 65 | 66 | ap_sim = similarity[idx, idx] 67 | 68 | loss = similarity[idx, negative] - ap_sim + margin[idx] 69 | 70 | negetive_index = self.semi_hard_negative(loss, margin[idx]) 71 | 72 | if negetive_index is not None: 73 | triplets.append([idx, idx, negative[negetive_index]]) 74 | 75 | if len(triplets) == 0: 76 | triplets.append([idx, idx, negative[0]]) 77 | 78 | triplets = torch.LongTensor(np.array(triplets)) 79 | 80 | return_margin = torch.FloatTensor(np.array(margin[triplets[:, 0]])).to(self.device) 81 | 82 | return triplets, return_margin 83 | 84 | def calculate_loss(self, similarity, label, auto_margin_flag, margin): 85 | 86 | image_triplets, img_margin = self.get_triplets(similarity, label, auto_margin_flag, margin) 87 | text_triplets, txt_margin = self.get_triplets(similarity.t(), label, auto_margin_flag, margin) 88 | # print(img_margin[:10], img_margin.size()) 89 | # print(txt_margin[:10], txt_margin.size()) 90 | image_anchor_loss = F.relu(img_margin 91 | - similarity[image_triplets[:, 0], image_triplets[:, 1]] 92 | + similarity[image_triplets[:, 0], image_triplets[:, 2]]) 93 | 94 | similarity = similarity.t() 95 | text_anchor_loss = F.relu(txt_margin 96 | - similarity[text_triplets[:, 0], text_triplets[:, 1]] 97 | + similarity[text_triplets[:, 0], text_triplets[:, 2]]) 98 | 99 | loss = torch.sum(image_anchor_loss) + torch.sum(text_anchor_loss) 100 | 101 | return loss 102 | 103 | def forward(self, img, txt, txt_cr, labels, auto_margin_flag): 104 | 105 | similarity, similarity_match = calculate_similarity(img, txt) 106 | similarity_cr, similarity_cr_match = calculate_similarity(img, txt_cr) 107 | margin_cr = calculate_margin_cr(similarity_cr_match, similarity_match, auto_margin_flag, self.margin) 108 | 109 | cr_loss = self.calculate_loss(similarity, labels, auto_margin_flag, self.margin) \ 110 | + self.beta * self.calculate_loss(similarity_cr, labels, auto_margin_flag, margin_cr) 111 | 112 | return cr_loss 113 | 114 | -------------------------------------------------------------------------------- /src/model/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | @author: zifyloo 5 | """ 6 | 7 | from torch import nn 8 | from model.text_feature_extract import TextExtract 9 | from torchvision import models 10 | import torch 11 | from torch.nn import init 12 | from torch.nn import functional as F 13 | 14 | 15 | def l2norm(X, dim, eps=1e-8): 16 | """L2-normalize columns of X 17 | """ 18 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 19 | X = torch.div(X, norm) 20 | return X 21 | 22 | 23 | def weights_init_kaiming(m): 24 | classname = m.__class__.__name__ 25 | if classname.find('Conv2d') != -1: 26 | init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 27 | elif classname.find('Linear') != -1: 28 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 29 | init.constant_(m.bias.data, 0.0) 30 | elif classname.find('BatchNorm1d') != -1: 31 | init.normal(m.weight.data, 1.0, 0.02) 32 | init.constant_(m.bias.data, 0.0) 33 | elif classname.find('BatchNorm2d') != -1: 34 | init.constant_(m.weight.data, 1) 35 | init.constant_(m.bias.data, 0) 36 | 37 | 38 | def weights_init_classifier(m): 39 | classname = m.__class__.__name__ 40 | if classname.find('Linear') != -1: 41 | init.normal(m.weight.data, std=0.001) 42 | init.constant(m.bias.data, 0.0) 43 | 44 | 45 | class conv(nn.Module): 46 | 47 | def __init__(self, input_dim, output_dim, relu=False, BN=False): 48 | super(conv, self).__init__() 49 | 50 | block = [] 51 | block += [nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False)] 52 | 53 | if BN: 54 | block += [nn.BatchNorm2d(output_dim)] 55 | if relu: 56 | block += [nn.LeakyReLU(0.25, inplace=True)] 57 | 58 | self.block = nn.Sequential(*block) 59 | self.block.apply(weights_init_kaiming) 60 | 61 | def forward(self, x): 62 | x = self.block(x) 63 | x = x.squeeze(3).squeeze(2) 64 | return x 65 | 66 | 67 | class NonLocalNet(nn.Module): 68 | def __init__(self, opt, dim_cut=8): 69 | super(NonLocalNet, self).__init__() 70 | self.opt = opt 71 | 72 | up_dim_conv = [] 73 | part_sim_conv = [] 74 | cur_sim_conv = [] 75 | conv_local_att = [] 76 | for i in range(opt.part): 77 | up_dim_conv.append(conv(opt.feature_length//dim_cut, 1024, relu=True, BN=True)) 78 | part_sim_conv.append(conv(opt.feature_length, opt.feature_length // dim_cut, relu=True, BN=False)) 79 | cur_sim_conv.append(conv(opt.feature_length, opt.feature_length // dim_cut, relu=True, BN=False)) 80 | conv_local_att.append(conv(opt.feature_length, 512)) 81 | 82 | self.up_dim_conv = nn.Sequential(*up_dim_conv) 83 | self.part_sim_conv = nn.Sequential(*part_sim_conv) 84 | self.cur_sim_conv = nn.Sequential(*cur_sim_conv) 85 | self.conv_local_att = nn.Sequential(*conv_local_att) 86 | 87 | self.zero_eye = (torch.eye(opt.part, opt.part) * -1e6).unsqueeze(0).to(opt.device) 88 | 89 | self.lambda_softmax = 1 90 | 91 | def forward(self, embedding): 92 | embedding = embedding.unsqueeze(3) 93 | embedding_part_sim = [] 94 | embedding_cur_sim = [] 95 | 96 | for i in range(self.opt.part): 97 | embedding_i = embedding[:, :, i, :].unsqueeze(2) 98 | 99 | embedding_part_sim_i = self.part_sim_conv[i](embedding_i).unsqueeze(2) 100 | embedding_part_sim.append(embedding_part_sim_i) 101 | 102 | embedding_cur_sim_i = self.cur_sim_conv[i](embedding_i).unsqueeze(2) 103 | embedding_cur_sim.append(embedding_cur_sim_i) 104 | 105 | embedding_part_sim = torch.cat(embedding_part_sim, dim=2) 106 | embedding_cur_sim = torch.cat(embedding_cur_sim, dim=2) 107 | 108 | embedding_part_sim_norm = l2norm(embedding_part_sim, dim=1) # N*D*n 109 | embedding_cur_sim_norm = l2norm(embedding_cur_sim, dim=1) # N*D*n 110 | self_att = torch.bmm(embedding_part_sim_norm.transpose(1, 2), embedding_cur_sim_norm) # N*n*n 111 | self_att = self_att + self.zero_eye.repeat(self_att.size(0), 1, 1) 112 | self_att = F.softmax(self_att * self.lambda_softmax, dim=1) # .transpose(1, 2).contiguous() 113 | embedding_att = torch.bmm(embedding_part_sim_norm, self_att).unsqueeze(3) 114 | 115 | embedding_att_up_dim = [] 116 | for i in range(self.opt.part): 117 | embedding_att_up_dim_i = embedding_att[:, :, i, :].unsqueeze(2) 118 | embedding_att_up_dim_i = self.up_dim_conv[i](embedding_att_up_dim_i).unsqueeze(2) 119 | embedding_att_up_dim.append(embedding_att_up_dim_i) 120 | embedding_att_up_dim = torch.cat(embedding_att_up_dim, dim=2).unsqueeze(3) 121 | 122 | embedding_att = embedding + embedding_att_up_dim 123 | 124 | embedding_local_att = [] 125 | for i in range(self.opt.part): 126 | embedding_att_i = embedding_att[:, :, i, :].unsqueeze(2) 127 | embedding_att_i = self.conv_local_att[i](embedding_att_i).unsqueeze(2) 128 | embedding_local_att.append(embedding_att_i) 129 | 130 | embedding_local_att = torch.cat(embedding_local_att, 2) 131 | 132 | return embedding_local_att.squeeze() 133 | 134 | 135 | class TextImgPersonReidNet(nn.Module): 136 | 137 | def __init__(self, opt): 138 | super(TextImgPersonReidNet, self).__init__() 139 | 140 | self.opt = opt 141 | resnet50 = models.resnet50(pretrained=True) 142 | self.ImageExtract = nn.Sequential(*(list(resnet50.children())[:-2])) 143 | self.TextExtract = TextExtract(opt) 144 | 145 | self.global_avgpool = nn.AdaptiveMaxPool2d((1, 1)) 146 | self.local_avgpool = nn.AdaptiveMaxPool2d((opt.part, 1)) 147 | 148 | conv_local = [] 149 | for i in range(opt.part): 150 | conv_local.append(conv(2048, opt.feature_length)) 151 | self.conv_local = nn.Sequential(*conv_local) 152 | 153 | self.conv_global = conv(2048, opt.feature_length) 154 | 155 | self.non_local_net = NonLocalNet(opt, dim_cut=2) 156 | self.leaky_relu = nn.LeakyReLU(0.25, inplace=True) 157 | 158 | self.conv_word_classifier = nn.Sequential( 159 | nn.Conv2d(2048, 6, kernel_size=1, bias=False), 160 | nn.Sigmoid() 161 | ) 162 | 163 | def forward(self, image, caption_id, text_length): 164 | 165 | img_global, img_local, img_non_local = self.img_embedding(image) 166 | txt_global, txt_local, txt_non_local = self.txt_embedding(caption_id, text_length) 167 | 168 | return img_global, img_local, img_non_local, txt_global, txt_local, txt_non_local 169 | 170 | def img_embedding(self, image): 171 | 172 | image_feature = self.ImageExtract(image) 173 | 174 | image_feature_global = self.global_avgpool(image_feature) 175 | image_global = self.conv_global(image_feature_global).unsqueeze(2) 176 | 177 | image_feature_local = self.local_avgpool(image_feature) 178 | image_local = [] 179 | for i in range(self.opt.part): 180 | image_feature_local_i = image_feature_local[:, :, i, :] 181 | image_feature_local_i = image_feature_local_i.unsqueeze(2) 182 | image_embedding_local_i = self.conv_local[i](image_feature_local_i).unsqueeze(2) 183 | image_local.append(image_embedding_local_i) 184 | 185 | image_local = torch.cat(image_local, 2) 186 | 187 | image_non_local = self.leaky_relu(image_local) 188 | image_non_local = self.non_local_net(image_non_local) 189 | 190 | return image_global, image_local, image_non_local 191 | 192 | def txt_embedding(self, caption_id, text_length): 193 | 194 | text_feature_g, text_feature_l = self.TextExtract(caption_id, text_length) 195 | 196 | text_global, _ = torch.max(text_feature_g, dim=2, keepdim=True) 197 | text_global = self.conv_global(text_global).unsqueeze(2) 198 | 199 | text_feature_local = [] 200 | for text_i in range(text_feature_l.size(0)): 201 | text_feature_local_i = text_feature_l[text_i, :, :text_length[text_i]].unsqueeze(0) 202 | 203 | word_classifier_score_i = self.conv_word_classifier(text_feature_local_i) 204 | 205 | word_classifier_score_i = word_classifier_score_i.permute(0, 3, 2, 1).contiguous() 206 | text_feature_local_i = text_feature_local_i.repeat(1, 1, 1, 6).contiguous() 207 | 208 | text_feature_local_i = text_feature_local_i * word_classifier_score_i 209 | 210 | text_feature_local_i, _ = torch.max(text_feature_local_i, dim=2) 211 | 212 | text_feature_local.append(text_feature_local_i) 213 | 214 | text_feature_local = torch.cat(text_feature_local, dim=0) 215 | 216 | text_local = [] 217 | for p in range(self.opt.part): 218 | text_feature_local_conv_p = text_feature_local[:, :, p].unsqueeze(2).unsqueeze(2) 219 | text_feature_local_conv_p = self.conv_local[p](text_feature_local_conv_p).unsqueeze(2) 220 | text_local.append(text_feature_local_conv_p) 221 | text_local = torch.cat(text_local, dim=2) 222 | 223 | text_non_local = self.leaky_relu(text_local) 224 | text_non_local = self.non_local_net(text_non_local) 225 | 226 | return text_global, text_local, text_non_local 227 | 228 | -------------------------------------------------------------------------------- /src/model/text_feature_extract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | 7 | from torch import nn 8 | import torch 9 | 10 | 11 | class TextExtract(nn.Module): 12 | 13 | def __init__(self, opt): 14 | super(TextExtract, self).__init__() 15 | 16 | self.embedding_local = nn.Embedding(opt.vocab_size, 512, padding_idx=0) 17 | self.embedding_global = nn.Embedding(opt.vocab_size, 512, padding_idx=0) 18 | self.dropout = nn.Dropout(0.3) 19 | self.lstm = nn.LSTM(512, 2048, num_layers=1, bidirectional=True, bias=False) 20 | 21 | def forward(self, caption_id, text_length): 22 | 23 | text_embedding_global = self.embedding_global(caption_id) 24 | text_embedding_global = self.dropout(text_embedding_global) 25 | text_embedding_global = self.calculate_different_length_lstm(text_embedding_global, text_length, self.lstm) 26 | 27 | text_embedding_local = self.embedding_local(caption_id) 28 | text_embedding_local = self.dropout(text_embedding_local) 29 | text_embedding_local = self.calculate_different_length_lstm(text_embedding_local, text_length, self.lstm) 30 | 31 | return text_embedding_global, text_embedding_local 32 | 33 | def calculate_different_length_lstm(self, text_embedding, text_length, lstm): 34 | text_length = text_length.view(-1) 35 | _, sort_index = torch.sort(text_length, dim=0, descending=True) 36 | _, unsort_index = sort_index.sort() 37 | 38 | sortlength_text_embedding = text_embedding[sort_index, :] 39 | sort_text_length = text_length[sort_index] 40 | # print(sort_text_length) 41 | packed_text_embedding = nn.utils.rnn.pack_padded_sequence(sortlength_text_embedding, 42 | sort_text_length, 43 | batch_first=True) 44 | 45 | 46 | # self.lstm.flatten_parameters() 47 | packed_feature, _ = lstm(packed_text_embedding) # [hn, cn] 48 | total_length = text_embedding.size(1) 49 | sort_feature = nn.utils.rnn.pad_packed_sequence(packed_feature, 50 | batch_first=True, 51 | total_length=total_length) # including[feature, length] 52 | 53 | unsort_feature = sort_feature[0][unsort_index, :] 54 | unsort_feature = (unsort_feature[:, :, :int(unsort_feature.size(2) / 2)] 55 | + unsort_feature[:, :, int(unsort_feature.size(2) / 2):]) / 2 56 | 57 | return unsort_feature.permute(0, 2, 1).contiguous().unsqueeze(3) 58 | -------------------------------------------------------------------------------- /src/option/options.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | import argparse 7 | import torch 8 | import logging 9 | import os 10 | from utils.read_write_data import makedir 11 | 12 | logger = logging.getLogger() 13 | logger.setLevel(logging.INFO) 14 | 15 | 16 | class options(): 17 | def __init__(self): 18 | self._par = argparse.ArgumentParser(description='options for Deep Cross Modal') 19 | 20 | self._par.add_argument('--model_name', type=str, help='experiment name') 21 | self._par.add_argument('--mode', type=str, default='', help='choose mode [train or test]') 22 | 23 | self._par.add_argument('--epoch', type=int, default=60, help='train epoch') 24 | self._par.add_argument('--epoch_decay', type=list, default=[20, 40], help='decay epoch') 25 | self._par.add_argument('--epoch_begin', type=int, default=5, help='when calculate the auto margin') 26 | 27 | self._par.add_argument('--batch_size', type=int, default=64, help='batch size') 28 | self._par.add_argument('--adam_alpha', type=float, default=0.9, help='momentum term of adam') 29 | self._par.add_argument('--adam_beta', type=float, default=0.999, help='momentum term of adam') 30 | self._par.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam') 31 | self._par.add_argument('--margin', type=float, default=0.2, help='ranking loss margin') 32 | self._par.add_argument('--cr_beta', type=float, default=0.1, help='ranking loss margin') 33 | 34 | self._par.add_argument('--vocab_size', type=int, default=5000, help='the size of vocab') 35 | self._par.add_argument('--feature_length', type=int, default=512, help='the length of feature') 36 | self._par.add_argument('--class_num', type=int, default=11000, 37 | help='num of class for StarGAN training on second dataset') 38 | self._par.add_argument('--part', type=int, default=6, help='the num of image part') 39 | self._par.add_argument('--caption_length_max', type=int, default=100, help='the max length of caption') 40 | 41 | self._par.add_argument('--save_path', type=str, default='./checkpoints/test', 42 | help='save the result during training') 43 | self._par.add_argument('--GPU_id', type=str, default='2', help='choose GPU ID [0 1]') 44 | self._par.add_argument('--device', type=str, default='', help='cuda devie') 45 | self._par.add_argument('--dataset', type=str, help='choose the dataset ') 46 | self._par.add_argument('--dataroot', type=str, help='data root of the Data') 47 | 48 | self.opt = self._par.parse_args() 49 | 50 | self.opt.device = torch.device('cuda:{}'.format(self.opt.GPU_id[0])) 51 | 52 | 53 | def config(opt): 54 | 55 | log_config(opt) 56 | model_root = os.path.join(opt.save_path, 'model') 57 | if os.path.exists(model_root) is False: 58 | makedir(model_root) 59 | 60 | 61 | def log_config(opt): 62 | logroot = os.path.join(opt.save_path, 'log') 63 | if os.path.exists(logroot) is False: 64 | makedir(logroot) 65 | filename = os.path.join(logroot, opt.mode + '.log') 66 | handler = logging.FileHandler(filename) 67 | handler.setLevel(logging.INFO) 68 | formatter = logging.Formatter('%(message)s') 69 | handler.setFormatter(formatter) 70 | logger.addHandler(logging.StreamHandler()) 71 | logger.addHandler(handler) 72 | if opt.mode != 'test': 73 | logger.info(opt) 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | 7 | from option.options import options 8 | from data.dataloader import get_dataloader 9 | import torch 10 | from model.model import TextImgPersonReidNet 11 | import os 12 | from test_during_train import test 13 | 14 | 15 | def save_checkpoint(state, opt): 16 | 17 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 18 | torch.save(state, filename) 19 | 20 | 21 | def load_checkpoint(opt): 22 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 23 | state = torch.load(filename) 24 | print('Load the {} epoch parameter successfully'.format(state['epoch'])) 25 | 26 | return state 27 | 28 | 29 | def main(opt): 30 | opt.device = torch.device('cuda:{}'.format(opt.GPU_id)) 31 | 32 | opt.save_path = './checkpoints/{}/'.format(opt.dataset) + opt.model_name 33 | 34 | test_img_dataloader, test_txt_dataloader = get_dataloader(opt) 35 | 36 | network = TextImgPersonReidNet(opt).to(opt.device) 37 | 38 | test_best = 0 39 | state = load_checkpoint(opt) 40 | network.load_state_dict(state['network']) 41 | epoch = state['epoch'] 42 | 43 | print(opt.model_name) 44 | network.eval() 45 | test(opt, epoch + 1, network, test_img_dataloader, test_txt_dataloader, test_best, False) 46 | network.train() 47 | 48 | 49 | if __name__ == '__main__': 50 | opt = options().opt 51 | main(opt) 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /src/test_during_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | 7 | import torch 8 | import numpy as np 9 | import os 10 | import torch.nn.functional as F 11 | from utils.read_write_data import write_txt 12 | 13 | 14 | def calculate_similarity(image_feature_local, text_feature_local): 15 | 16 | image_feature_local = image_feature_local.view(image_feature_local.size(0), -1) 17 | image_feature_local = image_feature_local / image_feature_local.norm(dim=1, keepdim=True) 18 | 19 | text_feature_local = text_feature_local.view(text_feature_local.size(0), -1) 20 | text_feature_local = text_feature_local / text_feature_local.norm(dim=1, keepdim=True) 21 | 22 | similarity = torch.mm(image_feature_local, text_feature_local.t()) 23 | 24 | return similarity.cpu() 25 | 26 | 27 | def calculate_ap(similarity, label_query, label_gallery): 28 | """ 29 | calculate the similarity, and rank the distance, according to the distance, calculate the ap, cmc 30 | :param label_query: the id of query [1] 31 | :param label_gallery:the id of gallery [N] 32 | :return: ap, cmc 33 | """ 34 | 35 | index = np.argsort(similarity)[::-1] # the index of the similarity from huge to small 36 | good_index = np.argwhere(label_gallery == label_query) # the index of the same label in gallery 37 | 38 | cmc = np.zeros(index.shape) 39 | 40 | mask = np.in1d(index, good_index) # get the flag the if index[i] is in the good_index 41 | 42 | precision_result = np.argwhere(mask == True) # get the situation of the good_index in the index 43 | 44 | precision_result = precision_result.reshape(precision_result.shape[0]) 45 | 46 | if precision_result.shape[0] != 0: 47 | cmc[int(precision_result[0]):] = 1 # get the cmc 48 | 49 | d_recall = 1.0 / len(precision_result) 50 | ap = 0 51 | 52 | for i in range(len(precision_result)): # ap is to calculate the PR area 53 | precision = (i + 1) * 1.0 / (precision_result[i] + 1) 54 | 55 | if precision_result[i] != 0: 56 | old_precision = i * 1.0 / precision_result[i] 57 | else: 58 | old_precision = 1.0 59 | 60 | ap += d_recall * (old_precision + precision) / 2 61 | 62 | return ap, cmc 63 | else: 64 | return None, None 65 | 66 | 67 | def evaluate(similarity, label_query, label_gallery): 68 | similarity = similarity.numpy() 69 | label_query = label_query.numpy() 70 | label_gallery = label_gallery.numpy() 71 | 72 | cmc = np.zeros(label_gallery.shape) 73 | ap = 0 74 | for i in range(len(label_query)): 75 | ap_i, cmc_i = calculate_ap(similarity[i, :], label_query[i], label_gallery) 76 | cmc += cmc_i 77 | ap += ap_i 78 | """ 79 | cmc_i is the vector [0,0,...1,1,..1], the first 1 is the first right prediction n, 80 | rank-n and the rank-k after it all add one right prediction, therefore all of them's index mark 1 81 | Through the add all the vector and then divive the n_query, we can get the rank-k accuracy cmc 82 | cmc[k-1] is the rank-k accuracy 83 | """ 84 | cmc = cmc / len(label_query) 85 | map = ap / len(label_query) # map = sum(ap) / n_query 86 | 87 | # print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f' % (cmc[0], cmc[4], cmc[9], map)) 88 | 89 | return cmc, map 90 | 91 | 92 | def evaluate_without_matching_image(similarity, label_query, label_gallery, txt_img_index): 93 | similarity = similarity.numpy() 94 | label_query = label_query.numpy() 95 | label_gallery = label_gallery.numpy() 96 | 97 | cmc = np.zeros(label_gallery.shape[0] - 1) 98 | ap = 0 99 | count = 0 100 | for i in range(len(label_query)): 101 | 102 | similarity_i = similarity[i, :] 103 | similarity_i = np.delete(similarity_i, txt_img_index[i]) 104 | label_gallery_i = np.delete(label_gallery, txt_img_index[i]) 105 | ap_i, cmc_i = calculate_ap(similarity_i, label_query[i], label_gallery_i) 106 | if ap_i is not None: 107 | cmc += cmc_i 108 | ap += ap_i 109 | else: 110 | count += 1 111 | """ 112 | cmc_i is the vector [0,0,...1,1,..1], the first 1 is the first right prediction n, 113 | rank-n and the rank-k after it all add one right prediction, therefore all of them's index mark 1 114 | Through the add all the vector and then divive the n_query, we can get the rank-k accuracy cmc 115 | cmc[k-1] is the rank-k accuracy 116 | """ 117 | cmc = cmc / (len(label_query) - count) 118 | map = ap / (len(label_query) - count) # map = sum(ap) / n_query 119 | # print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f' % (cmc[0], cmc[4], cmc[9], map)) 120 | 121 | return cmc, map 122 | 123 | 124 | def load_checkpoint(model_root, model_name): 125 | filename = os.path.join(model_root, 'model', model_name) 126 | state = torch.load(filename, map_location='cpu') 127 | 128 | return state 129 | 130 | 131 | def write_result(similarity, img_labels, txt_labels, name, txt_root, best_txt_root, epoch, best): 132 | write_txt(name, txt_root) 133 | print(name) 134 | t2i_cmc, t2i_map = evaluate(similarity.t(), txt_labels, img_labels) 135 | str = "t2i: @R1: {:.4}, @R5: {:.4}, @R10: {:.4}, map: {:.4}".format(t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_map) 136 | write_txt(str, txt_root) 137 | write_txt(str, txt_root) 138 | print(str) 139 | 140 | if t2i_cmc[0] > best: 141 | str = "Testing Epoch: {}".format(epoch) 142 | write_txt(str, best_txt_root) 143 | write_txt(name, best_txt_root) 144 | str = "t2i: @R1: {:.4}, @R5: {:.4}, @R10: {:.4}, map: {:.4}".format(t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_map) 145 | write_txt(str, best_txt_root) 146 | 147 | return t2i_cmc[0] 148 | else: 149 | return best 150 | 151 | 152 | def test(opt, epoch, network, img_dataloader, txt_dataloader, best, return_flag=True): 153 | 154 | best_txt_root = os.path.join(opt.save_path, 'log', 'best_test.log') 155 | txt_root = os.path.join(opt.save_path, 'log', 'test_separate.log') 156 | 157 | str = "Testing Epoch: {}".format(epoch) 158 | write_txt(str, txt_root) 159 | print(str) 160 | 161 | image_feature_global = [] 162 | image_feature_local = [] 163 | image_feature_non_local = [] 164 | img_labels = [] 165 | 166 | for times, [image, label] in enumerate(img_dataloader): 167 | 168 | image = image.to(opt.device) 169 | label = label.to(opt.device) 170 | 171 | with torch.no_grad(): 172 | img_global_i, img_local_i, img_non_local_i = network.img_embedding(image) 173 | 174 | image_feature_global.append(img_global_i) 175 | image_feature_local.append(img_local_i) 176 | image_feature_non_local.append(img_non_local_i) 177 | img_labels.append(label.view(-1)) 178 | 179 | image_feature_local = torch.cat(image_feature_local, 0) 180 | image_feature_global = torch.cat(image_feature_global, 0) 181 | image_feature_non_local = torch.cat(image_feature_non_local, 0) 182 | img_labels = torch.cat(img_labels, 0) 183 | 184 | text_feature_local = [] 185 | text_feature_global = [] 186 | text_feature_non_local = [] 187 | txt_labels = [] 188 | 189 | for times, [label, caption_code, caption_length] in enumerate(txt_dataloader): 190 | label = label.to(opt.device) 191 | caption_code = caption_code.to(opt.device).long() 192 | caption_length = caption_length.to(opt.device) 193 | 194 | with torch.no_grad(): 195 | text_global_i, text_local_i, text_non_local_i = network.txt_embedding(caption_code, caption_length) 196 | text_feature_local.append(text_local_i) 197 | text_feature_global.append(text_global_i) 198 | text_feature_non_local.append(text_non_local_i) 199 | txt_labels.append(label.view(-1)) 200 | 201 | text_feature_local = torch.cat(text_feature_local, 0) 202 | text_feature_global = torch.cat(text_feature_global, 0) 203 | text_feature_non_local = torch.cat(text_feature_non_local, 0) 204 | txt_labels = torch.cat(txt_labels, 0) 205 | 206 | similarity_local = calculate_similarity(image_feature_local, text_feature_local) 207 | # similarity_local = 0 208 | similarity_global = calculate_similarity(image_feature_global, text_feature_global) 209 | similarity_non_local = calculate_similarity(image_feature_non_local, text_feature_non_local) 210 | similarity_all = similarity_local + similarity_global + similarity_non_local 211 | 212 | img_labels = img_labels.cpu() 213 | txt_labels = txt_labels.cpu() 214 | 215 | best = write_result(similarity_global, img_labels, txt_labels, 'similarity_global:', 216 | txt_root, best_txt_root, epoch, best) 217 | 218 | best = write_result(similarity_local, img_labels, txt_labels, 'similarity_local:', 219 | txt_root, best_txt_root, epoch, best) 220 | 221 | best = write_result(similarity_non_local, img_labels, txt_labels, 'similarity_non_local:', 222 | txt_root, best_txt_root, epoch, best) 223 | 224 | best = write_result(similarity_all, img_labels, txt_labels, 'similarity_all:', 225 | txt_root, best_txt_root, epoch, best) 226 | if return_flag: 227 | return best 228 | 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | 7 | from option.options import options, config 8 | from data.dataloader import get_dataloader 9 | import torch 10 | from model.model import TextImgPersonReidNet 11 | from loss.Id_loss import Id_Loss 12 | from loss.RankingLoss import CRLoss 13 | from torch import optim 14 | import logging 15 | import os 16 | from test_during_train import test 17 | from torch.autograd import Variable 18 | 19 | 20 | logger = logging.getLogger() 21 | logger.setLevel(logging.INFO) 22 | 23 | 24 | def save_checkpoint(state, opt): 25 | 26 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 27 | torch.save(state, filename) 28 | 29 | 30 | def train(opt): 31 | opt.device = torch.device('cuda:{}'.format(opt.GPU_id)) 32 | 33 | opt.save_path = './checkpoints/{}/'.format(opt.dataset) + opt.model_name 34 | 35 | config(opt) 36 | train_dataloader = get_dataloader(opt) 37 | opt.mode = 'test' 38 | test_img_dataloader, test_txt_dataloader = get_dataloader(opt) 39 | opt.mode = 'train' 40 | 41 | id_loss_fun_global = Id_Loss(opt, 1, opt.feature_length).to(opt.device) 42 | id_loss_fun_local = Id_Loss(opt, opt.part, opt.feature_length).to(opt.device) 43 | id_loss_fun_non_local = Id_Loss(opt, opt.part, 512).to(opt.device) 44 | cr_loss_fun = CRLoss(opt) 45 | network = TextImgPersonReidNet(opt).to(opt.device) 46 | 47 | cnn_params = list(map(id, network.ImageExtract.parameters())) 48 | other_params = filter(lambda p: id(p) not in cnn_params, network.parameters()) 49 | other_params = list(other_params) 50 | other_params.extend(list(id_loss_fun_global.parameters())) 51 | other_params.extend(list(id_loss_fun_local.parameters())) 52 | other_params.extend(list(id_loss_fun_non_local.parameters())) 53 | param_groups = [{'params': other_params, 'lr': opt.lr}, 54 | {'params': network.ImageExtract.parameters(), 'lr': opt.lr * 0.1}] 55 | 56 | optimizer = optim.Adam(param_groups, betas=(opt.adam_alpha, opt.adam_beta)) 57 | 58 | test_best = 0 59 | test_history = 0 60 | 61 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, opt.epoch_decay) 62 | 63 | for epoch in range(opt.epoch): 64 | 65 | id_loss_sum = 0 66 | ranking_loss_sum = 0 67 | 68 | for param in optimizer.param_groups: 69 | logging.info('lr:{}'.format(param['lr'])) 70 | 71 | for times, [image, label, caption_code, caption_length, caption_code_cr, caption_length_cr] in enumerate( 72 | train_dataloader): 73 | 74 | image = Variable(image.to(opt.device)) 75 | label = Variable(label.to(opt.device)) 76 | caption_code = Variable(caption_code.to(opt.device).long()) 77 | caption_length = caption_length.to(opt.device) 78 | caption_code_cr = Variable(caption_code_cr.to(opt.device).long()) 79 | caption_length_cr = caption_length_cr.to(opt.device) 80 | 81 | img_global, img_local, img_non_local, txt_global, txt_local, txt_non_local = network(image, caption_code, 82 | caption_length) 83 | 84 | txt_global_cr, txt_local_cr, txt_non_local_cr = network.txt_embedding(caption_code_cr, caption_length_cr) 85 | 86 | id_loss_global = id_loss_fun_global(img_global, txt_global, label) 87 | id_loss_local = id_loss_fun_local(img_local, txt_local, label) 88 | id_loss_non_local = id_loss_fun_non_local(img_non_local, txt_non_local, label) 89 | id_loss = id_loss_global + (id_loss_local + id_loss_non_local) * 0.5 90 | 91 | cr_loss_global = cr_loss_fun(img_global, txt_global, txt_global_cr, label, epoch >= opt.epoch_begin) 92 | cr_loss_local = cr_loss_fun(img_local, txt_local, txt_local_cr, label, epoch >= opt.epoch_begin) 93 | cr_loss_non_local = cr_loss_fun(img_non_local, txt_non_local, 94 | txt_non_local_cr, label, epoch >= opt.epoch_begin) 95 | 96 | ranking_loss = cr_loss_global + (cr_loss_local + cr_loss_non_local) * 0.5 97 | 98 | optimizer.zero_grad() 99 | loss = (id_loss + ranking_loss) 100 | loss.backward() 101 | optimizer.step() 102 | 103 | if (times + 1) % 50 == 0: 104 | logging.info("Epoch: %d/%d Setp: %d, ranking_loss: %.2f, id_loss: %.2f" 105 | % (epoch + 1, opt.epoch, times + 1, ranking_loss, id_loss)) 106 | 107 | ranking_loss_sum += ranking_loss 108 | id_loss_sum += id_loss 109 | ranking_loss_avg = ranking_loss_sum / (times + 1) 110 | id_loss_avg = id_loss_sum / (times + 1) 111 | 112 | logging.info("Epoch: %d/%d , ranking_loss: %.2f, id_loss: %.2f" 113 | % (epoch + 1, opt.epoch, ranking_loss_avg, id_loss_avg)) 114 | 115 | print(opt.model_name) 116 | network.eval() 117 | test_best = test(opt, epoch + 1, network, test_img_dataloader, test_txt_dataloader, test_best) 118 | network.train() 119 | 120 | if test_best > test_history: 121 | test_history = test_best 122 | state = { 123 | 'network': network.cpu().state_dict(), 124 | 'test_best': test_best, 125 | 'epoch': epoch, 126 | 'WN': id_loss_fun_non_local.cpu().state_dict(), 127 | 'WL': id_loss_fun_local.cpu().state_dict(), 128 | } 129 | save_checkpoint(state, opt) 130 | network.to(opt.device) 131 | id_loss_fun_non_local.to(opt.device) 132 | id_loss_fun_local.to(opt.device) 133 | 134 | scheduler.step() 135 | 136 | logging.info('Training Done') 137 | 138 | 139 | if __name__ == '__main__': 140 | opt = options().opt 141 | train(opt) 142 | -------------------------------------------------------------------------------- /src/utils/read_write_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | the tool to read or write the data. Have a good luck ! 4 | 5 | @author: zifyloo 6 | """ 7 | 8 | import os 9 | import json 10 | import pickle 11 | 12 | 13 | def makedir(root): 14 | if not os.path.exists(root): 15 | os.makedirs(root) 16 | 17 | 18 | def write_json(data, root): 19 | with open(root, 'w') as f: 20 | json.dump(data, f) 21 | 22 | 23 | def read_json(root): 24 | with open(root, 'r') as f: 25 | data = json.load(f) 26 | 27 | return data 28 | 29 | 30 | def read_dict(root): 31 | with open(root, 'rb') as f: 32 | data = pickle.load(f) 33 | 34 | return data 35 | 36 | 37 | def save_dict(data, name): 38 | with open(name + '.pkl', 'wb') as f: 39 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 40 | 41 | 42 | def write_txt(data, name): 43 | with open(name, 'a') as f: 44 | f.write(data) 45 | f.write('\n') 46 | 47 | 48 | 49 | 50 | --------------------------------------------------------------------------------