├── .DS_Store ├── .gitignore ├── ARCTIC ├── ARCTIC_dataloader.py ├── ARCTIC_model.py ├── __pycache__ │ ├── ARCTIC_dataloader.cpython-39.pyc │ └── ARCTIC_model.cpython-39.pyc └── train.py ├── README.md ├── SwinTrans ├── evaluate.ipynb └── gridSwinTrans.ipynb ├── ViT ├── ViT.ipynb ├── config.json └── generate.ipynb ├── evaluate.py ├── new_dataset ├── QianFan-agent.py ├── combined_input.json ├── merge_json.py ├── new_generate.ipynb ├── res.json ├── res_add.json ├── res_new.json └── statement.txt ├── tools └── test_blip.py └── 结题报告.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/henryli2002/Image2TextEvaluation/056f36b0c84fad6a410d55fcc7ad2e8fa0b5a367/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | new_dataset/new_genrate.ipynb 2 | -------------------------------------------------------------------------------- /ARCTIC/ARCTIC_dataloader.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils.rnn import pack_padded_sequence # 压紧填充序列 5 | from torch.utils.data import Dataset 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from torchvision.models import ResNet101_Weights 9 | from nltk.translate.bleu_score import corpus_bleu # BLEU评价指标 10 | import numpy as np 11 | import json 12 | from torch.utils.data import Dataset 13 | import os 14 | from PIL import Image 15 | from collections import Counter,defaultdict 16 | class ImageTextDataset(Dataset): 17 | def __init__(self, dataset_path, vocab_path, split, captions_per_image=1, max_len=93, transform=None): 18 | 19 | self.split = split 20 | assert self.split in {'train', 'test'} 21 | self.cpi = captions_per_image 22 | self.max_len = max_len 23 | 24 | # 载入数据集 25 | with open(dataset_path, 'r') as f: 26 | self.data = json.load(f) #key是图片名字 value是描述 27 | self.data_img=list(self.data.keys()) 28 | # 载入词典 29 | with open(vocab_path, 'r') as f: 30 | self.vocab = json.load(f) 31 | 32 | # PyTorch图像预处理流程 33 | self.transform = transform 34 | 35 | # Total number of datapoints 36 | self.dataset_size = len(self.data_img) 37 | 38 | def __getitem__(self, i): 39 | # 第i个文本描述对应第(i // captions_per_image)张图片 40 | print(self.data_img[i]) 41 | img = Image.open(img_path+"/"+self.data_img[i]).convert('RGB') 42 | if self.transform is not None: 43 | img = self.transform(img) 44 | c_vec=cap_to_wvec(self.vocab,self.data[self.data_img[i]]) 45 | #加入起始和结束标志 46 | c_vec = [self.vocab['']] + c_vec + [self.vocab['']] 47 | caplen = len(c_vec) 48 | caption = torch.LongTensor(c_vec+ [self.vocab['']] * (self.max_len + 2 - caplen)) 49 | 50 | return img, caption, caplen 51 | 52 | def __len__(self): 53 | return self.dataset_size 54 | def mktrainval(data_dir, vocab_path, batch_size, workers=4,is_transform=True): 55 | train_tx = transforms.Compose([ 56 | transforms.Resize(256), # 重置图像分辨率 57 | transforms.RandomCrop(224), # 随机裁剪 58 | transforms.ToTensor(), # 转换成Tensor 59 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 标准化--三个参数为三个通道的均值和标准差 60 | ]) 61 | val_tx = transforms.Compose([ 62 | transforms.Resize(256), 63 | transforms.CenterCrop(224), 64 | transforms.ToTensor(), 65 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 66 | ]) 67 | no_trans=transforms.Compose([ 68 | transforms.Resize(256), 69 | transforms.CenterCrop(224), 70 | transforms.ToTensor(), 71 | ]) 72 | if is_transform: 73 | train_set = ImageTextDataset(os.path.join(data_dir, 'train_captions.json'), vocab_path, 'train', transform=train_tx) 74 | test_set = ImageTextDataset(os.path.join(data_dir, 'test_captions.json'), vocab_path, 'test', transform=val_tx) 75 | else: 76 | train_set = ImageTextDataset(os.path.join(data_dir, 'train_captions.json'), vocab_path, 'train', transform=no_trans) 77 | test_set = ImageTextDataset(os.path.join(data_dir, 'test_captions.json'), vocab_path, 'test', transform=no_trans) 78 | train_loader = torch.utils.data.DataLoader( 79 | train_set, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True) 80 | 81 | test_loader = torch.utils.data.DataLoader( 82 | test_set, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True, drop_last=False) 83 | 84 | return train_loader, test_loader 85 | img_path = f'../data/deepfashion-multimodal/images' 86 | def cap_to_wvec(vocab,cap):#将文本描述转换成向量 87 | cap.replace(",","") 88 | cap.replace(".","") 89 | cap=cap.split() 90 | res=[] 91 | for word in cap: 92 | if word in vocab.keys(): 93 | res.append(vocab[word]) 94 | else: #不在字典的词 95 | res.append(vocab['']) 96 | return res 97 | def wvec_to_cap(vocab,wvec):#将向量转换成文本描述 98 | res=[] 99 | for word in wvec: 100 | for key,value in vocab.items(): 101 | if value==word and key not in ['','','','']: 102 | res.append(key) 103 | res=" ".join(res) 104 | return res 105 | def wvec_to_capls(vocab,wvec):#将向量转换成文本描述 106 | res=[] 107 | for word in wvec: 108 | for key,value in vocab.items(): 109 | if value==word and key not in ['','','','']: 110 | res.append(key) 111 | return res -------------------------------------------------------------------------------- /ARCTIC/ARCTIC_model.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils.rnn import pack_padded_sequence # 压紧填充序列 5 | from torch.utils.data import Dataset 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from torchvision.models import ResNet101_Weights 9 | from nltk.translate.bleu_score import corpus_bleu # BLEU评价指标 10 | import numpy as np 11 | import json 12 | from torch.utils.data import Dataset 13 | import os 14 | from PIL import Image 15 | from collections import Counter,defaultdict 16 | ARCTIC_config = Namespace( 17 | max_len = 93, 18 | captions_per_image = 1, 19 | batch_size = 32, 20 | image_code_dim = 2048, 21 | word_dim = 512, 22 | hidden_size = 512, 23 | attention_dim = 512, 24 | num_layers = 1, 25 | encoder_learning_rate = 0.0001, 26 | decoder_learning_rate = 0.0005, 27 | num_epochs = 10, 28 | grad_clip = 5.0, 29 | alpha_weight = 1.0, 30 | evaluate_step = 900, # 每隔多少步在验证集上测试一次 31 | checkpoint = None, # 如果不为None,则利用该变量路径的模型继续训练 32 | best_checkpoint = 'model/ARCTIC/best_ARCTIC.ckpt', # 验证集上表现最优的模型的路径 33 | last_checkpoint = 'model/ARCTIC/last_ARCTIC.ckpt', # 训练完成时的模型的路径 34 | beam_k = 5 #束搜索的束宽 35 | ) 36 | class ImageEncoder(nn.Module): 37 | def __init__(self, finetuned=True): 38 | super(ImageEncoder, self).__init__() 39 | model = torchvision.models.resnet101(weights=ResNet101_Weights.DEFAULT) 40 | # ResNet-101网格表示提取器 41 | self.grid_rep_extractor = nn.Sequential(*(list(model.children())[:-2])) #去掉最后两层 42 | for param in self.grid_rep_extractor.parameters(): #冻结参数--不参与训练 43 | param.requires_grad = finetuned #是否微调 44 | def forward(self, images): 45 | out = self.grid_rep_extractor(images) 46 | return out 47 | class AdditiveAttention(nn.Module): #加性注意力 48 | def __init__(self, query_dim, key_dim, attn_dim): 49 | """ 50 | query_dim: 查询Q的维度 51 | key_dim: 键K的维度 52 | attn_dim: 注意力函数隐藏层表示的维度 53 | """ 54 | 55 | super(AdditiveAttention, self).__init__() 56 | self.attn_w_1_q = nn.Linear(query_dim, attn_dim) #Q的线性变换 57 | self.attn_w_1_k = nn.Linear(key_dim, attn_dim) #K的线性变换 58 | self.attn_w_2 = nn.Linear(attn_dim, 1) #注意力函数隐藏层到输出层的线性变换 59 | self.tanh = nn.Tanh() #激活函数 60 | self.softmax = nn.Softmax(dim=1) #归一化函数 61 | 62 | def forward(self, query, key_value): 63 | """ 64 | Q K V:Q和K算出相关性得分,作为V的权重,K=V 65 | 参数: 66 | query: 查询 (batch_size, q_dim) 67 | key_value: 键和值,(batch_size, n_kv, kv_dim) 68 | """ 69 | # (2)计算query和key的相关性,实现注意力评分函数 70 | # -> (batch_size, 1, attn_dim) 71 | queries = self.attn_w_1_q(query).unsqueeze(1) 72 | # -> (batch_size, n_kv, attn_dim) 73 | keys = self.attn_w_1_k(key_value) # 74 | # -> (batch_size, n_kv) 75 | attn = self.attn_w_2(self.tanh(queries+keys)).squeeze(2) #注意力评分函数 76 | # (3)归一化相关性分数 77 | # -> (batch_size, n_kv) 78 | attn = self.softmax(attn) #归一化 79 | # (4)计算输出 80 | # (batch_size x 1 x n_kv)(batch_size x n_kv x kv_dim) 81 | # -> (batch_size, 1, kv_dim) 82 | output = torch.bmm(attn.unsqueeze(1), key_value).squeeze(1) 83 | return output, attn 84 | class AttentionDecoder(nn.Module): 85 | def __init__(self, image_code_dim, vocab_size, word_dim, attention_dim, hidden_size, num_layers, dropout=0.5): 86 | super(AttentionDecoder, self).__init__() 87 | self.embed = nn.Embedding(vocab_size, word_dim) #词嵌入 88 | self.attention = AdditiveAttention(hidden_size, image_code_dim, attention_dim) #注意力机制 89 | self.init_state = nn.Linear(image_code_dim, num_layers*hidden_size) #初始化隐状态 90 | self.rnn = nn.GRU(word_dim + image_code_dim, hidden_size, num_layers) #GRU 91 | self.dropout = nn.Dropout(p=dropout) #dropout 92 | self.fc = nn.Linear(hidden_size, vocab_size) #全连接层 93 | # RNN默认已初始化 94 | self.init_weights() #初始化权重 95 | 96 | def init_weights(self): #初始化权重 97 | self.embed.weight.data.uniform_(-0.1, 0.1) #词嵌入 98 | self.fc.bias.data.fill_(0) #全连接层 99 | self.fc.weight.data.uniform_(-0.1, 0.1) #全连接层 100 | 101 | def init_hidden_state(self, image_code, captions, cap_lens): 102 | """ 103 | 参数: 104 | image_code:图像编码器输出的图像表示 105 | (batch_size, image_code_dim, grid_height, grid_width) 106 | """ 107 | # 将图像网格表示转换为序列表示形式 108 | batch_size, image_code_dim = image_code.size(0), image_code.size(1) 109 | # -> (batch_size, grid_height, grid_width, image_code_dim) 110 | image_code = image_code.permute(0, 2, 3, 1) 111 | # -> (batch_size, grid_height * grid_width, image_code_dim) 112 | image_code = image_code.view(batch_size, -1, image_code_dim) 113 | # (1)按照caption的长短排序 114 | sorted_cap_lens, sorted_cap_indices = torch.sort(cap_lens, 0, True) 115 | captions = captions[sorted_cap_indices] 116 | image_code = image_code[sorted_cap_indices] 117 | #(2)初始化隐状态 118 | hidden_state = self.init_state(image_code.mean(axis=1)) 119 | hidden_state = hidden_state.view( 120 | batch_size, 121 | self.rnn.num_layers, 122 | self.rnn.hidden_size).permute(1, 0, 2) 123 | return image_code, captions, sorted_cap_lens, sorted_cap_indices, hidden_state 124 | 125 | def forward_step(self, image_code, curr_cap_embed, hidden_state): 126 | #(3.2)利用注意力机制获得上下文向量 127 | # query:hidden_state[-1],即最后一个隐藏层输出 (batch_size, hidden_size) 128 | # context: (batch_size, hidden_size) 129 | context, alpha = self.attention(hidden_state[-1], image_code) 130 | #(3.3)以上下文向量和当前时刻词表示为输入,获得GRU输出 131 | x = torch.cat((context, curr_cap_embed), dim=-1).unsqueeze(0) 132 | # x: (1, real_batch_size, hidden_size+word_dim) 133 | # out: (1, real_batch_size, hidden_size) 134 | out, hidden_state = self.rnn(x, hidden_state) 135 | #(3.4)获取该时刻的预测结果 136 | # (real_batch_size, vocab_size) 137 | preds = self.fc(self.dropout(out.squeeze(0))) 138 | return preds, alpha, hidden_state 139 | 140 | def forward(self, image_code, captions, cap_lens): 141 | """ 142 | 参数: 143 | hidden_state: (num_layers, batch_size, hidden_size) 144 | image_code: (batch_size, feature_channel, feature_size) 145 | captions: (batch_size, ) 146 | """ 147 | # (1)将图文数据按照文本的实际长度从长到短排序 148 | # (2)获得GRU的初始隐状态 149 | image_code, captions, sorted_cap_lens, sorted_cap_indices, hidden_state \ 150 | = self.init_hidden_state(image_code, captions, cap_lens) 151 | batch_size = image_code.size(0) 152 | # 输入序列长度减1,因为最后一个时刻不需要预测下一个词 153 | lengths = sorted_cap_lens.cpu().numpy() - 1 154 | # 初始化变量:模型的预测结果和注意力分数 155 | predictions = torch.zeros(batch_size, lengths[0], self.fc.out_features).to(captions.device) 156 | alphas = torch.zeros(batch_size, lengths[0], image_code.shape[1]).to(captions.device) 157 | # 获取文本嵌入表示 cap_embeds: (batch_size, num_steps, word_dim) 158 | cap_embeds = self.embed(captions) 159 | # Teacher-Forcing模式 160 | for step in range(lengths[0]): 161 | #(3)解码 162 | #(3.1)模拟pack_padded_sequence函数的原理,获取该时刻的非输入 163 | real_batch_size = np.where(lengths>step)[0].shape[0] 164 | preds, alpha, hidden_state = self.forward_step( 165 | image_code[:real_batch_size], 166 | cap_embeds[:real_batch_size, step, :], 167 | hidden_state[:, :real_batch_size, :].contiguous()) 168 | # 记录结果 169 | predictions[:real_batch_size, step, :] = preds 170 | alphas[:real_batch_size, step, :] = alpha 171 | return predictions, alphas, captions, lengths, sorted_cap_indices 172 | 173 | class ARCTIC(nn.Module): #模型主体部分 174 | def __init__(self, image_code_dim, vocab, word_dim, attention_dim, hidden_size, num_layers): 175 | super(ARCTIC, self).__init__() 176 | self.vocab = vocab 177 | self.encoder = ImageEncoder() 178 | self.decoder = AttentionDecoder(image_code_dim, len(vocab), 179 | word_dim, attention_dim, hidden_size, num_layers) 180 | print("test") 181 | def forward(self, images, captions, cap_lens): 182 | image_code = self.encoder(images) 183 | return self.decoder(image_code, captions, cap_lens) 184 | def generate_by_beamsearch(self, images, beam_k, max_len): # beam_k束搜索 185 | vocab_size = len(self.vocab) 186 | image_codes = self.encoder(images) 187 | texts = [] 188 | device = images.device 189 | # 对每个图像样本执行束搜索 190 | for image_code in image_codes: 191 | # 将图像表示复制k份 192 | image_code = image_code.unsqueeze(0).repeat(beam_k,1,1,1) 193 | # 生成k个候选句子,初始时,仅包含开始符号 194 | cur_sents = torch.full((beam_k, 1), self.vocab[''], dtype=torch.long).to(device) 195 | cur_sent_embed = self.decoder.embed(cur_sents)[:,0,:] 196 | sent_lens = torch.LongTensor([1]*beam_k).to(device) 197 | # 获得GRU的初始隐状态 198 | image_code, cur_sent_embed, _, _, hidden_state = \ 199 | self.decoder.init_hidden_state(image_code, cur_sent_embed, sent_lens) 200 | # 存储已生成完整的句子(以句子结束符结尾的句子) 201 | end_sents = [] 202 | # 存储已生成完整的句子的概率 203 | end_probs = [] 204 | # 存储未完整生成的句子的概率 205 | probs = torch.zeros(beam_k, 1).to(device) 206 | k = beam_k 207 | while True: 208 | preds, _, hidden_state = self.decoder.forward_step(image_code[:k], cur_sent_embed, hidden_state.contiguous()) 209 | # -> (k, vocab_size) 210 | preds = nn.functional.log_softmax(preds, dim=1) 211 | # 对每个候选句子采样概率值最大的前k个单词生成k个新的候选句子,并计算概率 212 | # -> (k, vocab_size) 213 | probs = probs.repeat(1,preds.size(1)) + preds 214 | if cur_sents.size(1) == 1: 215 | # 第一步时,所有句子都只包含开始标识符,因此,仅利用其中一个句子计算topk 216 | values, indices = probs[0].topk(k, 0, True, True) 217 | else: 218 | # probs: (k, vocab_size) 是二维张量 219 | # topk函数直接应用于二维张量会按照指定维度取最大值,这里需要在全局取最大值 220 | # 因此,将probs转换为一维张量,再使用topk函数获取最大的k个值 221 | values, indices = probs.view(-1).topk(k, 0, True, True) 222 | # 计算最大的k个值对应的句子索引和词索引 223 | sent_indices = torch.div(indices, vocab_size, rounding_mode='trunc') 224 | word_indices = indices % vocab_size 225 | # 将词拼接在前一轮的句子后,获得此轮的句子 226 | cur_sents = torch.cat([cur_sents[sent_indices], word_indices.unsqueeze(1)], dim=1) 227 | # 查找此轮生成句子结束符的句子 228 | end_indices = [idx for idx, word in enumerate(word_indices) if word == self.vocab['']] 229 | if len(end_indices) > 0: 230 | end_probs.extend(values[end_indices]) 231 | end_sents.extend(cur_sents[end_indices].tolist()) 232 | # 如果所有的句子都包含结束符,则停止生成 233 | k -= len(end_indices) 234 | if k == 0: 235 | break 236 | # 查找还需要继续生成词的句子 237 | cur_indices = [idx for idx, word in enumerate(word_indices) 238 | if word != self.vocab['']] 239 | if len(cur_indices) > 0: 240 | cur_sent_indices = sent_indices[cur_indices] 241 | cur_word_indices = word_indices[cur_indices] 242 | # 仅保留还需要继续生成的句子、句子概率、隐状态、词嵌入 243 | cur_sents = cur_sents[cur_indices] 244 | probs = values[cur_indices].view(-1,1) 245 | hidden_state = hidden_state[:,cur_sent_indices,:] 246 | cur_sent_embed = self.decoder.embed( 247 | cur_word_indices.view(-1,1))[:,0,:] 248 | # 句子太长,停止生成 249 | if cur_sents.size(1) >= max_len: 250 | break 251 | if len(end_sents) == 0: 252 | # 如果没有包含结束符的句子,则选取第一个句子作为生成句子 253 | gen_sent = cur_sents[0].tolist() 254 | else: 255 | # 否则选取包含结束符的句子中概率最大的句子 256 | gen_sent = end_sents[end_probs.index(max(end_probs))] 257 | texts.append(gen_sent) 258 | return texts 259 | def generate_normal_version(self,images,max_len): #普通版本的生成--相当于k=1的束搜索 260 | device=images.device 261 | image_codes = self.encoder(images)#一个batch的图像编码 262 | texts = [] 263 | for image in image_codes: 264 | image=image.unsqueeze(0) 265 | cur_sents=torch.LongTensor([self.vocab['']]).to(device) #序列 266 | 267 | sent_lens=torch.LongTensor([1]).to(device) 268 | cur_sent_embed=self.decoder.embed(cur_sents) 269 | image, cur_sent_embed, _, _,hidden_state = self.decoder.init_hidden_state(image,cur_sent_embed, sent_lens) 270 | while True: 271 | preds, _, hidden_state = self.decoder.forward_step( #一个时间步 272 | image, cur_sent_embed, hidden_state.contiguous()) 273 | preds = nn.functional.log_softmax(preds, dim=1) #log_softmax 274 | #print(f"preds {preds.shape}") #获得概率分布 275 | values, indices = preds[0].topk(1, 0, True, True) #获得最大概率的概率 和对应词索引 276 | cur_sents=torch.cat([cur_sents,indices],dim=0) #拼接 生成序列 长度逐渐+1 277 | #print(f"indices {indices.shape} {indices} | {values}") 278 | cur_sent_embed=self.decoder.embed(indices) 279 | if indices==self.vocab['']: 280 | break 281 | if cur_sents.size(0) >= max_len: 282 | break 283 | texts.append(cur_sents.tolist()) 284 | return texts -------------------------------------------------------------------------------- /ARCTIC/__pycache__/ARCTIC_dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/henryli2002/Image2TextEvaluation/056f36b0c84fad6a410d55fcc7ad2e8fa0b5a367/ARCTIC/__pycache__/ARCTIC_dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /ARCTIC/__pycache__/ARCTIC_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/henryli2002/Image2TextEvaluation/056f36b0c84fad6a410d55fcc7ad2e8fa0b5a367/ARCTIC/__pycache__/ARCTIC_model.cpython-39.pyc -------------------------------------------------------------------------------- /ARCTIC/train.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils.rnn import pack_padded_sequence # 压紧填充序列 5 | from torch.utils.data import Dataset 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from torchvision.models import ResNet101_Weights 9 | from nltk.translate.bleu_score import corpus_bleu # BLEU评价指标 10 | import numpy as np 11 | import json 12 | from torch.utils.data import Dataset 13 | import os 14 | from PIL import Image 15 | from collections import Counter,defaultdict 16 | from ARCTIC_model import ImageEncoder ,AdditiveAttention,AttentionDecoder,ARCTIC 17 | from ARCTIC_dataset import ImageTextDataset ,mktrainval,cap_to_wvec,wvec_to_cap,wvec_to_capls 18 | ARCTIC_config = Namespace( 19 | max_len = 93, 20 | captions_per_image = 1, 21 | batch_size = 32, 22 | image_code_dim = 2048, 23 | word_dim = 512, 24 | hidden_size = 512, 25 | attention_dim = 512, 26 | num_layers = 1, 27 | encoder_learning_rate = 0.0001, 28 | decoder_learning_rate = 0.0005, 29 | num_epochs = 10, 30 | grad_clip = 5.0, 31 | alpha_weight = 1.0, 32 | evaluate_step = 900, # 每隔多少步在验证集上测试一次 33 | checkpoint = None, # 如果不为None,则利用该变量路径的模型继续训练 34 | best_checkpoint = 'model/ARCTIC/best_ARCTIC.ckpt', # 验证集上表现最优的模型的路径 35 | last_checkpoint = 'model/ARCTIC/last_ARCTIC.ckpt', # 训练完成时的模型的路径 36 | beam_k = 5 #束搜索的束宽 37 | ) 38 | 39 | class PackedCrossEntropyLoss(nn.Module): #损失函数 40 | def __init__(self): 41 | super(PackedCrossEntropyLoss, self).__init__() 42 | self.loss_fn = nn.CrossEntropyLoss() 43 | 44 | def forward(self, predictions, targets, lengths): #压紧填充序列 45 | 46 | predictions = pack_padded_sequence(predictions, lengths, batch_first=True)[0] 47 | targets = pack_padded_sequence(targets, lengths, batch_first=True)[0] 48 | return self.loss_fn(predictions, targets) #计算损失 49 | def filter_useless_words(sent, filterd_words): 50 | # 去除句子中不参与BLEU值计算的符号 51 | return [w for w in sent if w not in filterd_words] 52 | def evaluate(data_loader, model, config): 53 | model.eval() 54 | # 存储候选文本 55 | cands = [] 56 | # 存储参考文本 57 | refs = [] 58 | # 需要过滤的词 59 | filterd_words = set({model.vocab[''], model.vocab[''], model.vocab['']}) 60 | cpi = config.captions_per_image 61 | device = next(model.parameters()).device 62 | for i, (imgs, caps, caplens) in enumerate(data_loader): 63 | with torch.no_grad(): 64 | # 通过束搜索,生成候选文本 65 | texts = model.generate_by_beamsearch(imgs.to(device), config.beam_k, config.max_len+2) 66 | # 候选文本 67 | cands.extend([filter_useless_words(text, filterd_words) for text in texts]) 68 | # 参考文本 69 | refs.extend([filter_useless_words(cap, filterd_words) for cap in caps.tolist()]) 70 | # 实际上,每个候选文本对应cpi条参考文本 71 | multiple_refs = [] 72 | for idx in range(len(refs)): 73 | multiple_refs.append(refs[(idx//cpi)*cpi : (idx//cpi)*cpi+cpi]) 74 | # 计算BLEU-4值,corpus_bleu函数默认weights权重为(0.25,0.25,0.25,0.25) 75 | # 即计算1-gram到4-gram的BLEU几何平均值 76 | bleu4 = corpus_bleu(multiple_refs, cands, weights=(0.25,0.25,0.25,0.25)) 77 | model.train() 78 | return bleu4 79 | 80 | if __name__ == '__main__': 81 | # 设置GPU信息 82 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 83 | config = ARCTIC_config 84 | # 数据 85 | data_dir = 'data/deepfashion-multimodal/' 86 | vocab_path = 'data/deepfashion-multimodal/vocab.json' 87 | 88 | train_loader,test_loader=mktrainval(data_dir='data/deepfashion-multimodal',\ 89 | vocab_path='data/deepfashion-multimodal/vocab.json',\ 90 | batch_size=2,workers=2) 91 | 92 | # 模型 93 | with open(vocab_path, 'r') as f: 94 | vocab = json.load(f) 95 | 96 | # 随机初始化 或 载入已训练的模型 97 | start_epoch = 0 98 | checkpoint = config.checkpoint 99 | if checkpoint is None: 100 | model = ARCTIC(config.image_code_dim, vocab, config.word_dim, config.attention_dim, config.hidden_size, config.num_layers) 101 | else: 102 | checkpoint = torch.load(checkpoint) 103 | start_epoch = checkpoint['epoch'] + 1 104 | model = checkpoint['model'] 105 | 106 | # 优化器 107 | optimizer= torch.optim.Adam(lr=0.0001, params=model.parameters()) 108 | # 将模型拷贝至GPU,并开启训练模式 109 | model.to(device) 110 | model.train() 111 | # 损失函数 112 | loss_fn = PackedCrossEntropyLoss().to(device) 113 | best_res = 0 114 | print("开始训练") 115 | 116 | for epoch in range(start_epoch, config.num_epochs): 117 | for i, (imgs, caps, caplens) in enumerate(train_loader): 118 | optimizer.zero_grad() 119 | # 1. 读取数据至GPU 120 | imgs = imgs.to(device) 121 | caps = caps.to(device) 122 | caplens = caplens.to(device) 123 | predictions, alphas, sorted_captions, lengths, sorted_cap_indices = model(imgs, caps, caplens) 124 | loss = loss_fn(predictions, sorted_captions[:, 1:], lengths) 125 | loss += config.alpha_weight * ((1. - alphas.sum(axis=1)) ** 2).mean() 126 | loss.backward() 127 | if config.grad_clip > 0: 128 | nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) 129 | 130 | # 4. 更新参数 131 | optimizer.step() 132 | 133 | if (i+1) % 100 == 0: 134 | print('epoch %d, step %d: loss=%.2f' % (epoch, i+1, loss.cpu())) 135 | 136 | state = { 137 | 'epoch': epoch, 138 | 'step': i, 139 | 'model': model, 140 | 'optimizer': optimizer 141 | } 142 | if (i+1) % config.evaluate_step == 0: 143 | bleu_score = evaluate(test_loader, model, config) #在验证集上测试 144 | # 5. 选择模型 145 | if best_res < bleu_score: 146 | best_res = bleu_score 147 | torch.save(state, config.best_checkpoint) 148 | torch.save(state, config.last_checkpoint) 149 | print('Validation@epoch, %d, step, %d, BLEU-4=%.2f' % (epoch, i+1, bleu_score)) 150 | checkpoint = torch.load(config.best_checkpoint) 151 | model = checkpoint['model'] 152 | bleu_score = evaluate(test_loader, model, config) 153 | print("Evaluate on the test set with the model that has the best performance on the validation set") 154 | print('Epoch: %d, BLEU-4=%.2f' % (checkpoint['epoch'], bleu_score)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image2TextEvaluation 2 | BUPT神经网络与深度学习课设 3 | -------------------------------------------------------------------------------- /ViT/ViT.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "import os\n", 11 | "from PIL import Image\n", 12 | "import torch\n", 13 | "from torch.utils.data import Dataset, DataLoader\n", 14 | "from transformers import ViTFeatureExtractor, BertTokenizer\n", 15 | "from collections import defaultdict, Counter\n", 16 | "import numpy as np\n", 17 | "from nltk.translate.bleu_score import corpus_bleu\n", 18 | "\n", 19 | "dataset='deepfashion-multimodal'\n", 20 | "img_path = f'data/{dataset}/images'\n", 21 | "train_json_path= f'data/{dataset}/train_captions.json'\n", 22 | "test_json_path= f'data/{dataset}/test_captions.json'\n", 23 | "vocab_path = f'data/{dataset}/vocab.json'\n", 24 | "\n", 25 | "def idx_to_word(idx, vocab):#将向量转化为文本描述\n", 26 | " reverse_vocab = {v: k for k, v in vocab.items()}\n", 27 | " return reverse_vocab.get(int(idx), '')\n", 28 | "\n", 29 | "def cap_to_wvec(vocab,cap):#将文本描述转换成向量\n", 30 | " cap.replace(\",\",\"\")\n", 31 | " cap.replace(\".\",\"\")\n", 32 | " cap=cap.split()\n", 33 | " res=[]\n", 34 | " for word in cap:\n", 35 | " if word in vocab.keys():\n", 36 | " res.append(vocab[word])\n", 37 | " else: #不在字典的词\n", 38 | " res.append(vocab[''])\n", 39 | " return res\n", 40 | "\n", 41 | "def filter_cut_useless_words(sent, filterd_words):\n", 42 | " res=[]\n", 43 | " for w in sent:\n", 44 | " if w not in filterd_words:\n", 45 | " res.append(w)\n", 46 | " else:\n", 47 | " if w==155:\n", 48 | " return res\n", 49 | "\n", 50 | "def get_BLEU_score(cands, refs): #获取BLEU分数\n", 51 | " multiple_refs = []\n", 52 | " for idx in range(len(refs)):\n", 53 | " multiple_refs.append(refs[(idx//1)*1 : (idx//1)*1+1])#每个候选文本对应cpi==1条参考文本\n", 54 | " bleu4 = corpus_bleu(multiple_refs, cands, weights=(0.25,0.25,0.25,0.25))\n", 55 | " return bleu4\n", 56 | "\n", 57 | "def cider_d(reference_list, candidate_list, n=4):\n", 58 | " def count_ngrams(tokens, n):\n", 59 | " ngrams = []\n", 60 | " for i in range(len(tokens) - n + 1):\n", 61 | " ngram = tuple(tokens[i:i+n])\n", 62 | " ngrams.append(ngram)\n", 63 | " return ngrams\n", 64 | "\n", 65 | " def compute_cider_d(reference_list, candidate_list, n):\n", 66 | " cider_d_scores = []\n", 67 | " for refs, cand in zip(reference_list, candidate_list):\n", 68 | " cider_d_score = 0.0\n", 69 | " for i in range(1, n + 1):\n", 70 | " cand_ngrams = count_ngrams(cand, i)\n", 71 | " ref_ngrams_list = [count_ngrams(ref, i) for ref in refs]\n", 72 | "\n", 73 | " total_ref_ngrams = [ngram for ref_ngrams in ref_ngrams_list for ngram in ref_ngrams]\n", 74 | "\n", 75 | " count_cand = 0\n", 76 | " count_clip = 0\n", 77 | "\n", 78 | " for ngram in cand_ngrams:\n", 79 | " count_cand += 1\n", 80 | " if ngram in total_ref_ngrams:\n", 81 | " count_clip += 1\n", 82 | "\n", 83 | " precision = count_clip / count_cand if count_cand > 0 else 0.0\n", 84 | " recall = count_clip / len(total_ref_ngrams) if len(total_ref_ngrams) > 0 else 0.0\n", 85 | "\n", 86 | " beta = 1.0\n", 87 | " f_score = (1 + beta**2) * precision * recall / (beta**2 * precision + recall) if precision + recall > 0 else 0.0\n", 88 | "\n", 89 | " cider_d_score += f_score\n", 90 | "\n", 91 | " cider_d_score /= n\n", 92 | " cider_d_scores.append(cider_d_score)\n", 93 | "\n", 94 | " return cider_d_scores\n", 95 | "\n", 96 | " reference_tokens_list = reference_list\n", 97 | " candidate_tokens_list = candidate_list\n", 98 | "\n", 99 | " scores = compute_cider_d(reference_tokens_list, candidate_tokens_list, n)\n", 100 | "\n", 101 | " return np.mean(scores)\n", 102 | "\n", 103 | "def spice(reference_list, candidate_list, idf=None, beta=3):\n", 104 | " def tokenize(sentence):\n", 105 | " return sentence.lower().split()\n", 106 | "\n", 107 | " def count_ngrams(tokens, n):\n", 108 | " ngrams = []\n", 109 | " for i in range(len(tokens) - n + 1):\n", 110 | " ngram = tuple(tokens[i:i+n])\n", 111 | " ngrams.append(ngram)\n", 112 | " return ngrams\n", 113 | "\n", 114 | " def compute_spice_score(reference, candidate, idf, beta):\n", 115 | " reference_tokens = reference\n", 116 | " candidate_tokens = candidate\n", 117 | "\n", 118 | " reference_ngrams = [count_ngrams(reference_tokens, i) for i in range(1, beta + 1)]\n", 119 | " candidate_ngrams = [count_ngrams(candidate_tokens, i) for i in range(1, beta + 1)]\n", 120 | "\n", 121 | " precision_scores = []\n", 122 | " recall_scores = []\n", 123 | "\n", 124 | " for i in range(beta):\n", 125 | " common_ngrams = set(candidate_ngrams[i]) & set(reference_ngrams[i])\n", 126 | "\n", 127 | " precision = len(common_ngrams) / len(candidate_ngrams[i]) if len(candidate_ngrams[i]) > 0 else 0.0\n", 128 | " recall = len(common_ngrams) / len(reference_ngrams[i]) if len(reference_ngrams[i]) > 0 else 0.0\n", 129 | "\n", 130 | " precision_scores.append(precision)\n", 131 | " recall_scores.append(recall)\n", 132 | "\n", 133 | " precision_avg = np.mean(precision_scores)\n", 134 | " recall_avg = np.mean(recall_scores)\n", 135 | "\n", 136 | " spice_score = (precision_avg * recall_avg) / (precision_avg + recall_avg) if precision_avg + recall_avg > 0 else 0.0\n", 137 | "\n", 138 | " if idf:\n", 139 | " spice_score *= np.exp(np.sum([idf[token] for token in common_ngrams]) / len(candidate_tokens))\n", 140 | "\n", 141 | " return spice_score\n", 142 | "\n", 143 | " if idf is None:\n", 144 | " idf = {}\n", 145 | "\n", 146 | " spice_scores = []\n", 147 | "\n", 148 | " for reference, candidate in zip(reference_list, candidate_list):\n", 149 | " spice_score = compute_spice_score(reference, candidate, idf, beta)\n", 150 | " spice_scores.append(spice_score)\n", 151 | "\n", 152 | " return np.mean(spice_scores)\n", 153 | "\n", 154 | "def wvec_to_capls(vocab,wvec):#将向量转换成文本描述\n", 155 | " res=[]\n", 156 | " for word in wvec:\n", 157 | " for key,value in vocab.items():\n", 158 | " if value==word and key not in ['','','','']:\n", 159 | " res.append(key)\n", 160 | " return res\n", 161 | "\n", 162 | "def wvec_to_cap(vocab,wvec):#将向量转换成文本描述\n", 163 | " res=[]\n", 164 | " for word in wvec:\n", 165 | " for key,value in vocab.items():\n", 166 | " if value==word and key not in ['','','','']:\n", 167 | " res.append(key)\n", 168 | " res=\" \".join(res)\n", 169 | " return res\n", 170 | "\n", 171 | "def get_CIDER_D_score(vocab,cands, refs): #获得CIDER-D分数\n", 172 | " refs_ = [wvec_to_capls(vocab,ref) for ref in refs]\n", 173 | " cands_ = [wvec_to_capls(vocab,cand) for cand in cands]\n", 174 | " return cider_d(refs_, cands_)\n", 175 | "\n", 176 | "def get_SPICE_score(vocab,cands, refs): #获得SPICE分数\n", 177 | " refs_ = [wvec_to_cap(vocab,ref) for ref in refs]\n", 178 | " cands_ = [wvec_to_cap(vocab,cand) for cand in cands]\n", 179 | " return spice(refs_, cands_)\n", 180 | "\n", 181 | "class ImageTextDataset(Dataset):\n", 182 | " def __init__(self, dataset_path, vocab_path, split, captions_per_image=6, max_len=93, transform=None):\n", 183 | "\n", 184 | " self.split = split\n", 185 | " assert self.split in {'train', 'test'}\n", 186 | " self.cpi = captions_per_image\n", 187 | " self.max_len = max_len\n", 188 | "\n", 189 | " # 载入数据集\n", 190 | " with open(dataset_path, 'r') as f:\n", 191 | " self.data = json.load(f) #key是图片名字 value是描述\n", 192 | " self.data_img=list(self.data.keys())\n", 193 | " # 载入词典\n", 194 | " with open(vocab_path, 'r') as f:\n", 195 | " self.vocab = json.load(f)\n", 196 | "\n", 197 | " # PyTorch图像预处理流程\n", 198 | " self.transform = transform\n", 199 | "\n", 200 | " # Total number of datapoints\n", 201 | " self.dataset_size = len(self.data_img)\n", 202 | "\n", 203 | " def __getitem__(self, i):\n", 204 | " # 第i个文本描述对应第(i // captions_per_image)张图片\n", 205 | " img = Image.open(img_path+\"/\"+self.data_img[i]).convert('RGB')\n", 206 | " if self.transform is not None:\n", 207 | " img = self.transform(img)\n", 208 | " c_vec=cap_to_wvec(self.vocab,self.data[self.data_img[i]])\n", 209 | " #加入起始和结束标志\n", 210 | " c_vec = [self.vocab['']] + c_vec + [self.vocab['']]\n", 211 | " caplen = len(c_vec)\n", 212 | " caption = torch.LongTensor(c_vec+ [self.vocab['']] * (self.max_len + 2 - caplen))\n", 213 | " \n", 214 | " return img, caption, caplen\n", 215 | " \n", 216 | " def __len__(self):\n", 217 | " return self.dataset_size\n", 218 | " \n" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 4, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "from transformers import ViTModel, BertModel, BertConfig\n", 228 | "from torch import nn\n", 229 | "import torch\n", 230 | "\n", 231 | "class Img2TxtModel(nn.Module):\n", 232 | " def __init__(self, vit_model_name, transformer_config, vocab_size):\n", 233 | " super(Img2TxtModel, self).__init__()\n", 234 | " # ViT模型作为编码器\n", 235 | " self.encoder = ViTModel.from_pretrained(vit_model_name)\n", 236 | "\n", 237 | " # Transformer解码器配置\n", 238 | " transformer_config = BertConfig(vocab_size=vocab_size, num_hidden_layers=1, is_decoder=True, add_cross_attention=True)\n", 239 | " self.decoder = BertModel(transformer_config)\n", 240 | "\n", 241 | " # 预测每个词的线性层\n", 242 | " self.vocab_size = vocab_size\n", 243 | " self.fc = nn.Linear(transformer_config.hidden_size, vocab_size)\n", 244 | " \n", 245 | " def forward(self, input_ids, decoder_input_ids, decoder_attention_mask):\n", 246 | " # 通过ViT编码器获取图像特征\n", 247 | " encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state\n", 248 | "\n", 249 | " # 将图像特征作为解码器的输入\n", 250 | " decoder_outputs = self.decoder(input_ids=decoder_input_ids, \n", 251 | " attention_mask=decoder_attention_mask,\n", 252 | " encoder_hidden_states=encoder_outputs).last_hidden_state\n", 253 | "\n", 254 | " # 预测下一个词\n", 255 | " prediction_scores = self.fc(decoder_outputs)\n", 256 | " return prediction_scores\n", 257 | "\n", 258 | " def generate_text(self, input_ids, max_length=95, start_token_id=154):\n", 259 | " # 获取图像特征\n", 260 | " encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state\n", 261 | "\n", 262 | " # 初始化解码器输入为标记\n", 263 | " decoder_input_ids = torch.full((input_ids.size(0), 1), start_token_id).to(input_ids.device)\n", 264 | " \n", 265 | " # 存储所有时间步的logits\n", 266 | " all_logits = []\n", 267 | "\n", 268 | " for step in range(max_length):\n", 269 | " # 获取解码器输出\n", 270 | " decoder_outputs = self.decoder(\n", 271 | " input_ids=decoder_input_ids, \n", 272 | " encoder_hidden_states=encoder_outputs\n", 273 | " ).last_hidden_state\n", 274 | "\n", 275 | " # 预测下一个词\n", 276 | " next_word_logits = self.fc(decoder_outputs[:, -1, :])\n", 277 | " all_logits.append(next_word_logits.unsqueeze(1))\n", 278 | " next_word_id = next_word_logits.argmax(dim=-1).unsqueeze(-1)\n", 279 | " \n", 280 | " # 将预测的词添加到解码器输入中\n", 281 | " decoder_input_ids = torch.cat([decoder_input_ids, next_word_id], dim=-1)\n", 282 | " \n", 283 | " return decoder_input_ids ,torch.cat(all_logits, dim=1)\n", 284 | "\n", 285 | "\n", 286 | "\n", 287 | "\n", 288 | "\n" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 3, 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "name": "stdout", 298 | "output_type": "stream", 299 | "text": [ 300 | "Using GPU: NVIDIA GeForce RTX 3070 Laptop GPU\n", 301 | "Epoch [1/10], Batch [100/3385], Loss: 1.506847620010376\n", 302 | "Epoch [1/10], Batch [200/3385], Loss: 1.0654538869857788\n", 303 | "Epoch [1/10], Batch [300/3385], Loss: 0.720127284526825\n", 304 | "Epoch [1/10], Batch [400/3385], Loss: 0.7672491669654846\n", 305 | "Epoch [1/10], Batch [500/3385], Loss: 0.6676333546638489\n", 306 | "Epoch [1/10], Batch [600/3385], Loss: 0.7424653172492981\n", 307 | "Epoch [1/10], Batch [700/3385], Loss: 0.526539146900177\n", 308 | "Epoch [1/10], Batch [800/3385], Loss: 0.5412440896034241\n", 309 | "Epoch [1/10], Batch [900/3385], Loss: 0.5988483428955078\n", 310 | "Epoch [1/10], Batch [900/3385], Loss: 0.5988483428955078, BLEU Score: 0.6259826859669076\n", 311 | "New best model saved to model/best_model_epoch_1_batch_900.pth with BLEU score 0.6259826859669076\n", 312 | "Epoch [1/10], Batch [1000/3385], Loss: 0.6830837726593018\n", 313 | "Epoch [1/10], Batch [1100/3385], Loss: 0.6536044478416443\n", 314 | "Epoch [1/10], Batch [1200/3385], Loss: 0.6345134973526001\n", 315 | "Epoch [1/10], Batch [1300/3385], Loss: 0.6075872182846069\n", 316 | "Epoch [1/10], Batch [1400/3385], Loss: 0.6370975375175476\n", 317 | "Epoch [1/10], Batch [1500/3385], Loss: 0.5964183211326599\n", 318 | "Epoch [1/10], Batch [1600/3385], Loss: 0.633446455001831\n", 319 | "Epoch [1/10], Batch [1700/3385], Loss: 0.4941997230052948\n", 320 | "Epoch [1/10], Batch [1800/3385], Loss: 0.6094977855682373\n", 321 | "Epoch [1/10], Batch [1800/3385], Loss: 0.6094977855682373, BLEU Score: 0.6499761614377776\n", 322 | "New best model saved to model/best_model_epoch_1_batch_1800.pth with BLEU score 0.6499761614377776\n", 323 | "Epoch [1/10], Batch [1900/3385], Loss: 0.5010225772857666\n", 324 | "Epoch [1/10], Batch [2000/3385], Loss: 0.5958297252655029\n", 325 | "Epoch [1/10], Batch [2100/3385], Loss: 0.5294401049613953\n", 326 | "Epoch [1/10], Batch [2200/3385], Loss: 0.4647580087184906\n", 327 | "Epoch [1/10], Batch [2300/3385], Loss: 0.6333135962486267\n", 328 | "Epoch [1/10], Batch [2400/3385], Loss: 0.5327613949775696\n", 329 | "Epoch [1/10], Batch [2500/3385], Loss: 0.5936553478240967\n", 330 | "Epoch [1/10], Batch [2600/3385], Loss: 0.5662683248519897\n", 331 | "Epoch [1/10], Batch [2700/3385], Loss: 0.4893517792224884\n", 332 | "Epoch [1/10], Batch [2700/3385], Loss: 0.4893517792224884, BLEU Score: 0.6273520776805721\n", 333 | "Epoch [1/10], Batch [2800/3385], Loss: 0.5826975107192993\n", 334 | "Epoch [1/10], Batch [2900/3385], Loss: 0.5940765738487244\n", 335 | "Epoch [1/10], Batch [3000/3385], Loss: 0.45800718665122986\n", 336 | "Epoch [1/10], Batch [3100/3385], Loss: 0.4425952732563019\n", 337 | "Epoch [1/10], Batch [3200/3385], Loss: 0.4501250982284546\n", 338 | "Epoch [1/10], Batch [3300/3385], Loss: 0.5157708525657654\n", 339 | "Epoch [2/10], Batch [100/3385], Loss: 0.46476876735687256\n", 340 | "Epoch [2/10], Batch [200/3385], Loss: 0.5956483483314514\n", 341 | "Epoch [2/10], Batch [300/3385], Loss: 0.4944517910480499\n", 342 | "Epoch [2/10], Batch [400/3385], Loss: 0.46747153997421265\n", 343 | "Epoch [2/10], Batch [500/3385], Loss: 0.5078564286231995\n", 344 | "Epoch [2/10], Batch [600/3385], Loss: 0.49568527936935425\n", 345 | "Epoch [2/10], Batch [700/3385], Loss: 0.48543137311935425\n", 346 | "Epoch [2/10], Batch [800/3385], Loss: 0.5878289341926575\n", 347 | "Epoch [2/10], Batch [900/3385], Loss: 0.3999037444591522\n", 348 | "Epoch [2/10], Batch [900/3385], Loss: 0.3999037444591522, BLEU Score: 0.6306047781805396\n", 349 | "Epoch [2/10], Batch [1000/3385], Loss: 0.4190608859062195\n", 350 | "Epoch [2/10], Batch [1100/3385], Loss: 0.6893869638442993\n", 351 | "Epoch [2/10], Batch [1200/3385], Loss: 0.45286086201667786\n", 352 | "Epoch [2/10], Batch [1300/3385], Loss: 0.5955387353897095\n", 353 | "Epoch [2/10], Batch [1400/3385], Loss: 0.48197141289711\n", 354 | "Epoch [2/10], Batch [1500/3385], Loss: 0.5032705068588257\n", 355 | "Epoch [2/10], Batch [1600/3385], Loss: 0.5179394483566284\n", 356 | "Epoch [2/10], Batch [1700/3385], Loss: 0.5289398431777954\n", 357 | "Epoch [2/10], Batch [1800/3385], Loss: 0.48149487376213074\n", 358 | "Epoch [2/10], Batch [1800/3385], Loss: 0.48149487376213074, BLEU Score: 0.5952451904844176\n", 359 | "Epoch [2/10], Batch [1900/3385], Loss: 0.5015336871147156\n", 360 | "Epoch [2/10], Batch [2000/3385], Loss: 0.41118180751800537\n", 361 | "Epoch [2/10], Batch [2100/3385], Loss: 0.6792697906494141\n", 362 | "Epoch [2/10], Batch [2200/3385], Loss: 0.466006875038147\n", 363 | "Epoch [2/10], Batch [2300/3385], Loss: 0.492031991481781\n", 364 | "Epoch [2/10], Batch [2400/3385], Loss: 0.4278179109096527\n", 365 | "Epoch [2/10], Batch [2500/3385], Loss: 0.5117957592010498\n", 366 | "Epoch [2/10], Batch [2600/3385], Loss: 0.46907398104667664\n", 367 | "Epoch [2/10], Batch [2700/3385], Loss: 0.599402129650116\n", 368 | "Epoch [2/10], Batch [2700/3385], Loss: 0.599402129650116, BLEU Score: 0.6457694776801362\n", 369 | "Epoch [2/10], Batch [2800/3385], Loss: 0.5363532304763794\n", 370 | "Epoch [2/10], Batch [2900/3385], Loss: 0.5502106547355652\n", 371 | "Epoch [2/10], Batch [3000/3385], Loss: 0.5129263401031494\n", 372 | "Epoch [2/10], Batch [3100/3385], Loss: 0.4920698404312134\n", 373 | "Epoch [2/10], Batch [3200/3385], Loss: 0.47866472601890564\n", 374 | "Epoch [2/10], Batch [3300/3385], Loss: 0.525836169719696\n", 375 | "Epoch [3/10], Batch [100/3385], Loss: 0.4898045063018799\n", 376 | "Epoch [3/10], Batch [200/3385], Loss: 0.5084397196769714\n", 377 | "Epoch [3/10], Batch [300/3385], Loss: 0.4473249614238739\n", 378 | "Epoch [3/10], Batch [400/3385], Loss: 0.4763574004173279\n", 379 | "Epoch [3/10], Batch [500/3385], Loss: 0.5079389214515686\n", 380 | "Epoch [3/10], Batch [600/3385], Loss: 0.44530239701271057\n", 381 | "Epoch [3/10], Batch [700/3385], Loss: 0.44787898659706116\n", 382 | "Epoch [3/10], Batch [800/3385], Loss: 0.4493464529514313\n", 383 | "Epoch [3/10], Batch [900/3385], Loss: 0.4891371428966522\n", 384 | "Epoch [3/10], Batch [900/3385], Loss: 0.4891371428966522, BLEU Score: 0.6265277873070316\n", 385 | "Epoch [3/10], Batch [1000/3385], Loss: 0.44602829217910767\n", 386 | "Epoch [3/10], Batch [1100/3385], Loss: 0.5403011441230774\n", 387 | "Epoch [3/10], Batch [1200/3385], Loss: 0.4890633821487427\n", 388 | "Epoch [3/10], Batch [1300/3385], Loss: 0.4871854782104492\n", 389 | "Epoch [3/10], Batch [1400/3385], Loss: 0.4959268867969513\n", 390 | "Epoch [3/10], Batch [1500/3385], Loss: 0.4298979938030243\n", 391 | "Epoch [3/10], Batch [1600/3385], Loss: 0.5024811625480652\n", 392 | "Epoch [3/10], Batch [1700/3385], Loss: 0.5016076564788818\n", 393 | "Epoch [3/10], Batch [1800/3385], Loss: 0.4154992401599884\n", 394 | "Epoch [3/10], Batch [1800/3385], Loss: 0.4154992401599884, BLEU Score: 0.5135787960939046\n", 395 | "Epoch [3/10], Batch [1900/3385], Loss: 0.4994415044784546\n", 396 | "Epoch [3/10], Batch [2000/3385], Loss: 0.5580220818519592\n", 397 | "Epoch [3/10], Batch [2100/3385], Loss: 0.4753514528274536\n", 398 | "Epoch [3/10], Batch [2200/3385], Loss: 0.5234423875808716\n", 399 | "Epoch [3/10], Batch [2300/3385], Loss: 0.5374242663383484\n", 400 | "Epoch [3/10], Batch [2400/3385], Loss: 0.46420326828956604\n", 401 | "Epoch [3/10], Batch [2500/3385], Loss: 0.7553501129150391\n", 402 | "Epoch [3/10], Batch [2600/3385], Loss: 0.4739769697189331\n", 403 | "Epoch [3/10], Batch [2700/3385], Loss: 0.45902305841445923\n", 404 | "Epoch [3/10], Batch [2700/3385], Loss: 0.45902305841445923, BLEU Score: 0.5779657123842289\n", 405 | "Epoch [3/10], Batch [2800/3385], Loss: 0.5141093134880066\n", 406 | "Epoch [3/10], Batch [2900/3385], Loss: 0.5070458054542542\n", 407 | "Epoch [3/10], Batch [3000/3385], Loss: 0.4712170958518982\n", 408 | "Epoch [3/10], Batch [3100/3385], Loss: 0.42678409814834595\n", 409 | "Epoch [3/10], Batch [3200/3385], Loss: 0.43770480155944824\n", 410 | "Epoch [3/10], Batch [3300/3385], Loss: 0.4264000356197357\n", 411 | "Epoch [4/10], Batch [100/3385], Loss: 0.46770212054252625\n", 412 | "Epoch [4/10], Batch [200/3385], Loss: 0.43026411533355713\n", 413 | "Epoch [4/10], Batch [300/3385], Loss: 0.46422722935676575\n", 414 | "Epoch [4/10], Batch [400/3385], Loss: 0.4911220073699951\n", 415 | "Epoch [4/10], Batch [500/3385], Loss: 0.4276711344718933\n", 416 | "Epoch [4/10], Batch [600/3385], Loss: 0.6344284415245056\n", 417 | "Epoch [4/10], Batch [700/3385], Loss: 0.48250553011894226\n", 418 | "Epoch [4/10], Batch [800/3385], Loss: 0.4964185953140259\n", 419 | "Epoch [4/10], Batch [900/3385], Loss: 0.5124995112419128\n", 420 | "Epoch [4/10], Batch [900/3385], Loss: 0.5124995112419128, BLEU Score: 0.5957952320057779\n", 421 | "Epoch [4/10], Batch [1000/3385], Loss: 0.4611283242702484\n", 422 | "Epoch [4/10], Batch [1100/3385], Loss: 0.3852692246437073\n", 423 | "Epoch [4/10], Batch [1200/3385], Loss: 0.4541368782520294\n", 424 | "Epoch [4/10], Batch [1300/3385], Loss: 0.46968477964401245\n", 425 | "Epoch [4/10], Batch [1400/3385], Loss: 0.4993707835674286\n", 426 | "Epoch [4/10], Batch [1500/3385], Loss: 0.43745920062065125\n", 427 | "Epoch [4/10], Batch [1600/3385], Loss: 0.5956730246543884\n", 428 | "Epoch [4/10], Batch [1700/3385], Loss: 0.4268757104873657\n", 429 | "Epoch [4/10], Batch [1800/3385], Loss: 0.39837056398391724\n", 430 | "Epoch [4/10], Batch [1800/3385], Loss: 0.39837056398391724, BLEU Score: 0.5280371000845845\n", 431 | "Epoch [4/10], Batch [1900/3385], Loss: 0.5282601118087769\n", 432 | "Epoch [4/10], Batch [2000/3385], Loss: 0.4508409798145294\n", 433 | "Epoch [4/10], Batch [2100/3385], Loss: 0.44163599610328674\n", 434 | "Epoch [4/10], Batch [2200/3385], Loss: 0.4085763990879059\n", 435 | "Epoch [4/10], Batch [2300/3385], Loss: 0.42351624369621277\n", 436 | "Epoch [4/10], Batch [2400/3385], Loss: 0.4682406485080719\n", 437 | "Epoch [4/10], Batch [2500/3385], Loss: 0.39348071813583374\n", 438 | "Epoch [4/10], Batch [2600/3385], Loss: 0.45065101981163025\n", 439 | "Epoch [4/10], Batch [2700/3385], Loss: 0.4543023407459259\n", 440 | "Epoch [4/10], Batch [2700/3385], Loss: 0.4543023407459259, BLEU Score: 0.639408653717198\n", 441 | "Epoch [4/10], Batch [2800/3385], Loss: 0.4714568853378296\n", 442 | "Epoch [4/10], Batch [2900/3385], Loss: 0.42544957995414734\n", 443 | "Epoch [4/10], Batch [3000/3385], Loss: 0.40744078159332275\n", 444 | "Epoch [4/10], Batch [3100/3385], Loss: 0.5178104639053345\n", 445 | "Epoch [4/10], Batch [3200/3385], Loss: 0.4387783408164978\n", 446 | "Epoch [4/10], Batch [3300/3385], Loss: 0.43427574634552\n", 447 | "Epoch [5/10], Batch [100/3385], Loss: 0.4102031886577606\n", 448 | "Epoch [5/10], Batch [200/3385], Loss: 0.5433588624000549\n", 449 | "Epoch [5/10], Batch [300/3385], Loss: 0.6278513669967651\n", 450 | "Epoch [5/10], Batch [400/3385], Loss: 0.4025243818759918\n", 451 | "Epoch [5/10], Batch [500/3385], Loss: 0.44197726249694824\n", 452 | "Epoch [5/10], Batch [600/3385], Loss: 0.43441057205200195\n", 453 | "Epoch [5/10], Batch [700/3385], Loss: 0.4190601706504822\n", 454 | "Epoch [5/10], Batch [800/3385], Loss: 0.4117361605167389\n", 455 | "Epoch [5/10], Batch [900/3385], Loss: 0.4667566418647766\n", 456 | "Epoch [5/10], Batch [900/3385], Loss: 0.4667566418647766, BLEU Score: 0.657183633386523\n", 457 | "New best model saved to model/best_model_epoch_5_batch_900.pth with BLEU score 0.657183633386523\n", 458 | "Epoch [5/10], Batch [1000/3385], Loss: 0.451249897480011\n", 459 | "Epoch [5/10], Batch [1100/3385], Loss: 0.49214866757392883\n", 460 | "Epoch [5/10], Batch [1200/3385], Loss: 0.43253254890441895\n", 461 | "Epoch [5/10], Batch [1300/3385], Loss: 0.5628259181976318\n", 462 | "Epoch [5/10], Batch [1400/3385], Loss: 0.37490856647491455\n", 463 | "Epoch [5/10], Batch [1500/3385], Loss: 0.499080628156662\n", 464 | "Epoch [5/10], Batch [1600/3385], Loss: 0.4075632095336914\n", 465 | "Epoch [5/10], Batch [1700/3385], Loss: 0.45617181062698364\n", 466 | "Epoch [5/10], Batch [1800/3385], Loss: 0.4551757872104645\n", 467 | "Epoch [5/10], Batch [1800/3385], Loss: 0.4551757872104645, BLEU Score: 0.6360896363451195\n", 468 | "Epoch [5/10], Batch [1900/3385], Loss: 0.4530367851257324\n", 469 | "Epoch [5/10], Batch [2000/3385], Loss: 0.4712294340133667\n", 470 | "Epoch [5/10], Batch [2100/3385], Loss: 0.5259047746658325\n", 471 | "Epoch [5/10], Batch [2200/3385], Loss: 0.551623523235321\n", 472 | "Epoch [5/10], Batch [2300/3385], Loss: 0.4146064519882202\n", 473 | "Epoch [5/10], Batch [2400/3385], Loss: 0.5055832862854004\n", 474 | "Epoch [5/10], Batch [2500/3385], Loss: 0.48640984296798706\n", 475 | "Epoch [5/10], Batch [2600/3385], Loss: 0.43123242259025574\n", 476 | "Epoch [5/10], Batch [2700/3385], Loss: 0.4834398329257965\n", 477 | "Epoch [5/10], Batch [2700/3385], Loss: 0.4834398329257965, BLEU Score: 0.6066966294284946\n", 478 | "Epoch [5/10], Batch [2800/3385], Loss: 0.37496018409729004\n", 479 | "Epoch [5/10], Batch [2900/3385], Loss: 0.41133013367652893\n", 480 | "Epoch [5/10], Batch [3000/3385], Loss: 0.4146104156970978\n", 481 | "Epoch [5/10], Batch [3100/3385], Loss: 0.46421417593955994\n", 482 | "Epoch [5/10], Batch [3200/3385], Loss: 0.3927662670612335\n", 483 | "Epoch [5/10], Batch [3300/3385], Loss: 0.4173738360404968\n", 484 | "Epoch [6/10], Batch [100/3385], Loss: 0.45240864157676697\n", 485 | "Epoch [6/10], Batch [200/3385], Loss: 0.4844304323196411\n", 486 | "Epoch [6/10], Batch [300/3385], Loss: 0.5530915260314941\n", 487 | "Epoch [6/10], Batch [400/3385], Loss: 0.45176994800567627\n", 488 | "Epoch [6/10], Batch [500/3385], Loss: 0.4052988886833191\n", 489 | "Epoch [6/10], Batch [600/3385], Loss: 0.42596396803855896\n", 490 | "Epoch [6/10], Batch [700/3385], Loss: 0.4326914846897125\n", 491 | "Epoch [6/10], Batch [800/3385], Loss: 0.4937518835067749\n", 492 | "Epoch [6/10], Batch [900/3385], Loss: 0.4936274290084839\n", 493 | "Epoch [6/10], Batch [900/3385], Loss: 0.4936274290084839, BLEU Score: 0.6099578034593809\n", 494 | "Epoch [6/10], Batch [1000/3385], Loss: 0.4961879849433899\n", 495 | "Epoch [6/10], Batch [1100/3385], Loss: 0.41622450947761536\n", 496 | "Epoch [6/10], Batch [1200/3385], Loss: 0.39469578862190247\n", 497 | "Epoch [6/10], Batch [1300/3385], Loss: 0.4247240126132965\n", 498 | "Epoch [6/10], Batch [1400/3385], Loss: 0.5799625515937805\n", 499 | "Epoch [6/10], Batch [1500/3385], Loss: 0.4860732853412628\n", 500 | "Epoch [6/10], Batch [1600/3385], Loss: 0.4164365828037262\n", 501 | "Epoch [6/10], Batch [1700/3385], Loss: 0.48073866963386536\n", 502 | "Epoch [6/10], Batch [1800/3385], Loss: 0.3841512203216553\n", 503 | "Epoch [6/10], Batch [1800/3385], Loss: 0.3841512203216553, BLEU Score: 0.5947547813004236\n", 504 | "Epoch [6/10], Batch [1900/3385], Loss: 0.4323679506778717\n", 505 | "Epoch [6/10], Batch [2000/3385], Loss: 0.49509772658348083\n", 506 | "Epoch [6/10], Batch [2100/3385], Loss: 0.49317845702171326\n", 507 | "Epoch [6/10], Batch [2200/3385], Loss: 0.4318307340145111\n", 508 | "Epoch [6/10], Batch [2300/3385], Loss: 0.46929118037223816\n", 509 | "Epoch [6/10], Batch [2400/3385], Loss: 0.386508971452713\n", 510 | "Epoch [6/10], Batch [2500/3385], Loss: 0.3676878809928894\n", 511 | "Epoch [6/10], Batch [2600/3385], Loss: 0.39446550607681274\n", 512 | "Epoch [6/10], Batch [2700/3385], Loss: 0.525370717048645\n", 513 | "Epoch [6/10], Batch [2700/3385], Loss: 0.525370717048645, BLEU Score: 0.6140873664042887\n", 514 | "Epoch [6/10], Batch [2800/3385], Loss: 0.3571808338165283\n", 515 | "Epoch [6/10], Batch [2900/3385], Loss: 0.4875484108924866\n", 516 | "Epoch [6/10], Batch [3000/3385], Loss: 0.4953784942626953\n", 517 | "Epoch [6/10], Batch [3100/3385], Loss: 0.4420313239097595\n", 518 | "Epoch [6/10], Batch [3200/3385], Loss: 0.39990687370300293\n", 519 | "Epoch [6/10], Batch [3300/3385], Loss: 0.4982645809650421\n", 520 | "Epoch [7/10], Batch [100/3385], Loss: 0.4385923147201538\n", 521 | "Epoch [7/10], Batch [200/3385], Loss: 0.43997159600257874\n", 522 | "Epoch [7/10], Batch [300/3385], Loss: 0.5116627216339111\n", 523 | "Epoch [7/10], Batch [400/3385], Loss: 0.43550121784210205\n", 524 | "Epoch [7/10], Batch [500/3385], Loss: 0.4919750988483429\n", 525 | "Epoch [7/10], Batch [600/3385], Loss: 0.4595571756362915\n", 526 | "Epoch [7/10], Batch [700/3385], Loss: 0.46697908639907837\n", 527 | "Epoch [7/10], Batch [800/3385], Loss: 0.4819534718990326\n", 528 | "Epoch [7/10], Batch [900/3385], Loss: 0.3747020959854126\n", 529 | "Epoch [7/10], Batch [900/3385], Loss: 0.3747020959854126, BLEU Score: 0.6191507813486982\n", 530 | "Epoch [7/10], Batch [1000/3385], Loss: 0.4624747037887573\n", 531 | "Epoch [7/10], Batch [1100/3385], Loss: 0.44760662317276\n", 532 | "Epoch [7/10], Batch [1200/3385], Loss: 0.5016758441925049\n", 533 | "Epoch [7/10], Batch [1300/3385], Loss: 0.5357837677001953\n", 534 | "Epoch [7/10], Batch [1400/3385], Loss: 0.47158148884773254\n", 535 | "Epoch [7/10], Batch [1500/3385], Loss: 0.40382376313209534\n", 536 | "Epoch [7/10], Batch [1600/3385], Loss: 0.4869937300682068\n", 537 | "Epoch [7/10], Batch [1700/3385], Loss: 0.47104746103286743\n", 538 | "Epoch [7/10], Batch [1800/3385], Loss: 0.4044100344181061\n", 539 | "Epoch [7/10], Batch [1800/3385], Loss: 0.4044100344181061, BLEU Score: 0.6562587813189472\n", 540 | "Epoch [7/10], Batch [1900/3385], Loss: 0.5432260036468506\n", 541 | "Epoch [7/10], Batch [2000/3385], Loss: 0.4081021249294281\n", 542 | "Epoch [7/10], Batch [2100/3385], Loss: 0.39168065786361694\n", 543 | "Epoch [7/10], Batch [2200/3385], Loss: 0.45291900634765625\n", 544 | "Epoch [7/10], Batch [2300/3385], Loss: 0.5602723956108093\n", 545 | "Epoch [7/10], Batch [2400/3385], Loss: 0.4504901170730591\n", 546 | "Epoch [7/10], Batch [2500/3385], Loss: 0.42020854353904724\n", 547 | "Epoch [7/10], Batch [2600/3385], Loss: 0.4315740168094635\n", 548 | "Epoch [7/10], Batch [2700/3385], Loss: 0.42608895897865295\n", 549 | "Epoch [7/10], Batch [2700/3385], Loss: 0.42608895897865295, BLEU Score: 0.6368497371633002\n", 550 | "Epoch [7/10], Batch [2800/3385], Loss: 0.44904786348342896\n", 551 | "Epoch [7/10], Batch [2900/3385], Loss: 0.4278545379638672\n", 552 | "Epoch [7/10], Batch [3000/3385], Loss: 0.46741461753845215\n", 553 | "Epoch [7/10], Batch [3100/3385], Loss: 0.5061702132225037\n", 554 | "Epoch [7/10], Batch [3200/3385], Loss: 0.4447651207447052\n", 555 | "Epoch [7/10], Batch [3300/3385], Loss: 0.40045034885406494\n", 556 | "Epoch [8/10], Batch [100/3385], Loss: 0.43047285079956055\n", 557 | "Epoch [8/10], Batch [200/3385], Loss: 0.35144561529159546\n", 558 | "Epoch [8/10], Batch [300/3385], Loss: 0.4458153247833252\n", 559 | "Epoch [8/10], Batch [400/3385], Loss: 0.4366624653339386\n", 560 | "Epoch [8/10], Batch [500/3385], Loss: 0.5055469870567322\n", 561 | "Epoch [8/10], Batch [600/3385], Loss: 0.4011968672275543\n", 562 | "Epoch [8/10], Batch [700/3385], Loss: 0.48366671800613403\n", 563 | "Epoch [8/10], Batch [800/3385], Loss: 0.5047667622566223\n", 564 | "Epoch [8/10], Batch [900/3385], Loss: 0.482930451631546\n", 565 | "Epoch [8/10], Batch [900/3385], Loss: 0.482930451631546, BLEU Score: 0.6122230712176184\n", 566 | "Epoch [8/10], Batch [1000/3385], Loss: 0.4426495134830475\n", 567 | "Epoch [8/10], Batch [1100/3385], Loss: 0.5322694182395935\n", 568 | "Epoch [8/10], Batch [1200/3385], Loss: 0.46033889055252075\n", 569 | "Epoch [8/10], Batch [1300/3385], Loss: 0.47482892870903015\n", 570 | "Epoch [8/10], Batch [1400/3385], Loss: 0.4434620141983032\n", 571 | "Epoch [8/10], Batch [1500/3385], Loss: 0.4904381334781647\n", 572 | "Epoch [8/10], Batch [1600/3385], Loss: 0.3616958558559418\n", 573 | "Epoch [8/10], Batch [1700/3385], Loss: 0.37503698468208313\n", 574 | "Epoch [8/10], Batch [1800/3385], Loss: 0.4366777539253235\n", 575 | "Epoch [8/10], Batch [1800/3385], Loss: 0.4366777539253235, BLEU Score: 0.6585437255741698\n", 576 | "New best model saved to model/best_model_epoch_8_batch_1800.pth with BLEU score 0.6585437255741698\n", 577 | "Epoch [8/10], Batch [1900/3385], Loss: 0.3897594213485718\n", 578 | "Epoch [8/10], Batch [2000/3385], Loss: 0.40130987763404846\n", 579 | "Epoch [8/10], Batch [2100/3385], Loss: 0.37560680508613586\n", 580 | "Epoch [8/10], Batch [2200/3385], Loss: 0.444630891084671\n", 581 | "Epoch [8/10], Batch [2300/3385], Loss: 0.4399137496948242\n", 582 | "Epoch [8/10], Batch [2400/3385], Loss: 0.5606774687767029\n", 583 | "Epoch [8/10], Batch [2500/3385], Loss: 0.4447225332260132\n", 584 | "Epoch [8/10], Batch [2600/3385], Loss: 0.36013519763946533\n", 585 | "Epoch [8/10], Batch [2700/3385], Loss: 0.5258728861808777\n", 586 | "Epoch [8/10], Batch [2700/3385], Loss: 0.5258728861808777, BLEU Score: 0.6099369823315987\n", 587 | "Epoch [8/10], Batch [2800/3385], Loss: 0.5524425506591797\n", 588 | "Epoch [8/10], Batch [2900/3385], Loss: 0.49892255663871765\n", 589 | "Epoch [8/10], Batch [3000/3385], Loss: 0.43546628952026367\n", 590 | "Epoch [8/10], Batch [3100/3385], Loss: 0.4473860561847687\n", 591 | "Epoch [8/10], Batch [3200/3385], Loss: 0.45614057779312134\n", 592 | "Epoch [8/10], Batch [3300/3385], Loss: 0.4784499704837799\n", 593 | "Epoch [9/10], Batch [100/3385], Loss: 0.4882113039493561\n", 594 | "Epoch [9/10], Batch [200/3385], Loss: 0.39622798562049866\n", 595 | "Epoch [9/10], Batch [300/3385], Loss: 0.43471258878707886\n", 596 | "Epoch [9/10], Batch [400/3385], Loss: 0.5436696410179138\n", 597 | "Epoch [9/10], Batch [500/3385], Loss: 0.4261394441127777\n", 598 | "Epoch [9/10], Batch [600/3385], Loss: 0.46043333411216736\n", 599 | "Epoch [9/10], Batch [700/3385], Loss: 0.5166803002357483\n", 600 | "Epoch [9/10], Batch [800/3385], Loss: 0.5033921003341675\n", 601 | "Epoch [9/10], Batch [900/3385], Loss: 0.43346792459487915\n", 602 | "Epoch [9/10], Batch [900/3385], Loss: 0.43346792459487915, BLEU Score: 0.6411087982755627\n", 603 | "Epoch [9/10], Batch [1000/3385], Loss: 0.4604831337928772\n", 604 | "Epoch [9/10], Batch [1100/3385], Loss: 0.48319172859191895\n", 605 | "Epoch [9/10], Batch [1200/3385], Loss: 0.40194717049598694\n", 606 | "Epoch [9/10], Batch [1300/3385], Loss: 0.5072685480117798\n", 607 | "Epoch [9/10], Batch [1400/3385], Loss: 0.4238347113132477\n", 608 | "Epoch [9/10], Batch [1500/3385], Loss: 0.6097126007080078\n", 609 | "Epoch [9/10], Batch [1600/3385], Loss: 0.4651317596435547\n", 610 | "Epoch [9/10], Batch [1700/3385], Loss: 0.45657482743263245\n", 611 | "Epoch [9/10], Batch [1800/3385], Loss: 0.3479486107826233\n", 612 | "Epoch [9/10], Batch [1800/3385], Loss: 0.3479486107826233, BLEU Score: 0.6366775191183544\n", 613 | "Epoch [9/10], Batch [1900/3385], Loss: 0.36093080043792725\n", 614 | "Epoch [9/10], Batch [2000/3385], Loss: 0.371528297662735\n", 615 | "Epoch [9/10], Batch [2100/3385], Loss: 0.5159643888473511\n", 616 | "Epoch [9/10], Batch [2200/3385], Loss: 0.38837355375289917\n", 617 | "Epoch [9/10], Batch [2300/3385], Loss: 0.5057320594787598\n", 618 | "Epoch [9/10], Batch [2400/3385], Loss: 0.46596574783325195\n", 619 | "Epoch [9/10], Batch [2500/3385], Loss: 0.450432151556015\n", 620 | "Epoch [9/10], Batch [2600/3385], Loss: 0.40980589389801025\n", 621 | "Epoch [9/10], Batch [2700/3385], Loss: 0.4379664361476898\n", 622 | "Epoch [9/10], Batch [2700/3385], Loss: 0.4379664361476898, BLEU Score: 0.6150770088422075\n", 623 | "Epoch [9/10], Batch [2800/3385], Loss: 0.4877370297908783\n", 624 | "Epoch [9/10], Batch [2900/3385], Loss: 0.40075528621673584\n", 625 | "Epoch [9/10], Batch [3000/3385], Loss: 0.3522195816040039\n", 626 | "Epoch [9/10], Batch [3100/3385], Loss: 0.3765401244163513\n", 627 | "Epoch [9/10], Batch [3200/3385], Loss: 0.38588494062423706\n", 628 | "Epoch [9/10], Batch [3300/3385], Loss: 0.4325076937675476\n", 629 | "Epoch [10/10], Batch [100/3385], Loss: 0.4888536036014557\n", 630 | "Epoch [10/10], Batch [200/3385], Loss: 0.3897317051887512\n", 631 | "Epoch [10/10], Batch [300/3385], Loss: 0.35797613859176636\n", 632 | "Epoch [10/10], Batch [400/3385], Loss: 0.39244332909584045\n", 633 | "Epoch [10/10], Batch [500/3385], Loss: 0.42655041813850403\n", 634 | "Epoch [10/10], Batch [600/3385], Loss: 0.4513237178325653\n", 635 | "Epoch [10/10], Batch [700/3385], Loss: 0.4259300231933594\n", 636 | "Epoch [10/10], Batch [800/3385], Loss: 0.3846912384033203\n", 637 | "Epoch [10/10], Batch [900/3385], Loss: 0.407758891582489\n", 638 | "Epoch [10/10], Batch [900/3385], Loss: 0.407758891582489, BLEU Score: 0.6306034361402596\n", 639 | "Epoch [10/10], Batch [1000/3385], Loss: 0.5385867357254028\n", 640 | "Epoch [10/10], Batch [1100/3385], Loss: 0.46417561173439026\n", 641 | "Epoch [10/10], Batch [1200/3385], Loss: 0.4319520592689514\n", 642 | "Epoch [10/10], Batch [1300/3385], Loss: 0.36374565958976746\n", 643 | "Epoch [10/10], Batch [1400/3385], Loss: 0.3605068624019623\n", 644 | "Epoch [10/10], Batch [1500/3385], Loss: 0.4097943603992462\n", 645 | "Epoch [10/10], Batch [1600/3385], Loss: 0.42661750316619873\n", 646 | "Epoch [10/10], Batch [1700/3385], Loss: 0.4788033664226532\n", 647 | "Epoch [10/10], Batch [1800/3385], Loss: 0.3844136893749237\n", 648 | "Epoch [10/10], Batch [1800/3385], Loss: 0.3844136893749237, BLEU Score: 0.6339382113448834\n", 649 | "Epoch [10/10], Batch [1900/3385], Loss: 0.46550217270851135\n", 650 | "Epoch [10/10], Batch [2000/3385], Loss: 0.33330783247947693\n", 651 | "Epoch [10/10], Batch [2100/3385], Loss: 0.4416351318359375\n", 652 | "Epoch [10/10], Batch [2200/3385], Loss: 0.5118553042411804\n", 653 | "Epoch [10/10], Batch [2300/3385], Loss: 0.3825181722640991\n", 654 | "Epoch [10/10], Batch [2400/3385], Loss: 0.3916151523590088\n", 655 | "Epoch [10/10], Batch [2500/3385], Loss: 0.42515692114830017\n", 656 | "Epoch [10/10], Batch [2600/3385], Loss: 0.43171802163124084\n", 657 | "Epoch [10/10], Batch [2700/3385], Loss: 0.4633576571941376\n", 658 | "Epoch [10/10], Batch [2700/3385], Loss: 0.4633576571941376, BLEU Score: 0.6640697619027633\n", 659 | "New best model saved to model/best_model_epoch_10_batch_2700.pth with BLEU score 0.6640697619027633\n", 660 | "Epoch [10/10], Batch [2800/3385], Loss: 0.5012379288673401\n", 661 | "Epoch [10/10], Batch [2900/3385], Loss: 0.4428495466709137\n", 662 | "Epoch [10/10], Batch [3000/3385], Loss: 0.4523705840110779\n", 663 | "Epoch [10/10], Batch [3100/3385], Loss: 0.35377541184425354\n", 664 | "Epoch [10/10], Batch [3200/3385], Loss: 0.44981083273887634\n", 665 | "Epoch [10/10], Batch [3300/3385], Loss: 0.37558305263519287\n" 666 | ] 667 | } 668 | ], 669 | "source": [ 670 | "import torch\n", 671 | "from torch import nn\n", 672 | "from torch.optim import Adam\n", 673 | "from torch.utils.data import DataLoader\n", 674 | "from torchvision import transforms\n", 675 | "\n", 676 | "if torch.cuda.is_available():\n", 677 | " device = torch.device(\"cuda\")\n", 678 | " print(\"Using GPU:\", torch.cuda.get_device_name(0))\n", 679 | "else:\n", 680 | " device = torch.device(\"cpu\")\n", 681 | " print(\"Using CPU\")\n", 682 | "\n", 683 | "def evaluate(test_loader, model):\n", 684 | " model.eval() # 将模型设置为评估模式\n", 685 | " generated_captions = []\n", 686 | " actual_captions = []\n", 687 | "\n", 688 | " with torch.no_grad():\n", 689 | " for images, captions, caplens in test_loader:\n", 690 | " images = images.to(device)\n", 691 | " input_ids = images\n", 692 | " outputs,_ = model.generate_text(input_ids, max_length=95, start_token_id=test_dataset.vocab[''])\n", 693 | " for i in range(outputs.shape[0]):\n", 694 | " # 生成字幕\n", 695 | " gen_caption = [idx_to_word(idx, test_dataset.vocab) for idx in outputs[i]]\n", 696 | " # print(gen_caption)\n", 697 | " # 移除 \n", 698 | " if '' in gen_caption:\n", 699 | " gen_caption = gen_caption[1:] # 移除第一个元素 ()\n", 700 | " if '' in gen_caption:\n", 701 | " gen_caption = gen_caption[:gen_caption.index('')] # 移除 及其后面的元素\n", 702 | " \n", 703 | " generated_captions.append(' '.join(gen_caption))\n", 704 | "\n", 705 | " # 真实字幕\n", 706 | " act_caption = [idx_to_word(idx, test_dataset.vocab) for idx in captions[i]]\n", 707 | " # print(act_caption)\n", 708 | " # 移除 \n", 709 | " if '' in act_caption:\n", 710 | " act_caption = act_caption[1:] # 移除第一个元素 ()\n", 711 | " if '' in act_caption:\n", 712 | " act_caption = act_caption[:act_caption.index('')] # 移除 及其后面的元素\n", 713 | " \n", 714 | " actual_captions.append([' '.join(act_caption)])\n", 715 | "\n", 716 | " # 计算BLEU分数\n", 717 | " bleu4 = corpus_bleu(actual_captions, generated_captions, weights=(0.25,0.25,0.25,0.25))\n", 718 | " model.train()\n", 719 | " return bleu4\n", 720 | "\n", 721 | "transform = transforms.Compose([\n", 722 | " transforms.Resize((224, 224)), \n", 723 | " transforms.ToTensor(),\n", 724 | " # 这里可以添加其他必要的转换,如归一化等\n", 725 | "])\n", 726 | "\n", 727 | "dataset = ImageTextDataset(train_json_path, vocab_path, split='train', transform=transform)\n", 728 | "data_loader = DataLoader(dataset, batch_size=3, shuffle=True)\n", 729 | "\n", 730 | "vocab_size = len(dataset.vocab)\n", 731 | "vit_model_name = 'google/vit-base-patch16-224-in21k'\n", 732 | "transformer_config = BertConfig()\n", 733 | "\n", 734 | "# 初始化模型\n", 735 | "model = Img2TxtModel(vit_model_name, transformer_config, vocab_size)\n", 736 | "model = model.to(device)\n", 737 | "optimizer = Adam(model.parameters(), lr=0.0001)\n", 738 | "criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab[''])\n", 739 | "\n", 740 | "test_dataset = ImageTextDataset(test_json_path, vocab_path, split='test', transform=transform)\n", 741 | "test_loader = DataLoader(test_dataset, batch_size=3, shuffle=True)\n", 742 | "# 设定训练周期\n", 743 | "num_epochs = 10\n", 744 | "best_bleu_score = 0.0 # 初始化最高BLEU分数\n", 745 | "\n", 746 | "for epoch in range(num_epochs):\n", 747 | " for i, (images, captions, caplens) in enumerate(data_loader):\n", 748 | " # 假设您的ViT模型接受标准化的图像张量作为输入\n", 749 | " images = images.to(device)\n", 750 | " captions = captions.to(device)\n", 751 | " input_ids = images\n", 752 | "\n", 753 | " # 准备解码器输入\n", 754 | " decoder_input_ids = captions[:, :-1] # 删除每个字幕的最后一个单词\n", 755 | " decoder_attention_mask = (decoder_input_ids != dataset.vocab['']).type(torch.uint8)\n", 756 | " \n", 757 | " # 前向传播\n", 758 | " outputs = model(input_ids, decoder_input_ids, decoder_attention_mask)\n", 759 | "\n", 760 | " # 计算损失,outputs需要调整以适配损失函数的要求\n", 761 | " loss = criterion(outputs.view(-1, outputs.size(-1)), captions[:, 1:].contiguous().view(-1))\n", 762 | "\n", 763 | " # 反向传播和优化\n", 764 | " optimizer.zero_grad()\n", 765 | " loss.backward()\n", 766 | " optimizer.step()\n", 767 | "\n", 768 | " if (i+1) % 100 == 0:\n", 769 | " print(f\"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(data_loader)}], Loss: {loss.item()}\")\n", 770 | " \n", 771 | " if (i + 1) % 900 == 0:\n", 772 | " bleu4 = evaluate(test_loader, model)\n", 773 | " print(f\"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(data_loader)}], Loss: {loss.item()}, BLEU Score: {bleu4}\")\n", 774 | "\n", 775 | " # 如果BLEU分数是新的最高分,则保存模型\n", 776 | " if bleu4 > best_bleu_score:\n", 777 | " best_bleu_score = bleu4\n", 778 | " save_path = f\"model/best_model_epoch_{epoch+1}_batch_{i+1}.pth\"\n", 779 | " torch.save({\n", 780 | " 'epoch': epoch,\n", 781 | " 'batch': i,\n", 782 | " 'model_state_dict': model.state_dict(),\n", 783 | " 'optimizer_state_dict': optimizer.state_dict(),\n", 784 | " 'loss': loss,\n", 785 | " 'bleu_score': bleu4,\n", 786 | " }, save_path)\n", 787 | " print(f\"New best model saved to {save_path} with BLEU score {bleu4}\")\n", 788 | "\n", 789 | "\n", 790 | "\n", 791 | "\n" 792 | ] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "execution_count": 5, 797 | "metadata": {}, 798 | "outputs": [ 799 | { 800 | "name": "stdout", 801 | "output_type": "stream", 802 | "text": [ 803 | "Using GPU: NVIDIA GeForce RTX 3070 Laptop GPU\n", 804 | "bleu4 :0.3074786561430062\n", 805 | "cider_d_score :0.004119164250591897\n", 806 | "相对cider_d_score :0.9379024924812082\n", 807 | "spice_score :0.1321016861247424\n", 808 | "相对spice_score :0.706289522613148\n" 809 | ] 810 | } 811 | ], 812 | "source": [ 813 | "import torch\n", 814 | "from torch.utils.data import DataLoader\n", 815 | "from torchvision import transforms\n", 816 | "\n", 817 | "\n", 818 | "if torch.cuda.is_available():\n", 819 | " device = torch.device(\"cuda\")\n", 820 | " print(\"Using GPU:\", torch.cuda.get_device_name(0))\n", 821 | "else:\n", 822 | " device = torch.device(\"cpu\")\n", 823 | " print(\"Using CPU\")\n", 824 | "\n", 825 | "transform = transforms.Compose([\n", 826 | " transforms.Resize((224, 224)), \n", 827 | " transforms.ToTensor(),\n", 828 | " # 这里可以添加其他必要的转换,如归一化等\n", 829 | "])\n", 830 | "test_dataset = ImageTextDataset(test_json_path, vocab_path, split='test', transform=transform)\n", 831 | "test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)\n", 832 | "vocab_size = len(test_dataset.vocab)\n", 833 | "vit_model_name = 'google/vit-base-patch16-224-in21k'\n", 834 | "transformer_config = BertConfig()\n", 835 | "model = Img2TxtModel(vit_model_name, transformer_config, vocab_size)\n", 836 | "# 加载模型状态字典\n", 837 | "checkpoint = torch.load('./model/best_model_epoch_10_batch_2700.pth')\n", 838 | "\n", 839 | "# 将状态字典应用到模型实例中\n", 840 | "model.load_state_dict(checkpoint['model_state_dict'])\n", 841 | "model = model.to(device)\n", 842 | "\n", 843 | "model.eval() \n", 844 | "\n", 845 | "generated_captions = []\n", 846 | "actual_captions = []\n", 847 | "cands = []\n", 848 | "refs = []\n", 849 | "filterd_words = set({test_dataset.vocab[''], test_dataset.vocab[''], test_dataset.vocab['']})\n", 850 | "with torch.no_grad():\n", 851 | " for images, captions, caplens in test_loader:\n", 852 | " images = images.to(device)\n", 853 | " input_ids = images\n", 854 | " outputs,_ = model.generate_text(input_ids, max_length=95, start_token_id=test_dataset.vocab[''])\n", 855 | " for i in range(outputs.shape[0]):\n", 856 | " gen_caption = [idx_to_word(idx, test_dataset.vocab) for idx in outputs[i]]\n", 857 | " if '' in gen_caption:\n", 858 | " gen_caption = gen_caption[1:] # 移除第一个元素 ()\n", 859 | " if '' in gen_caption:\n", 860 | " gen_caption = gen_caption[:gen_caption.index('')] # 移除 及其后面的元素\n", 861 | " generated_captions.append(' '.join(gen_caption))\n", 862 | " act_caption = [idx_to_word(idx, test_dataset.vocab) for idx in captions[i]]\n", 863 | " # 移除 \n", 864 | " if '' in act_caption:\n", 865 | " act_caption = act_caption[1:] # 移除第一个元素 ()\n", 866 | " if '' in act_caption:\n", 867 | " act_caption = act_caption[:act_caption.index('')] # 移除 及其后面的元素\n", 868 | " \n", 869 | " actual_captions.append([' '.join(act_caption)])\n", 870 | " texts=outputs\n", 871 | " cands.extend([filter_cut_useless_words(text, filterd_words) for text in texts.tolist()])\n", 872 | " # 参考文本\n", 873 | " refs.extend([filter_cut_useless_words(cap, filterd_words) for cap in captions.tolist()])\n", 874 | " \n", 875 | " \n", 876 | " bleu4=get_BLEU_score(cands, refs)\n", 877 | "\n", 878 | " cider_d_score=get_CIDER_D_score(test_dataset.vocab,refs, cands)\n", 879 | " \n", 880 | " spice_score=get_SPICE_score(test_dataset.vocab,refs, cands)\n", 881 | " \n", 882 | " max_cider=0.0043918896512309125\n", 883 | " max_spice=0.18703616844830037\n", 884 | "\n", 885 | " print(f\"bleu4 :{bleu4}\")\n", 886 | " print(f\"cider_d_score :{cider_d_score}\")\n", 887 | " print(f\"相对cider_d_score :{cider_d_score / max_cider}\")\n", 888 | " print(f\"spice_score :{spice_score}\")\n", 889 | " print(f\"相对spice_score :{spice_score / max_spice}\")\n", 890 | "\n" 891 | ] 892 | } 893 | ], 894 | "metadata": { 895 | "kernelspec": { 896 | "display_name": ".venv", 897 | "language": "python", 898 | "name": "python3" 899 | }, 900 | "language_info": { 901 | "codemirror_mode": { 902 | "name": "ipython", 903 | "version": 3 904 | }, 905 | "file_extension": ".py", 906 | "mimetype": "text/x-python", 907 | "name": "python", 908 | "nbconvert_exporter": "python", 909 | "pygments_lexer": "ipython3", 910 | "version": "3.10.6" 911 | } 912 | }, 913 | "nbformat": 4, 914 | "nbformat_minor": 2 915 | } 916 | -------------------------------------------------------------------------------- /ViT/generate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 15, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "import os\n", 11 | "from PIL import Image\n", 12 | "import torch\n", 13 | "from torch.utils.data import Dataset, DataLoader\n", 14 | "from transformers import ViTFeatureExtractor, BertTokenizer\n", 15 | "from collections import defaultdict, Counter\n", 16 | "import numpy as np\n", 17 | "\n", 18 | "\n", 19 | "dataset='deepfashion-multimodal'\n", 20 | "img_path = f'data/{dataset}/img-001/img'\n", 21 | "vocab_path = f'data/{dataset}/vocab.json'\n", 22 | "\n", 23 | "def idx_to_word(idx, vocab):#将向量转化为文本描述\n", 24 | " reverse_vocab = {v: k for k, v in vocab.items()}\n", 25 | " return reverse_vocab.get(int(idx), '')\n", 26 | "\n", 27 | "class CustomImageDataset(Dataset):\n", 28 | " def __init__(self, img_folder, transform=None):\n", 29 | " self.img_folder = img_folder\n", 30 | " self.img_names = [img for img in os.listdir(img_folder) if img.endswith(('.png', '.jpg', '.jpeg'))]\n", 31 | " print(len(self.img_names))\n", 32 | " self.img_names = self.img_names[31000:36000]\n", 33 | " print(self.img_names[0])\n", 34 | " self.transform = transform\n", 35 | "\n", 36 | " def __len__(self):\n", 37 | " return len(self.img_names)\n", 38 | "\n", 39 | " def __getitem__(self, idx):\n", 40 | " img_path = os.path.join(self.img_folder, self.img_names[idx])\n", 41 | " image = Image.open(img_path).convert('RGB')\n", 42 | " if self.transform:\n", 43 | " image = self.transform(image)\n", 44 | "\n", 45 | " return image, self.img_names[idx]\n", 46 | "\n", 47 | "\n" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 16, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "from transformers import ViTModel, BertModel, BertConfig\n", 57 | "from torch import nn\n", 58 | "import torch\n", 59 | "\n", 60 | "class Img2TxtModel(nn.Module):\n", 61 | " def __init__(self, vit_model_name, transformer_config, vocab_size):\n", 62 | " super(Img2TxtModel, self).__init__()\n", 63 | " # ViT模型作为编码器\n", 64 | " self.encoder = ViTModel.from_pretrained(vit_model_name)\n", 65 | "\n", 66 | " # Transformer解码器配置\n", 67 | " transformer_config = BertConfig(vocab_size=vocab_size, num_hidden_layers=1, is_decoder=True, add_cross_attention=True)\n", 68 | " self.decoder = BertModel(transformer_config)\n", 69 | "\n", 70 | " # 预测每个词的线性层\n", 71 | " self.vocab_size = vocab_size\n", 72 | " self.fc = nn.Linear(transformer_config.hidden_size, vocab_size)\n", 73 | " \n", 74 | " def forward(self, input_ids, decoder_input_ids, decoder_attention_mask):\n", 75 | " # 通过ViT编码器获取图像特征\n", 76 | " encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state\n", 77 | "\n", 78 | " # 将图像特征作为解码器的输入\n", 79 | " decoder_outputs = self.decoder(input_ids=decoder_input_ids, \n", 80 | " attention_mask=decoder_attention_mask,\n", 81 | " encoder_hidden_states=encoder_outputs).last_hidden_state\n", 82 | "\n", 83 | " # 预测下一个词\n", 84 | " prediction_scores = self.fc(decoder_outputs)\n", 85 | " return prediction_scores\n", 86 | "\n", 87 | " def generate_text(self, input_ids, max_length=95, start_token_id=154):\n", 88 | " # 获取图像特征\n", 89 | " encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state\n", 90 | "\n", 91 | " # 初始化解码器输入为标记\n", 92 | " decoder_input_ids = torch.full((input_ids.size(0), 1), start_token_id).to(input_ids.device)\n", 93 | " \n", 94 | " # 存储所有时间步的logits\n", 95 | " all_logits = []\n", 96 | "\n", 97 | " for step in range(max_length):\n", 98 | " # 获取解码器输出\n", 99 | " decoder_outputs = self.decoder(\n", 100 | " input_ids=decoder_input_ids, \n", 101 | " encoder_hidden_states=encoder_outputs\n", 102 | " ).last_hidden_state\n", 103 | "\n", 104 | " # 预测下一个词\n", 105 | " next_word_logits = self.fc(decoder_outputs[:, -1, :])\n", 106 | " all_logits.append(next_word_logits.unsqueeze(1))\n", 107 | " next_word_id = next_word_logits.argmax(dim=-1).unsqueeze(-1)\n", 108 | " \n", 109 | " # 将预测的词添加到解码器输入中\n", 110 | " decoder_input_ids = torch.cat([decoder_input_ids, next_word_id], dim=-1)\n", 111 | " \n", 112 | " return decoder_input_ids ,torch.cat(all_logits, dim=1)\n", 113 | "\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 17, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "Using GPU: NVIDIA GeForce RTX 3070 Laptop GPU\n", 126 | "123016\n", 127 | "img_00031006.jpg\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "from transformers import ViTModel, BertModel, BertConfig\n", 133 | "import torch\n", 134 | "from torch import nn\n", 135 | "from torch.utils.data import DataLoader\n", 136 | "from torchvision import transforms\n", 137 | "\n", 138 | "if torch.cuda.is_available():\n", 139 | " device = torch.device(\"cuda\")\n", 140 | " print(\"Using GPU:\", torch.cuda.get_device_name(0))\n", 141 | "else:\n", 142 | " device = torch.device(\"cpu\")\n", 143 | " print(\"Using CPU\")\n", 144 | "\n", 145 | "# 图像预处理\n", 146 | "transform = transforms.Compose([\n", 147 | " transforms.Resize((224, 224)),\n", 148 | " transforms.ToTensor(),\n", 149 | " # 根据需要添加更多的转换\n", 150 | "])\n", 151 | "\n", 152 | "# 创建 Dataset 实例\n", 153 | "dataset = CustomImageDataset(img_folder=img_path, transform=transform)\n", 154 | "\n", 155 | "# 创建 DataLoader\n", 156 | "data_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", 157 | "\n", 158 | "with open(vocab_path, 'r') as f:\n", 159 | " vocab = json.load(f)\n", 160 | "\n", 161 | "vocab_size = len(vocab)\n", 162 | "vit_model_name = 'google/vit-base-patch16-224-in21k'\n", 163 | "transformer_config = BertConfig()\n", 164 | "\n", 165 | "model = Img2TxtModel(vit_model_name, transformer_config, vocab_size)\n", 166 | "# 加载模型状态字典\n", 167 | "checkpoint = torch.load('./model/best_model_epoch_10_batch_2700.pth')\n", 168 | "\n", 169 | "\n", 170 | "# 将状态字典应用到模型实例中\n", 171 | "model.load_state_dict(checkpoint['model_state_dict'])\n", 172 | "model = model.to(device)\n", 173 | "\n", 174 | "model.eval() # 将模型设置为评估模式\n", 175 | "\n", 176 | "generated_captions_dict = {}\n", 177 | "\n", 178 | "with torch.no_grad():\n", 179 | " for images, name in data_loader:\n", 180 | " images = images.to(device)\n", 181 | " input_ids = images\n", 182 | " outputs,_ = model.generate_text(input_ids, max_length=95, start_token_id=vocab[''])\n", 183 | " for i in range(outputs.shape[0]):\n", 184 | " gen_caption = [idx_to_word(idx, vocab) for idx in outputs[i]]\n", 185 | " if '' in gen_caption:\n", 186 | " gen_caption = gen_caption[1:] # 移除第一个元素 ()\n", 187 | " if '' in gen_caption:\n", 188 | " gen_caption = gen_caption[:gen_caption.index('')] # 移除 及其后面的元素\n", 189 | "\n", 190 | " caption_text = ' '.join(gen_caption)\n", 191 | " generated_captions_dict[name[0]] = caption_text\n", 192 | "with open('res.json', 'w') as f:\n", 193 | " json.dump(generated_captions_dict, f, indent=2)" 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": ".venv", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.10.6" 214 | } 215 | }, 216 | "nbformat": 4, 217 | "nbformat_minor": 2 218 | } 219 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence # 压紧填充序列 4 | from torch.utils.data import Dataset 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torchvision.models import ResNet101_Weights 8 | from nltk.translate.bleu_score import corpus_bleu # BLEU评价指标 9 | import numpy as np 10 | import json 11 | from torch.utils.data import Dataset 12 | import os 13 | from PIL import Image 14 | from collections import Counter,defaultdict 15 | from argparse import Namespace 16 | 17 | from ARCTIC.ARCTIC_model import ARCTIC,AttentionDecoder,AdditiveAttention,ImageEncoder 18 | from ARCTIC.ARCTIC_dataloader import ImageTextDataset ,mktrainval,cap_to_wvec,wvec_to_cap,wvec_to_capls 19 | import matplotlib.pyplot as plt 20 | 21 | max_cider=0.0043918896512309125 22 | max_spice=0.18703616844830037 23 | img_path = f'../data/deepfashion-multimodal/images' 24 | def cider_d(reference_list, candidate_list, n=4): 25 | def count_ngrams(tokens, n): 26 | ngrams = [] 27 | for i in range(len(tokens) - n + 1): 28 | ngram = tuple(tokens[i:i+n]) 29 | ngrams.append(ngram) 30 | return ngrams 31 | 32 | def compute_cider_d(reference_list, candidate_list, n): 33 | cider_d_scores = [] 34 | for refs, cand in zip(reference_list, candidate_list): 35 | cider_d_score = 0.0 36 | for i in range(1, n + 1): 37 | cand_ngrams = count_ngrams(cand, i) 38 | ref_ngrams_list = [count_ngrams(ref, i) for ref in refs] 39 | 40 | total_ref_ngrams = [ngram for ref_ngrams in ref_ngrams_list for ngram in ref_ngrams] 41 | 42 | count_cand = 0 43 | count_clip = 0 44 | 45 | for ngram in cand_ngrams: 46 | count_cand += 1 47 | if ngram in total_ref_ngrams: 48 | count_clip += 1 49 | 50 | precision = count_clip / count_cand if count_cand > 0 else 0.0 51 | recall = count_clip / len(total_ref_ngrams) if len(total_ref_ngrams) > 0 else 0.0 52 | 53 | beta = 1.0 54 | f_score = (1 + beta**2) * precision * recall / (beta**2 * precision + recall) if precision + recall > 0 else 0.0 55 | 56 | cider_d_score += f_score 57 | 58 | cider_d_score /= n 59 | cider_d_scores.append(cider_d_score) 60 | 61 | return cider_d_scores 62 | 63 | reference_tokens_list = reference_list 64 | candidate_tokens_list = candidate_list 65 | 66 | scores = compute_cider_d(reference_tokens_list, candidate_tokens_list, n) 67 | 68 | return np.mean(scores) 69 | def spice(reference_list, candidate_list, idf=None, beta=3): 70 | def tokenize(sentence): 71 | return sentence.lower().split() 72 | 73 | def count_ngrams(tokens, n): 74 | ngrams = [] 75 | for i in range(len(tokens) - n + 1): 76 | ngram = tuple(tokens[i:i+n]) 77 | ngrams.append(ngram) 78 | return ngrams 79 | 80 | def compute_spice_score(reference, candidate, idf, beta): 81 | reference_tokens = reference 82 | candidate_tokens = candidate 83 | 84 | reference_ngrams = [count_ngrams(reference_tokens, i) for i in range(1, beta + 1)] 85 | candidate_ngrams = [count_ngrams(candidate_tokens, i) for i in range(1, beta + 1)] 86 | 87 | precision_scores = [] 88 | recall_scores = [] 89 | 90 | for i in range(beta): 91 | common_ngrams = set(candidate_ngrams[i]) & set(reference_ngrams[i]) 92 | 93 | precision = len(common_ngrams) / len(candidate_ngrams[i]) if len(candidate_ngrams[i]) > 0 else 0.0 94 | recall = len(common_ngrams) / len(reference_ngrams[i]) if len(reference_ngrams[i]) > 0 else 0.0 95 | 96 | precision_scores.append(precision) 97 | recall_scores.append(recall) 98 | 99 | precision_avg = np.mean(precision_scores) 100 | recall_avg = np.mean(recall_scores) 101 | 102 | spice_score = (precision_avg * recall_avg) / (precision_avg + recall_avg) if precision_avg + recall_avg > 0 else 0.0 103 | 104 | if idf: 105 | spice_score *= np.exp(np.sum([idf[token] for token in common_ngrams]) / len(candidate_tokens)) 106 | 107 | return spice_score 108 | 109 | if idf is None: 110 | idf = {} 111 | 112 | spice_scores = [] 113 | 114 | for reference, candidate in zip(reference_list, candidate_list): 115 | spice_score = compute_spice_score(reference, candidate, idf, beta) 116 | spice_scores.append(spice_score) 117 | 118 | return np.mean(spice_scores) 119 | def get_BLEU_score(cands, refs): #获取BLEU分数 120 | multiple_refs = [] 121 | for idx in range(len(refs)): 122 | multiple_refs.append(refs[(idx//1)*1 : (idx//1)*1+1])#每个候选文本对应cpi==1条参考文本 123 | bleu4 = corpus_bleu(multiple_refs, cands, weights=(0.25,0.25,0.25,0.25)) 124 | return bleu4 125 | def get_CIDER_D_score(vocab,cands, refs): #获得CIDER-D分数 126 | refs_ = [wvec_to_capls(vocab,ref) for ref in refs] 127 | cands_ = [wvec_to_capls(vocab,cand) for cand in cands] 128 | return cider_d(refs_, cands_) 129 | def get_SPICE_score(vocab,cands, refs): #获得SPICE分数 130 | refs_ = [wvec_to_cap(vocab,ref) for ref in refs] 131 | cands_ = [wvec_to_cap(vocab,cand) for cand in cands] 132 | return spice(refs_, cands_) 133 | 134 | def filter_useless_words(sent, filterd_words): 135 | # 去除句子中不参与BLEU值计算的符号 136 | return [w for w in sent if w not in filterd_words] 137 | def filter_cut_useless_words(sent, filterd_words): 138 | res=[] 139 | for w in sent: 140 | if w not in filterd_words: 141 | res.append(w) 142 | else: 143 | if w==155: 144 | return res 145 | return res 146 | def evaluate_ARCTIC(model_path="../best_arctic.ckpt"): 147 | model=torch.load(model_path)["model"] #加载模型 148 | ARCTIC_config = Namespace( 149 | max_len = 93, 150 | captions_per_image = 1, 151 | batch_size = 32, 152 | image_code_dim = 2048, 153 | word_dim = 512, 154 | hidden_size = 512, 155 | attention_dim = 512, 156 | num_layers = 1, 157 | encoder_learning_rate = 0.0001, 158 | decoder_learning_rate = 0.0005, 159 | num_epochs = 10, 160 | grad_clip = 5.0, 161 | alpha_weight = 1.0, 162 | evaluate_step = 900, # 每隔多少步在验证集上测试一次 163 | checkpoint = None, # 如果不为None,则利用该变量路径的模型继续训练 164 | best_checkpoint = 'model/ARCTIC/best_ARCTIC.ckpt', # 验证集上表现最优的模型的路径 165 | last_checkpoint = 'model/ARCTIC/last_ARCTIC.ckpt', # 训练完成时的模型的路径 166 | beam_k = 5 #束搜索的束宽 167 | ) 168 | config=ARCTIC_config 169 | img_path = f'../data/deepfashion-multimodal/images' 170 | _ ,test_loader=mktrainval(data_dir='../data/deepfashion-multimodal',\ 171 | vocab_path='../data/deepfashion-multimodal/vocab.json',\ 172 | batch_size=32,workers=0) 173 | data_loader=test_loader 174 | model.eval() 175 | cands = [] 176 | # 存储参考文本 177 | refs = [] 178 | # 需要过滤的词 179 | filterd_words = set({model.vocab[''], model.vocab[''], model.vocab['']}) 180 | cpi = config.captions_per_image 181 | device = next(model.parameters()).device 182 | for i, (imgs, caps, caplens) in enumerate(data_loader): 183 | with torch.no_grad(): 184 | # 通过束搜索,生成候选文本 185 | texts = model.generate_by_beamsearch(imgs.to(device), config.beam_k, config.max_len+2) 186 | #texts= model.generate_normal_version(imgs.to(device), config.max_len+2) --k=1的情况 效果会差点,但是跑得快 187 | # 候选文本 188 | cands.extend([filter_useless_words(text, filterd_words) for text in texts]) 189 | # 参考文本 190 | refs.extend([filter_useless_words(cap, filterd_words) for cap in caps.tolist()]) 191 | print("正在计算BLEU-4分数...") 192 | bleu4_score = get_BLEU_score(cands, refs) 193 | print("正在计算CIDEr-D分数...") 194 | cider_d_score = get_CIDER_D_score(model.vocab,cands, refs) 195 | print("正在计算SPICE分数...") 196 | spice_score= get_SPICE_score(model.vocab,cands, refs) 197 | print(f"@@@实际值 BLEU:{bleu4_score}|CIDEr-D:{cider_d_score}|SPICE:{spice_score}") 198 | print(f"@@@相对值(0-1) BLEU:{bleu4_score}|CIDEr-D:{cider_d_score/max_cider}|SPICE:{spice_score/max_spice}") 199 | def generate_n(model_path="../best_arctic.ckpt",num_g=5): #随机生成n个文本 200 | model=torch.load(model_path)["model"] #加载模型 201 | ARCTIC_config = Namespace( 202 | max_len = 93, 203 | captions_per_image = 1, 204 | batch_size = 32, 205 | image_code_dim = 2048, 206 | word_dim = 512, 207 | hidden_size = 512, 208 | attention_dim = 512, 209 | num_layers = 1, 210 | encoder_learning_rate = 0.0001, 211 | decoder_learning_rate = 0.0005, 212 | num_epochs = 10, 213 | grad_clip = 5.0, 214 | alpha_weight = 1.0, 215 | evaluate_step = 900, # 每隔多少步在验证集上测试一次 216 | checkpoint = None, # 如果不为None,则利用该变量路径的模型继续训练 217 | best_checkpoint = 'model/ARCTIC/best_ARCTIC.ckpt', # 验证集上表现最优的模型的路径 218 | last_checkpoint = 'model/ARCTIC/last_ARCTIC.ckpt', # 训练完成时的模型的路径 219 | beam_k = 5 #束搜索的束宽 220 | ) 221 | img_path = f'../data/deepfashion-multimodal/images' 222 | _ ,test_loader=mktrainval(data_dir='../data/deepfashion-multimodal',\ 223 | vocab_path='../data/deepfashion-multimodal/vocab.json',\ 224 | batch_size=1,workers=0,is_transform=False) 225 | 226 | model.eval() 227 | cands = [] 228 | # 存储参考文本 229 | refs = [] 230 | # 需要过滤的词 231 | filterd_words = set({model.vocab[''], model.vocab[''], model.vocab['']}) 232 | device = next(model.parameters()).device 233 | for i, (imgs, caps, caplens) in enumerate(test_loader): #随机抽 234 | with torch.no_grad(): 235 | # 通过束搜索,生成候选文本 236 | imgs_=imgs.clone() 237 | texts = model.generate_by_beamsearch(imgs.to(device), ARCTIC_config.beam_k, ARCTIC_config.max_len+2) 238 | # 候选文本 239 | cands.extend([filter_useless_words(text, filterd_words) for text in texts]) 240 | # 参考文本 241 | refs.extend([filter_useless_words(cap, filterd_words) for cap in caps.tolist()]) 242 | res_string=wvec_to_cap(model.vocab,cands[i]) 243 | ref_string=wvec_to_cap(model.vocab,refs[i]) 244 | print(f"@@@生成文本({i+1}/{num_g}):{res_string}") 245 | print(f"@@@实际文本({i+1}/{num_g}):{ref_string}") 246 | #窗口化显示图片 247 | transform_ = transforms.ToPILImage() 248 | image_pil = transform_(imgs[0]) 249 | #image_pil.show() 250 | #画出图像 251 | plt.imshow(image_pil) 252 | plt.show() 253 | if i+1>=num_g: 254 | break 255 | #evaluate_ARCTIC() 256 | generate_n() 257 | -------------------------------------------------------------------------------- /new_dataset/QianFan-agent.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import os 4 | import datetime 5 | 6 | API_KEY = 'jIHlGsMYHp4j1MrMZGNmYbCL' 7 | SECRET_KEY = 'GzG1o4HC6G0qDVxPKcn9Zl4pv20j7CGA' #估计看到的时候已经过了有效期了所以无所谓了 8 | 9 | headers = { 10 | 'Content-Type': 'application/json', 11 | 'Accept': 'application/json' 12 | } 13 | 14 | def get_access_token(): 15 | """ 16 | 使用 AK,SK 生成鉴权签名(Access Token) 17 | :return: access_token,或是None(如果错误) 18 | """ 19 | url = "https://aip.baidubce.com/oauth/2.0/token" 20 | params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY} 21 | return str(requests.post(url, params=params).json().get("access_token")) 22 | 23 | def read_json(file_path): 24 | with open(file_path, 'r', encoding='utf-8') as file: 25 | return json.load(file) 26 | 27 | def QianFan(url, inputs): 28 | responses = {} # 创建一个字典来存储输入和相应的响应 29 | 30 | for key, user_input in inputs.items(): # inputs是一个字典 31 | request = { 32 | "messages": [ 33 | { 34 | "role": "user", 35 | "content": "I will give you a sentence, where the first part is a summary of the background and the second part is information about a person's outfit. Please focus on the background information from the first part and provide an overall background description. " 36 | }, 37 | { 38 | "role": "assistant", 39 | "content": "Of course, please provide the sentence, and I will only output a sentence like: 'The background is'. " 40 | }, 41 | { 42 | "role": "user", 43 | "content": "a people in front of a bed. a pair of jeans" 44 | }, 45 | { 46 | "role": "assistant", 47 | "content": "The backgroud is a homely bedroom." 48 | }, 49 | { 50 | "role": "user", 51 | "content": "a people in front of a mirror. a man in a blue shorts and a white shirt" 52 | }, 53 | { 54 | "role": "assistant", 55 | "content": "The backgroud is a mirror." 56 | }, 57 | ] 58 | }#添加一些前置词 59 | request["messages"].append({"role": "user", "content": f"Describe the setting of the following scene, focusing solely on the background without including any details about the person:{user_input}"}) 60 | 61 | 62 | try: 63 | response = requests.request("POST", url, headers=headers, data=json.dumps(request)) 64 | text = response.text 65 | data = json.loads(text) 66 | model_response = data['result'] 67 | print("\n回答:\n", model_response, '\n') 68 | # 根据句号分割文本 69 | sentences = model_response.split(". ") 70 | 71 | # 获取第一个句子 72 | first_sentence = sentences[0] + "." 73 | responses[key] = first_sentence # 将响应存储在字典中 74 | 75 | except Exception as e: 76 | print(f"QianFan 接口调用出错: {e}") 77 | 78 | # 保存或处理responses字典 79 | return responses 80 | 81 | def main(): 82 | if not os.path.exists('./Amadeus/history'): 83 | os.makedirs('./Amadeus/history') 84 | 85 | url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + get_access_token() 86 | 87 | inputs = read_json('res_new.json') 88 | responses = QianFan(url, inputs) 89 | 90 | # 可以选择保存responses字典 91 | with open('./Amadeus/history/responses.json', 'w', encoding='utf-8') as file: 92 | json.dump(responses, file, ensure_ascii=False, indent=4) 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /new_dataset/merge_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | # 读取JSON文件的函数 4 | def read_json(file_path): 5 | with open(file_path, 'r', encoding='utf-8') as file: 6 | return json.load(file) 7 | 8 | # 读取两个JSON文件 9 | json1 = read_json('res.json') 10 | json2 = read_json('res_add.json') 11 | 12 | # 存储组合后的结果 13 | combined_values = {} 14 | 15 | # 遍历第一个JSON文件的键 16 | for key in json1: 17 | if key in json2: 18 | # 将两个文件中相同键的值组合在一起 19 | combined_values[key] = json1[key] + json2[key] 20 | 21 | # 保存组合后的结果到新的JSON文件 22 | with open('combined_input.json', 'w', encoding='utf-8') as file: 23 | json.dump(combined_values, file, ensure_ascii=False, indent=2) 24 | 25 | print("Combined JSON saved as combined_output.json") -------------------------------------------------------------------------------- /new_dataset/new_generate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 30, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "import os\n", 11 | "from PIL import Image\n", 12 | "import torch\n", 13 | "from torch.utils.data import Dataset, DataLoader\n", 14 | "from transformers import ViTFeatureExtractor, BertTokenizer\n", 15 | "from collections import defaultdict, Counter\n", 16 | "import numpy as np\n", 17 | "\n", 18 | "\n", 19 | "dataset='deepfashion-multimodal'\n", 20 | "img_path = f'data/{dataset}/test_image'\n", 21 | "vocab_path = f'data/{dataset}/vocab.json'\n", 22 | "\n", 23 | "def idx_to_word(idx, vocab):#将向量转化为文本描述\n", 24 | " reverse_vocab = {v: k for k, v in vocab.items()}\n", 25 | " return reverse_vocab.get(int(idx), '')\n", 26 | "\n", 27 | "class CustomImageDataset(Dataset):\n", 28 | " def __init__(self, img_folder, transform=None):\n", 29 | " self.img_folder = img_folder\n", 30 | " self.img_names = [img for img in os.listdir(img_folder) if img.endswith(('.png', '.jpg', '.jpeg'))]\n", 31 | " print(len(self.img_names))\n", 32 | " print(self.img_names[0])\n", 33 | " self.transform = transform\n", 34 | "\n", 35 | " def __len__(self):\n", 36 | " return len(self.img_names)\n", 37 | "\n", 38 | " def __getitem__(self, idx):\n", 39 | " img_path = os.path.join(self.img_folder, self.img_names[idx])\n", 40 | " image = Image.open(img_path).convert('RGB')\n", 41 | " if self.transform:\n", 42 | " image = self.transform(image)\n", 43 | "\n", 44 | " return image, self.img_names[idx]\n", 45 | "\n", 46 | "\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 31, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "from transformers import ViTModel, BertModel, BertConfig\n", 56 | "from torch import nn\n", 57 | "import torch\n", 58 | "\n", 59 | "class Img2TxtModel(nn.Module):\n", 60 | " def __init__(self, vit_model_name, transformer_config, vocab_size):\n", 61 | " super(Img2TxtModel, self).__init__()\n", 62 | " # ViT模型作为编码器\n", 63 | " self.encoder = ViTModel.from_pretrained(vit_model_name)\n", 64 | "\n", 65 | " # Transformer解码器配置\n", 66 | " transformer_config = BertConfig(vocab_size=vocab_size, num_hidden_layers=1, is_decoder=True, add_cross_attention=True)\n", 67 | " self.decoder = BertModel(transformer_config)\n", 68 | "\n", 69 | " # 预测每个词的线性层\n", 70 | " self.vocab_size = vocab_size\n", 71 | " self.fc = nn.Linear(transformer_config.hidden_size, vocab_size)\n", 72 | " \n", 73 | " def forward(self, input_ids, decoder_input_ids, decoder_attention_mask):\n", 74 | " # 通过ViT编码器获取图像特征\n", 75 | " encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state\n", 76 | "\n", 77 | " # 将图像特征作为解码器的输入\n", 78 | " decoder_outputs = self.decoder(input_ids=decoder_input_ids, \n", 79 | " attention_mask=decoder_attention_mask,\n", 80 | " encoder_hidden_states=encoder_outputs).last_hidden_state\n", 81 | "\n", 82 | " # 预测下一个词\n", 83 | " prediction_scores = self.fc(decoder_outputs)\n", 84 | " return prediction_scores\n", 85 | "\n", 86 | " def generate_text(self, input_ids, max_length=95, start_token_id=154):\n", 87 | " # 获取图像特征\n", 88 | " encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state\n", 89 | "\n", 90 | " # 初始化解码器输入为标记\n", 91 | " decoder_input_ids = torch.full((input_ids.size(0), 1), start_token_id).to(input_ids.device)\n", 92 | " \n", 93 | " # 存储所有时间步的logits\n", 94 | " all_logits = []\n", 95 | "\n", 96 | " for step in range(max_length):\n", 97 | " # 获取解码器输出\n", 98 | " decoder_outputs = self.decoder(\n", 99 | " input_ids=decoder_input_ids, \n", 100 | " encoder_hidden_states=encoder_outputs\n", 101 | " ).last_hidden_state\n", 102 | "\n", 103 | " # 预测下一个词\n", 104 | " next_word_logits = self.fc(decoder_outputs[:, -1, :])\n", 105 | " all_logits.append(next_word_logits.unsqueeze(1))\n", 106 | " next_word_id = next_word_logits.argmax(dim=-1).unsqueeze(-1)\n", 107 | " \n", 108 | " # 将预测的词添加到解码器输入中\n", 109 | " decoder_input_ids = torch.cat([decoder_input_ids, next_word_id], dim=-1)\n", 110 | " \n", 111 | " return decoder_input_ids ,torch.cat(all_logits, dim=1)\n", 112 | "\n" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 32, 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "Using GPU: NVIDIA GeForce RTX 3070 Laptop GPU\n", 125 | "5\n", 126 | "MEN-Sweaters-id_00000702-06_7_additional.jpg\n", 127 | "{'MEN-Sweaters-id_00000702-06_7_additional.jpg': 'The person is wearing a short-sleeve shirt with graphic patterns. The shirt is with cotton fabric. It has a crew neckline. The person wears a three-point shorts. The shorts are with denim fabric and pure color patterns. There is an accessory on her wrist. There is a ring on her finger.', 'MEN-Sweatshirts_Hoodies-id_00000911-01_4_full.jpg': 'The person is wearing a tank tank top with graphic patterns. The tank top is with cotton fabric. It has a suspenders neckline. The person wears a long trousers. The trousers are with cotton fabric and graphic patterns. There is an accessory on her wrist. There is a ring on her finger.', 'WOMEN-Pants-id_00005000-06_1_front.jpg': 'The person is wearing a tank tank top with graphic patterns. The tank top is with cotton fabric. It has a suspenders neckline. The person wears a long trousers. The trousers are with cotton fabric and graphic patterns. There is an accessory on her wrist. There is a ring on her finger.', 'WOMEN-Rompers_Jumpsuits-id_00004968-01_2_side.jpg': 'The person is wearing a long-sleeve shirt with graphic patterns. The shirt is with cotton fabric and its neckline is round. The trousers the person wears is of long length. The trousers are with cotton fabric and solid color patterns. There is an accessory on his wrist.', 'WOMEN-Shorts-id_00006003-01_4_full.jpg': 'The tank top this person wears has no sleeves and it is with cotton fabric and graphic patterns. The neckline of the tank top is suspenders. This person wears a three-point pants, with denim fabric and solid color patterns. There is an accessory on her wrist. There is a ring on her finger.'}\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "from transformers import ViTModel, BertModel, BertConfig\n", 133 | "import torch\n", 134 | "from torch import nn\n", 135 | "from torch.utils.data import DataLoader\n", 136 | "from torchvision import transforms\n", 137 | "\n", 138 | "if torch.cuda.is_available():\n", 139 | " device = torch.device(\"cuda\")\n", 140 | " print(\"Using GPU:\", torch.cuda.get_device_name(0))\n", 141 | "else:\n", 142 | " device = torch.device(\"cpu\")\n", 143 | " print(\"Using CPU\")\n", 144 | "\n", 145 | "# 图像预处理\n", 146 | "transform = transforms.Compose([\n", 147 | " transforms.Resize((224, 224)),\n", 148 | " transforms.ToTensor(),\n", 149 | " # 根据需要添加更多的转换\n", 150 | "])\n", 151 | "\n", 152 | "# 创建 Dataset 实例\n", 153 | "dataset = CustomImageDataset(img_folder=img_path, transform=transform)\n", 154 | "\n", 155 | "# 创建 DataLoader\n", 156 | "data_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", 157 | "\n", 158 | "with open(vocab_path, 'r') as f:\n", 159 | " vocab = json.load(f)\n", 160 | "\n", 161 | "vocab_size = len(vocab)\n", 162 | "vit_model_name = 'google/vit-base-patch16-224-in21k'\n", 163 | "transformer_config = BertConfig()\n", 164 | "\n", 165 | "model = Img2TxtModel(vit_model_name, transformer_config, vocab_size)\n", 166 | "# 加载模型状态字典\n", 167 | "checkpoint = torch.load('./model/best_model_epoch_10_batch_2700.pth')\n", 168 | "\n", 169 | "\n", 170 | "# 将状态字典应用到模型实例中\n", 171 | "model.load_state_dict(checkpoint['model_state_dict'])\n", 172 | "model = model.to(device)\n", 173 | "\n", 174 | "model.eval() # 将模型设置为评估模式\n", 175 | "\n", 176 | "generated_captions_dict = {}\n", 177 | "\n", 178 | "with torch.no_grad():\n", 179 | " for images, name in data_loader:\n", 180 | " images = images.to(device)\n", 181 | " input_ids = images\n", 182 | " outputs,_ = model.generate_text(input_ids, max_length=95, start_token_id=vocab[''])\n", 183 | " for i in range(outputs.shape[0]):\n", 184 | " gen_caption = [idx_to_word(idx, vocab) for idx in outputs[i]]\n", 185 | " if '' in gen_caption:\n", 186 | " gen_caption = gen_caption[1:] # 移除第一个元素 ()\n", 187 | " if '' in gen_caption:\n", 188 | " gen_caption = gen_caption[:gen_caption.index('')] # 移除 及其后面的元素\n", 189 | "\n", 190 | " caption_text = ' '.join(gen_caption)\n", 191 | " generated_captions_dict[name[0]] = caption_text\n", 192 | " print(generated_captions_dict)" 193 | ] 194 | } 195 | ], 196 | "metadata": { 197 | "kernelspec": { 198 | "display_name": ".venv", 199 | "language": "python", 200 | "name": "python3" 201 | }, 202 | "language_info": { 203 | "codemirror_mode": { 204 | "name": "ipython", 205 | "version": 3 206 | }, 207 | "file_extension": ".py", 208 | "mimetype": "text/x-python", 209 | "name": "python", 210 | "nbconvert_exporter": "python", 211 | "pygments_lexer": "ipython3", 212 | "version": "3.10.6" 213 | } 214 | }, 215 | "nbformat": 4, 216 | "nbformat_minor": 2 217 | } 218 | -------------------------------------------------------------------------------- /new_dataset/statement.txt: -------------------------------------------------------------------------------- 1 | res.json 是ViT模型在新数据集上跑的结果 2 | res_add.json 是LLM接受rew_new以后生成的结果中的第一个句子 3 | combined_input.json 是上面两个文件按照key拼接value得到的数据,使用时用前4000作为训练,后1000作为测试 4 | res_new.json blip生成的图像描述,较为简陋 5 | 6 | config.json是ViT的设置 7 | merge——json用来和res_add.json、res.json生成combined_input.json 8 | 9 | new_generate.ipynb存储了一些新数据集上训练的模型结果,虽然结果不是很好,主要问题是blip的提取错误和模型性能不足,数据量不足;LLM虽然工作不算完美,但是基本满足要求。 -------------------------------------------------------------------------------- /tools/test_blip.py: -------------------------------------------------------------------------------- 1 | 2 | import requests 3 | from PIL import Image 4 | from transformers import BlipProcessor, BlipForConditionalGeneration 5 | import os 6 | import json 7 | class blip_model(): 8 | def __init__(self) -> None: 9 | self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") 10 | self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda") 11 | def gen_res(self,img_path): 12 | raw_image = Image.open(img_path).convert('RGB') 13 | 14 | text = "a people in front of " 15 | input_1 = self.processor(raw_image, text, return_tensors="pt").to("cuda") 16 | out_1 = self.model.generate(**input_1,max_length=100) 17 | res_1=self.processor.decode(out_1[0], skip_special_tokens=True) 18 | 19 | input_2 = self.processor(raw_image, return_tensors="pt").to("cuda") 20 | out_2 = self.model.generate(**input_2,max_length=100) 21 | res_2=self.processor.decode(out_2[0], skip_special_tokens=True) 22 | return res_1+". "+res_2 23 | def gen_json(img_path,n): #使用Blip模型标注图片 24 | model=blip_model() 25 | #img_path="D:/NNDL/data/deepfashion-multimodal/images" 26 | #获取该目录下所有文件,存入列表中 27 | imgs=os.listdir(img_path) 28 | res={} 29 | start=31000 30 | 31 | for img in range(start,len(imgs)): 32 | img_k=imgs[img] 33 | img_path_=img_path+"/"+img_k 34 | res[img_k]=model.gen_res(img_path_) 35 | if len(res)>=n: 36 | break 37 | #保存为json文件 38 | with open('res.json', 'w') as f: 39 | json.dump(res, f,indent=2) 40 | #print(model.gen_res("test.JPG")) 41 | img_path="D:/NNDL/data/img" 42 | #gen_json(img_path,5000) 43 | print("@@@@") 44 | def read_json(json_path): 45 | with open(json_path,'r',encoding='utf-8') as f: 46 | data=json.load(f) 47 | return data 48 | #data=read_json("res.json") 49 | #重新保存 50 | def save_json(data,json_path): 51 | with open(json_path,'w',encoding='utf-8') as f: 52 | json.dump(data,f,indent=2) 53 | #save_json(data,"res_new.json") -------------------------------------------------------------------------------- /结题报告.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# NNDL课设报告\n", 8 | "**小组成员及对应工作量:苏柏侨(0.3) 李恒屹(0.4) 石基宽(0.3)**" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "# 任务说明\n", 16 | "服饰图像描述,训练一个模型,对输入的服饰图片,输出描述信息,我们实现的模型有以下三个实现:\n", 17 | "- ARCTIC,一个典型的基于注意力的编解码模型\n", 18 | "- 视觉Transformer (ViT) + Transformer解码器\n", 19 | "- 网格/区域表示、Transformer编码器+Transformer解码器\n", 20 | " \n", 21 | "同时也实现三种测评方法进行测评:\n", 22 | "- BLEU (Bilingual Evaluation Understudy)\n", 23 | "- SPICE (Semantic Propositional Image Caption Evaluation): \n", 24 | "- CIDEr-D (Consensus-based Image Description Evaluation)\n", 25 | "\n", 26 | "以及实现了附加任务:\n", 27 | "- 利用训练的服饰图像描述模型和多模态大语言模型,为真实背景的服饰图像数据集增加服饰描述和背景描述,构建全新的服饰图像描述数据集\n", 28 | " - 在新数据集上重新训练服饰图像描述模型" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "# 实验数据\n", 36 | "数据集使用的是 DeepFashion-MultiModal (https://github.com/yumingj/DeepFashion-MultiModal), 仅用到image和textual descriptions ,数据集划分为10k+行数据的训练集和2k+行数据的测试集,`train_captions.json`和`test_captions.json`分别对应训练集和测试集的图片与描述信息的键值对应" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 8, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "img_path = f'../data/deepfashion-multimodal/images'\n", 46 | "def cap_to_wvec(vocab,cap):#将文本描述转换成向量\n", 47 | " cap.replace(\",\",\"\")\n", 48 | " cap.replace(\".\",\"\")\n", 49 | " cap=cap.split()\n", 50 | " res=[]\n", 51 | " for word in cap:\n", 52 | " if word in vocab.keys():\n", 53 | " res.append(vocab[word])\n", 54 | " else: #不在字典的词\n", 55 | " res.append(vocab[''])\n", 56 | " return res\n", 57 | "def wvec_to_cap(vocab,wvec):#将向量转换成文本描述\n", 58 | " res=[]\n", 59 | " for word in wvec:\n", 60 | " for key,value in vocab.items():\n", 61 | " if value==word and key not in ['','','','']:\n", 62 | " res.append(key)\n", 63 | " res=\" \".join(res)\n", 64 | " return res\n", 65 | "def wvec_to_capls(vocab,wvec):#将向量转换成文本描述\n", 66 | " res=[]\n", 67 | " for word in wvec:\n", 68 | " for key,value in vocab.items():\n", 69 | " if value==word and key not in ['','','','']:\n", 70 | " res.append(key)\n", 71 | " return res\n", 72 | "class ImageTextDataset(Dataset):\n", 73 | " def __init__(self, dataset_path, vocab_path, split, captions_per_image=1, max_len=93, transform=None):\n", 74 | "\n", 75 | " self.split = split\n", 76 | " assert self.split in {'train', 'test'}\n", 77 | " self.cpi = captions_per_image\n", 78 | " self.max_len = max_len\n", 79 | "\n", 80 | " # 载入数据集\n", 81 | " with open(dataset_path, 'r') as f:\n", 82 | " self.data = json.load(f) #key是图片名字 value是描述\n", 83 | " self.data_img=list(self.data.keys())\n", 84 | " # 载入词典\n", 85 | " with open(vocab_path, 'r') as f:\n", 86 | " self.vocab = json.load(f)\n", 87 | "\n", 88 | " # PyTorch图像预处理流程\n", 89 | " self.transform = transform\n", 90 | "\n", 91 | " # Total number of datapoints\n", 92 | " self.dataset_size = len(self.data_img)\n", 93 | "\n", 94 | " def __getitem__(self, i):\n", 95 | " # 第i个文本描述对应第(i // captions_per_image)张图片\n", 96 | " img = Image.open(img_path+\"/\"+self.data_img[i]).convert('RGB')\n", 97 | " if self.transform is not None:\n", 98 | " img = self.transform(img)\n", 99 | " c_vec=cap_to_wvec(self.vocab,self.data[self.data_img[i]])\n", 100 | " #加入起始和结束标志\n", 101 | " c_vec = [self.vocab['']] + c_vec + [self.vocab['']]\n", 102 | " caplen = len(c_vec)\n", 103 | " caption = torch.LongTensor(c_vec+ [self.vocab['']] * (self.max_len + 2 - caplen))\n", 104 | " \n", 105 | " return img, caption, caplen\n", 106 | " \n", 107 | " def __len__(self):\n", 108 | " return self.dataset_size\n", 109 | "def mktrainval(data_dir, vocab_path, batch_size, workers=1):\n", 110 | " train_tx = transforms.Compose([\n", 111 | " transforms.Resize(256), # 重置图像分辨率\n", 112 | " transforms.RandomCrop(224), # 随机裁剪\n", 113 | " transforms.ToTensor(), # 转换成Tensor\n", 114 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 标准化--三个参数为三个通道的均值和标准差\n", 115 | " ])\n", 116 | " val_tx = transforms.Compose([\n", 117 | " transforms.Resize(256),\n", 118 | " transforms.CenterCrop(224),\n", 119 | " transforms.ToTensor(),\n", 120 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 121 | " ])\n", 122 | " train_set = ImageTextDataset(os.path.join(data_dir, 'train_captions.json'), vocab_path, 'train', transform=train_tx)\n", 123 | " test_set = ImageTextDataset(os.path.join(data_dir, 'test_captions.json'), vocab_path, 'test', transform=val_tx)\n", 124 | "\n", 125 | " train_loader = torch.utils.data.DataLoader(\n", 126 | " train_set, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)\n", 127 | " \n", 128 | " test_loader = torch.utils.data.DataLoader(\n", 129 | " test_set, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True, drop_last=False)\n", 130 | "\n", 131 | " return train_loader, test_loader \n", 132 | "train_loader,test_loader=mktrainval(data_dir='../data/deepfashion-multimodal',\\\n", 133 | " vocab_path='../data/deepfashion-multimodal/vocab.json',\\\n", 134 | " batch_size=3,workers=0) \n", 135 | "#workers=0 是因为ipynb不支持多线程" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "- 关于字典的处理\n", 143 | " 我们采用了一个非常传统的方式,加载所有的训练集和测试集,进行词频统计,我们默认阈值为5 到达阈值的词将会加入到词典中\n", 144 | "\n", 145 | " 之后我们额外添加了 < pad > < start > < end > < unk >四个词,分别代表填充词,句首标记,句尾标记,未知词\n", 146 | " \n", 147 | " 最终写入到vocab.json文件中\n", 148 | "- 关于数据集类的处理\n", 149 | "我们使用Pytroch的Dataset来构建数据集类,在此之外封装了返回测试集和训练集的函数,可以进行自定义的批量预处理,我们在训练和推理过程中进行了如下的处理\n", 150 | " - resize 图像大小为256*256\n", 151 | " - 随机裁剪 为224*224\n", 152 | " - 转换为Torch Tensor\n", 153 | " - normalize 归一化为\n", 154 | " - mean=[0.485, 0.456, 0.406]\n", 155 | " - std=[0.229, 0.224, 0.225]" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "# 实验环境\n", 163 | "- Python 3.9.16\n", 164 | "- 主要依赖库\n", 165 | " - torch \n", 166 | " - torchvision\n", 167 | " - nltk \n", 168 | "实际使用py库情况如下\n", 169 | "\n", 170 | "项目开发中除数据集和模型外代码使用git进行版本控制,需要说明的是,由于课设由多个人多个设备下完成,而训练模型和数据集有所不同,我们维持同步的仅仅是Image2TextEvaluation这一项目,所以不同设备下由细微不同之处,如此测试,需要更换实际环境替换对应代码中的路径,直接进行测试可能会出现一些问题————这里提前声明这一点————文件路径大致如下:\n", 171 | "```\n", 172 | "├── Image2TextEvaluation \n", 173 | "│ ├── ARCTIC\n", 174 | "│ │ ├── ARCTIC_dataloader.py\n", 175 | "│ │ ├── ARCTIC_model.py\n", 176 | "│ │ └── train.py\n", 177 | "│ ├── Vit\n", 178 | "│ │ ├── Vit.ipynb\n", 179 | "│ │ ├── config.json\n", 180 | "│ │ ├── generate.ipynb\n", 181 | "│ ├── SwinTrans\n", 182 | "│ │ ├── ...//同上\n", 183 | "│ │ ├── evaluate.ipynb\n", 184 | "│ │ └── gridSwinTrans.ipynb\n", 185 | "│ │\n", 186 | "│ ├── Tools\n", 187 | "│ │ ├── test_blip.py\n", 188 | "│ │ ├── ...\n", 189 | "│ │\n", 190 | "│ ├── new_dataset //Blip+多模态构建新数据集\n", 191 | "│ │ ├── QianFan-agent.py \n", 192 | "│ │ ├── combined_input.json\n", 193 | "│ │ ├── statement.txt\n", 194 | "│ │ ├── merge_json.py\n", 195 | "│ │ ├── ...\n", 196 | "│ │\n", 197 | "│ ├── evaluate.py\n", 198 | "│ ├── README.md\n", 199 | "│ └── 结题报告.ipynb\n", 200 | "├── model\n", 201 | "│ ├── best_arctic.ckpt\n", 202 | "│ ├── last_arctic.ckpt\n", 203 | "│ └── ...\n", 204 | "├── data\n", 205 | "│ ├── deepfashion-multimodal\n", 206 | "│ │ ├── images\n", 207 | "│ │ │ ├── 001.jpg\n", 208 | "│ │ │ ├── ...\n", 209 | "│ │ ├── train_captions.json\n", 210 | "│ │ ├── test_captions.json\n", 211 | "│ │ ├── vocab.json\n", 212 | "│ │ └── ...\n", 213 | "\n", 214 | "```\n" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 2, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "import torch\n", 224 | "import torch.nn as nn\n", 225 | "from torch.nn.utils.rnn import pack_padded_sequence # 压紧填充序列\n", 226 | "from torch.utils.data import Dataset\n", 227 | "import torchvision\n", 228 | "import torchvision.transforms as transforms\n", 229 | "from torchvision.models import ResNet101_Weights\n", 230 | "from nltk.translate.bleu_score import corpus_bleu # BLEU评价指标\n", 231 | "import numpy as np\n", 232 | "import json\n", 233 | "from torch.utils.data import Dataset\n", 234 | "import os\n", 235 | "from PIL import Image\n", 236 | "from collections import Counter,defaultdict\n", 237 | "from argparse import Namespace \n" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "metadata": {}, 243 | "source": [ 244 | "# 所用的方法或模型\n", 245 | "\n", 246 | "## 评估方法\n", 247 | "\n", 248 | "### BLEU (BiLingual Evaluation Understudy)\n", 249 | "- BLUE是比较常用的评估指标之一,也是我们默认指标,需要注意的是,再调用计算BLEU值之前,要先将文本中人工添加的文本开始符、结束符和占位符去掉,其公式如下, 实际代码中我们借助nltk库进行实现\n", 250 | "$$BLEU = \\sum_{n=1}^k w_n \\frac{ngram_{sys}(n)}{ngram_{ref}(n)}$$\n", 251 | "其中:\n", 252 | " - n 是 n-gram 的阶数,取值范围为 1 到 4。\n", 253 | " - wn 是 n-gram 的权重,通常取均匀权重。\n", 254 | " - ngramsys(n) 是机器翻译结果中的 n-gram 数量。\n", 255 | " - ngramref(n) 是参考翻译中的 n-gram 数量。\n", 256 | " BLEU 的得分范围为 0 到 1。得分越高,表示机器翻译结果与参考翻译越相似。\n", 257 | " - 优点:容易计算\n", 258 | " - 缺点:\n", 259 | " - 没有考虑n-gram的顺序\n", 260 | " - 平等对待所有的n-gram\n", 261 | " - 衡量的是句子之间的流畅性而非语义相似度\n", 262 | "### CIDEr-D (Consensus-based Image Description Evaluation)\n", 263 | "- 是CIDEr的改进,对于动词原形和名词匹配成功的问题,CIDEr-D不再取词根\n", 264 | "其用了一种折扣函数来降低长句子对评分的影响,增加了惩罚生成句子和参考句子的长度差别的权重,并且通过对n-gram计数的截断操作不再计算生成句子中出现次数超过参考句子的n-gram,\n", 265 | "从而减少了重复单词对评分的影响,其实也是计算1到4 gram的结果的平均值,其公式如下\n", 266 | "$$C I D E r - D _ { n } ( c _ { i } , S _ { i } ) = \\frac { 1 0 } { m } \\sum _ { j } e ^ { - \\frac { -( i ( c _ { i } ) - l ( s _ { i j } ) ) ^ { 2 } } { 2 \\sigma ^ { 2 } } } \\times \\frac { \\min ( g ^ { n } ( c _ { i } ) , g ^ { n } ( s _ { i j } ) ) \\cdot g ^ { n } ( s _ { i j } ) } {| | g ^ { n } ( c _ { i } ) | | | g ^ { n } ( s _ { i j } ) || } $$\n", 267 | "- 优点:\n", 268 | " - CIDEr引入了TF-IDF为n-gram进行加权,这样就避免评价候选句子时因为一些常见却不够有信息量的n-gram打上高分\n", 269 | "- 缺点:\n", 270 | " - CIDEr取词根的操作会让一些动词的原型和名词匹配成功\n", 271 | " - 高置信度的词重复出现的长句的CIDEr得分也很高\n", 272 | "### SPICE (Semantic Propositional Image Caption Evaluation): \n", 273 | "- 是以名词为中心的度量,是以图的语义表示来编码图像描述中的对象、属性和关系\n", 274 | "首先要将候选句子和参考句子集转化为场景图\n", 275 | "然后比较候选句子和参考句子集中元组的precision、recall,最终计算出F1 score\n", 276 | "公式如下\n", 277 | "$$SPICE = \\sum_{i=1}^m \\frac{1}{|S_i|} \\sum_{j=1}^n \\frac{s_{ij}}{|R_i|}\n", 278 | "$$\n", 279 | " - m 是图像描述的数量。\n", 280 | " - n 是图像描述中的对象、属性和关系的数量。\n", 281 | " - Si 是图像描述 i 中的对象、属性和关系。\n", 282 | " - Ri 是参考图像描述 i 中的对象、属性和关系。\n", 283 | " - sij 是图像描述 i 中的对象、属性和关系 j 与参考图像描述 i 中的对象、属性和关系 j 的相似度\n", 284 | "- 优点:\n", 285 | " - 在语义而非n-gram层级度量\n", 286 | " - 每个句子映射到场景图后可以从中提取出模型关于某些关系或者属性的识别能力\n", 287 | "- 缺点\n", 288 | " - 缺少n-gram来度量句子的流畅性\n", 289 | " - 度量的准确性受到场景图解析器的制约\n", 290 | "\n", 291 | "使用代码如下,在evaluate的时候调用,接受cands, refs返回对应评估分数" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": 12, 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "def cider_d(reference_list, candidate_list, n=4):\n", 301 | " def count_ngrams(tokens, n):\n", 302 | " ngrams = []\n", 303 | " for i in range(len(tokens) - n + 1):\n", 304 | " ngram = tuple(tokens[i:i+n])\n", 305 | " ngrams.append(ngram)\n", 306 | " return ngrams\n", 307 | "\n", 308 | " def compute_cider_d(reference_list, candidate_list, n):\n", 309 | " cider_d_scores = []\n", 310 | " for refs, cand in zip(reference_list, candidate_list):\n", 311 | " cider_d_score = 0.0\n", 312 | " for i in range(1, n + 1):\n", 313 | " cand_ngrams = count_ngrams(cand, i)\n", 314 | " ref_ngrams_list = [count_ngrams(ref, i) for ref in refs]\n", 315 | "\n", 316 | " total_ref_ngrams = [ngram for ref_ngrams in ref_ngrams_list for ngram in ref_ngrams]\n", 317 | "\n", 318 | " count_cand = 0\n", 319 | " count_clip = 0\n", 320 | "\n", 321 | " for ngram in cand_ngrams:\n", 322 | " count_cand += 1\n", 323 | " if ngram in total_ref_ngrams:\n", 324 | " count_clip += 1\n", 325 | "\n", 326 | " precision = count_clip / count_cand if count_cand > 0 else 0.0\n", 327 | " recall = count_clip / len(total_ref_ngrams) if len(total_ref_ngrams) > 0 else 0.0\n", 328 | "\n", 329 | " beta = 1.0\n", 330 | " f_score = (1 + beta**2) * precision * recall / (beta**2 * precision + recall) if precision + recall > 0 else 0.0\n", 331 | "\n", 332 | " cider_d_score += f_score\n", 333 | "\n", 334 | " cider_d_score /= n\n", 335 | " cider_d_scores.append(cider_d_score)\n", 336 | "\n", 337 | " return cider_d_scores\n", 338 | "\n", 339 | " reference_tokens_list = reference_list\n", 340 | " candidate_tokens_list = candidate_list\n", 341 | "\n", 342 | " scores = compute_cider_d(reference_tokens_list, candidate_tokens_list, n)\n", 343 | "\n", 344 | " return np.mean(scores)\n", 345 | "def spice(reference_list, candidate_list, idf=None, beta=3):\n", 346 | " def count_ngrams(tokens, n):\n", 347 | " ngrams = []\n", 348 | " for i in range(len(tokens) - n + 1):\n", 349 | " ngram = tuple(tokens[i:i+n])\n", 350 | " ngrams.append(ngram)\n", 351 | " return ngrams\n", 352 | "\n", 353 | " def compute_spice_score(reference, candidate, idf, beta):\n", 354 | " reference_tokens = reference\n", 355 | " candidate_tokens = candidate\n", 356 | "\n", 357 | " reference_ngrams = [count_ngrams(reference_tokens, i) for i in range(1, beta + 1)]\n", 358 | " candidate_ngrams = [count_ngrams(candidate_tokens, i) for i in range(1, beta + 1)]\n", 359 | "\n", 360 | " precision_scores = []\n", 361 | " recall_scores = []\n", 362 | "\n", 363 | " for i in range(beta):\n", 364 | " common_ngrams = set(candidate_ngrams[i]) & set(reference_ngrams[i])\n", 365 | "\n", 366 | " precision = len(common_ngrams) / len(candidate_ngrams[i]) if len(candidate_ngrams[i]) > 0 else 0.0\n", 367 | " recall = len(common_ngrams) / len(reference_ngrams[i]) if len(reference_ngrams[i]) > 0 else 0.0\n", 368 | "\n", 369 | " precision_scores.append(precision)\n", 370 | " recall_scores.append(recall)\n", 371 | "\n", 372 | " precision_avg = np.mean(precision_scores)\n", 373 | " recall_avg = np.mean(recall_scores)\n", 374 | "\n", 375 | " spice_score = (precision_avg * recall_avg) / (precision_avg + recall_avg) if precision_avg + recall_avg > 0 else 0.0\n", 376 | "\n", 377 | " if idf:\n", 378 | " spice_score *= np.exp(np.sum([idf[token] for token in common_ngrams]) / len(candidate_tokens))\n", 379 | "\n", 380 | " return spice_score\n", 381 | "\n", 382 | " if idf is None:\n", 383 | " idf = {}\n", 384 | "\n", 385 | " spice_scores = []\n", 386 | "\n", 387 | " for reference, candidate in zip(reference_list, candidate_list):\n", 388 | " spice_score = compute_spice_score(reference, candidate, idf, beta)\n", 389 | " spice_scores.append(spice_score)\n", 390 | "\n", 391 | " return np.mean(spice_scores)\n", 392 | "def get_BLEU_score(cands, refs): #获取BLEU分数\n", 393 | " multiple_refs = []\n", 394 | " for idx in range(len(refs)):\n", 395 | " multiple_refs.append(refs[(idx//1)*1 : (idx//1)*1+1])#每个候选文本对应cpi==1条参考文本\n", 396 | " bleu4 = corpus_bleu(multiple_refs, cands, weights=(0.25,0.25,0.25,0.25))\n", 397 | " return bleu4\n", 398 | "def get_CIDER_D_score(cands, refs): #获得CIDER-D分数\n", 399 | " refs_ = [wvec_to_capls(model.vocab,ref) for ref in refs]\n", 400 | " cands_ = [wvec_to_capls(model.vocab,cand) for cand in cands]\n", 401 | " return cider_d(refs_, cands_)\n", 402 | "def get_SPICE_score(cands, refs): #获得SPICE分数\n", 403 | " refs_ = [wvec_to_cap(model.vocab,ref) for ref in refs]\n", 404 | " cands_ = [wvec_to_cap(model.vocab,cand) for cand in cands]\n", 405 | " return spice(refs_, cands_)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "metadata": {}, 411 | "source": [ 412 | "## 模型定义\n", 413 | "### ARCTIC,一个典型的基于注意力的编解码模型\n", 414 | "\n", 415 | "模型架构: ARCTIC 是一个基于注意力的编码-解码模型,使用了图像编码器和注意力解码器。\n", 416 | "#### 编码器部分\n", 417 | "图像编码器使用了 ResNet-101 网络进行特征提取。\n", 418 | "\n", 419 | "我们使用了预训练模型的权重,所以后续的训练中实际上编码器不参与训练过程,其参数是冻结的\n", 420 | "\n", 421 | "\n", 422 | "\n", 423 | "我们并将其最后一个非全连接层作为网格表示提取层。\n", 424 | "#### 解码器部分\n", 425 | "解码器采用 GRU,利用注意力上下文向量和当前时刻的词嵌入来生成预测结果。该模型支持束搜索来生成更准确的描述。\n", 426 | "\n", 427 | "解码器实质是一个rnn,其是有一层加性注意力机制,它接受查询(Q)和键值对(K,V),计算注意力分数,最后输出上下文向量。\n", 428 | "\n", 429 | "- **加性注意力机制**:我们在AdditiveAttention 类上实现了加性注意力机制,用于在解码过程中关注图像中不同部分的信息。\n", 430 | "\n", 431 | "- **使用GRU**: rnn具体使用的是 GRU :是一种门控循环单元(gated recurrent unit)的缩写,它是一种用于处理序列数据的循环神经网络(RNN)\n", 432 | "- \n", 433 | " 我个人认为选择GRU 而不是 LSTM 是因为相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,可以很出现可观的效果。\n", 434 | "\n", 435 | "- **forward过程:** 在单步forward过程中,我们将上下文向量和当前时刻的词表示拼接,然后和我们的rnn贴合在一起\n", 436 | "\n", 437 | " 作为 GRU 的输入,进入全连接层对GRU 输出进行线性变换,得到单步forward;而实际生成句子的时候就是重复这个过程,选择概率最大的词作为下一个词,直到遇到结束符或者到达最大的长度\n", 438 | "\n", 439 | " 当然,这种推理方法不是最好的,实质上这是一种 贪心算法\n", 440 | " 我们知道,**贪心算法的缺点就是它无法保证全局最优**,我们要的是所有预测的的概率相乘最大。\n", 441 | "\n", 442 | "- **使用束搜索**:所以我们还实现了另一种方法,即使用束搜索来生成更准确的描述。\n", 443 | "\n", 444 | " 束搜索就是在每一步的时候,计算到这一步为止的预测y序列的概率最大的前k条,k叫集束宽。\n", 445 | "\n", 446 | " 当然束搜索的缺点也很明显,推理所需时间实际上是倍增的,但是效果上是有增益的。\n", 447 | "\n" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 13, 453 | "metadata": {}, 454 | "outputs": [], 455 | "source": [ 456 | "ARCTIC_config = Namespace(\n", 457 | " max_len = 93,\n", 458 | " captions_per_image = 1,\n", 459 | " batch_size = 32,\n", 460 | " image_code_dim = 2048,\n", 461 | " word_dim = 512,\n", 462 | " hidden_size = 512,\n", 463 | " attention_dim = 512,\n", 464 | " num_layers = 1,\n", 465 | " encoder_learning_rate = 0.0001,\n", 466 | " decoder_learning_rate = 0.0005,\n", 467 | " num_epochs = 10,\n", 468 | " grad_clip = 5.0,\n", 469 | " alpha_weight = 1.0,\n", 470 | " evaluate_step = 900, # 每隔多少步在验证集上测试一次\n", 471 | " checkpoint = None, # 如果不为None,则利用该变量路径的模型继续训练\n", 472 | " best_checkpoint = 'model/ARCTIC/best_ARCTIC.ckpt', # 验证集上表现最优的模型的路径\n", 473 | " last_checkpoint = 'model/ARCTIC/last_ARCTIC.ckpt', # 训练完成时的模型的路径\n", 474 | " beam_k = 5 #束搜索的束宽\n", 475 | " )\n", 476 | "class ImageEncoder(nn.Module):\n", 477 | " def __init__(self, finetuned=True):\n", 478 | " super(ImageEncoder, self).__init__()\n", 479 | " model = torchvision.models.resnet101(weights=ResNet101_Weights.DEFAULT)\n", 480 | " # ResNet-101网格表示提取器\n", 481 | " self.grid_rep_extractor = nn.Sequential(*(list(model.children())[:-2])) #去掉最后两层 \n", 482 | " for param in self.grid_rep_extractor.parameters(): #冻结参数--不参与训练\n", 483 | " param.requires_grad = finetuned #是否微调\n", 484 | " def forward(self, images):\n", 485 | " out = self.grid_rep_extractor(images) \n", 486 | " return out\n", 487 | "class AdditiveAttention(nn.Module): #加性注意力\n", 488 | " def __init__(self, query_dim, key_dim, attn_dim):\n", 489 | " \"\"\"\n", 490 | " query_dim: 查询Q的维度\n", 491 | " key_dim: 键K的维度\n", 492 | " attn_dim: 注意力函数隐藏层表示的维度\n", 493 | " \"\"\"\n", 494 | " \n", 495 | " super(AdditiveAttention, self).__init__()\n", 496 | " self.attn_w_1_q = nn.Linear(query_dim, attn_dim) #Q的线性变换\n", 497 | " self.attn_w_1_k = nn.Linear(key_dim, attn_dim) #K的线性变换\n", 498 | " self.attn_w_2 = nn.Linear(attn_dim, 1) #注意力函数隐藏层到输出层的线性变换\n", 499 | " self.tanh = nn.Tanh() #激活函数\n", 500 | " self.softmax = nn.Softmax(dim=1) #归一化函数\n", 501 | "\n", 502 | " def forward(self, query, key_value):\n", 503 | " \"\"\"\n", 504 | " Q K V:Q和K算出相关性得分,作为V的权重,K=V\n", 505 | " 参数:\n", 506 | " query: 查询 (batch_size, q_dim)\n", 507 | " key_value: 键和值,(batch_size, n_kv, kv_dim)\n", 508 | " \"\"\"\n", 509 | " # (2)计算query和key的相关性,实现注意力评分函数\n", 510 | " # -> (batch_size, 1, attn_dim)\n", 511 | " queries = self.attn_w_1_q(query).unsqueeze(1) \n", 512 | " # -> (batch_size, n_kv, attn_dim)\n", 513 | " keys = self.attn_w_1_k(key_value) #\n", 514 | " # -> (batch_size, n_kv)\n", 515 | " attn = self.attn_w_2(self.tanh(queries+keys)).squeeze(2) #注意力评分函数\n", 516 | " # (3)归一化相关性分数\n", 517 | " # -> (batch_size, n_kv)\n", 518 | " attn = self.softmax(attn) #归一化\n", 519 | " # (4)计算输出\n", 520 | " # (batch_size x 1 x n_kv)(batch_size x n_kv x kv_dim)\n", 521 | " # -> (batch_size, 1, kv_dim)\n", 522 | " output = torch.bmm(attn.unsqueeze(1), key_value).squeeze(1)\n", 523 | " return output, attn\n", 524 | "class AttentionDecoder(nn.Module):\n", 525 | " def __init__(self, image_code_dim, vocab_size, word_dim, attention_dim, hidden_size, num_layers, dropout=0.5):\n", 526 | " super(AttentionDecoder, self).__init__()\n", 527 | " self.embed = nn.Embedding(vocab_size, word_dim) #词嵌入 \n", 528 | " self.attention = AdditiveAttention(hidden_size, image_code_dim, attention_dim) #注意力机制\n", 529 | " self.init_state = nn.Linear(image_code_dim, num_layers*hidden_size) #初始化隐状态\n", 530 | " self.rnn = nn.GRU(word_dim + image_code_dim, hidden_size, num_layers) #GRU 门控循环\n", 531 | " self.dropout = nn.Dropout(p=dropout) #dropout\n", 532 | " self.fc = nn.Linear(hidden_size, vocab_size) #全连接层\n", 533 | " # RNN默认已初始化\n", 534 | " self.init_weights() #初始化权重\n", 535 | " \n", 536 | " def init_weights(self): #初始化权重\n", 537 | " self.embed.weight.data.uniform_(-0.1, 0.1) #词嵌入\n", 538 | " self.fc.bias.data.fill_(0) #全连接层\n", 539 | " self.fc.weight.data.uniform_(-0.1, 0.1) #全连接层\n", 540 | " \n", 541 | " def init_hidden_state(self, image_code, captions, cap_lens):\n", 542 | " \"\"\"\n", 543 | " 参数:\n", 544 | " image_code:图像编码器输出的图像表示 \n", 545 | " (batch_size, image_code_dim, grid_height, grid_width)\n", 546 | " \"\"\"\n", 547 | " # 将图像网格表示转换为序列表示形式 \n", 548 | " batch_size, image_code_dim = image_code.size(0), image_code.size(1)\n", 549 | " # -> (batch_size, grid_height, grid_width, image_code_dim) \n", 550 | " image_code = image_code.permute(0, 2, 3, 1) \n", 551 | " # -> (batch_size, grid_height * grid_width, image_code_dim)\n", 552 | " image_code = image_code.view(batch_size, -1, image_code_dim)\n", 553 | " # (1)按照caption的长短排序\n", 554 | " sorted_cap_lens, sorted_cap_indices = torch.sort(cap_lens, 0, True)\n", 555 | " captions = captions[sorted_cap_indices]\n", 556 | " image_code = image_code[sorted_cap_indices]\n", 557 | " #(2)初始化隐状态\n", 558 | " hidden_state = self.init_state(image_code.mean(axis=1))\n", 559 | " hidden_state = hidden_state.view(\n", 560 | " batch_size, \n", 561 | " self.rnn.num_layers, \n", 562 | " self.rnn.hidden_size).permute(1, 0, 2)\n", 563 | " return image_code, captions, sorted_cap_lens, sorted_cap_indices, hidden_state\n", 564 | "\n", 565 | " def forward_step(self, image_code, curr_cap_embed, hidden_state):\n", 566 | " #(3.2)利用注意力机制获得上下文向量\n", 567 | " # query:hidden_state[-1],即最后一个隐藏层输出 (batch_size, hidden_size)\n", 568 | " # context: (batch_size, hidden_size)\n", 569 | " context, alpha = self.attention(hidden_state[-1], image_code)\n", 570 | " #(3.3)以上下文向量和当前时刻词表示为输入,获得GRU输出\n", 571 | " x = torch.cat((context, curr_cap_embed), dim=-1).unsqueeze(0)\n", 572 | " # x: (1, real_batch_size, hidden_size+word_dim)\n", 573 | " # out: (1, real_batch_size, hidden_size)\n", 574 | " out, hidden_state = self.rnn(x, hidden_state)\n", 575 | " #(3.4)获取该时刻的预测结果\n", 576 | " # (real_batch_size, vocab_size)\n", 577 | " preds = self.fc(self.dropout(out.squeeze(0)))\n", 578 | " return preds, alpha, hidden_state\n", 579 | " \n", 580 | " def forward(self, image_code, captions, cap_lens):\n", 581 | " \"\"\"\n", 582 | " 参数:\n", 583 | " hidden_state: (num_layers, batch_size, hidden_size)\n", 584 | " image_code: (batch_size, feature_channel, feature_size)\n", 585 | " captions: (batch_size, )\n", 586 | " \"\"\"\n", 587 | " # (1)将图文数据按照文本的实际长度从长到短排序\n", 588 | " # (2)获得GRU的初始隐状态\n", 589 | " image_code, captions, sorted_cap_lens, sorted_cap_indices, hidden_state \\\n", 590 | " = self.init_hidden_state(image_code, captions, cap_lens)\n", 591 | " batch_size = image_code.size(0)\n", 592 | " # 输入序列长度减1,因为最后一个时刻不需要预测下一个词\n", 593 | " lengths = sorted_cap_lens.cpu().numpy() - 1\n", 594 | " # 初始化变量:模型的预测结果和注意力分数\n", 595 | " predictions = torch.zeros(batch_size, lengths[0], self.fc.out_features).to(captions.device)\n", 596 | " alphas = torch.zeros(batch_size, lengths[0], image_code.shape[1]).to(captions.device)\n", 597 | " # 获取文本嵌入表示 cap_embeds: (batch_size, num_steps, word_dim)\n", 598 | " cap_embeds = self.embed(captions)\n", 599 | " # Teacher-Forcing模式\n", 600 | " for step in range(lengths[0]):\n", 601 | " #(3)解码\n", 602 | " #(3.1)模拟pack_padded_sequence函数的原理,获取该时刻的非输入\n", 603 | " real_batch_size = np.where(lengths>step)[0].shape[0]\n", 604 | " preds, alpha, hidden_state = self.forward_step(\n", 605 | " image_code[:real_batch_size], \n", 606 | " cap_embeds[:real_batch_size, step, :],\n", 607 | " hidden_state[:, :real_batch_size, :].contiguous()) \n", 608 | " # 记录结果\n", 609 | " predictions[:real_batch_size, step, :] = preds\n", 610 | " alphas[:real_batch_size, step, :] = alpha\n", 611 | " return predictions, alphas, captions, lengths, sorted_cap_indices\n", 612 | " \n", 613 | "class ARCTIC(nn.Module): #模型主体部分\n", 614 | " def __init__(self, image_code_dim, vocab, word_dim, attention_dim, hidden_size, num_layers):\n", 615 | " super(ARCTIC, self).__init__()\n", 616 | " self.vocab = vocab\n", 617 | " self.encoder = ImageEncoder()\n", 618 | " self.decoder = AttentionDecoder(image_code_dim, len(vocab),\n", 619 | " word_dim, attention_dim, hidden_size, num_layers)\n", 620 | " print(\"test\")\n", 621 | " def forward(self, images, captions, cap_lens):\n", 622 | " image_code = self.encoder(images)\n", 623 | " return self.decoder(image_code, captions, cap_lens)\n", 624 | " def generate_by_beamsearch(self, images, beam_k, max_len): # beam_k束搜索\n", 625 | " vocab_size = len(self.vocab)\n", 626 | " image_codes = self.encoder(images)\n", 627 | " texts = []\n", 628 | " device = images.device\n", 629 | " # 对每个图像样本执行束搜索\n", 630 | " for image_code in image_codes:\n", 631 | " # 将图像表示复制k份\n", 632 | " image_code = image_code.unsqueeze(0).repeat(beam_k,1,1,1)\n", 633 | " # 生成k个候选句子,初始时,仅包含开始符号\n", 634 | " cur_sents = torch.full((beam_k, 1), self.vocab[''], dtype=torch.long).to(device)\n", 635 | " cur_sent_embed = self.decoder.embed(cur_sents)[:,0,:]\n", 636 | " sent_lens = torch.LongTensor([1]*beam_k).to(device)\n", 637 | " # 获得GRU的初始隐状态\n", 638 | " image_code, cur_sent_embed, _, _, hidden_state = \\\n", 639 | " self.decoder.init_hidden_state(image_code, cur_sent_embed, sent_lens)\n", 640 | " # 存储已生成完整的句子(以句子结束符结尾的句子)\n", 641 | " end_sents = []\n", 642 | " # 存储已生成完整的句子的概率\n", 643 | " end_probs = []\n", 644 | " # 存储未完整生成的句子的概率\n", 645 | " probs = torch.zeros(beam_k, 1).to(device)\n", 646 | " k = beam_k\n", 647 | " while True:\n", 648 | " preds, _, hidden_state = self.decoder.forward_step(image_code[:k], cur_sent_embed, hidden_state.contiguous())\n", 649 | " # -> (k, vocab_size)\n", 650 | " preds = nn.functional.log_softmax(preds, dim=1)\n", 651 | " # 对每个候选句子采样概率值最大的前k个单词生成k个新的候选句子,并计算概率\n", 652 | " # -> (k, vocab_size)\n", 653 | " probs = probs.repeat(1,preds.size(1)) + preds\n", 654 | " if cur_sents.size(1) == 1:\n", 655 | " # 第一步时,所有句子都只包含开始标识符,因此,仅利用其中一个句子计算topk\n", 656 | " values, indices = probs[0].topk(k, 0, True, True)\n", 657 | " else:\n", 658 | " # probs: (k, vocab_size) 是二维张量\n", 659 | " # topk函数直接应用于二维张量会按照指定维度取最大值,这里需要在全局取最大值\n", 660 | " # 因此,将probs转换为一维张量,再使用topk函数获取最大的k个值\n", 661 | " values, indices = probs.view(-1).topk(k, 0, True, True)\n", 662 | " # 计算最大的k个值对应的句子索引和词索引\n", 663 | " sent_indices = torch.div(indices, vocab_size, rounding_mode='trunc') \n", 664 | " word_indices = indices % vocab_size \n", 665 | " # 将词拼接在前一轮的句子后,获得此轮的句子\n", 666 | " cur_sents = torch.cat([cur_sents[sent_indices], word_indices.unsqueeze(1)], dim=1)\n", 667 | " # 查找此轮生成句子结束符的句子\n", 668 | " end_indices = [idx for idx, word in enumerate(word_indices) if word == self.vocab['']]\n", 669 | " if len(end_indices) > 0:\n", 670 | " end_probs.extend(values[end_indices])\n", 671 | " end_sents.extend(cur_sents[end_indices].tolist())\n", 672 | " # 如果所有的句子都包含结束符,则停止生成\n", 673 | " k -= len(end_indices)\n", 674 | " if k == 0:\n", 675 | " break\n", 676 | " # 查找还需要继续生成词的句子\n", 677 | " cur_indices = [idx for idx, word in enumerate(word_indices) \n", 678 | " if word != self.vocab['']]\n", 679 | " if len(cur_indices) > 0:\n", 680 | " cur_sent_indices = sent_indices[cur_indices]\n", 681 | " cur_word_indices = word_indices[cur_indices]\n", 682 | " # 仅保留还需要继续生成的句子、句子概率、隐状态、词嵌入\n", 683 | " cur_sents = cur_sents[cur_indices]\n", 684 | " probs = values[cur_indices].view(-1,1)\n", 685 | " hidden_state = hidden_state[:,cur_sent_indices,:]\n", 686 | " cur_sent_embed = self.decoder.embed(\n", 687 | " cur_word_indices.view(-1,1))[:,0,:]\n", 688 | " # 句子太长,停止生成\n", 689 | " if cur_sents.size(1) >= max_len:\n", 690 | " break\n", 691 | " if len(end_sents) == 0:\n", 692 | " # 如果没有包含结束符的句子,则选取第一个句子作为生成句子\n", 693 | " gen_sent = cur_sents[0].tolist()\n", 694 | " else: \n", 695 | " # 否则选取包含结束符的句子中概率最大的句子\n", 696 | " gen_sent = end_sents[end_probs.index(max(end_probs))]\n", 697 | " texts.append(gen_sent)\n", 698 | " return texts\n", 699 | "mode_arctic=\"../best_arctic.ckpt\"\n", 700 | "model=torch.load(mode_arctic)[\"model\"] #加载模型" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "metadata": {}, 706 | "source": [ 707 | "### 视觉Transformer (ViT) + Transformer解码器\n", 708 | "\n", 709 | "\n", 710 | "\n", 711 | "\n", 712 | "### ViT部分:\n", 713 | "ViT是一种用于处理视觉数据的转换器架构。传统的卷积神经网络(CNN)在图像处理任务中表现出色,但ViT引入了自注意力机制,使得它可以处理可变大小的图像。ViT将输入图片分为多个patch(16x16),再将每个patch投影为固定长度的向量送入Transformer,后续encoder的操作和原始Transformer中完全相同。但是因为对图片分类,因此在输入序列中加入一个特殊的token,该token对应的输出即为最后的类别预测\n", 714 | "一个ViT block可以分为以下几个步骤\n", 715 | "\n", 716 | "- **patch embedding**\n", 717 | " 例如输入图片大小为224x224,将图片分为固定大小的patch,patch大小为16x16,则每张图像会生成224x224/16x16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196x768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197x768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题\n", 718 | "\n", 719 | "- **positional encoding(standard learnable 1D position embeddings)**\n", 720 | " 加入位置编码。加入位置编码信息之后,维度依然是197x768。位置编码的方式有**1-D位置编码**和**2-D 位置编码**,无论是哪种方式精度都很接近,甚至不适用位置编码性能损失也没有大到无法接受的地步。猜测因为ViT作用在image patch上,对网络来说这些patch之间的相对位置信息很容易理解,所以位置编码的方式与是否采用的影响都不大\n", 721 | "\n", 722 | "- **LN/multi-head attention/LN**\n", 723 | " \n", 724 | " LN输出维度依然是197x768。将输入映射到q,k,v从而进行MHA操作,如果只有一个head,qkv的维度都是197x768,如果有12个头(768/12=64),则qkv的维度是197x64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197x768,然后在过一层LN,维度依然是197x768\n", 725 | "\n", 726 | "- **MLP**\n", 727 | " \n", 728 | " 将维度放大再缩小回去,197x768放大为197x3072,再缩小变为197x768一个block之后维度依然和输入相同,都是197x768,因此可以堆叠多个block。最后会将特殊字符cls对应的输出 作为encoder的最终输出 ,代表最终的image presentation(另一种做法是不加cls字符,对所有的tokens的输出做一个平均),后面接一个MLP进行图片分类\n", 729 | "\n", 730 | "### Transformer解码器:\n", 731 | "我们首先采用BERT来对提取的特征进行编码,再用Transformer解码器对编码内容进行处理后输出。要注意的是,这里我们仅采用了BERT的结构,并没有调用预训练模型。\n", 732 | "BERT(Bidirectional Encoder Representations from Transformers)是一种自然语言处理中常用的预训练模型。我们将BERT用作图像特征的编码器。通过将ViT提取的图像特征输入BERT,模型能够学习图像中的上下文信息,捕捉图像内部的复杂关系。这使得模型能够更好地理解图像内容,从而更好地生成相关的图像描述。\n", 733 | "- BERT的结构是标准transformer结构的encoder部分, 一个transformer的encoder单元由一个multi-head-Attention + Layer Normalization + feedforword + Layer Normalization 叠加产生,BERT的每一层由一个这样的encoder单元构成\n", 734 | "- 这种transformer的结构可以使用上下文来预测mask的token,从而捕捉双向关系\n", 735 | "\n", 736 | "之后,我们采用经典的Transformer解码器结构,Transformer解码器在NLP领域中被广泛用于生成序列数据,如文本。在这里,我们使用Transformer解码器来生成图像描述。解码器接收来自BERT编码器的图像特征作为输入,然后通过自注意力机制和前馈神经网络层逐步生成与图像内容相关的描述序列。这个描述序列最终形成图像的摘要或说明。\n", 737 | "\n", 738 | "### 网格/区域表示、Transformer编码器+Transformer解码器\n", 739 | "\n", 740 | "### 网格表示:\n", 741 | "我们借鉴了 **SwinTransformer** 的结构来做网格特征提取。SwinTransformer是一个基于注意力机制的深度学习模型,专门设计用于图像处理任务。它采用了分层的注意力机制,允许模型有效地处理大尺寸的图像。在图像摘要生成任务中,SwinTransformer作为网格特征提取器,从输入图像中提取有用的特征。\n", 742 | "- **Hierarchical Patch-based Attention**\n", 743 | " \n", 744 | " Swin Transformer引入了层次化的基于Patch的注意力机制。传统的Transformer模型是基于全连接的,对于大尺寸的图像,计算复杂度可能会非常高。为了解决这个问题,Swin Transformer将图像划分为一系列的非重叠图像块(patches),并在这些块上应用自注意力机制。这种分块处理允许Swin Transformer更好地扩展到大规模图像。\n", 745 | "- **Shifted Windows**\n", 746 | " \n", 747 | " Swin Transformer采用了一种称为\"shifted windows\"的策略,通过改变注意力机制中的窗口偏移,提高了模型的局部感知能力。这对于捕捉图像中不同区域的特征非常有帮助,尤其是在需要考虑对象之间相对位置关系时。\n", 748 | "- **Tokenization and Positional Embeddings**\n", 749 | " \n", 750 | " Swin Transformer将图像块转换为序列,每个序列元素对应一个图像块。为了使模型能够处理序列数据,需要引入类似自然语言处理中的tokenization和positional embeddings机制。这样,Swin Transformer可以对图像块进行有效的注意力计算,并理解它们之间的语义关系。\n", 751 | "\n", 752 | "- **PatchMergin**\n", 753 | " \n", 754 | " 类似于池化操作\n", 755 | "\n", 756 | "- 借鉴了许多CNN中的trick,每经过一个stage后size就会缩小为原来的二分之一,channel扩大为两倍,与CNN相似\n", 757 | " - Patch Merging 模块将 尺寸为 H×WH×W 的 Patch 块首先进行拼接并在 channel 维度上进行 concatenate 构成了 H/2×W/2×4C的特征图,然后再进行 Layer Normalization 操作进行正则化,通过一个 Linear 层后形成了一个 H/2×W/2×2C得到特征图,完成了特征图的下采样过程。size 缩小为原来的 1/2,channel 扩大为原来的 2 倍。\n", 758 | "- 分割成多个固定的窗口,每个窗口内的像素只能内部进行内积,虽然减少了计算开销,但是也因为这个操作,各个窗口无法进行信息交互,也即是感受野减小,难以整体获得信息,因此需要进行平移图像,改变分割方式,但这会使得计算量增加,因此重新滑动分割后的窗口,同时加入mask机制,防止不同位置的元素进行了自注意力的计算,最后再将信息数据平移回到原来的位置。通过 SW-MSA 机制完成了偏移窗口的像素点的 MSA 计算并实现了不同窗口间像素点的信息交流,从而间接扩大了网络的“感受野”,提高了信息的利用效率。\n", 759 | "\n", 760 | "\n", 761 | "### Transformer编码器+Transformer解码器\n", 762 | "\n", 763 | "同上,我们首先采用BERT来对提取的特征进行编码,再用Transformer解码器对编码内容进行处理后输出。要注意的是,这里我们仅采用了BERT的结构,并没有调用预训练模型。\n", 764 | "BERT(Bidirectional Encoder Representations from Transformers)是一种自然语言处理中常用的预训练模型。我们将BERT用作图像特征的编码器。通过将SwinTransformer提取特征信息输入BERT,模型能够学习图像中的上下文信息,捕捉图像内部的复杂关系。这使得模型能够更好地理解图像内容,从而更好地生成相关的图像描述。\n", 765 | "- BERT的结构是标准transformer结构的encoder部分, 一个transformer的encoder单元由一个multi-head-Attention + Layer Normalization + feedforword + Layer Normalization 叠加产生,BERT的每一层由一个这样的encoder单元构成\n", 766 | "- 这种transformer的结构可以使用上下文来预测mask的token,从而捕捉双向关系\n", 767 | "\n", 768 | "之后,我们采用经典的Transformer解码器结构,Transformer解码器在NLP领域中被广泛用于生成序列数据,如文本。在这里,我们使用Transformer解码器来生成图像描述。解码器接收来自BERT编码器的图像特征作为输入,然后通过自注意力机制和前馈神经网络层逐步生成与图像内容相关的描述序列。这个描述序列最终形成图像的摘要或说明。\n", 769 | "\n", 770 | "\n" 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "execution_count": null, 776 | "metadata": {}, 777 | "outputs": [], 778 | "source": [ 779 | "# ViT编码器+Transformer解码器\n", 780 | "class Img2TxtModel(nn.Module):\n", 781 | " def __init__(self, vit_model_name, transformer_config, vocab_size):\n", 782 | " super(Img2TxtModel, self).__init__()\n", 783 | " # ViT模型作为编码器\n", 784 | " self.encoder = ViTModel.from_pretrained(vit_model_name)\n", 785 | "\n", 786 | " # Transformer解码器配置\n", 787 | " transformer_config = BertConfig(vocab_size=vocab_size, num_hidden_layers=1, is_decoder=True, add_cross_attention=True)\n", 788 | " self.decoder = BertModel(transformer_config)\n", 789 | "\n", 790 | " # 预测每个词的线性层\n", 791 | " self.vocab_size = vocab_size\n", 792 | " self.fc = nn.Linear(transformer_config.hidden_size, vocab_size)\n", 793 | " \n", 794 | " def forward(self, input_ids, decoder_input_ids, decoder_attention_mask):\n", 795 | " # 通过ViT编码器获取图像特征\n", 796 | " encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state\n", 797 | "\n", 798 | " # 将图像特征作为解码器的输入\n", 799 | " decoder_outputs = self.decoder(input_ids=decoder_input_ids, \n", 800 | " attention_mask=decoder_attention_mask,\n", 801 | " encoder_hidden_states=encoder_outputs).last_hidden_state\n", 802 | "\n", 803 | " # 预测下一个词\n", 804 | " prediction_scores = self.fc(decoder_outputs)\n", 805 | " return prediction_scores\n", 806 | "\n", 807 | " def generate_text(self, input_ids, max_length=95, start_token_id=154):\n", 808 | " # 获取图像特征\n", 809 | " encoder_outputs = self.encoder(pixel_values=input_ids).last_hidden_state\n", 810 | "\n", 811 | " # 初始化解码器输入为标记\n", 812 | " decoder_input_ids = torch.full((input_ids.size(0), 1), start_token_id).to(input_ids.device)\n", 813 | " \n", 814 | " # 存储所有时间步的logits\n", 815 | " all_logits = []\n", 816 | "\n", 817 | " for step in range(max_length):\n", 818 | " # 获取解码器输出\n", 819 | " decoder_outputs = self.decoder(\n", 820 | " input_ids=decoder_input_ids, \n", 821 | " encoder_hidden_states=encoder_outputs\n", 822 | " ).last_hidden_state\n", 823 | "\n", 824 | " # 预测下一个词\n", 825 | " next_word_logits = self.fc(decoder_outputs[:, -1, :])\n", 826 | " all_logits.append(next_word_logits.unsqueeze(1))\n", 827 | " next_word_id = next_word_logits.argmax(dim=-1).unsqueeze(-1)\n", 828 | " \n", 829 | " # 将预测的词添加到解码器输入中\n", 830 | " decoder_input_ids = torch.cat([decoder_input_ids, next_word_id], dim=-1)\n", 831 | " \n", 832 | " return decoder_input_ids ,torch.cat(all_logits, dim=1)" 833 | ] 834 | }, 835 | { 836 | "cell_type": "code", 837 | "execution_count": null, 838 | "metadata": {}, 839 | "outputs": [], 840 | "source": [ 841 | "#SwinTransformerBlock ,SwinTransformer经过多个这样的block\n", 842 | "class SwinTransformerBlock(nn.Module):\n", 843 | " def __init__(self, dim, num_heads, window_size=7, shift_size=0., mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n", 844 | " super(SwinTransformerBlock, self).__init__()\n", 845 | "\n", 846 | " self.dim = dim\n", 847 | " self.num_heads = num_heads\n", 848 | " self.window_size = window_size\n", 849 | " self.shift_size = shift_size\n", 850 | " self.mlp_ratio = mlp_ratio\n", 851 | "\n", 852 | " self.norm1 = norm_layer(dim)\n", 853 | " self.attn = self._create_attention_module(dim, num_heads, window_size, qkv_bias, attn_drop, drop)\n", 854 | " self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n", 855 | " self.norm2 = norm_layer(dim)\n", 856 | " mlp_hidden_dim = int(dim * mlp_ratio)\n", 857 | " self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act=act_layer, drop=drop)\n", 858 | "\n", 859 | " def _create_attention_module(self, dim, num_heads, window_size, qkv_bias, attn_drop, drop):\n", 860 | " return WindowAttention(\n", 861 | " dim=dim, window_size=(window_size, window_size), num_heads=num_heads,\n", 862 | " qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop\n", 863 | " )\n", 864 | "\n", 865 | " def forward(self, x, attn_mask):\n", 866 | " x, Hp, Wp = self._process_input(x)\n", 867 | " shifted_x = self._shift_and_pad(x)\n", 868 | " x_windows = self._prepare_windows(shifted_x)\n", 869 | " attn_windows = self.attn(x_windows, mask=attn_mask)\n", 870 | " shifted_x = self._reverse_windows(attn_windows, Hp, Wp)\n", 871 | " x = self._restore_data(shifted_x, x.shape, Hp, Wp)\n", 872 | " x = x + self.drop_path(self.mlp(self.norm2(x)))\n", 873 | "\n", 874 | " return x\n", 875 | "\n", 876 | " def _process_input(self, x):\n", 877 | " H, W = self.H, self.W # feature map\n", 878 | " B, L, C = x.shape\n", 879 | " shortcut = x\n", 880 | " x = self.norm1(x)\n", 881 | " x = x.view(B, H, W, C)\n", 882 | " return x, H, W\n", 883 | "\n", 884 | " def _shift_and_pad(self, x):\n", 885 | " if self.shift_size > 0.:\n", 886 | " shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n", 887 | " else:\n", 888 | " shifted_x = x\n", 889 | " return shifted_x\n", 890 | "\n", 891 | " def _prepare_windows(self, x):\n", 892 | " x_windows = window_partition(x, self.window_size)\n", 893 | " x_windows = x_windows.view(-1, self.window_size * self.window_size, x.shape[-1])\n", 894 | " return x_windows\n", 895 | "\n", 896 | " def _reverse_windows(self, attn_windows, Hp, Wp):\n", 897 | " attn_windows = attn_windows.view(-1, self.window_size, self.window_size, attn_windows.shape[-1])\n", 898 | " shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)\n", 899 | " return shifted_x\n", 900 | "\n", 901 | " def _restore_data(self, shifted_x, original_shape, Hp, Wp):\n", 902 | " if self.shift_size > 0:\n", 903 | " x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n", 904 | " else:\n", 905 | " x = shifted_x\n", 906 | " x_r = (self.window_size - original_shape[2] % self.window_size) % self.window_size\n", 907 | " x_d = (self.window_size - original_shape[1] % self.window_size) % self.window_size\n", 908 | " if x_r > 0 or x_d > 0:\n", 909 | " x = x[:, :original_shape[1], :original_shape[2], :].contiguous()\n", 910 | " x = x.view(original_shape[0], original_shape[1] * original_shape[2], original_shape[3])\n", 911 | " return x" 912 | ] 913 | }, 914 | { 915 | "cell_type": "code", 916 | "execution_count": null, 917 | "metadata": {}, 918 | "outputs": [], 919 | "source": [ 920 | "#特征提取器的总体结构\n", 921 | "class SwinTransformerFeatureExtractor(nn.Module):\n", 922 | " def __init__(self, downsapmle_size=4, in_channels=3, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), window_size=7, mlp_ratio=4.,\n", 923 | " qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, patch_norm=True, **kwargs):\n", 924 | " super(SwinTransformerFeatureExtractor, self).__init__()\n", 925 | "\n", 926 | " self.num_layers = len(depths)\n", 927 | " self.embed_dim = embed_dim\n", 928 | " self.patch_norm = patch_norm\n", 929 | " # stage4 输出的特征矩阵的Channel\n", 930 | " self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))\n", 931 | " self.mlp_ratio = mlp_ratio\n", 932 | "\n", 933 | " self.patch_embed = patchEmbed(patch_size=downsapmle_size, in_channels=in_channels, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None)\n", 934 | " self.pos_drop = nn.Dropout(p=drop_rate)\n", 935 | "\n", 936 | " dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]\n", 937 | "\n", 938 | " self.layers = nn.ModuleList()\n", 939 | " for i_layer in range(self.num_layers):\n", 940 | " layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),\n", 941 | " depth=depths[i_layer],\n", 942 | " num_heads=num_heads[i_layer],\n", 943 | " window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate,\n", 944 | " drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=patchmerging if (i_layer < self.num_layers - 1) else None)\n", 945 | " self.layers.append(layers)\n", 946 | "\n", 947 | " def _init_weights(self, m):\n", 948 | " if isinstance(m, nn.Linear):\n", 949 | " nn.init.trunc_normal_(m.weight, std=.02)\n", 950 | " if isinstance(m, nn.Linear) and m.bias is not None:\n", 951 | " nn.init.constant_(m.bias, 0)\n", 952 | "\n", 953 | " elif isinstance(m, nn.LayerNorm):\n", 954 | " nn.init.constant_(m.bias, 0)\n", 955 | " nn.init.constant_(m.weight, 1.0)\n", 956 | "\n", 957 | " def forward(self, x):\n", 958 | " # [B, L, C]\n", 959 | " x, H, W = self.patch_embed(x)\n", 960 | " x = self.pos_drop(x)\n", 961 | " hidden_states = []\n", 962 | " for layer in self.layers:\n", 963 | " x, H, W = layer(x, H, W)\n", 964 | " \n", 965 | " hidden_states.append(x.clone())\n", 966 | "\n", 967 | " return x, hidden_states\n" 968 | ] 969 | }, 970 | { 971 | "cell_type": "markdown", 972 | "metadata": {}, 973 | "source": [ 974 | "### blip + 多模态构建 新数据集\n", 975 | "\n", 976 | "为了得到带有背景描述的数据集,我们利用训练的服饰图像描述模型和多模态大语言模型,为真实背景的服饰图像数据集增加服饰描述和背景描述,构建全新的服饰图像描述数据集。\n", 977 | "\n", 978 | "我的使用的新数据集是选用 DeepFasion开源的12w数据集,仅使用图片,选用其中背景较丰富的5000张图像区间(从31006开始),具体范围可以看new_dataset/combined_input.json部分的起点和终点key。\n", 979 | "\n", 980 | "由于没有条件调用GPT-4 Vision的API,我们没有直接使用多模态的大模型,而是采用img2txt再txt2txt的模式,先让图像描述模型(Blip)生成图像的文字描述,再用大语言模型(文心一言)接受这些文字描述,按照提示性的prompt生成更有深度的背景信息。\n", 981 | "\n", 982 | "下面我封装了一个模块,用以快速调用Blip 进行批量处理,生成关于图片的简单文字表述,结果保存在new_dataset/res_new.json中,包括一个背景的粗略描述和一个整体的图像描述。\n" 983 | ] 984 | }, 985 | { 986 | "cell_type": "code", 987 | "execution_count": null, 988 | "metadata": {}, 989 | "outputs": [], 990 | "source": [ 991 | "class blip_model():\n", 992 | " def __init__(self) -> None:\n", 993 | " self.processor = BlipProcessor.from_pretrained(\"Salesforce/blip-image-captioning-base\")\n", 994 | " self.model = BlipForConditionalGeneration.from_pretrained(\"Salesforce/blip-image-captioning-base\").to(\"cuda\")\n", 995 | " def gen_res(self,img_path):\n", 996 | " raw_image = Image.open(img_path).convert('RGB')\n", 997 | "\n", 998 | " text = \"a people in front of \"\n", 999 | " input_1 = self.processor(raw_image, text, return_tensors=\"pt\").to(\"cuda\")\n", 1000 | " out_1 = self.model.generate(**input_1,max_length=100)\n", 1001 | " res_1=self.processor.decode(out_1[0], skip_special_tokens=True)\n", 1002 | "\n", 1003 | " input_2 = self.processor(raw_image, return_tensors=\"pt\").to(\"cuda\")\n", 1004 | " out_2 = self.model.generate(**input_2,max_length=100)\n", 1005 | " res_2=self.processor.decode(out_2[0], skip_special_tokens=True)\n", 1006 | " return res_1+\". \"+res_2\n", 1007 | "def gen_json(img_path,n):\n", 1008 | " model=blip_model()\n", 1009 | " #img_path=\"D:/NNDL/data/deepfashion-multimodal/images\"\n", 1010 | " #获取该目录下所有文件,存入列表中\n", 1011 | " imgs=os.listdir(img_path)\n", 1012 | " res={}\n", 1013 | " start=31000\n", 1014 | "\n", 1015 | " for img in range(start,len(imgs)):\n", 1016 | " img_k=imgs[img]\n", 1017 | " img_path_=img_path+\"/\"+img_k\n", 1018 | " res[img_k]=model.gen_res(img_path_)\n", 1019 | " if len(res)>=n:\n", 1020 | " break\n", 1021 | " #保存为json文件\n", 1022 | " with open('res.json', 'w') as f:\n", 1023 | " json.dump(res, f,indent=2)" 1024 | ] 1025 | }, 1026 | { 1027 | "cell_type": "markdown", 1028 | "metadata": {}, 1029 | "source": [ 1030 | "### LLM生成更有深度的图像描述\n", 1031 | "将Blip生成的文字输入给文心一言,返回更有深度的背景图像描述。由于文心一言未能按照预期仅仅输出背景相关的描述,这里取描述的第一句话存于new_dataset/res_add.json中作为数据。" 1032 | ] 1033 | }, 1034 | { 1035 | "cell_type": "code", 1036 | "execution_count": null, 1037 | "metadata": {}, 1038 | "outputs": [], 1039 | "source": [ 1040 | "import requests\n", 1041 | "import json\n", 1042 | "import os\n", 1043 | "import datetime\n", 1044 | "\n", 1045 | "API_KEY = 'jIHlGsMYHp4j1MrMZGNmYbCL' \n", 1046 | "SECRET_KEY = 'GzG1o4HC6G0qDVxPKcn9Zl4pv20j7CGA' #已经过了有效期了所以直接放这里了\n", 1047 | "\n", 1048 | "headers = {\n", 1049 | " 'Content-Type': 'application/json',\n", 1050 | " 'Accept': 'application/json'\n", 1051 | "}\n", 1052 | "\n", 1053 | "def get_access_token():\n", 1054 | " \"\"\"\n", 1055 | " 使用 AK,SK 生成鉴权签名(Access Token)\n", 1056 | " :return: access_token,或是None(如果错误)\n", 1057 | " \"\"\"\n", 1058 | " url = \"https://aip.baidubce.com/oauth/2.0/token\"\n", 1059 | " params = {\"grant_type\": \"client_credentials\", \"client_id\": API_KEY, \"client_secret\": SECRET_KEY}\n", 1060 | " return str(requests.post(url, params=params).json().get(\"access_token\"))\n", 1061 | "\n", 1062 | "def read_json(file_path):\n", 1063 | " with open(file_path, 'r', encoding='utf-8') as file:\n", 1064 | " return json.load(file)\n", 1065 | "\n", 1066 | "def QianFan(url, inputs):\n", 1067 | " responses = {} # 创建一个字典来存储输入和相应的响应\n", 1068 | "\n", 1069 | " for key, user_input in inputs.items(): # inputs是一个字典\n", 1070 | " request = {\n", 1071 | " \"messages\": [\n", 1072 | " {\n", 1073 | " \"role\": \"user\",\n", 1074 | " \"content\": \"I will give you a sentence, where the first part is a summary of the background and the second part is information about a person's outfit. Please focus on the background information from the first part and provide an overall background description. \"\n", 1075 | " },\n", 1076 | " {\n", 1077 | " \"role\": \"assistant\",\n", 1078 | " \"content\": \"Of course, please provide the sentence, and I will only output a sentence like: 'The background is'. \"\n", 1079 | " },\n", 1080 | " {\n", 1081 | " \"role\": \"user\",\n", 1082 | " \"content\": \"a people in front of a bed. a pair of jeans\"\n", 1083 | " },\n", 1084 | " {\n", 1085 | " \"role\": \"assistant\",\n", 1086 | " \"content\": \"The backgroud is a homely bedroom.\"\n", 1087 | " },\n", 1088 | " {\n", 1089 | " \"role\": \"user\",\n", 1090 | " \"content\": \"a people in front of a mirror. a man in a blue shorts and a white shirt\"\n", 1091 | " },\n", 1092 | " {\n", 1093 | " \"role\": \"assistant\",\n", 1094 | " \"content\": \"The backgroud is a mirror.\"\n", 1095 | " },\n", 1096 | " ]\n", 1097 | " }#添加一些前置词,以获得更贴合且标准化的回答。\n", 1098 | " # 添加数据\n", 1099 | " request[\"messages\"].append({\"role\": \"user\", \"content\": f\"Describe the setting of the following scene, focusing solely on the background without including any details about the person:{user_input}\"})\n", 1100 | "\n", 1101 | "\n", 1102 | " try:\n", 1103 | " response = requests.request(\"POST\", url, headers=headers, data=json.dumps(request))\n", 1104 | " text = response.text\n", 1105 | " data = json.loads(text)\n", 1106 | " model_response = data['result']\n", 1107 | " print(\"\\n回答:\\n\", model_response, '\\n')\n", 1108 | " # 根据句号分割文本\n", 1109 | " sentences = model_response.split(\". \")\n", 1110 | "\n", 1111 | " # 获取第一个句子\n", 1112 | " first_sentence = sentences[0] + \".\"\n", 1113 | " responses[key] = first_sentence # 将响应存储在字典中\n", 1114 | "\n", 1115 | " except Exception as e:\n", 1116 | " print(f\"QianFan 接口调用出错: {e}\")\n", 1117 | "\n", 1118 | " # 保存或处理responses字典\n", 1119 | " return responses\n", 1120 | "\n", 1121 | "def main():\n", 1122 | " if not os.path.exists('./Amadeus/history'):\n", 1123 | " os.makedirs('./Amadeus/history')\n", 1124 | "\n", 1125 | " url = \"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=\" + get_access_token()\n", 1126 | "\n", 1127 | " inputs = read_json('res_new.json') \n", 1128 | " responses = QianFan(url, inputs)\n", 1129 | "\n", 1130 | " # 可以选择保存responses字典\n", 1131 | " with open('./Amadeus/history/responses.json', 'w', encoding='utf-8') as file:\n", 1132 | " json.dump(responses, file, ensure_ascii=False, indent=4)\n", 1133 | "\n", 1134 | "if __name__ == \"__main__\":\n", 1135 | " main()\n" 1136 | ] 1137 | }, 1138 | { 1139 | "cell_type": "markdown", 1140 | "metadata": {}, 1141 | "source": [ 1142 | "### 生成基本的服饰描述\n", 1143 | "使用之前训练好的ViT模型,生成新数据集上的图像描述,结果储存在new_dataset/res.json中。\n", 1144 | "\n", 1145 | "下面是用于调用之前训练的模型生成服装描述的脚本代码" 1146 | ] 1147 | }, 1148 | { 1149 | "cell_type": "code", 1150 | "execution_count": null, 1151 | "metadata": {}, 1152 | "outputs": [], 1153 | "source": [ 1154 | "from torchvision import transforms\n", 1155 | "import json\n", 1156 | "import os\n", 1157 | "from PIL import Image\n", 1158 | "import torch\n", 1159 | "from torch.utils.data import Dataset, DataLoader\n", 1160 | "from transformers import ViTFeatureExtractor, BertTokenizer ,ViTModel, BertModel, BertConfig\n", 1161 | "from collections import defaultdict, Counter\n", 1162 | "import numpy as np\n", 1163 | "\n", 1164 | "\n", 1165 | "dataset='deepfashion-multimodal'\n", 1166 | "img_path = f'data/{dataset}/img-001/img'\n", 1167 | "vocab_path = f'data/{dataset}/vocab.json'\n", 1168 | "\n", 1169 | "def idx_to_word(idx, vocab):#将向量转化为文本描述\n", 1170 | " reverse_vocab = {v: k for k, v in vocab.items()}\n", 1171 | " return reverse_vocab.get(int(idx), '')\n", 1172 | "\n", 1173 | "class CustomImageDataset(Dataset):\n", 1174 | " def __init__(self, img_folder, transform=None):\n", 1175 | " self.img_folder = img_folder\n", 1176 | " self.img_names = [img for img in os.listdir(img_folder) if img.endswith(('.png', '.jpg', '.jpeg'))]\n", 1177 | " print(len(self.img_names))\n", 1178 | " self.img_names = self.img_names[31000:36000]\n", 1179 | " print(self.img_names[0])\n", 1180 | " self.transform = transform\n", 1181 | "\n", 1182 | " def __len__(self):\n", 1183 | " return len(self.img_names)\n", 1184 | "\n", 1185 | " def __getitem__(self, idx):\n", 1186 | " img_path = os.path.join(self.img_folder, self.img_names[idx])\n", 1187 | " image = Image.open(img_path).convert('RGB')\n", 1188 | " if self.transform:\n", 1189 | " image = self.transform(image)\n", 1190 | "\n", 1191 | " return image, self.img_names[idx]\n", 1192 | "if torch.cuda.is_available():\n", 1193 | " device = torch.device(\"cuda\")\n", 1194 | " print(\"Using GPU:\", torch.cuda.get_device_name(0))\n", 1195 | "else:\n", 1196 | " device = torch.device(\"cpu\")\n", 1197 | " print(\"Using CPU\")\n", 1198 | "\n", 1199 | "# 图像预处理\n", 1200 | "transform = transforms.Compose([\n", 1201 | " transforms.Resize((224, 224)),\n", 1202 | " transforms.ToTensor(),\n", 1203 | " # 根据需要添加更多的转换\n", 1204 | "])\n", 1205 | "\n", 1206 | "# 创建 Dataset 实例\n", 1207 | "dataset = CustomImageDataset(img_folder=img_path, transform=transform)\n", 1208 | "\n", 1209 | "# 创建 DataLoader\n", 1210 | "data_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", 1211 | "\n", 1212 | "with open(vocab_path, 'r') as f:\n", 1213 | " vocab = json.load(f)\n", 1214 | "\n", 1215 | "vocab_size = len(vocab)\n", 1216 | "vit_model_name = 'google/vit-base-patch16-224-in21k'\n", 1217 | "transformer_config = BertConfig()\n", 1218 | "\n", 1219 | "model = Img2TxtModel(vit_model_name, transformer_config, vocab_size)\n", 1220 | "# 加载模型状态字典\n", 1221 | "checkpoint = torch.load('./model/best_model_epoch_10_batch_2700.pth')\n", 1222 | "\n", 1223 | "\n", 1224 | "# 将状态字典应用到模型实例中\n", 1225 | "model.load_state_dict(checkpoint['model_state_dict'])\n", 1226 | "model = model.to(device)\n", 1227 | "\n", 1228 | "model.eval() # 将模型设置为评估模式\n", 1229 | "\n", 1230 | "generated_captions_dict = {}\n", 1231 | "\n", 1232 | "with torch.no_grad():\n", 1233 | " for images, name in data_loader:\n", 1234 | " images = images.to(device)\n", 1235 | " input_ids = images\n", 1236 | " outputs,_ = model.generate_text(input_ids, max_length=95, start_token_id=vocab[''])\n", 1237 | " for i in range(outputs.shape[0]):\n", 1238 | " gen_caption = [idx_to_word(idx, vocab) for idx in outputs[i]]\n", 1239 | " if '' in gen_caption:\n", 1240 | " gen_caption = gen_caption[1:] # 移除第一个元素 ()\n", 1241 | " if '' in gen_caption:\n", 1242 | " gen_caption = gen_caption[:gen_caption.index('')] # 移除 及其后面的元素\n", 1243 | "\n", 1244 | " caption_text = ' '.join(gen_caption)\n", 1245 | " generated_captions_dict[name[0]] = caption_text\n", 1246 | "with open('res.json', 'w') as f:\n", 1247 | " json.dump(generated_captions_dict, f, indent=2)" 1248 | ] 1249 | }, 1250 | { 1251 | "cell_type": "markdown", 1252 | "metadata": {}, 1253 | "source": [ 1254 | "### 组合数据\n", 1255 | "使用脚本将服饰描述和背景描述拼接在一起,得到完整的数据。" 1256 | ] 1257 | }, 1258 | { 1259 | "cell_type": "code", 1260 | "execution_count": null, 1261 | "metadata": {}, 1262 | "outputs": [], 1263 | "source": [ 1264 | "import json\n", 1265 | "\n", 1266 | "# 读取JSON文件的函数\n", 1267 | "def read_json(file_path):\n", 1268 | " with open(file_path, 'r', encoding='utf-8') as file:\n", 1269 | " return json.load(file)\n", 1270 | "\n", 1271 | "# 读取两个JSON文件\n", 1272 | "json1 = read_json('res.json')\n", 1273 | "json2 = read_json('res_add.json')\n", 1274 | "\n", 1275 | "# 存储组合后的结果\n", 1276 | "combined_values = {}\n", 1277 | "\n", 1278 | "# 遍历第一个JSON文件的键\n", 1279 | "for key in json1:\n", 1280 | " if key in json2:\n", 1281 | " # 将两个文件中相同键的值组合在一起\n", 1282 | " combined_values[key] = json1[key] + json2[key]\n", 1283 | "\n", 1284 | "# 保存组合后的结果到新的JSON文件\n", 1285 | "with open('combined_input.json', 'w', encoding='utf-8') as file:\n", 1286 | " json.dump(combined_values, file, ensure_ascii=False, indent=2)\n", 1287 | "\n", 1288 | "print(\"Combined JSON saved as combined_output.json\")" 1289 | ] 1290 | }, 1291 | { 1292 | "cell_type": "markdown", 1293 | "metadata": {}, 1294 | "source": [ 1295 | "### 再次训练模型\n", 1296 | "按照之前的流程重新在新数据集上训练一个模型,得到新的模型,这里不再放训练的代码。" 1297 | ] 1298 | }, 1299 | { 1300 | "cell_type": "markdown", 1301 | "metadata": {}, 1302 | "source": [ 1303 | "# 实验结果\n", 1304 | "由于除了BLEU-4以外,其他指标的最大值其实和数据有关,所以为此我们先进行一个实验:得到如下指标的最大结果,然后将其作为一种相对的标准,结果如下:\n", 1305 | "BLEU-MAX:1.0|CIDEr-D-MAX:0.0043918896512309125|SPICE-MAX:0.18703616844830037,在评估过程中将会同时输出实际值和相对值(被压缩到0-1之间)," 1306 | ] 1307 | }, 1308 | { 1309 | "cell_type": "code", 1310 | "execution_count": 38, 1311 | "metadata": {}, 1312 | "outputs": [ 1313 | { 1314 | "name": "stdout", 1315 | "output_type": "stream", 1316 | "text": [ 1317 | "BLEU-MAX:1.0|CIDEr-D-MAX:0.0043918896512309125|SPICE-MAX:0.18703616844830037\n" 1318 | ] 1319 | } 1320 | ], 1321 | "source": [ 1322 | "max_cider=0.0043918896512309125\n", 1323 | "max_spice=0.18703616844830037\n", 1324 | "def evaluate_(data_loader) :#用来测试剩下两个指标的最大值\n", 1325 | " cands = []# 存储参考文本\n", 1326 | " refs = []# 需要过滤的词\n", 1327 | " filterd_words = set({model.vocab[''], model.vocab[''], model.vocab['']})\n", 1328 | " for i, (imgs, caps, caplens) in enumerate(data_loader):\n", 1329 | " cands.extend([filter_useless_words(cap, filterd_words) for cap in caps.tolist()])# 参考文本\n", 1330 | " refs.extend([filter_useless_words(cap, filterd_words) for cap in caps.tolist()]) #候选文本\n", 1331 | " bleu4_score = get_BLEU_score(cands, refs)\n", 1332 | " cider_d_score = get_CIDER_D_score(cands, refs)\n", 1333 | " spice_score= get_SPICE_score(cands, refs)\n", 1334 | " print(f\"BLEU-MAX:{bleu4_score}|CIDEr-D-MAX:{cider_d_score}|SPICE-MAX:{spice_score}\")\n", 1335 | " max_cider=cider_d_score\n", 1336 | " max_spice=spice_score\n", 1337 | "evaluate_(test_loader)" 1338 | ] 1339 | }, 1340 | { 1341 | "cell_type": "markdown", 1342 | "metadata": {}, 1343 | "source": [ 1344 | "这里以 ARCTIC 的测评代码为例子" 1345 | ] 1346 | }, 1347 | { 1348 | "cell_type": "code", 1349 | "execution_count": 45, 1350 | "metadata": {}, 1351 | "outputs": [ 1352 | { 1353 | "name": "stdout", 1354 | "output_type": "stream", 1355 | "text": [ 1356 | "@@@实际值 BLEU:0.30466660274770024|CIDEr-D:0.0017447527516659645|SPICE:0.1336081641272871\n", 1357 | "@@@相对值(0-1) BLEU:0.30466660274770024|CIDEr-D:0.3972669830575009|SPICE:0.7143439968629298\n" 1358 | ] 1359 | } 1360 | ], 1361 | "source": [ 1362 | "def filter_useless_words(sent, filterd_words):\n", 1363 | " # 去除句子中不参与BLEU值计算的符号\n", 1364 | " return [w for w in sent if w not in filterd_words]\n", 1365 | "cider_d_score=0\n", 1366 | "spice_score=0\n", 1367 | "def evaluate(data_loader, model, config):\n", 1368 | " model.eval()\n", 1369 | " # 存储候选文本\n", 1370 | " cands = []\n", 1371 | " # 存储参考文本\n", 1372 | " refs = []\n", 1373 | " # 需要过滤的词\n", 1374 | " filterd_words = set({model.vocab[''], model.vocab[''], model.vocab['']})\n", 1375 | " cpi = config.captions_per_image\n", 1376 | " device = next(model.parameters()).device\n", 1377 | " for i, (imgs, caps, caplens) in enumerate(data_loader):\n", 1378 | " with torch.no_grad():\n", 1379 | " # 通过束搜索,生成候选文本\n", 1380 | " texts = model.generate_by_beamsearch(imgs.to(device), config.beam_k, config.max_len+2)\n", 1381 | " # 候选文本\n", 1382 | " cands.extend([filter_useless_words(text, filterd_words) for text in texts])\n", 1383 | " # 参考文本\n", 1384 | " refs.extend([filter_useless_words(cap, filterd_words) for cap in caps.tolist()])\n", 1385 | " \n", 1386 | " \n", 1387 | " bleu4_score = get_BLEU_score(cands, refs)\n", 1388 | " cider_d_score = get_CIDER_D_score(cands, refs)\n", 1389 | " spice_score= get_SPICE_score(cands, refs)\n", 1390 | " print(f\"@@@实际值 BLEU:{bleu4_score}|CIDEr-D:{cider_d_score}|SPICE:{spice_score}\")\n", 1391 | " print(f\"@@@相对值(0-1) BLEU:{bleu4_score}|CIDEr-D:{cider_d_score/max_cider}|SPICE:{spice_score/max_spice}\")\n", 1392 | " #model.train()\n", 1393 | "evaluate(test_loader, model, ARCTIC_config)" 1394 | ] 1395 | }, 1396 | { 1397 | "cell_type": "markdown", 1398 | "metadata": {}, 1399 | "source": [ 1400 | "同时我们以测试集的第一行数据来进行演示:以生成的描述文本和参考文本进行比较,直观评估ARCTIC的性能" 1401 | ] 1402 | }, 1403 | { 1404 | "cell_type": "code", 1405 | "execution_count": 24, 1406 | "metadata": {}, 1407 | "outputs": [ 1408 | { 1409 | "name": "stdout", 1410 | "output_type": "stream", 1411 | "text": [ 1412 | "@生成数据: [[154, 21, 47, 16, 31, 32, 39, 39, 52, 28, 10, 11, 12, 1, 39, 52, 16, 28, 7, 8, 9, 72, 13, 16, 84, 1, 49, 50, 28, 7, 8, 9, 10, 11, 12, 57, 16, 25, 59, 34, 35, 60, 155], [154, 21, 47, 16, 31, 32, 39, 39, 52, 28, 43, 12, 1, 39, 52, 16, 28, 7, 8, 9, 72, 13, 16, 84, 1, 49, 50, 28, 7, 8, 9, 43, 12, 57, 16, 25, 59, 34, 35, 60, 155], [154, 38, 40, 4, 5, 6, 7, 8, 9, 43, 12, 1, 13, 14, 15, 16, 81, 1, 18, 3, 16, 14, 5, 19, 1, 8, 16, 7, 9, 15, 4, 43, 12, 1, 26, 3, 16, 28, 7, 8, 9, 43, 12, 1, 26, 3, 16, 28, 7, 8, 9, 43, 12, 57, 16, 25, 59, 34, 35, 60, 155]]\n", 1413 | "@生成的文本: This person is wearing a tank tank top with solid color patterns. The tank top is with cotton fabric and its neckline is round. The pants are with cotton fabric and solid color patterns. There is an accessory on her wrist.\n", 1414 | "@实际的文本: This woman is wearing a tank tank shirt with graphic patterns and a three-point shorts. The tank shirt is with cotton fabric and its neckline is crew. The shorts are with cotton fabric and graphic patterns.\n" 1415 | ] 1416 | } 1417 | ], 1418 | "source": [ 1419 | "def batch_eva(data_loader, model, config): #这里使用实验数据的第一个batch来进行演示\n", 1420 | " model.eval()\n", 1421 | " for i, (imgs, caps, caplens) in enumerate(test_loader):\n", 1422 | " cands = [] # 存储候选文本 \n", 1423 | " refs = [] # 存储参考文本\n", 1424 | " filterd_words = set({model.vocab[''], model.vocab[''], model.vocab['']}) #过滤词\n", 1425 | " cpi = config.captions_per_image\n", 1426 | " texts = model.generate_by_beamsearch(imgs.to(\"cuda\"), config.beam_k, config.max_len+2)\n", 1427 | " print(\"@生成数据:\",texts)\n", 1428 | " print(\"@生成的文本:\",wvec_to_cap(model.vocab,texts[0])) #抽出一个batch的第一个文本\n", 1429 | " print(\"@实际的文本:\",wvec_to_cap(model.vocab,caps[0].tolist()))\n", 1430 | " break\n", 1431 | "batch_eva(test_loader, model, ARCTIC_config)" 1432 | ] 1433 | }, 1434 | { 1435 | "cell_type": "markdown", 1436 | "metadata": {}, 1437 | "source": [ 1438 | "之后的两个模型的forward形式同上, 我们整理了最大值的统计情况,如下表格\n", 1439 | "\n", 1440 | "\n", 1441 | "| 模型名称 | BLEU-4 | CIDEr-D | SPICE |\n", 1442 | "|---------|------|------|------|\n", 1443 | "| ARCTIC | 0.30466660274770024 | (相对值:0.39726) 0.0017447527516659645 | (相对值:0.71434)0.1336081641272871 |\n", 1444 | "| VIT | **0.3074786561430062** | (相对值:0.93790) 0.004119164250591897 | (相对值:0.70628)0.1321016861247424 |\n", 1445 | "| SwinTrans | 0.25770958160979623 | (相对值:0.91071) 0.003999756354414652 | (相对值:0.60681)0.11349623153205401 |" 1446 | ] 1447 | }, 1448 | { 1449 | "cell_type": "markdown", 1450 | "metadata": {}, 1451 | "source": [ 1452 | "## 实验结果分析\n", 1453 | "\n", 1454 | "根据实验结果,分析模型的优缺点\n", 1455 | "- BLEU-4 分析下\n", 1456 | " - ARCTIC 模型 性能评价:\n", 1457 | " ARCTIC 模型在 BLEU-4 上表现良好,得分为(0.30466660274770024)。\n", 1458 | " 语法和词汇的准确性得到了有效提升,而且经过实验贪心算法下的得分要低于束搜索算法下的得分,所以束搜索算法在生成文本时,能够生成更流畅的文本。\n", 1459 | " - 视觉Transformer (ViT) + Transformer解码器 性能评价:\n", 1460 | " ViT + Transformer 模型在 BLEU-4 上的得分为(0.3074786561430062)。模型在服饰描述任务中取得了良好的结果,生成文本在语和词汇方面表现出色。**在这个指标下 其语义流畅度最好**\n", 1461 | " - 网格/区域表示、Transformer编码器+Transformer解码器 性能评价: \n", 1462 | " 网格/区域表示 + Transformer 模型在 BLEU-4 上的得分为(0.25770958160979623)。\n", 1463 | " 模型在服饰描述任务中呈现出略低的性能,表现缺乏一定语句流畅度\n", 1464 | "- CIDEr-D 分析下\n", 1465 | " - ARCTIC 模型 性能评价:\n", 1466 | " ARCTIC 模型在 CIDEr-D 上取得了(相对值:0.39726) 0.0017447527516659645 \n", 1467 | " CIDEr-D 分数显示模型在文本多样性和丰富性方面有一定的成功,但是不如其他模型。\n", 1468 | " - 视觉Transformer (ViT) + Transformer解码器 性能评价:\n", 1469 | " ViT + Transformer 模型在 CIDEr-D 上的得分为 (相对值:0.93790) 0.004119164250591897。\n", 1470 | " 模型在服饰描述任务中具有较高的文本多样性和丰富性,**是表现最好的模型**\n", 1471 | " - 网格/区域表示、Transformer编码器+Transformer解码器 性能评价:\n", 1472 | " 网格/区域表示 + Transformer 模型在 CIDEr-D 上的得分为(相对值:0.91071) 0.003999756354414652。\n", 1473 | " 模型对服饰描述的多样性和丰富性取得了一定的成功。可以看出虽然流畅度不如其他模型,但是多样性还是不错的。\n", 1474 | "- SPICE 分析下\n", 1475 | " - ARCTIC 模型 性能评价:\n", 1476 | " ARCTIC 模型在 SPICE 上取得了 (相对值:0.71434)0.1336081641272871 。\n", 1477 | " SPICE 分数反映了模型生成文本与图像内容相关性的程度,在语义层面上有一个很好的效果\n", 1478 | " - 视觉Transformer (ViT) + Transformer解码器 性能评价:\n", 1479 | " ViT + Transformer 模型在 SPICE 上的得分为 (相对值:0.70628)0.1321016861247424 。\n", 1480 | " 模型在服饰描述任务中表现出色,成功捕捉图像语义信息。\n", 1481 | " - 网格/区域表示、Transformer编码器+Transformer解码器 性能评价:\n", 1482 | " 网格/区域表示 + Transformer 模型在 SPICE 上的得分为 (相对值:0.60681)0.11349623153205401。\n", 1483 | " 模型在 SPICE 上的表现显示其在描述图像内容方面的良好性能,但是效果相对较低\n", 1484 | "\n", 1485 | "当然 ARCTIC、视觉Transformer (ViT) + Transformer解码器和网格/区域表示、Transformer编码器+Transformer解码器这三个模型有不同的优劣势\n", 1486 | "\n", 1487 | "- ARCTIC 模型:\n", 1488 | " - 优势:\n", 1489 | " 结合了注意力机制和编解码模型,有助于捕捉输入图像和生成描述之间的语义关系。\n", 1490 | " 注意力机制使得模型在生成描述时能够更加关注与服饰相关的区域,提高了描述的准确性。\n", 1491 | " - 劣势:\n", 1492 | " 可能需要更多的计算资源和训练时间,因为结合了多个模型组件。\n", 1493 | " 在处理大规模数据集时,训练和推理速度可能较慢。\n", 1494 | " 推理过程中尝试了不同的生成方式,发现加入 beam search 策略效果最佳。\n", 1495 | "\n", 1496 | "\n", 1497 | "- 视觉 Transformer (ViT) + Transformer 解码器:\n", 1498 | " - 优势:\n", 1499 | " ViT 模型将输入图像转换为序列数据,直接应用 Transformer 解码器进行描述生成。\n", 1500 | " Transformer 解码器在自然语言处理任务中表现出色,生成准确且流畅的描述。\n", 1501 | " - 劣势:\n", 1502 | " ViT 模型可能对输入图像的分辨率和细节要求较高,对于复杂的服饰图像可能需要更多的训练数据和计算资源。\n", 1503 | " 在处理长序列数据时,Transformer 解码器可能面临较长的训练和推理时间。\n", 1504 | "\n", 1505 | "\n", 1506 | "- 网格/区域表示、Transformer 编码器+Transformer 解码器:\n", 1507 | " - 优势:\n", 1508 | " 网格/区域表示将图像划分为网格或区域,有助于捕捉局部特征。\n", 1509 | " Transformer 编码器和解码器在处理序列数据时具有较强的建模能力,能够生成准确的描述。\n", 1510 | " - 劣势:\n", 1511 | " 网格/区域表示可能需要额外的预处理步骤来划分图像,并可能导致信息损失。\n", 1512 | " Transformer 编码器和解码器的训练和推理时间可能较长,特别是在处理大规模数据集时\n", 1513 | "\n", 1514 | "\n", 1515 | "\n" 1516 | ] 1517 | }, 1518 | { 1519 | "cell_type": "markdown", 1520 | "metadata": {}, 1521 | "source": [ 1522 | "# 新数据集分析\n", 1523 | "1.\"img_00035891.jpg\": \"a people in front of a red wall. a young man wearing a blue shirt and khaki pants\" --->\"img_00035891.jpg\": \"The scene takes place in front of a red wall.\"\n", 1524 | "\n", 1525 | "2.\"img_00035892.jpg\": \"a people in front of a mirror. a woman wearing a t - shirt with a picture of a man\"--->\"img_00035892.jpg\": \"The setting is a mirror in a room.\"\n", 1526 | "\n", 1527 | "3.\"img_00035899.jpg\": \"a people in front of a computer. a young boy sitting on a desk with a laptop\"--->\"img_00035899.jpg\": \"The setting is a room filled with a modern office desks and computers.\"\n", 1528 | "\n", 1529 | "4.\"img_00035904.jpg\": \"a people in front of a flower. a little girl with a flower in her hand\"--->\"img_00035904.jpg\": \"The setting is a garden or park with flowers blooming.\"\n", 1530 | "\n", 1531 | "在上述几个例子中,LLM接受blip生成的文字,确实可以输出更加有深度的背景信息,并且不会明显影响原本的信息,如1中的red wall。而且也能获得更深层的信息,如从mirror推出room,还有从laptop推出office,从flower推出花园.\n", 1532 | "但是可惜的是,blip的功能比较弱,提取图片中央信息就已经不准确了,背景更是经常出错,如1中其实真实的背景是一个红色集装箱,而2中背景是窗帘和玻璃。" 1533 | ] 1534 | }, 1535 | { 1536 | "cell_type": "markdown", 1537 | "metadata": {}, 1538 | "source": [ 1539 | "## 存在的问题 \n", 1540 | "\n", 1541 | "- 模型训练部分\n", 1542 | " - **资源不足**:不同模型对计算资源的需求不同,,作为学生的算力水平非常有限, 要训练出以该模型架构下的最优模型是一件困难的问题\n", 1543 | " - **数据集限制**:以训练数据构建的词典其实是一个非常小的词典,其实限制了模型的上限,离开数据集,在真实数据下的推理可能还是会有不完备的地方\n", 1544 | " - **图像裁剪**:数据集进行批量处理时的正方形裁剪可能会导致原始数据丢失部分信息,尤其是头部和脚步,还有白边导致的背景错误。\n", 1545 | " \n", 1546 | "- 数据生成部分\n", 1547 | " - **Blip性能不足**:我们使用的较小的模型,性能不足以提取所有背景,而且错误率出奇的高,导致后面的各项工作收到严重影响\n", 1548 | " 比如对于一个图像,经过blip生成的文本是:a people in front of a brick wall. a girl wearing a denim dress and white **shoes** ;但是原图没有鞋子的信息。在没有鞋子信息的情况下还是生成了鞋子信息\n", 1549 | " - **LLM性能不足**:文心一言相比ChatGPT,理解能力还是较弱,而且除了第一句总结性的背景概括,后续便开始进行一些离奇的没有根据的幻想,比如下面这个例子中即使在我反复强调的情况下,LLM还是违背我对文本内容和长度的限制,生成如下长篇大论:\n", 1550 | " \n", 1551 | " The scene takes place in a room with a green mat on the floor. The background is a wall, which is painted in a neutral color and has no decorations or features. The room is dimly lit, with only a small amount of natural light coming from the window. There is a door leading out of the room, but it is closed. The man standing on the green mat is wearing a blue shorts and a white shirt, and he is facing the wall. He appears to be in a meditative or contemplative state, as he is standing still and not interacting with anything else in the room.\n", 1552 | " \n", 1553 | " 但是,实际上我们需要的是相对较短的描述\n", 1554 | "\n", 1555 | " - **真实多模态能力的缺失**:由于现在的结构是先img2txt再txt2txt,对于LLM来说会损失很多图像信息,真正的多模态还是需要直接进行img、txt2txt,有一个badcase:img2txt模型将楼上的玻璃识别成了镜子,导致LLM的输出也出错,误认为其在一个房间里,而真实的图片是在一个街道上。\n", 1556 | " 这些因素导致数据集目前是一个不可用的状态。\n", 1557 | "\n", 1558 | "\n" 1559 | ] 1560 | }, 1561 | { 1562 | "cell_type": "markdown", 1563 | "metadata": {}, 1564 | "source": [ 1565 | "# 总结\n", 1566 | "\n", 1567 | "我们基本完成了作业要求,并额外进行了数据集的构建。\n", 1568 | "\n", 1569 | "完成了数个模型的构建和训练。\n", 1570 | "\n", 1571 | "利用训练的服饰图像描述模型和多模态大语言模型,为真实背景的服饰图像数据集增加服饰描述和背景描述,构建了一套可以生成经过了增强的全新数据集,并且在新数据集上重新训练服饰图像描述模型。\n", 1572 | "\n", 1573 | "尽管在课设完成的过程中我们面临了很多挑战和问题,例如有限的算力水平,这可能影响了构建模型的性能和实验的严谨性,但是我们依然尽我们所能完成了实验任务。" 1574 | ] 1575 | } 1576 | ], 1577 | "metadata": { 1578 | "kernelspec": { 1579 | "display_name": "base", 1580 | "language": "python", 1581 | "name": "python3" 1582 | }, 1583 | "language_info": { 1584 | "codemirror_mode": { 1585 | "name": "ipython", 1586 | "version": 3 1587 | }, 1588 | "file_extension": ".py", 1589 | "mimetype": "text/x-python", 1590 | "name": "python", 1591 | "nbconvert_exporter": "python", 1592 | "pygments_lexer": "ipython3", 1593 | "version": "3.9.16" 1594 | }, 1595 | "orig_nbformat": 4 1596 | }, 1597 | "nbformat": 4, 1598 | "nbformat_minor": 2 1599 | } 1600 | --------------------------------------------------------------------------------