├── README.md ├── config.py ├── datasets ├── data.sh ├── directory.py ├── pedes.py └── preprocess.py ├── models ├── bi_lstm.py ├── mobilenet.py ├── model.py └── resnet.py ├── scripts ├── run.sh ├── test.sh └── train.sh ├── test.py ├── test_config.py ├── train.py ├── train_config.py └── utils ├── directory.py ├── metric.py ├── statistics.py └── visualize.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep Cross-Modal Projection Learning for Image-Text Matching 2 | This is a Pytorch implmentation for the paper [Deep Cross-Modal Projection Learning for Image-Text Matching](http://openaccess.thecvf.com/content_ECCV_2018/papers/Ying_Zhang_Deep_Cross-Modal_Projection_ECCV_2018_paper.pdf). 3 | The official implementation in TensorFlow can be found [here](https://github.com/YingZhangDUT/Cross-Modal-Projection-Learning). 4 | ## Requirement 5 | * Python 3.5 6 | * Pytorch 1.0.0 & torchvision 0.2.1 7 | * numpy 8 | * scipy 1.2.1 9 | 10 | ## Data Preparation 11 | - Download the pre-computed/pre-extracted data from [GoogleDrive](https://drive.google.com/drive/folders/1Nbx5Oa5746_uAcuRi73DmuhQhrxvrAc9?usp=sharing) and move them to ```data/processed``` folder. Or you can use the file ```dataset/preprocess.py``` to prepare your own data. 12 | - *[Optional]* Download the pre-trained model weights from [GoogleDrive](https://drive.google.com/drive/folders/1LtTjWeGuLNvQYMTjdrYbdVjbxr7bLQQC?usp=sharing) and move them to ```pretrained_models``` folder. 13 | 14 | ## Training & Testing 15 | You should firstly change the param ```model_path``` to your current directory. 16 | ``` 17 | sh scripts/run.sh 18 | ``` 19 | You can directly run the code instead of performing training and testing seperately. 20 | Or training: 21 | ``` 22 | sh scripts/train.sh 23 | ``` 24 | Or testing: 25 | ``` 26 | sh scripts/test.sh 27 | ``` 28 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | import torch.backends.cudnn as cudnn 7 | import random 8 | import numpy as np 9 | import logging 10 | from datasets.pedes import CuhkPedes 11 | from models.model import Model 12 | from utils import directory 13 | 14 | logger = logging.getLogger() 15 | logger.setLevel(logging.INFO) 16 | 17 | 18 | def data_config(image_dir, anno_dir, batch_size, split, max_length, transform): 19 | data_split = CuhkPedes(image_dir, anno_dir, split, max_length, transform) 20 | if split == 'train': 21 | shuffle = True 22 | else: 23 | shuffle = False 24 | loader = data.DataLoader(data_split, batch_size, shuffle=shuffle, num_workers=4) 25 | return loader 26 | 27 | def network_config(args, split='train', param=None, resume=False, model_path=None, ema=False): 28 | network = Model(args) 29 | network = nn.DataParallel(network).cuda() 30 | cudnn.benchmark = True 31 | args.start_epoch = 0 32 | 33 | # process network params 34 | if resume: 35 | directory.check_file(model_path, 'model_file') 36 | checkpoint = torch.load(model_path) 37 | args.start_epoch = checkpoint['epoch'] + 1 38 | # best_prec1 = checkpoint['best_prec1'] 39 | #network.load_state_dict(checkpoint['state_dict']) 40 | network_dict = checkpoint['network'] 41 | if ema: 42 | logging.info('==> EMA Loading') 43 | network_dict.update(checkpoint['network_ema']) 44 | network.load_state_dict(network_dict) 45 | print('==> Loading checkpoint "{}"'.format(model_path)) 46 | else: 47 | # pretrained 48 | if model_path is not None: 49 | print('==> Loading from pretrained models') 50 | network_dict = network.state_dict() 51 | if args.image_model == 'mobilenet_v1': 52 | cnn_pretrained = torch.load(model_path)['state_dict'] 53 | start = 7 54 | else: 55 | cnn_pretrained = torch.load(model_path) 56 | start = 0 57 | # process keyword of pretrained model 58 | prefix = 'module.image_model.' 59 | pretrained_dict = {prefix + k[start:] :v for k,v in cnn_pretrained.items()} 60 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in network_dict} 61 | network_dict.update(pretrained_dict) 62 | network.load_state_dict(network_dict) 63 | 64 | # process optimizer params 65 | if split == 'test': 66 | optimizer = None 67 | else: 68 | # optimizer 69 | # different params for different part 70 | cnn_params = list(map(id, network.module.image_model.parameters())) 71 | other_params = filter(lambda p: id(p) not in cnn_params, network.parameters()) 72 | other_params = list(other_params) 73 | if param is not None: 74 | other_params.extend(list(param)) 75 | param_groups = [{'params':other_params}, 76 | {'params':network.module.image_model.parameters(), 'weight_decay':args.wd}] 77 | optimizer = torch.optim.Adam( 78 | param_groups, 79 | lr = args.lr, betas=(args.adam_alpha, args.adam_beta), eps=args.epsilon) 80 | if resume: 81 | optimizer.load_state_dict(checkpoint['optimizer']) 82 | 83 | print('Total params: %2.fM' % (sum(p.numel() for p in network.parameters()) / 1000000.0)) 84 | # seed 85 | manualSeed = random.randint(1, 10000) 86 | random.seed(manualSeed) 87 | np.random.seed(manualSeed) 88 | torch.manual_seed(manualSeed) 89 | torch.cuda.manual_seed_all(manualSeed) 90 | 91 | return network, optimizer 92 | 93 | 94 | def log_config(args, ca): 95 | filename = args.log_dir +'/' + ca + '.log' 96 | handler = logging.FileHandler(filename) 97 | handler.setLevel(logging.INFO) 98 | formatter = logging.Formatter('%(message)s') 99 | handler.setFormatter(formatter) 100 | logger.addHandler(logging.StreamHandler()) 101 | logger.addHandler(handler) 102 | logging.info(args) 103 | 104 | 105 | def dir_config(args): 106 | if not os.path.exists(args.image_dir): 107 | raise ValueError('Supply the dataset directory with --image_dir') 108 | if not os.path.exists(args.anno_dir): 109 | raise ValueError('Supply the anno file with --anno_dir') 110 | directory.makedir(args.log_dir) 111 | # save checkpoint 112 | directory.makedir(args.checkpoint_dir) 113 | directory.makedir(os.path.join(args.checkpoint_dir,'model_best')) 114 | 115 | 116 | def adjust_lr(optimizer, epoch, args): 117 | # Decay learning rate by args.lr_decay_ratio every args.epoches_decay 118 | if args.lr_decay_type == 'exponential': 119 | if '_' in args.epoches_decay: 120 | epoches_list = args.epoches_decay.split('_') 121 | epoches_list = [int(e) for e in epoches_list] 122 | for times, e in enumerate(epoches_list): 123 | if epoch / e == 0: 124 | lr = args.lr * ((1 - args.lr_decay_ratio) ** times) 125 | break 126 | times = len(epoches_list) 127 | lr = args.lr * ((1 - args.lr_decay_ratio) ** times) 128 | else: 129 | epoches_decay = int(args.epoches_decay) 130 | lr = args.lr * ((1 - args.lr_decay_ratio) ** (epoch // epoches_decay)) 131 | for param_group in optimizer.param_groups: 132 | param_group['lr'] = lr 133 | logging.info('lr:{}'.format(lr)) 134 | 135 | def lr_scheduler(optimizer, args): 136 | if '_' in args.epoches_decay: 137 | epoches_list = args.epoches_decay.split('_') 138 | epoches_list = [int(e) for e in epoches_list] 139 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, epoches_list) 140 | else: 141 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(args.epoches_decay)) 142 | return scheduler 143 | -------------------------------------------------------------------------------- /datasets/data.sh: -------------------------------------------------------------------------------- 1 | BASE_ROOT=/home/labyrinth7x/Codes/PersonSearch/Deep-Cross-Modal-Projection-Learning-for-Image-Text-Matching 2 | 3 | IMAGE_ROOT=$BASE_ROOT/data/CUHK-PEDES/imgs 4 | JSON_ROOT=$BASE_ROOT/data/reid_raw.json 5 | OUT_ROOT=$BASE_ROOT/data/processed_data 6 | 7 | 8 | echo "Process CUHK-PEDES dataset and save it as pickle form" 9 | 10 | python ${BASE_ROOT}/datasets/preprocess.py \ 11 | --img_root=${IMAGE_ROOT} \ 12 | --json_root=${JSON_ROOT} \ 13 | --out_root=${OUT_ROOT} \ 14 | --min_word_count 3 15 | -------------------------------------------------------------------------------- /datasets/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def makedir(root): 5 | if not os.path.exists(root): 6 | os.makedirs(root) 7 | 8 | 9 | def write_json(data, root): 10 | with open(root, 'w') as f: 11 | json.dump(data, f) 12 | -------------------------------------------------------------------------------- /datasets/pedes.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import os 4 | import pickle 5 | import h5py 6 | from PIL import Image 7 | from utils.directory import check_exists 8 | from scipy.misc import imread, imresize 9 | 10 | class CuhkPedes(data.Dataset): 11 | ''' 12 | Args: 13 | root (string): Base root directory of dataset where [split].pkl and [split].h5 exists 14 | split (string): 'train', 'val' or 'test' 15 | transform (callable, optional): A function/transform that takes in an PIL image 16 | and returns a transformed vector. E.g, ''transform.RandomCrop' 17 | target_transform (callable, optional): A funciton/transform that tkes in the 18 | targt and transfomrs it. 19 | ''' 20 | pklname_list = ['train.pkl', 'val.pkl', 'test.pkl'] 21 | h5name_list = ['train.h5', 'val.h5', 'test.h5'] 22 | 23 | def __init__(self, image_root, anno_root, split, max_length, transform=None, target_transform=None, cap_transform=None): 24 | 25 | self.image_root = image_root 26 | self.anno_root = anno_root 27 | self.max_length = max_length 28 | self.transform = transform 29 | self.target_transform = target_transform 30 | self.cap_transform = cap_transform 31 | self.split = split.lower() 32 | 33 | if not check_exists(self.image_root): 34 | raise RuntimeError('Dataset not found or corrupted.' + 35 | 'Please follow the directions to generate datasets') 36 | 37 | if self.split == 'train': 38 | self.pklname = self.pklname_list[0] 39 | #self.h5name = self.h5name_list[0] 40 | 41 | with open(os.path.join(self.anno_root, self.pklname), 'rb') as f_pkl: 42 | data = pickle.load(f_pkl) 43 | self.train_labels = data['labels'] 44 | self.train_captions = data['caption_id'] 45 | self.train_images = data['images_path'] 46 | #data_h5py = h5py.File(os.path.join(self.root, self.h5name), 'r') 47 | #self.train_images = data_h5py['images'] 48 | 49 | 50 | elif self.split == 'val': 51 | self.pklname = self.pklname_list[1] 52 | #self.h5name = self.h5name_list[1] 53 | with open(os.path.join(self.anno_root, self.pklname), 'rb') as f_pkl: 54 | data = pickle.load(f_pkl) 55 | self.val_labels = data['labels'] 56 | self.val_captions = data['caption_id'] 57 | self.val_images = data['images_path'] 58 | #data_h5py = h5py.File(os.path.join(self.root, self.h5name), 'r') 59 | #self.val_images = data_h5py['images'] 60 | 61 | elif self.split == 'test': 62 | self.pklname = self.pklname_list[2] 63 | #self.h5name = self.h5name_list[2] 64 | 65 | with open(os.path.join(self.anno_root, self.pklname), 'rb') as f_pkl: 66 | data = pickle.load(f_pkl) 67 | self.test_labels = data['labels'] 68 | self.test_captions = data['caption_id'] 69 | self.test_images = data['images_path'] 70 | 71 | #data_h5py = h5py.File(os.path.join(self.root, self.h5name), 'r') 72 | #self.test_images = data_h5py['images'] 73 | 74 | else: 75 | raise RuntimeError('Wrong split which should be one of "train","val" or "test"') 76 | 77 | def __getitem__(self, index): 78 | """ 79 | Args: 80 | index(int): Index 81 | Returns: 82 | tuple: (images, labels, captions) 83 | """ 84 | if self.split == 'train': 85 | img_path, caption, label = self.train_images[index], self.train_captions[index], self.train_labels[index] 86 | elif self.split == 'val': 87 | img_path, caption, label = self.val_images[index], self.val_captions[index], self.val_labels[index] 88 | else: 89 | img_path, caption, label = self.test_images[index], self.test_captions[index], self.test_labels[index] 90 | img_path = os.path.join(self.image_root, img_path) 91 | img = imread(img_path) 92 | img = imresize(img, (224,224)) 93 | if len(img.shape) == 2: 94 | img = np.dstack((img,img,img)) 95 | img = Image.fromarray(img) 96 | 97 | if self.transform is not None: 98 | img = self.transform(img) 99 | 100 | if self.target_transform is not None: 101 | label = self.target_transform(label) 102 | 103 | if self.cap_transform is not None: 104 | caption = self.cap_transform(caption) 105 | caption = caption[1:-1] 106 | caption = np.array(caption) 107 | caption, mask = self.fix_length(caption) 108 | return img, caption, label, mask 109 | 110 | def fix_length(self, caption): 111 | caption_len = caption.shape[0] 112 | if caption_len < self.max_length: 113 | pad = np.zeros((self.max_length - caption_len, 1), dtype=np.int64) 114 | caption = np.append(caption, pad) 115 | return caption, caption_len 116 | 117 | def __len__(self): 118 | if self.split == 'train': 119 | return len(self.train_labels) 120 | elif self.split == 'val': 121 | return len(self.val_labels) 122 | else: 123 | return len(self.test_labels) 124 | -------------------------------------------------------------------------------- /datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import argparse 4 | import string 5 | import os 6 | from utils.directory import write_json, makedir 7 | from collections import namedtuple 8 | 9 | 10 | ImageMetaData = namedtuple('ImageMetaData', ['id', 'image_path', 'captions', 'split']) 11 | ImageDecodeData = namedtuple('ImageDecodeData', ['id', 'image_path', 'captions_id', 'split']) 12 | 13 | 14 | class Vocabulary(object): 15 | """ 16 | Vocabulary wrapper 17 | """ 18 | def __init__(self, vocab, unk_id): 19 | """ 20 | :param vocab: A dictionary of word to word_id 21 | :param unk_id: Id of the bad/unknown words 22 | """ 23 | self._vocab = vocab 24 | self._unk_id = unk_id 25 | 26 | def word_to_id(self, word): 27 | if word not in self._vocab: 28 | return self._unk_id 29 | return self._vocab[word] 30 | 31 | 32 | def cap2tokens(cap): 33 | exclude = set(string.punctuation) 34 | caption = ''.join(c for c in cap if c not in exclude) 35 | tokens = caption.split() 36 | tokens = add_start_end(tokens) 37 | return tokens 38 | 39 | 40 | def add_start_end(tokens, start_word='', end_word=''): 41 | """ 42 | Add start and end words for a caption 43 | """ 44 | tokens_processed = [start_word] 45 | tokens_processed.extend(tokens) 46 | tokens_processed.append(end_word) 47 | return tokens_processed 48 | 49 | 50 | def process_captions(imgs): 51 | for img in imgs: 52 | img['processed_tokens'] = [] 53 | for s in img['captions']: 54 | tokens = cap2tokens(s) 55 | img['processed_tokens'].append(tokens) 56 | 57 | 58 | def build_vocab(imgs, args): 59 | print('start build vodabulary') 60 | counts = {} 61 | for img in imgs: 62 | for tokens in img['processed_tokens']: 63 | for word in tokens: 64 | counts[word] = counts.get(word, 0) + 1 65 | print('Total words:', len(counts)) 66 | 67 | # filter uncommon words and sort by descending count. 68 | # word_counts: a list of (words, count) for words satisfying the condition. 69 | word_counts = [(w,n) for w,n in counts.items() if n >= args.min_word_count] 70 | word_counts.sort(key = lambda x : x[1], reverse=True) 71 | print('Words in vocab:', len(word_counts)) 72 | 73 | # words_out: a list of (words, count) for words unsatisfying the condition. 74 | words_out = [(w,n) for w,n in counts.items() if n < args.min_word_count] 75 | bad_words = len(words_out) 76 | bad_count = sum(x[1] for x in words_out) 77 | 78 | # save the word counts file 79 | word_counts_root = os.path.join(args.out_root + '/word_counts.txt') 80 | with open(word_counts_root, 'w') as f: 81 | f.write('Total words: %d \n' % len(counts)) 82 | f.write('Words in vocabulary: %d \n' % len(word_counts)) 83 | f.write(str(word_counts)) 84 | 85 | word_reverse = [w for (w,n) in word_counts] 86 | vocab_dict = dict([(word, index) for (index, word) in enumerate(word_reverse)]) 87 | vocab = Vocabulary(vocab_dict, len(vocab_dict)) 88 | 89 | # Save word index as pickle form 90 | word_to_idx = {} 91 | for index, word in enumerate(word_reverse): 92 | word_to_idx[word] = index 93 | 94 | with open(os.path.join(args.out_root, 'word_to_index.pkl'), 'wb') as f: 95 | pickle.dump(word_to_idx, f) 96 | 97 | print('number of bad words: %d/%d = %.2f%%' % (bad_words, len(counts), bad_words * 100.0 / len(counts))) 98 | print('number of words in vocab: %d/%d = %.2f%%' % (len(word_counts), len(counts), len(word_counts) * 100.0 / len(counts))) 99 | print('number of Null: %d/%d = %.2f%%' % (bad_count, len(counts), bad_count * 100.0 / len(counts))) 100 | 101 | return vocab 102 | 103 | def load_vocab(args): 104 | 105 | with open(os.path.join(args.out_root, 'word_to_index.pkl'), 'rb') as f: 106 | word_to_idx = pickle.load(f) 107 | 108 | vocab = Vocabulary(word_to_idx, len(word_to_idx)) 109 | print('load vocabulary done') 110 | return vocab 111 | 112 | 113 | def process_metadata(split, data, args): 114 | """ 115 | Wrap data into ImageMatadata form 116 | """ 117 | id_to_captions = {} 118 | image_metadata = [] 119 | num_captions = 0 120 | count = 0 121 | 122 | for img in data: 123 | count += 1 124 | # absolute image path 125 | # filepath = os.path.join(args.img_root, img['file_path']) 126 | # relative image path 127 | filepath = img['file_path'] 128 | # assert os.path.exists(filepath) 129 | id = img['id'] - 1 130 | captions = img['processed_tokens'] 131 | id_to_captions.setdefault(id, []) 132 | id_to_captions[id].append(captions) 133 | assert split == img['split'], 'error: wrong split' 134 | image_metadata.append(ImageMetaData(id, filepath, captions, split)) 135 | num_captions += len(captions) 136 | 137 | print("Process metadata done!") 138 | print("Total %d captions %d images %d identities in %s" % (num_captions, count, len(id_to_captions), split)) 139 | with open(os.path.join(args.out_root, 'metadata_info.txt') ,'a') as f: 140 | f.write("Total %d captions %d images %d identities in %s" % (num_captions, count, len(id_to_captions), split)) 141 | f.write('\n') 142 | 143 | return image_metadata 144 | 145 | 146 | def process_decodedata(data, vocab): 147 | """ 148 | Decode ImageMetaData to ImageDecodeData 149 | Each item in imagedecodedata has 2 captions. (len(captions_id) = 2) 150 | """ 151 | image_decodedata = [] 152 | for img in data: 153 | image_path = img.image_path 154 | #image = imread(img.filepath) 155 | #image = imresize(image, (args.default_image_size, args.default_image_size)) 156 | # handle grayscale input images 157 | #if len(image.shape) == 2: 158 | # image = np.dstack((image, image, image)) 159 | # (height, width, channel) to (channel, height, weight) 160 | # (224,224,3) to (3,224,224)) 161 | #image = image.transpose(2,0,1) 162 | cap_to_vec = [] 163 | for cap in img.captions: 164 | cap_to_vec.append([vocab.word_to_id(word) for word in cap]) 165 | image_decodedata.append(ImageDecodeData(img.id, image_path, cap_to_vec, img.split)) 166 | 167 | print('Process decodedata done!') 168 | 169 | return image_decodedata 170 | 171 | 172 | def process_dataset(split, decodedata): 173 | # Process dataset 174 | 175 | # Arrange by caption in a sorted form 176 | dataset, label_range = create_dataset_sort(split, decodedata) 177 | write_dataset(split, dataset, args, label_range) 178 | 179 | 180 | def create_dataset_sort(split, data): 181 | images_sort = [] 182 | label_range = {} 183 | images = {} 184 | for img in data: 185 | label = img.id 186 | image = [ImageDecodeData(img.id, img.image_path, [caption_id], img.split) for caption_id in img.captions_id] 187 | if label in images: 188 | images[label].extend(image) 189 | label_range[label].append(len(image)) 190 | else: 191 | images[label] = image 192 | label_range[label] = [len(image)] 193 | 194 | print('=========== Arrange by id=============================') 195 | index = -1 196 | for label in images.keys(): 197 | # all captions arrange together 198 | images_sort.extend(images[label]) 199 | # label_range is arranged according to their actual index 200 | # label_range[label] = (previous, current] 201 | start = index 202 | for index_image in range(len(label_range[label])): 203 | label_range[label][index_image] += index 204 | index = label_range[label][index_image] 205 | label_range[label].append(start) 206 | 207 | return images_sort, label_range 208 | 209 | 210 | def write_dataset(split, data, args, label_range=None): 211 | """ 212 | Separate each component 213 | Write dataset into binary file 214 | """ 215 | caption_id = [] 216 | images_path = [] 217 | labels = [] 218 | 219 | for img in data: 220 | assert len(img.captions_id) == 1 221 | caption_id.append(img.captions_id[0]) 222 | labels.append(img.id) 223 | images_path.append(img.image_path) 224 | 225 | #N = len(images) 226 | data = {'caption_id':caption_id, 'labels':labels, 'images_path':images_path} 227 | 228 | if label_range is not None: 229 | data['label_range'] = label_range 230 | pickle_root = os.path.join(args.out_root, split + '_sort.pkl') 231 | else: 232 | pickle_root = os.path.join(args.out_root, split + '.pkl') 233 | # Write caption_id and labels as pickle form 234 | with open(pickle_root, 'wb') as f: 235 | pickle.dump(data, f) 236 | 237 | #h5py_root = os.path.join(args.out_root, split + '.h5') 238 | #f = h5py.File(h5py_root, 'w') 239 | #f.create_dataset('images', (N, 3, args.default_image_size, args.default_image_size), data=images) 240 | 241 | print('Save dataset') 242 | 243 | 244 | def generate_split(args): 245 | 246 | with open(args.json_root,'r') as f: 247 | imgs = json.load(f) 248 | # process caption 249 | process_captions(imgs) 250 | val_data = [] 251 | train_data = [] 252 | test_data = [] 253 | for img in imgs: 254 | if img['split'] == 'train': 255 | train_data.append(img) 256 | elif img['split'] =='val': 257 | val_data.append(img) 258 | else: 259 | test_data.append(img) 260 | write_json(train_data, os.path.join(args.out_root, 'train_reid.json')) 261 | write_json(val_data, os.path.join(args.out_root, 'val_reid.json')) 262 | write_json(test_data, os.path.join(args.out_root, 'test_reid.json')) 263 | 264 | return [train_data, val_data, test_data] 265 | 266 | 267 | def load_split(args): 268 | 269 | data = [] 270 | splits = ['train', 'val', 'test'] 271 | for split in splits: 272 | split_root = os.path.join(args.out_root, split + '_reid.json') 273 | with open(split_root, 'r') as f: 274 | split_data = json.load(f) 275 | data.append(split_data) 276 | 277 | print('load data done') 278 | return data 279 | 280 | 281 | def process_data(args): 282 | 283 | if args.first: 284 | train_data, val_data, test_data = generate_split(args) 285 | vocab = build_vocab(train_data, args) 286 | else: 287 | train_data, val_data, test_data = load_split(args) 288 | vocab = load_vocab(args) 289 | 290 | # Transform original data to Imagedata form. 291 | train_metadata = process_metadata('train', train_data, args) 292 | val_metadata = process_metadata('val', val_data, args) 293 | test_metadata = process_metadata('test', test_data, args) 294 | 295 | 296 | # Decode Imagedata to index caption and replace image file_root with image vecetor. 297 | train_decodedata = process_decodedata(train_metadata, vocab) 298 | val_decodedata = process_decodedata(val_metadata, vocab) 299 | test_decodedata = process_decodedata(test_metadata, vocab) 300 | 301 | 302 | process_dataset('train', train_decodedata) 303 | process_dataset('val', val_decodedata) 304 | process_dataset('test', test_decodedata) 305 | 306 | 307 | def parse_args(): 308 | parser = argparse.ArgumentParser(description='Command for data preprocessing') 309 | parser.add_argument('--img_root', type=str) 310 | parser.add_argument('--json_root', type=str) 311 | parser.add_argument('--out_root',type=str) 312 | parser.add_argument('--min_word_count', type=int) 313 | parser.add_argument('--default_image_size', type=int, default=224) 314 | parser.add_argument('--first', action='store_true') 315 | args = parser.parse_args() 316 | return args 317 | 318 | if __name__ == '__main__': 319 | args = parse_args() 320 | makedir(args.out_root) 321 | process_data(args) 322 | -------------------------------------------------------------------------------- /models/bi_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | 5 | seed_num = 223 6 | torch.manual_seed(seed_num) 7 | random.seed(seed_num) 8 | 9 | """ 10 | Neural Networks model : Bidirection LSTM 11 | """ 12 | 13 | 14 | class BiLSTM(nn.Module): 15 | def __init__(self, args): 16 | super(BiLSTM, self).__init__() 17 | 18 | self.hidden_dim = args.num_lstm_units 19 | 20 | V = args.vocab_size 21 | D = args.embedding_size 22 | 23 | # word embedding 24 | self.embed = nn.Embedding(V, D, padding_idx=0) 25 | 26 | self.bilstm = nn.ModuleList() 27 | self.bilstm.append(nn.LSTM(D, args.num_lstm_units, num_layers=1, dropout=0, bidirectional=False, bias=False)) 28 | 29 | self.bidirectional = args.bidirectional 30 | if self.bidirectional: 31 | self.bilstm.append(nn.LSTM(D, args.num_lstm_units, num_layers=1, dropout=0, bidirectional=False, bias=False)) 32 | 33 | 34 | def forward(self, text, text_length): 35 | embed = self.embed(text) 36 | 37 | # unidirectional lstm 38 | bilstm_out = self.bilstm_out(embed, text_length, 0) 39 | 40 | if self.bidirectional: 41 | index_reverse = list(range(embed.shape[0]-1, -1, -1)) 42 | index_reverse = torch.LongTensor(index_reverse).cuda() 43 | embed_reverse = embed.index_select(0, index_reverse) 44 | text_length_reverse = text_length.index_select(0, index_reverse) 45 | bilstm_out_bidirection = self.bilstm_out(embed_reverse, text_length_reverse, 1) 46 | bilstm_out_bidirection_reverse = bilstm_out_bidirection.index_select(0, index_reverse) 47 | bilstm_out = torch.cat([bilstm_out, bilstm_out_bidirection_reverse], dim=2) 48 | bilstm_out, _ = torch.max(bilstm_out, dim=1) 49 | bilstm_out = bilstm_out.unsqueeze(2).unsqueeze(2) 50 | return bilstm_out 51 | 52 | 53 | def bilstm_out(self, embed, text_length, index): 54 | 55 | _, idx_sort = torch.sort(text_length, dim=0, descending=True) 56 | _, idx_unsort = torch.sort(idx_sort, dim=0) 57 | 58 | embed_sort = embed.index_select(0, idx_sort) 59 | length_list = text_length[idx_sort] 60 | pack = nn.utils.rnn.pack_padded_sequence(embed_sort, length_list, batch_first=True) 61 | 62 | bilstm_sort_out, _ = self.bilstm[index](pack) 63 | bilstm_sort_out = nn.utils.rnn.pad_packed_sequence(bilstm_sort_out, batch_first=True) 64 | bilstm_sort_out = bilstm_sort_out[0] 65 | 66 | bilstm_out = bilstm_sort_out.index_select(0, idx_unsort) 67 | 68 | return bilstm_out 69 | 70 | 71 | def weight_init(self, m): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.xavier_uniform_(m.weight.data, 1) 74 | nn.init.constant(m.bias.data, 0) 75 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | """ 5 | Imported by https://github.com/marvis/pytorch-mobilenet/blob/master/main.py 6 | """ 7 | 8 | 9 | class MobileNetV1(nn.Module): 10 | def __init__(self, dropout_keep_prob=0.999): 11 | super(MobileNetV1, self).__init__() 12 | self.dropout_keep_prob = dropout_keep_prob 13 | self.dropout = nn.Dropout(1 - dropout_keep_prob) 14 | def conv_bn(inp, oup, stride): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 17 | nn.BatchNorm2d(oup), 18 | nn.ReLU6(inplace=True) 19 | ) 20 | 21 | def conv_dw(inp, oup, stride): 22 | return nn.Sequential( 23 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 24 | nn.BatchNorm2d(inp), 25 | nn.ReLU6(inplace=True), 26 | 27 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 28 | nn.BatchNorm2d(oup), 29 | nn.ReLU6(inplace=True), 30 | ) 31 | 32 | self.model = nn.Sequential( 33 | conv_bn(3, 32, 2), 34 | conv_dw(32, 64, 1), 35 | conv_dw(64, 128, 2), 36 | conv_dw(128, 128, 1), 37 | conv_dw(128, 256, 2), 38 | conv_dw(256, 256, 1), 39 | conv_dw(256, 512, 2), 40 | conv_dw(512, 512, 1), 41 | conv_dw(512, 512, 1), 42 | conv_dw(512, 512, 1), 43 | conv_dw(512, 512, 1), 44 | conv_dw(512, 512, 1), 45 | conv_dw(512, 1024, 2), 46 | conv_dw(1024, 1024, 1), 47 | nn.AvgPool2d(7), 48 | ) 49 | 50 | 51 | def weight_init(self, m): 52 | if isinstance(m, nn.Conv2d): 53 | # truncated_normal_initializer in tensorflow 54 | nn.init.normal_(m.weight.data, std=0.09) 55 | #nn.init.constant(m.bias.data, 0) 56 | 57 | def forward(self, x): 58 | x = self.model(x) 59 | x = self.dropout(x) 60 | return x 61 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .bi_lstm import BiLSTM 3 | from .mobilenet import MobileNetV1 4 | from .resnet import resnet50 5 | 6 | 7 | class Model(nn.Module): 8 | def __init__(self, args): 9 | super(Model, self).__init__() 10 | if args.image_model == 'mobilenet_v1': 11 | self.image_model = MobileNetV1() 12 | self.image_model.apply(self.image_model.weight_init) 13 | elif args.image_model == 'resnet50': 14 | self.image_model = resnet50() 15 | elif args.image_model == 'resent101': 16 | self.image_model = resnet101() 17 | 18 | self.bilstm = BiLSTM(args) 19 | self.bilstm.apply(self.bilstm.weight_init) 20 | 21 | inp_size = 1024 22 | if args.image_model == 'resnet50' or args.image_model == 'resnet101': 23 | inp_size = 2048 24 | # shorten the tensor using 1*1 conv 25 | self.conv_images = nn.Conv2d(inp_size, args.feature_size, 1) 26 | self.conv_text = nn.Conv2d(1024, args.feature_size, 1) 27 | 28 | 29 | def forward(self, images, text, text_length): 30 | image_features = self.image_model(images) 31 | text_features = self.bilstm(text, text_length) 32 | image_embeddings, text_embeddings= self.build_joint_embeddings(image_features, text_features) 33 | 34 | return image_embeddings, text_embeddings 35 | 36 | 37 | def build_joint_embeddings(self, images_features, text_features): 38 | 39 | #images_features = images_features.permute(0,2,3,1) 40 | #text_features = text_features.permute(0,3,1,2) 41 | image_embeddings = self.conv_images(images_features).squeeze() 42 | text_embeddings = self.conv_text(text_features).squeeze() 43 | 44 | return image_embeddings, text_embeddings 45 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = conv1x1(inplanes, planes) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = conv3x3(planes, planes, stride) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = conv1x1(planes, planes * self.expansion) 71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, num_classes=1000): 102 | super(ResNet, self).__init__() 103 | self.inplanes = 64 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 105 | bias=False) 106 | self.bn1 = nn.BatchNorm2d(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 109 | self.layer1 = self._make_layer(block, 64, layers[0]) 110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 113 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 114 | self.fc = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.constant_(m.weight, 1) 121 | nn.init.constant_(m.bias, 0) 122 | 123 | def _make_layer(self, block, planes, blocks, stride=1): 124 | downsample = None 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | conv1x1(self.inplanes, planes * block.expansion, stride), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for _ in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.relu(x) 143 | x = self.maxpool(x) 144 | 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | 150 | x = self.avgpool(x) 151 | #x = x.view(x.size(0), -1) 152 | #x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def resnet18(pretrained=False, **kwargs): 158 | """Constructs a ResNet-18 model. 159 | 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 164 | if pretrained: 165 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 166 | return model 167 | 168 | 169 | def resnet34(pretrained=False, **kwargs): 170 | """Constructs a ResNet-34 model. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 214 | return model 215 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | GPUS=3 2 | export CUDA_VISIBLE_DEVICES=$GPUS 3 | 4 | BASE_ROOT=/home/zhangqi/Deep-Cross-Modal-Projection-Learning-for-Image-Text-Matching 5 | IMAGE_DIR=/home/zhangqi/TriCrossModalV2/data/ 6 | ANNO_DIR=$BASE_ROOT/data/processed_data 7 | CKPT_DIR=$BASE_ROOT/data/model_data 8 | LOG_DIR=$BASE_ROOT/data/logs 9 | PRETRAINED_PATH=$BASE_ROOT/pretrained_models/mobilenet.tar 10 | #PRETRAINED_PATH=$BASE_ROOT/resnet50.pth 11 | IMAGE_MODEL=mobilenet_v1 12 | lr=0.0002 13 | num_epoches=300 14 | batch_size=16 15 | lr_decay_ratio=0.9 16 | epoches_decay=80_150_200 17 | 18 | python3.5 $BASE_ROOT/train.py \ 19 | --CMPC \ 20 | --CMPM \ 21 | --bidirectional \ 22 | --pretrained \ 23 | --model_path $PRETRAINED_PATH \ 24 | --image_model $IMAGE_MODEL \ 25 | --log_dir $LOG_DIR/lr-$lr-decay-$lr_decay_ratio-batch-$batch_size \ 26 | --checkpoint_dir $CKPT_DIR/lr-$lr-decay-$lr_decay_ratio-batch-$batch_size \ 27 | --image_dir $IMAGE_DIR \ 28 | --anno_dir $ANNO_DIR \ 29 | --batch_size $batch_size \ 30 | --gpus $GPUS \ 31 | --num_epoches $num_epoches \ 32 | --lr $lr \ 33 | --lr_decay_ratio $lr_decay_ratio \ 34 | --epoches_decay ${epoches_decay} 35 | 36 | 37 | python3.5 ${BASE_ROOT}/test.py \ 38 | --bidirectional \ 39 | --model_path $CKPT_DIR/lr-$lr-decay-$lr_decay_ratio-batch-$batch_size \ 40 | --image_model $IMAGE_MODEL \ 41 | --log_dir $LOG_DIR/lr-$lr-decay-$lr_decay_ratio-batch-$batch_size \ 42 | --image_dir $IMAGE_DIR \ 43 | --anno_dir $ANNO_DIR \ 44 | --gpus $GPUS \ 45 | --epoch_ema 0 46 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | GPUS=3 2 | export CUDA_VISIBLE_DEVICES=$GPUS 3 | 4 | BASE_ROOT=/home/zhangqi/Deep-Cross-Modal-Projection-Learning-for-Image-Text-Matching 5 | IMAGE_DIR=/home/zhangqi/TriCrossModalV2/data/ 6 | ANNO_DIR=$BASE_ROOT/data/processed_data 7 | CKPT_DIR=$BASE_ROOT/data/model_data 8 | LOG_DIR=$BASE_ROOT/data/logs 9 | IMAGE_MODEL=mobilenet_v1 10 | lr=0.0002 11 | batch_size=16 12 | lr_decay_ratio=0.9 13 | epoches_decay=80_150_200 14 | 15 | 16 | python3.5 ${BASE_ROOT}/test.py \ 17 | --bidirectional \ 18 | --model_path $CKPT_DIR/lr-$lr-decay-$lr_decay_ratio-batch-$batch_size \ 19 | --image_model $IMAGE_MODEL \ 20 | --log_dir $LOG_DIR/lr-$lr-decay-$lr_decay_ratio-batch-$batch_size \ 21 | --image_dir $IMAGE_DIR \ 22 | --anno_dir $ANNO_DIR \ 23 | --gpus $GPUS \ 24 | --epoch_ema 0 25 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | GPUS=3 2 | export CUDA_VISIBLE_DEVICES=$GPUS 3 | 4 | BASE_ROOT=/home/zhangqi/Deep-Cross-Modal-Projection-Learning-for-Image-Text-Matching 5 | IMAGE_DIR=/home/zhangqi/TriCrossModalV2/data/ 6 | ANNO_DIR=$BASE_ROOT/data/processed_data 7 | CKPT_DIR=$BASE_ROOT/data/model_data 8 | LOG_DIR=$BASE_ROOT/data/logs 9 | PRETRAINED_PATH=$BASE_ROOT/pretrained_models/mobilenet.tar 10 | #PRETRAINED_PATH=$BASE_ROOT/resnet50.pth 11 | IMAGE_MODEL=mobilenet_v1 12 | lr=0.0002 13 | num_epoches=300 14 | batch_size=16 15 | lr_decay_ratio=0.9 16 | epoches_decay=80_150_200 17 | 18 | python3.5 $BASE_ROOT/train.py \ 19 | --CMPC \ 20 | --CMPM \ 21 | --bidirectional \ 22 | --pretrained \ 23 | --model_path $PRETRAINED_PATH \ 24 | --image_model $IMAGE_MODEL \ 25 | --log_dir $LOG_DIR/lr-$lr-decay-$lr_decay_ratio-batch-$batch_size \ 26 | --checkpoint_dir $CKPT_DIR/lr-$lr-decay-$lr_decay_ratio-batch-$batch_size \ 27 | --image_dir $IMAGE_DIR \ 28 | --anno_dir $ANNO_DIR \ 29 | --batch_size $batch_size \ 30 | --gpus $GPUS \ 31 | --num_epoches $num_epoches \ 32 | --lr $lr \ 33 | --lr_decay_ratio $lr_decay_ratio \ 34 | --epoches_decay ${epoches_decay} 35 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import shutil 5 | import logging 6 | import gc 7 | import torch 8 | import torchvision.transforms as transforms 9 | from utils.metric import AverageMeter, compute_topk 10 | from test_config import config 11 | from config import data_config, network_config 12 | 13 | 14 | def test(data_loader, network, args): 15 | batch_time = AverageMeter() 16 | 17 | # switch to evaluate mode 18 | network.eval() 19 | max_size = 64 * len(data_loader) 20 | images_bank = torch.zeros((max_size, args.feature_size)).cuda() 21 | text_bank = torch.zeros((max_size,args.feature_size)).cuda() 22 | labels_bank = torch.zeros(max_size).cuda() 23 | index = 0 24 | with torch.no_grad(): 25 | end = time.time() 26 | for images, captions, labels, captions_length in data_loader: 27 | images = images.cuda() 28 | captions = captions.cuda() 29 | 30 | interval = images.shape[0] 31 | image_embeddings, text_embeddings = network(images, captions, captions_length) 32 | images_bank[index: index + interval] = image_embeddings 33 | text_bank[index: index + interval] = text_embeddings 34 | labels_bank[index:index + interval] = labels 35 | batch_time.update(time.time() - end) 36 | end = time.time() 37 | 38 | index = index + interval 39 | 40 | images_bank = images_bank[:index] 41 | text_bank = text_bank[:index] 42 | labels_bank = labels_bank[:index] 43 | #[ac_top1_t2i, ac_top10_t2i] = compute_topk(text_bank, images_bank, labels_bank, labels_bank, [1,10]) 44 | #[ac_top1_i2t, ac_top10_i2t] = compute_topk(images_bank, text_bank, labels_bank, labels_bank, [1,10]) 45 | ac_top1_i2t, ac_top10_i2t, ac_top1_t2i, ac_top10_t2i = compute_topk(images_bank, text_bank, labels_bank, labels_bank, [1,10], True) 46 | return ac_top1_i2t, ac_top10_i2t, ac_top1_t2i, ac_top10_t2i, batch_time.avg 47 | 48 | 49 | def main(args): 50 | # need to clear the pipeline 51 | # top1 & top10 need to be chosen in the same params ??? 52 | test_transform = transforms.Compose([ 53 | transforms.ToTensor(), 54 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 55 | ]) 56 | test_loader = data_config(args.image_dir, args.anno_dir, 64, 'test', args.max_length, test_transform) 57 | 58 | ac_i2t_top1_best = 0.0 59 | ac_i2t_top10_best = 0.0 60 | ac_t2i_top1_best = 0.0 61 | ac_t2i_top10_best = 0.0 62 | i2t_models = os.listdir(args.model_path) 63 | i2t_models.sort() 64 | for i2t_model in i2t_models: 65 | model_file = os.path.join(args.model_path, i2t_model) 66 | if os.path.isdir(model_file): 67 | continue 68 | epoch = i2t_model.split('.')[0] 69 | if int(epoch) >= args.epoch_ema: 70 | ema = True 71 | else: 72 | ema = False 73 | network, _ = network_config(args, [0], 'test', None, True, model_file, ema) 74 | ac_top1_i2t, ac_top10_i2t, ac_top1_t2i, ac_top10_t2i, test_time = test(test_loader, network, args) 75 | if ac_top1_t2i > ac_t2i_top1_best: 76 | ac_i2t_top1_best = ac_top1_i2t 77 | ac_i2t_top10_best = ac_top10_i2t 78 | ac_t2i_top1_best = ac_top1_t2i 79 | ac_t2i_top10_best = ac_top10_t2i 80 | dst_best = os.path.join(args.model_path, 'model_best', str(epoch)) + '.pth.tar' 81 | shutil.copyfile(model_file, dst_best) 82 | 83 | logging.info('epoch:{}'.format(epoch)) 84 | logging.info('top1_t2i: {:.3f}, top10_t2i: {:.3f}, top1_i2t: {:.3f}, top10_i2t: {:.3f}'.format( 85 | ac_top1_t2i, ac_top10_t2i, ac_top1_i2t, ac_top10_i2t)) 86 | logging.info('t2i_top1_best: {:.3f}, t2i_top10_best: {:.3f}, i2t_top1_best: {:.3f}, i2t_top10_best: {:.3f}'.format( 87 | ac_t2i_top1_best, ac_t2i_top10_best, ac_i2t_top1_best, ac_i2t_top10_best)) 88 | logging.info(args.model_path) 89 | logging.info(args.log_dir) 90 | 91 | if __name__ == '__main__': 92 | args = config() 93 | main(args) 94 | -------------------------------------------------------------------------------- /test_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from config import log_config 3 | import logging 4 | 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='command for evaluate on CUHK-PEDES') 8 | # Directory 9 | parser.add_argument('--image_dir', type=str, help='directory to store dataset') 10 | parser.add_argument('--anno_dir', type=str, help='directory to store anno') 11 | parser.add_argument('--model_path', type=str, help='directory to load checkpoint') 12 | parser.add_argument('--log_dir', type=str, help='directory to store log') 13 | 14 | # LSTM setting 15 | parser.add_argument('--embedding_size', type=int, default=512) 16 | parser.add_argument('--num_lstm_units', type=int, default=512) 17 | parser.add_argument('--vocab_size', type=int, default=12000) 18 | parser.add_argument('--lstm_dropout_ratio', type=float, default=0.7) 19 | parser.add_argument('--bidirectional', action='store_true') 20 | 21 | parser.add_argument('--max_length', type=int, default=100) 22 | parser.add_argument('--feature_size', type=int, default=512) 23 | 24 | parser.add_argument('--image_model', type=str, default='mobilenet_v1') 25 | parser.add_argument('--cnn_dropout_keep', type=float, default=0.999) 26 | 27 | parser.add_argument('--epoch_ema', type=int, default=0) 28 | 29 | # Default setting 30 | parser.add_argument('--gpus', type=str, default='0') 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | 36 | def config(): 37 | args = parse_args() 38 | log_config(args, 'test') 39 | return args 40 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import time 5 | import logging 6 | import torch 7 | import torch.utils.data as data 8 | import torch.nn as nn 9 | import torchvision.transforms as transforms 10 | from utils.metric import AverageMeter, Loss, constraints_loss, EMA 11 | from test import test 12 | from config import data_config, network_config, adjust_lr, lr_scheduler 13 | from train_config import config 14 | 15 | logger = logging.getLogger() 16 | logger.setLevel(logging.INFO) 17 | 18 | 19 | def save_checkpoint(state, epoch, dst, is_best): 20 | filename = os.path.join(dst, str(args.start_epoch + epoch)) + '.pth.tar' 21 | torch.save(state, filename) 22 | if is_best: 23 | dst_best = os.path.join(dst, 'model_best', str(epoch)) + '.pth.tar' 24 | shutil.copyfile(filename, dst_best) 25 | 26 | 27 | def train(epoch, train_loader, network, optimizer, compute_loss, args): 28 | batch_time = AverageMeter() 29 | train_loss = AverageMeter() 30 | image_pre = AverageMeter() 31 | text_pre = AverageMeter() 32 | 33 | # switch to train mode 34 | network.train() 35 | 36 | end = time.time() 37 | for step, (images, captions, labels, captions_length) in enumerate(train_loader): 38 | images = images.cuda() 39 | labels = labels.cuda() 40 | captions = captions.cuda() 41 | 42 | # compute loss 43 | image_embeddings, text_embeddings = network(images, captions, captions_length) 44 | cmpm_loss, cmpc_loss, loss, image_precision, text_precision, pos_avg_sim, neg_arg_sim = compute_loss(image_embeddings, text_embeddings, labels) 45 | 46 | 47 | if step % 10 == 0: 48 | print('epoch:{}, step:{}, cmpm_loss:{:.3f}, cmpc_loss:{:.3f}'.format(epoch, step, cmpm_loss, cmpc_loss)) 49 | 50 | # constrain embedding with the same id at the end of one epoch 51 | if (args.constraints_images or args.constraints_text) and step == len(train_loader) - 1: 52 | con_images, con_text = constraints_loss(train_loader, network, args) 53 | loss += (con_images + con_text) 54 | print('epoch:{}, step:{}, con_images:{:.3f}, con_text:{:.3f}'.format(epoch, step, con_images, con_text)) 55 | 56 | 57 | # compute gradient and do ADAM step 58 | optimizer.zero_grad() 59 | loss.backward() 60 | #nn.utils.clip_grad_norm(network.parameters(), 5) 61 | optimizer.step() 62 | 63 | # measure elapsed time 64 | batch_time.update(time.time() - end) 65 | end = time.time() 66 | 67 | train_loss.update(loss, images.shape[0]) 68 | image_pre.update(image_precision, images.shape[0]) 69 | text_pre.update(text_precision, images.shape[0]) 70 | 71 | return train_loss.avg, batch_time.avg, image_pre.avg, text_pre.avg 72 | 73 | 74 | def main(args): 75 | # transform 76 | train_transform = transforms.Compose([ 77 | transforms.RandomHorizontalFlip(), 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 80 | ]) 81 | val_transform = transforms.Compose([ 82 | transforms.ToTensor(), 83 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 84 | ]) 85 | # data 86 | train_loader = data_config(args.image_dir, args.anno_dir, args.batch_size, 'train', args.max_length, train_transform) 87 | #val_loader = data_config(args.dataset_dir, 64, 'val', args.max_length, val_transform) 88 | 89 | # loss 90 | compute_loss = Loss(args) 91 | nn.DataParallel(compute_loss).cuda() 92 | 93 | # network 94 | network, optimizer = network_config(args, 'train', compute_loss.parameters(), args.resume, args.model_path) 95 | 96 | ema = EMA(args.ema_decay) 97 | for name, param in network.named_parameters(): 98 | if param.requires_grad: 99 | ema.register(name, param.data) 100 | 101 | # lr_scheduler 102 | scheduler = lr_scheduler(optimizer, args) 103 | for epoch in range(args.num_epoches - args.start_epoch): 104 | # train for one epoch 105 | train_loss, train_time, image_precision, text_precision = train(args.start_epoch + epoch, train_loader, network, optimizer, compute_loss, args) 106 | # evaluate on validation set 107 | print('Train done for epoch-{}'.format(args.start_epoch + epoch)) 108 | 109 | if epoch == args.epoch_ema: 110 | for name, param in network.named_parameters(): 111 | if param.requires_grad: 112 | ema.register(name, param.data) 113 | 114 | 115 | if epoch > args.epoch_ema: 116 | # ema update 117 | for name, param in network.named_parameters(): 118 | if param.requires_grad: 119 | ema.update(name, param.data) 120 | 121 | state = {'network': network.state_dict(), 'optimizer': optimizer.state_dict(), 'W': compute_loss.W, 'epoch': args.start_epoch + epoch} 122 | # 'ac': [ac_top1_i2t, ac_top10_i2t, ac_top1_t2i, ac_top10_t2i], 123 | # 'best_ac': [ac_i2t_best, ac_t2i_best]} 124 | save_checkpoint(state, epoch, args.checkpoint_dir, False) 125 | state = {'network': network.state_dict(), 'network_ema': ema.shadow, 'optimizer': optimizer.state_dict(), 'W': comput_loss.W,'epoch': args.start_epoch + epoch} 126 | save_checkpoint(state, args.start_epoch + epoch, args.checkpoint_dir, False) 127 | logging.info('Epoch: [{}|{}], train_time: {:.3f}, train_loss: {:.3f}'.format(args.start_epoch + epoch, args.num_epoches, train_time, train_loss)) 128 | logging.info('image_precision: {:.3f}, text_precision: {:.3f}'.format(image_precision, text_precision)) 129 | adjust_lr(optimizer, args.start_epoch + epoch, args) 130 | scheduler.step() 131 | for param in optimizer.param_groups: 132 | print('lr:{}'.format(param['lr'])) 133 | break 134 | logging.info('Train done') 135 | logging.info(args.checkpoint_dir) 136 | logging.info(args.log_dir) 137 | 138 | 139 | if __name__ == "__main__": 140 | args = config() 141 | main(args) 142 | -------------------------------------------------------------------------------- /train_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | from config import log_config, dir_config 5 | 6 | logger = logging.getLogger() 7 | logger.setLevel(logging.INFO) 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='command for train on CUHK-PEDES') 11 | 12 | # Directory 13 | parser.add_argument('--image_dir', type=str, help='directory to store dataset') 14 | parser.add_argument('--anno_dir', type=str, help='directory to store anno file') 15 | parser.add_argument('--checkpoint_dir', type=str, help='directory to store checkpoint') 16 | parser.add_argument('--log_dir', type=str, help='directory to store log') 17 | parser.add_argument('--model_path', type=str, default = None, help='directory to pretrained model, whole model or just visual part') 18 | 19 | # LSTM setting 20 | parser.add_argument('--embedding_size', type=int, default=512) 21 | parser.add_argument('--num_lstm_units', type=int, default=512) 22 | parser.add_argument('--vocab_size', type=int, default=12000) 23 | parser.add_argument('--lstm_dropout_ratio', type=float, default=0.7) 24 | parser.add_argument('--max_length', type=int, default=100) 25 | parser.add_argument('--bidirectional', action='store_true') 26 | 27 | # Model setting 28 | parser.add_argument('--image_model', type=str, default='mobilenet_v1') 29 | parser.add_argument('--resume', action='store_true', help='whether or not to restore the pretrained whole model') 30 | parser.add_argument('--batch_size', type=int, default=16) 31 | parser.add_argument('--num_epoches', type=int, default=100) 32 | parser.add_argument('--ckpt_steps', type=int, default=5000, help='#steps to save checkpoint') 33 | parser.add_argument('--feature_size', type=int, default=512) 34 | parser.add_argument('--img_model', type=str, default='mobilenet_v1', help='model to train images') 35 | parser.add_argument('--loss_weight', type=float, default=1) 36 | parser.add_argument('--CMPM', action='store_true') 37 | parser.add_argument('--CMPC', action='store_true') 38 | parser.add_argument('--cnn_dropout_keep', type=float, default=0.999) 39 | parser.add_argument('--constraints_text', action='store_true') 40 | parser.add_argument('--constraints_images', action='store_true') 41 | parser.add_argument('--num_classes', type=int, default=11003) 42 | parser.add_argument('--pretrained', action='store_true', help='whether or not to restore the pretrained visual model') 43 | 44 | # Optimization setting 45 | parser.add_argument('--optimizer', type=str, default='adam', help='one of "sgd", "adam", "rmsprop", "adadelta", or "adagrad"') 46 | parser.add_argument('--lr', type=float, default=0.0002) 47 | parser.add_argument('--wd', type=float, default=0.00004) 48 | parser.add_argument('--adam_alpha', type=float, default=0.9) 49 | parser.add_argument('--adam_beta', type=float, default=0.999) 50 | parser.add_argument('--epsilon', type=float, default=1e-8) 51 | parser.add_argument('--end_lr', type=float, default=0.0001, help='minimum end learning rate used by a polynomial decay learning rate') 52 | parser.add_argument('--lr_decay_type', type=str, default='exponential', help='One of "fixed" or "exponential"') 53 | parser.add_argument('--lr_decay_ratio', type=float, default=0.1) 54 | parser.add_argument('--epoches_decay', type=str, default='50,100', help='#epoches when learning rate decays') 55 | parser.add_argument('--epoch_ema', type=int, default=0) 56 | parser.add_argument('--ema_decay', type=float, default=0.9) 57 | 58 | # Default setting 59 | parser.add_argument('--gpus', type=str, default='0') 60 | 61 | args = parser.parse_args() 62 | return args 63 | 64 | 65 | def config(): 66 | args = parse_args() 67 | dir_config(args) 68 | log_config(args,'train') 69 | return args 70 | -------------------------------------------------------------------------------- /utils/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def makedir(root): 5 | if not os.path.exists(root): 6 | os.makedirs(root) 7 | 8 | 9 | def write_json(data, root): 10 | with open(dir, 'w') as f: 11 | json.dump(data, f) 12 | 13 | 14 | def check_exists(root): 15 | if os.path.exists(root): 16 | return True 17 | return False 18 | 19 | def check_file(root, keyword): 20 | if not os.path.isfile(root): 21 | raise RuntimeError('===> No {} in {}'.format(keyword, root)) 22 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | import logging 7 | from torch.nn.parameter import Parameter 8 | from torch.autograd import Variable 9 | 10 | logger = logging.getLogger() 11 | logger.setLevel(logging.INFO) 12 | 13 | 14 | 15 | class EMA(): 16 | def __init__(self, decay=0.999): 17 | self.decay = decay 18 | self.shadow = {} 19 | 20 | def register(self, name, val): 21 | self.shadow[name] = val.cpu().detach() 22 | 23 | def get(self, name): 24 | return self.shadow[name] 25 | 26 | def update(self, name, x): 27 | assert name in self.shadow 28 | new_average = (1.0 - self.decay) * x.cpu().detach() + self.decay * self.shadow[name] 29 | self.shadow[name] = new_average.clone() 30 | 31 | 32 | def pairwise_distance(A, B): 33 | """ 34 | Compute distance between points in A and points in B 35 | :param A: (m,n) -m points, each of n dimension. Every row vector is a point, denoted as A(i). 36 | :param B: (k,n) -k points, each of n dimension. Every row vector is a point, denoted as B(j). 37 | :return: Matrix with (m, k). And the ele in (i,j) is the distance between A(i) and B(j) 38 | """ 39 | A_square = torch.sum(A * A, dim=1, keepdim=True) 40 | B_square = torch.sum(B * B, dim=1, keepdim=True) 41 | 42 | distance = A_square + B_square.t() - 2 * torch.matmul(A, B.t()) 43 | 44 | return distance 45 | 46 | 47 | def one_hot_coding(index, k): 48 | if type(index) is torch.Tensor: 49 | length = len(index) 50 | else: 51 | length = 1 52 | out = torch.zeros((length, k), dtype=torch.int64).cuda() 53 | index = index.reshape((len(index), 1)) 54 | out.scatter_(1, index, 1) 55 | return out 56 | 57 | 58 | # deprecated due to the large memory usage 59 | def constraints_old(features, labels): 60 | distance = pairwise_distance(features, features) 61 | labels_reshape = torch.reshape(labels, (features.shape[0], 1)) 62 | labels_dist = labels_reshape - labels_reshape.t() 63 | labels_mask = (labels_dist == 0).float() 64 | 65 | # Average loss with each matching pair 66 | num = torch.sum(labels_mask) - features.shape[0] 67 | if num == 0: 68 | con_loss = 0.0 69 | else: 70 | con_loss = torch.sum(distance * labels_mask) / num 71 | 72 | return con_loss 73 | 74 | 75 | def constraints(features, labels): 76 | labels = torch.reshape(labels, (labels.shape[0],1)) 77 | con_loss = AverageMeter() 78 | index_dict = {k.item() for k in labels} 79 | for index in index_dict: 80 | labels_mask = (labels == index) 81 | feas = torch.masked_select(features, labels_mask) 82 | feas = feas.view(-1, features.shape[1]) 83 | distance = pairwise_distance(feas, feas) 84 | #torch.sqrt_(distance) 85 | num = feas.shape[0] * (feas.shape[0] - 1) 86 | loss = torch.sum(distance) / num 87 | con_loss.update(loss, n = num / 2) 88 | return con_loss.avg 89 | 90 | 91 | def constraints_loss(data_loader, network, args): 92 | network.eval() 93 | max_size = args.batch_size * len(data_loader) 94 | images_bank = torch.zeros((max_size, args.feature_size)).cuda() 95 | text_bank = torch.zeros((max_size,args.feature_size)).cuda() 96 | labels_bank = torch.zeros(max_size).cuda() 97 | index = 0 98 | con_images = 0.0 99 | con_text = 0.0 100 | with torch.no_grad(): 101 | for images, captions, labels, captions_length in data_loader: 102 | images = images.cuda() 103 | captions = captions.cuda() 104 | interval = images.shape[0] 105 | image_embeddings, text_embeddings = network(images, captions, captions_length) 106 | images_bank[index: index + interval] = image_embeddings 107 | text_bank[index: index + interval] = text_embeddings 108 | labels_bank[index: index + interval] = labels 109 | index = index + interval 110 | images_bank = images_bank[:index] 111 | text_bank = text_bank[:index] 112 | labels_bank = labels_bank[:index] 113 | 114 | if args.constraints_text: 115 | con_text = constraints(text_bank, labels_bank) 116 | if args.constraints_images: 117 | con_images = constraints(images_bank, labels_bank) 118 | 119 | return con_images, con_text 120 | 121 | 122 | class Loss(nn.Module): 123 | def __init__(self, args): 124 | super(Loss, self).__init__() 125 | self.CMPM = args.CMPM 126 | self.CMPC = args.CMPC 127 | self.epsilon = args.epsilon 128 | self.num_classes = args.num_classes 129 | if args.resume: 130 | checkpoint = torch.load(args.model_path) 131 | self.W = Parameter(checkpoint['W']) 132 | print('=========> Loading in parameter W from pretrained models') 133 | else: 134 | self.W = Parameter(torch.randn(args.feature_size, args.num_classes)) 135 | self.init_weight() 136 | 137 | def init_weight(self): 138 | nn.init.xavier_uniform_(self.W.data, gain=1) 139 | 140 | 141 | def compute_cmpc_loss(self, image_embeddings, text_embeddings, labels): 142 | """ 143 | Cross-Modal Projection Classfication loss(CMPC) 144 | :param image_embeddings: Tensor with dtype torch.float32 145 | :param text_embeddings: Tensor with dtype torch.float32 146 | :param labels: Tensor with dtype torch.int32 147 | :return: 148 | """ 149 | criterion = nn.CrossEntropyLoss(reduction='mean') 150 | self.W_norm = self.W / self.W.norm(dim=0) 151 | #labels_onehot = one_hot_coding(labels, self.num_classes).float() 152 | image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) 153 | text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) 154 | 155 | image_proj_text = torch.sum(image_embeddings * text_norm, dim=1, keepdim=True) * text_norm 156 | text_proj_image = torch.sum(text_embeddings * image_norm, dim=1, keepdim=True) * image_norm 157 | 158 | image_logits = torch.matmul(image_proj_text, self.W_norm) 159 | text_logits = torch.matmul(text_proj_image, self.W_norm) 160 | 161 | #labels_one_hot = one_hot_coding(labels, num_classes) 162 | ''' 163 | ipt_loss = criterion(input=image_logits, target=labels) 164 | tpi_loss = criterion(input=text_logits, target=labels) 165 | cmpc_loss = ipt_loss + tpi_loss 166 | ''' 167 | cmpc_loss = criterion(image_logits, labels) + criterion(text_logits, labels) 168 | #cmpc_loss = - (F.log_softmax(image_logits, dim=1) + F.log_softmax(text_logits, dim=1)) * labels_onehot 169 | #cmpc_loss = torch.mean(torch.sum(cmpc_loss, dim=1)) 170 | # classification accuracy for observation 171 | image_pred = torch.argmax(image_logits, dim=1) 172 | text_pred = torch.argmax(text_logits, dim=1) 173 | 174 | image_precision = torch.mean((image_pred == labels).float()) 175 | text_precision = torch.mean((text_pred == labels).float()) 176 | 177 | return cmpc_loss, image_precision, text_precision 178 | 179 | 180 | def compute_cmpm_loss(self, image_embeddings, text_embeddings, labels): 181 | """ 182 | Cross-Modal Projection Matching Loss(CMPM) 183 | :param image_embeddings: Tensor with dtype torch.float32 184 | :param text_embeddings: Tensor with dtype torch.float32 185 | :param labels: Tensor with dtype torch.int32 186 | :return: 187 | i2t_loss: cmpm loss for image projected to text 188 | t2i_loss: cmpm loss for text projected to image 189 | pos_avg_sim: average cosine-similarity for positive pairs 190 | neg_avg_sim: averate cosine-similarity for negative pairs 191 | """ 192 | 193 | batch_size = image_embeddings.shape[0] 194 | labels_reshape = torch.reshape(labels, (batch_size, 1)) 195 | labels_dist = labels_reshape - labels_reshape.t() 196 | labels_mask = (labels_dist == 0) 197 | 198 | image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) 199 | text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) 200 | image_proj_text = torch.matmul(image_embeddings, text_norm.t()) 201 | text_proj_image = torch.matmul(text_embeddings, image_norm.t()) 202 | 203 | # normalize the true matching distribution 204 | labels_mask_norm = labels_mask.float() / labels_mask.float().norm(dim=1) 205 | 206 | i2t_pred = F.softmax(image_proj_text, dim=1) 207 | #i2t_loss = i2t_pred * torch.log((i2t_pred + self.epsilon)/ (labels_mask_norm + self.epsilon)) 208 | i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + self.epsilon)) 209 | 210 | t2i_pred = F.softmax(text_proj_image, dim=1) 211 | #t2i_loss = t2i_pred * torch.log((t2i_pred + self.epsilon)/ (labels_mask_norm + self.epsilon)) 212 | t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + self.epsilon)) 213 | 214 | cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) 215 | 216 | sim_cos = torch.matmul(image_norm, text_norm.t()) 217 | 218 | pos_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask)) 219 | neg_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask == 0)) 220 | 221 | return cmpm_loss, pos_avg_sim, neg_avg_sim 222 | 223 | 224 | def forward(self, image_embeddings, text_embeddings, labels): 225 | cmpm_loss = 0.0 226 | cmpc_loss = 0.0 227 | image_precision = 0.0 228 | text_precision = 0.0 229 | neg_avg_sim = 0.0 230 | pos_avg_sim =0.0 231 | if self.CMPM: 232 | cmpm_loss, pos_avg_sim, neg_avg_sim = self.compute_cmpm_loss(image_embeddings, text_embeddings, labels) 233 | if self.CMPC: 234 | cmpc_loss, image_precision, text_precision = self.compute_cmpc_loss(image_embeddings, text_embeddings, labels) 235 | 236 | loss = cmpm_loss + cmpc_loss 237 | 238 | return cmpm_loss, cmpc_loss, loss, image_precision, text_precision, pos_avg_sim, neg_avg_sim 239 | 240 | 241 | class AverageMeter(object): 242 | """ 243 | Computes and stores the averate and current value 244 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py #L247-262 245 | """ 246 | def __init__(self): 247 | self.reset() 248 | 249 | def reset(self): 250 | self.val = 0 251 | self.avg = 0 252 | self.sum = 0 253 | self.count = 0 254 | 255 | def update(self, val, n=1): 256 | self.val = val 257 | self.sum += n * val 258 | self.count += n 259 | self.avg = self.sum / self.count 260 | 261 | 262 | def compute_topk(query, gallery, target_query, target_gallery, k=[1,10], reverse=False): 263 | result = [] 264 | query = query / query.norm(dim=1,keepdim=True) 265 | gallery = gallery / gallery.norm(dim=1,keepdim=True) 266 | sim_cosine = torch.matmul(query, gallery.t()) 267 | result.extend(topk(sim_cosine, target_gallery, target_query, k=[1,10])) 268 | if reverse: 269 | result.extend(topk(sim_cosine, target_query, target_gallery, k=[1,10], dim=0)) 270 | return result 271 | 272 | 273 | def topk(sim, target_gallery, target_query, k=[1,10], dim=1): 274 | result = [] 275 | maxk = max(k) 276 | size_total = len(target_gallery) 277 | _, pred_index = sim.topk(maxk, dim, True, True) 278 | pred_labels = target_gallery[pred_index] 279 | if dim == 1: 280 | pred_labels = pred_labels.t() 281 | correct = pred_labels.eq(target_query.view(1,-1).expand_as(pred_labels)) 282 | 283 | for topk in k: 284 | #correct_k = torch.sum(correct[:topk]).float() 285 | correct_k = torch.sum(correct[:topk], dim=0) 286 | correct_k = torch.sum(correct_k > 0).float() 287 | result.append(correct_k * 100 / size_total) 288 | return result 289 | -------------------------------------------------------------------------------- /utils/statistics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib 3 | matplotlib.use('TkAgg') 4 | import matplotlib.pyplot as plt 5 | import json 6 | import pickle 7 | 8 | def count_ids(root, flag=0): 9 | ids_dict = {} 10 | captions = 0 11 | with open(root,'r') as f: 12 | info = json.load(f) 13 | for data in info: 14 | label = data['id'] - flag 15 | ids_dict[label] = ids_dict.get(label,0) + 1 16 | captions += len(data['captions']) 17 | return ids_dict, captions 18 | 19 | 20 | def count_images(root): 21 | info = pickle.load(open(root, 'rb'))['label_range'] 22 | images_dict = {} 23 | # info['#images'] = num 24 | for label in info: 25 | num_images = len(info[label]) - 1 26 | images_dict[num_images] = images_dict.get(num_images, 0) + 1 27 | return images_dict 28 | 29 | def count_captions(root): 30 | info = pickle.load(open(root, 'rb'))['label_range'] 31 | captions_dict = {} 32 | for label in info: 33 | for index in range(0, len(info[label]) - 1): 34 | num_captions = info[label][index] - info[label][index - 1] 35 | captions_dict[num_captions] = captions_dict.get(num_captions, 0) + 1 36 | return captions_dict 37 | 38 | def visualize(data): 39 | keys = list(data.keys()) 40 | keys.sort() 41 | values = [] 42 | for key in keys: 43 | values.append(data[key]) 44 | plt.figure('#captions in each image') 45 | a = plt.bar(keys, values) 46 | #plt.yticks([1,5,1,100,200,500,1000,5000]) 47 | plt.xticks(list(range(min(keys), max(keys) + 1, 1))) 48 | autolabel(a) 49 | plt.xlim(min(keys) - 1, max(keys) + 1) 50 | plt.show() 51 | 52 | 53 | def autolabel(rects): 54 | for rect in rects: 55 | height = rect.get_height() 56 | plt.text(rect.get_x() + rect.get_width() / 2 - 0.2, height + 2, '%s' % int(height)) 57 | 58 | 59 | if __name__ == "__main__": 60 | root = '/Users/zhangqi/Codes/Deep-Cross-Modal-Projection-Learning-for-Image-Text-Matching/data/processed_data/train_sort.pkl' 61 | data = count_images(root) 62 | print(data) 63 | visualize(data) 64 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plot 2 | import os 3 | import cv2 4 | 5 | # visualize loss & accuracy 6 | def visualize_curve(log_root): 7 | log_file = open(log_root, 'r') 8 | result_root = log_root[:log_root.rfind('/') + 1] + 'train.jpg' 9 | loss = [] 10 | 11 | top1_i2t = [] 12 | top10_i2t = [] 13 | top1_t2i = [] 14 | top10_t2i = [] 15 | for line in log_file.readlines(): 16 | line = line.strip().split() 17 | 18 | if 'top10_t2i' not in line[-2]: 19 | continue 20 | 21 | loss.append(line[1]) 22 | top1_i2t.append(line[3]) 23 | top10_i2t.append(line[5]) 24 | top1_t2i.append(line[7]) 25 | top10_t2i.append(line[9]) 26 | 27 | log_file.close() 28 | 29 | plt.figure('loss') 30 | plt.plot(loss) 31 | 32 | plt.figure('accuracy') 33 | plt.subplot(211) 34 | plt.plot(top1_i2t, label = 'top1') 35 | plt.plot(top10_i2t, label = 'top10') 36 | plt.legend(['image to text'], loc = 'upper right') 37 | plt.subplot(212) 38 | plt.plot(top1_t2i, label = 'top1') 39 | plt.plot(top10_i2t, label = 'top10') 40 | plt.legend(['text to image'], loc = 'upper right') 41 | plt.savefig(result_root) 42 | plt.show() 43 | 44 | 45 | if __name__ == '__main__': 46 | log_root = '/home/zhangqi/Deep-Cross-Modal-Projection-Learning-for-Image-Text-Matching/data/logs/train.log' 47 | visualize_curve(log_root) 48 | --------------------------------------------------------------------------------