├── model ├── txt ├── HCNet.py └── FC5lstm.py ├── data ├── UCM_images └── UCM_images1 ├── README.md ├── datasets.py ├── eval.py ├── eval_HCNet_UCM.py ├── HCNet.py ├── train_HCNet_UCM.py ├── train.py └── utils.py /model/txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/UCM_images: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/UCM_images1: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Remote sensing image captioning aims to describe the crucial objects from remote sensing images in the form of natural language. Currently, it is challenging to generate high-quality captions due to the multi-scale targets in remote sensing images and the cross-modality differences between images and text features. To address these problems, this paper presents an approach for generating captions through hierarchical feature aggregation and cross-modality feature alignment, namely HCNet. Specifically, we propose a hierarchical feature aggregation module (HFAM) to obtain a comprehensive representation of vision features. Considering the disparities among different modality features, a cross-modality feature interaction module (CFIM) is designed in the decoder to facilitate feature alignment. Meanwhile, a cross-modality align loss is introduced to realize the alignment of image and text features. Extensive experiments on the three public caption datasets show our HCNet can achieve satisfactory performance. Especially, we demonstrate significant 2 | performance improvements of +14.15\% CIDEr score on NWPU datasets compared to existing approaches. 3 | 4 | 5 | First, refer to the [MLAT](https://github.com/Chen-Yang-Liu/MLAT) to generate the required data in the data\UCM_images1. 6 | 7 | Then, python train_HCNet_UCM.py, generate the weights in the best_UCM_weights. 8 | 9 | Finally, python eval_HCNet_UCM.py. 10 | 11 | This code is based on the [MLAT](https://github.com/Chen-Yang-Liu/MLAT) and [Clip](https://github.com/openai/CLIP). 12 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import h5py 4 | import json 5 | import os 6 | 7 | 8 | class CaptionDataset(Dataset): 9 | """ 10 | A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches. 11 | """ 12 | 13 | def __init__(self, data_folder, data_name, split, transform=None): 14 | """ 15 | :param data_folder: folder where data files are stored - /Users/skye/docs/image_dataset/dataset 16 | :param data_name: base name of processed datasets 17 | :param split: split, one of 'TRAIN', 'VAL', or 'TEST' 18 | :param transform: image transform pipeline 19 | """ 20 | self.split = split 21 | assert self.split in {'TRAIN', 'VAL', 'TEST'} 22 | 23 | # Open hdf5 file where images are stored 24 | self.h = h5py.File(os.path.join(data_folder, self.split + '_IMAGES_' + data_name + '.hdf5'), 'r') 25 | self.imgs = self.h['images'] 26 | 27 | # Captions per image 28 | self.cpi = self.h.attrs['captions_per_image'] 29 | 30 | # Load encoded captions (completely into memory) 31 | with open(os.path.join(data_folder, self.split + '_CAPTIONS_' + data_name + '.json'), 'r') as j: 32 | self.captions = json.load(j) 33 | 34 | # Load caption lengths (completely into memory) 35 | with open(os.path.join(data_folder, self.split + '_CAPLENS_' + data_name + '.json'), 'r') as j: 36 | self.caplens = json.load(j) 37 | 38 | # PyTorch transformation pipeline for the image (normalizing, etc.) 39 | self.transform = transform 40 | 41 | # Total number of datapoints 42 | self.dataset_size = len(self.captions) 43 | 44 | def __getitem__(self, i): 45 | # Remember, the Nth caption corresponds to the (N // captions_per_image)th image 46 | img = torch.FloatTensor(self.imgs[i // self.cpi] / 255.) 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | 50 | caption = torch.LongTensor(self.captions[i]) 51 | 52 | caplen = torch.LongTensor([self.caplens[i]]) 53 | 54 | if self.split is 'TRAIN': 55 | return img, caption, caplen 56 | else: 57 | # For validation of testing, also return all 'captions_per_image' captions to find BLEU-4 score 58 | all_captions = torch.LongTensor(self.captions[((i // self.cpi) * self.cpi):(((i // self.cpi) * self.cpi) + self.cpi)]) 59 | return img, caption, caplen, all_captions 60 | 61 | def __len__(self): 62 | return self.dataset_size 63 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch.backends.cudnn as cudnn 4 | import torch.optim 5 | import torch.utils.data 6 | import torchvision.transforms as transforms 7 | from datasets import * 8 | from utils import * 9 | from nltk.translate.bleu_score import corpus_bleu 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | import argparse 13 | import time 14 | # import transformer, models 15 | 16 | 17 | def evaluate_transformer(args): 18 | """ 19 | Evaluation for decoder_mode: transformer 20 | 21 | :param beam_size: beam size at which to generate captions for evaluation 22 | :return: BLEU-4 score 23 | """ 24 | beam_size = args.beam_size 25 | Caption_End = False 26 | # DataLoader 27 | loader = torch.utils.data.DataLoader( 28 | CaptionDataset(args.data_folder, args.data_name, 'TEST', transform=transforms.Compose([transforms.RandomHorizontalFlip(),normalize])), 29 | batch_size=1, shuffle=False, num_workers=0, pin_memory=True) 30 | 31 | # Lists to store references (true captions), and hypothesis (prediction) for each image 32 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - 33 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] 34 | references = list() 35 | hypotheses = list() 36 | 37 | with torch.no_grad(): 38 | for i, (image, caps, caplens, allcaps) in enumerate( 39 | tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))): 40 | # if i>30: 41 | # break 42 | if (i+1)%5 != 0: 43 | continue 44 | k = beam_size 45 | # Move to GPU device, if available 46 | image = image.to(device) # [1, 3, 256, 256] 47 | 48 | # Encode 49 | encoder_out = encoder(image) # [1, enc_image_size=14, enc_image_size=14, encoder_dim=2048] 50 | enc_image_size = encoder_out.size(1) 51 | encoder_dim = encoder_out.size(-1) 52 | # We'll treat the problem as having a batch size of k, where k is beam_size 53 | encoder_out = encoder_out.expand(k, enc_image_size, enc_image_size, encoder_dim) # [k, enc_image_size, enc_image_size, encoder_dim] 54 | # Tensor to store top k previous words at each step; now they're just 55 | # Important: [1, 52] (eg: [[ ...]]) will not work, since it contains the position encoding 56 | k_prev_words = torch.LongTensor([[word_map['']]*52] * k).to(device) # (k, 52) 57 | # Tensor to store top k sequences; now they're just 58 | seqs = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 59 | # Tensor to store top k sequences' scores; now they're just 0 60 | top_k_scores = torch.zeros(k, 1).to(device) 61 | # Lists to store completed sequences and scores 62 | complete_seqs = [] 63 | complete_seqs_scores = [] 64 | step = 1 65 | 66 | # Start decoding 67 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 68 | while True: 69 | # print("steps {} k_prev_words: {}".format(step, k_prev_words)) 70 | # cap_len = torch.LongTensor([52]).repeat(k, 1).to(device) may cause different sorted results on GPU/CPU in transformer.py 71 | cap_len = torch.LongTensor([52]).repeat(k, 1) # [s, 1] 72 | scores, _, _, _, _,_,_ = decoder(encoder_out, k_prev_words, cap_len) 73 | scores = scores[:, step-1, :].squeeze(1) # [s, 1, vocab_size] -> [s, vocab_size] 74 | scores = F.log_softmax(scores, dim=1) 75 | # top_k_scores: [s, 1] 76 | scores = top_k_scores.expand_as(scores) + scores # [s, vocab_size] 77 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 78 | if step == 1: 79 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 80 | else: 81 | # Unroll and find top scores, and their unrolled indices 82 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 83 | 84 | # Convert unrolled indices to actual indices of scores 85 | prev_word_inds = top_k_words // vocab_size # (s) 86 | next_word_inds = top_k_words % vocab_size # (s) 87 | 88 | # Add new words to sequences 89 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 90 | # Which sequences are incomplete (didn't reach )? 91 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 92 | next_word != word_map['']] 93 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) 94 | # Set aside complete sequences 95 | if len(complete_inds) > 0: 96 | Caption_End = True 97 | complete_seqs.extend(seqs[complete_inds].tolist()) 98 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 99 | k -= len(complete_inds) # reduce beam length accordingly 100 | # Proceed with incomplete sequences 101 | if k == 0: 102 | break 103 | seqs = seqs[incomplete_inds] 104 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 105 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 106 | # Important: this will not work, since decoder has self-attention 107 | # k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1).repeat(k, 52) 108 | k_prev_words = k_prev_words[incomplete_inds] 109 | k_prev_words[:, :step+1] = seqs # [s, 52] 110 | # k_prev_words[:, step] = next_word_inds[incomplete_inds] # [s, 52] 111 | # Break if things have been going on too long 112 | if step > 50: 113 | break 114 | step += 1 115 | 116 | # choose the caption which has the best_score. 117 | assert Caption_End 118 | indices = complete_seqs_scores.index(max(complete_seqs_scores)) 119 | seq = complete_seqs[indices] 120 | # References 121 | img_caps = allcaps[0].tolist() 122 | img_captions = list( 123 | map(lambda c: [w for w in c if w not in {word_map[''], word_map[''], word_map['']}], 124 | img_caps)) # remove and pads 125 | references.append(img_captions) 126 | # Hypotheses 127 | # tmp_hyp = [w for w in seq if w not in {word_map[''], word_map[''], word_map['']}] 128 | hypotheses.append([w for w in seq if w not in {word_map[''], word_map[''], word_map['']}]) 129 | assert len(references) == len(hypotheses) 130 | # Print References, Hypotheses and metrics every step 131 | # words = [] 132 | # # print('*' * 10 + 'ImageCaptions' + '*' * 10, len(img_captions)) 133 | # for seq in img_captions: 134 | # words.append([rev_word_map[ind] for ind in seq]) 135 | # for i, seq in enumerate(words): 136 | # print('Reference{}: '.format(i), seq) 137 | # print('Hypotheses: ', [rev_word_map[ind] for ind in tmp_hyp]) 138 | # metrics = get_eval_score([img_captions], [tmp_hyp]) 139 | # print("{} - beam size {}: BLEU-1 {} BLEU-2 {} BLEU-3 {} BLEU-4 {} METEOR {} ROUGE_L {} CIDEr {}".format 140 | # (args.decoder_mode, args.beam_size, metrics["Bleu_1"], metrics["Bleu_2"], metrics["Bleu_3"], 141 | # metrics["Bleu_4"], 142 | # metrics["METEOR"], metrics["ROUGE_L"], metrics["CIDEr"])) 143 | 144 | # Calculate BLEU1~4, METEOR, ROUGE_L, CIDEr scores 145 | with open('./results/33LFC5lstm_NWPU.json', 'w') as file: 146 | json.dump(hypotheses, file) 147 | metrics = get_eval_score(references, hypotheses) 148 | 149 | return metrics 150 | 151 | 152 | if __name__ == '__main__': 153 | parser = argparse.ArgumentParser(description='Image_Captioning') 154 | parser.add_argument('--data_folder', default="./data/NWPU_images1",help='' 155 | '' 156 | 'folder with data files saved by create_input_files.py.') 157 | parser.add_argument('--data_name', default="NWPU_5_cap_per_img_4_min_word_freq",help='base name shared by data files.') 158 | 159 | # FIXME:note to change these 160 | parser.add_argument('--encoder_mode', default="resnet50", help='which model does encoder use?') # inception_v3 or vgg16 or vgg19 or resnet50 or resnet101 or resnet152 161 | parser.add_argument('--decoder_mode', default="lstm_attention", help='which model does decoder use?') # lstm or lstm_attention or transformer or transformer_decoder 162 | 163 | parser.add_argument('--beam_size', type=int, default=3, help='beam_size.') 164 | parser.add_argument('--path', default="./best_models_weights/", help='model checkpoint.') 165 | args = parser.parse_args() 166 | 167 | for encoder_layers, decoder_layers in [(3, 3)]: # ,,(0,6),(2,2), 168 | 169 | 170 | args.encoder_layers = encoder_layers 171 | args.decoder_layers = decoder_layers 172 | 173 | word_map_file = os.path.join(args.data_folder, 'WORDMAP_' + args.data_name + '.json') 174 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 175 | # transformer.device = torch.device("cpu") 176 | # models.device = torch.device("cpu") 177 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead 178 | print(device) 179 | 180 | # Load model 181 | # checkpoint_path = args.checkpoint + args.encoder_mode + '_' + args.decoder_mode + '_'+ str(args.encoder_layers) + '_' + str(args.decoder_layers) +'_Res+MLAT'+'.pth.tar' 182 | 183 | filename = os.listdir(args.path) 184 | pathname = 'BEST_checkpoint_HCNet_NWPU.pth.tar' 185 | 186 | print(time.strftime("%m-%d %H : %M : %S", time.localtime(time.time()))) 187 | 188 | checkpoint_path = os.path.join(args.path,pathname) 189 | print(pathname) 190 | 191 | checkpoint = torch.load(checkpoint_path, map_location=str(device)) 192 | decoder = checkpoint['decoder'] 193 | decoder = decoder.to(device) 194 | decoder.eval() 195 | encoder = checkpoint['encoder'] 196 | encoder = encoder.to(device) 197 | encoder.eval() 198 | # print(encoder) 199 | # print(decoder) 200 | 201 | # Load word map (word2id) 202 | with open(word_map_file, 'r') as j: 203 | word_map = json.load(j) 204 | vocab_size = len(word_map) 205 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word 206 | 207 | # Normalization transform 208 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 209 | std=[0.229, 0.224, 0.225]) 210 | if args.decoder_mode == "lstm_attention" or args.decoder_mode == "transformer_decoder": 211 | metrics = evaluate_transformer(args) 212 | 213 | print("{} - beam size {}: BLEU-1 {} BLEU-2 {} BLEU-3 {} BLEU-4 {} METEOR {} ROUGE_L {} CIDEr {}".format 214 | (args.decoder_mode, args.beam_size, metrics["Bleu_1"], metrics["Bleu_2"], metrics["Bleu_3"], metrics["Bleu_4"], 215 | metrics["METEOR"], metrics["ROUGE_L"], metrics["CIDEr"])) 216 | 217 | print(time.strftime("%m-%d %H : %M : %S", time.localtime(time.time()))) 218 | 219 | print("\n") 220 | print("\n") 221 | print("\n") 222 | 223 | -------------------------------------------------------------------------------- /eval_HCNet_UCM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch.backends.cudnn as cudnn 4 | import torch.optim 5 | import torch.utils.data 6 | import torchvision.transforms as transforms 7 | from datasets import * 8 | from utils import * 9 | from nltk.translate.bleu_score import corpus_bleu 10 | import torch.nn.functional as F 11 | from tqdm import tqdm 12 | import argparse 13 | import time 14 | # import transformer, models 15 | 16 | 17 | def evaluate_transformer(args): 18 | """ 19 | Evaluation for decoder_mode: transformer 20 | 21 | :param beam_size: beam size at which to generate captions for evaluation 22 | :return: BLEU-4 score 23 | """ 24 | beam_size = args.beam_size 25 | Caption_End = False 26 | # DataLoader 27 | loader = torch.utils.data.DataLoader( 28 | CaptionDataset(args.data_folder, args.data_name, 'TEST', transform=transforms.Compose([transforms.RandomHorizontalFlip(),normalize])), 29 | batch_size=1, shuffle=False, num_workers=0, pin_memory=True) 30 | 31 | # Lists to store references (true captions), and hypothesis (prediction) for each image 32 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - 33 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] 34 | references = list() 35 | hypotheses = list() 36 | 37 | with torch.no_grad(): 38 | for i, (image, caps, caplens, allcaps) in enumerate( 39 | tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))): 40 | # if i>30: 41 | # break 42 | if (i+1)%5 != 0: 43 | continue 44 | k = beam_size 45 | # Move to GPU device, if available 46 | image = image.to(device) # [1, 3, 256, 256] 47 | 48 | # Encode 49 | encoder_out = encoder(image) # [1, enc_image_size=14, enc_image_size=14, encoder_dim=2048] 50 | enc_image_size = encoder_out.size(1) 51 | encoder_dim = encoder_out.size(-1) 52 | # We'll treat the problem as having a batch size of k, where k is beam_size 53 | encoder_out = encoder_out.expand(k, enc_image_size, enc_image_size, encoder_dim) # [k, enc_image_size, enc_image_size, encoder_dim] 54 | # Tensor to store top k previous words at each step; now they're just 55 | # Important: [1, 52] (eg: [[ ...]]) will not work, since it contains the position encoding 56 | k_prev_words = torch.LongTensor([[word_map['']]*52] * k).to(device) # (k, 52) 57 | # Tensor to store top k sequences; now they're just 58 | seqs = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 59 | # Tensor to store top k sequences' scores; now they're just 0 60 | top_k_scores = torch.zeros(k, 1).to(device) 61 | # Lists to store completed sequences and scores 62 | complete_seqs = [] 63 | complete_seqs_scores = [] 64 | step = 1 65 | 66 | # Start decoding 67 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 68 | while True: 69 | # print("steps {} k_prev_words: {}".format(step, k_prev_words)) 70 | # cap_len = torch.LongTensor([52]).repeat(k, 1).to(device) may cause different sorted results on GPU/CPU in transformer.py 71 | cap_len = torch.LongTensor([52]).repeat(k, 1) # [s, 1] 72 | scores, _, _, _, _,_,_ = decoder(encoder_out, k_prev_words, cap_len) 73 | scores = scores[:, step-1, :].squeeze(1) # [s, 1, vocab_size] -> [s, vocab_size] 74 | scores = F.log_softmax(scores, dim=1) 75 | # top_k_scores: [s, 1] 76 | scores = top_k_scores.expand_as(scores) + scores # [s, vocab_size] 77 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 78 | if step == 1: 79 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 80 | else: 81 | # Unroll and find top scores, and their unrolled indices 82 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 83 | 84 | # Convert unrolled indices to actual indices of scores 85 | prev_word_inds = top_k_words // vocab_size # (s) 86 | next_word_inds = top_k_words % vocab_size # (s) 87 | 88 | # Add new words to sequences 89 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 90 | # Which sequences are incomplete (didn't reach )? 91 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 92 | next_word != word_map['']] 93 | complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds)) 94 | # Set aside complete sequences 95 | if len(complete_inds) > 0: 96 | Caption_End = True 97 | complete_seqs.extend(seqs[complete_inds].tolist()) 98 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 99 | k -= len(complete_inds) # reduce beam length accordingly 100 | # Proceed with incomplete sequences 101 | if k == 0: 102 | break 103 | seqs = seqs[incomplete_inds] 104 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 105 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 106 | # Important: this will not work, since decoder has self-attention 107 | # k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1).repeat(k, 52) 108 | k_prev_words = k_prev_words[incomplete_inds] 109 | k_prev_words[:, :step+1] = seqs # [s, 52] 110 | # k_prev_words[:, step] = next_word_inds[incomplete_inds] # [s, 52] 111 | # Break if things have been going on too long 112 | if step > 50: 113 | break 114 | step += 1 115 | 116 | # choose the caption which has the best_score. 117 | assert Caption_End 118 | indices = complete_seqs_scores.index(max(complete_seqs_scores)) 119 | seq = complete_seqs[indices] 120 | # References 121 | img_caps = allcaps[0].tolist() 122 | img_captions = list( 123 | map(lambda c: [w for w in c if w not in {word_map[''], word_map[''], word_map['']}], 124 | img_caps)) # remove and pads 125 | references.append(img_captions) 126 | # Hypotheses 127 | # tmp_hyp = [w for w in seq if w not in {word_map[''], word_map[''], word_map['']}] 128 | hypotheses.append([w for w in seq if w not in {word_map[''], word_map[''], word_map['']}]) 129 | assert len(references) == len(hypotheses) 130 | # Print References, Hypotheses and metrics every step 131 | # words = [] 132 | # # print('*' * 10 + 'ImageCaptions' + '*' * 10, len(img_captions)) 133 | # for seq in img_captions: 134 | # words.append([rev_word_map[ind] for ind in seq]) 135 | # for i, seq in enumerate(words): 136 | # print('Reference{}: '.format(i), seq) 137 | # print('Hypotheses: ', [rev_word_map[ind] for ind in tmp_hyp]) 138 | # metrics = get_eval_score([img_captions], [tmp_hyp]) 139 | # print("{} - beam size {}: BLEU-1 {} BLEU-2 {} BLEU-3 {} BLEU-4 {} METEOR {} ROUGE_L {} CIDEr {}".format 140 | # (args.decoder_mode, args.beam_size, metrics["Bleu_1"], metrics["Bleu_2"], metrics["Bleu_3"], 141 | # metrics["Bleu_4"], 142 | # metrics["METEOR"], metrics["ROUGE_L"], metrics["CIDEr"])) 143 | 144 | # Calculate BLEU1~4, METEOR, ROUGE_L, CIDEr scores 145 | with open('./results/HCNet_UCM.json', 'w') as file: 146 | json.dump(hypotheses, file) 147 | metrics = get_eval_score(references, hypotheses) 148 | 149 | return metrics 150 | 151 | 152 | if __name__ == '__main__': 153 | parser = argparse.ArgumentParser(description='Image_Captioning') 154 | parser.add_argument('--data_folder', default="./data/UCM_images1",help='' 155 | '' 156 | 'folder with data files saved by create_input_files.py.') 157 | parser.add_argument('--data_name', default="UCM_5_cap_per_img_4_min_word_freq",help='base name shared by data files.') 158 | 159 | # FIXME:note to change these 160 | parser.add_argument('--encoder_mode', default="resnet50", help='which model does encoder use?') # inception_v3 or vgg16 or vgg19 or resnet50 or resnet101 or resnet152 161 | parser.add_argument('--decoder_mode', default="lstm_attention", help='which model does decoder use?') # lstm or lstm_attention or transformer or transformer_decoder 162 | 163 | parser.add_argument('--beam_size', type=int, default=3, help='beam_size.') 164 | # parser.add_argument('--checkpoint', default="./models_checkpoint_GRSL/BEST_checkpoint_",help='model checkpoint.') 165 | parser.add_argument('--path', default="./best_models_weights/", help='model checkpoint.') 166 | args = parser.parse_args() 167 | 168 | for encoder_layers, decoder_layers in [(3, 3)]: # ,,(0,6),(2,2), 169 | 170 | 171 | args.encoder_layers = encoder_layers 172 | args.decoder_layers = decoder_layers 173 | 174 | word_map_file = os.path.join(args.data_folder, 'WORDMAP_' + args.data_name + '.json') 175 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 176 | # transformer.device = torch.device("cpu") 177 | # models.device = torch.device("cpu") 178 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead 179 | print(device) 180 | 181 | # Load model 182 | # checkpoint_path = args.checkpoint + args.encoder_mode + '_' + args.decoder_mode + '_'+ str(args.encoder_layers) + '_' + str(args.decoder_layers) +'_Res+MLAT'+'.pth.tar' 183 | 184 | filename = os.listdir(args.path) 185 | pathname = 'BEST_checkpoint_HCNet_UCM.pth.tar' 186 | 187 | print(time.strftime("%m-%d %H : %M : %S", time.localtime(time.time()))) 188 | 189 | checkpoint_path = os.path.join(args.path,pathname) 190 | print(pathname) 191 | 192 | checkpoint = torch.load(checkpoint_path, map_location=str(device)) 193 | decoder = checkpoint['decoder'] 194 | decoder = decoder.to(device) 195 | decoder.eval() 196 | encoder = checkpoint['encoder'] 197 | encoder = encoder.to(device) 198 | encoder.eval() 199 | # print(encoder) 200 | # print(decoder) 201 | 202 | # Load word map (word2id) 203 | with open(word_map_file, 'r') as j: 204 | word_map = json.load(j) 205 | vocab_size = len(word_map) 206 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word 207 | 208 | # Normalization transform 209 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 210 | std=[0.229, 0.224, 0.225]) 211 | if args.decoder_mode == "lstm_attention" or args.decoder_mode == "transformer_decoder": 212 | metrics = evaluate_transformer(args) 213 | 214 | print("{} - beam size {}: BLEU-1 {} BLEU-2 {} BLEU-3 {} BLEU-4 {} METEOR {} ROUGE_L {} CIDEr {}".format 215 | (args.decoder_mode, args.beam_size, metrics["Bleu_1"], metrics["Bleu_2"], metrics["Bleu_3"], metrics["Bleu_4"], 216 | metrics["METEOR"], metrics["ROUGE_L"], metrics["CIDEr"])) 217 | 218 | print(time.strftime("%m-%d %H : %M : %S", time.localtime(time.time()))) 219 | 220 | print("\n") 221 | print("\n") 222 | print("\n") 223 | -------------------------------------------------------------------------------- /HCNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | import torch.nn.functional as F 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class HFAM(nn.Module): 9 | 10 | def __init__(self): 11 | super(HFAM, self).__init__() 12 | self.branch1 = nn.Sequential( 13 | nn.Conv2d( 14 | 512, 512, kernel_size=15, stride=2, 15 | padding=7, groups=128, bias=False), 16 | nn.BatchNorm2d(512), 17 | nn.Conv2d( 18 | 512, 512, kernel_size=1, stride=1, 19 | padding=0, bias=False), 20 | ) 21 | self.branch2 = nn.Sequential( 22 | nn.Conv2d( 23 | 1024, 512, kernel_size=11, stride=1, 24 | padding=5, bias=False), 25 | nn.BatchNorm2d(512), 26 | ) 27 | self.branch3 = nn.Sequential( 28 | nn.Conv2d( 29 | 2048, 512, kernel_size=7, stride=1, 30 | padding=3, groups=128, bias=False), 31 | nn.Upsample(scale_factor=2), 32 | nn.BatchNorm2d(512), 33 | ) 34 | self.conv = nn.Sequential( 35 | nn.Conv2d( 36 | 1024, 1024, kernel_size=11, stride=1, 37 | padding=5, bias=False), 38 | nn.BatchNorm2d(1024), 39 | nn.ReLU(inplace=True), # not shown in paper 40 | ) 41 | 42 | self.conv1 = nn.Conv2d(2, 1, 3, padding=1, bias=False) 43 | self.sigmoid1 = nn.Sigmoid() 44 | self.sigmoid2 = nn.Sigmoid() 45 | 46 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 47 | self.max_pool = nn.AdaptiveMaxPool2d(1) 48 | 49 | self.fc1 = nn.Conv2d(512, 512 // 16, 1, bias=False) 50 | self.relu1 = nn.ReLU() 51 | self.fc2 = nn.Conv2d(512 // 16, 512, 1, bias=False) 52 | 53 | def forward(self, x1, x2 ,x3): 54 | x1_1 = self.branch1(x1) 55 | x2_1 = self.branch2(x2) 56 | x3_1 = self.branch3(x3) 57 | 58 | pixavg = torch.mean(x1_1, dim=1, keepdim=True) 59 | detail = self.sigmoid1(pixavg) * x2_1 60 | 61 | chaavg = self.fc2(self.relu1(self.fc1(self.avg_pool(x3_1)))) 62 | seman = self.sigmoid2(chaavg) * x2_1 63 | out = self.conv(torch.cat([detail,seman],dim=1)) 64 | return out 65 | 66 | class Encoder(nn.Module): 67 | """ 68 | CNN_Encoder. 69 | """ 70 | def __init__(self, NetType='resnet50', encoded_image_size=14, attention_method="ByPixel"): 71 | super(Encoder, self).__init__() 72 | self.enc_image_size = encoded_image_size 73 | self.attention_method = attention_method 74 | 75 | self.FF = HFAM() 76 | 77 | # resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101 78 | net = torchvision.models.inception_v3(pretrained=True, transform_input=False) if NetType == 'inception_v3' else \ 79 | torchvision.models.vgg16(pretrained=True) if NetType == 'vgg16' else \ 80 | torchvision.models.resnet50(pretrained=True) if NetType == 'resnet50' else torchvision.models.resnet50(pretrained=True) 81 | # Remove linear and pool layers (since we're not doing classification) 82 | # Specifically, Remove: AdaptiveAvgPool2d(output_size=(1, 1)), Linear(in_features=2048, out_features=1000, bias=True)] 83 | 84 | # modules = list(net.children())[:-2] 85 | modules = list(net.children())[:-1] if NetType == 'inception_v3' or NetType == 'vgg16' else list(net.children())[:-2] 86 | # modules = list(net.children())[:-1] if NetType == 'inception_v3' else list(net.children())[:-2] # -2 for resnet & vgg 87 | if NetType == 'inception_v3': del modules[13] 88 | 89 | self.net = nn.Sequential(*modules) 90 | 91 | # every block of resnet for fusion 92 | if NetType == 'resnet50' or NetType == 'resnet101' or NetType == 'resnet152': 93 | resnet_block1 = list(net.children())[:5] 94 | self.resnet_block1 = nn.Sequential(*resnet_block1) 95 | resnet_block2 = list(net.children())[5] 96 | self.resnet_block2 = nn.Sequential(*resnet_block2) 97 | resnet_block3 = list(net.children())[6] 98 | self.resnet_block3 = nn.Sequential(*resnet_block3) 99 | resnet_block4 = list(net.children())[7] 100 | self.resnet_block4 = nn.Sequential(*resnet_block4) 101 | self.conv4 = nn.Conv2d(in_channels=2048, out_channels=1024, kernel_size=1, stride=1) 102 | self.conv3 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1) 103 | self.conv2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1) 104 | 105 | # if self.attention_method == "ByChannel": 106 | # self.cnn1 = nn.Conv2d(in_channels=2048, out_channels=512, kernel_size=(1, 1), stride=(1, 1), bias=False) 107 | # self.bn1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 108 | # self.relu = nn.ReLU(inplace=True) 109 | 110 | # Resize image to fixed size to allow input images of variable size 111 | self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 112 | # self.adaptive_pool4 = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 113 | # self.adaptive_pool3 = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 114 | 115 | self.fine_tune() 116 | 117 | def forward(self, images): 118 | """ 119 | Forward propagation. 120 | 121 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) 122 | :return: encoded images [batch_size, encoded_image_size=14, encoded_image_size=14, 2048] 123 | """ 124 | # with fusion for resnet 125 | out1 = self.resnet_block1(images) # 256 126 | out2 = self.resnet_block2(out1) # 512 127 | out3 = self.resnet_block3(out2) # 1024 128 | out4 = self.resnet_block4(out3) # 2048 129 | 130 | # # FIXME:concat432 131 | out = self.FF(out2,out3,out4) 132 | 133 | 134 | # without fusion 135 | # out = self.net(images) # (batch_size, 2048, image_size/32, image_size/32) 136 | # if self.attention_method == "ByChannel": 137 | # out = self.relu(self.bn1(self.cnn1(out))) 138 | out = self.adaptive_pool(out) # [batch_size, 2048/512, 8, 8] -> [batch_size, 2048/512, 14, 14] #FIXME:for fusion 139 | out = out.permute(0, 2, 3, 1) 140 | return out 141 | 142 | def fine_tune(self, fine_tune=True): 143 | """ 144 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder. 145 | 146 | :param fine_tune: Allow? 147 | """ 148 | for p in self.net.parameters(): 149 | p.requires_grad = False 150 | # If fine-tuning, only fine-tune convolutional blocks 2 through 4 151 | for c in list(self.net.children())[5:]: # FIXME:maybe try 6: 152 | for p in c.parameters(): 153 | p.requires_grad = fine_tune 154 | 155 | 156 | class Attention(nn.Module): 157 | """ 158 | Attention Network. 159 | """ 160 | 161 | def __init__(self, encoder_dim, decoder_dim, attention_dim): 162 | """ 163 | :param encoder_dim: feature size of encoded images 164 | :param decoder_dim: size of decoder's RNN 165 | :param attention_dim: size of the attention network 166 | """ 167 | super(Attention, self).__init__() 168 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image 169 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output 170 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed 171 | self.relu = nn.ReLU() 172 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 173 | 174 | def forward(self, encoder_out, decoder_hidden): 175 | """ 176 | Forward propagation. 177 | 178 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 179 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 180 | :return: attention weighted encoding, weights 181 | """ 182 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) 183 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) 184 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) 185 | alpha = self.softmax(att) # (batch_size, num_pixels) 186 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim) 187 | #attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)) # (batch_size, pixels, encoder_dim) 188 | return attention_weighted_encoding, alpha 189 | 190 | class CrossAttention(nn.Module): 191 | """ 192 | Cross Transformer layer 193 | """ 194 | 195 | def __init__(self, dropout, d_model=512, n_head=8): 196 | """ 197 | :param dropout: dropout rate 198 | :param d_model: dimension of hidden state 199 | :param n_head: number of heads in multi head attention 200 | """ 201 | super(CrossAttention, self).__init__() 202 | 203 | self.attention = nn.MultiheadAttention(d_model, n_head, dropout=dropout) 204 | 205 | self.norm1 = nn.LayerNorm(d_model) 206 | self.norm2 = nn.LayerNorm(d_model) 207 | 208 | self.dropout1 = nn.Dropout(dropout) 209 | self.dropout2 = nn.Dropout(dropout) 210 | self.dropout3 = nn.Dropout(dropout) 211 | self.activation = nn.ReLU() 212 | 213 | def forward(self, input1, input2): 214 | # dif_as_kv 215 | input1 = input1.permute(1, 0, 2) 216 | input2 = input2.permute(1, 0, 2) 217 | output_1 = self.cross1(input1, input2) # (Q,K,V) 218 | output_1 = output_1.permute(1, 0, 2) 219 | return output_1 220 | def cross1(self, input,input2): 221 | # RSICCformer_D (diff_as_kv) 222 | attn_output, attn_weight = self.attention(input, input2, input2) # (Q,K,V) 223 | output = input + self.dropout1(attn_output) 224 | output = self.activation(self.norm1(output)) 225 | return output 226 | 227 | 228 | class CFIM(nn.Module): 229 | """ 230 | Attention Network. 231 | """ 232 | 233 | def __init__(self, encoder_dim, embed_dim, attention_dim): 234 | """ 235 | :param encoder_dim: feature size of encoded images 236 | :param decoder_dim: size of decoder's RNN 237 | :param attention_dim: size of the attention network 238 | """ 239 | super(CFIM, self).__init__() 240 | self.nn1 = nn.Linear(encoder_dim, encoder_dim) # linear layer to transform encoded image 241 | self.nn2 = nn.Linear(1000, attention_dim) # linear layer to transform encoded image 242 | self.crossatt = CrossAttention(dropout=0.5) 243 | 244 | 245 | def forward(self, TextFeature, wordFeature, VisionFeature): 246 | """ 247 | Forward propagation. 248 | 249 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 250 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 251 | :return: attention weighted encoding, weights 252 | """ 253 | b, n, channels = TextFeature.size(0), TextFeature.size(1), TextFeature.size(1) 254 | 255 | 256 | visions = torch.chunk(VisionFeature,chunks=2,dim=2) 257 | vision1 = visions[0] 258 | vision2 = visions[1] 259 | # vision1 TextFeature 260 | 261 | TextFeature = self.nn2(TextFeature.unsqueeze(1)) 262 | sim_mapv_T = vision1 * TextFeature 263 | sim_mapv_T = F.softmax(sim_mapv_T, dim=-2) 264 | vision1_T = vision1 * sim_mapv_T + vision1 265 | 266 | 267 | wordFeature = wordFeature.unsqueeze(1) 268 | vision2_w = self.crossatt(vision2, wordFeature)+ vision2 269 | vision = self.nn1(torch.cat([vision1_T,vision2_w],dim=2)) 270 | 271 | out = vision.mean(1).squeeze(1) 272 | return out 273 | 274 | class TextEncoder(nn.Module): 275 | def __init__(self, input_size, hidden_size, output_size): 276 | super(TextEncoder, self).__init__() 277 | self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) 278 | self.fc = nn.Linear(hidden_size, output_size) 279 | 280 | def forward(self, x): 281 | output, _ = self.lstm(x) 282 | output = self.fc(output[:, -1, :]) 283 | return output 284 | 285 | class DecoderWithAttention(nn.Module): 286 | """ 287 | Decoder. 288 | """ 289 | 290 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=1024, dropout=0.5): 291 | """ 292 | :param attention_dim: size of attention network 293 | :param embed_dim: embedding size 294 | :param decoder_dim: size of decoder's RNN 295 | :param vocab_size: size of vocabulary 296 | :param encoder_dim: feature size of encoded images 297 | :param dropout: dropout 298 | """ 299 | super(DecoderWithAttention, self).__init__() 300 | 301 | self.encoder_dim = encoder_dim 302 | self.attention_dim = attention_dim 303 | self.embed_dim = embed_dim 304 | self.decoder_dim = decoder_dim 305 | self.vocab_size = vocab_size 306 | self.dropout = dropout 307 | 308 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network 309 | self.attention2 = CFIM(encoder_dim, embed_dim, attention_dim) 310 | 311 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 312 | self.dropout = nn.Dropout(p=self.dropout) 313 | 314 | #self.decode_step = nn.LSTMCell(attention_dim+attention_dim, decoder_dim, bias=True) # decoding LSTMCell 315 | self.top_down_attention = nn.LSTMCell(decoder_dim+encoder_dim+embed_dim, decoder_dim, bias=True) # decoding LSTMCell 316 | self.language_attention = nn.LSTMCell(encoder_dim+decoder_dim, decoder_dim, bias=True) # decoding LSTMCell 317 | 318 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 319 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 320 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 321 | self.sigmoid = nn.Sigmoid() 322 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 323 | self.init_weights() # initialize some layers with the uniform distribution 324 | self.textencoder = TextEncoder(input_size=embed_dim, hidden_size=decoder_dim, output_size=attention_dim) 325 | self.nnimg = nn.Linear(encoder_dim, attention_dim) 326 | 327 | 328 | def init_weights(self): 329 | """ 330 | Initializes some parameters with values from the uniform distribution, for easier convergence. 331 | """ 332 | self.embedding.weight.data.uniform_(-0.1, 0.1) 333 | self.fc.bias.data.fill_(0) 334 | self.fc.weight.data.uniform_(-0.1, 0.1) 335 | 336 | def load_pretrained_embeddings(self, embeddings): 337 | """ 338 | Loads embedding layer with pre-trained embeddings. 339 | 340 | :param embeddings: pre-trained embeddings 341 | """ 342 | self.embedding.weight = nn.Parameter(embeddings) 343 | 344 | def fine_tune_embeddings(self, fine_tune=True): 345 | """ 346 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). 347 | 348 | :param fine_tune: Allow? 349 | """ 350 | for p in self.embedding.parameters(): 351 | p.requires_grad = fine_tune 352 | 353 | def init_hidden_state(self, encoder_out): 354 | """ 355 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 356 | 357 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 358 | :return: hidden state, cell state 359 | """ 360 | mean_encoder_out = encoder_out.mean(dim=1) 361 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 362 | c = self.init_c(mean_encoder_out) 363 | return h, c 364 | 365 | def forward(self, encoder_out, encoded_captions, caption_lengths): 366 | """ 367 | Forward propagation. 368 | 369 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 370 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) 371 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 372 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 373 | """ 374 | 375 | batch_size = encoder_out.size(0) 376 | encoder_dim = encoder_out.size(-1) 377 | vocab_size = self.vocab_size 378 | 379 | # Flatten image 380 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 381 | num_pixels = encoder_out.size(1) 382 | 383 | # Sort input data by decreasing lengths; why? apparent below 384 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True) 385 | # 64 64 386 | encoder_out = encoder_out[sort_ind] 387 | 388 | #64 196 2048 389 | encoded_captions = encoded_captions[sort_ind] 390 | #64 52 391 | # Embedding 392 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) 393 | embeddings1 = embeddings.clone() 394 | text_feature = self.textencoder(embeddings1) 395 | 396 | 397 | # Initialize LSTM state 398 | h1, c1 = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 399 | h2, c2 = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 400 | encoder_out_mean = encoder_out.mean(1) 401 | encoder_out_mean1 = encoder_out_mean.clone() 402 | img_feature = self.nnimg(encoder_out_mean1).squeeze(1) 403 | 404 | # We won't decode at the position, since we've finished generating as soon as we generate 405 | # So, decoding lengths are actual lengths - 1 406 | decode_lengths = (caption_lengths - 1).tolist() 407 | 408 | # Create tensors to hold word predicion scores and alphas 409 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device) 410 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device) 411 | 412 | # At each time-step, decode by 413 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 414 | # then generate a new word in the decoder with the previous word and the attention weighted encoding 415 | for t in range(max(decode_lengths)): 416 | batch_size_t = sum([l > t for l in decode_lengths]) 417 | 418 | 419 | out_feature = self.attention2(h2[:batch_size_t], embeddings[:batch_size_t, t, :], encoder_out[:batch_size_t]) 420 | 421 | h1, c1 = self.top_down_attention( 422 | torch.cat([h2[:batch_size_t], out_feature, embeddings[:batch_size_t, t, :]], dim=1), 423 | (h1[:batch_size_t], c1[:batch_size_t])) # (batch_size_t, decoder_dim) 424 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 425 | h1[:batch_size_t]) 426 | h2, c2 = self.language_attention( 427 | torch.cat([h1[:batch_size_t], attention_weighted_encoding[:batch_size_t]], dim=1), 428 | (h2[:batch_size_t], c2[:batch_size_t])) # (batch_size_t, decoder_dim) 429 | 430 | preds = self.fc(self.dropout(h2)) # (batch_size_t, vocab_size) 431 | predictions[:batch_size_t, t, :] = preds 432 | 433 | alphas[:batch_size_t, t, :] = alpha 434 | 435 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind, img_feature, text_feature 436 | -------------------------------------------------------------------------------- /model/HCNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | import torch.nn.functional as F 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class HFAM(nn.Module): 9 | 10 | def __init__(self): 11 | super(HFAM, self).__init__() 12 | self.branch1 = nn.Sequential( 13 | nn.Conv2d( 14 | 512, 512, kernel_size=15, stride=2, 15 | padding=7, groups=128, bias=False), 16 | nn.BatchNorm2d(512), 17 | nn.Conv2d( 18 | 512, 512, kernel_size=1, stride=1, 19 | padding=0, bias=False), 20 | ) 21 | self.branch2 = nn.Sequential( 22 | nn.Conv2d( 23 | 1024, 512, kernel_size=11, stride=1, 24 | padding=5, bias=False), 25 | nn.BatchNorm2d(512), 26 | ) 27 | self.branch3 = nn.Sequential( 28 | nn.Conv2d( 29 | 2048, 512, kernel_size=7, stride=1, 30 | padding=3, groups=128, bias=False), 31 | nn.Upsample(scale_factor=2), 32 | nn.BatchNorm2d(512), 33 | ) 34 | self.conv = nn.Sequential( 35 | nn.Conv2d( 36 | 1024, 1024, kernel_size=11, stride=1, 37 | padding=5, bias=False), 38 | nn.BatchNorm2d(1024), 39 | nn.ReLU(inplace=True), # not shown in paper 40 | ) 41 | 42 | self.conv1 = nn.Conv2d(2, 1, 3, padding=1, bias=False) 43 | self.sigmoid1 = nn.Sigmoid() 44 | self.sigmoid2 = nn.Sigmoid() 45 | 46 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 47 | self.max_pool = nn.AdaptiveMaxPool2d(1) 48 | 49 | self.fc1 = nn.Conv2d(512, 512 // 16, 1, bias=False) 50 | self.relu1 = nn.ReLU() 51 | self.fc2 = nn.Conv2d(512 // 16, 512, 1, bias=False) 52 | 53 | def forward(self, x1, x2 ,x3): 54 | x1_1 = self.branch1(x1) 55 | x2_1 = self.branch2(x2) 56 | x3_1 = self.branch3(x3) 57 | 58 | pixavg = torch.mean(x1_1, dim=1, keepdim=True) 59 | detail = self.sigmoid1(pixavg) * x2_1 60 | 61 | chaavg = self.fc2(self.relu1(self.fc1(self.avg_pool(x3_1)))) 62 | seman = self.sigmoid2(chaavg) * x2_1 63 | out = self.conv(torch.cat([detail,seman],dim=1)) 64 | return out 65 | 66 | class Encoder(nn.Module): 67 | """ 68 | CNN_Encoder. 69 | """ 70 | def __init__(self, NetType='resnet50', encoded_image_size=14, attention_method="ByPixel"): 71 | super(Encoder, self).__init__() 72 | self.enc_image_size = encoded_image_size 73 | self.attention_method = attention_method 74 | 75 | self.FF = HFAM() 76 | 77 | # resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101 78 | net = torchvision.models.inception_v3(pretrained=True, transform_input=False) if NetType == 'inception_v3' else \ 79 | torchvision.models.vgg16(pretrained=True) if NetType == 'vgg16' else \ 80 | torchvision.models.resnet50(pretrained=True) if NetType == 'resnet50' else torchvision.models.resnet50(pretrained=True) 81 | # Remove linear and pool layers (since we're not doing classification) 82 | # Specifically, Remove: AdaptiveAvgPool2d(output_size=(1, 1)), Linear(in_features=2048, out_features=1000, bias=True)] 83 | 84 | # modules = list(net.children())[:-2] 85 | modules = list(net.children())[:-1] if NetType == 'inception_v3' or NetType == 'vgg16' else list(net.children())[:-2] 86 | # modules = list(net.children())[:-1] if NetType == 'inception_v3' else list(net.children())[:-2] # -2 for resnet & vgg 87 | if NetType == 'inception_v3': del modules[13] 88 | 89 | self.net = nn.Sequential(*modules) 90 | 91 | # every block of resnet for fusion 92 | if NetType == 'resnet50' or NetType == 'resnet101' or NetType == 'resnet152': 93 | resnet_block1 = list(net.children())[:5] 94 | self.resnet_block1 = nn.Sequential(*resnet_block1) 95 | resnet_block2 = list(net.children())[5] 96 | self.resnet_block2 = nn.Sequential(*resnet_block2) 97 | resnet_block3 = list(net.children())[6] 98 | self.resnet_block3 = nn.Sequential(*resnet_block3) 99 | resnet_block4 = list(net.children())[7] 100 | self.resnet_block4 = nn.Sequential(*resnet_block4) 101 | self.conv4 = nn.Conv2d(in_channels=2048, out_channels=1024, kernel_size=1, stride=1) 102 | self.conv3 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1) 103 | self.conv2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1) 104 | 105 | # if self.attention_method == "ByChannel": 106 | # self.cnn1 = nn.Conv2d(in_channels=2048, out_channels=512, kernel_size=(1, 1), stride=(1, 1), bias=False) 107 | # self.bn1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 108 | # self.relu = nn.ReLU(inplace=True) 109 | 110 | # Resize image to fixed size to allow input images of variable size 111 | self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 112 | # self.adaptive_pool4 = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 113 | # self.adaptive_pool3 = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 114 | 115 | self.fine_tune() 116 | 117 | def forward(self, images): 118 | """ 119 | Forward propagation. 120 | 121 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) 122 | :return: encoded images [batch_size, encoded_image_size=14, encoded_image_size=14, 2048] 123 | """ 124 | # with fusion for resnet 125 | out1 = self.resnet_block1(images) # 256 126 | out2 = self.resnet_block2(out1) # 512 127 | out3 = self.resnet_block3(out2) # 1024 128 | out4 = self.resnet_block4(out3) # 2048 129 | 130 | # # FIXME:concat432 131 | out = self.FF(out2,out3,out4) 132 | 133 | 134 | # without fusion 135 | # out = self.net(images) # (batch_size, 2048, image_size/32, image_size/32) 136 | # if self.attention_method == "ByChannel": 137 | # out = self.relu(self.bn1(self.cnn1(out))) 138 | out = self.adaptive_pool(out) # [batch_size, 2048/512, 8, 8] -> [batch_size, 2048/512, 14, 14] #FIXME:for fusion 139 | out = out.permute(0, 2, 3, 1) 140 | return out 141 | 142 | def fine_tune(self, fine_tune=True): 143 | """ 144 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder. 145 | 146 | :param fine_tune: Allow? 147 | """ 148 | for p in self.net.parameters(): 149 | p.requires_grad = False 150 | # If fine-tuning, only fine-tune convolutional blocks 2 through 4 151 | for c in list(self.net.children())[5:]: # FIXME:maybe try 6: 152 | for p in c.parameters(): 153 | p.requires_grad = fine_tune 154 | 155 | 156 | class Attention(nn.Module): 157 | """ 158 | Attention Network. 159 | """ 160 | 161 | def __init__(self, encoder_dim, decoder_dim, attention_dim): 162 | """ 163 | :param encoder_dim: feature size of encoded images 164 | :param decoder_dim: size of decoder's RNN 165 | :param attention_dim: size of the attention network 166 | """ 167 | super(Attention, self).__init__() 168 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image 169 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output 170 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed 171 | self.relu = nn.ReLU() 172 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 173 | 174 | def forward(self, encoder_out, decoder_hidden): 175 | """ 176 | Forward propagation. 177 | 178 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 179 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 180 | :return: attention weighted encoding, weights 181 | """ 182 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) 183 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) 184 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) 185 | alpha = self.softmax(att) # (batch_size, num_pixels) 186 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim) 187 | #attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)) # (batch_size, pixels, encoder_dim) 188 | return attention_weighted_encoding, alpha 189 | 190 | class CrossAttention(nn.Module): 191 | """ 192 | Cross Transformer layer 193 | """ 194 | 195 | def __init__(self, dropout, d_model=512, n_head=8): 196 | """ 197 | :param dropout: dropout rate 198 | :param d_model: dimension of hidden state 199 | :param n_head: number of heads in multi head attention 200 | """ 201 | super(CrossAttention, self).__init__() 202 | 203 | self.attention = nn.MultiheadAttention(d_model, n_head, dropout=dropout) 204 | 205 | self.norm1 = nn.LayerNorm(d_model) 206 | self.norm2 = nn.LayerNorm(d_model) 207 | 208 | self.dropout1 = nn.Dropout(dropout) 209 | self.dropout2 = nn.Dropout(dropout) 210 | self.dropout3 = nn.Dropout(dropout) 211 | self.activation = nn.ReLU() 212 | 213 | def forward(self, input1, input2): 214 | # dif_as_kv 215 | input1 = input1.permute(1, 0, 2) 216 | input2 = input2.permute(1, 0, 2) 217 | output_1 = self.cross1(input1, input2) # (Q,K,V) 218 | output_1 = output_1.permute(1, 0, 2) 219 | return output_1 220 | def cross1(self, input,input2): 221 | # RSICCformer_D (diff_as_kv) 222 | attn_output, attn_weight = self.attention(input, input2, input2) # (Q,K,V) 223 | output = input + self.dropout1(attn_output) 224 | output = self.activation(self.norm1(output)) 225 | return output 226 | 227 | 228 | class CFIM(nn.Module): 229 | """ 230 | Attention Network. 231 | """ 232 | 233 | def __init__(self, encoder_dim, embed_dim, attention_dim): 234 | """ 235 | :param encoder_dim: feature size of encoded images 236 | :param decoder_dim: size of decoder's RNN 237 | :param attention_dim: size of the attention network 238 | """ 239 | super(CFIM, self).__init__() 240 | self.nn1 = nn.Linear(encoder_dim, encoder_dim) # linear layer to transform encoded image 241 | self.nn2 = nn.Linear(1000, attention_dim) # linear layer to transform encoded image 242 | self.crossatt = CrossAttention(dropout=0.5) 243 | 244 | 245 | def forward(self, TextFeature, wordFeature, VisionFeature): 246 | """ 247 | Forward propagation. 248 | 249 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 250 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 251 | :return: attention weighted encoding, weights 252 | """ 253 | b, n, channels = TextFeature.size(0), TextFeature.size(1), TextFeature.size(1) 254 | 255 | 256 | visions = torch.chunk(VisionFeature,chunks=2,dim=2) 257 | vision1 = visions[0] 258 | vision2 = visions[1] 259 | # vision1 TextFeature 260 | 261 | TextFeature = self.nn2(TextFeature.unsqueeze(1)) 262 | sim_mapv_T = vision1 * TextFeature 263 | sim_mapv_T = F.softmax(sim_mapv_T, dim=-2) 264 | vision1_T = vision1 * sim_mapv_T + vision1 265 | 266 | 267 | wordFeature = wordFeature.unsqueeze(1) 268 | vision2_w = self.crossatt(vision2, wordFeature)+ vision2 269 | vision = self.nn1(torch.cat([vision1_T,vision2_w],dim=2)) 270 | 271 | out = vision.mean(1).squeeze(1) 272 | return out 273 | 274 | class TextEncoder(nn.Module): 275 | def __init__(self, input_size, hidden_size, output_size): 276 | super(TextEncoder, self).__init__() 277 | self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) 278 | self.fc = nn.Linear(hidden_size, output_size) 279 | 280 | def forward(self, x): 281 | output, _ = self.lstm(x) 282 | output = self.fc(output[:, -1, :]) 283 | return output 284 | 285 | class DecoderWithAttention(nn.Module): 286 | """ 287 | Decoder. 288 | """ 289 | 290 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=1024, dropout=0.5): 291 | """ 292 | :param attention_dim: size of attention network 293 | :param embed_dim: embedding size 294 | :param decoder_dim: size of decoder's RNN 295 | :param vocab_size: size of vocabulary 296 | :param encoder_dim: feature size of encoded images 297 | :param dropout: dropout 298 | """ 299 | super(DecoderWithAttention, self).__init__() 300 | 301 | self.encoder_dim = encoder_dim 302 | self.attention_dim = attention_dim 303 | self.embed_dim = embed_dim 304 | self.decoder_dim = decoder_dim 305 | self.vocab_size = vocab_size 306 | self.dropout = dropout 307 | 308 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network 309 | self.attention2 = CFIM(encoder_dim, embed_dim, attention_dim) 310 | 311 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 312 | self.dropout = nn.Dropout(p=self.dropout) 313 | 314 | #self.decode_step = nn.LSTMCell(attention_dim+attention_dim, decoder_dim, bias=True) # decoding LSTMCell 315 | self.top_down_attention = nn.LSTMCell(decoder_dim+encoder_dim+embed_dim, decoder_dim, bias=True) # decoding LSTMCell 316 | self.language_attention = nn.LSTMCell(encoder_dim+decoder_dim, decoder_dim, bias=True) # decoding LSTMCell 317 | 318 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 319 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 320 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 321 | self.sigmoid = nn.Sigmoid() 322 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 323 | self.init_weights() # initialize some layers with the uniform distribution 324 | self.textencoder = TextEncoder(input_size=embed_dim, hidden_size=decoder_dim, output_size=attention_dim) 325 | self.nnimg = nn.Linear(encoder_dim, attention_dim) 326 | 327 | 328 | def init_weights(self): 329 | """ 330 | Initializes some parameters with values from the uniform distribution, for easier convergence. 331 | """ 332 | self.embedding.weight.data.uniform_(-0.1, 0.1) 333 | self.fc.bias.data.fill_(0) 334 | self.fc.weight.data.uniform_(-0.1, 0.1) 335 | 336 | def load_pretrained_embeddings(self, embeddings): 337 | """ 338 | Loads embedding layer with pre-trained embeddings. 339 | 340 | :param embeddings: pre-trained embeddings 341 | """ 342 | self.embedding.weight = nn.Parameter(embeddings) 343 | 344 | def fine_tune_embeddings(self, fine_tune=True): 345 | """ 346 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). 347 | 348 | :param fine_tune: Allow? 349 | """ 350 | for p in self.embedding.parameters(): 351 | p.requires_grad = fine_tune 352 | 353 | def init_hidden_state(self, encoder_out): 354 | """ 355 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 356 | 357 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 358 | :return: hidden state, cell state 359 | """ 360 | mean_encoder_out = encoder_out.mean(dim=1) 361 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 362 | c = self.init_c(mean_encoder_out) 363 | return h, c 364 | 365 | def forward(self, encoder_out, encoded_captions, caption_lengths): 366 | """ 367 | Forward propagation. 368 | 369 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 370 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) 371 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 372 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 373 | """ 374 | 375 | batch_size = encoder_out.size(0) 376 | encoder_dim = encoder_out.size(-1) 377 | vocab_size = self.vocab_size 378 | 379 | # Flatten image 380 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 381 | num_pixels = encoder_out.size(1) 382 | 383 | # Sort input data by decreasing lengths; why? apparent below 384 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True) 385 | # 64 64 386 | encoder_out = encoder_out[sort_ind] 387 | 388 | #64 196 2048 389 | encoded_captions = encoded_captions[sort_ind] 390 | #64 52 391 | # Embedding 392 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) 393 | embeddings1 = embeddings.clone() 394 | text_feature = self.textencoder(embeddings1) 395 | 396 | 397 | # Initialize LSTM state 398 | h1, c1 = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 399 | h2, c2 = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 400 | encoder_out_mean = encoder_out.mean(1) 401 | encoder_out_mean1 = encoder_out_mean.clone() 402 | img_feature = self.nnimg(encoder_out_mean1).squeeze(1) 403 | 404 | # We won't decode at the position, since we've finished generating as soon as we generate 405 | # So, decoding lengths are actual lengths - 1 406 | decode_lengths = (caption_lengths - 1).tolist() 407 | 408 | # Create tensors to hold word predicion scores and alphas 409 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device) 410 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device) 411 | 412 | # At each time-step, decode by 413 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 414 | # then generate a new word in the decoder with the previous word and the attention weighted encoding 415 | for t in range(max(decode_lengths)): 416 | batch_size_t = sum([l > t for l in decode_lengths]) 417 | 418 | 419 | out_feature = self.attention2(h2[:batch_size_t], embeddings[:batch_size_t, t, :], encoder_out[:batch_size_t]) 420 | 421 | h1, c1 = self.top_down_attention( 422 | torch.cat([h2[:batch_size_t], out_feature, embeddings[:batch_size_t, t, :]], dim=1), 423 | (h1[:batch_size_t], c1[:batch_size_t])) # (batch_size_t, decoder_dim) 424 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 425 | h1[:batch_size_t]) 426 | h2, c2 = self.language_attention( 427 | torch.cat([h1[:batch_size_t], attention_weighted_encoding[:batch_size_t]], dim=1), 428 | (h2[:batch_size_t], c2[:batch_size_t])) # (batch_size_t, decoder_dim) 429 | 430 | preds = self.fc(self.dropout(h2)) # (batch_size_t, vocab_size) 431 | predictions[:batch_size_t, t, :] = preds 432 | 433 | alphas[:batch_size_t, t, :] = alpha 434 | 435 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind, img_feature, text_feature 436 | -------------------------------------------------------------------------------- /model/FC5lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | import torch.nn.functional as F 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class FeatureFusion(nn.Module): 9 | 10 | def __init__(self): 11 | super(FeatureFusion, self).__init__() 12 | self.branch1 = nn.Sequential( 13 | nn.Conv2d( 14 | 512, 512, kernel_size=15, stride=2, 15 | padding=7, groups=128, bias=False), 16 | nn.BatchNorm2d(512), 17 | nn.Conv2d( 18 | 512, 512, kernel_size=1, stride=1, 19 | padding=0, bias=False), 20 | ) 21 | self.branch2 = nn.Sequential( 22 | nn.Conv2d( 23 | 1024, 512, kernel_size=11, stride=1, 24 | padding=5, bias=False), 25 | nn.BatchNorm2d(512), 26 | ) 27 | self.branch3 = nn.Sequential( 28 | nn.Conv2d( 29 | 2048, 512, kernel_size=7, stride=1, 30 | padding=3, groups=128, bias=False), 31 | nn.Upsample(scale_factor=2), 32 | nn.BatchNorm2d(512), 33 | ) 34 | self.conv = nn.Sequential( 35 | nn.Conv2d( 36 | 1024, 1024, kernel_size=11, stride=1, 37 | padding=5, bias=False), 38 | nn.BatchNorm2d(1024), 39 | nn.ReLU(inplace=True), # not shown in paper 40 | ) 41 | 42 | self.conv1 = nn.Conv2d(2, 1, 3, padding=1, bias=False) 43 | self.sigmoid1 = nn.Sigmoid() 44 | self.sigmoid2 = nn.Sigmoid() 45 | 46 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 47 | self.max_pool = nn.AdaptiveMaxPool2d(1) 48 | 49 | self.fc1 = nn.Conv2d(512, 512 // 16, 1, bias=False) 50 | self.relu1 = nn.ReLU() 51 | self.fc2 = nn.Conv2d(512 // 16, 512, 1, bias=False) 52 | 53 | def forward(self, x1, x2 ,x3): 54 | x1_1 = self.branch1(x1) 55 | x2_1 = self.branch2(x2) 56 | x3_1 = self.branch3(x3) 57 | 58 | pixavg = torch.mean(x1_1, dim=1, keepdim=True) 59 | detail = self.sigmoid1(pixavg) * x2_1 60 | 61 | chaavg = self.fc2(self.relu1(self.fc1(self.avg_pool(x3_1)))) 62 | seman = self.sigmoid2(chaavg) * x2_1 63 | out = self.conv(torch.cat([detail,seman],dim=1)) 64 | return out 65 | 66 | class Encoder(nn.Module): 67 | """ 68 | CNN_Encoder. 69 | """ 70 | def __init__(self, NetType='resnet50', encoded_image_size=14, attention_method="ByPixel"): 71 | super(Encoder, self).__init__() 72 | self.enc_image_size = encoded_image_size 73 | self.attention_method = attention_method 74 | 75 | self.FF = FeatureFusion() 76 | 77 | # resnet = torchvision.models.resnet101(pretrained=True) # pretrained ImageNet ResNet-101 78 | net = torchvision.models.inception_v3(pretrained=True, transform_input=False) if NetType == 'inception_v3' else \ 79 | torchvision.models.vgg16(pretrained=True) if NetType == 'vgg16' else \ 80 | torchvision.models.resnet50(pretrained=True) if NetType == 'resnet50' else torchvision.models.resnet50(pretrained=True) 81 | # Remove linear and pool layers (since we're not doing classification) 82 | # Specifically, Remove: AdaptiveAvgPool2d(output_size=(1, 1)), Linear(in_features=2048, out_features=1000, bias=True)] 83 | 84 | # modules = list(net.children())[:-2] 85 | modules = list(net.children())[:-1] if NetType == 'inception_v3' or NetType == 'vgg16' else list(net.children())[:-2] 86 | # modules = list(net.children())[:-1] if NetType == 'inception_v3' else list(net.children())[:-2] # -2 for resnet & vgg 87 | if NetType == 'inception_v3': del modules[13] 88 | 89 | self.net = nn.Sequential(*modules) 90 | 91 | # every block of resnet for fusion 92 | if NetType == 'resnet50' or NetType == 'resnet101' or NetType == 'resnet152': 93 | resnet_block1 = list(net.children())[:5] 94 | self.resnet_block1 = nn.Sequential(*resnet_block1) 95 | resnet_block2 = list(net.children())[5] 96 | self.resnet_block2 = nn.Sequential(*resnet_block2) 97 | resnet_block3 = list(net.children())[6] 98 | self.resnet_block3 = nn.Sequential(*resnet_block3) 99 | resnet_block4 = list(net.children())[7] 100 | self.resnet_block4 = nn.Sequential(*resnet_block4) 101 | self.conv4 = nn.Conv2d(in_channels=2048, out_channels=1024, kernel_size=1, stride=1) 102 | self.conv3 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=1, stride=1) 103 | self.conv2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1) 104 | 105 | # if self.attention_method == "ByChannel": 106 | # self.cnn1 = nn.Conv2d(in_channels=2048, out_channels=512, kernel_size=(1, 1), stride=(1, 1), bias=False) 107 | # self.bn1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 108 | # self.relu = nn.ReLU(inplace=True) 109 | 110 | # Resize image to fixed size to allow input images of variable size 111 | self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 112 | # self.adaptive_pool4 = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 113 | # self.adaptive_pool3 = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size)) 114 | 115 | self.fine_tune() 116 | 117 | def forward(self, images): 118 | """ 119 | Forward propagation. 120 | 121 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) 122 | :return: encoded images [batch_size, encoded_image_size=14, encoded_image_size=14, 2048] 123 | """ 124 | # with fusion for resnet 125 | out1 = self.resnet_block1(images) # 256 126 | out2 = self.resnet_block2(out1) # 512 127 | out3 = self.resnet_block3(out2) # 1024 128 | out4 = self.resnet_block4(out3) # 2048 129 | 130 | # # FIXME:concat432 131 | out = self.FF(out2,out3,out4) 132 | 133 | 134 | # without fusion 135 | # out = self.net(images) # (batch_size, 2048, image_size/32, image_size/32) 136 | # if self.attention_method == "ByChannel": 137 | # out = self.relu(self.bn1(self.cnn1(out))) 138 | out = self.adaptive_pool(out) # [batch_size, 2048/512, 8, 8] -> [batch_size, 2048/512, 14, 14] #FIXME:for fusion 139 | out = out.permute(0, 2, 3, 1) 140 | return out 141 | 142 | def fine_tune(self, fine_tune=True): 143 | """ 144 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder. 145 | 146 | :param fine_tune: Allow? 147 | """ 148 | for p in self.net.parameters(): 149 | p.requires_grad = False 150 | # If fine-tuning, only fine-tune convolutional blocks 2 through 4 151 | for c in list(self.net.children())[5:]: # FIXME:maybe try 6: 152 | for p in c.parameters(): 153 | p.requires_grad = fine_tune 154 | 155 | 156 | class Attention(nn.Module): 157 | """ 158 | Attention Network. 159 | """ 160 | 161 | def __init__(self, encoder_dim, decoder_dim, attention_dim): 162 | """ 163 | :param encoder_dim: feature size of encoded images 164 | :param decoder_dim: size of decoder's RNN 165 | :param attention_dim: size of the attention network 166 | """ 167 | super(Attention, self).__init__() 168 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image 169 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output 170 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed 171 | self.relu = nn.ReLU() 172 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 173 | 174 | def forward(self, encoder_out, decoder_hidden): 175 | """ 176 | Forward propagation. 177 | 178 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 179 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 180 | :return: attention weighted encoding, weights 181 | """ 182 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) 183 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) 184 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) 185 | alpha = self.softmax(att) # (batch_size, num_pixels) 186 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim) 187 | #attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)) # (batch_size, pixels, encoder_dim) 188 | return attention_weighted_encoding, alpha 189 | 190 | class CrossAttention(nn.Module): 191 | """ 192 | Cross Transformer layer 193 | """ 194 | 195 | def __init__(self, dropout, d_model=512, n_head=8): 196 | """ 197 | :param dropout: dropout rate 198 | :param d_model: dimension of hidden state 199 | :param n_head: number of heads in multi head attention 200 | """ 201 | super(CrossAttention, self).__init__() 202 | 203 | self.attention = nn.MultiheadAttention(d_model, n_head, dropout=dropout) 204 | 205 | self.norm1 = nn.LayerNorm(d_model) 206 | self.norm2 = nn.LayerNorm(d_model) 207 | 208 | self.dropout1 = nn.Dropout(dropout) 209 | self.dropout2 = nn.Dropout(dropout) 210 | self.dropout3 = nn.Dropout(dropout) 211 | self.activation = nn.ReLU() 212 | 213 | def forward(self, input1, input2): 214 | # dif_as_kv 215 | input1 = input1.permute(1, 0, 2) 216 | input2 = input2.permute(1, 0, 2) 217 | output_1 = self.cross1(input1, input2) # (Q,K,V) 218 | output_1 = output_1.permute(1, 0, 2) 219 | return output_1 220 | def cross1(self, input,input2): 221 | # RSICCformer_D (diff_as_kv) 222 | attn_output, attn_weight = self.attention(input, input2, input2) # (Q,K,V) 223 | output = input + self.dropout1(attn_output) 224 | output = self.activation(self.norm1(output)) 225 | return output 226 | 227 | 228 | class TVAttention(nn.Module): 229 | """ 230 | Attention Network. 231 | """ 232 | 233 | def __init__(self, encoder_dim, embed_dim, attention_dim): 234 | """ 235 | :param encoder_dim: feature size of encoded images 236 | :param decoder_dim: size of decoder's RNN 237 | :param attention_dim: size of the attention network 238 | """ 239 | super(TVAttention, self).__init__() 240 | self.nn1 = nn.Linear(encoder_dim, encoder_dim) # linear layer to transform encoded image 241 | self.nn2 = nn.Linear(1000, attention_dim) # linear layer to transform encoded image 242 | self.crossatt = CrossAttention(dropout=0.5) 243 | 244 | 245 | def forward(self, TextFeature, wordFeature, VisionFeature): 246 | """ 247 | Forward propagation. 248 | 249 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 250 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 251 | :return: attention weighted encoding, weights 252 | """ 253 | b, n, channels = TextFeature.size(0), TextFeature.size(1), TextFeature.size(1) 254 | 255 | 256 | visions = torch.chunk(VisionFeature,chunks=2,dim=2) 257 | vision1 = visions[0] 258 | vision2 = visions[1] 259 | # vision1 TextFeature 260 | 261 | TextFeature = self.nn2(TextFeature.unsqueeze(1)) 262 | # sim_mapv_W = torch.matmul(vision2, wordFeature) # 64 196 263 | # sim_mapv_W = (channels ** -.5) * sim_mapv_W 264 | # sim_mapv_W = F.softmax(sim_mapv_W, dim=-1) 265 | sim_mapv_T = vision1 * TextFeature 266 | sim_mapv_T = F.softmax(sim_mapv_T, dim=-2) 267 | vision1_T = vision1 * sim_mapv_T + vision1 268 | 269 | 270 | # VisionFeature =VisionFeature.unsqueeze(1) 271 | wordFeature = wordFeature.unsqueeze(1) 272 | #sim_mapv_W = torch.matmul(vision2, wordFeature) # 64 196 273 | #sim_mapv_W = (channels ** -.5) * sim_mapv_W 274 | #sim_mapv_W = F.softmax(sim_mapv_W, dim=-1) 275 | """ 276 | sim_mapv_W = vision2 * wordFeature 277 | sim_mapv_W = F.softmax(sim_mapv_W, dim=-2) 278 | vision2_w = vision2 * sim_mapv_W + vision2 279 | """ 280 | vision2_w = self.crossatt(vision2, wordFeature)+ vision2 281 | 282 | vision = self.nn1(torch.cat([vision1_T,vision2_w],dim=2)) 283 | 284 | out = vision.mean(1).squeeze(1) 285 | return out 286 | 287 | class TextEncoder(nn.Module): 288 | def __init__(self, input_size, hidden_size, output_size): 289 | super(TextEncoder, self).__init__() 290 | self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) 291 | self.fc = nn.Linear(hidden_size, output_size) 292 | 293 | def forward(self, x): 294 | output, _ = self.lstm(x) 295 | output = self.fc(output[:, -1, :]) 296 | return output 297 | 298 | class DecoderWithAttention(nn.Module): 299 | """ 300 | Decoder. 301 | """ 302 | 303 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=1024, dropout=0.5): 304 | """ 305 | :param attention_dim: size of attention network 306 | :param embed_dim: embedding size 307 | :param decoder_dim: size of decoder's RNN 308 | :param vocab_size: size of vocabulary 309 | :param encoder_dim: feature size of encoded images 310 | :param dropout: dropout 311 | """ 312 | super(DecoderWithAttention, self).__init__() 313 | 314 | self.encoder_dim = encoder_dim 315 | self.attention_dim = attention_dim 316 | self.embed_dim = embed_dim 317 | self.decoder_dim = decoder_dim 318 | self.vocab_size = vocab_size 319 | self.dropout = dropout 320 | 321 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network 322 | self.attention2 = TVAttention(encoder_dim, embed_dim, attention_dim) 323 | 324 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 325 | self.dropout = nn.Dropout(p=self.dropout) 326 | 327 | #self.decode_step = nn.LSTMCell(attention_dim+attention_dim, decoder_dim, bias=True) # decoding LSTMCell 328 | self.top_down_attention = nn.LSTMCell(decoder_dim+encoder_dim+embed_dim, decoder_dim, bias=True) # decoding LSTMCell 329 | self.language_attention = nn.LSTMCell(encoder_dim+decoder_dim, decoder_dim, bias=True) # decoding LSTMCell 330 | 331 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 332 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 333 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 334 | self.sigmoid = nn.Sigmoid() 335 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 336 | self.init_weights() # initialize some layers with the uniform distribution 337 | self.textencoder = TextEncoder(input_size=embed_dim, hidden_size=decoder_dim, output_size=attention_dim) 338 | self.nnimg = nn.Linear(encoder_dim, attention_dim) 339 | 340 | 341 | def init_weights(self): 342 | """ 343 | Initializes some parameters with values from the uniform distribution, for easier convergence. 344 | """ 345 | self.embedding.weight.data.uniform_(-0.1, 0.1) 346 | self.fc.bias.data.fill_(0) 347 | self.fc.weight.data.uniform_(-0.1, 0.1) 348 | 349 | def load_pretrained_embeddings(self, embeddings): 350 | """ 351 | Loads embedding layer with pre-trained embeddings. 352 | 353 | :param embeddings: pre-trained embeddings 354 | """ 355 | self.embedding.weight = nn.Parameter(embeddings) 356 | 357 | def fine_tune_embeddings(self, fine_tune=True): 358 | """ 359 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). 360 | 361 | :param fine_tune: Allow? 362 | """ 363 | for p in self.embedding.parameters(): 364 | p.requires_grad = fine_tune 365 | 366 | def init_hidden_state(self, encoder_out): 367 | """ 368 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 369 | 370 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 371 | :return: hidden state, cell state 372 | """ 373 | mean_encoder_out = encoder_out.mean(dim=1) 374 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 375 | c = self.init_c(mean_encoder_out) 376 | return h, c 377 | 378 | def forward(self, encoder_out, encoded_captions, caption_lengths): 379 | """ 380 | Forward propagation. 381 | 382 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 383 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) 384 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 385 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 386 | """ 387 | 388 | batch_size = encoder_out.size(0) 389 | encoder_dim = encoder_out.size(-1) 390 | vocab_size = self.vocab_size 391 | 392 | # Flatten image 393 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 394 | num_pixels = encoder_out.size(1) 395 | 396 | # Sort input data by decreasing lengths; why? apparent below 397 | caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True) 398 | # 64 64 399 | encoder_out = encoder_out[sort_ind] 400 | 401 | #64 196 2048 402 | encoded_captions = encoded_captions[sort_ind] 403 | #64 52 404 | # Embedding 405 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) 406 | embeddings1 = embeddings.clone() 407 | text_feature = self.textencoder(embeddings1) 408 | 409 | 410 | # Initialize LSTM state 411 | h1, c1 = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 412 | h2, c2 = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 413 | encoder_out_mean = encoder_out.mean(1) 414 | encoder_out_mean1 = encoder_out_mean.clone() 415 | img_feature = self.nnimg(encoder_out_mean1).squeeze(1) 416 | 417 | # We won't decode at the position, since we've finished generating as soon as we generate 418 | # So, decoding lengths are actual lengths - 1 419 | decode_lengths = (caption_lengths - 1).tolist() 420 | 421 | # Create tensors to hold word predicion scores and alphas 422 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device) 423 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device) 424 | 425 | # At each time-step, decode by 426 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 427 | # then generate a new word in the decoder with the previous word and the attention weighted encoding 428 | for t in range(max(decode_lengths)): 429 | batch_size_t = sum([l > t for l in decode_lengths]) 430 | ''' 431 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 432 | h[:batch_size_t]) 433 | gate = self.sigmoid(self.f_beta(h1[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) 434 | attention_weighted_encoding = gate * attention_weighted_encoding 435 | ''' 436 | 437 | out_feature = self.attention2(h2[:batch_size_t], embeddings[:batch_size_t, t, :], encoder_out[:batch_size_t]) 438 | 439 | h1, c1 = self.top_down_attention( 440 | torch.cat([h2[:batch_size_t], out_feature, embeddings[:batch_size_t, t, :]], dim=1), 441 | (h1[:batch_size_t], c1[:batch_size_t])) # (batch_size_t, decoder_dim) 442 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 443 | h1[:batch_size_t]) 444 | h2, c2 = self.language_attention( 445 | torch.cat([h1[:batch_size_t], attention_weighted_encoding[:batch_size_t]], dim=1), 446 | (h2[:batch_size_t], c2[:batch_size_t])) # (batch_size_t, decoder_dim) 447 | 448 | preds = self.fc(self.dropout(h2)) # (batch_size_t, vocab_size) 449 | predictions[:batch_size_t, t, :] = preds 450 | 451 | alphas[:batch_size_t, t, :] = alpha 452 | 453 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind, img_feature, text_feature 454 | -------------------------------------------------------------------------------- /train_HCNet_UCM.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch.nn.functional as F 3 | import time 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim 6 | import torch.utils.data 7 | import torchvision.transforms as transforms 8 | from torch import nn 9 | from torch.nn.utils.rnn import pack_padded_sequence 10 | from models.FC5lstm import * 11 | from datasets import * 12 | from utils import * 13 | from nltk.translate.bleu_score import corpus_bleu 14 | import argparse 15 | import codecs 16 | import numpy as np 17 | from torch.optim.lr_scheduler import StepLR 18 | dataset = "UCM" 19 | model = "33LFC5LSTM" 20 | 21 | def train(args, train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch): 22 | """ 23 | Performs one epoch's training. 24 | 25 | :param train_loader: DataLoader for training data 26 | :param encoder: encoder model 27 | :param decoder: decoder model 28 | :param criterion: loss layer 29 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning) 30 | :param decoder_optimizer: optimizer to update decoder's weights 31 | :param epoch: epoch number 32 | """ 33 | 34 | encoder.train() 35 | decoder.train() # train mode (dropout and batchnorm is used) 36 | 37 | batch_time = AverageMeter() # forward prop. + back prop. time 38 | data_time = AverageMeter() # data loading time 39 | losses = AverageMeter() # loss (per word decoded) 40 | top5accs = AverageMeter() # top5 accuracy 41 | start = time.time() 42 | 43 | # Batches 44 | best_bleu4 = 0. # BLEU-4 score right now 45 | steps_since_improvement = 0 46 | final_args = {"emb_dim": args.emb_dim, 47 | "attention_dim": args.attention_dim, 48 | "decoder_dim": args.decoder_dim, 49 | "n_heads": args.n_heads, 50 | "dropout": args.dropout, 51 | "decoder_mode": args.decoder_mode, 52 | "attention_method": args.attention_method, 53 | "encoder_layers": args.encoder_layers, 54 | "decoder_layers": args.decoder_layers} 55 | for i, (imgs, caps, caplens) in enumerate(train_loader): 56 | data_time.update(time.time() - start) 57 | 58 | # Move to GPU, if available 59 | # print(caps) 60 | # print(caplens) 61 | imgs = imgs.to(device) 62 | caps = caps.to(device) 63 | caplens = caplens.to(device) 64 | 65 | # Forward prop. 66 | imgs = encoder(imgs) 67 | # imgs: [batch_size, 14, 14, 2048] 68 | # caps: [batch_size, 52] 69 | # caplens: [batch_size, 1] 70 | if args.decoder_mode == 'lstm_attention': 71 | scores, caps_sorted, decode_lengths, alphas, sort_ind, img_feature, text_feature = decoder(imgs, caps, caplens) 72 | else: 73 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 74 | 75 | # Since we decoded starting with , the targets are all words after , up to 76 | targets = caps_sorted[:, 1:] 77 | 78 | # Remove timesteps that we didn't decode at, or are pads 79 | # pack_padded_sequence is an easy trick to do this 80 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data 81 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data 82 | # print(scores.size()) 83 | # print(targets.size()) 84 | 85 | image_features = img_feature / img_feature.norm(dim=1, keepdim=True) 86 | text_features = text_feature / text_feature.norm(dim=1, keepdim=True) 87 | logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 88 | # 计算余弦相似度 89 | logit_scale = logit_scale.exp() 90 | logits_per_image = logit_scale * image_features @ text_features.t() 91 | logits_per_text = logits_per_image.t() 92 | labels = torch.arange(logits_per_image.shape[0], dtype=torch.long) 93 | #logits = logits.to(device) 94 | labels = labels.to(device) 95 | loss1 = (F.cross_entropy(logits_per_image, labels) +F.cross_entropy(logits_per_text, labels)) / 2 96 | 97 | # Calculate loss 98 | loss = criterion(scores, targets) 99 | 100 | total_loss = loss1+ 5*loss 101 | # Add doubly stochastic attention regularization 102 | # Second loss, mentioned in paper "Show, Attend and Tell: Neural Image Caption Generation with Visual Attention" 103 | # https://arxiv.org/abs/1502.03044 104 | # In section 4.2.1 Doubly stochastic attention regularization: We know the weights sum to 1 at a given timestep. 105 | # But we also encourage the weights at a single pixel p to sum to 1 across all timesteps T. 106 | # This means we want the model to attend to every pixel over the course of generating the entire sequence. 107 | # Therefore, we want to minimize the difference between 1 and the sum of a pixel's weights across all timesteps. 108 | 109 | 110 | # Back prop. 111 | decoder_optimizer.zero_grad() 112 | if encoder_optimizer is not None: 113 | encoder_optimizer.zero_grad() 114 | total_loss.backward() 115 | 116 | # Clip gradients 117 | if args.grad_clip is not None: 118 | clip_gradient(decoder_optimizer, args.grad_clip) 119 | if encoder_optimizer is not None: 120 | clip_gradient(encoder_optimizer, args.grad_clip) 121 | 122 | # Update weights 123 | decoder_optimizer.step() 124 | 125 | if encoder_optimizer is not None: 126 | encoder_optimizer.step() 127 | 128 | 129 | # Keep track of metrics 130 | top5 = accuracy(scores, targets, 5) 131 | losses.update(loss.item(), sum(decode_lengths)) 132 | top5accs.update(top5, sum(decode_lengths)) 133 | batch_time.update(time.time() - start) 134 | start = time.time() 135 | if i % args.print_freq == 0: 136 | # print('TIME: ', time.strftime("%m-%d %H : %M : %S", time.localtime(time.time()))) 137 | print("Epoch: {}/{} step: {}/{} Loss: {} AVG_Loss: {} Top-5 Accuracy: {} Batch_time: {}s".format(epoch+0, args.epochs, i+0, len(train_loader), losses.val, losses.avg, top5accs.val, batch_time.val)) 138 | 139 | 140 | def validate(args, val_loader, encoder, decoder, criterion): 141 | """ 142 | Performs one epoch's validation. 143 | 144 | :param val_loader: DataLoader for validation data. 145 | :param encoder: encoder model 146 | :param decoder: decoder model 147 | :param criterion: loss layer 148 | :return: score_dict {'Bleu_1': 0., 'Bleu_2': 0., 'Bleu_3': 0., 'Bleu_4': 0., 'METEOR': 0., 'ROUGE_L': 0., 'CIDEr': 1.} 149 | """ 150 | decoder.eval() # eval mode (no dropout or batchnorm) 151 | if encoder is not None: 152 | encoder.eval() 153 | 154 | batch_time = AverageMeter() 155 | losses = AverageMeter() 156 | top5accs = AverageMeter() 157 | 158 | start = time.time() 159 | 160 | references = list() # references (true captions) for calculating BLEU-4 score 161 | hypotheses = list() # hypotheses (predictions) 162 | 163 | # explicitly disable gradient calculation to avoid CUDA memory error 164 | with torch.no_grad(): 165 | # Batches 166 | for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader): 167 | 168 | # Move to device, if available 169 | imgs = imgs.to(device) 170 | caps = caps.to(device) 171 | caplens = caplens.to(device) 172 | 173 | # Forward prop. 174 | if encoder is not None: 175 | imgs = encoder(imgs) 176 | 177 | if args.decoder_mode == 'lstm_attention': 178 | scores, caps_sorted, decode_lengths, alphas, sort_ind, img_feature, text_feature = decoder(imgs, caps, caplens) 179 | else: 180 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 181 | 182 | # Since we decoded starting with , the targets are all words after , up to 183 | targets = caps_sorted[:, 1:] 184 | 185 | # Remove timesteps that we didn't decode at, or are pads 186 | # pack_padded_sequence is an easy trick to do this 187 | scores_copy = scores.clone() 188 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data 189 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data 190 | 191 | # Calculate loss 192 | loss = criterion(scores, targets) 193 | 194 | # Add doubly stochastic attention regularization 195 | 196 | # Keep track of metrics 197 | losses.update(loss.item(), sum(decode_lengths)) 198 | top5 = accuracy(scores, targets, 5) 199 | top5accs.update(top5, sum(decode_lengths)) 200 | batch_time.update(time.time() - start) 201 | start = time.time() 202 | 203 | 204 | # Store references (true captions), and hypothesis (prediction) for each image 205 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - 206 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] 207 | 208 | # References 209 | allcaps = allcaps[sort_ind] # because images were sorted in the decoder 210 | for j in range(allcaps.shape[0]): 211 | img_caps = allcaps[j].tolist() 212 | img_captions = list( 213 | map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}], 214 | img_caps)) # remove and pads 215 | references.append(img_captions) 216 | 217 | # Hypotheses 218 | _, preds = torch.max(scores_copy, dim=2) 219 | preds = preds.tolist() 220 | temp_preds = list() 221 | for j, p in enumerate(preds): 222 | temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads 223 | preds = temp_preds 224 | hypotheses.extend(preds) 225 | 226 | assert len(references) == len(hypotheses) 227 | 228 | # Calculate BLEU1~4, METEOR, ROUGE_L, CIDEr scores 229 | print('Validation:') 230 | metrics = get_eval_score(references, hypotheses) 231 | 232 | # print("EVA LOSS: {} TOP-5 Accuracy {} BLEU-1 {} BLEU2 {} BLEU3 {} BLEU-4 {} METEOR {} ROUGE_L {} CIDEr {}".format 233 | # (losses.avg, top5accs.avg, metrics["Bleu_1"], metrics["Bleu_2"], metrics["Bleu_3"], metrics["Bleu_4"], 234 | # metrics["METEOR"],metrics["ROUGE_L"], metrics["CIDEr"])) 235 | print('\n') 236 | 237 | return metrics 238 | 239 | 240 | if __name__ == '__main__': 241 | parser = argparse.ArgumentParser(description='Image_Captioning') 242 | 243 | # Data parameters 244 | parser.add_argument('--data_folder', default="./data/UCM_images1",help='folder with data files saved by create_input_files.py.') 245 | parser.add_argument('--data_name', default="UCM_5_cap_per_img_4_min_word_freq",help='base name shared by data files.') 246 | 247 | 248 | # Model parameters 249 | parser.add_argument('--emb_dim', type=int, default=512, help='dimension of word embeddings.')#300 250 | parser.add_argument('--attention_dim', type=int, default=512, help='dimension of attention linear layers.') 251 | parser.add_argument('--decoder_dim', type=int, default=1000, help='dimension of decoder RNN.') 252 | parser.add_argument('--n_heads', type=int, default=8, help='Multi-head attention in Transformer.') 253 | parser.add_argument('--dropout', type=float, default=0.5, help='dropout') 254 | 255 | # FIXME:note to change these 256 | parser.add_argument('--encoder_mode', default="resnet50", help='which model does encoder use?') # inception_v3 or vgg16 or vgg19 or resnet50 or resnet101 or resnet152 257 | parser.add_argument('--decoder_mode', default="lstm_attention", help='which model does decoder use?') # lstm or lstm_attention or transformer or transformer_decoder 258 | 259 | parser.add_argument('--attention_method', default="ByPixel", help='which attention method to use?') # ByPixel or ByChannel 260 | parser.add_argument('--encoder_layers', type=int, default=3, help='the number of layers of encoder in Transformer.') 261 | parser.add_argument('--decoder_layers', type=int, default=3, help='the number of layers of decoder in Transformer.') 262 | 263 | 264 | # Training parameters 265 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train for (if early stopping is not triggered).') 266 | parser.add_argument('--stop_criteria', type=int, default=20, help='training stop if epochs_since_improvement == stop_criteria') 267 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 268 | parser.add_argument('--print_freq', type=int, default=100, help='print training/validation stats every __ batches.') 269 | parser.add_argument('--workers', type=int, default=16, help='for data-loading; right now, only 0 works with h5pys in windows.') 270 | parser.add_argument('--encoder_lr', type=float, default=1e-4, help='learning rate for encoder if fine-tuning.') 271 | parser.add_argument('--decoder_lr', type=float, default=1e-4, help='learning rate for decoder.') 272 | parser.add_argument('--grad_clip', type=float, default=5., help='clip gradients at an absolute value of.') 273 | parser.add_argument('--alpha_c', type=float, default=1., help='regularization parameter for doubly stochastic attention, as in the paper.') 274 | parser.add_argument('--fine_tune_encoder', type=bool, default= True, help='whether fine-tune encoder or not') 275 | parser.add_argument('--fine_tune_embedding', type=bool, default= True, help='whether fine-tune word embeddings or not') 276 | parser.add_argument('--checkpoint', default=None, help='path to checkpoint, None if none.') 277 | parser.add_argument('--embedding_path', default=None, help='path to pre-trained word Embedding.') 278 | 279 | args = parser.parse_args() 280 | 281 | for encoder_layers, decoder_layers in [(3,3)]: #,,(0,6),(2,2), 282 | args.encoder_layers = encoder_layers 283 | args.decoder_layers = decoder_layers 284 | # args.encoder_mode = encoder_mode 285 | 286 | # load checkpoint, these parameters can't be modified 287 | final_args = {"emb_dim": args.emb_dim, 288 | "attention_dim": args.attention_dim, 289 | "decoder_dim": args.decoder_dim, 290 | "n_heads": args.n_heads, 291 | "dropout": args.dropout, 292 | "decoder_mode": args.decoder_mode, 293 | "attention_method": args.attention_method, 294 | "encoder_layers": args.encoder_layers, 295 | "decoder_layers": args.decoder_layers} 296 | 297 | start_epoch = 0 298 | best_bleu4 = 0. # BLEU-4 score right now 299 | epochs_since_improvement = 0 # keeps track of number of epochs since there's been an improvement in validation BLEU 300 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors 301 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead 302 | # print(device) 303 | 304 | # Read word map 305 | word_map_file = os.path.join(args.data_folder, 'WORDMAP_' + args.data_name + '.json') 306 | with open(word_map_file, 'r') as j: 307 | word_map = json.load(j) 308 | 309 | # Initialize / load checkpoint 310 | if args.checkpoint is None: 311 | 312 | # Encoder 313 | encoder = Encoder(NetType=args.encoder_mode) 314 | 315 | encoder.fine_tune(args.fine_tune_encoder) 316 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 317 | lr=args.encoder_lr) if args.fine_tune_encoder else None 318 | encoder_lr_scheduler = StepLR(encoder_optimizer,step_size=600,gamma=0.9) 319 | # set the encoder_dim 320 | encoder_dim = 512 if args.encoder_mode == 'vgg16' else 512 if args.encoder_mode == 'vgg19' \ 321 | else 2048 # FIXME: encoder_dim depends on the model 322 | 323 | # different Decoder 324 | if args.decoder_mode == "lstm_attention": 325 | decoder = DecoderWithAttention(attention_dim=args.attention_dim, 326 | embed_dim= args.emb_dim, 327 | decoder_dim= args.decoder_dim, 328 | vocab_size=len(word_map), 329 | dropout=args.dropout) 330 | 331 | decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()), 332 | lr=args.decoder_lr) 333 | decoder_lr_scheduler = StepLR(decoder_optimizer,step_size=600,gamma=0.9) 334 | 335 | # load pre-trained word embedding 336 | if args.embedding_path is not None: 337 | all_word_embeds = {} 338 | for i, line in enumerate(codecs.open(args.embedding_path, 'r', 'utf-8')): 339 | s = line.strip().split() 340 | all_word_embeds[s[0]] = np.array([float(i) for i in s[1:]]) 341 | 342 | # change emb_dim 343 | args.emb_dim = list(all_word_embeds.values())[-1].size 344 | word_embeds = np.random.uniform(-np.sqrt(0.06), np.sqrt(0.06), (len(word_map), args.emb_dim)) 345 | for w in word_map: 346 | if w in all_word_embeds: 347 | word_embeds[word_map[w]] = all_word_embeds[w] 348 | elif w.lower() in all_word_embeds: 349 | word_embeds[word_map[w]] = all_word_embeds[w.lower()] 350 | else: 351 | # 352 | embedding_i = torch.ones(1, args.emb_dim) 353 | torch.nn.init.xavier_uniform_(embedding_i) 354 | word_embeds[word_map[w]] = embedding_i 355 | 356 | word_embeds = torch.FloatTensor(word_embeds).to(device) 357 | decoder.load_pretrained_embeddings(word_embeds) 358 | decoder.fine_tune_embeddings(args.fine_tune_embedding) 359 | print('Loaded {} pre-trained word embeddings.'.format(len(word_embeds))) 360 | 361 | else: 362 | print("isNone") 363 | print(args.checkpoint) 364 | checkpoint = torch.load(args.checkpoint, map_location=str(device)) 365 | start_epoch = checkpoint['epoch'] + 1 366 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 367 | best_bleu1 = checkpoint['metrics']["Bleu_1"] 368 | encoder = checkpoint['encoder'] 369 | encoder_optimizer = checkpoint['encoder_optimizer'] 370 | decoder = checkpoint['decoder'] 371 | decoder_optimizer = checkpoint['decoder_optimizer'] 372 | decoder.fine_tune_embeddings(args.fine_tune_embedding) 373 | # load final_args from checkpoint 374 | final_args = checkpoint['final_args'] 375 | for key in final_args.keys(): 376 | args.__setattr__(key, final_args[key]) 377 | if args.fine_tune_encoder is True and encoder_optimizer is None: 378 | print("Encoder_Optimizer is None, Creating new Encoder_Optimizer!") 379 | encoder.fine_tune(args.fine_tune_encoder) 380 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 381 | lr=args.encoder_lr) 382 | 383 | # Move to GPU, if available 384 | decoder = decoder.to(device) 385 | encoder = encoder.to(device) 386 | print("Encoder_mode:{} Decoder_mode:{}".format(args.encoder_mode,args.decoder_mode)) 387 | print("encoder_layers {} decoder_layers {} n_heads {} dropout {} attention_method {} encoder_lr {} " 388 | "decoder_lr {} alpha_c {}".format(args.encoder_layers, args.decoder_layers, args.n_heads, args.dropout, 389 | args.attention_method, args.encoder_lr, args.decoder_lr, args.alpha_c)) 390 | # print(encoder) 391 | # print(decoder) 392 | 393 | # Loss function 394 | criterion = nn.CrossEntropyLoss().to(device) 395 | 396 | # Custom dataloaders 397 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 398 | # normalize = transforms.Normalize(mean=[0.399, 0.410, 0.371], std=[0.151, 0.138, 0.134]) 399 | # normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 400 | 401 | # pin_memory: If True, the data loader will copy Tensors into CUDA pinned memory before returning them. 402 | # If your data elements are a custom type, or your collate_fn returns a batch that is a custom type. 403 | train_loader = torch.utils.data.DataLoader( 404 | CaptionDataset(args.data_folder, args.data_name, 'TRAIN', transform=transforms.Compose([transforms.RandomHorizontalFlip(),normalize])), 405 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 406 | val_loader = torch.utils.data.DataLoader( 407 | CaptionDataset(args.data_folder, args.data_name, 'VAL', transform=transforms.Compose([transforms.RandomHorizontalFlip(),normalize])), 408 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 409 | 410 | # Epochs 411 | for epoch in range(start_epoch, args.epochs): 412 | 413 | # Decay learning rate if there is no improvement for 5 consecutive epochs, and terminate training after 25 414 | # 8 20 415 | if epochs_since_improvement == args.stop_criteria: 416 | print("the model has not improved in the last {} epochs".format(args.stop_criteria)) 417 | break 418 | if epochs_since_improvement > 0 and epochs_since_improvement % 5 == 0: 419 | adjust_learning_rate(decoder_optimizer, 0.8) 420 | if args.fine_tune_encoder and encoder_optimizer is not None: 421 | print(encoder_optimizer) 422 | adjust_learning_rate(encoder_optimizer, 0.8) 423 | 424 | # One epoch's training 425 | train(args, 426 | train_loader=train_loader, 427 | # val_loader=val_loader, 428 | encoder=encoder, 429 | decoder=decoder, 430 | criterion=criterion, 431 | encoder_optimizer=encoder_optimizer, 432 | #encoder_lr_scheduler=encoder_lr_scheduler, 433 | decoder_optimizer=decoder_optimizer, 434 | #decoder_lr_scheduler=decoder_lr_scheduler, 435 | epoch=epoch) 436 | 437 | 438 | # One epoch's validation 439 | metrics = validate(args, 440 | val_loader=val_loader, 441 | encoder=encoder, 442 | decoder=decoder, 443 | criterion=criterion) 444 | 445 | recent_bleu4 = metrics["Bleu_4"] 446 | 447 | # Check if there was an improvement 448 | is_best = recent_bleu4 > best_bleu4 449 | best_bleu4 = max(recent_bleu4, best_bleu4) 450 | if not is_best: 451 | epochs_since_improvement += 1 452 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,)) 453 | else: 454 | epochs_since_improvement = 0 455 | 456 | # Save checkpoint 457 | checkpoint_name = model+"_"+dataset #_tengxun_aggregation 458 | save_checkpoint(checkpoint_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, 459 | decoder_optimizer, metrics, is_best, final_args) 460 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch.nn.functional as F 3 | import time 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim 6 | import torch.utils.data 7 | import torchvision.transforms as transforms 8 | from torch import nn 9 | from torch.nn.utils.rnn import pack_padded_sequence 10 | from models.TVFC5lstm import * 11 | from datasets import * 12 | from utils import * 13 | from nltk.translate.bleu_score import corpus_bleu 14 | import argparse 15 | import codecs 16 | import numpy as np 17 | from torch.optim.lr_scheduler import StepLR 18 | dataset = "NWPU" 19 | model = "33LTVFC5LSTM" 20 | 21 | def train(args, train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch): 22 | """ 23 | Performs one epoch's training. 24 | 25 | :param train_loader: DataLoader for training data 26 | :param encoder: encoder model 27 | :param decoder: decoder model 28 | :param criterion: loss layer 29 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning) 30 | :param decoder_optimizer: optimizer to update decoder's weights 31 | :param epoch: epoch number 32 | """ 33 | 34 | encoder.train() 35 | decoder.train() # train mode (dropout and batchnorm is used) 36 | 37 | batch_time = AverageMeter() # forward prop. + back prop. time 38 | data_time = AverageMeter() # data loading time 39 | losses = AverageMeter() # loss (per word decoded) 40 | top5accs = AverageMeter() # top5 accuracy 41 | start = time.time() 42 | 43 | # Batches 44 | best_bleu4 = 0. # BLEU-4 score right now 45 | steps_since_improvement = 0 46 | final_args = {"emb_dim": args.emb_dim, 47 | "attention_dim": args.attention_dim, 48 | "decoder_dim": args.decoder_dim, 49 | "n_heads": args.n_heads, 50 | "dropout": args.dropout, 51 | "decoder_mode": args.decoder_mode, 52 | "attention_method": args.attention_method, 53 | "encoder_layers": args.encoder_layers, 54 | "decoder_layers": args.decoder_layers} 55 | for i, (imgs, caps, caplens) in enumerate(train_loader): 56 | data_time.update(time.time() - start) 57 | 58 | # Move to GPU, if available 59 | # print(caps) 60 | # print(caplens) 61 | imgs = imgs.to(device) 62 | caps = caps.to(device) 63 | caplens = caplens.to(device) 64 | 65 | # Forward prop. 66 | imgs = encoder(imgs) 67 | # imgs: [batch_size, 14, 14, 2048] 68 | # caps: [batch_size, 52] 69 | # caplens: [batch_size, 1] 70 | if args.decoder_mode == 'lstm_attention': 71 | scores, caps_sorted, decode_lengths, alphas, sort_ind, img_feature, text_feature = decoder(imgs, caps, caplens) 72 | else: 73 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 74 | 75 | # Since we decoded starting with , the targets are all words after , up to 76 | targets = caps_sorted[:, 1:] 77 | 78 | # Remove timesteps that we didn't decode at, or are pads 79 | # pack_padded_sequence is an easy trick to do this 80 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data 81 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data 82 | # print(scores.size()) 83 | # print(targets.size()) 84 | 85 | image_features = img_feature / img_feature.norm(dim=1, keepdim=True) 86 | text_features = text_feature / text_feature.norm(dim=1, keepdim=True) 87 | logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 88 | # 计算余弦相似度 89 | logit_scale = logit_scale.exp() 90 | logits_per_image = logit_scale * image_features @ text_features.t() 91 | logits_per_text = logits_per_image.t() 92 | labels = torch.arange(logits_per_image.shape[0], dtype=torch.long) 93 | #logits = logits.to(device) 94 | labels = labels.to(device) 95 | loss1 = (F.cross_entropy(logits_per_image, labels) +F.cross_entropy(logits_per_text, labels)) / 2 96 | 97 | # Calculate loss 98 | loss = criterion(scores, targets) 99 | 100 | total_loss = loss1+ 3*loss 101 | # Add doubly stochastic attention regularization 102 | # Second loss, mentioned in paper "Show, Attend and Tell: Neural Image Caption Generation with Visual Attention" 103 | # https://arxiv.org/abs/1502.03044 104 | # In section 4.2.1 Doubly stochastic attention regularization: We know the weights sum to 1 at a given timestep. 105 | # But we also encourage the weights at a single pixel p to sum to 1 across all timesteps T. 106 | # This means we want the model to attend to every pixel over the course of generating the entire sequence. 107 | # Therefore, we want to minimize the difference between 1 and the sum of a pixel's weights across all timesteps. 108 | 109 | 110 | # Back prop. 111 | decoder_optimizer.zero_grad() 112 | if encoder_optimizer is not None: 113 | encoder_optimizer.zero_grad() 114 | total_loss.backward() 115 | 116 | # Clip gradients 117 | if args.grad_clip is not None: 118 | clip_gradient(decoder_optimizer, args.grad_clip) 119 | if encoder_optimizer is not None: 120 | clip_gradient(encoder_optimizer, args.grad_clip) 121 | 122 | # Update weights 123 | decoder_optimizer.step() 124 | 125 | if encoder_optimizer is not None: 126 | encoder_optimizer.step() 127 | 128 | 129 | # Keep track of metrics 130 | top5 = accuracy(scores, targets, 5) 131 | losses.update(loss.item(), sum(decode_lengths)) 132 | top5accs.update(top5, sum(decode_lengths)) 133 | batch_time.update(time.time() - start) 134 | start = time.time() 135 | if i % args.print_freq == 0: 136 | # print('TIME: ', time.strftime("%m-%d %H : %M : %S", time.localtime(time.time()))) 137 | print("Epoch: {}/{} step: {}/{} Loss: {} AVG_Loss: {} Top-5 Accuracy: {} Batch_time: {}s".format(epoch+0, args.epochs, i+0, len(train_loader), losses.val, losses.avg, top5accs.val, batch_time.val)) 138 | 139 | 140 | def validate(args, val_loader, encoder, decoder, criterion): 141 | """ 142 | Performs one epoch's validation. 143 | 144 | :param val_loader: DataLoader for validation data. 145 | :param encoder: encoder model 146 | :param decoder: decoder model 147 | :param criterion: loss layer 148 | :return: score_dict {'Bleu_1': 0., 'Bleu_2': 0., 'Bleu_3': 0., 'Bleu_4': 0., 'METEOR': 0., 'ROUGE_L': 0., 'CIDEr': 1.} 149 | """ 150 | decoder.eval() # eval mode (no dropout or batchnorm) 151 | if encoder is not None: 152 | encoder.eval() 153 | 154 | batch_time = AverageMeter() 155 | losses = AverageMeter() 156 | top5accs = AverageMeter() 157 | 158 | start = time.time() 159 | 160 | references = list() # references (true captions) for calculating BLEU-4 score 161 | hypotheses = list() # hypotheses (predictions) 162 | 163 | # explicitly disable gradient calculation to avoid CUDA memory error 164 | with torch.no_grad(): 165 | # Batches 166 | for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader): 167 | 168 | # Move to device, if available 169 | imgs = imgs.to(device) 170 | caps = caps.to(device) 171 | caplens = caplens.to(device) 172 | 173 | # Forward prop. 174 | if encoder is not None: 175 | imgs = encoder(imgs) 176 | 177 | if args.decoder_mode == 'lstm_attention': 178 | scores, caps_sorted, decode_lengths, alphas, sort_ind, img_feature, text_feature = decoder(imgs, caps, caplens) 179 | else: 180 | scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 181 | 182 | # Since we decoded starting with , the targets are all words after , up to 183 | targets = caps_sorted[:, 1:] 184 | 185 | # Remove timesteps that we didn't decode at, or are pads 186 | # pack_padded_sequence is an easy trick to do this 187 | scores_copy = scores.clone() 188 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data 189 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data 190 | 191 | # Calculate loss 192 | loss = criterion(scores, targets) 193 | 194 | # Add doubly stochastic attention regularization 195 | 196 | # Keep track of metrics 197 | losses.update(loss.item(), sum(decode_lengths)) 198 | top5 = accuracy(scores, targets, 5) 199 | top5accs.update(top5, sum(decode_lengths)) 200 | batch_time.update(time.time() - start) 201 | start = time.time() 202 | 203 | 204 | # Store references (true captions), and hypothesis (prediction) for each image 205 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - 206 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] 207 | 208 | # References 209 | allcaps = allcaps[sort_ind] # because images were sorted in the decoder 210 | for j in range(allcaps.shape[0]): 211 | img_caps = allcaps[j].tolist() 212 | img_captions = list( 213 | map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}], 214 | img_caps)) # remove and pads 215 | references.append(img_captions) 216 | 217 | # Hypotheses 218 | _, preds = torch.max(scores_copy, dim=2) 219 | preds = preds.tolist() 220 | temp_preds = list() 221 | for j, p in enumerate(preds): 222 | temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads 223 | preds = temp_preds 224 | hypotheses.extend(preds) 225 | 226 | assert len(references) == len(hypotheses) 227 | 228 | # Calculate BLEU1~4, METEOR, ROUGE_L, CIDEr scores 229 | print('Validation:') 230 | metrics = get_eval_score(references, hypotheses) 231 | 232 | # print("EVA LOSS: {} TOP-5 Accuracy {} BLEU-1 {} BLEU2 {} BLEU3 {} BLEU-4 {} METEOR {} ROUGE_L {} CIDEr {}".format 233 | # (losses.avg, top5accs.avg, metrics["Bleu_1"], metrics["Bleu_2"], metrics["Bleu_3"], metrics["Bleu_4"], 234 | # metrics["METEOR"],metrics["ROUGE_L"], metrics["CIDEr"])) 235 | print('\n') 236 | 237 | return metrics 238 | 239 | 240 | if __name__ == '__main__': 241 | parser = argparse.ArgumentParser(description='Image_Captioning') 242 | 243 | # Data parameters 244 | parser.add_argument('--data_folder', default="./data/NWPU_images1",help='folder with data files saved by create_input_files.py.') 245 | parser.add_argument('--data_name', default="NWPU_5_cap_per_img_4_min_word_freq",help='base name shared by data files.') 246 | 247 | 248 | # Model parameters 249 | parser.add_argument('--emb_dim', type=int, default=512, help='dimension of word embeddings.')#300 250 | parser.add_argument('--attention_dim', type=int, default=512, help='dimension of attention linear layers.') 251 | parser.add_argument('--decoder_dim', type=int, default=1000, help='dimension of decoder RNN.') 252 | parser.add_argument('--n_heads', type=int, default=8, help='Multi-head attention in Transformer.') 253 | parser.add_argument('--dropout', type=float, default=0.5, help='dropout') 254 | 255 | # FIXME:note to change these 256 | parser.add_argument('--encoder_mode', default="resnet50", help='which model does encoder use?') # inception_v3 or vgg16 or vgg19 or resnet50 or resnet101 or resnet152 257 | parser.add_argument('--decoder_mode', default="lstm_attention", help='which model does decoder use?') # lstm or lstm_attention or transformer or transformer_decoder 258 | 259 | parser.add_argument('--attention_method', default="ByPixel", help='which attention method to use?') # ByPixel or ByChannel 260 | parser.add_argument('--encoder_layers', type=int, default=3, help='the number of layers of encoder in Transformer.') 261 | parser.add_argument('--decoder_layers', type=int, default=3, help='the number of layers of decoder in Transformer.') 262 | 263 | 264 | # Training parameters 265 | parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train for (if early stopping is not triggered).') 266 | parser.add_argument('--stop_criteria', type=int, default=20, help='training stop if epochs_since_improvement == stop_criteria') 267 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 268 | parser.add_argument('--print_freq', type=int, default=100, help='print training/validation stats every __ batches.') 269 | parser.add_argument('--workers', type=int, default=16, help='for data-loading; right now, only 0 works with h5pys in windows.') 270 | parser.add_argument('--encoder_lr', type=float, default=5e-4, help='learning rate for encoder if fine-tuning.') 271 | parser.add_argument('--decoder_lr', type=float, default=5e-4, help='learning rate for decoder.') 272 | parser.add_argument('--grad_clip', type=float, default=5., help='clip gradients at an absolute value of.') 273 | parser.add_argument('--alpha_c', type=float, default=1., help='regularization parameter for doubly stochastic attention, as in the paper.') 274 | parser.add_argument('--fine_tune_encoder', type=bool, default= True, help='whether fine-tune encoder or not') 275 | parser.add_argument('--fine_tune_embedding', type=bool, default= True, help='whether fine-tune word embeddings or not') 276 | parser.add_argument('--checkpoint', default=None, help='path to checkpoint, None if none.') 277 | parser.add_argument('--embedding_path', default=None, help='path to pre-trained word Embedding.') 278 | 279 | args = parser.parse_args() 280 | 281 | for encoder_layers, decoder_layers in [(3,3)]: #,,(0,6),(2,2), 282 | args.encoder_layers = encoder_layers 283 | args.decoder_layers = decoder_layers 284 | # args.encoder_mode = encoder_mode 285 | 286 | # load checkpoint, these parameters can't be modified 287 | final_args = {"emb_dim": args.emb_dim, 288 | "attention_dim": args.attention_dim, 289 | "decoder_dim": args.decoder_dim, 290 | "n_heads": args.n_heads, 291 | "dropout": args.dropout, 292 | "decoder_mode": args.decoder_mode, 293 | "attention_method": args.attention_method, 294 | "encoder_layers": args.encoder_layers, 295 | "decoder_layers": args.decoder_layers} 296 | 297 | start_epoch = 0 298 | best_bleu4 = 0. # BLEU-4 score right now 299 | epochs_since_improvement = 0 # keeps track of number of epochs since there's been an improvement in validation BLEU 300 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors 301 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead 302 | # print(device) 303 | 304 | # Read word map 305 | word_map_file = os.path.join(args.data_folder, 'WORDMAP_' + args.data_name + '.json') 306 | with open(word_map_file, 'r') as j: 307 | word_map = json.load(j) 308 | 309 | # Initialize / load checkpoint 310 | if args.checkpoint is None: 311 | 312 | # Encoder 313 | encoder = Encoder(NetType=args.encoder_mode) 314 | 315 | encoder.fine_tune(args.fine_tune_encoder) 316 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 317 | lr=args.encoder_lr) if args.fine_tune_encoder else None 318 | encoder_lr_scheduler = StepLR(encoder_optimizer,step_size=600,gamma=0.9) 319 | # set the encoder_dim 320 | encoder_dim = 512 if args.encoder_mode == 'vgg16' else 512 if args.encoder_mode == 'vgg19' \ 321 | else 2048 # FIXME: encoder_dim depends on the model 322 | 323 | # different Decoder 324 | if args.decoder_mode == "lstm_attention": 325 | decoder = DecoderWithAttention(attention_dim=args.attention_dim, 326 | embed_dim= args.emb_dim, 327 | decoder_dim= args.decoder_dim, 328 | vocab_size=len(word_map), 329 | dropout=args.dropout) 330 | 331 | decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()), 332 | lr=args.decoder_lr) 333 | decoder_lr_scheduler = StepLR(decoder_optimizer,step_size=600,gamma=0.9) 334 | 335 | # load pre-trained word embedding 336 | if args.embedding_path is not None: 337 | all_word_embeds = {} 338 | for i, line in enumerate(codecs.open(args.embedding_path, 'r', 'utf-8')): 339 | s = line.strip().split() 340 | all_word_embeds[s[0]] = np.array([float(i) for i in s[1:]]) 341 | 342 | # change emb_dim 343 | args.emb_dim = list(all_word_embeds.values())[-1].size 344 | word_embeds = np.random.uniform(-np.sqrt(0.06), np.sqrt(0.06), (len(word_map), args.emb_dim)) 345 | for w in word_map: 346 | if w in all_word_embeds: 347 | word_embeds[word_map[w]] = all_word_embeds[w] 348 | elif w.lower() in all_word_embeds: 349 | word_embeds[word_map[w]] = all_word_embeds[w.lower()] 350 | else: 351 | # 352 | embedding_i = torch.ones(1, args.emb_dim) 353 | torch.nn.init.xavier_uniform_(embedding_i) 354 | word_embeds[word_map[w]] = embedding_i 355 | 356 | word_embeds = torch.FloatTensor(word_embeds).to(device) 357 | decoder.load_pretrained_embeddings(word_embeds) 358 | decoder.fine_tune_embeddings(args.fine_tune_embedding) 359 | print('Loaded {} pre-trained word embeddings.'.format(len(word_embeds))) 360 | 361 | else: 362 | print("isNone") 363 | print(args.checkpoint) 364 | checkpoint = torch.load(args.checkpoint, map_location=str(device)) 365 | start_epoch = checkpoint['epoch'] + 1 366 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 367 | best_bleu1 = checkpoint['metrics']["Bleu_1"] 368 | encoder = checkpoint['encoder'] 369 | encoder_optimizer = checkpoint['encoder_optimizer'] 370 | decoder = checkpoint['decoder'] 371 | decoder_optimizer = checkpoint['decoder_optimizer'] 372 | decoder.fine_tune_embeddings(args.fine_tune_embedding) 373 | # load final_args from checkpoint 374 | final_args = checkpoint['final_args'] 375 | for key in final_args.keys(): 376 | args.__setattr__(key, final_args[key]) 377 | if args.fine_tune_encoder is True and encoder_optimizer is None: 378 | print("Encoder_Optimizer is None, Creating new Encoder_Optimizer!") 379 | encoder.fine_tune(args.fine_tune_encoder) 380 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 381 | lr=args.encoder_lr) 382 | 383 | # Move to GPU, if available 384 | decoder = decoder.to(device) 385 | encoder = encoder.to(device) 386 | print("Encoder_mode:{} Decoder_mode:{}".format(args.encoder_mode,args.decoder_mode)) 387 | print("encoder_layers {} decoder_layers {} n_heads {} dropout {} attention_method {} encoder_lr {} " 388 | "decoder_lr {} alpha_c {}".format(args.encoder_layers, args.decoder_layers, args.n_heads, args.dropout, 389 | args.attention_method, args.encoder_lr, args.decoder_lr, args.alpha_c)) 390 | # print(encoder) 391 | # print(decoder) 392 | 393 | # Loss function 394 | criterion = nn.CrossEntropyLoss().to(device) 395 | 396 | # Custom dataloaders 397 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 398 | # normalize = transforms.Normalize(mean=[0.399, 0.410, 0.371], std=[0.151, 0.138, 0.134]) 399 | # normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 400 | 401 | # pin_memory: If True, the data loader will copy Tensors into CUDA pinned memory before returning them. 402 | # If your data elements are a custom type, or your collate_fn returns a batch that is a custom type. 403 | train_loader = torch.utils.data.DataLoader( 404 | CaptionDataset(args.data_folder, args.data_name, 'TRAIN', transform=transforms.Compose([transforms.RandomHorizontalFlip(),normalize])), 405 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 406 | val_loader = torch.utils.data.DataLoader( 407 | CaptionDataset(args.data_folder, args.data_name, 'VAL', transform=transforms.Compose([transforms.RandomHorizontalFlip(),normalize])), 408 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 409 | 410 | # Epochs 411 | for epoch in range(start_epoch, args.epochs): 412 | 413 | # Decay learning rate if there is no improvement for 5 consecutive epochs, and terminate training after 25 414 | # 8 20 415 | if epochs_since_improvement == args.stop_criteria: 416 | print("the model has not improved in the last {} epochs".format(args.stop_criteria)) 417 | break 418 | if epochs_since_improvement > 0 and epochs_since_improvement % 5 == 0: 419 | adjust_learning_rate(decoder_optimizer, 0.8) 420 | if args.fine_tune_encoder and encoder_optimizer is not None: 421 | print(encoder_optimizer) 422 | adjust_learning_rate(encoder_optimizer, 0.8) 423 | 424 | # One epoch's training 425 | train(args, 426 | train_loader=train_loader, 427 | # val_loader=val_loader, 428 | encoder=encoder, 429 | decoder=decoder, 430 | criterion=criterion, 431 | encoder_optimizer=encoder_optimizer, 432 | #encoder_lr_scheduler=encoder_lr_scheduler, 433 | decoder_optimizer=decoder_optimizer, 434 | #decoder_lr_scheduler=decoder_lr_scheduler, 435 | epoch=epoch) 436 | 437 | 438 | # One epoch's validation 439 | metrics = validate(args, 440 | val_loader=val_loader, 441 | encoder=encoder, 442 | decoder=decoder, 443 | criterion=criterion) 444 | 445 | recent_bleu4 = metrics["Bleu_4"] 446 | 447 | # Check if there was an improvement 448 | is_best = recent_bleu4 > best_bleu4 449 | best_bleu4 = max(recent_bleu4, best_bleu4) 450 | if not is_best: 451 | epochs_since_improvement += 1 452 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,)) 453 | else: 454 | epochs_since_improvement = 0 455 | 456 | # Save checkpoint 457 | checkpoint_name = model+"_"+dataset #_tengxun_aggregation 458 | save_checkpoint(checkpoint_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, 459 | decoder_optimizer, metrics, is_best, final_args) 460 | 461 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | import json 5 | import torch 6 | # from scipy.misc import imread, imresize 7 | from PIL import Image 8 | import numpy as np 9 | #import imageio 10 | from tqdm import tqdm 11 | from collections import Counter 12 | from random import seed, choice, sample 13 | #from scipy.misc import imread, imresize 14 | from PIL import Image 15 | 16 | from eval_func.bleu.bleu import Bleu 17 | from eval_func.rouge.rouge import Rouge 18 | from eval_func.cider.cider import Cider 19 | from eval_func.meteor.meteor import Meteor 20 | # from eval_func.spice.spice import Spice 21 | import matplotlib.pyplot as plt 22 | import matplotlib as mpl 23 | import numpy as np 24 | 25 | 26 | def create_input_files(dataset, karpathy_json_path, image_folder, captions_per_image, min_word_freq, output_folder, 27 | max_len=100, imgnum=None): 28 | """ 29 | Creates input files for training, validation, and test data. 30 | 31 | :param dataset: name of dataset, one of 'coco', 'flickr8k', 'flickr30k' 32 | :param karpathy_json_path: path of Karpathy JSON file with splits and captions 33 | :param image_folder: folder with downloaded images 34 | :param captions_per_image: number of captions to sample per image 35 | :param min_word_freq: words occuring less frequently than this threshold are binned as s 36 | :param output_folder: folder to save files 37 | :param max_len: don't sample captions longer than this length 38 | """ 39 | 40 | assert dataset in {'coco', 'flickr8k', 'flickr30k','RSICD','Sydney','UCM'} 41 | 42 | # Read Karpathy JSON 43 | with open(karpathy_json_path, 'r') as j: 44 | data = json.load(j) 45 | 46 | # Read image paths and captions for each image 47 | train_image_paths = [] 48 | train_image_captions = [] 49 | val_image_paths = [] 50 | val_image_captions = [] 51 | test_image_paths = [] 52 | test_image_captions = [] 53 | word_freq = Counter() 54 | 55 | count = 0 56 | for img in data['images']: 57 | if imgnum is not None: 58 | count += 1 59 | if count > imgnum: # FIXME: fewer images 60 | break 61 | captions = [] 62 | for c in img['sentences']: 63 | # Update word frequency 64 | word_freq.update(c['tokens']) 65 | if len(c['tokens']) <= max_len: 66 | captions.append(c['tokens']) 67 | 68 | if len(captions) == 0: 69 | continue 70 | 71 | path = os.path.join(image_folder, img['filepath'], img['filename']) if dataset == 'coco' else os.path.join( 72 | image_folder, img['filename']) 73 | 74 | if img['split'] in {'train', 'restval'}: 75 | train_image_paths.append(path) 76 | train_image_captions.append(captions) 77 | elif img['split'] in {'val'}: 78 | val_image_paths.append(path) 79 | val_image_captions.append(captions) 80 | elif img['split'] in {'test'}: 81 | test_image_paths.append(path) 82 | test_image_captions.append(captions) 83 | 84 | with open("./data/test_image_paths.json", 'wb') as json_file: # FIXME: Store the test img paths 85 | json.dump(test_image_paths, json_file, ensure_ascii=False) 86 | 87 | # Sanity check 88 | assert len(train_image_paths) == len(train_image_captions) 89 | assert len(val_image_paths) == len(val_image_captions) 90 | assert len(test_image_paths) == len(test_image_captions) 91 | 92 | # Create word map 93 | words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq] 94 | word_map = {k: v + 1 for v, k in enumerate(words)} 95 | word_map[''] = len(word_map) + 1 96 | word_map[''] = len(word_map) + 1 97 | word_map[''] = len(word_map) + 1 98 | word_map[''] = 0 99 | 100 | # Create a base/root name for all output files 101 | base_filename = dataset + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq' 102 | 103 | # Save word map to a JSON 104 | with open(os.path.join(output_folder, 'WORDMAP_' + base_filename + '.json'), 'wb') as j: 105 | json.dump(word_map, j) 106 | 107 | # Sample captions for each image, save images to HDF5 file, and captions and their lengths to JSON files 108 | seed(123) 109 | for impaths, imcaps, split in [(train_image_paths, train_image_captions, 'TRAIN'), 110 | (val_image_paths, val_image_captions, 'VAL'), 111 | (test_image_paths, test_image_captions, 'TEST')]: 112 | 113 | with h5py.File(os.path.join(output_folder, split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as h: 114 | # Make a note of the number of captions we are sampling per image 115 | h.attrs['captions_per_image'] = captions_per_image 116 | 117 | # Create dataset inside HDF5 file to store images 118 | images = h.create_dataset('images', (len(impaths), 3, 256, 256), dtype='uint8') 119 | 120 | print("\nReading %s images and captions, storing to file...\n" % split) 121 | 122 | enc_captions = [] 123 | caplens = [] 124 | 125 | for i, path in enumerate(tqdm(impaths)): 126 | 127 | # Sample captions 128 | if len(imcaps[i]) < captions_per_image: 129 | captions = imcaps[i] + [choice(imcaps[i]) for _ in range(captions_per_image - len(imcaps[i]))] 130 | else: 131 | captions = sample(imcaps[i], k=captions_per_image) 132 | 133 | # Sanity check 134 | assert len(captions) == captions_per_image 135 | 136 | # Read images 137 | img = imageio.imread(impaths[i]) 138 | if len(img.shape) == 2: 139 | img = img[:, :, np.newaxis] 140 | img = np.concatenate([img, img, img], axis=2) 141 | img = np.array(Image.fromarray(img).resize((256, 256))) 142 | # img = imresize(img, (256, 256)) 143 | img = img.transpose(2, 0, 1) 144 | assert img.shape == (3, 256, 256) 145 | assert np.max(img) <= 255 146 | 147 | # Save image to HDF5 file 148 | images[i] = img 149 | 150 | for j, c in enumerate(captions): 151 | # Encode captions 152 | enc_c = [word_map['']] + [word_map.get(word, word_map['']) for word in c] + [ 153 | word_map['']] + [word_map['']] * (max_len - len(c)) 154 | 155 | # Find caption lengths 156 | c_len = len(c) + 2 157 | 158 | enc_captions.append(enc_c) 159 | caplens.append(c_len) 160 | 161 | # Sanity check 162 | assert images.shape[0] * captions_per_image == len(enc_captions) == len(caplens) 163 | 164 | # Save encoded captions and their lengths to JSON files 165 | with open(os.path.join(output_folder, split + '_CAPTIONS_' + base_filename + '.json'), 'wb') as j: 166 | json.dump(enc_captions, j) 167 | 168 | with open(os.path.join(output_folder, split + '_CAPLENS_' + base_filename + '.json'), 'wb') as j: 169 | json.dump(caplens, j) 170 | 171 | def create_input_files222(dataset, karpathy_json_path, image_folder, captions_per_image, min_word_freq, output_folder, 172 | max_len=100): 173 | """ 174 | Creates input files for training, validation, and test data. 175 | 176 | :param dataset: name of dataset, one of 'coco', 'flickr8k', 'flickr30k' 177 | :param karpathy_json_path: path of Karpathy JSON file with splits and captions 178 | :param image_folder: folder with downloaded images 179 | :param captions_per_image: number of captions to sample per image 180 | :param min_word_freq: words occuring less frequently than this threshold are binned as s 181 | :param output_folder: folder to save files 182 | :param max_len: don't sample captions longer than this length 183 | """ 184 | 185 | assert dataset in {'coco', 'flickr8k', 'flickr30k','RSICD'} 186 | 187 | # Read Karpathy JSON 188 | with open(karpathy_json_path, 'r') as j: 189 | data = json.load(j) 190 | 191 | # Read image paths and captions for each image 192 | train_image_paths = [] 193 | train_image_captions = [] 194 | val_image_paths = [] 195 | val_image_captions = [] 196 | test_image_paths = [] 197 | test_image_captions = [] 198 | word_freq = Counter() 199 | 200 | classlist = ['tree','building','airport','land','field','beach','bridge','center','church','commercial','residential','desert', 201 | 'farmland','forest','industrial','meadow',#'mediumresidential', 202 | 'mountain','park','school','square','parking', 203 | 'playground','pond','viaduct','port','railway', 204 | 'resort','river',#'sparseresidential', 205 | 'tank','stadium'] 206 | train_dict = {'tree':0,'building':0,'airport':0,'land':0,'field':0,'beach':0,'bridge':0,'center':0,'church':0,'commercial':0,'residential':0,'desert':0, 207 | 'farmland':0,'forest':0,'industrial':0,'meadow':0,#'mediumresidential':0, 208 | 'mountain':0,'park':0,'school':0,'square':0,'parking':0, 209 | 'playground':0,'pond':0,'viaduct':0,'port':0,'railway':0, 210 | 'resort':0,'river':0,#'sparseresidential':0, 211 | 'tank':0,'stadium':0} 212 | val_dict = {'tree':0,'building':0,'airport':0,'land':0,'field':0,'beach':0,'bridge':0,'center':0,'church':0,'commercial':0,'residential':0,'desert':0, 213 | 'farmland':0,'forest':0,'industrial':0,'meadow':0,#'mediumresidential':0, 214 | 'mountain':0,'park':0,'school':0,'square':0,'parking':0, 215 | 'playground':0,'pond':0,'viaduct':0,'port':0,'railway':0, 216 | 'resort':0,'river':0,#'sparseresidential':0, 217 | 'tank':0,'stadium':0} 218 | test_dict = {'tree':0,'building':0,'airport':0,'land':0,'field':0,'beach':0,'bridge':0,'center':0,'church':0,'commercial':0,'residential':0,'desert':0, 219 | 'farmland':0,'forest':0,'industrial':0,'meadow':0,#'mediumresidential':0, 220 | 'mountain':0,'park':0,'school':0,'square':0,'parking':0, 221 | 'playground':0,'pond':0,'viaduct':0,'port':0,'railway':0, 222 | 'resort':0,'river':0,#'sparseresidential':0, 223 | 'tank':0,'stadium':0} 224 | train_leng=0 225 | val_leng = 0 226 | test_leng = 0 227 | num=0 228 | for img in data['images']: 229 | captions = [] 230 | for c in img['sentences']: 231 | # Update word frequency 232 | word_freq.update(c['tokens']) 233 | if len(c['tokens']) <= max_len: 234 | captions.append(c['tokens']) # [[0], [1], [2], [3], [4]] 235 | 236 | if len(captions) == 0: 237 | continue 238 | 239 | path = os.path.join(image_folder, img['filepath'], img['filename']) if dataset == 'coco' else os.path.join( 240 | image_folder, img['filename']) 241 | 242 | if img['split'] in {'train', 'restval','val','test'}: 243 | num = num + 1 244 | if num % 10 == 0: 245 | val_image_paths.append(path) 246 | val_image_captions.append(captions) 247 | for imageclass in classlist: 248 | for i in range(5): 249 | sentences = img['sentences'] 250 | sentencesi = sentences[i] 251 | # if imageclass in sentencesi['tokens']: 252 | if (imageclass + ' ' in sentencesi['raw']) or (imageclass + 's ' in sentencesi['raw']): 253 | val_dict[imageclass] = val_dict[imageclass] + 1 254 | val_leng = val_leng + 1 255 | elif num % 10 == 1: 256 | test_image_paths.append(path) 257 | test_image_captions.append(captions) 258 | for imageclass in classlist: 259 | for i in range(5): 260 | sentences = img['sentences'] 261 | sentencesi = sentences[i] 262 | # if imageclass in sentencesi['tokens']: 263 | if (imageclass + ' ' in sentencesi['raw']) or (imageclass + 's ' in sentencesi['raw']): 264 | test_dict[imageclass] = test_dict[imageclass] + 1 265 | test_leng = test_leng + 1 266 | else: 267 | train_image_paths.append(path) 268 | train_image_captions.append(captions) 269 | for imageclass in classlist: 270 | for i in range(5): 271 | sentences = img['sentences'] 272 | sentencesi = sentences[i] 273 | # if imageclass in img['filename']: 274 | if (imageclass+' ' in sentencesi['raw']) or (imageclass +'s ' in sentencesi['raw']): 275 | train_dict[imageclass] = train_dict[imageclass] +1 276 | train_leng = train_leng + 1 277 | 278 | # if img['split'] in {'train', 'restval'}: 279 | # train_image_paths.append(path) 280 | # train_image_captions.append(captions) 281 | # for imageclass in classlist: 282 | # for i in range(5): 283 | # sentences = img['sentences'] 284 | # sentencesi = sentences[i] 285 | # # if imageclass in img['filename']: 286 | # if (imageclass + ' ' in sentencesi['raw']) or (imageclass + 's ' in sentencesi['raw']): 287 | # train_dict[imageclass] = train_dict[imageclass] + 1 288 | # train_leng = train_leng + 1 289 | # elif img['split'] in {'val'}: 290 | # val_image_paths.append(path) 291 | # val_image_captions.append(captions) 292 | # for imageclass in classlist: 293 | # for i in range(5): 294 | # sentences = img['sentences'] 295 | # sentencesi = sentences[i] 296 | # # if imageclass in sentencesi['tokens']: 297 | # if (imageclass + ' ' in sentencesi['raw']) or (imageclass + 's ' in sentencesi['raw']): 298 | # val_dict[imageclass] = val_dict[imageclass] +1 299 | # val_leng = val_leng + 1 300 | # elif img['split'] in {'test'}: 301 | # test_image_paths.append(path) 302 | # test_image_captions.append(captions) 303 | # for imageclass in classlist: 304 | # for i in range(5): 305 | # sentences = img['sentences'] 306 | # sentencesi = sentences[i] 307 | # # if imageclass in sentencesi['tokens']: 308 | # if (imageclass + ' ' in sentencesi['raw']) or (imageclass + 's ' in sentencesi['raw']): 309 | # test_dict[imageclass] = test_dict[imageclass] +1 310 | # test_leng = test_leng + 1 311 | 312 | 313 | total_dict = Counter(train_dict) + Counter(val_dict) +Counter(test_dict) 314 | for imageclass in classlist: 315 | train_dict[imageclass] = train_dict[imageclass]/(5*len(train_image_paths)) 316 | val_dict[imageclass] = val_dict[imageclass] / (5*len(val_image_paths)) 317 | test_dict[imageclass] = test_dict[imageclass]/(5*len(test_image_paths)) 318 | # total_dict[imageclass] = total_dict[imageclass] / (len(train_image_paths)+len(val_image_paths) + len(test_image_paths)) 319 | 320 | draw_from_dict([train_dict,val_dict,test_dict], len(classlist), 1) 321 | print('train_dict:\n', train_dict) 322 | print('val_dict:\n', val_dict) 323 | print('test_dict:\n', test_dict) 324 | print('total_dict:\n', total_dict) 325 | print(train_leng,val_leng,test_leng) 326 | # Sanity check 327 | assert len(train_image_paths) == len(train_image_captions) 328 | assert len(val_image_paths) == len(val_image_captions) 329 | assert len(test_image_paths) == len(test_image_captions) 330 | print("find {} training data, {} val data, {} test data".format(len(train_image_paths), len(val_image_paths), len(test_image_paths))) 331 | 332 | # # Create word map 333 | words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq] 334 | word_map = {k: v + 1 for v, k in enumerate(words)} # word2id 335 | word_map[''] = len(word_map) + 1 336 | word_map[''] = len(word_map) + 1 337 | word_map[''] = len(word_map) + 1 338 | word_map[''] = 0 339 | 340 | # Create a base/root name for all output files 341 | base_filename = dataset + '_' + str(captions_per_image) + '_cap_per_img_' + str(min_word_freq) + '_min_word_freq' 342 | 343 | # Save word map to a JSON 344 | with open(os.path.join(output_folder, 'WORDMAP_' + base_filename + '.json'), 'w') as j: 345 | json.dump(word_map, j) 346 | print("{} words write into WORDMAP".format(len(word_map))) 347 | 348 | # Sample captions for each image, save images to HDF5 file, and captions and their lengths to JSON files 349 | seed(123) 350 | for impaths, imcaps, split in [(train_image_paths, train_image_captions, 'TRAIN'), 351 | (val_image_paths, val_image_captions, 'VAL'), 352 | (test_image_paths, test_image_captions, 'TEST')]: 353 | 354 | with h5py.File(os.path.join(output_folder, split + '_IMAGES_' + base_filename + '.hdf5'), 'a') as h: 355 | # Make a note of the number of captions we are sampling per image 356 | h.attrs['captions_per_image'] = captions_per_image 357 | 358 | # Create dataset inside HDF5 file to store images 359 | images = h.create_dataset('images', (len(impaths), 3, 256, 256), dtype='uint8') 360 | 361 | print("\nReading %s images and captions, storing to file...\n" % split) 362 | 363 | enc_captions = [] 364 | caplens = [] 365 | 366 | for i, path in enumerate(tqdm(impaths)): 367 | 368 | # Sample captions 369 | if len(imcaps[i]) < captions_per_image: 370 | captions = imcaps[i] + [choice(imcaps[i]) for _ in range(captions_per_image - len(imcaps[i]))] 371 | else: 372 | captions = sample(imcaps[i], k=captions_per_image) 373 | 374 | # Sanity check 375 | assert len(captions) == captions_per_image 376 | 377 | # Read images 378 | img = imageio.imread(impaths[i]) 379 | # img = imread(impaths[i]) 380 | if len(img.shape) == 2: 381 | # gray-scale 382 | img = img[:, :, np.newaxis] 383 | img = np.concatenate([img, img, img], axis=2) # [256, 256, 1+1+1] 384 | img = np.array(Image.fromarray(img).resize((256, 256))) 385 | # img = imresize(img, (256, 256)) 386 | img = img.transpose(2, 0, 1) 387 | assert img.shape == (3, 256, 256) 388 | assert np.max(img) <= 255 389 | 390 | # Save image to HDF5 file 391 | images[i] = img 392 | 393 | for j, c in enumerate(captions): 394 | # Encode captions 395 | enc_c = [word_map['']] + [word_map.get(word, word_map['']) for word in c] + [ 396 | word_map['']] + [word_map['']] * (max_len - len(c)) 397 | 398 | # Find caption lengths 399 | c_len = len(c) + 2 400 | 401 | enc_captions.append(enc_c) 402 | caplens.append(c_len) 403 | 404 | # Sanity check 405 | assert images.shape[0] * captions_per_image == len(enc_captions) == len(caplens) 406 | 407 | # Save encoded captions and their lengths to JSON files 408 | with open(os.path.join(output_folder, split + '_CAPTIONS_' + base_filename + '.json'), 'w') as j: 409 | json.dump(enc_captions, j) 410 | 411 | with open(os.path.join(output_folder, split + '_CAPLENS_' + base_filename + '.json'), 'w') as j: 412 | json.dump(caplens, j) 413 | 414 | 415 | def init_embedding(embeddings): 416 | """ 417 | Fills embedding tensor with values from the uniform distribution. 418 | 419 | :param embeddings: embedding tensor 420 | """ 421 | bias = np.sqrt(3.0 / embeddings.size(1)) 422 | torch.nn.init.uniform_(embeddings, -bias, bias) 423 | 424 | 425 | def load_embeddings(emb_file, word_map): 426 | """ 427 | Creates an embedding tensor for the specified word map, for loading into the model. 428 | 429 | :param emb_file: file containing embeddings (stored in GloVe format) 430 | :param word_map: word map 431 | :return: embeddings in the same order as the words in the word map, dimension of embeddings 432 | """ 433 | 434 | # Find embedding dimension 435 | with open(emb_file, 'r') as f: 436 | emb_dim = len(f.readline().split(' ')) - 1 437 | 438 | vocab = set(word_map.keys()) 439 | 440 | # Create tensor to hold embeddings, initialize 441 | embeddings = torch.FloatTensor(len(vocab), emb_dim) 442 | init_embedding(embeddings) 443 | 444 | # Read embedding file 445 | print("\nLoading embeddings...") 446 | for line in open(emb_file, 'r'): 447 | line = line.split(' ') 448 | 449 | emb_word = line[0] 450 | embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:]))) 451 | 452 | # Ignore word if not in train_vocab 453 | if emb_word not in vocab: 454 | continue 455 | 456 | embeddings[word_map[emb_word]] = torch.FloatTensor(embedding) 457 | 458 | return embeddings, emb_dim 459 | 460 | 461 | def clip_gradient(optimizer, grad_clip): 462 | """ 463 | Clips gradients computed during backpropagation to avoid explosion of gradients. 464 | 465 | :param optimizer: optimizer with the gradients to be clipped 466 | :param grad_clip: clip value 467 | """ 468 | for group in optimizer.param_groups: 469 | for param in group['params']: 470 | if param.grad is not None: 471 | param.grad.data.clamp_(-grad_clip, grad_clip) 472 | 473 | 474 | def save_checkpoint(checkpoint_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer, 475 | metrics, is_best, final_args): 476 | """ 477 | Saves model checkpoint. 478 | 479 | :param data_name: base name of processed dataset #FIXME:change data_name to decoder_mode 480 | :param epoch: epoch number 481 | :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score 482 | :param encoder: encoder model 483 | :param decoder: decoder model 484 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning 485 | :param decoder_optimizer: optimizer to update decoder's weights 486 | :param bleu4: validation BLEU-4 score for this epoch #FIXME:change bleu4 to metrics 487 | :param is_best: is this checkpoint the best so far? 488 | """ 489 | state = {'epoch': epoch, 490 | 'epochs_since_improvement': epochs_since_improvement, 491 | 'metrics': metrics, 492 | 'encoder': encoder, 493 | 'decoder': decoder, 494 | 'encoder_optimizer': encoder_optimizer, 495 | 'decoder_optimizer': decoder_optimizer, 496 | 'final_args': final_args} 497 | filename = 'checkpoint_' + checkpoint_name +'.pth.tar' 498 | 499 | filepath = os.path.join('./last_models_weights/', filename) # 最终参数模型 500 | torch.save(state, filepath) 501 | 502 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 503 | if is_best: 504 | torch.save(state, os.path.join('./best_models_weights/', 'BEST_' + filename)) 505 | #torch.save(state, os.path.join('./models_checkpoint3/', 'Last_epoch_'+ filename)) 506 | #torch.save(state, os.path.join('./models_checkpoint3/', 'epoch_'+str(epoch) + filename)) 507 | 508 | 509 | class AverageMeter(object): 510 | """ 511 | Keeps track of most recent, average, sum, and count of a metric. 512 | """ 513 | 514 | def __init__(self): 515 | self.reset() 516 | 517 | def reset(self): 518 | self.val = 0 519 | self.avg = 0 520 | self.sum = 0 521 | self.count = 0 522 | 523 | def update(self, val, n=1): 524 | self.val = val 525 | self.sum += val * n 526 | self.count += n 527 | self.avg = self.sum / self.count 528 | 529 | 530 | def adjust_learning_rate(optimizer, shrink_factor): 531 | """ 532 | Shrinks learning rate by a specified factor. 533 | 534 | :param optimizer: optimizer whose learning rate must be shrunk. 535 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 536 | """ 537 | 538 | print("\nDECAYING learning rate.") 539 | for param_group in optimizer.param_groups: 540 | param_group['lr'] = param_group['lr'] * shrink_factor 541 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 542 | 543 | 544 | def accuracy(scores, targets, k): 545 | """ 546 | Computes top-k accuracy, from predicted and true labels. 547 | 548 | :param scores: scores from the model 549 | :param targets: true labels 550 | :param k: k in top-k accuracy 551 | :return: top-k accuracy 552 | """ 553 | 554 | batch_size = targets.size(0) 555 | _, ind = scores.topk(k, 1, True, True) 556 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 557 | correct_total = correct.view(-1).float().sum() # 0D tensor 558 | return correct_total.item() * (100.0 / batch_size) 559 | 560 | 561 | def get_eval_score(references, hypotheses): 562 | """ 563 | scorers = [ 564 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 565 | (Meteor(), "METEOR"), 566 | (Rouge(), "ROUGE_L"), 567 | (Cider(), "CIDEr") 568 | ] 569 | """ 570 | scorers = [ 571 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 572 | (Meteor(), "METEOR"), 573 | (Rouge(), "ROUGE_L"), 574 | (Cider(), "CIDEr") 575 | ] 576 | 577 | hypo = [[' '.join(hypo)] for hypo in [[str(x) for x in hypo] for hypo in hypotheses]] 578 | ref = [[' '.join(reft) for reft in reftmp] for reftmp in 579 | [[[str(x) for x in reft] for reft in reftmp] for reftmp in references]] 580 | 581 | score = [] 582 | method = [] 583 | for scorer, method_i in scorers: 584 | score_i, scores_i = scorer.compute_score(ref, hypo) 585 | score.extend(score_i) if isinstance(score_i, list) else score.append(score_i) 586 | method.extend(method_i) if isinstance(method_i, list) else method.append(method_i) 587 | print("{} {}".format(method_i, score_i)) 588 | score_dict = dict(zip(method, score)) 589 | 590 | return score_dict 591 | 592 | 593 | def convert2words(sequences, rev_word_map): 594 | for l1 in sequences: 595 | caption = "" 596 | for l2 in l1: 597 | caption += rev_word_map[l2] 598 | caption += " " 599 | print(caption) 600 | --------------------------------------------------------------------------------