├── transformer ├── readme.txt ├── user_modeling │ ├── readme.txt │ ├── Embed.py │ ├── download_image.py │ ├── Layers.py │ ├── pytorchtools.py │ ├── Optim.py │ ├── Sublayers.py │ ├── build_vocab.py │ ├── resize_image.py │ ├── Beam.py │ ├── preprocess.py │ ├── test.py │ ├── dataset.py │ ├── preprocess_fashionIQ.py │ ├── Models.py │ └── train.py ├── attribute_prediction │ ├── readme.txt │ ├── attribute_loader.py │ ├── test.py │ └── finetune.py └── interactive_retrieval │ ├── readme.txt │ ├── Vocabulary.py │ ├── data_loader.py │ ├── Ranker.py │ ├── glove_embedding.py │ ├── UserModel.py │ ├── utils.py │ ├── Beam.py │ ├── models.py │ ├── eval.py │ ├── user_model.py │ └── train.py ├── start_kit ├── models │ └── models_will_be_saved_here ├── data │ ├── captions │ │ └── download_the_caption_data │ ├── data_splits │ │ └── download_the_data_split │ ├── resized_images │ │ └── download_and_resize_the_images_via_url │ └── .DS_Store ├── README.md ├── resize_images.py ├── build_vocab.py ├── models.py ├── utils.py ├── eval.py ├── data_loader.py └── train.py └── README.md /transformer/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /start_kit/models/models_will_be_saved_here: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformer/user_modeling/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /transformer/attribute_prediction/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /start_kit/data/captions/download_the_caption_data: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /start_kit/data/data_splits/download_the_data_split: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transformer/interactive_retrieval/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /start_kit/data/resized_images/download_and_resize_the_images_via_url: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /start_kit/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoxiaoGuo/fashion-iq/HEAD/start_kit/data/.DS_Store -------------------------------------------------------------------------------- /start_kit/README.md: -------------------------------------------------------------------------------- 1 | # Fashion-IQ Starter Code 2 | ## About this repository 3 | For more information on the dataset, please visit its [project page](https://www.spacewu.com/posts/fashion-iq). 4 | ## Starter code for Fashion IQ challenge 5 | To get started with the framework, install the following dependencies: 6 | - Python 3.6 7 | - [PyTorch 0.4](https://pytorch.org/get-started/previous-versions/) 8 | ## Train and evaluate a model 9 | Follow the following steps to train a model: 10 | 1. Download the dataset and resize the images. 11 | 2. Build the vocabulary for a specific datasest: 12 | ``` 13 | python build_vocab.py --data_set dress 14 | ``` 15 | 3. Train the model 16 | ``` 17 | python train.py --data_set dress --batch_size 128 --log_step 15 18 | ``` 19 | The trained models will be saved into the folder `models/` every epoch. 20 | 4. Generate the submission results 21 | ``` 22 | python eval.py --data_set dress --batch_size 128 --model_folder --data_split test 23 | ``` 24 | 5. Submit your results at: https://codalab.lri.fr/competitions/573#results 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fashion-iq 2 | 3 | ## About this repository 4 | Fashion IQ is a dataset we contribute to the research community to 5 | facilitate research on natural language based interactive image retrieval. 6 | We released Fashion IQ dataset at ICCV 2019 workshop on 7 | [Linguistics Meets Image and Video Retrieval](https://sites.google.com/view/lingir/fashion-iq). 8 | 9 | The images can be downloaded from [here](https://github.com/hongwang600/fashion-iq-metadata). 10 | 11 | The image attribute features can be downloaded from [here](https://ibm.box.com/s/imyukakmnrkk2zuitju2m8akln3ayoct). 12 | 13 | ## Starter code for Fashion IQ challenge 14 | To get started with the framework, install the following dependencies: 15 | - Python 3.6 16 | - [PyTorch 0.4](https://pytorch.org/get-started/previous-versions/) 17 | 18 | ## Citations 19 | If you find Fashion IQ useful, please cite the following paper: 20 | 21 | ``` 22 | @article{guo2019fashion, 23 | title={The Fashion IQ Dataset: Retrieving Images by Combining Side Information and Relative Natural Language Feedback}, 24 | author={Wu, Hui and Gao, Yupeng and Guo, Xiaoxiao and Al-Halah, Ziad and Rennie, Steven and Grauman, Kristen and Feris, Rogerio}, 25 | journal={CVPR}, 26 | year={2021} 27 | } 28 | ``` 29 | 30 | ## License 31 | [Community Data License Agreement](https://cdla.io/) (CDLA) License 32 | -------------------------------------------------------------------------------- /transformer/interactive_retrieval/Vocabulary.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class Vocabulary(object): 5 | """Simple vocabulary wrapper.""" 6 | 7 | def __init__(self): 8 | self.word2idx = {} 9 | self.idx2word = {} 10 | self.idx = 0 11 | 12 | def add_word(self, word): 13 | if not word in self.word2idx: 14 | self.word2idx[word] = self.idx 15 | self.idx2word[self.idx] = word 16 | self.idx += 1 17 | 18 | def __call__(self, word): 19 | if not word in self.word2idx: 20 | return self.word2idx[''] 21 | return self.word2idx[word] 22 | 23 | def __len__(self): 24 | return len(self.word2idx) 25 | 26 | def init_vocab(self): 27 | self.add_word('') 28 | self.add_word('') 29 | self.add_word('') 30 | self.add_word('') 31 | self.add_word('') 32 | 33 | def save(self, file_name): 34 | data = {'word2idx': self.word2idx, 'idx2word': self.idx2word, 35 | 'idx': self.idx} 36 | with open(file_name, 'w') as f: 37 | json.dump(data, f, indent=4) 38 | return 39 | 40 | def load(self, file_name): 41 | with open(file_name, 'r') as f: 42 | data = json.load(f) 43 | self.word2idx = data['word2idx'] 44 | self.idx2word = data['idx2word'] 45 | self.idx = data['idx'] 46 | return 47 | -------------------------------------------------------------------------------- /transformer/user_modeling/Embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from torch.autograd import Variable 5 | 6 | class Embedder(nn.Module): 7 | def __init__(self, vocab_size, d_model): 8 | super().__init__() 9 | self.d_model = d_model 10 | self.embed = nn.Embedding(vocab_size, d_model) 11 | def forward(self, x): 12 | return self.embed(x) 13 | 14 | class PositionalEncoder(nn.Module): 15 | def __init__(self, d_model, max_seq_len = 200, dropout = 0.1): 16 | super().__init__() 17 | self.d_model = d_model 18 | self.dropout = nn.Dropout(dropout) 19 | # create constant 'pe' matrix with values dependant on 20 | # pos and i 21 | pe = torch.zeros(max_seq_len, d_model) 22 | for pos in range(max_seq_len): 23 | for i in range(0, d_model, 2): 24 | pe[pos, i] = \ 25 | math.sin(pos / (10000 ** ((2 * i)/d_model))) 26 | pe[pos, i + 1] = \ 27 | math.cos(pos / (10000 ** ((2 * (i + 1))/d_model))) 28 | pe = pe.unsqueeze(0) 29 | self.register_buffer('pe', pe) 30 | 31 | 32 | def forward(self, x): 33 | # make embeddings relatively larger 34 | x = x * math.sqrt(self.d_model) 35 | #add constant to embedding 36 | seq_len = x.size(1) 37 | pe = Variable(self.pe[:,:seq_len], requires_grad=False) 38 | if x.is_cuda: 39 | pe.cuda() 40 | x = x + pe 41 | return self.dropout(x) -------------------------------------------------------------------------------- /transformer/user_modeling/download_image.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import ssl 3 | import urllib.request 4 | import os 5 | import requests 6 | import tqdm 7 | 8 | ctx = ssl.create_default_context() 9 | ctx.check_hostname = False 10 | ctx.verify_mode = ssl.CERT_NONE 11 | 12 | def parse_url(url): 13 | # print('url', url) 14 | tokens = url.split('/') 15 | # print(tokens) 16 | folder = tokens[4] 17 | tokens = tokens[5].split('?') 18 | tokens.reverse() 19 | file = '.'.join(tokens) 20 | # print(tokens[1]) 21 | # print(tokens) 22 | # if len(tokens) > 1: 23 | # file = tokens[1] 24 | # else: 25 | # file = 'null' 26 | # print(tokens[4], tokens[5]) 27 | # print(folder, file) 28 | return 'images/' + folder + '.' + file 29 | 30 | 31 | def make_folder(folder): 32 | if not os.path.exists(folder): 33 | os.mkdir(folder) 34 | return 35 | 36 | def process_url(url): 37 | file = parse_url(url).lower() 38 | if file[-1] == '.': 39 | file = file + 'jpg' 40 | 41 | # make_folder(folder) 42 | 43 | if not os.path.isfile(file): 44 | with open(file, 'wb') as f: 45 | resp = requests.get(url, verify=False) 46 | f.write(resp.content) 47 | f.close() 48 | 49 | with open('birds-to-words-v1.0.tsv') as tsvfile: 50 | reader = csv.reader(tsvfile, delimiter='\t') 51 | # i = 0 52 | data = [] 53 | for i, row in enumerate(reader): 54 | # i += 1 55 | if i == 0: 56 | continue 57 | process_url(row[1]) 58 | process_url(row[5]) 59 | # print(i) 60 | # if i == 10: 61 | # break 62 | -------------------------------------------------------------------------------- /transformer/interactive_retrieval/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | from joblib import Parallel, delayed 5 | import json 6 | 7 | 8 | class Dataset(): 9 | def __init__(self, root, data_file_name, transform=None, num_workers=4): 10 | """Set the path for images, captions and vocabulary wrapper. 11 | 12 | Args: 13 | root: image directory. 14 | data: index file name 15 | vocab: vocabulary wrapper. 16 | transform: image transformer. 17 | """ 18 | self.num_workers = num_workers 19 | self.root = root 20 | with open(data_file_name, 'r') as f: 21 | data = json.load(f) 22 | self.data = data 23 | self.ids = range(len(self.data)) 24 | self.transform = transform 25 | 26 | def get_item(self, index): 27 | """Returns one data pair (image and caption).""" 28 | data = self.data 29 | id = self.ids[index] 30 | 31 | img_name = data[id] + '.jpg' 32 | 33 | image = Image.open(os.path.join(self.root, img_name)).convert('RGB') 34 | if self.transform is not None: 35 | image = self.transform(image) 36 | 37 | return image, [data[id]] 38 | 39 | def get_items(self, indexes): 40 | items = Parallel(n_jobs=self.num_workers)( 41 | delayed(self.get_item)( 42 | i) for i in indexes) 43 | 44 | return collate_fn(items) 45 | 46 | def __len__(self): 47 | return len(self.ids) 48 | 49 | 50 | def collate_fn(data): 51 | images, meta_info = zip(*data) 52 | 53 | # Merge images (from tuple of 3D tensor to 4D tensor). 54 | images = torch.stack(images, dim=0) 55 | 56 | return images, meta_info 57 | 58 | -------------------------------------------------------------------------------- /transformer/interactive_retrieval/Ranker.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import utils 4 | 5 | 6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 7 | 8 | 9 | class Ranker: 10 | def __init__(self, device): 11 | self.data_emb = None 12 | self.device = device 13 | return 14 | 15 | def update_emb(self, img_fts, batch_size, model): 16 | 17 | self.data_emb = [] 18 | num_data = len(img_fts['asins']) 19 | num_batch = math.floor(num_data / batch_size) 20 | 21 | def append_emb(first, last): 22 | batch_ids = torch.tensor( 23 | [j for j in range(first, last)], dtype=torch.long) 24 | 25 | feat = utils.get_image_batch(img_fts, batch_ids) 26 | 27 | feat = feat.to(device) 28 | with torch.no_grad(): 29 | feat = model.encode_image(feat) 30 | self.data_emb.append(feat) 31 | 32 | for i in range(num_batch): 33 | append_emb(i*batch_size, (i+1)*batch_size) 34 | 35 | if num_batch * batch_size < num_data: 36 | append_emb(num_batch * batch_size, num_data) 37 | 38 | self.data_emb = torch.cat(self.data_emb, dim=0) 39 | 40 | return 41 | 42 | def nearest_neighbors(self, inputs): 43 | neighbors = [] 44 | for i in range(inputs.size(0)): 45 | [_, neighbor] = ((self.data_emb - inputs[i]).pow(2) 46 | .sum(dim=1).min(dim=0)) 47 | 48 | neighbors.append(neighbor) 49 | return torch.tensor(neighbors).to( 50 | device=self.device, dtype=torch.long) 51 | 52 | def compute_rank(self, inputs, target_ids): 53 | rankings = [] 54 | for i in range(inputs.size(0)): 55 | distances = (self.data_emb - inputs[i]).pow(2).sum(dim=1) 56 | ranking = (distances < distances[target_ids[i]]).float().sum(dim=0) 57 | rankings.append(ranking) 58 | return torch.tensor(rankings).to(device=self.device, dtype=torch.float) 59 | 60 | -------------------------------------------------------------------------------- /transformer/user_modeling/Layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from Sublayers import FeedForward, MultiHeadAttention, Norm 4 | 5 | class EncoderLayer(nn.Module): 6 | def __init__(self, d_model, heads, dropout=0.1): 7 | super().__init__() 8 | self.norm_1 = Norm(d_model) 9 | self.norm_2 = Norm(d_model) 10 | self.attn = MultiHeadAttention(heads, d_model, dropout=dropout) 11 | self.ff = FeedForward(d_model, dropout=dropout) 12 | self.dropout_1 = nn.Dropout(dropout) 13 | self.dropout_2 = nn.Dropout(dropout) 14 | 15 | def forward(self, x, mask=None): 16 | x2 = self.norm_1(x) 17 | x = x + self.dropout_1(self.attn(x2,x2,x2,mask)) 18 | x2 = self.norm_2(x) 19 | x = x + self.dropout_2(self.ff(x2)) 20 | return x 21 | # def forward(self, image1, image2, mask=None): 22 | # # image1 = self.norm_1(image1) 23 | # # image2 = self.norm_2(image2) 24 | 25 | # output = self.dropout_1(self.attn(image1, image2, image2)) #q,k,v 26 | 27 | # build a decoder layer with two multi-head attention layers and 28 | # one feed-forward layer 29 | class DecoderLayer(nn.Module): 30 | def __init__(self, d_model, heads, dropout=0.1): 31 | super().__init__() 32 | self.norm_1 = Norm(d_model) 33 | self.norm_2 = Norm(d_model) 34 | self.norm_3 = Norm(d_model) 35 | 36 | self.dropout_1 = nn.Dropout(dropout) 37 | self.dropout_2 = nn.Dropout(dropout) 38 | self.dropout_3 = nn.Dropout(dropout) 39 | 40 | self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout) 41 | self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout) 42 | self.ff = FeedForward(d_model, dropout=dropout) 43 | 44 | def forward(self, x, e_outputs, src_mask=None, trg_mask=None): 45 | x2 = self.norm_1(x) 46 | x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask)) 47 | x2 = self.norm_2(x) 48 | x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs, \ 49 | src_mask)) 50 | x2 = self.norm_3(x) 51 | x = x + self.dropout_3(self.ff(x2)) 52 | return x -------------------------------------------------------------------------------- /transformer/user_modeling/pytorchtools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class EarlyStopping: 5 | """Early stops the training if validation loss doesn't improve after a given patience.""" 6 | def __init__(self, patience=7, verbose=False, args=None, delta=0): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | """ 16 | self.patience = patience 17 | self.verbose = verbose 18 | self.counter = 0 19 | self.best_score = None 20 | self.early_stop = False 21 | self.val_loss_min = -np.Inf 22 | self.delta = delta 23 | # self.save_path = args.save_model + '.chkpt' 24 | self.args = args 25 | self.epoch_i = 0 26 | 27 | def __call__(self, val_loss, model, epoch_i): 28 | 29 | score = val_loss 30 | self.epoch_i = epoch_i 31 | 32 | if self.best_score is None: 33 | self.best_score = score 34 | self.save_checkpoint(val_loss, model, epoch_i) 35 | elif score <= self.best_score + self.delta: 36 | self.counter += 1 37 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 38 | if self.counter >= self.patience: 39 | self.early_stop = True 40 | else: 41 | self.best_score = score 42 | self.save_checkpoint(val_loss, model, epoch_i) 43 | self.counter = 0 44 | 45 | def save_checkpoint(self, val_loss, model, epoch_i): 46 | '''Saves model when validation loss decrease.''' 47 | if self.verbose: 48 | print(f'Validation score changed ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 49 | model_state_dict = model.state_dict() 50 | checkpoint = { 51 | 'model': model_state_dict, 52 | 'settings': self.args, 53 | 'epoch': epoch_i} 54 | torch.save(checkpoint, self.args.save_model + '.chkpt') 55 | self.val_loss_min = val_loss -------------------------------------------------------------------------------- /transformer/user_modeling/Optim.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | import numpy as np 3 | import torch 4 | import torch.optim 5 | # class ScheduledOptim(): 6 | # '''A simple wrapper class for learning rate scheduling''' 7 | 8 | # def __init__(self, optimizer, d_model, n_warmup_steps): 9 | # self._optimizer = optimizer 10 | # self.n_warmup_steps = n_warmup_steps 11 | # self.n_current_steps = 0 12 | # self.init_lr = np.power(d_model, -0.5) 13 | 14 | # def step_and_update_lr(self): 15 | # "Step with the inner optimizer" 16 | # self._update_learning_rate() 17 | # self._optimizer.step() 18 | 19 | # def zero_grad(self): 20 | # "Zero out the gradients by the inner optimizer" 21 | # self._optimizer.zero_grad() 22 | 23 | # def _get_lr_scale(self): 24 | # return np.min([ 25 | # np.power(self.n_current_steps, -0.5), 26 | # np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 27 | 28 | # def _update_learning_rate(self): 29 | # ''' Learning rate scheduling per step ''' 30 | 31 | # self.n_current_steps += 1 32 | # lr = self.init_lr * self._get_lr_scale() 33 | 34 | # for param_group in self._optimizer.param_groups: 35 | # param_group['lr'] = lr 36 | 37 | class NoamOpt: 38 | "Optim wrapper that implements rate." 39 | def __init__(self, model_size, factor, warmup, optimizer): 40 | self.optimizer = optimizer 41 | self._step = 0 42 | self.warmup = warmup 43 | self.factor = factor 44 | self.model_size = model_size 45 | self._rate = 0 46 | 47 | def step(self): 48 | "Update parameters and rate" 49 | self._step += 1 50 | rate = self.rate() 51 | for p in self.optimizer.param_groups: 52 | p['lr'] = rate 53 | self._rate = rate 54 | self.optimizer.step() 55 | 56 | def rate(self, step = None): 57 | "Implement `lrate` above" 58 | if step is None: 59 | step = self._step 60 | return self.factor * \ 61 | (self.model_size ** (-0.5) * 62 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 63 | 64 | def get_std_opt(model,opt): 65 | return NoamOpt(opt.d_model, 2, opt.n_warmup_steps, 66 | torch.optim.Adam(model.get_trainable_parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 67 | 68 | -------------------------------------------------------------------------------- /transformer/attribute_prediction/attribute_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import os 4 | import json 5 | from PIL import Image 6 | import multiprocessing 7 | 8 | 9 | class Dataset(data.Dataset): 10 | def __init__(self, root, data_file_name, class_file, transform=None): 11 | """Set the path for images, captions and vocabulary wrapper. 12 | 13 | Args: 14 | root: image directory. 15 | data_file_name: asin --> [tag] 16 | transform: image transformer. 17 | """ 18 | self.root = root 19 | 20 | with open(data_file_name, 'r') as f: 21 | self.data = json.load(f) 22 | self.ids = range(len(self.data)) 23 | 24 | self.asin = [] 25 | for key, _ in self.data.items(): 26 | self.asin.append(key) 27 | 28 | with open(class_file, 'r') as f: 29 | self.cls = json.load(f) 30 | 31 | self.transform = transform 32 | return 33 | 34 | def __getitem__(self, index): 35 | """Returns one data pair (image and caption).""" 36 | id = self.ids[index] 37 | asin = self.asin[id] 38 | img_name = self.asin[id] + '.jpg' 39 | 40 | image = Image.open(os.path.join(self.root, img_name)).convert('RGB') 41 | if self.transform is not None: 42 | image = self.transform(image) 43 | 44 | attribute_labels = self.data[asin] 45 | # convert text words to idx, flattened 46 | label = torch.zeros(len(self.cls)) 47 | for sublist in attribute_labels[1:]: 48 | for tag in sublist: 49 | label[self.cls[tag]] = 1 50 | 51 | return image, label, asin 52 | 53 | def __len__(self): 54 | return len(self.ids) 55 | 56 | 57 | def collate_fn(data): 58 | """Creates mini-batch tensors from the list of tuples (image, tags). 59 | """ 60 | images, labels, asins = zip(*data) 61 | 62 | # Merge images (from tuple of 3D tensor to 4D tensor). 63 | images = torch.stack(images, dim=0) 64 | labels = torch.stack(labels, dim=0) 65 | 66 | return images, labels, asins 67 | 68 | 69 | def get_loader(root, data_file, class_file, transform, batch_size, shuffle): 70 | """Returns torch.utils.data.DataLoader for custom coco dataset.""" 71 | cpu_num = multiprocessing.cpu_count() 72 | num_workers = cpu_num - 2 if cpu_num > 2 else 1 73 | 74 | dataset = Dataset(root=root, 75 | data_file_name=data_file, 76 | class_file=class_file, 77 | transform=transform) 78 | 79 | data_loader = torch.utils.data.DataLoader( 80 | dataset=dataset, batch_size=batch_size, shuffle=shuffle, 81 | num_workers=num_workers, collate_fn=collate_fn, pin_memory=True) 82 | 83 | return data_loader 84 | 85 | -------------------------------------------------------------------------------- /start_kit/resize_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PIL import Image 4 | 5 | from joblib import Parallel, delayed 6 | import multiprocessing 7 | 8 | def resize_image(image, size): 9 | """Resize an image to the given size.""" 10 | return image.resize(size, Image.ANTIALIAS) 11 | 12 | def resize_images(image_dir, output_dir, size): 13 | """Resize the images in 'image_dir' and save into 'output_dir'.""" 14 | if not os.path.exists(output_dir): 15 | os.makedirs(output_dir) 16 | 17 | images = os.listdir(image_dir) 18 | num_images = len(images) 19 | for i, image in enumerate(images): 20 | with open(os.path.join(image_dir, image), 'r+b') as f: 21 | with Image.open(f) as img: 22 | img = resize_image(img, size) 23 | img.save(os.path.join(output_dir, image), img.format) 24 | if (i+1) % 100 == 0: 25 | print ("[{}/{}] Resized the images and saved into '{}'." 26 | .format(i+1, num_images, output_dir)) 27 | 28 | def resize_image_operator(image_file, output_file, size, i, num_images): 29 | with open(image_file, 'r+b') as f: 30 | with Image.open(f) as img: 31 | img = resize_image(img, size) 32 | img.save(output_file, img.format) 33 | if (i + 1) % 100 == 0: 34 | print("[{}/{}] Resized the images and saved." 35 | .format(i + 1, num_images)) 36 | return 37 | 38 | def resize_images_parallel(image_dir, output_dir, size): 39 | """Resize the images in 'image_dir' and save into 'output_dir'.""" 40 | if not os.path.exists(output_dir): 41 | os.makedirs(output_dir) 42 | num_cores = multiprocessing.cpu_count() 43 | print('resize on {} CPUs'.format(num_cores)) 44 | 45 | images = os.listdir(image_dir) 46 | num_images = len(images) 47 | Parallel(n_jobs=num_cores)( 48 | delayed(resize_image_operator)( 49 | os.path.join(image_dir, image), 50 | os.path.join(output_dir, image), 51 | size, 52 | i, 53 | num_images) for i, image in enumerate(images)) 54 | 55 | 56 | def main(args): 57 | image_dir = args.image_dir 58 | output_dir = args.output_dir 59 | image_size = [args.image_size, args.image_size] 60 | resize_images_parallel(image_dir, output_dir, image_size) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--image_dir', type=str, default='./data/train2014/', 66 | help='directory for train images') 67 | parser.add_argument('--output_dir', type=str, default='./data/resized2014/', 68 | help='directory for saving resized images') 69 | parser.add_argument('--image_size', type=int, default=256, 70 | help='size for image after processing') 71 | args = parser.parse_args() 72 | main(args) -------------------------------------------------------------------------------- /transformer/interactive_retrieval/glove_embedding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pickle 4 | import json 5 | 6 | 7 | def extract_glove_embedding(): 8 | embs = [] 9 | word2id = {'': 0, 10 | '': 1} 11 | 12 | glove_dim = 300 13 | embs.append(np.asarray([0.0] * glove_dim, 'float32')) 14 | embs.append(np.asarray([0.0] * glove_dim, 'float32')) 15 | 16 | emb_sum = 0 17 | count = 0 18 | with open('glove.6B.{}d.txt'.format(glove_dim), 'r') as f: 19 | for i, line in enumerate(f): 20 | array = line.strip().split(' ') 21 | word = array[0] 22 | word2id[word] = len(word2id) 23 | e = np.asarray(array[1:], 'float32') 24 | # print(len(e)) 25 | emb_sum += e 26 | count += 1 27 | embs.append(e) 28 | emb_sum /= count 29 | embs[word2id['']] = emb_sum 30 | 31 | # special token 32 | word2id[''] = len(word2id) 33 | embs.append(-emb_sum) 34 | 35 | save_data = {'word2id': word2id, 'embs': embs} 36 | with open('dict_x.pt', 'wb') as f: 37 | pickle.dump(save_data, f) 38 | 39 | 40 | def extract_vocab_embedding(args): 41 | with open(args.glove_file, 'rb') as f: 42 | glove = pickle.load(f) 43 | 44 | vocab_file = args.user_vocab_file.format(args.data_set) 45 | with open(vocab_file, 'r') as f: 46 | vocab = json.load(f) 47 | 48 | glove_word2idx = glove['word2id'] 49 | glove_embs = glove['embs'] 50 | 51 | vocab_emb = [] 52 | for idx in range(len(vocab['idx2word'])): 53 | word = vocab['idx2word'][str(idx)] 54 | if word in glove_word2idx: 55 | glove_word = word 56 | else: 57 | if word in ['', '']: 58 | glove_word = '' 59 | elif word == '': 60 | glove_word = '' 61 | else: 62 | glove_word = '' 63 | glove_word_idx = glove_word2idx[glove_word] 64 | vocab_emb.append(glove_embs[glove_word_idx]) 65 | print('idx:', idx, '\tword:', word, '\tglove_word:', glove_word) 66 | vocab_emb.append(glove_embs[glove_word2idx['']]) 67 | 68 | with open(args.save.format(args.data_set), 'wb') as f: 69 | pickle.dump(vocab_emb, f) 70 | return 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--glove_file', type=str, 76 | default='dict_x.pt') 77 | parser.add_argument('--data_set', type=str, 78 | default='shirt') 79 | parser.add_argument('--user_vocab_file', type=str, 80 | default='../user_modeling/data/{}_vocab.json') 81 | parser.add_argument('--save', type=str, 82 | default='data/{}_emb.pt') 83 | args = parser.parse_args() 84 | 85 | extract_glove_embedding() 86 | extract_vocab_embedding(args) 87 | 88 | -------------------------------------------------------------------------------- /transformer/user_modeling/Sublayers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class Norm(nn.Module): 8 | def __init__(self, d_model, eps = 1e-6): 9 | super().__init__() 10 | 11 | self.size = d_model 12 | 13 | # create two learnable parameters to calibrate normalisation 14 | self.alpha = nn.Parameter(torch.ones(self.size)) 15 | self.bias = nn.Parameter(torch.zeros(self.size)) 16 | 17 | self.eps = eps 18 | 19 | def forward(self, x): 20 | norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \ 21 | / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias 22 | return norm 23 | 24 | def attention(q, k, v, d_k, mask=None, dropout=None): 25 | 26 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) 27 | 28 | if mask is not None: 29 | mask = mask.unsqueeze(1) 30 | scores = scores.masked_fill(mask == 0, -1e9) 31 | 32 | scores = F.softmax(scores, dim=-1) 33 | 34 | if dropout is not None: 35 | scores = dropout(scores) 36 | 37 | output = torch.matmul(scores, v) 38 | return output 39 | 40 | class MultiHeadAttention(nn.Module): 41 | def __init__(self, heads, d_model, dropout = 0.1): 42 | super().__init__() 43 | 44 | self.d_model = d_model 45 | self.d_k = d_model // heads 46 | self.h = heads 47 | 48 | self.q_linear = nn.Linear(d_model, d_model) 49 | self.v_linear = nn.Linear(d_model, d_model) 50 | self.k_linear = nn.Linear(d_model, d_model) 51 | 52 | self.dropout = nn.Dropout(dropout) 53 | self.out = nn.Linear(d_model, d_model) 54 | 55 | def forward(self, q, k, v, mask=None): 56 | 57 | bs = q.size(0) 58 | 59 | # perform linear operation and split into N heads 60 | k = self.k_linear(k).view(bs, -1, self.h, self.d_k) 61 | q = self.q_linear(q).view(bs, -1, self.h, self.d_k) 62 | v = self.v_linear(v).view(bs, -1, self.h, self.d_k) 63 | 64 | # transpose to get dimensions bs * N * sl * d_model 65 | k = k.transpose(1,2) 66 | q = q.transpose(1,2) 67 | v = v.transpose(1,2) 68 | 69 | 70 | # calculate attention using function we will define next 71 | scores = attention(q, k, v, self.d_k, mask, self.dropout) 72 | # concatenate heads and put through final linear layer 73 | concat = scores.transpose(1,2).contiguous()\ 74 | .view(bs, -1, self.d_model) 75 | output = self.out(concat) 76 | 77 | return output 78 | 79 | class FeedForward(nn.Module): 80 | def __init__(self, d_model, d_ff=2048, dropout = 0.1): 81 | super().__init__() 82 | 83 | # We set d_ff as a default to 2048 84 | self.linear_1 = nn.Linear(d_model, d_ff) 85 | self.dropout = nn.Dropout(dropout) 86 | self.linear_2 = nn.Linear(d_ff, d_model) 87 | 88 | def forward(self, x): 89 | x = self.dropout(F.relu(self.linear_1(x))) 90 | x = self.linear_2(x) 91 | return x 92 | -------------------------------------------------------------------------------- /start_kit/build_vocab.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import pickle 3 | import argparse 4 | from collections import Counter 5 | 6 | CAP_FILE = 'data/captions/cap.{}.train.json' 7 | DICT_OUTPUT_FILE = 'data/captions/dict.{}.json' 8 | 9 | class Vocabulary(object): 10 | """Simple vocabulary wrapper.""" 11 | def __init__(self): 12 | self.word2idx = {} 13 | self.idx2word = {} 14 | self.idx = 0 15 | 16 | def add_word(self, word): 17 | if not word in self.word2idx: 18 | self.word2idx[word] = self.idx 19 | self.idx2word[self.idx] = word 20 | self.idx += 1 21 | 22 | def __call__(self, word): 23 | if not word in self.word2idx: 24 | return self.word2idx[''] 25 | return self.word2idx[word] 26 | 27 | def __len__(self): 28 | return len(self.word2idx) 29 | 30 | def init_vocab(self): 31 | self.add_word('') 32 | self.add_word('') 33 | self.add_word('') 34 | self.add_word('') 35 | self.add_word('') 36 | 37 | def save(self, file_name): 38 | data = {} 39 | data['word2idx'] = self.word2idx 40 | data['idx2word'] = self.idx2word 41 | data['idx'] = self.idx 42 | import json 43 | with open(file_name, 'w') as f: 44 | json.dump(data, f, indent=4) 45 | return 46 | 47 | def load(self, file_name): 48 | import json 49 | with open(file_name, 'r') as f: 50 | data = json.load(f) 51 | self.word2idx = data['word2idx'] 52 | self.idx2word = data['idx2word'] 53 | self.idx = data['idx'] 54 | return 55 | 56 | def build_vocab(cap_file, threshold): 57 | """Build a simple vocabulary wrapper.""" 58 | import json 59 | data = json.load(open(cap_file, 'r')) 60 | # with open(json, 'rb') as f: 61 | # [data] = pickle.load(f) 62 | 63 | counter = Counter() 64 | for i in range(len(data)): 65 | captions = data[i]['captions'] 66 | for caption in captions: 67 | tokens = nltk.tokenize.word_tokenize(caption.lower()) 68 | counter.update(tokens) 69 | 70 | if (i+1) % 1000 == 0: 71 | print("[{}/{}] Tokenized the captions.".format(i+1, len(data))) 72 | # break 73 | 74 | # If the word frequency is less than 'threshold', then the word is discarded. 75 | words = [word for word, cnt in counter.items() if cnt >= threshold] 76 | 77 | # Create a vocab wrapper and add some special tokens. 78 | vocab = Vocabulary() 79 | vocab.init_vocab() 80 | 81 | # Add the words to the vocabulary. 82 | for i, word in enumerate(words): 83 | vocab.add_word(word) 84 | 85 | return vocab 86 | 87 | def main(args): 88 | vocab = build_vocab(cap_file=CAP_FILE.format(args.data_set), threshold=args.threshold) 89 | vocab.save(DICT_OUTPUT_FILE.format(args.data_set)) 90 | print("Total vocabulary size: {}".format(len(vocab))) 91 | 92 | if __name__ == '__main__': 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('--data_set', type=str, default='dress') 95 | parser.add_argument('--threshold', type=int, default=2, 96 | help='minimum word count threshold') 97 | args = parser.parse_args() 98 | main(args) -------------------------------------------------------------------------------- /transformer/user_modeling/build_vocab.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import pickle 3 | import argparse 4 | from collections import Counter 5 | 6 | # CAP_FILE = './data/captions/{}.tsv' 7 | # DICT_OUTPUT_FILE = './data/captions/{}.json' 8 | 9 | class Vocabulary(object): 10 | """Simple vocabulary wrapper.""" 11 | def __init__(self): 12 | self.word2idx = {} 13 | self.idx2word = {} 14 | self.idx = 0 15 | 16 | def add_word(self, word): 17 | if not word in self.word2idx: 18 | self.word2idx[word] = self.idx 19 | self.idx2word[self.idx] = word 20 | self.idx += 1 21 | 22 | def __call__(self, word): 23 | if not word in self.word2idx: 24 | return self.word2idx[''] 25 | return self.word2idx[word] 26 | 27 | def __len__(self): 28 | return len(self.word2idx) 29 | 30 | def init_vocab(self): 31 | self.add_word('') 32 | self.add_word('') 33 | self.add_word('') 34 | self.add_word('') 35 | self.add_word('') 36 | 37 | 38 | def save(self, file_name): 39 | data = {} 40 | data['word2idx'] = self.word2idx 41 | data['idx2word'] = self.idx2word 42 | data['idx'] = self.idx 43 | import json 44 | with open(file_name, 'w') as f: 45 | json.dump(data, f, indent=4) 46 | return 47 | 48 | def load(self, file_name): 49 | import json 50 | with open(file_name, 'r') as f: 51 | data = json.load(f) 52 | self.word2idx = data['word2idx'] 53 | self.idx2word = data['idx2word'] 54 | self.idx = data['idx'] 55 | return 56 | 57 | def build_vocab(cap_file, threshold): 58 | """Build a simple vocabulary wrapper.""" 59 | import json 60 | data = json.load(open(cap_file, 'r')) 61 | # with open(json, 'rb') as f: 62 | # [data] = pickle.load(f) 63 | 64 | counter = Counter() 65 | print('total number of image pairs',len(data)) 66 | for i in range(len(data)): 67 | captions = data[i]['captions'] 68 | # for caption in captions: 69 | tokens = nltk.tokenize.word_tokenize(captions.lower()) 70 | counter.update(tokens) 71 | 72 | if (i+1) % 1000 == 0: 73 | print("[{}/{}] Tokenized the captions.".format(i+1, len(data))) 74 | # break 75 | 76 | # If the word frequency is less than 'threshold', then the word is discarded. 77 | words = [word for word, cnt in counter.items() if cnt >= threshold] 78 | 79 | # Create a vocab wrapper and add some special tokens. 80 | vocab = Vocabulary() 81 | vocab.init_vocab() 82 | 83 | # Add the words to the vocabulary. 84 | for i, word in enumerate(words): 85 | vocab.add_word(word) 86 | 87 | return vocab 88 | 89 | def main(args): 90 | vocab = build_vocab(cap_file=args.data_set_path, threshold=args.threshold) 91 | vocab.save(args.save_output_path) 92 | print("Total vocabulary size: {}".format(len(vocab))) 93 | 94 | if __name__ == '__main__': 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument('--data_set_path', type=str, default='dress') 97 | parser.add_argument('--save_output_path', type=str, default='dress') 98 | parser.add_argument('--threshold', type=int, default=2, 99 | help='minimum word count threshold') 100 | args = parser.parse_args() 101 | main(args) -------------------------------------------------------------------------------- /transformer/user_modeling/resize_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PIL import Image 4 | 5 | from joblib import Parallel, delayed 6 | import multiprocessing 7 | 8 | def resize_image(image, size): 9 | """Resize an image to the given size.""" 10 | return image.resize(size, Image.ANTIALIAS) 11 | 12 | def resize_images(image_dir, output_dir, size): 13 | """Resize the images in 'image_dir' and save into 'output_dir'.""" 14 | if not os.path.exists(output_dir): 15 | os.makedirs(output_dir) 16 | 17 | images = os.listdir(image_dir) 18 | num_images = len(images) 19 | for i, image in enumerate(images): 20 | print(image) 21 | with open(os.path.join(image_dir, image), 'r+b') as f: 22 | with Image.open(f) as img: 23 | img = resize_image(img, size) 24 | img.save(os.path.join(output_dir, image), img.format) 25 | if (i+1) % 100 == 0: 26 | print ("[{}/{}] Resized the images and saved into '{}'." 27 | .format(i+1, num_images, output_dir)) 28 | 29 | def resize_image_operator(image_file, output_file, size, i, num_images): 30 | with open(image_file, 'r+b') as f: 31 | with Image.open(f) as img: 32 | img = resize_image(img, size) 33 | img.save(output_file, img.format) 34 | if (i + 1) % 100 == 0: 35 | print("[{}/{}] Resized the images and saved." 36 | .format(i + 1, num_images)) 37 | return 38 | 39 | def resize_images_parallel(image_dir, output_dir, size): 40 | """Resize the images in 'image_dir' and save into 'output_dir'.""" 41 | if not os.path.exists(output_dir): 42 | os.makedirs(output_dir) 43 | num_cores = multiprocessing.cpu_count() 44 | print('resize on {} CPUs'.format(num_cores)) 45 | 46 | images = os.listdir(image_dir) 47 | num_images = len(images) 48 | Parallel(n_jobs=num_cores)( 49 | delayed(resize_image_operator)( 50 | os.path.join(image_dir, image), 51 | os.path.join(output_dir, image), 52 | size, 53 | i, 54 | num_images) for i, image in enumerate(images)) 55 | 56 | 57 | # def main(): 58 | # image_dir = '../data/images/' 59 | # output_dir = '../data/revised_images' 60 | # image_size = [256, 256] 61 | # resize_images(image_dir, output_dir, image_size) 62 | 63 | # if __name__ == '__main__': 64 | # # parser = argparse.ArgumentParser() 65 | # # parser.add_argument('--image_dir', type=str, default='../data/images/', 66 | # # help='directory for train images') 67 | # # parser.add_argument('--output_dir', type=str, default='../data/revised_images', 68 | # # help='directory for saving resized images') 69 | # # parser.add_argument('--image_size', type=int, default=256, 70 | # # help='size for image after processing') 71 | # # args = parser.parse_args() 72 | # main() 73 | 74 | def main(args): 75 | image_dir = args.image_dir 76 | output_dir = args.output_dir 77 | image_size = [args.image_size, args.image_size] 78 | resize_images_parallel(image_dir, output_dir, image_size) 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--image_dir', type=str, default='./data/train2014/', 84 | help='directory for train images') 85 | parser.add_argument('--output_dir', type=str, default='./data/resized2014/', 86 | help='directory for saving resized images') 87 | parser.add_argument('--image_size', type=int, default=256, 88 | help='size for image after processing') 89 | args = parser.parse_args() 90 | main(args) 91 | -------------------------------------------------------------------------------- /start_kit/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | 10 | class DummyImageEncoder(nn.Module): 11 | def __init__(self, embed_size): 12 | """Load the pretrained ResNet-152 and replace top fc layer.""" 13 | super(DummyImageEncoder, self).__init__() 14 | resnet = models.resnet152(pretrained=True) 15 | modules = list(resnet.children())[:-1] # delete the last fc layer. 16 | self.resnet = nn.Sequential(*modules) 17 | self.linear = nn.Linear(resnet.fc.in_features, embed_size) 18 | self.bn = nn.BatchNorm1d(resnet.fc.in_features, momentum=0.01) 19 | 20 | def get_trainable_parameters(self): 21 | return list(self.bn.parameters()) + list(self.linear.parameters()) 22 | 23 | def load_resnet(self, resnet=None): 24 | if resnet is None: 25 | resnet = models.resnet152(pretrained=True) 26 | modules = list(resnet.children())[:-1] # delete the last fc layer. 27 | self.resnet = nn.Sequential(*modules) 28 | self.resnet_in_features = resnet.fc.in_features 29 | else: 30 | self.resnet = resnet 31 | return 32 | 33 | def delete_resnet(self): 34 | resnet = self.resnet 35 | self.resnet = None 36 | return resnet 37 | 38 | def forward(self, image): 39 | with torch.no_grad(): 40 | img_ft = self.resnet(image) 41 | 42 | out = self.linear(self.bn(img_ft.reshape(img_ft.size(0), -1))) 43 | return out 44 | 45 | 46 | class DummyCaptionEncoder(nn.Module): 47 | def __init__(self, vocab_size, vocab_embed_size, embed_size): 48 | super(DummyCaptionEncoder, self).__init__() 49 | self.out_linear = nn.Linear(embed_size, embed_size, bias=False) 50 | self.rnn = nn.GRU(vocab_embed_size, embed_size) 51 | self.embed = nn.Embedding(vocab_size, vocab_embed_size) 52 | 53 | def forward(self, input, lengths): 54 | input = self.embed(input) 55 | lengths = torch.LongTensor(lengths) 56 | [_, sort_ids] = torch.sort(lengths, descending=True) 57 | sorted_input = input[sort_ids] 58 | sorted_length = lengths[sort_ids] 59 | reverse_sort_ids = sort_ids.clone() 60 | for i in range(sort_ids.size(0)): 61 | reverse_sort_ids[sort_ids[i]] = i 62 | packed = pack_padded_sequence(sorted_input, sorted_length, batch_first=True) 63 | output, _ = self.rnn(packed) 64 | padded, output_length = torch.nn.utils.rnn.pad_packed_sequence(output) 65 | output = [padded[output_length[i]-1, i, :] for i in range(len(output_length))] 66 | output = torch.stack([output[reverse_sort_ids[i]] for i in range(len(output))], dim=0) 67 | output = self.out_linear(output) 68 | return output 69 | 70 | def get_trainable_parameters(self): 71 | return list(self.parameters()) 72 | 73 | # 74 | # model = DummyCaptionEncoder(100, 64, 10) 75 | # 76 | # x1 = [ 77 | # [45, 4, 7, 9, 2, 0, 0], 78 | # [11, 2, 3, 4, 5, 6, 7], 79 | # [99, 98, 97, 96, 7, 8, 0], 80 | # [89, 87, 86, 2, 0, 0, 0] 81 | # ] 82 | # len1 = [5, 2, 3, 2] 83 | # x1 = torch.tensor(x1) 84 | # y1 = model(x1, len1) 85 | # 86 | # x2 = [ 87 | # [56, 56, 3, 0, 0, 0, 0], 88 | # [89, 87, 86, 1, 0, 0, 0], 89 | # [1, 36, 4, 7, 8, 4, 0], 90 | # [99, 98, 97, 96, 4, 0, 0] 91 | # ] 92 | # len2 = [2, 2, 5, 3] 93 | # x2 = torch.tensor(x2) 94 | # y2 = model(x2, len2) 95 | # 96 | # print('max dif 1', (y1[3,:] - y2[1,:]).max(), (y1[3,:] - y2[1,:]).min()) 97 | # print('max dif 2', (y1[2,:] - y2[3,:]).max(), (y1[2,:] - y2[3,:]).min()) -------------------------------------------------------------------------------- /start_kit/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import json 5 | import math 6 | from PIL import Image 7 | from joblib import Parallel, delayed 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | 12 | def create_exp_dir(path, scripts_to_save=None): 13 | if not os.path.exists(path): 14 | os.mkdir(path) 15 | print('Experiment dir : {}'.format(path)) 16 | if scripts_to_save is not None: 17 | if not os.path.exists(os.path.join(path, 'scripts')): 18 | os.mkdir(os.path.join(path, 'scripts')) 19 | for script in scripts_to_save: 20 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 21 | shutil.copyfile(script, dst_file) 22 | return 23 | 24 | 25 | class Ranker(): 26 | def __init__(self, root, image_split_file, transform=None, num_workers=16): 27 | self.num_workers = num_workers 28 | self.root = root 29 | with open(image_split_file, 'r') as f: 30 | data = json.load(f) 31 | self.data = data 32 | self.ids = range(len(self.data)) 33 | self.transform = transform 34 | return 35 | 36 | def get_item(self, index): 37 | data = self.data 38 | id = self.ids[index] 39 | img_name = data[id] + '.jpg' 40 | image = Image.open(os.path.join(self.root, img_name)).convert('RGB') 41 | if self.transform is not None: 42 | image = self.transform(image) 43 | return image, data[id] 44 | 45 | def get_items(self, indexes): 46 | items = Parallel(n_jobs=self.num_workers)( 47 | delayed(self.get_item)( 48 | i) for i in indexes) 49 | images, meta_info = zip(*items) 50 | images = torch.stack(images, dim=0) 51 | return images, meta_info 52 | 53 | def update_emb(self, image_encoder, batch_size=64): 54 | data_emb = [] 55 | data_asin = [] 56 | num_data = len(self.data) 57 | num_batch = math.floor(num_data / batch_size) 58 | print('updating emb') 59 | for i in range(num_batch): 60 | batch_ids = torch.LongTensor([i for i in range(i * batch_size, (i + 1) * batch_size)]) 61 | images, asins = self.get_items(batch_ids) 62 | images = images.to(device) 63 | with torch.no_grad(): 64 | feat = image_encoder(images) 65 | data_emb.append(feat) 66 | data_asin.extend(asins) 67 | 68 | if num_batch * batch_size < num_data: 69 | batch_ids = torch.LongTensor([i for i in range(num_batch * batch_size, num_data)]) 70 | images, asins = self.get_items(batch_ids) 71 | images = images.to(device) 72 | with torch.no_grad(): 73 | feat = image_encoder(images) 74 | data_emb.append(feat) 75 | data_asin.extend(asins) 76 | 77 | self.data_emb = torch.cat(data_emb, dim=0) 78 | self.data_asin = data_asin 79 | print('emb updated') 80 | return 81 | 82 | def compute_rank(self, inputs, target_ids): 83 | rankings = [] 84 | for i in range(inputs.size(0)): 85 | distances = (self.data_emb - inputs[i]).pow(2).sum(dim=1) 86 | ranking = (distances < distances[self.data_asin.index(target_ids[i])]).sum(dim=0) 87 | rankings.append(ranking) 88 | return torch.FloatTensor(rankings).to(device) 89 | 90 | def get_nearest_neighbors(self, inputs, topK=50): 91 | neighbors = [] 92 | for i in range(inputs.size(0)): 93 | [_, neighbor] = (self.data_emb - inputs[i]).pow(2).sum(dim=1).topk(dim=0, k=topK, largest=False, sorted=True) 94 | neighbors.append(neighbor) 95 | return torch.stack(neighbors, dim=0).to(device) -------------------------------------------------------------------------------- /start_kit/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import json 4 | import os 5 | from torchvision import transforms 6 | from data_loader import get_loader 7 | from build_vocab import Vocabulary 8 | from models import DummyImageEncoder, DummyCaptionEncoder 9 | from utils import Ranker 10 | 11 | # Device configuration 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | # Paths to data 14 | IMAGE_ROOT = 'data/resized_images/{}/' 15 | CAPT = 'data/captions/cap.{}.{}.json' 16 | DICT = 'data/captions/dict.{}.json' 17 | SPLIT = 'data/image_splits/split.{}.{}.json' 18 | 19 | 20 | def evaluate(args): 21 | # Image pre-processing, normalization for the pre-trained resnet 22 | transform_test = transforms.Compose([ 23 | transforms.CenterCrop(args.crop_size), 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.485, 0.456, 0.406), 26 | (0.229, 0.224, 0.225))]) 27 | vocab = Vocabulary() 28 | vocab.load(DICT.format(args.data_set)) 29 | # Build data loader 30 | data_loader_test = get_loader(IMAGE_ROOT.format(args.data_set), 31 | CAPT.format(args.data_set, args.data_split), 32 | vocab, transform_test, 33 | args.batch_size, shuffle=False, return_target=False, num_workers=args.num_workers) 34 | ranker = Ranker(root=IMAGE_ROOT.format(args.data_set), image_split_file=SPLIT.format(args.data_set, args.data_split), 35 | transform=transform_test, num_workers=args.num_workers) 36 | 37 | # Build the dummy models 38 | image_encoder = DummyImageEncoder(args.embed_size).to(device) 39 | caption_encoder = DummyCaptionEncoder(vocab_size=len(vocab), vocab_embed_size=args.embed_size * 2, 40 | embed_size=args.embed_size).to(device) 41 | # load trained models 42 | image_model = os.path.join(args.model_folder, 43 | "image-{}.th".format(args.embed_size)) 44 | resnet = image_encoder.delete_resnet() 45 | image_encoder.load_state_dict(torch.load(image_model, map_location=device)) 46 | image_encoder.load_resnet(resnet) 47 | 48 | cap_model = os.path.join(args.model_folder, 49 | "cap-{}.th".format(args.embed_size)) 50 | caption_encoder.load_state_dict(torch.load(cap_model, map_location=device)) 51 | 52 | ranker.update_emb(image_encoder) 53 | image_encoder.eval() 54 | caption_encoder.eval() 55 | 56 | output = json.load(open(CAPT.format(args.data_set, args.data_split))) 57 | 58 | index = 0 59 | for _, candidate_images, captions, lengths, meta_info in data_loader_test: 60 | with torch.no_grad(): 61 | candidate_images = candidate_images.to(device) 62 | candidate_ft = image_encoder.forward(candidate_images) 63 | captions = captions.to(device) 64 | caption_ft = caption_encoder(captions, lengths) 65 | rankings = ranker.get_nearest_neighbors(candidate_ft + caption_ft) 66 | # print(rankings) 67 | for j in range(rankings.size(0)): 68 | output[index]['ranking'] = [ranker.data_asin[rankings[j, m].item()] for m in range(rankings.size(1))] 69 | index += 1 70 | 71 | json.dump(output, open("{}.{}.pred.json".format(args.data_set, args.data_split), 'w'), indent=4) 72 | print('eval completed. Output file: {}'.format("{}.{}.pred.json".format(args.data_set, args.data_split))) 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument('--model_folder', type=str, default='models/dress-20190612-112918/', 78 | help='path for trained models') 79 | parser.add_argument('--crop_size', type=int, default=224, 80 | help='size for randomly cropping images') 81 | parser.add_argument('--data_set', type=str, default='dress') 82 | parser.add_argument('--data_split', type=str, default='test') 83 | # Model parameters 84 | parser.add_argument('--embed_size', type=int, default=512, 85 | help='dimension of word embedding vectors') 86 | # Learning parameters 87 | parser.add_argument('--batch_size', type=int, default=2) 88 | parser.add_argument('--num_workers', type=int, default=16) 89 | args = parser.parse_args() 90 | evaluate(args) 91 | -------------------------------------------------------------------------------- /transformer/user_modeling/Beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | from Models import nopeak_mask, create_masks 5 | 6 | def init_vars(image0, image1, model, opt, vocab, image0_attribute, image1_attribute): 7 | 8 | init_tok = vocab.word2idx[''] 9 | # src_mask = (src != SRC.vocab.stoi['']).unsqueeze(-2) 10 | image0 = model.cnn1(image0) 11 | 12 | image1 = model.cnn2(image1) 13 | 14 | if model.add_attribute: 15 | 16 | attribute = model.attribute_embedding(image0_attribute - image1_attribute).unsqueeze(1) 17 | # attribute = self.norm(attribute) 18 | 19 | # image0_attribute = self.attribute_embedding1(image0_attribute) 20 | 21 | # image1_attribute = self.attribute_embedding2(image1_attribute) 22 | 23 | # image0 = torch.cat((image0, image0_attribute), 1) 24 | # image1 = torch.cat((image1, image1_attribute), 1) 25 | 26 | #joint_encoding = self.joint_encoding(torch.cat((image0, image0_attribute),1), torch.cat((image1,image1_attribute),1)) 27 | joint_encoding = model.joint_encoding(image0, image1) 28 | joint_encoding = torch.cat((joint_encoding, attribute), 1) 29 | # joint_encoding = model.bn(joint_encoding.transpose(1,2)).transpose(1,2) 30 | 31 | else: 32 | joint_encoding = model.joint_encoding(image0, image1) 33 | 34 | e_output = model.encoder(joint_encoding) 35 | 36 | outputs = torch.LongTensor([[init_tok]]).to(opt.device) 37 | 38 | trg_mask = nopeak_mask(1).to(opt.device) 39 | 40 | out = model.out(model.decoder(outputs, e_output, trg_mask))# (batch_size, seq_len, vocab_size) 41 | 42 | out = F.softmax(out, dim=-1) 43 | 44 | probs, ix = out[:, -1].data.topk(opt.beam_size) 45 | 46 | log_scores = torch.Tensor([math.log(prob) for prob in probs.data[0]]).unsqueeze(0) 47 | 48 | outputs = torch.zeros(opt.beam_size, opt.max_seq_len).long().to(opt.device) 49 | 50 | outputs[:, 0] = init_tok 51 | 52 | outputs[:, 1] = ix[0] 53 | 54 | e_outputs = torch.zeros(opt.beam_size, e_output.size(-2),e_output.size(-1)).to(opt.device) 55 | 56 | e_outputs[:, :] = e_output[0] 57 | 58 | return outputs, e_outputs, log_scores 59 | 60 | def k_best_outputs(outputs, out, log_scores, i, k): 61 | 62 | probs, ix = out[:, -1].data.topk(k) 63 | 64 | log_probs = torch.Tensor([math.log(p) for p in probs.data.view(-1)]).view(k, -1) + log_scores.transpose(0,1) 65 | 66 | k_probs, k_ix = log_probs.view(-1).topk(k) 67 | 68 | row = k_ix // k 69 | col = k_ix % k 70 | 71 | outputs[:, :i] = outputs[row, :i] 72 | outputs[:, i] = ix[row, col] 73 | 74 | log_scores = k_probs.unsqueeze(0) 75 | 76 | return outputs, log_scores 77 | 78 | def beam_search(image0, image1, model, opt, vocab, image0_attribute, image1_attribute): 79 | 80 | 81 | outputs, e_outputs, log_scores = init_vars(image0, image1, model, opt, vocab, image0_attribute, image1_attribute) 82 | eos_tok = vocab.word2idx[''] 83 | ind = None 84 | for i in range(2, opt.max_seq_len): 85 | 86 | trg_mask = nopeak_mask(i).to(opt.device) 87 | 88 | out = model.out(model.decoder(outputs[:,:i], e_outputs, trg_mask)) 89 | 90 | out = F.softmax(out, dim=-1) 91 | 92 | outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, opt.beam_size) 93 | 94 | ones = (outputs==eos_tok).nonzero() # Occurrences of end symbols for all input sentences. 95 | 96 | sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda() 97 | for vec in ones: 98 | i = vec[0] 99 | if sentence_lengths[i]==0: # First end symbol has not been found yet 100 | sentence_lengths[i] = vec[1] # Position of first end symbol 101 | 102 | num_finished_sentences = len([s for s in sentence_lengths if s > 0]) 103 | 104 | if num_finished_sentences == opt.beam_size: 105 | alpha = 0.7 106 | div = 1/(sentence_lengths.type_as(log_scores)**alpha) 107 | _, ind = torch.max(log_scores * div, 1) 108 | ind = ind.data[0] 109 | break 110 | 111 | if ind is None: 112 | # length = (outputs[0]==eos_tok).nonzero()[0] 113 | # return ' '.join([vocab.idx2word[str(tok.item())] for tok in outputs[0][1:length]]) 114 | return ' '.join([vocab.idx2word[str(tok.item())] for tok in outputs[0][1:]]) 115 | 116 | else: 117 | length = (outputs[ind]==eos_tok).nonzero()[0] 118 | return ' '.join([vocab.idx2word[str(tok.item())] for tok in outputs[ind][1:length]]) 119 | -------------------------------------------------------------------------------- /transformer/user_modeling/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import json 3 | import argparse 4 | import csv 5 | def preprocess_data(tsv_file_path): 6 | """load tsv_file and save image1_path and image2_path and caption in json and save it. 7 | 8 | Args: 9 | tsv_file_path 10 | 11 | Returns: 12 | None 13 | """ 14 | with open(tsv_file_path) as tsvfile: 15 | reader = csv.reader(tsvfile, delimiter='\t') 16 | # i = 0 17 | data_train = [] 18 | data_dev = [] 19 | data_test = [] 20 | for i, row in enumerate(reader): 21 | if i == 0: 22 | continue 23 | image0 = process_url(row[1]) 24 | image1 = process_url(row[5]) 25 | text = row[10] 26 | 27 | if row[8] == 'train': 28 | data_train.append({'image0': image0, 'image1':image1, "captions":text}) 29 | elif row[8] == 'val': 30 | data_dev.append({'image0': image0, 'image1':image1, "captions":text}) 31 | elif row[8] == 'test': 32 | data_test.append({'image0': image0, 'image1':image1, "captions":text}) 33 | 34 | data_dev_combine = {} 35 | data_test_combine = {} 36 | 37 | for data in data_dev: 38 | key = data["image0"] + data["image1"] 39 | cap = data["captions"] 40 | if key in data_dev_combine: 41 | temp = data_dev_combine[key] 42 | temp["captions"].append(cap) 43 | else: 44 | data_dev_combine[key] = {'image0': data["image0"], 'image1':data["image1"], "captions":[cap]} 45 | 46 | temp = data_dev_combine.values() 47 | 48 | data_dev_new = [] 49 | 50 | for t in temp: 51 | data_dev_new.append(t) 52 | assert len(t['captions']) > 1 53 | 54 | for data in data_test: 55 | key = data["image0"] + data["image1"] 56 | cap = data["captions"] 57 | if key in data_test_combine: 58 | temp = data_test_combine[key] 59 | temp["captions"].append(cap) 60 | else: 61 | data_test_combine[key] = {'image0': data["image0"], 'image1':data["image1"], "captions":[cap]} 62 | 63 | temp = data_test_combine.values() 64 | 65 | data_test_new = [] 66 | 67 | for t in temp: 68 | data_test_new.append(t) 69 | assert len(t['captions']) > 1 70 | 71 | 72 | 73 | with open('./data_train.json', 'w') as outfile: 74 | json.dump(data_train, outfile) 75 | 76 | with open('./data_dev.json', 'w') as outfile: 77 | json.dump(data_dev, outfile) 78 | 79 | with open('./data_test.json', 'w') as outfile: 80 | json.dump(data_test, outfile) 81 | 82 | with open('./data_dev_combine.json', 'w') as outfile: 83 | json.dump(data_dev_new, outfile) 84 | 85 | with open('./data_test_combine.json', 'w') as outfile: 86 | json.dump(data_test_new, outfile) 87 | 88 | 89 | 90 | 91 | 92 | 93 | def parse_url(url): 94 | # print('url', url) 95 | tokens = url.split('/') 96 | # print(tokens) 97 | folder = tokens[4] 98 | tokens = tokens[5].split('?') 99 | tokens.reverse() 100 | file = '.'.join(tokens) 101 | # print(tokens[1]) 102 | # print(tokens) 103 | # if len(tokens) > 1: 104 | # file = tokens[1] 105 | # else: 106 | # file = 'null' 107 | # print(tokens[4], tokens[5]) 108 | # print(folder, file) 109 | return '/dccstor/extrastore/Neural-Naturalist/data/resized_images/' + folder + '.' + file 110 | 111 | 112 | def process_url(url): 113 | file = parse_url(url) 114 | if file[-1] == '.': 115 | file = file + 'jpg' 116 | # make_folder(folder) 117 | 118 | # if not os.path.isfile(file): 119 | # with open(file, 'wb') as f: 120 | # resp = requests.get(url, verify=False) 121 | # f.write(resp.content) 122 | # f.close() 123 | return file 124 | 125 | def main(args): 126 | ''' Main function ''' 127 | 128 | 129 | 130 | preprocess_data(args.tsv_file_path) 131 | 132 | 133 | if __name__ == '__main__': 134 | 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('-tsv_file_path', required=True) 137 | # parser.add_argument('-train_tgt', required=True) 138 | # parser.add_argument('-valid_src', required=True) 139 | # parser.add_argument('-valid_tgt', required=True) 140 | # parser.add_argument('-save_data', required=True) 141 | # parser.add_argument('-max_len', '--max_word_seq_len', type=int, default=50) 142 | # parser.add_argument('-min_word_count', type=int, default=5) 143 | # parser.add_argument('-keep_case', action='store_true') 144 | # parser.add_argument('-share_vocab', action='store_true') 145 | # parser.add_argument('-vocab', default=None) 146 | args = parser.parse_args() 147 | main(args) 148 | -------------------------------------------------------------------------------- /start_kit/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torch.utils.data as data 4 | import os 5 | import nltk 6 | from PIL import Image 7 | import json 8 | 9 | 10 | class Dataset(data.Dataset): 11 | 12 | def __init__(self, root, data_file_name, vocab, transform=None, return_target=True): 13 | """Set the path for images, captions and vocabulary wrapper. 14 | 15 | Args: 16 | root: image directory. 17 | data: index file name. 18 | transform: image transformer. 19 | vocab: pre-processed vocabulary. 20 | """ 21 | self.root = root 22 | with open(data_file_name, 'r') as f: 23 | self.data = json.load(f) 24 | self.ids = range(len(self.data)) 25 | self.vocab = vocab 26 | self.transform = transform 27 | self.return_target = return_target 28 | 29 | def __getitem__(self, index): 30 | """Returns one data pair (image and concatenated captions).""" 31 | data = self.data 32 | vocab = self.vocab 33 | id = self.ids[index] 34 | 35 | candidate_asin = data[id]['candidate'] 36 | candidate_img_name = candidate_asin + '.jpg' 37 | candidate_image = Image.open(os.path.join(self.root, candidate_img_name)).convert('RGB') 38 | if self.transform is not None: 39 | candidate_image = self.transform(candidate_image) 40 | 41 | if self.return_target: 42 | target_asin = data[id]['target'] 43 | target_img_name = target_asin + '.jpg' 44 | target_image = Image.open(os.path.join(self.root, target_img_name)).convert('RGB') 45 | if self.transform is not None: 46 | target_image = self.transform(target_image) 47 | else: 48 | target_image = candidate_image 49 | target_asin = '' 50 | 51 | caption_texts = data[id]['captions'] 52 | # Convert caption (string) to word ids. 53 | tokens = nltk.tokenize.word_tokenize(str(caption_texts[0]).lower()) + [''] + \ 54 | nltk.tokenize.word_tokenize(str(caption_texts[1]).lower()) 55 | caption = [] 56 | caption.append(vocab('')) 57 | caption.extend([vocab(token) for token in tokens]) 58 | caption.append(vocab('')) 59 | caption = torch.Tensor(caption) 60 | 61 | return target_image, candidate_image, caption, {'target': target_asin, 'candidate': candidate_asin, 'caption': caption_texts} 62 | 63 | def __len__(self): 64 | return len(self.ids) 65 | 66 | 67 | def collate_fn(data): 68 | """Creates mini-batch tensors from the list of tuples (image, caption). 69 | 70 | Args: 71 | data: list of tuple (image, caption). 72 | - image: torch tensor of shape 73 | - caption: torch tensor of shape (?); variable length. 74 | 75 | Returns: 76 | images: torch tensor of images. 77 | targets: torch tensor of shape (batch_size, padded_length). 78 | lengths: list; valid length for each padded caption. 79 | """ 80 | # Sort a data list by caption length (descending order). 81 | target_images, candidate_images, captions, meta = zip(*data) 82 | 83 | # Merge images (from tuple of 3D tensor to 4D tensor). 84 | target_images = torch.stack(target_images, 0) 85 | candidate_images = torch.stack(candidate_images, 0) 86 | 87 | # Merge captions (from tuple of 1D tensor to 2D tensor). 88 | lengths = [len(cap) for cap in captions] 89 | captions_out = torch.zeros(len(captions), max(lengths)).long() 90 | for i, cap in enumerate(captions): 91 | end = lengths[i] 92 | captions_out[i, :end] = cap[:end] 93 | return target_images, candidate_images, captions_out, lengths, meta 94 | 95 | 96 | def get_loader(root, data_file_name, vocab, transform, batch_size, shuffle, return_target, num_workers): 97 | """Returns torch.utils.data.DataLoader for custom dataset.""" 98 | # relative caption dataset 99 | dataset = Dataset(root=root, 100 | data_file_name=data_file_name, 101 | vocab=vocab, 102 | transform=transform, 103 | return_target=return_target) 104 | 105 | # Data loader for the dataset 106 | # This will return (images, captions, lengths) for each iteration. 107 | # images: a tensor of shape (batch_size, 3, 224, 224). 108 | # captions: a tensor of shape (batch_size, padded_length). 109 | # lengths: a list indicating valid length for each caption. length is (batch_size) 110 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 111 | batch_size=batch_size, 112 | shuffle=shuffle, 113 | num_workers=num_workers, 114 | collate_fn=collate_fn, 115 | timeout=60) 116 | 117 | return data_loader 118 | -------------------------------------------------------------------------------- /transformer/interactive_retrieval/UserModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from user_model import load_trained_model 3 | import Vocabulary 4 | from Beam import beam_search, greedy_search 5 | import argparse 6 | 7 | 8 | class UserModel: 9 | def __init__(self, opt, mode='greedy_search'): 10 | self.model = load_trained_model( 11 | opt.user_model_file.format(opt.data_set)) 12 | self.model.eval() 13 | self.vocab = Vocabulary.Vocabulary() 14 | vocab_file = opt.user_vocab_file.format(opt.data_set) 15 | self.vocab.load(vocab_file) 16 | self.max_seq_len = opt.max_seq_len 17 | self.opt = opt 18 | self.decode_mode = mode 19 | return 20 | 21 | def to(self, device): 22 | self.model = self.model.to(device) 23 | 24 | def get_max_seq_len(self): 25 | return self.max_seq_len 26 | 27 | def get_vocab_size(self): 28 | return len(self.vocab) 29 | 30 | def get_caption(self, target_img, candidate_img, 31 | target_attr, candidate_attr, return_cap=False): 32 | pad_idx = self.vocab('') 33 | if self.decode_mode == 'beam_search': 34 | packed_results = [beam_search( 35 | candidate_img[i].unsqueeze(dim=0).unsqueeze(dim=0), 36 | target_img[i].unsqueeze(dim=0).unsqueeze(dim=0), 37 | self.model, self.opt, self.vocab, 38 | candidate_attr[i].unsqueeze(dim=0), 39 | target_attr[i].unsqueeze(dim=0)) 40 | for i in range(target_img.size(0))] 41 | 42 | pad_cap_idx = [] 43 | caps = [] 44 | for cap in packed_results: 45 | caps.append(cap[1]) 46 | if len(cap[0]) > self.max_seq_len: 47 | pad_cap_idx.append(cap[0][:self.max_seq_len]) 48 | else: 49 | pad_cap_idx.append( 50 | cap[0] + [pad_idx] * (self.max_seq_len - len(cap[0]))) 51 | 52 | pad_cap_idx = torch.tensor(pad_cap_idx, dtype=torch.long) 53 | else: 54 | pad_cap_idx, caps = greedy_search(candidate_img.unsqueeze(dim=1), 55 | target_img.unsqueeze(dim=1), 56 | self.model, self.opt, self.vocab, 57 | candidate_attr, target_attr) 58 | if return_cap: 59 | return pad_cap_idx, caps 60 | return pad_cap_idx 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--image_folder', type=str, 66 | default='../resized_images/{}/') 67 | 68 | parser.add_argument('--pretrained_image_model', type=str, 69 | default='../attribute_prediction/deepfashion_models/' 70 | 'dfattributes_efficientnet_b7ns.pth') 71 | 72 | parser.add_argument('--save', type=str, default='../models/ranker/', 73 | help='path for saving ranker models') 74 | parser.add_argument('--crop_size', type=int, default=224, 75 | help='size for randomly cropping images') 76 | parser.add_argument('--data_set', type=str, default='dress', 77 | help='dress / toptee / shirt') 78 | 79 | parser.add_argument('--rep_type', type=str, default='image', 80 | help='all / side_info / image ') 81 | 82 | parser.add_argument('--merger_type', type=str, default='attention', 83 | help='attention / sum-image / sum-all / sum-other') 84 | 85 | parser.add_argument('--log_step', type=int, default=44, 86 | help='step size for printing log info') 87 | parser.add_argument('--checkpoint', type=int, default=2, 88 | help='step size for saving models') 89 | parser.add_argument('--patient', type=int, default=3, 90 | help='patient for reducing learning rate') 91 | 92 | # User model parameters 93 | parser.add_argument('--user_model_file', type=str, 94 | default='../user_modeling/models/' 95 | 'dress-efficientnet-b7.chkpt') 96 | parser.add_argument('--user_vocab_file', type=str, 97 | default='../user_modeling/vocab.json') 98 | parser.add_argument('--max_seq_len', type=int, default=10, 99 | help='maximum caption length') 100 | parser.add_argument('--beam_size', type=int, default=5, 101 | help='beam search branch size') 102 | 103 | # Model parameters 104 | parser.add_argument('--history_input_size', type=int, default=256) 105 | parser.add_argument('--image_embed_size', type=int, default=256) 106 | parser.add_argument('--text_embed_size', type=int, default=256) 107 | parser.add_argument('--vocab_embed_size', type=int, default=256) 108 | 109 | parser.add_argument('--num_workers', type=int, default=4) 110 | parser.add_argument('--num_dialog_turns', type=int, default=5) 111 | parser.add_argument('--margin', type=float, default=0.1) 112 | parser.add_argument('--clip', type=float, default=10) 113 | 114 | parser.add_argument('--no_cuda', action='store_true') 115 | parser.add_argument('--batch_size', type=int, default=2) 116 | parser.add_argument('--learning_rate', type=float, default=0.0003) 117 | 118 | args = parser.parse_args() 119 | 120 | model = UserModel(args) 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /transformer/attribute_prediction/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import torchvision 4 | from torchvision import transforms 5 | import torch 6 | import torch.utils.data as data 7 | import json 8 | from PIL import Image 9 | import tqdm 10 | from efficientnet_pytorch import EfficientNet 11 | 12 | 13 | class Dataset(data.Dataset): 14 | def __init__(self, root, transform=None): 15 | """Set the path for images, captions and vocabulary wrapper. 16 | 17 | Args: 18 | root: image directory. 19 | data_file_name: asin --> [tag] 20 | transform: image transformer. 21 | """ 22 | self.root = root 23 | self.image_list = glob.glob(self.root + "/*") 24 | # print('image list', self.image_list) 25 | 26 | self.transform = transform 27 | return 28 | 29 | def __getitem__(self, index): 30 | """Returns one data pair (image and caption).""" 31 | 32 | img_name = self.image_list[index] 33 | asin = img_name.split('/')[-1] 34 | asin = asin.split('.')[0] 35 | image = Image.open(img_name).convert('RGB') 36 | if self.transform is not None: 37 | image = self.transform(image) 38 | 39 | return image, asin 40 | 41 | def __len__(self): 42 | return len(self.image_list) 43 | 44 | 45 | def collate_fn(data): 46 | """Creates mini-batch tensors from the list of tuples (image, tags). 47 | """ 48 | images, asins = zip(*data) 49 | 50 | # Merge images (from tuple of 3D tensor to 4D tensor). 51 | images = torch.stack(images, dim=0) 52 | 53 | return images, asins 54 | 55 | 56 | def get_loader(root, transform, batch_size, shuffle, num_workers=4): 57 | """Returns torch.utils.data.DataLoader for custom coco dataset.""" 58 | # cpu_num = multiprocessing.cpu_count() 59 | # num_workers = cpu_num - 2 if cpu_num > 2 else 1 60 | 61 | dataset = Dataset(root=root, 62 | transform=transform) 63 | 64 | data_loader = torch.utils.data.DataLoader( 65 | dataset=dataset, batch_size=batch_size, shuffle=shuffle, 66 | num_workers=num_workers, collate_fn=collate_fn, pin_memory=True) 67 | 68 | return data_loader 69 | 70 | 71 | def evaluate_model(data_loader, model, idx2attr, device, topk=5): 72 | model.eval() 73 | test_num = 0 74 | prediction = {} 75 | for images, asins in tqdm.tqdm(data_loader): 76 | # Set mini-batch dataset 77 | with torch.no_grad(): 78 | images = images.to(device) 79 | outs = model(images) 80 | top_scores, top_outs = outs.topk(dim=1, k=topk*2, largest=True) 81 | 82 | for j in range(images.size(0)): 83 | top_out_tags = [idx2attr[top_outs[j, m].item()] 84 | for m in range(topk*2)] 85 | 86 | prediction[asins[j]] = { 87 | 'predict': top_out_tags, 88 | 'pred_score': top_scores[j].cpu().numpy().tolist(), 89 | } 90 | 91 | return prediction 92 | 93 | 94 | def evaluate_attributes(args): 95 | device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu' 96 | # Build data loader 97 | # --Image preprocessing, normalization for the pretrained resnet 98 | transform = transforms.Compose([ 99 | transforms.Resize(args.crop_size), 100 | transforms.ToTensor(), 101 | transforms.Normalize((0.485, 0.456, 0.406), 102 | (0.229, 0.224, 0.225))]) 103 | 104 | image_folder = args.image_folder.format(args.data_set) 105 | 106 | with open(args.label_file, 'r') as f: 107 | attr2idx = json.load(f) 108 | 109 | idx2attr = {} 110 | for key, val in attr2idx.items(): 111 | idx2attr[val] = key 112 | 113 | model = EfficientNet.from_pretrained('efficientnet-b7') 114 | ckpt = torch.load(args.pretrained_model, map_location='cpu') 115 | if "model_state" in ckpt: 116 | model.load_state_dict(ckpt["model_state"]) 117 | else: 118 | model.load_state_dict(ckpt) 119 | model.to(device) 120 | 121 | def logging(s, print_=True): 122 | if print_: 123 | print(s) 124 | 125 | data_loader = get_loader( 126 | root=image_folder, 127 | transform=transform, 128 | batch_size=args.batch_size, 129 | shuffle=False) 130 | 131 | model.eval() 132 | logging('-' * 87) 133 | with torch.no_grad(): 134 | prediction = evaluate_model(data_loader, model, idx2attr, device) 135 | 136 | with open('fashion_iq_{}.json'.format(args.data_set), 'w') as f: 137 | json.dump(prediction, f, indent=4) 138 | 139 | 140 | if __name__ == '__main__': 141 | parser = argparse.ArgumentParser() 142 | # data 143 | parser.add_argument('--image_folder', type=str, 144 | default='../resized_images/{}/') 145 | parser.add_argument('--label_file', type=str, 146 | default='data/attribute2idx.json') 147 | parser.add_argument('--crop_size', type=int, default=224, 148 | help='size for randomly cropping images') 149 | parser.add_argument('--data_set', type=str, default='dress', 150 | help='dress / shirt / toptee') 151 | 152 | # model 153 | parser.add_argument('--pretrained_model', type=str, 154 | default='dfattributes_efficientnet_b7ns.pth') 155 | 156 | parser.add_argument('--no_cuda', action='store_true') 157 | parser.add_argument('--batch_size', type=int, default=2) 158 | 159 | args = parser.parse_args() 160 | print(args) 161 | evaluate_attributes(args) 162 | -------------------------------------------------------------------------------- /transformer/interactive_retrieval/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from torchvision import transforms 4 | from efficientnet_pytorch import EfficientNet 5 | import torch 6 | import tqdm 7 | import math 8 | import numpy as np 9 | # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | 12 | def get_image_batch(fts, batch_ids): 13 | image = fts['image'][batch_ids] 14 | return image 15 | 16 | 17 | def get_attribute_batch(fts, batch_ids): 18 | attribute = fts['attribute'][batch_ids] 19 | return attribute 20 | 21 | 22 | def extract_features(data_loader, attr_file, attr2idx_file, device, image_model, 23 | attribute_topk=8, batch_size=128): 24 | model = EfficientNet.from_pretrained('efficientnet-b7') 25 | ckpt = torch.load(image_model, map_location='cpu') 26 | print('[INFO] Loading weights from {}'.format(image_model)) 27 | if "model_state" in ckpt: 28 | model.load_state_dict(ckpt["model_state"]) 29 | else: 30 | model.load_state_dict(ckpt) 31 | model = model.to(device) 32 | model = model.eval() 33 | 34 | with open(attr_file, 'r') as f: 35 | predicted_attr = json.load(f) 36 | 37 | with open(attr2idx_file, 'r') as f: 38 | attr2idx = json.load(f) 39 | 40 | num_data = len(data_loader) 41 | num_batch = math.floor(num_data / batch_size) 42 | asins = [] 43 | image_ft = [] 44 | attributes = [] 45 | 46 | def compute_features(data): 47 | with torch.no_grad(): 48 | outs = model.extract_features(data) 49 | outs = model._avg_pooling(outs) 50 | outs = outs.flatten(start_dim=1) 51 | image_ft_batch = model._dropout(outs) 52 | return image_ft_batch 53 | 54 | def compute_attribute_idx(asin_batch): 55 | labels = [predicted_attr[asin[0]]['predict'][:attribute_topk] 56 | for asin in asin_batch] 57 | attribute_idx = [[attr2idx[attr] for attr in label] 58 | for label in labels] 59 | return attribute_idx 60 | 61 | def append_batch(first, last): 62 | batch_ids = torch.tensor( 63 | [j for j in range(first, last)], 64 | dtype=torch.long, device=device) 65 | [data, meta_info] = data_loader.get_items(batch_ids) 66 | data = data.to(device) 67 | image_ft_batch = compute_features(data) 68 | 69 | image_ft.append(image_ft_batch) 70 | asins.extend(meta_info) 71 | attribute_idx = compute_attribute_idx(meta_info) 72 | attributes.extend(attribute_idx) 73 | 74 | for i in tqdm.tqdm(range(num_batch), ascii=True): 75 | append_batch(i * batch_size, (i + 1) * batch_size) 76 | 77 | if num_batch * batch_size < num_data: 78 | append_batch(num_batch * batch_size, num_data) 79 | 80 | image_ft = torch.cat(image_ft, dim=0).to('cpu') 81 | attributes = torch.from_numpy(np.asarray(attributes, dtype=int)) 82 | features = {'asins': asins, 'image': image_ft, 'attribute': attributes} 83 | 84 | return features 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | # 102 | # # 103 | # def extract_features(caption_model, data_loader, save_name): 104 | # # a list of {'asin':[X], 'img_ft': Tensor, 'si_1': Tensor} 105 | # # keep each separate 106 | # # load side info models 107 | # model_path = os.path.join(SI_MODEL.format(args.data_set, 'image')) 108 | # si_image_encoder = ImageEncoder(1024).to(device) 109 | # resnet = si_image_encoder.delete_resnet() 110 | # si_image_encoder.load_state_dict(torch.load(model_path, map_location=device)) 111 | # si_image_encoder.load_resnet(resnet) 112 | # 113 | # model_path = os.path.join(SI_MODEL.format(args.data_set, 'text')) 114 | # si_text_decoder = MultiColumnPredictor(1024).to(device) 115 | # si_text_decoder.load_state_dict(torch.load(model_path, map_location=device)) 116 | # si_image_encoder.eval() 117 | # si_text_decoder.eval() 118 | # 119 | # batch_size = 128 120 | # 121 | # num_data = len(data_loader) 122 | # num_batch = math.floor(num_data / batch_size) 123 | # asins = [] 124 | # image_ft = [] 125 | # texture_ft = [] 126 | # fabric_ft = [] 127 | # shape_ft = [] 128 | # part_ft = [] 129 | # style_ft = [] 130 | # for iter in tqdm(range(num_batch), ascii=True): 131 | # # for i in range(2): 132 | # batch_ids = torch.LongTensor([i for i in range(iter * batch_size, (iter + 1) * batch_size)]) 133 | # [data, meta_info] = data_loader.get_items(batch_ids) 134 | # data = data.to(device) 135 | # with torch.no_grad(): 136 | # image_ft_batch = caption_model['image_encoder'](data) 137 | # si_ft_batch = si_image_encoder(data) 138 | # si_ft_batch = si_text_decoder(si_ft_batch) 139 | # 140 | # image_ft.append(image_ft_batch) 141 | # texture_ft.append(si_ft_batch['texture']) 142 | # fabric_ft.append(si_ft_batch['fabric']) 143 | # shape_ft.append(si_ft_batch['shape']) 144 | # part_ft.append(si_ft_batch['part']) 145 | # style_ft.append(si_ft_batch['style']) 146 | # asins.extend(meta_info) 147 | # 148 | # if num_batch * batch_size < num_data: 149 | # batch_ids = torch.LongTensor([i for i in range(num_batch * batch_size, num_data)]) 150 | # [data, meta_info] = data_loader.get_items(batch_ids) 151 | # data = data.to(device) 152 | # with torch.no_grad(): 153 | # image_ft_batch = caption_model['image_encoder'](data) 154 | # si_ft_batch = si_image_encoder(data) 155 | # si_ft_batch = si_text_decoder(si_ft_batch) 156 | # 157 | # image_ft.append(image_ft_batch) 158 | # texture_ft.append(si_ft_batch['texture']) 159 | # fabric_ft.append(si_ft_batch['fabric']) 160 | # shape_ft.append(si_ft_batch['shape']) 161 | # part_ft.append(si_ft_batch['part']) 162 | # style_ft.append(si_ft_batch['style']) 163 | # asins.extend(meta_info) 164 | # 165 | # image_ft = torch.cat(image_ft, dim=0) 166 | # texture_ft = torch.cat(texture_ft, dim=0) 167 | # fabric_ft = torch.cat(fabric_ft, dim=0) 168 | # shape_ft = torch.cat(shape_ft, dim=0) 169 | # part_ft = torch.cat(part_ft, dim=0) 170 | # style_ft = torch.cat(style_ft, dim=0) 171 | # features = {'asins':asins, 'image':image_ft, 'texture':texture_ft, 172 | # 'fabric':fabric_ft, 'part':part_ft, 'shape':shape_ft, 173 | # 'style':style_ft} 174 | # 175 | # torch.save(features, save_name) 176 | # 177 | # return features -------------------------------------------------------------------------------- /transformer/interactive_retrieval/Beam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | import numpy as np 5 | from torch.autograd import Variable 6 | 7 | 8 | Constants_PAD = 0 9 | 10 | 11 | def nopeak_mask(size): 12 | np_mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8') 13 | np_mask = Variable(torch.from_numpy(np_mask) == 0) 14 | 15 | return np_mask 16 | 17 | 18 | def create_masks(trg): 19 | # src_mask = (src != Constants_PAD.unsqueeze(-2) 20 | 21 | if trg is not None: 22 | trg_mask = (trg != Constants_PAD).unsqueeze(-2) 23 | size = trg.size(1) # get seq_len for matrix 24 | np_mask = nopeak_mask(size).to(trg_mask.device) 25 | 26 | trg_mask = trg_mask & np_mask 27 | 28 | else: 29 | trg_mask = None 30 | 31 | return trg_mask 32 | 33 | 34 | def init_vars(image0, image1, model, opt, vocab, image0_attribute, 35 | image1_attribute): 36 | init_tok = vocab.word2idx[''] 37 | image0 = model.cnn1(image0) 38 | image1 = model.cnn2(image1) 39 | 40 | if model.add_attribute: 41 | image0_attribute = model.attribute_embedding1(image0_attribute) 42 | image1_attribute = model.attribute_embedding2(image1_attribute) 43 | joint_encoding = model.joint_encoding(image0, image1) 44 | joint_encoding = torch.cat((joint_encoding, image0_attribute), 1) 45 | joint_encoding = torch.cat((joint_encoding, image1_attribute), 1) 46 | else: 47 | joint_encoding = model.joint_encoding(image0, image1) 48 | 49 | e_output = model.encoder(joint_encoding) 50 | outputs = torch.LongTensor([[init_tok]]).to(opt.device) 51 | 52 | trg_mask = nopeak_mask(1).to(opt.device) 53 | out = model.out(model.decoder( 54 | outputs, e_output, trg_mask)) 55 | 56 | out = F.softmax(out, dim=-1) 57 | probs, ix = out[:, -1].data.topk(opt.beam_size) 58 | log_scores = torch.Tensor( 59 | [math.log(prob) for prob in probs.data[0]]).unsqueeze(0) 60 | 61 | outputs = torch.zeros(opt.beam_size, opt.max_seq_len).long().to(opt.device) 62 | outputs[:, 0] = init_tok 63 | outputs[:, 1] = ix[0] 64 | 65 | e_outputs = torch.zeros( 66 | opt.beam_size, e_output.size(-2), e_output.size(-1)).to(opt.device) 67 | e_outputs[:, :] = e_output[0] 68 | 69 | return outputs, e_outputs, log_scores 70 | 71 | 72 | def k_best_outputs(outputs, out, log_scores, i, k): 73 | probs, ix = out[:, -1].data.topk(k) 74 | 75 | log_probs = (torch.Tensor( 76 | [math.log(p) for p in probs.data.view(-1)]).view(k,-1) + 77 | log_scores.transpose(0, 1)) 78 | 79 | k_probs, k_ix = log_probs.view(-1).topk(k) 80 | row = k_ix // k 81 | col = k_ix % k 82 | outputs[:, :i] = outputs[row, :i] 83 | outputs[:, i] = ix[row, col] 84 | log_scores = k_probs.unsqueeze(0) 85 | return outputs, log_scores 86 | 87 | 88 | def beam_search(image0, image1, model, opt, vocab, image0_label, image1_label): 89 | outputs, e_outputs, log_scores = init_vars( 90 | image0, image1, model, opt, vocab, image0_label, image1_label) 91 | eos_tok = vocab.word2idx[''] 92 | ind = None 93 | 94 | for i in range(2, opt.max_seq_len): 95 | trg_mask = nopeak_mask(i).to(opt.device) 96 | out = model.out(model.decoder(outputs[:, :i], e_outputs, trg_mask)) 97 | out = F.softmax(out, dim=-1) 98 | outputs, log_scores = k_best_outputs( 99 | outputs, out, log_scores, i, opt.beam_size) 100 | # Occurrences of end symbols for all input sentences. 101 | ones = (outputs == eos_tok).nonzero() 102 | 103 | sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda() 104 | for vec in ones: 105 | i = vec[0] 106 | if sentence_lengths[i] == 0: 107 | # First end symbol has not been found yet 108 | sentence_lengths[i] = vec[1] # Position of first end symbol 109 | 110 | num_finished_sentences = len([s for s in sentence_lengths if s > 0]) 111 | 112 | if num_finished_sentences == opt.beam_size: 113 | alpha = 0.7 114 | div = 1 / (sentence_lengths.type_as(log_scores) ** alpha) 115 | _, ind = torch.max(log_scores * div, 1) 116 | ind = ind.data[0] 117 | break 118 | 119 | if ind is None: 120 | out_str = ' '.join([vocab.idx2word[str(tok.item())] 121 | for tok in outputs[0][1:]]) 122 | out_idx = [tok.item() for tok in outputs[0][1:]] 123 | return out_idx, out_str 124 | else: 125 | length = (outputs[ind] == eos_tok).nonzero()[0] 126 | 127 | out_str = ' '.join([vocab.idx2word[str(tok.item())] 128 | for tok in outputs[ind][1:length]]) 129 | out_idx = [tok.item() 130 | for tok in outputs[ind][1:length]] 131 | return out_idx, out_str 132 | 133 | 134 | def greedy_search(image0, image1, model, opt, vocab, 135 | image0_label, image1_label): 136 | image0 = model.cnn1(image0) 137 | image1 = model.cnn2(image1) 138 | 139 | if model.add_attribute: 140 | image0_attribute = model.attribute_embedding1(image0_label) 141 | image1_attribute = model.attribute_embedding2(image1_label) 142 | joint_encoding = model.joint_encoding(image0, image1) 143 | joint_encoding = torch.cat((joint_encoding, image0_attribute), 1) 144 | joint_encoding = torch.cat((joint_encoding, image1_attribute), 1) 145 | else: 146 | joint_encoding = model.joint_encoding(image0, image1) 147 | 148 | e_outputs = model.encoder(joint_encoding) 149 | 150 | outputs = torch.from_numpy( 151 | np.zeros((image1.size(0), opt.max_seq_len))).to( 152 | dtype=torch.long, device=opt.device) 153 | 154 | init_tok = vocab.word2idx[''] 155 | 156 | outputs[:, 0] = init_tok 157 | for i in range(1, opt.max_seq_len): 158 | trg_mask = nopeak_mask(i).to(opt.device) 159 | out = model.out(model.decoder(outputs[:, :i], e_outputs, trg_mask)) 160 | probs, ix = out.max(dim=2) 161 | outputs[:, i] = ix[:, -1] 162 | 163 | end_tok = vocab.word2idx[''] 164 | mask = (outputs == end_tok).to(dtype=torch.float).cumsum(dim=1) 165 | # print('mask', mask) 166 | outputs = (outputs * ((mask == 0).to(dtype=torch.int)) + 167 | end_tok * ((mask > 0).to(dtype=torch.int))) 168 | out_str = [' '.join([vocab.idx2word[str(tok.item())] 169 | for tok in outputs[j][0:]]) 170 | for j in range(outputs.size(0))] 171 | 172 | return outputs, out_str 173 | -------------------------------------------------------------------------------- /transformer/user_modeling/test.py: -------------------------------------------------------------------------------- 1 | ''' Translate input text with trained model. ''' 2 | 3 | import torch 4 | import torch.utils.data 5 | import argparse 6 | from tqdm import tqdm 7 | import torchvision.transforms as transforms 8 | from nltk.translate.bleu_score import corpus_bleu 9 | from dataset import get_loader_test, load_ori_token_data_new 10 | from build_vocab import Vocabulary 11 | from Models import get_model, create_masks 12 | from Beam import beam_search 13 | from torch.autograd import Variable 14 | import numpy as np 15 | 16 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 17 | from pycocoevalcap.bleu.bleu import Bleu 18 | # from pycocoevalcap.meteor.meteor import Meteor 19 | from pycocoevalcap.rouge.rouge import Rouge 20 | from pycocoevalcap.ciderd.ciderD import CiderD 21 | from pycocoevalcap.cider.cider import Cider 22 | # from pycocoevalcap.spice.spice import Spice 23 | 24 | 25 | # from dataset import collate_fn, TranslationDataset 26 | # from transformer.Translator import Translator 27 | # from preprocess import read_instances_from_file, convert_instance_to_idx_seq 28 | 29 | 30 | 31 | def main(): 32 | '''Main Function''' 33 | 34 | parser = argparse.ArgumentParser(description='test.py') 35 | 36 | parser.add_argument('-pretrained_model', required=True, 37 | help='Path to model .pt file') 38 | parser.add_argument('-data_test', required=True, 39 | help='Path to input file') 40 | parser.add_argument('-vocab', required=True, 41 | help='Path to vocab file') 42 | parser.add_argument('-output', default='pred.txt', 43 | help="""Path to output the predictions (each line will 44 | be the decoded sequence""") 45 | parser.add_argument('-beam_size', type=int, default=5, 46 | help='Beam size') 47 | parser.add_argument('-batch_size', type=int, default=1, 48 | help='Batch size must be 1') 49 | parser.add_argument('-n_best', type=int, default=1, 50 | help="""If verbose is set, will output the n_best 51 | decoded sentences""") 52 | parser.add_argument('-no_cuda', action='store_true') 53 | parser.add_argument('-crop_size', type=int, default=224, help="""crop size""") 54 | parser.add_argument('-max_seq_len', type=int, default=64, help="""seq length""") 55 | parser.add_argument('-attribute_len', type=int, default=5, help="""attribute length""") 56 | 57 | opt = parser.parse_args() 58 | if args.batch_size != 1: 59 | print("batch size must be 1") 60 | exit() 61 | 62 | opt.cuda = not opt.no_cuda 63 | 64 | opt.device = torch.device('cuda' if opt.cuda else 'cpu') 65 | 66 | # print(args) 67 | test(opt) 68 | 69 | def test(opt): 70 | 71 | transform = transforms.Compose([ 72 | transforms.CenterCrop(opt.crop_size), 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.485, 0.456, 0.406), 75 | (0.229, 0.224, 0.225))]) 76 | 77 | vocab = Vocabulary() 78 | 79 | vocab.load(opt.vocab) 80 | 81 | data_loader = get_loader_test(opt.data_test, 82 | vocab, transform, 83 | opt.batch_size, shuffle=False, attribute_len=opt.attribute_len) 84 | 85 | list_of_refs = load_ori_token_data_new(opt.data_test) 86 | 87 | model = get_model(opt, load_weights=True) 88 | 89 | count = 0 90 | 91 | hypotheses = {} 92 | 93 | model.eval() 94 | 95 | for batch in tqdm(data_loader, mininterval=2, desc=' - (Test)', leave=False): 96 | 97 | image0, image1, image0_attribute, image1_attribute = map(lambda x: x.to(opt.device), batch) 98 | 99 | hyp = beam_search(image0, image1, model, opt, vocab, image0_attribute, image1_attribute) 100 | # hyp = greedy_search(image1.to(device), image2.to(device), model, opt, vocab) 101 | 102 | hyp = hyp.split("")[0].strip() 103 | 104 | hypotheses[count] = ["it " + hyp] 105 | 106 | count += 1 107 | 108 | # ================================================= 109 | # Set up scorers 110 | # ================================================= 111 | print('setting up scorers...') 112 | scorers = [ 113 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 114 | # (Meteor(),"METEOR"), 115 | (Rouge(), "ROUGE_L"), 116 | # (Cider(), "CIDEr"), 117 | (Cider(), "CIDEr"), 118 | (CiderD(), "CIDEr-D") 119 | # (Spice(), "SPICE") 120 | ] 121 | 122 | for scorer, method in scorers: 123 | print('computing %s score...'%(scorer.method())) 124 | score, scores = scorer.compute_score(list_of_refs, hypotheses) 125 | if type(method) == list: 126 | for sc, scs, m in zip(score, scores, method): 127 | # self.setEval(sc, m) 128 | # self.setImgToEvalImgs(scs, gts.keys(), m) 129 | print("%s: %0.3f"%(m, sc)) 130 | else: 131 | # self.setEval(score, method) 132 | # self.setImgToEvalImgs(scores, gts.keys(), method) 133 | print("%s: %0.3f"%(method, score)) 134 | 135 | for i in range(len(hypotheses)): 136 | ref = {i:list_of_refs[i]} 137 | hyp = {i:hypotheses[i]} 138 | print(ref) 139 | print(hyp) 140 | for scorer, method in scorers: 141 | print('computing %s score...'%(scorer.method())) 142 | score, scores = scorer.compute_score(ref, hyp) 143 | if type(method) == list: 144 | for sc, scs, m in zip(score, scores, method): 145 | # self.setEval(sc, m) 146 | # self.setImgToEvalImgs(scs, gts.keys(), m) 147 | print("%s: %0.3f"%(m, sc)) 148 | else: 149 | # self.setEval(score, method) 150 | # self.setImgToEvalImgs(scores, gts.keys(), method) 151 | print("%s: %0.3f"%(method, score)) 152 | 153 | 154 | def greedy_search(image1, image2, model, opt, vocab): 155 | 156 | # Autoregressive inference 157 | embedding_1 = model.cnn(image1)#(1, batch_size, embed_size) 158 | 159 | embedding_2 = model.cnn(image2)#(1, batch_size, embed_size) 160 | 161 | joint_embedding = model.joint_encoding(embedding_1, embedding_2)#(1, batch_size, embed_size) 162 | 163 | e_output = model.encoder(joint_embedding) 164 | 165 | preds_t = torch.LongTensor(np.zeros((image1.size(0), opt.max_seq_len), np.int32)).cuda() 166 | 167 | init_tok = vocab.word2idx[''] 168 | 169 | preds_t[:,0] = init_tok 170 | 171 | for j in range(opt.max_seq_len): 172 | 173 | # _, _preds, _ = model(x_, preds) 174 | 175 | trg_mask = create_masks(preds_t).to(opt.device) 176 | 177 | hidden = model.decoder(preds_t, e_output, trg_mask)#(seq_len, batch_size, hidden) 178 | 179 | logits = model.out(hidden)#(batch_size, seq_len, vocab_size) 180 | 181 | _preds = logits.max(2)[1] #(batch_size, seq_len) 182 | 183 | preds_t[:, j] = _preds.data[:, j] 184 | 185 | preds = preds_t.cpu().numpy() 186 | 187 | return ' '.join([vocab.idx2word[str(tok)] for tok in preds[0]]) 188 | 189 | if __name__ == "__main__": 190 | main() 191 | 192 | 193 | -------------------------------------------------------------------------------- /start_kit/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | from torchvision import transforms 7 | from data_loader import get_loader 8 | from build_vocab import Vocabulary 9 | from models import DummyImageEncoder, DummyCaptionEncoder 10 | from utils import create_exp_dir, Ranker 11 | 12 | # Device configuration 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | # Paths to data 16 | IMAGE_ROOT = 'data/resized_images/{}/' 17 | CAPT = 'data/captions/cap.{}.{}.json' 18 | DICT = 'data/captions/dict.{}.json' 19 | SPLIT = 'data/image_splits/split.{}.{}.json' 20 | 21 | # Loss function 22 | triplet_avg = nn.TripletMarginLoss(reduction='elementwise_mean', margin=1) 23 | 24 | 25 | def eval_batch(data_loader, image_encoder, caption_encoder, ranker): 26 | ranker.update_emb(image_encoder) 27 | rankings = [] 28 | loss = [] 29 | for i, (target_images, candidate_images, captions, lengths, meta_info) in enumerate(data_loader): 30 | with torch.no_grad(): 31 | target_images = target_images.to(device) 32 | target_ft = image_encoder.forward(target_images) 33 | candidate_images = candidate_images.to(device) 34 | candidate_ft = image_encoder.forward(candidate_images) 35 | captions = captions.to(device) 36 | caption_ft = caption_encoder(captions, lengths) 37 | target_asins = [ meta_info[m]['target'] for m in range(len(meta_info)) ] 38 | rankings.append(ranker.compute_rank(candidate_ft + caption_ft, target_asins)) 39 | m = target_images.size(0) 40 | random_index = [m - 1 - n for n in range(m)] 41 | random_index = torch.LongTensor(random_index) 42 | negative_ft = target_ft[random_index] 43 | loss.append(triplet_avg(anchor=(candidate_ft + caption_ft), 44 | positive=target_ft, negative=negative_ft)) 45 | 46 | metrics = {} 47 | rankings = torch.cat(rankings, dim=0) 48 | metrics['score'] = 1 - rankings.mean().item() / ranker.data_emb.size(0) 49 | metrics['loss'] = torch.stack(loss, dim=0).mean().item() 50 | return metrics 51 | 52 | 53 | def train(args): 54 | # Image preprocessing, normalization for the pretrained resnet 55 | transform = transforms.Compose([ 56 | transforms.RandomCrop(args.crop_size), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | transforms.Normalize((0.485, 0.456, 0.406), 60 | (0.229, 0.224, 0.225))]) 61 | 62 | transform_dev = transforms.Compose([ 63 | transforms.CenterCrop(args.crop_size), 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.485, 0.456, 0.406), 66 | (0.229, 0.224, 0.225))]) 67 | 68 | vocab = Vocabulary() 69 | vocab.load(DICT.format(args.data_set)) 70 | 71 | # Build data loader 72 | data_loader = get_loader(IMAGE_ROOT.format(args.data_set), 73 | CAPT.format(args.data_set, 'train'), 74 | vocab, transform, 75 | args.batch_size, shuffle=True, return_target=True, num_workers=args.num_workers) 76 | 77 | data_loader_dev = get_loader(IMAGE_ROOT.format(args.data_set), 78 | CAPT.format(args.data_set, 'val'), 79 | vocab, transform_dev, 80 | args.batch_size, shuffle=False, return_target=True, num_workers=args.num_workers) 81 | 82 | ranker = Ranker(root=IMAGE_ROOT.format(args.data_set), image_split_file=SPLIT.format(args.data_set, 'val'), 83 | transform=transform_dev, num_workers=args.num_workers) 84 | 85 | save_folder = '{}/{}-{}'.format(args.save, args.data_set, time.strftime("%Y%m%d-%H%M%S")) 86 | create_exp_dir(save_folder, scripts_to_save=['models.py', 'data_loader.py', 'train.py', 'build_vocab.py', 'utils.py']) 87 | 88 | def logging(s, print_=True, log_=True): 89 | if print_: 90 | print(s) 91 | if log_: 92 | with open(os.path.join(save_folder, 'log.txt'), 'a+') as f_log: 93 | f_log.write(s + '\n') 94 | 95 | logging(str(args)) 96 | # Build the dummy models 97 | image_encoder = DummyImageEncoder(args.embed_size).to(device) 98 | caption_encoder = DummyCaptionEncoder(vocab_size=len(vocab), vocab_embed_size=args.embed_size * 2, 99 | embed_size=args.embed_size).to(device) 100 | 101 | image_encoder.train() 102 | caption_encoder.train() 103 | params = image_encoder.get_trainable_parameters() + caption_encoder.get_trainable_parameters() 104 | 105 | current_lr = args.learning_rate 106 | optimizer = torch.optim.Adam(params, lr=current_lr) 107 | 108 | cur_patient = 0 109 | best_score = float('-inf') 110 | stop_train = False 111 | total_step = len(data_loader) 112 | # epoch = 1 for dummy setting 113 | for epoch in range(1): 114 | 115 | for i, (target_images, candidate_images, captions, lengths, meta_info) in enumerate(data_loader): 116 | 117 | target_images = target_images.to(device) 118 | target_ft = image_encoder.forward(target_images) 119 | 120 | candidate_images = candidate_images.to(device) 121 | candidate_ft = image_encoder.forward(candidate_images) 122 | 123 | captions = captions.to(device) 124 | caption_ft = caption_encoder(captions, lengths) 125 | 126 | # random select negative examples 127 | m = target_images.size(0) 128 | random_index = [m - 1 - n for n in range(m)] 129 | random_index = torch.LongTensor(random_index) 130 | negative_ft = target_ft[random_index] 131 | 132 | loss = triplet_avg(anchor=(candidate_ft + caption_ft), 133 | positive=target_ft, negative=negative_ft) 134 | 135 | caption_encoder.zero_grad() 136 | image_encoder.zero_grad() 137 | loss.backward() 138 | optimizer.step() 139 | 140 | if i % args.log_step == 0: 141 | logging( 142 | '| epoch {:3d} | step {:6d}/{:6d} | lr {:06.6f} | train loss {:8.3f}'.format(epoch, i, total_step, 143 | current_lr, 144 | loss.item())) 145 | 146 | image_encoder.eval() 147 | caption_encoder.eval() 148 | logging('-' * 77) 149 | metrics = eval_batch(data_loader_dev, image_encoder, caption_encoder, ranker) 150 | logging('| eval loss: {:8.3f} | score {:8.5f} / {:8.5f} '.format( 151 | metrics['loss'], metrics['score'], best_score)) 152 | logging('-' * 77) 153 | 154 | image_encoder.train() 155 | caption_encoder.train() 156 | 157 | dev_score = metrics['score'] 158 | if dev_score > best_score: 159 | best_score = dev_score 160 | # save best model 161 | resnet = image_encoder.delete_resnet() 162 | torch.save(image_encoder.state_dict(), os.path.join( 163 | save_folder, 164 | 'image-{}.th'.format(args.embed_size))) 165 | image_encoder.load_resnet(resnet) 166 | 167 | torch.save(caption_encoder.state_dict(), os.path.join( 168 | save_folder, 169 | 'cap-{}.th'.format(args.embed_size))) 170 | 171 | cur_patient = 0 172 | else: 173 | cur_patient += 1 174 | if cur_patient >= args.patient: 175 | current_lr /= 2 176 | for param_group in optimizer.param_groups: 177 | param_group['lr'] = current_lr 178 | if current_lr < args.learning_rate * 1e-3: 179 | stop_train = True 180 | break 181 | 182 | if stop_train: 183 | break 184 | logging('best_dev_score: {}'.format(best_score)) 185 | 186 | 187 | if __name__ == '__main__': 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument('--save', type=str, default='models', 190 | help='path for saving trained models') 191 | parser.add_argument('--crop_size', type=int, default=224, 192 | help='size for randomly cropping images') 193 | 194 | parser.add_argument('--data_set', type=str, default='dress') 195 | parser.add_argument('--log_step', type=int, default=3, 196 | help='step size for printing log info') 197 | parser.add_argument('--patient', type=int, default=3, 198 | help='patient for reducing learning rate') 199 | 200 | # Model parameters 201 | parser.add_argument('--embed_size', type=int , default=512, 202 | help='dimension of word embedding vectors') 203 | # Learning parameters 204 | parser.add_argument('--batch_size', type=int, default=2) 205 | parser.add_argument('--num_workers', type=int, default=16) 206 | parser.add_argument('--learning_rate', type=float, default=0.001) 207 | 208 | args = parser.parse_args() 209 | 210 | train(args) 211 | 212 | -------------------------------------------------------------------------------- /transformer/attribute_prediction/finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import tqdm 4 | import shutil 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | from torchvision import transforms 10 | import attribute_loader 11 | from efficientnet_pytorch import EfficientNet 12 | 13 | 14 | def create_exp_dir(path, scripts_to_save=None): 15 | if not os.path.exists(path): 16 | os.mkdir(path) 17 | print('Experiment dir : {}'.format(path)) 18 | if scripts_to_save is not None: 19 | if not os.path.exists(os.path.join(path, 'scripts')): 20 | os.mkdir(os.path.join(path, 'scripts')) 21 | for script in scripts_to_save: 22 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 23 | shutil.copyfile(script, dst_file) 24 | return 25 | 26 | 27 | def compute_metric(pred, label, topk=5): 28 | topk_tags = pred.topk(dim=1, k=topk, largest=True)[1] 29 | # compute precision at topk 30 | positive = label.gather(dim=1, index=topk_tags) 31 | p = positive.sum(dim=1) / topk 32 | r = positive.sum(dim=1) / (label.sum(dim=1)+1e-5) 33 | score = 2 * p * r / (p + r + 1e-5) 34 | return score.sum() / label.size(0) 35 | 36 | 37 | def evaluate_model(data_loader, model, loss, device): 38 | model.eval() 39 | test_num = 0 40 | # Update learning rate and create optimizer 41 | error_sum = 0.0 42 | # Training loop 43 | fs_sum = 0.0 44 | for i, (images, labels, _) in enumerate(data_loader): 45 | # Set mini-batch dataset 46 | with torch.no_grad(): 47 | images = images.to(device) 48 | labels = labels.to(device) 49 | outs = model(images) 50 | error = loss(outs, labels) 51 | fs = compute_metric(outs, labels) 52 | 53 | error_sum += error.item() * images.size(0) 54 | fs_sum += fs.item() * images.size(0) 55 | 56 | test_num += images.size(0) 57 | 58 | return error_sum / test_num, fs_sum / test_num 59 | 60 | 61 | def finetune_attributes(args): 62 | device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu' 63 | # Build data loader 64 | # --Image preprocessing, normalization for the pretrained resnet 65 | transform = transforms.Compose([ 66 | transforms.Resize(args.crop_size), 67 | transforms.ToTensor(), 68 | transforms.Normalize((0.485, 0.456, 0.406), 69 | (0.229, 0.224, 0.225))]) 70 | 71 | transform_dev = transforms.Compose([ 72 | transforms.Resize(args.crop_size), 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.485, 0.456, 0.406), 75 | (0.229, 0.224, 0.225))]) 76 | 77 | image_folder = args.image_folder.format(args.data_set) 78 | data_file = args.data_file 79 | train_data_loader = attribute_loader.get_loader( 80 | root=image_folder, 81 | data_file=data_file.format(args.data_set, 'train'), 82 | class_file=args.label_file, 83 | transform=transform, 84 | batch_size=args.batch_size, 85 | shuffle=True) 86 | 87 | val_data_loader = attribute_loader.get_loader( 88 | root=image_folder, 89 | data_file=data_file.format(args.data_set, 'val'), 90 | class_file=args.label_file, 91 | transform=transform_dev, 92 | batch_size=args.batch_size, 93 | shuffle=False) 94 | 95 | # Load the models 96 | model = EfficientNet.from_pretrained('efficientnet-b7') 97 | model_type = 'ft' 98 | ckpt = torch.load(args.pretrained_model, map_location='cpu') 99 | if "model_state" in ckpt: 100 | model.load_state_dict(ckpt["model_state"]) 101 | else: 102 | model.load_state_dict(ckpt) 103 | model.to(device) 104 | 105 | # - freeze the bottom part 106 | trainable_parameters = [] 107 | for name, param in model.named_parameters(): 108 | if 'fc' in name: 109 | param.requires_grad = True 110 | trainable_parameters.append(param) 111 | else: 112 | param.requires_grad = False 113 | 114 | for name, param in model.named_parameters(): 115 | if param.requires_grad: 116 | print(name) 117 | model.to(device) 118 | 119 | # Loss and optimizer 120 | current_lr = args.learning_rate 121 | optimizer = torch.optim.Adam(lr=current_lr, params=trainable_parameters) 122 | bce_average = nn.BCEWithLogitsLoss(reduction='mean').to(device) 123 | 124 | # Experiment logging 125 | global_step = 0 126 | cur_patient = 0 127 | best_score = float('-inf') 128 | total_step = len(train_data_loader) 129 | 130 | save_folder = 'logs/{}-{}'.format( 131 | args.data_set, time.strftime("%Y%m%d-%H%M%S")) 132 | create_exp_dir(save_folder, scripts_to_save=[]) 133 | 134 | def logging(s, print_=True, log_=True): 135 | if print_: 136 | print(s) 137 | if log_: 138 | with open(os.path.join(save_folder, 'log.txt'), 'a+') as f_log: 139 | f_log.write(s + '\n') 140 | 141 | for epoch in range(1000): 142 | if global_step % args.checkpoint == 0: 143 | model.eval() 144 | logging('-' * 87) 145 | with torch.no_grad(): 146 | error, fs = evaluate_model( 147 | val_data_loader, model, bce_average, device) 148 | logging( 149 | '| ({}) eval loss: {:8.3f} | score {:8.5f} / {:8.5f}'.format( 150 | epoch, error, fs, best_score)) 151 | logging('-' * 87) 152 | # print(metrics) 153 | dev_score = fs 154 | if dev_score > best_score: 155 | best_score = dev_score 156 | 157 | torch.save(model.state_dict(), os.path.join( 158 | args.save_model.format(args.data_set, model_type))) 159 | else: 160 | cur_patient += 1 161 | if cur_patient >= args.patient: 162 | current_lr /= 2 163 | for param_group in optimizer.param_groups: 164 | param_group['lr'] = current_lr 165 | if current_lr < args.learning_rate * 1e-3: 166 | # stop_train = True 167 | break 168 | cur_patient = 0 169 | 170 | model.train() 171 | for images, labels, _ in tqdm.tqdm( 172 | train_data_loader, 173 | desc="training epoch {}".format(epoch)): 174 | # Set mini-batch dataset 175 | images = images.to(device) 176 | labels = labels.to(device) 177 | outs = model(images) 178 | error = bce_average(outs, labels) 179 | 180 | optimizer.zero_grad() 181 | error.backward() 182 | optimizer.step() 183 | global_step += 1 184 | # Print log info 185 | if global_step % args.log_step == 0: 186 | if global_step >= total_step: 187 | global_step = (global_step - int(global_step / total_step) * 188 | total_step) 189 | 190 | logging('| epoch {:3d} | step {:6d}/{:6d} | ' 191 | 'lr {:06.6f} | train loss {:8.3f}'.format( 192 | epoch, global_step, total_step, current_lr, 193 | error.item())) 194 | 195 | logging('beset_dev_score: {}'.format(best_score)) 196 | 197 | 198 | if __name__ == '__main__': 199 | parser = argparse.ArgumentParser() 200 | # data 201 | parser.add_argument('--image_folder', type=str, 202 | default='../resized_images/{}/') 203 | parser.add_argument('--data_file', type=str, 204 | default='data/asin2attr.{}.{}.json') 205 | parser.add_argument('--label_file', type=str, 206 | default='data/attribute2idx.json') 207 | parser.add_argument('--crop_size', type=int, default=224, 208 | help='size for randomly cropping images') 209 | parser.add_argument('--data_set', type=str, default='dress', 210 | help='dress / shirt / toptee') 211 | 212 | # model 213 | parser.add_argument('--pretrained_model', type=str, 214 | default='deepfashion_models/' 215 | 'dfattributes_efficientnet_b7ns.pth') 216 | parser.add_argument('--save_model', type=str, 217 | default='models/attributes_{}_{}.pth', 218 | help='path for saving trained models') 219 | parser.add_argument('--loss', type=str, default='binary', 220 | help='binary / rank') 221 | 222 | parser.add_argument('--log_step', type=int, default=45, 223 | help='step size for printing log info') 224 | parser.add_argument('--checkpoint', type=int, default=1, 225 | help='step size for saving models') 226 | parser.add_argument('--patient', type=int, default=3, 227 | help='patient for reducing learning rate') 228 | 229 | parser.add_argument('--no_cuda', action='store_true') 230 | parser.add_argument('--batch_size', type=int, default=2) 231 | parser.add_argument('--learning_rate', type=float, default=0.001) 232 | 233 | args = parser.parse_args() 234 | print(args) 235 | finetune_attributes(args) 236 | -------------------------------------------------------------------------------- /transformer/interactive_retrieval/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.init import xavier_uniform_, constant_ 6 | from torch.autograd import Variable 7 | import copy 8 | import numpy as np 9 | 10 | 11 | class PositionalEncoder(nn.Module): 12 | def __init__(self, d_model, max_seq_len=200, dropout=None): 13 | super().__init__() 14 | self.d_model = d_model 15 | if dropout is not None: 16 | self.dropout = nn.Dropout(dropout) 17 | else: 18 | self.dropout = None 19 | # create constant 'pe' matrix with values dependant on 20 | # pos and i 21 | pe = torch.zeros(max_seq_len, d_model) 22 | for pos in range(max_seq_len): 23 | for i in range(0, d_model, 2): 24 | pe[pos, i] = \ 25 | math.sin(pos / (10000 ** ((2 * i) / d_model))) 26 | pe[pos, i + 1] = \ 27 | math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model))) 28 | pe = pe.unsqueeze(0) 29 | self.register_buffer('pe', pe) 30 | 31 | def forward(self, x): 32 | # make embeddings relatively larger 33 | x = x * math.sqrt(self.d_model) 34 | # add constant to embedding 35 | seq_len = x.size(1) 36 | pe = Variable(self.pe[:, :seq_len], requires_grad=False) 37 | if x.is_cuda: 38 | pe.cuda() 39 | x = x + pe 40 | if self.dropout: 41 | return self.dropout(x) 42 | return x 43 | 44 | 45 | class Norm(nn.Module): 46 | def __init__(self, d_model, eps=1e-6, calibrate=True): 47 | super().__init__() 48 | 49 | self.size = d_model 50 | self.calibrate = calibrate 51 | # create two learnable parameters to calibrate normalisation 52 | if self.calibrate: 53 | self.alpha = nn.Parameter(torch.ones(self.size)) 54 | self.bias = nn.Parameter(torch.zeros(self.size)) 55 | 56 | self.eps = eps 57 | 58 | def forward(self, x): 59 | if self.calibrate: 60 | norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \ 61 | / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias 62 | else: 63 | norm = (x - x.mean(dim=-1, keepdim=True)) \ 64 | / (x.std(dim=-1, keepdim=True) + self.eps) 65 | return norm 66 | 67 | 68 | def attention(q, k, v, d_k, mask=None, dropout=None): 69 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) 70 | 71 | if mask is not None: 72 | mask = mask.unsqueeze(1) 73 | scores = scores.masked_fill(mask == 0, -1e9) 74 | 75 | scores = F.softmax(scores, dim=-1) 76 | 77 | if dropout is not None: 78 | scores = dropout(scores) 79 | 80 | output = torch.matmul(scores, v) 81 | return output 82 | 83 | 84 | class MultiHeadAttention(nn.Module): 85 | def __init__(self, heads, d_model, dropout): 86 | super().__init__() 87 | 88 | self.d_model = d_model 89 | self.d_k = d_model // heads 90 | self.h = heads 91 | 92 | self.q_linear = nn.Linear(d_model, d_model) 93 | self.v_linear = nn.Linear(d_model, d_model) 94 | self.k_linear = nn.Linear(d_model, d_model) 95 | if dropout is not None: 96 | self.dropout = nn.Dropout(dropout) 97 | else: 98 | self.dropout = None 99 | self.out = nn.Linear(d_model, d_model) 100 | 101 | def forward(self, q, k, v, mask=None): 102 | bs = q.size(0) 103 | 104 | # perform linear operation and split into N heads 105 | k = self.k_linear(k).view(bs, -1, self.h, self.d_k) 106 | q = self.q_linear(q).view(bs, -1, self.h, self.d_k) 107 | v = self.v_linear(v).view(bs, -1, self.h, self.d_k) 108 | 109 | # transpose to get dimensions bs * N * sl * d_model 110 | k = k.transpose(1, 2) 111 | q = q.transpose(1, 2) 112 | v = v.transpose(1, 2) 113 | 114 | # calculate attention using function we will define next 115 | scores = attention(q, k, v, self.d_k, mask, self.dropout) 116 | # concatenate heads and put through final linear layer 117 | concat = (scores.transpose(1, 2).contiguous() 118 | .view(bs, -1, self.d_model)) 119 | output = self.out(concat) 120 | 121 | return output 122 | 123 | 124 | class FeedForward(nn.Module): 125 | def __init__(self, d_model, d_ff=1024): 126 | super().__init__() 127 | # We set d_ff as a default to 2048 128 | self.linear_1 = nn.Linear(d_model, d_ff) 129 | # self.dropout = nn.Dropout(dropout) 130 | self.linear_2 = nn.Linear(d_ff, d_model) 131 | 132 | def forward(self, x): 133 | x = F.relu(self.linear_1(x)) 134 | x = self.linear_2(x) 135 | return x 136 | 137 | 138 | class EncoderLayer(nn.Module): 139 | def __init__(self, d_model, heads, dropout=None): 140 | super().__init__() 141 | self.norm_1 = Norm(d_model) 142 | self.norm_2 = Norm(d_model) 143 | self.attn = MultiHeadAttention(heads, d_model, dropout=dropout) 144 | self.ff = FeedForward(d_model) 145 | 146 | def forward(self, x, mask=None): 147 | x2 = self.norm_1(x) 148 | x = x + self.attn(x2, x2, x2, mask) 149 | x2 = self.norm_2(x) 150 | x = x + self.ff(x2) 151 | return x 152 | 153 | 154 | def get_clones(module, N): 155 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 156 | 157 | 158 | class Encoder(nn.Module): 159 | def __init__(self, d_model, N_layers, heads=8, dropout=None): 160 | super().__init__() 161 | self.N_layers = N_layers 162 | self.pe = PositionalEncoder(d_model, dropout=dropout) 163 | self.layers = get_clones( 164 | EncoderLayer(d_model, heads, dropout), N_layers) 165 | self.norm = Norm(d_model) 166 | 167 | def forward(self, x): 168 | x = self.pe(x) 169 | for i in range(self.N_layers): 170 | x = self.layers[i](x) 171 | return self.norm(x) 172 | 173 | 174 | class RetrieverTransformer(nn.Module): 175 | def __init__(self, vocab_size, word_mat, img_dim, hist_dim, 176 | layer_num, attribute_num): 177 | super().__init__() 178 | 179 | # text part 180 | glove_dim = 300 181 | # glove_dim = len(word_mat[0]) 182 | 183 | self.word_emb = nn.Embedding(vocab_size+1, glove_dim) 184 | print('[INFO] Load glove embedding ({})'.format(glove_dim)) 185 | self.word_emb.weight.data.copy_( 186 | torch.from_numpy(np.asarray(word_mat))) 187 | self.word_emb.weight.requires_grad = False 188 | self.fix_word_emb = True 189 | 190 | self.text_linear = nn.Linear(glove_dim, hist_dim) 191 | self.text_norm = Norm(hist_dim) 192 | 193 | # image part 194 | self.img_emb = nn.Linear(img_dim, hist_dim, bias=False) 195 | self.img_norm = Norm(hist_dim) 196 | 197 | # attribute part 198 | self.attr_emb = nn.Embedding(attribute_num, hist_dim) 199 | self.attr_emb.scale_grad_by_freq = True 200 | self.attr_norm = Norm(hist_dim) 201 | 202 | # response encoder 203 | self.tran = Encoder( 204 | d_model=hist_dim, N_layers=layer_num) 205 | self.layer_num = layer_num 206 | 207 | # output part 208 | self.out_linear = nn.Linear(hist_dim, hist_dim, bias=True) 209 | 210 | self.vocab_size = vocab_size+1 211 | self.hist_vectors = [] 212 | 213 | self.sp_token = nn.Parameter( 214 | torch.zeros(size=(1, hist_dim)), requires_grad=False) 215 | 216 | self.hist_dim = hist_dim 217 | self.init_parameters() 218 | 219 | def init_parameters(self): 220 | return 221 | 222 | def init_hist(self): 223 | self.hist_vectors.clear() 224 | return 225 | 226 | def encode_image(self, images): 227 | return self.img_norm(self.img_emb(images)) 228 | 229 | def get_sp_emb(self, batch_size): 230 | 231 | with torch.no_grad(): 232 | sp_emb = self.sp_token.expand( 233 | size=(batch_size, 1, self.hist_dim)) 234 | 235 | sp_emb = self.text_norm(sp_emb) 236 | return sp_emb 237 | 238 | # input: 239 | # text: B x L x V 240 | # image: B x Hi 241 | # hist: B x Hh 242 | def forward(self, text, image, attribute): 243 | # special token 244 | sp_emb = self.get_sp_emb(text.size(0)) 245 | self.hist_vectors.append(sp_emb) 246 | 247 | # text part 248 | # B x L x H 249 | with torch.no_grad(): 250 | text_emb = self.word_emb(text) 251 | 252 | text_emb = self.text_linear(text_emb) 253 | text_emb = self.text_norm(text_emb) 254 | self.hist_vectors.append(text_emb) 255 | 256 | # attribute part 257 | attr_emb = self.attr_emb(attribute) 258 | attr_emb = self.attr_norm(attr_emb) 259 | self.hist_vectors.append(attr_emb) 260 | 261 | # image part 262 | # B x 1 x H 263 | img_emb = self.encode_image(image).unsqueeze(dim=1) 264 | self.hist_vectors.append(img_emb) 265 | 266 | full_input = torch.cat(self.hist_vectors, dim=1) 267 | outs = self.tran(full_input) 268 | outs = self.out_linear(F.relu(outs.mean(dim=1))) 269 | 270 | return outs 271 | 272 | def convert_onehot(self, text): 273 | B, L = text.size(0), text.size(1) 274 | onehot = torch.zeros(B * L, self.vocab_size).to(text.device) 275 | onehot.scatter_(1, text.view(-1, 1), 1) 276 | onehot = onehot.view(B, L, self.vocab_size) 277 | return onehot 278 | 279 | -------------------------------------------------------------------------------- /transformer/user_modeling/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torch.utils.data as data 4 | import os 5 | import nltk 6 | from PIL import Image 7 | import json 8 | import numpy as np 9 | 10 | SEQ_LEN = None 11 | class Dataset(data.Dataset): 12 | 13 | def __init__(self, data_file_name, vocab, transform=None, max_seq_len=64, label_len=5): 14 | """Set the path for images, captions and vocabulary wrapper. 15 | 16 | Args: 17 | root: image directory. 18 | data: index file name. 19 | transform: image transformer. 20 | vocab: pre-processed vocabulary. 21 | """ 22 | # self.root = root 23 | with open(data_file_name, 'r') as f: 24 | self.data = json.load(f) 25 | self.ids = range(len(self.data)) 26 | self.vocab = vocab 27 | self.transform = transform 28 | # self.return_target = return_target 29 | self.seq_len = max_seq_len 30 | SEQ_LEN = max_seq_len 31 | self.label_len = label_len 32 | 33 | def __getitem__(self, index): 34 | """Returns one data pair (image and concatenated captions).""" 35 | data = self.data 36 | vocab = self.vocab 37 | id = self.ids[index] 38 | 39 | image0 = data[id]['image0'] 40 | image0 = Image.open(os.path.join(image0)).convert('RGB') 41 | if self.transform is not None: 42 | image0 = self.transform(image0) 43 | 44 | image1 = data[id]['image1'] 45 | image1 = Image.open(os.path.join(image1)).convert('RGB') 46 | if self.transform is not None: 47 | image1 = self.transform(image1) 48 | 49 | caption = [] 50 | caption_texts = data[id]['captions'] 51 | # Convert caption (string) to word ids. 52 | tokens = nltk.tokenize.word_tokenize(str(caption_texts).lower()) 53 | 54 | if len(tokens) >= self.seq_len: 55 | tokens = tokens[:self.seq_len] 56 | 57 | caption.append(vocab('')) 58 | caption.extend([vocab(token) for token in tokens]) 59 | caption.append(vocab('')) 60 | caption = torch.Tensor(caption) 61 | 62 | # image0_label = torch.Tensor(data[id]['image0_label'][:self.label_len]).long() 63 | # image1_label = torch.Tensor(data[id]['image1_label'][:self.label_len]).long() 64 | image0_label = torch.Tensor(data[id]['image0_full_score']).float() 65 | image1_label = torch.Tensor(data[id]['image1_full_score']).float() 66 | 67 | return image0, image1, caption, image0_label, image1_label 68 | 69 | def __len__(self): 70 | return len(self.ids) 71 | 72 | 73 | class Dataset_fastrcnn(data.Dataset): 74 | 75 | def __init__(self, data_file_name, vocab, transform=None, max_seq_len=64): 76 | """Set the path for images, captions and vocabulary wrapper. 77 | 78 | Args: 79 | root: image directory. 80 | data: index file name. 81 | transform: image transformer. 82 | vocab: pre-processed vocabulary. 83 | """ 84 | # self.root = root 85 | with open(data_file_name, 'r') as f: 86 | self.data = json.load(f) 87 | self.ids = range(len(self.data)) 88 | self.vocab = vocab 89 | self.transform = transform 90 | # self.return_target = return_target 91 | self.seq_len = max_seq_len 92 | SEQ_LEN = max_seq_len 93 | 94 | def __getitem__(self, index): 95 | """Returns one data pair (image and concatenated captions).""" 96 | data = self.data 97 | vocab = self.vocab 98 | id = self.ids[index] 99 | 100 | image_0 = data[id]['image0'] 101 | # image_0 = Image.open(os.path.join(image_0)).convert('RGB') 102 | image_0 = np.load(image_0, allow_pickle=True) 103 | # if self.transform is not None: 104 | # image_0 = self.transform(image_0) 105 | 106 | image_1 = data[id]['image1'] 107 | image_1 = np.load(image_1, allow_pickle=True) 108 | # image_1 = Image.open(os.path.join(image_1)).convert('RGB') 109 | # if self.transform is not None: 110 | # image_1 = self.transform(image_1) 111 | 112 | caption = [] 113 | caption_texts = data[id]['captions'] 114 | # Convert caption (string) to word ids. 115 | tokens = nltk.tokenize.word_tokenize(str(caption_texts).lower()) 116 | 117 | if len(tokens) >= self.seq_len: 118 | tokens = tokens[:self.seq_len] 119 | 120 | caption.append(vocab('')) 121 | caption.extend([vocab(token) for token in tokens]) 122 | caption.append(vocab('')) 123 | caption = torch.Tensor(caption) 124 | 125 | return image_0, image_1, caption 126 | 127 | def __len__(self): 128 | return len(self.ids) 129 | 130 | 131 | 132 | 133 | 134 | def collate_fn(data): 135 | """Creates mini-batch tensors from the list of tuples (image, caption). 136 | 137 | Args: 138 | data: list of tuple (image, caption). 139 | - image: torch tensor of shape 140 | - caption: torch tensor of shape (?); variable length. 141 | 142 | Returns: 143 | images: torch tensor of images. 144 | targets: torch tensor of shape (batch_size, padded_length). 145 | lengths: list; valid length for each padded caption. 146 | """ 147 | # Sort a data list by caption length (descending order). 148 | image0, image1, captions, image0_label, image1_label = zip(*data) 149 | 150 | 151 | # Merge images (from tuple of 3D tensor to 4D tensor). 152 | image0 = torch.stack(image0, 0) 153 | image1 = torch.stack(image1, 0) 154 | 155 | image0_label = torch.stack(image0_label, 0) 156 | image1_label = torch.stack(image1_label, 0) 157 | # Merge captions (from tuple of 1D tensor to 2D tensor). 158 | lengths = [len(cap) for cap in captions] 159 | 160 | captions_src = torch.zeros(len(captions), max(lengths)).long() 161 | captions_tgt = torch.zeros(len(captions), max(lengths)).long() 162 | for i, cap in enumerate(captions): 163 | end = lengths[i] 164 | captions_src[i, :end-1] = cap[:end-1] 165 | captions_tgt[i, :end-1] = cap[1:end] 166 | # caption_padding_mask = (captions_src != 0) 167 | return image0, image1, captions_src, captions_tgt, image0_label, image1_label 168 | 169 | 170 | def collate_fn_test(data): 171 | """Creates mini-batch tensors from the list of tuples (image, caption). 172 | Args: 173 | data: list of tuple (image, caption). 174 | - image: torch tensor of shape 175 | - caption: torch tensor of shape (?); variable length. 176 | Returns: 177 | images: torch tensor of images. 178 | targets: torch tensor of shape (batch_size, padded_length). 179 | lengths: list; valid length for each padded caption. 180 | """ 181 | # Sort a data list by caption length (descending order). 182 | image0, image1, _, image0_label, image1_label = zip(*data) 183 | # Merge images (from tuple of 3D tensor to 4D tensor). 184 | image0 = torch.stack(image0, 0) 185 | image1 = torch.stack(image1, 0) 186 | 187 | image0_label = torch.stack(image0_label, 0) 188 | image1_label = torch.stack(image1_label, 0) 189 | # # Merge captions (from tuple of 1D tensor to 2D tensor). 190 | # lengths = [len(cap) for cap in captions] 191 | # captions_src = torch.zeros(len(captions), max(lengths)).long() 192 | # captions_tgt = torch.zeros(len(captions), max(lengths)).long() 193 | # for i, cap in enumerate(captions): 194 | # end = lengths[i] 195 | # captions_src[i, :end-1] = cap[:end-1] 196 | # captions_tgt[i, :end-1] = cap[1:end] 197 | # # caption_padding_mask = (captions_src != 0) 198 | # return target_images, candidate_images, captions_src, captions_tgt 199 | return image0, image1, image0_label, image1_label 200 | 201 | 202 | def get_loader(data_file_path, vocab, transform, batch_size, shuffle, num_workers=1,max_seq_len=64, attribute_len=5): 203 | """Returns torch.utils.data.DataLoader for custom dataset.""" 204 | 205 | # relative caption dataset 206 | print('Reading data from',data_file_path) 207 | dataset = Dataset( 208 | data_file_name=data_file_path, 209 | vocab=vocab, 210 | transform=transform, 211 | max_seq_len=max_seq_len, 212 | label_len=attribute_len 213 | ) 214 | print('data size',len(dataset)) 215 | # Data loader for the dataset 216 | # This will return (images, captions, lengths) for each iteration. 217 | # images: a tensor of shape (batch_size, 3, 224, 224). 218 | # captions: a tensor of shape (batch_size, padded_length). 219 | # lengths: a list indicating valid length for each caption. length is (batch_size) 220 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 221 | batch_size=batch_size, 222 | shuffle=shuffle, 223 | num_workers=num_workers, 224 | collate_fn=collate_fn, 225 | timeout=60) 226 | 227 | return data_loader 228 | 229 | def get_loader_test(data_file_path, vocab, transform, batch_size, shuffle, num_workers=1,max_seq_len=64,attribute_len=5): 230 | """Returns torch.utils.data.DataLoader for custom dataset.""" 231 | # relative caption dataset 232 | print('Reading data from',data_file_path) 233 | dataset = Dataset( 234 | data_file_name=data_file_path, 235 | vocab=vocab, 236 | transform=transform, 237 | max_seq_len=max_seq_len, 238 | label_len=attribute_len 239 | ) 240 | print('data size',len(dataset)) 241 | # Data loader for the dataset 242 | # This will return (images, captions, lengths) for each iteration. 243 | # images: a tensor of shape (batch_size, 3, 224, 224). 244 | # captions: a tensor of shape (batch_size, padded_length). 245 | # lengths: a list indicating valid length for each caption. length is (batch_size) 246 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 247 | batch_size=1, 248 | shuffle=shuffle, 249 | num_workers=num_workers, 250 | collate_fn=collate_fn_test, 251 | timeout=60) 252 | return data_loader 253 | 254 | def load_ori_token_data(data_file_name): 255 | test_data_captions = [] 256 | with open(data_file_name, 'r') as f: 257 | data = json.load(f) 258 | 259 | for line in data: 260 | caption_texts = line['captions'] 261 | temp = [] 262 | for c in caption_texts: 263 | # tokens = nltk.tokenize.word_tokenize(str(c).lower()) 264 | temp.append(c) 265 | test_data_captions.append(temp) 266 | 267 | 268 | return test_data_captions 269 | 270 | 271 | def load_ori_token_data_new(data_file_name): 272 | test_data_captions = {} 273 | with open(data_file_name, 'r') as f: 274 | data = json.load(f) 275 | count = 0 276 | for line in data: 277 | caption_texts = line['captions'] 278 | caption_texts = ["it " + x for x in caption_texts] 279 | # temp = [] 280 | # for c in caption_texts: 281 | # # tokens = nltk.tokenize.word_tokenize(str(c).lower()) 282 | # temp.append(c) 283 | test_data_captions[count] = caption_texts 284 | count += 1 285 | 286 | return test_data_captions 287 | -------------------------------------------------------------------------------- /transformer/user_modeling/preprocess_fashionIQ.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import json 3 | import argparse 4 | import csv 5 | import os 6 | 7 | 8 | 9 | 10 | def preprocess_data(file_prefix, topic): 11 | """load tsv_file and save image1_path and image2_path and caption in json and save it. 12 | 13 | Args: 14 | tsv_file_path 15 | 16 | Returns: 17 | None 18 | """ 19 | IMAGE_ROOT = 'resized_images/{}/' 20 | CAPT = 'caption/pair2cap.{}.{}.json' 21 | LABEL = 'attribute_prediction/attributes/{}_attribute_best.pth_{}_attributes.json' 22 | # DICT = 'attribute_prediction/attribtue2idx.json' 23 | # SPLIT = 'image_splits/split.{}.{}.json' 24 | 25 | #training file 26 | training_json = json.load(open(file_prefix + CAPT.format(topic, 'train'))) 27 | dev_json = json.load(open(file_prefix + CAPT.format(topic, 'val'))) 28 | test_json = json.load(open(file_prefix + CAPT.format(topic, 'test'))) 29 | image_path = file_prefix + IMAGE_ROOT.format(topic) 30 | data_train = [] 31 | data_dev = [] 32 | data_test = [] 33 | data_dev_combine = [] 34 | data_test_combine = [] 35 | 36 | 37 | label_path = LABEL.format(topic, topic) 38 | label_json = json.load(open(label_path)) 39 | label_dic = {} 40 | 41 | for line in label_json: 42 | label_dic[line['image']] = {} 43 | label_dic[line['image']]["id"] = line["predict_id"] 44 | label_dic[line['image']]["pred"] = line["prediction"] 45 | 46 | for entry in training_json: 47 | 48 | image0 = image_path + entry["candidate"] + '.jpg' 49 | 50 | image0_label = label_dic[entry["candidate"]+ '.jpg']["id"] 51 | 52 | image1 = image_path + entry["target"] + ".jpg" 53 | 54 | image1_label = label_dic[entry["target"]+ '.jpg']["id"] 55 | 56 | caps = entry["captions"] 57 | 58 | for cap in caps: 59 | 60 | data_train.append({'image0': image0, 'image1':image1, "captions":cap, \ 61 | "image0_label":image0_label, 62 | "image1_label":image1_label, 63 | }) 64 | 65 | for entry in dev_json: 66 | image0 = image_path + entry["candidate"] + '.jpg' 67 | image0_label = label_dic[entry["candidate"]+ '.jpg']["id"] 68 | image1 = image_path + entry["target"] + ".jpg" 69 | image1_label = label_dic[entry["target"]+ '.jpg']["id"] 70 | caps = entry["captions"] 71 | 72 | for cap in caps: 73 | 74 | data_dev.append({'image0': image0, 'image1':image1, "captions":cap,\ 75 | "image0_label":image0_label, 76 | "image1_label":image1_label, 77 | }) 78 | 79 | data_dev_combine.append({'image0': image0, 'image1':image1, "captions":caps,\ 80 | "image0_label":image0_label, 81 | "image1_label":image1_label, 82 | }) 83 | 84 | for entry in test_json: 85 | image0 = image_path + entry["candidate"] + '.jpg' 86 | image0_label = label_dic[entry["candidate"]+ '.jpg']["id"] 87 | image1 = image_path + entry["target"] + ".jpg" 88 | image1_label = label_dic[entry["target"]+ '.jpg']["id"] 89 | caps = entry["captions"] 90 | 91 | 92 | for cap in caps: 93 | 94 | data_test.append({'image0': image0, 'image1':image1, "captions":cap, \ 95 | "image0_label":image0_label, 96 | "image1_label":image1_label, 97 | }) 98 | 99 | data_test_combine.append({'image0': image0, 'image1':image1, "captions":caps, 100 | "image0_label":image0_label, 101 | "image1_label":image1_label, 102 | 103 | }) 104 | 105 | file_prefix += "/" + topic 106 | 107 | if not os.path.exists(file_prefix): 108 | os.makedirs(file_prefix) 109 | 110 | with open(file_prefix + '/data_train.json', 'w') as outfile: 111 | json.dump(data_train, outfile) 112 | 113 | with open(file_prefix + '/data_dev.json', 'w') as outfile: 114 | json.dump(data_dev, outfile) 115 | 116 | with open(file_prefix + '/data_test.json', 'w') as outfile: 117 | json.dump(data_test, outfile) 118 | 119 | with open(file_prefix + '/data_dev_combine.json', 'w') as outfile: 120 | json.dump(data_dev_combine, outfile) 121 | 122 | with open(file_prefix + '/data_test_combine.json', 'w') as outfile: 123 | json.dump(data_test_combine, outfile) 124 | 125 | 126 | 127 | 128 | 129 | 130 | def parse_url(url): 131 | # print('url', url) 132 | tokens = url.split('/') 133 | # print(tokens) 134 | folder = tokens[4] 135 | tokens = tokens[5].split('?') 136 | tokens.reverse() 137 | file = '.'.join(tokens) 138 | # print(tokens[1]) 139 | # print(tokens) 140 | # if len(tokens) > 1: 141 | # file = tokens[1] 142 | # else: 143 | # file = 'null' 144 | # print(tokens[4], tokens[5]) 145 | # print(folder, file) 146 | return '/dccstor/extrastore/Neural-Naturalist/data/resized_images/' + folder + '.' + file 147 | 148 | 149 | def preprocess_data_from_xiaoxiao(file_prefix, topic): 150 | IMAGE_ROOT = 'resized_images/{}/' 151 | CAPT = 'caption/pair2cap.{}.{}.json' 152 | LABEL = 'data/fashion-IQ/predicted_attributes/{}_attribute_best.pth_{}_attributes.json' 153 | DICT = 'attribute_prediction/attribute2idx.json' 154 | # SPLIT = 'image_splits/split.{}.{}.json' 155 | 156 | training_json = json.load(open(file_prefix + CAPT.format(topic, 'train'))) 157 | dev_json = json.load(open(file_prefix + CAPT.format(topic, 'val'))) 158 | test_json = json.load(open(file_prefix + CAPT.format(topic, 'test'))) 159 | image_path = file_prefix + IMAGE_ROOT.format(topic) 160 | data_train = [] 161 | data_dev = [] 162 | data_test = [] 163 | data_dev_combine = [] 164 | data_test_combine = [] 165 | 166 | 167 | label_predict = LABEL.format(topic, topic) 168 | label_predict = json.load(open(label_predict)) 169 | 170 | 171 | label_dic = json.load(open(DICT)) 172 | 173 | # for line in label_json: 174 | # label_dic[line['image']] = {} 175 | # label_dic[line['image']]["id"] = line["predict_id"] 176 | # label_dic[line['image']]["pred"] = line["prediction"] 177 | 178 | for entry in training_json: 179 | 180 | image0 = image_path + entry["candidate"] + '.jpg' 181 | 182 | image0_label = [ label_dic[x] for x in label_predict[entry["candidate"]]["predict"]] 183 | image0_full_score = label_predict[entry["candidate"]]["full_predict"] 184 | 185 | image1 = image_path + entry["target"] + ".jpg" 186 | 187 | image1_label = [ label_dic[x] for x in label_predict[entry["target"]]["predict"]] 188 | image1_full_score = label_predict[entry["target"]]["full_predict"] 189 | 190 | caps = entry["captions"] 191 | 192 | for cap in caps: 193 | 194 | data_train.append({'image0': image0, 'image1':image1, "captions":cap, \ 195 | "image0_label":image0_label, 196 | "image1_label":image1_label, 197 | "image0_full_score":image0_full_score, 198 | "image1_full_score":image1_full_score 199 | }) 200 | 201 | for entry in dev_json: 202 | image0 = image_path + entry["candidate"] + '.jpg' 203 | image0_label = [ label_dic[x] for x in label_predict[entry["candidate"]]["predict"]] 204 | image1 = image_path + entry["target"] + ".jpg" 205 | image1_label = [ label_dic[x] for x in label_predict[entry["target"]]["predict"]] 206 | caps = entry["captions"] 207 | image0_full_score = label_predict[entry["candidate"]]["full_predict"] 208 | image1_full_score = label_predict[entry["target"]]["full_predict"] 209 | for cap in caps: 210 | 211 | data_dev.append({'image0': image0, 'image1':image1, "captions":cap,\ 212 | "image0_label":image0_label, 213 | "image1_label":image1_label, 214 | "image0_full_score":image0_full_score, 215 | "image1_full_score":image1_full_score 216 | }) 217 | 218 | data_dev_combine.append({'image0': image0, 'image1':image1, "captions":caps,\ 219 | "image0_label":image0_label, 220 | "image1_label":image1_label, 221 | "image0_full_score":image0_full_score, 222 | "image1_full_score":image1_full_score 223 | }) 224 | 225 | for entry in test_json: 226 | image0 = image_path + entry["candidate"] + '.jpg' 227 | image0_label = [ label_dic[x] for x in label_predict[entry["candidate"]]["predict"]] 228 | image1 = image_path + entry["target"] + ".jpg" 229 | image1_label = [ label_dic[x] for x in label_predict[entry["target"]]["predict"]] 230 | caps = entry["captions"] 231 | image0_full_score = label_predict[entry["candidate"]]["full_predict"] 232 | image1_full_score = label_predict[entry["target"]]["full_predict"] 233 | 234 | 235 | for cap in caps: 236 | 237 | data_test.append({'image0': image0, 'image1':image1, "captions":cap, \ 238 | "image0_label":image0_label, 239 | "image1_label":image1_label, 240 | "image0_full_score":image0_full_score, 241 | "image1_full_score":image1_full_score 242 | }) 243 | 244 | data_test_combine.append({'image0': image0, 'image1':image1, "captions":caps, 245 | "image0_label":image0_label, 246 | "image1_label":image1_label, 247 | "image0_full_score":image0_full_score, 248 | "image1_full_score":image1_full_score 249 | 250 | }) 251 | 252 | file_prefix += "/" + topic 253 | 254 | if not os.path.exists(file_prefix): 255 | os.makedirs(file_prefix) 256 | 257 | with open(file_prefix + '/data_train.json', 'w') as outfile: 258 | json.dump(data_train, outfile) 259 | 260 | with open(file_prefix + '/data_dev.json', 'w') as outfile: 261 | json.dump(data_dev, outfile) 262 | 263 | with open(file_prefix + '/data_test.json', 'w') as outfile: 264 | json.dump(data_test, outfile) 265 | 266 | with open(file_prefix + '/data_dev_combine.json', 'w') as outfile: 267 | json.dump(data_dev_combine, outfile) 268 | 269 | with open(file_prefix + '/data_test_combine.json', 'w') as outfile: 270 | json.dump(data_test_combine, outfile) 271 | 272 | def process_url(url): 273 | file = parse_url(url) 274 | if file[-1] == '.': 275 | file = file + 'jpg' 276 | # make_folder(folder) 277 | 278 | # if not os.path.isfile(file): 279 | # with open(file, 'wb') as f: 280 | # resp = requests.get(url, verify=False) 281 | # f.write(resp.content) 282 | # f.close() 283 | return file 284 | 285 | def main(args): 286 | ''' Main function ''' 287 | 288 | 289 | 290 | preprocess_data_from_xiaoxiao(args.data_prefix, 'dress') 291 | preprocess_data_from_xiaoxiao(args.data_prefix, 'shirt') 292 | preprocess_data_from_xiaoxiao(args.data_prefix, "toptee") 293 | 294 | 295 | if __name__ == '__main__': 296 | 297 | parser = argparse.ArgumentParser() 298 | parser.add_argument('-data_prefix', required=True) 299 | 300 | args = parser.parse_args() 301 | main(args) -------------------------------------------------------------------------------- /transformer/interactive_retrieval/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import shutil 4 | import time 5 | import json 6 | import os 7 | import tqdm 8 | import torch 9 | import torch.nn as nn 10 | from torchvision import transforms 11 | from data_loader import Dataset 12 | import utils 13 | import Ranker 14 | import UserModel 15 | 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | 20 | def create_exp_dir(path, scripts_to_save=None): 21 | if not os.path.exists(path): 22 | os.mkdir(path) 23 | print('Experiment dir : {}'.format(path)) 24 | if scripts_to_save is not None: 25 | if not os.path.exists(os.path.join(path, 'scripts')): 26 | os.mkdir(os.path.join(path, 'scripts')) 27 | for script in scripts_to_save: 28 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 29 | shutil.copyfile(script, dst_file) 30 | return 31 | 32 | 33 | def load_test_image_features(args): 34 | # Image preprocessing, normalization for the pretrained b7 35 | transform = transforms.Compose([ 36 | transforms.Resize(args.crop_size), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.485, 0.456, 0.406), 39 | (0.229, 0.224, 0.225))]) 40 | 41 | # test split 42 | fts_file = os.path.join(args.save, 'b7_v2.{}.{}.pair.th'.format( 43 | args.data_set, 'test')) 44 | if os.path.isfile(fts_file): 45 | print('[INFO] loading image features: {}'.format(fts_file)) 46 | fts = torch.load(fts_file, map_location='cpu') 47 | else: 48 | print('[INFO] computing image features: {}'.format(fts_file)) 49 | data_loader = Dataset( 50 | args.image_folder.format(args.data_set), 51 | args.data_split_file.format(args.data_set, 'test'), 52 | transform, num_workers=args.num_workers) 53 | attr_file = args.attribute_file.format(args.data_set) 54 | 55 | fts = utils.extract_features(data_loader, attr_file, 56 | args.attr2idx_file, device, 57 | args.image_model) 58 | torch.save(fts, fts_file) 59 | 60 | return fts 61 | 62 | 63 | def eval_batch(fts, captioner, retriever, args): 64 | criterion = nn.TripletMarginLoss( 65 | reduction='mean', margin=args.margin).to(device) 66 | # generate a mapping for dev, to ensure sampling bias is reduced 67 | num_target = len(fts['asins']) 68 | 69 | batch_size = args.batch_size 70 | ranker = Ranker.Ranker(device) 71 | total_step = math.floor(num_target / batch_size) 72 | 73 | ranking_tracker = [0] * args.num_dialog_turns 74 | loss_tracker = [0] * args.num_dialog_turns 75 | 76 | with open('data/shuffled.{}.{}.json'.format( 77 | args.data_set, 'test')) as f: 78 | first_candidate_set = json.load(f) 79 | 80 | with torch.no_grad(): 81 | retriever.eval() 82 | ranker.update_emb(fts, args.batch_size, retriever) 83 | 84 | retriever.eval() 85 | ret_results = {} 86 | total_time = 0 87 | 88 | for step in tqdm.tqdm(range(total_step)): 89 | # sample target 90 | target_ids = torch.tensor( 91 | [i for i in 92 | range(step * batch_size, (step + 1) * batch_size)]).to( 93 | device=device, dtype=torch.long) 94 | 95 | # sample first batch of candidates 96 | candidate_ids = torch.tensor( 97 | [first_candidate_set[i] 98 | for i in range(step * batch_size, (step + 1) * batch_size)], 99 | device=device, dtype=torch.long) 100 | 101 | # keep track of results 102 | ret_result = {} 103 | for batch_id in range(target_ids.size(0)): 104 | idx = target_ids[batch_id].cpu().item() 105 | ret_result[idx] = {} 106 | ret_result[idx]['candidate'] = [] 107 | ret_result[idx]['ranking'] = [] 108 | ret_result[idx]['caption'] = [] 109 | 110 | target_img_ft = utils.get_image_batch(fts, target_ids) 111 | target_img_ft = target_img_ft.to(device) 112 | target_img_emb = retriever.encode_image(target_img_ft) 113 | 114 | target_attr = utils.get_attribute_batch(fts, target_ids) 115 | target_attr = target_attr.to(device) 116 | 117 | # clean up dialog history tracker 118 | retriever.init_hist() 119 | # history_hidden = history_hidden.expand_as(target_img_emb) 120 | 121 | loss = 0 122 | 123 | for d_turn in range(args.num_dialog_turns): 124 | last_timer = int(round(time.time() * 1000)) 125 | # get candidate image features 126 | candidate_img_ft = utils.get_image_batch(fts, candidate_ids) 127 | candidate_img_ft = candidate_img_ft.to(device) 128 | 129 | candidate_attr = utils.get_attribute_batch(fts, candidate_ids) 130 | candidate_attr = candidate_attr.to(device) 131 | # generate captions from model 132 | total_time += (int(round(time.time() * 1000)) - last_timer) 133 | with torch.no_grad(): 134 | sentence_ids, caps = captioner.get_caption( 135 | target_img_ft, candidate_img_ft, 136 | target_attr, candidate_attr, return_cap=True) 137 | last_timer = int(round(time.time() * 1000)) 138 | sentence_ids = sentence_ids.to(device) 139 | 140 | candidate_img_ft = candidate_img_ft.to(device) 141 | 142 | history_hidden = retriever.forward( 143 | text=sentence_ids, image=candidate_img_ft, 144 | attribute=candidate_attr) 145 | 146 | # sample negatives, update tracker's output to 147 | # match targets via triplet loss 148 | negative_ids = torch.tensor( 149 | [0]*args.batch_size, device=device, dtype=torch.long) 150 | negative_ids.random_(0, num_target) 151 | 152 | negative_img_ft = utils.get_image_batch(fts, negative_ids) 153 | negative_img_ft = negative_img_ft.to(device) 154 | negative_img_emb = retriever.encode_image(negative_img_ft) 155 | 156 | # accumulate loss 157 | loss_tmp = criterion(history_hidden, target_img_emb, 158 | negative_img_emb) 159 | loss += loss_tmp 160 | loss_tracker[d_turn] += loss_tmp.item() 161 | 162 | # generate new candidates, compute ranking information 163 | with torch.no_grad(): 164 | candidate_ids = ranker.nearest_neighbors(history_hidden) 165 | ranking = ranker.compute_rank(history_hidden, target_ids) 166 | ranking_tracker[d_turn] += (ranking.mean().item() / 167 | (num_target * 1.0)) 168 | 169 | for batch_id in range(target_ids.size(0)): 170 | idx = target_ids[batch_id].cpu().item() 171 | ret_result[idx]['caption'].append( 172 | caps[batch_id]) 173 | ret_result[idx]['candidate'].append( 174 | candidate_ids[batch_id].item()) 175 | ret_result[idx]['ranking'].append( 176 | ranking[batch_id].item()) 177 | 178 | total_time += (int(round(time.time() * 1000)) - last_timer) 179 | 180 | ret_results.update(ret_result) 181 | 182 | loss = loss.item() / total_step 183 | for i in range(args.num_dialog_turns): 184 | ranking_tracker[i] /= total_step 185 | loss_tracker[i] /= total_step 186 | 187 | metrics = {'loss': loss, 'score': 5 - sum(ranking_tracker), 188 | 'loss_tracker': loss_tracker, 189 | 'ranking_tracker': ranking_tracker, 'retrieve_time': total_time/float(num_target)} 190 | return metrics, ret_results 191 | 192 | 193 | def eval(args): 194 | def logging(s, print_=True, log_=False): 195 | if print_: 196 | print(s) 197 | return 198 | 199 | logging(str(args)) 200 | 201 | # load image features (or compute and store them if necessary) 202 | fts = load_test_image_features(args) 203 | 204 | # user model: captioner 205 | captioner = UserModel.UserModel(args, mode='greedy') 206 | captioner.to(device) 207 | 208 | # ranker 209 | checkpoint_model = torch.load( 210 | os.path.join(args.trained_model, 'checkpoint_model.th')) 211 | retriever = checkpoint_model['retrieval_model'] 212 | opt = checkpoint_model['args'] 213 | 214 | print("=" * 88) 215 | print(opt) 216 | print("=" * 88) 217 | 218 | retriever.eval() 219 | logging('-' * 77) 220 | 221 | with torch.no_grad(): 222 | metrics, ret_results = eval_batch(fts, captioner, retriever, opt) 223 | res = metrics['ranking_tracker'] 224 | logging( 225 | '|eval loss: {:8.3f} | score {:8.5f} | ' 226 | 'rank {:5.3f}/{:5.3f}/{:5.3f}/{:5.3f}/{:5.3f} | time:{}'.format( 227 | metrics['loss'], metrics['score'], 228 | 1 - res[0], 1 - res[1], 1 - res[2], 1 - res[3], 1 - res[4], metrics['retrieve_time'])) 229 | logging('-' * 77) 230 | 231 | with open('prediction.test.{}.u{}.l{}.json'.format( 232 | opt.data_set, 233 | opt.hidden_unit_num, opt.layer_num), 'w') as f: 234 | json.dump(ret_results, f, indent=4) 235 | logging('evaluation complete') 236 | 237 | 238 | if __name__ == '__main__': 239 | parser = argparse.ArgumentParser() 240 | parser.add_argument('--image_folder', type=str, 241 | default='../resized_images/{}/') 242 | parser.add_argument('--data_split_file', type=str, 243 | default='data/split.{}.{}.json') 244 | parser.add_argument('--attribute_file', type=str, 245 | default='../attribute_prediction/prediction/' 246 | 'predict_{}_b7_ft.json') 247 | parser.add_argument('--attr2idx_file', type=str, 248 | default='../attribute_prediction/data/' 249 | 'attribute2idx.json') 250 | parser.add_argument('--image_model', type=str, 251 | default='../attribute_prediction/deepfashion_models/' 252 | 'dfattributes_efficientnet_b7ns.pth') 253 | parser.add_argument('--save', type=str, default='models/', 254 | help='path for saving ranker models') 255 | parser.add_argument('--trained_model', type=str, 256 | default='models/dress') 257 | parser.add_argument('--crop_size', type=int, default=224, 258 | help='size for randomly cropping images') 259 | parser.add_argument('--data_set', type=str, default='dress', 260 | help='dress / toptee / shirt') 261 | parser.add_argument('--exp_folder', type=str, default='v2') 262 | 263 | # User model parameters 264 | parser.add_argument('--user_model_file', type=str, 265 | default='../user_modeling/models/' 266 | '{}-efficient-b7-finetune.chkpt') 267 | parser.add_argument('--user_vocab_file', type=str, 268 | default='../user_modeling/data/{}_vocab.json') 269 | parser.add_argument('--max_seq_len', type=int, default=8, 270 | help='maximum caption length') 271 | parser.add_argument('--beam_size', type=int, default=5, 272 | help='beam search branch size') 273 | parser.add_argument('--glove_emb_file', type=str, 274 | default='data/{}_emb.pt') 275 | parser.add_argument('--attribute_num', type=int, 276 | default=1000) 277 | 278 | # Model parameters 279 | parser.add_argument('--num_workers', type=int, default=4) 280 | parser.add_argument('--num_dialog_turns', type=int, default=5) 281 | parser.add_argument('--margin', type=float, default=1) 282 | parser.add_argument('--clip_norm', type=float, default=10) 283 | 284 | parser.add_argument('--no_cuda', action='store_true') 285 | parser.add_argument('--batch_size', type=int, default=2) 286 | parser.add_argument('--learning_rate', type=float, default=0.001) 287 | 288 | parser.add_argument('--device', type=str, default='cuda') 289 | 290 | args = parser.parse_args() 291 | 292 | eval(args) 293 | 294 | -------------------------------------------------------------------------------- /transformer/user_modeling/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from Layers import EncoderLayer, DecoderLayer 4 | from Embed import Embedder, PositionalEncoder 5 | from Sublayers import FeedForward, MultiHeadAttention, Norm 6 | import copy 7 | import torchvision.models as models 8 | import numpy as np 9 | from torch.autograd import Variable 10 | from efficientnet_pytorch import EfficientNet 11 | 12 | Constants_PAD = 0 13 | 14 | def get_clones(module, N): 15 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 16 | 17 | class Encoder(nn.Module): 18 | def __init__(self, d_model, N_layers, heads, dropout): 19 | super().__init__() 20 | self.N_layers = N_layers 21 | # self.embed = Embedder(vocab_size, d_model) 22 | # self.pe = PositionalEncoder(d_model, dropout=dropout) 23 | # self.attn = MultiHeadAttention(heads, d_model, dropout=dropout) 24 | self.layers = get_clones(EncoderLayer(d_model, heads, dropout), N_layers) 25 | self.norm = Norm(d_model) 26 | # self.norm_1 = Norm(d_model) 27 | # self.norm_2 = Norm(d_model) 28 | # self.dropout= nn.Dropout(dropout) 29 | def forward(self, x): 30 | # x = self.embed(src) 31 | # x = self.pe(x) 32 | # x = src 33 | for i in range(self.N_layers): 34 | x = self.layers[i](x) 35 | return self.norm(x) 36 | 37 | # def forward(self, image1, image2): 38 | # image1 = self.norm_1(image1) 39 | # image2 = self.norm_2(image2) 40 | # x = self.dropout(self.attn(image1,image2,image2)) 41 | # for i in range(self.N_layers): 42 | # x = self.layers[i](x) 43 | # return self.norm(x) 44 | 45 | class Decoder(nn.Module): 46 | def __init__(self, vocab_size, d_model, N_layers, heads, dropout): 47 | super().__init__() 48 | self.N_layers = N_layers 49 | self.embed = Embedder(vocab_size, d_model) 50 | self.pe = PositionalEncoder(d_model, dropout=dropout) 51 | self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N_layers) 52 | self.norm = Norm(d_model) 53 | def forward(self, trg, e_outputs, trg_mask): 54 | x = self.embed(trg) 55 | x = self.pe(x) 56 | for i in range(self.N_layers): 57 | x = self.layers[i](x, e_outputs, src_mask=None, trg_mask=trg_mask) 58 | return self.norm(x) 59 | 60 | class CNN_Embedding(nn.Module): 61 | def __init__(self, d_model, model_name, pretrained_model=None): 62 | """Load the pretrained ResNet-152 and replace top fc layer.""" 63 | super().__init__() 64 | 65 | self.d_model = d_model 66 | self.model_name = model_name 67 | 68 | if model_name[:6] == 'resnet': 69 | print("cnn model name: ", model_name) 70 | if model_name == "resnet101": 71 | 72 | model = models.resnet101(pretrained=True) 73 | elif model_name == "resnet18": 74 | model = models.resnet18(pretrained=True) 75 | 76 | if pretrained_model: 77 | print("cnn initialed from pretrained_model") 78 | ckpt = torch.load(pretrained_model, map_location='cpu') 79 | if "model_state" in ckpt: 80 | model.load_state_dict(ckpt["model_state"]) 81 | else: 82 | model.load_state_dict(ckpt) 83 | 84 | modules = list(model.children())[:-1] # delete the last fc layer. 85 | self.model = nn.Sequential(*modules) 86 | 87 | for param in self.model.parameters(): 88 | param.requires_grad = False 89 | self.linear = nn.Linear(model.fc.in_features, d_model) 90 | self.bn = nn.BatchNorm1d(model.fc.in_features, momentum=0.01) 91 | 92 | 93 | elif model_name[:12] == "efficientnet": 94 | self.model = EfficientNet.from_pretrained(model_name) 95 | if pretrained_model: 96 | ckpt = torch.load(pretrained_model, map_location='cpu') 97 | if "model_state" in ckpt: 98 | self.model.load_state_dict(ckpt["model_state"]) 99 | else: 100 | self.model.load_state_dict(ckpt) 101 | 102 | for param in self.model.parameters(): 103 | param.requires_grad = False 104 | 105 | self.linear = nn.Linear(self.model._fc.in_features, d_model) 106 | self.bn = nn.BatchNorm1d(self.model._fc.in_features, momentum=0.01) 107 | 108 | 109 | def get_trainable_parameters(self): 110 | return list(self.linear.parameters()) + list(self.bn.parameters()) 111 | 112 | 113 | 114 | def forward(self, image): 115 | with torch.no_grad(): 116 | if self.model_name[:12] == "efficientnet": 117 | img_ft = self.model.extract_features(image) 118 | img_ft = self.model._avg_pooling(img_ft) 119 | img_ft = img_ft.flatten(start_dim=1) 120 | img_ft = self.model._dropout(img_ft) 121 | else: 122 | img_ft = self.model(image) 123 | 124 | img_ft = self.linear(self.bn(img_ft.reshape(img_ft.size(0), img_ft.size(1), -1)).transpose(1,2)) # (batch_size, d, d, f) -> (batch_size, d^2, f) 125 | #(batch_size, f, 1, 1) -> (batch_size, 1, f) 126 | return img_ft#.transpose(0,1)#(1, batch_size,f) 127 | 128 | class Joint_Encoding: 129 | def __init__(self, joint_encoding_function): 130 | # super().__init__() 131 | if joint_encoding_function == 'addition': 132 | self.joint_encoding_function = lambda x1, x2 : x1 + x2 133 | elif joint_encoding_function == 'deduction': 134 | self.joint_encoding_function = lambda x1, x2 : x1 - x2 135 | elif joint_encoding_function == 'max': 136 | self.joint_encoding_function = lambda x1, x2 : torch.max(x1,x2) 137 | elif joint_encoding_function == 'element_multiplication': 138 | self.joint_encoding_function = lambda x1, x2 : x1 * x2 139 | 140 | def __call__(self,E1, E2): 141 | 142 | return self.joint_encoding_function(E1, E2) 143 | 144 | class Attribute_Embedding(nn.Module): 145 | def __init__(self, d_model, attribute_vocab_size): 146 | """Load the pretrained ResNet-152 and replace top fc layer.""" 147 | super().__init__() 148 | self.embed = nn.Linear(attribute_vocab_size, d_model)#Embedder(attribute_vocab_size, d_model) 149 | self.norm = nn.BatchNorm1d(attribute_vocab_size, momentum=0.01) 150 | 151 | def forward(self, attribute): 152 | attribute = self.norm(attribute) 153 | attribute = self.embed(attribute) 154 | return attribute 155 | 156 | class Transformer(nn.Module): 157 | def __init__(self, trg_vocab, d_model, N, heads, dropout, cnn_model_name, \ 158 | joint_encoding_function, attribute_vocab_size=1000, cnn_pretrained_model=None, add_attribute=False): 159 | super().__init__() 160 | self.add_attribute = add_attribute 161 | self.cnn1 = CNN_Embedding(d_model, cnn_model_name, cnn_pretrained_model) 162 | self.cnn2 = CNN_Embedding(d_model, cnn_model_name, cnn_pretrained_model) 163 | # self.bn = nn.BatchNorm1d(d_model, momentum=0.01) 164 | if self.add_attribute: 165 | self.attribute_embedding = Attribute_Embedding(d_model, attribute_vocab_size) 166 | # self.attribute_embedding2 = Attribute_Embedding(d_model, attribute_vocab_size) 167 | self.joint_encoding = Joint_Encoding(joint_encoding_function) 168 | self.encoder = Encoder(d_model, N, heads, dropout) 169 | self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout) 170 | self.out = nn.Linear(d_model, trg_vocab) 171 | 172 | def get_trainable_parameters(self): 173 | if not self.add_attribute: 174 | return self.cnn1.get_trainable_parameters() \ 175 | + self.cnn2.get_trainable_parameters() \ 176 | + list(self.encoder.parameters()) \ 177 | + list(self.decoder.parameters()) \ 178 | + list(self.out.parameters()) 179 | else: 180 | return self.cnn1.get_trainable_parameters() \ 181 | + self.cnn2.get_trainable_parameters() \ 182 | + list(self.encoder.parameters()) \ 183 | + list(self.decoder.parameters()) \ 184 | + list(self.out.parameters()) \ 185 | + list(self.attribute_embedding.parameters()) 186 | # + list(self.bn.parameters()) \ 187 | # + list(self.attribute_embedding1.parameters()) \ 188 | # + list(self.attribute_embedding2.parameters()) 189 | 190 | # def get_parameters_to_initial(self): 191 | # return list(self.encoder.parameters()) \ 192 | # + list(self.decoder.parameters()) \ 193 | # + list(self.out.parameters()) \ 194 | # + list(self.attribute_embedding1.parameters()) \ 195 | # + list(self.attribute_embedding2.parameters()) 196 | 197 | 198 | def forward(self, image0, image1, trg, trg_mask, image0_attribute, image1_attribute): 199 | #image1, image2 = image2, image1 200 | 201 | image0 = self.cnn1(image0) 202 | 203 | image1 = self.cnn2(image1) 204 | 205 | if self.add_attribute: 206 | attribute = self.attribute_embedding(image0_attribute - image1_attribute).unsqueeze(1) 207 | # attribute = self.norm(attribute) 208 | 209 | # image0_attribute = self.attribute_embedding1(image0_attribute) 210 | 211 | # image1_attribute = self.attribute_embedding2(image1_attribute) 212 | 213 | # image0 = torch.cat((image0, image0_attribute), 1) 214 | # image1 = torch.cat((image1, image1_attribute), 1) 215 | 216 | #joint_encoding = self.joint_encoding(torch.cat((image0, image0_attribute),1), torch.cat((image1,image1_attribute),1)) 217 | joint_encoding = self.joint_encoding(image0, image1) 218 | joint_encoding = torch.cat((joint_encoding, attribute), 1) 219 | # joint_encoding = self.bn(joint_encoding.transpose(1,2)).transpose(1,2) 220 | # joint_encoding = torch.cat((joint_encoding, image0_attribute), 1) 221 | 222 | # joint_encoding = torch.cat((joint_encoding, image1_attribute), 1) 223 | else: 224 | joint_encoding = self.joint_encoding(image0, image1) 225 | 226 | joint_encoding = self.encoder(joint_encoding) 227 | #print("DECODER") 228 | output = self.decoder(trg, joint_encoding, trg_mask) 229 | 230 | output = self.out(output) 231 | 232 | return output 233 | 234 | def get_model(opt, load_weights=False): 235 | 236 | 237 | 238 | if load_weights: 239 | 240 | device = torch.device('cuda' if opt.cuda else 'cpu') 241 | 242 | checkpoint = torch.load(opt.pretrained_model + '.chkpt') 243 | 244 | model_opt = checkpoint['settings'] 245 | 246 | model = Transformer(model_opt.vocab_size, model_opt.d_model, \ 247 | model_opt.n_layers, model_opt.n_heads, model_opt.dropout, \ 248 | model_opt.cnn_name, model_opt.joint_enc_func, \ 249 | model_opt.attribute_vocab_size, model_opt.cnn_pretrained_model, model_opt.add_attribute, 250 | ) 251 | 252 | model.load_state_dict(checkpoint['model']) 253 | 254 | print('[Info] Trained model state loaded from: ', opt.pretrained_model) 255 | 256 | model = model.to(device) 257 | 258 | 259 | else: 260 | assert opt.d_model % opt.n_heads == 0 261 | 262 | assert opt.dropout < 1 263 | 264 | model = Transformer(opt.vocab_size, opt.d_model, opt.n_layers, opt.n_heads, opt.dropout, \ 265 | opt.cnn_name, opt.joint_enc_func, opt.attribute_vocab_size, opt.cnn_pretrained_model, \ 266 | opt.add_attribute) 267 | 268 | for p in model.get_trainable_parameters(): 269 | if p.dim() > 1: 270 | nn.init.xavier_uniform_(p) 271 | 272 | model.to(opt.device) 273 | 274 | return model 275 | 276 | def nopeak_mask(size): 277 | np_mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8') 278 | np_mask = Variable(torch.from_numpy(np_mask) == 0) 279 | 280 | return np_mask 281 | 282 | def create_masks(trg): 283 | # src_mask = (src != Constants_PAD.unsqueeze(-2) 284 | 285 | if trg is not None: 286 | trg_mask = (trg != Constants_PAD).unsqueeze(-2) 287 | size = trg.size(1) # get seq_len for matrix 288 | np_mask = nopeak_mask(size).to(trg_mask.device) 289 | 290 | trg_mask = trg_mask & np_mask 291 | 292 | else: 293 | trg_mask = None 294 | 295 | return trg_mask 296 | 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /transformer/interactive_retrieval/user_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | 6 | import copy 7 | from torch.autograd import Variable 8 | 9 | 10 | class Embedder(nn.Module): 11 | def __init__(self, vocab_size, d_model): 12 | super().__init__() 13 | self.d_model = d_model 14 | self.embed = nn.Embedding(vocab_size, d_model) 15 | 16 | def forward(self, x): 17 | return self.embed(x) 18 | 19 | 20 | class PositionalEncoder(nn.Module): 21 | def __init__(self, d_model, max_seq_len=200, dropout=0.1): 22 | super().__init__() 23 | self.d_model = d_model 24 | self.dropout = nn.Dropout(dropout) 25 | # create constant 'pe' matrix with values dependant on 26 | # pos and i 27 | pe = torch.zeros(max_seq_len, d_model) 28 | for pos in range(max_seq_len): 29 | for i in range(0, d_model, 2): 30 | pe[pos, i] = \ 31 | math.sin(pos / (10000 ** ((2 * i) / d_model))) 32 | pe[pos, i + 1] = \ 33 | math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model))) 34 | pe = pe.unsqueeze(0) 35 | self.register_buffer('pe', pe) 36 | 37 | def forward(self, x): 38 | # make embeddings relatively larger 39 | x = x * math.sqrt(self.d_model) 40 | # add constant to embedding 41 | seq_len = x.size(1) 42 | pe = Variable(self.pe[:, :seq_len], requires_grad=False) 43 | if x.is_cuda: 44 | pe.cuda() 45 | x = x + pe 46 | return self.dropout(x) 47 | 48 | 49 | class Norm(nn.Module): 50 | def __init__(self, d_model, eps=1e-6): 51 | super().__init__() 52 | 53 | self.size = d_model 54 | 55 | # create two learnable parameters to calibrate normalisation 56 | self.alpha = nn.Parameter(torch.ones(self.size)) 57 | self.bias = nn.Parameter(torch.zeros(self.size)) 58 | 59 | self.eps = eps 60 | 61 | def forward(self, x): 62 | norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \ 63 | / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias 64 | return norm 65 | 66 | 67 | def attention(q, k, v, d_k, mask=None, dropout=None): 68 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) 69 | 70 | if mask is not None: 71 | mask = mask.unsqueeze(1) 72 | scores = scores.masked_fill(mask == 0, -1e9) 73 | 74 | scores = F.softmax(scores, dim=-1) 75 | 76 | if dropout is not None: 77 | scores = dropout(scores) 78 | 79 | output = torch.matmul(scores, v) 80 | return output 81 | 82 | 83 | class MultiHeadAttention(nn.Module): 84 | def __init__(self, heads, d_model, dropout=0.1): 85 | super().__init__() 86 | 87 | self.d_model = d_model 88 | self.d_k = d_model // heads 89 | self.h = heads 90 | 91 | self.q_linear = nn.Linear(d_model, d_model) 92 | self.v_linear = nn.Linear(d_model, d_model) 93 | self.k_linear = nn.Linear(d_model, d_model) 94 | 95 | self.dropout = nn.Dropout(dropout) 96 | self.out = nn.Linear(d_model, d_model) 97 | 98 | def forward(self, q, k, v, mask=None): 99 | bs = q.size(0) 100 | 101 | # perform linear operation and split into N heads 102 | k = self.k_linear(k).view(bs, -1, self.h, self.d_k) 103 | q = self.q_linear(q).view(bs, -1, self.h, self.d_k) 104 | v = self.v_linear(v).view(bs, -1, self.h, self.d_k) 105 | 106 | # transpose to get dimensions bs * N * sl * d_model 107 | k = k.transpose(1, 2) 108 | q = q.transpose(1, 2) 109 | v = v.transpose(1, 2) 110 | 111 | # calculate attention using function we will define next 112 | scores = attention(q, k, v, self.d_k, mask, self.dropout) 113 | # concatenate heads and put through final linear layer 114 | concat = scores.transpose(1, 2).contiguous() \ 115 | .view(bs, -1, self.d_model) 116 | output = self.out(concat) 117 | 118 | return output 119 | 120 | 121 | class FeedForward(nn.Module): 122 | def __init__(self, d_model, d_ff=2048, dropout=0.1): 123 | super().__init__() 124 | 125 | # We set d_ff as a default to 2048 126 | self.linear_1 = nn.Linear(d_model, d_ff) 127 | self.dropout = nn.Dropout(dropout) 128 | self.linear_2 = nn.Linear(d_ff, d_model) 129 | 130 | def forward(self, x): 131 | x = self.dropout(F.relu(self.linear_1(x))) 132 | x = self.linear_2(x) 133 | return x 134 | 135 | 136 | class EncoderLayer(nn.Module): 137 | def __init__(self, d_model, heads, dropout=0.1): 138 | super().__init__() 139 | self.norm_1 = Norm(d_model) 140 | self.norm_2 = Norm(d_model) 141 | self.attn = MultiHeadAttention(heads, d_model, dropout=dropout) 142 | self.ff = FeedForward(d_model, dropout=dropout) 143 | self.dropout_1 = nn.Dropout(dropout) 144 | self.dropout_2 = nn.Dropout(dropout) 145 | 146 | def forward(self, x, mask=None): 147 | x2 = self.norm_1(x) 148 | x = x + self.dropout_1(self.attn(x2, x2, x2, mask)) 149 | x2 = self.norm_2(x) 150 | x = x + self.dropout_2(self.ff(x2)) 151 | return x 152 | 153 | 154 | # build a decoder layer with two multi-head attention layers and 155 | # one feed-forward layer 156 | class DecoderLayer(nn.Module): 157 | def __init__(self, d_model, heads, dropout=0.1): 158 | super().__init__() 159 | self.norm_1 = Norm(d_model) 160 | self.norm_2 = Norm(d_model) 161 | self.norm_3 = Norm(d_model) 162 | 163 | self.dropout_1 = nn.Dropout(dropout) 164 | self.dropout_2 = nn.Dropout(dropout) 165 | self.dropout_3 = nn.Dropout(dropout) 166 | 167 | self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout) 168 | self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout) 169 | self.ff = FeedForward(d_model, dropout=dropout) 170 | 171 | def forward(self, x, e_outputs, src_mask=None, trg_mask=None): 172 | x2 = self.norm_1(x) 173 | x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask)) 174 | x2 = self.norm_2(x) 175 | x = x + self.dropout_2( 176 | self.attn_2(x2, e_outputs, e_outputs, src_mask)) 177 | x2 = self.norm_3(x) 178 | x = x + self.dropout_3(self.ff(x2)) 179 | return x 180 | 181 | 182 | def get_clones(module, N): 183 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 184 | 185 | 186 | class Encoder(nn.Module): 187 | def __init__(self, d_model, N_layers, heads, dropout): 188 | super().__init__() 189 | self.N_layers = N_layers 190 | self.pe = PositionalEncoder(d_model, dropout=dropout) 191 | self.layers = get_clones( 192 | EncoderLayer(d_model, heads, dropout), N_layers) 193 | self.norm = Norm(d_model) 194 | 195 | def forward(self, x): 196 | x = self.pe(x) 197 | for i in range(self.N_layers): 198 | x = self.layers[i](x) 199 | return self.norm(x) 200 | 201 | 202 | class Decoder(nn.Module): 203 | def __init__(self, vocab_size, d_model, N_layers, heads, dropout): 204 | super().__init__() 205 | self.N_layers = N_layers 206 | self.embed = Embedder(vocab_size, d_model) 207 | self.pe = PositionalEncoder(d_model, dropout=dropout) 208 | self.layers = get_clones( 209 | DecoderLayer(d_model, heads, dropout), N_layers) 210 | self.norm = Norm(d_model) 211 | 212 | def forward(self, trg, e_outputs, trg_mask): 213 | x = self.embed(trg) 214 | x = self.pe(x) 215 | for i in range(self.N_layers): 216 | x = self.layers[i](x, e_outputs, src_mask=None, trg_mask=trg_mask) 217 | return self.norm(x) 218 | 219 | 220 | class CNN_Embedding(nn.Module): 221 | def __init__(self, d_model, model_name, pretrained_model=None): 222 | """Load the pretrained ResNet-152 and replace top fc layer.""" 223 | super().__init__() 224 | 225 | self.d_model = d_model 226 | self.model_name = model_name 227 | print("cnn model name: ", model_name) 228 | if model_name[:6] == 'resnet': 229 | if model_name == "resnet101": 230 | in_features = 2048 231 | elif model_name == "resnet18": 232 | in_features = 512 233 | elif model_name[:12] == "efficientnet": 234 | if model_name == "efficientnet-b7": 235 | in_features = 2560 236 | elif model_name == "efficientnet-b4": 237 | in_features = 1792 238 | 239 | self.linear = nn.Linear(in_features, d_model) 240 | self.bn = nn.BatchNorm1d(in_features, momentum=0.01) 241 | 242 | def forward(self, img_ft): 243 | # (batch_size, d, d, f) -> (batch_size, d^2, f) 244 | img_ft = self.linear( 245 | self.bn(img_ft.squeeze(1))).unsqueeze(1) 246 | return img_ft 247 | 248 | 249 | class Joint_Encoding: 250 | def __init__(self, joint_encoding_function): 251 | if joint_encoding_function == 'addition': 252 | self.joint_encoding_function = lambda x1, x2 : x1 + x2 253 | elif joint_encoding_function == 'deduction': 254 | self.joint_encoding_function = lambda x1, x2 : x1 - x2 255 | elif joint_encoding_function == 'max': 256 | self.joint_encoding_function = lambda x1, x2 : torch.max(x1,x2) 257 | elif joint_encoding_function == 'element_multiplication': 258 | self.joint_encoding_function = lambda x1, x2 : x1 * x2 259 | 260 | def __call__(self,E1, E2): 261 | return self.joint_encoding_function(E1, E2) 262 | 263 | 264 | class Attribute_Embedding(nn.Module): 265 | def __init__(self, d_model, attribute_vocab_size): 266 | """Load the pretrained ResNet-152 and replace top fc layer.""" 267 | super().__init__() 268 | self.embed = Embedder(attribute_vocab_size, d_model) 269 | 270 | def forward(self, attribute): 271 | attribute = self.embed(attribute) 272 | return attribute 273 | 274 | 275 | class Transformer(nn.Module): 276 | def __init__(self, trg_vocab, d_model, N, heads, dropout, cnn_model_name, 277 | joint_encoding_function, attribute_vocab_size=1000, 278 | cnn_pretrained_model=None, add_attribute=False): 279 | super().__init__() 280 | self.add_attribute = add_attribute 281 | self.cnn1 = CNN_Embedding(d_model, cnn_model_name, cnn_pretrained_model) 282 | self.cnn2 = CNN_Embedding(d_model, cnn_model_name, cnn_pretrained_model) 283 | 284 | if self.add_attribute: 285 | self.attribute_embedding1 = Attribute_Embedding( 286 | d_model, attribute_vocab_size) 287 | self.attribute_embedding2 = Attribute_Embedding( 288 | d_model, attribute_vocab_size) 289 | self.joint_encoding = Joint_Encoding(joint_encoding_function) 290 | self.encoder = Encoder(d_model, N, heads, dropout) 291 | self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout) 292 | self.out = nn.Linear(d_model, trg_vocab) 293 | 294 | def forward(self, image0, image1, trg, trg_mask, 295 | image0_attribute, image1_attribute): 296 | 297 | image0 = self.cnn1(image0) 298 | image1 = self.cnn2(image1) 299 | 300 | if self.add_attribute: 301 | image0_attribute = self.attribute_embedding1(image0_attribute) 302 | image1_attribute = self.attribute_embedding2(image1_attribute) 303 | joint_encoding = self.joint_encoding(image0, image1) 304 | joint_encoding = torch.cat((joint_encoding, image0_attribute), 1) 305 | joint_encoding = torch.cat((joint_encoding, image1_attribute), 1) 306 | else: 307 | joint_encoding = self.joint_encoding(image0, image1) 308 | 309 | joint_encoding = self.encoder(joint_encoding) 310 | output = self.decoder(trg, joint_encoding, trg_mask) 311 | output = self.out(output) 312 | 313 | return output 314 | 315 | 316 | def load_trained_model(model_name): 317 | checkpoint = torch.load(model_name, map_location='cpu') 318 | model_opt = checkpoint['settings'] 319 | 320 | model = Transformer(model_opt.vocab_size, model_opt.d_model, 321 | model_opt.n_layers, model_opt.n_heads, 322 | model_opt.dropout, model_opt.cnn_name, 323 | model_opt.joint_enc_func, 324 | model_opt.attribute_vocab_size, 325 | model_opt.cnn_pretrained_model, 326 | model_opt.add_attribute) 327 | 328 | model.load_state_dict(checkpoint['model']) 329 | print('[Info] Trained model state loaded from: ', model_name) 330 | return model 331 | 332 | 333 | def create_model(opt): 334 | assert opt.d_model % opt.n_heads == 0 335 | assert opt.dropout < 1 336 | model = Transformer(opt.vocab_size, opt.d_model, opt.n_layers, 337 | opt.n_heads, opt.dropout, opt.cnn_name, 338 | opt.joint_enc_func, opt.attribute_vocab_size, 339 | opt.cnn_pretrained_model, opt.add_attribute) 340 | 341 | for p in model.parameters(): 342 | if p.dim() > 1: 343 | nn.init.xavier_uniform_(p) 344 | return model 345 | -------------------------------------------------------------------------------- /transformer/interactive_retrieval/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import shutil 4 | import time 5 | import json 6 | import os 7 | import tqdm 8 | import torch 9 | import pickle 10 | import torch.nn as nn 11 | from torchvision import transforms 12 | from models import RetrieverTransformer 13 | from data_loader import Dataset 14 | import utils 15 | import Ranker 16 | import UserModel 17 | 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | 20 | 21 | def create_exp_dir(path, scripts_to_save=None): 22 | if not os.path.exists(path): 23 | os.mkdir(path) 24 | print('Experiment dir : {}'.format(path)) 25 | if scripts_to_save is not None: 26 | if not os.path.exists(os.path.join(path, 'scripts')): 27 | os.mkdir(os.path.join(path, 'scripts')) 28 | for script in scripts_to_save: 29 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 30 | shutil.copyfile(script, dst_file) 31 | return 32 | 33 | 34 | def load_image_features(args): 35 | # Image preprocessing, normalization for the pretrained b7 36 | transform = transforms.Compose([ 37 | transforms.Resize(args.crop_size), 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.485, 0.456, 0.406), 40 | (0.229, 0.224, 0.225))]) 41 | 42 | # train split 43 | fts_file = os.path.join(args.save, 'b7_v2.{}.{}.th'.format( 44 | args.data_set, 'train')) 45 | if os.path.isfile(fts_file): 46 | print('[INFO] loading image features: {}'.format(fts_file)) 47 | fts = torch.load(fts_file, map_location='cpu') 48 | else: 49 | print('[INFO] computing image features: {}'.format(fts_file)) 50 | data_loader = Dataset( 51 | args.image_folder.format(args.data_set), 52 | args.data_split_file.format(args.data_set, 'train'), 53 | transform, num_workers=args.num_workers) 54 | attr_file = args.attribute_file.format(args.data_set) 55 | 56 | fts = utils.extract_features(data_loader, attr_file, 57 | args.attr2idx_file, device, 58 | args.image_model) 59 | torch.save(fts, fts_file) 60 | 61 | # dev split 62 | fts_file_dev = os.path.join(args.save, 'b7_v2.{}.{}.th'.format( 63 | args.data_set, 'val')) 64 | if os.path.isfile(fts_file_dev): 65 | print('[INFO] loading image features: {}'.format(fts_file_dev)) 66 | fts_dev = torch.load(fts_file_dev, map_location='cpu') 67 | else: 68 | print('[INFO] computing image features: {}'.format(fts_file_dev)) 69 | data_loader_dev = Dataset( 70 | args.image_folder.format(args.data_set), 71 | args.data_split_file.format(args.data_set, 'val'), 72 | transform, num_workers=args.num_workers) 73 | attr_file_dev = args.attribute_file.format(args.data_set) 74 | fts_dev = utils.extract_features(data_loader_dev, attr_file_dev, 75 | args.attr2idx_file, device, 76 | args.image_model) 77 | torch.save(fts_dev, fts_file_dev) 78 | 79 | return fts, fts_dev 80 | 81 | 82 | def eval_batch(fts, captioner, retriever, args, 83 | train_mode=False, 84 | optimizer=None): 85 | criterion = nn.TripletMarginLoss( 86 | reduction='mean', margin=args.margin).to(device) 87 | # generate a mapping for dev, to ensure sampling bias is reduced 88 | num_target = len(fts['asins']) 89 | 90 | batch_size = args.batch_size 91 | ranker = Ranker.Ranker(device) 92 | total_step = math.floor(num_target / batch_size) 93 | 94 | ranking_tracker = [0] * args.num_dialog_turns 95 | loss_tracker = [0] * args.num_dialog_turns 96 | 97 | with open('data/shuffled.{}.{}.json'.format( 98 | args.data_set, 'val')) as f: 99 | first_candidate_set = json.load(f) 100 | 101 | with torch.no_grad(): 102 | retriever.eval() 103 | ranker.update_emb(fts, args.batch_size, retriever) 104 | 105 | if train_mode: 106 | retriever.train() 107 | else: 108 | retriever.eval() 109 | 110 | for step in tqdm.tqdm(range(total_step)): 111 | # sample target 112 | if train_mode: 113 | target_ids = torch.tensor( 114 | [0]*args.batch_size, device=device, dtype=torch.long) 115 | target_ids.random_(0, num_target) 116 | else: 117 | target_ids = torch.tensor( 118 | [i for i in 119 | range(step * batch_size, (step + 1) * batch_size)]).to( 120 | device=device, dtype=torch.long) 121 | 122 | # sample first batch of candidates 123 | if train_mode: 124 | candidate_ids = torch.tensor( 125 | [0]*args.batch_size, device=device, dtype=torch.long) 126 | candidate_ids.random_(0, num_target) 127 | else: 128 | candidate_ids = torch.tensor( 129 | [first_candidate_set[i] 130 | for i in range(step * batch_size, (step + 1) * batch_size)], 131 | device=device, dtype=torch.long) 132 | 133 | # target_ids.random_(0, num_target) 134 | target_img_ft = utils.get_image_batch(fts, target_ids) 135 | target_img_ft = target_img_ft.to(device) 136 | target_img_emb = retriever.encode_image(target_img_ft) 137 | 138 | target_attr = utils.get_attribute_batch(fts, target_ids) 139 | target_attr = target_attr.to(device) 140 | 141 | # clean up dialog history tracker 142 | retriever.init_hist() 143 | # history_hidden = history_hidden.expand_as(target_img_emb) 144 | # history_hidden = None 145 | loss = 0 146 | 147 | for d_turn in range(args.num_dialog_turns): 148 | # get candidate image features 149 | candidate_img_ft = utils.get_image_batch(fts, candidate_ids) 150 | candidate_img_ft = candidate_img_ft.to(device) 151 | 152 | candidate_attr = utils.get_attribute_batch(fts, candidate_ids) 153 | candidate_attr = candidate_attr.to(device) 154 | # generate captions from model 155 | with torch.no_grad(): 156 | sentence_ids = captioner.get_caption( 157 | target_img_ft, candidate_img_ft, 158 | target_attr, candidate_attr) 159 | sentence_ids = sentence_ids.to(device) 160 | 161 | candidate_img_ft = candidate_img_ft.to(device) 162 | 163 | history_hidden = retriever.forward( 164 | text=sentence_ids, image=candidate_img_ft, 165 | attribute=candidate_attr) 166 | 167 | # sample negatives, update tracker's output to 168 | # match targets via triplet loss 169 | negative_ids = torch.tensor( 170 | [0]*args.batch_size, device=device, dtype=torch.long) 171 | negative_ids.random_(0, num_target) 172 | 173 | negative_img_ft = utils.get_image_batch(fts, negative_ids) 174 | negative_img_ft = negative_img_ft.to(device) 175 | negative_img_emb = retriever.encode_image(negative_img_ft) 176 | 177 | # accumulate loss 178 | loss_tmp = criterion(history_hidden, target_img_emb, 179 | negative_img_emb) 180 | loss += loss_tmp 181 | loss_tracker[d_turn] += loss_tmp.item() 182 | 183 | # generate new candidates, compute ranking information 184 | with torch.no_grad(): 185 | candidate_ids = ranker.nearest_neighbors(history_hidden) 186 | ranking = ranker.compute_rank(history_hidden, target_ids) 187 | ranking_tracker[d_turn] += (ranking.mean().item() / 188 | (num_target * 1.0)) 189 | 190 | # update weights 191 | if train_mode: 192 | optimizer.zero_grad() 193 | # loss = loss / args.num_dialog_turns 194 | loss.backward() 195 | # clip_grad_norm_(retriever.parameters(), args.clip_norm) 196 | optimizer.step() 197 | with torch.no_grad(): 198 | retriever.eval() 199 | ranker.update_emb(fts, args.batch_size, retriever) 200 | retriever.train() 201 | 202 | loss = loss.item() / total_step 203 | for i in range(args.num_dialog_turns): 204 | ranking_tracker[i] /= total_step 205 | loss_tracker[i] /= total_step 206 | 207 | metrics = {'loss': loss, 'score': 5 - sum(ranking_tracker), 208 | 'loss_tracker': loss_tracker, 209 | 'ranking_tracker': ranking_tracker} 210 | return metrics 211 | 212 | 213 | def train(args): 214 | save_folder = '{}/{}'.format( 215 | args.save, args.data_set) 216 | create_exp_dir(save_folder) 217 | 218 | def logging(s, print_=True, log_=True): 219 | if print_: 220 | print(s) 221 | if log_: 222 | with open(os.path.join(save_folder, 'log.txt'), 'a+') as f_log: 223 | f_log.write(s + '\n') 224 | return 225 | 226 | logging(str(args)) 227 | 228 | # load image features (or compute and store them if necessary) 229 | fts, fts_dev = load_image_features(args) 230 | 231 | # user model: captioner 232 | captioner = UserModel.UserModel(args) 233 | captioner.to(device) 234 | 235 | # ranker 236 | img_dim = fts['image'].size(dim=1) 237 | 238 | # response encoder 239 | glove_emb_file = args.glove_emb_file.format(args.data_set) 240 | with open(glove_emb_file, 'rb') as f: 241 | glove_emb = pickle.load(f) 242 | 243 | retriever = RetrieverTransformer( 244 | captioner.get_vocab_size(), glove_emb, img_dim, 245 | args.hidden_unit_num, args.layer_num, 246 | args.attribute_num).to(device) 247 | 248 | # Loss and optimizer 249 | params = retriever.parameters() 250 | optimizer = torch.optim.Adam(params, lr=args.learning_rate) 251 | 252 | current_lr = args.learning_rate 253 | cur_patient = 0 254 | best_score = float('-inf') 255 | best_train_score = float('-inf') 256 | 257 | for epoch in range(100): 258 | retriever.train() 259 | metrics = eval_batch(fts, captioner, retriever, args, 260 | train_mode=True, optimizer=optimizer) 261 | res = metrics['ranking_tracker'] 262 | logging( 263 | '| ({}) train loss: {:8.5f} | lr: {:8.7f} | ' 264 | 'score {:8.5f} / {:8.5f} | ' 265 | 'rank {:5.3f}/{:5.3f}/{:5.3f}/{:5.3f}/{:5.3f}'.format( 266 | epoch, metrics['loss'], current_lr, 267 | metrics['score'], best_train_score, 268 | 1 - res[0], 1 - res[1], 1 - res[2], 1 - res[3], 1 - res[4])) 269 | logging('-' * 77) 270 | 271 | if metrics['score'] < best_train_score + 1e-3: 272 | cur_patient += 1 273 | if cur_patient >= args.patient: 274 | current_lr *= 0.5 275 | if current_lr < args.learning_rate * 1e-1: 276 | break 277 | params = retriever.parameters() 278 | optimizer = torch.optim.Adam(params, lr=current_lr) 279 | 280 | best_train_score = max(metrics['score'], best_train_score) 281 | 282 | # eval on validation split 283 | if epoch % args.checkpoint == 0: 284 | retriever.eval() 285 | logging('-' * 77) 286 | 287 | with torch.no_grad(): 288 | metrics = eval_batch(fts_dev, captioner, retriever, args) 289 | res = metrics['ranking_tracker'] 290 | logging( 291 | '| ({}) eval loss: {:8.3f} | score {:8.5f} / {:8.5f} | ' 292 | 'rank {:5.3f}/{:5.3f}/{:5.3f}/{:5.3f}/{:5.3f}'.format( 293 | epoch, metrics['loss'], metrics['score'], best_score, 294 | 1-res[0], 1-res[1], 1-res[2], 1-res[3], 1-res[4])) 295 | logging('-' * 77) 296 | dev_score = metrics['score'] 297 | 298 | if dev_score > best_score: 299 | best_score = dev_score 300 | # save best model 301 | checkpoint_model = {'args': args, 302 | 'retrieval_model': retriever} 303 | torch.save(checkpoint_model, os.path.join( 304 | save_folder, 'checkpoint_model.th')) 305 | cur_patient = 0 306 | 307 | logging('best_dev_score: {}'.format(best_score)) 308 | 309 | 310 | if __name__ == '__main__': 311 | parser = argparse.ArgumentParser() 312 | parser.add_argument('--image_folder', type=str, 313 | default='../resized_images/{}/') 314 | parser.add_argument('--data_split_file', type=str, 315 | default='data/split.{}.{}.json') 316 | parser.add_argument('--attribute_file', type=str, 317 | default='../attribute_prediction/prediction/' 318 | 'predict_{}_b7_ft.json') 319 | parser.add_argument('--attr2idx_file', type=str, 320 | default='../attribute_prediction/data/' 321 | 'attribute2idx.json') 322 | parser.add_argument('--image_model', type=str, 323 | default='../attribute_prediction/deepfashion_models/' 324 | 'dfattributes_efficientnet_b7ns.pth') 325 | 326 | parser.add_argument('--save', type=str, default='models/', 327 | help='path for saving ranker models') 328 | parser.add_argument('--crop_size', type=int, default=224, 329 | help='size for randomly cropping images') 330 | parser.add_argument('--data_set', type=str, default='dress', 331 | help='dress / toptee / shirt') 332 | 333 | parser.add_argument('--checkpoint', type=int, default=3, 334 | help='step size for saving models') 335 | parser.add_argument('--patient', type=int, default=2, 336 | help='patient for reducing learning rate') 337 | parser.add_argument('--exp_folder', type=str, default='v1') 338 | 339 | # User model parameters 340 | parser.add_argument('--user_model_file', type=str, 341 | default='../user_modeling/models/' 342 | '{}-efficient-b7-finetune.chkpt') 343 | parser.add_argument('--user_vocab_file', type=str, 344 | default='../user_modeling/data/{}_vocab.json') 345 | parser.add_argument('--max_seq_len', type=int, default=8, 346 | help='maximum caption length') 347 | parser.add_argument('--beam_size', type=int, default=1, 348 | help='beam search branch size') 349 | parser.add_argument('--glove_emb_file', type=str, 350 | default='data/{}_emb.pt') 351 | parser.add_argument('--attribute_num', type=int, 352 | default=1000) 353 | 354 | # Model parameters 355 | parser.add_argument('--hidden_unit_num', type=int, default=256) 356 | parser.add_argument('--layer_num', type=int, default=6) 357 | 358 | parser.add_argument('--num_workers', type=int, default=4) 359 | parser.add_argument('--num_dialog_turns', type=int, default=5) 360 | parser.add_argument('--margin', type=float, default=10) 361 | parser.add_argument('--clip_norm', type=float, default=10) 362 | 363 | parser.add_argument('--no_cuda', action='store_true') 364 | parser.add_argument('--batch_size', type=int, default=2) 365 | parser.add_argument('--learning_rate', type=float, default=0.001) 366 | 367 | parser.add_argument('--device', type=str, default='cuda') 368 | 369 | args = parser.parse_args() 370 | 371 | train(args) 372 | 373 | -------------------------------------------------------------------------------- /transformer/user_modeling/train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script handling the training process. 3 | ''' 4 | 5 | 6 | import argparse 7 | import math 8 | import time 9 | import os 10 | from tqdm import tqdm 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import torch.utils.data 15 | # import transformer.Constants as Constants 16 | from dataset import get_loader, load_ori_token_data, get_loader_test, load_ori_token_data_new 17 | from Beam import beam_search 18 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 19 | from pycocoevalcap.bleu.bleu import Bleu 20 | import numpy as np 21 | from build_vocab import Vocabulary 22 | # from model import Neural_Naturalist 23 | import torchvision.transforms as transforms 24 | from nltk.translate.bleu_score import sentence_bleu 25 | 26 | from torch.optim.lr_scheduler import StepLR 27 | from Optim import NoamOpt, get_std_opt 28 | 29 | from Models import get_model, create_masks 30 | from torch.autograd import Variable 31 | from pytorchtools import EarlyStopping 32 | 33 | # from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer 34 | 35 | # # from pycocoevalcap.meteor.meteor import Meteor 36 | # from pycocoevalcap.rouge.rouge import Rouge 37 | # from pycocoevalcap.ciderd.ciderD import CiderD 38 | # from pycocoevalcap.cider.cider import Cider 39 | from test import test 40 | 41 | # from dataset import TranslationDataset, paired_collate_fn 42 | # from transformer.Models import Transformer 43 | # from transformer.Optim import ScheduledOptim 44 | 45 | """ 46 | key_padding_mask should be a ByteTensor where True values are positions 47 | that should be masked with float('-inf') and False values will be unchanged 48 | 49 | """ 50 | 51 | Constants_PAD = 0 52 | 53 | # def nopeak_mask(size): 54 | # np_mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8') 55 | # np_mask = Variable(torch.from_numpy(np_mask) == 0) 56 | 57 | # return np_mask 58 | 59 | # def create_masks(trg): 60 | # # src_mask = (src != Constants_PAD.unsqueeze(-2) 61 | 62 | # if trg is not None: 63 | # trg_mask = (trg != Constants_PAD).unsqueeze(-2) 64 | # size = trg.size(1) # get seq_len for matrix 65 | # np_mask = nopeak_mask(size).to(trg_mask.device) 66 | 67 | # trg_mask = trg_mask & np_mask 68 | 69 | # else: 70 | # trg_mask = None 71 | 72 | # return trg_mask 73 | 74 | def cal_performance(pred, gold, smoothing=False): 75 | ''' Apply label smoothing if needed 76 | pred:(batch_size, sequence_len, vocab_size) 77 | gold:(batch_size, seq_len) (indices) 78 | ''' 79 | 80 | loss = cal_loss(pred, gold, smoothing) 81 | 82 | #greedy decoding 83 | pred = pred.max(2)[1]# torch.max() return (values, indices) 84 | #shape:(batch_size, sequence_len) 85 | 86 | # gold = gold.transpose(0,1) #shape:(sequence_len, batch_size) 87 | 88 | non_pad_mask = gold.ne(Constants_PAD)#Compute input!=other element-wise 89 | n_correct = pred.eq(gold) 90 | n_correct = n_correct.masked_select(non_pad_mask).sum().item() 91 | 92 | return loss, n_correct 93 | # return loss 94 | 95 | def calculate_bleu(tgt, logits, vocab): 96 | """ 97 | reference = [['this', 'is', 'small', 'test']] 98 | candidate = ['this', 'is', 'a', 'test'] 99 | bleu_4 = sentence_bleu(reference, candidate, weights=(0.25, 0.25, 0.25, 0.25)) 100 | 101 | """ 102 | 103 | # TODO: Batched Beam Search 104 | # Therefore, do not use a batch_size greater than 1 - IMPORTANT! 105 | 106 | # Lists to store references (true captions), and hypothesis (prediction) for each image 107 | # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - 108 | # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] 109 | 110 | word_map = vocab.word2idx 111 | 112 | pred = logits.max(2)[1] 113 | 114 | references = list() 115 | hypotheses = list() 116 | 117 | img_caps = tgt.tolist() 118 | img_captions = list( 119 | map(lambda c: [w for w in c if w not in {word_map[''], word_map[''], word_map['']}], 120 | img_caps)) # remove and pads 121 | references.append(img_captions) 122 | 123 | # Hypotheses 124 | hypotheses.append([w for w in seq if w not in {word_map[''], word_map[''], word_map['']}]) 125 | 126 | bleu4 = sentence_bleu(references, hypotheses) 127 | 128 | return bleu4 129 | 130 | 131 | 132 | def cal_loss(pred, gold, smoothing): 133 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 134 | #why the cross_entropy loss is sum, because we want to calculator the total loss devided by the total num of tokens 135 | #input (Tensor) – (N, C)(N,C) where C = number of classes or (N, C, H, W) 136 | #target (Tensor) – (N)(N) where each value is 0 <= targets} <= C-1, 0≤targets[i]≤C−1 , or (N, d_1, d_2, ..., d_K) 137 | #if our pred shape is (batch_size, sequence_len, vocab_size), we either transpose(1,2) or flatten the matrix 138 | 139 | gold = gold.contiguous().view(-1) 140 | pred = pred.contiguous().view(-1,pred.size(-1)) 141 | 142 | if smoothing: 143 | eps = 0.1 144 | n_class = pred.size(1) 145 | 146 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 147 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 148 | log_prb = F.log_softmax(pred, dim=1) 149 | 150 | non_pad_mask = gold.ne(Constants_PAD) 151 | loss = -(one_hot * log_prb).sum(dim=1) 152 | loss = loss.masked_select(non_pad_mask).sum() # average later 153 | else: 154 | loss = F.cross_entropy(pred, gold, ignore_index=Constants_PAD, reduction='sum')#This criterion combines log_softmax and nll_loss in a single function 155 | 156 | return loss 157 | 158 | def get_subsequent_mask(seq): 159 | ''' For masking out the subsequent info. ''' 160 | 161 | sz_b, len_s = seq.size() 162 | mask = (torch.triu(torch.ones(len_s, len_s)) == 1).transpose(0, 1) 163 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 164 | # subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 165 | 166 | return mask 167 | 168 | 169 | def train_epoch(model, training_data, optimizer, device, smoothing=False): 170 | ''' Epoch operation in training phase''' 171 | 172 | model.train() 173 | 174 | total_loss = 0 175 | n_word_total = 0 176 | n_word_correct = 0 177 | 178 | for batch in tqdm( 179 | training_data, mininterval=2, 180 | desc=' - (Training) ', leave=False): 181 | 182 | # prepare data 183 | image0, image1, captions, gold, image0_attribute, image1_attribute = map(lambda x: x.to(device), batch) 184 | 185 | """[src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions 186 | that should be masked with float('-inf') and False values will be unchanged. 187 | This mask ensures that no information will be taken from position i if 188 | it is masked, and has a separate mask for each sequence in a batch.""" 189 | 190 | # caption_padding_mask = captions.eq(Constants_PAD) 191 | 192 | # look_ahead_mask = get_subsequent_mask(captions).to(device) 193 | 194 | # trg_input = caption[:, :-1] 195 | trg_mask = create_masks(captions).to(device) 196 | 197 | # ys = trg[:, 1:].contiguous().view(-1) 198 | 199 | # forward 200 | optimizer.optimizer.zero_grad() 201 | 202 | logits = model(image0, image1, captions, trg_mask, image0_attribute, image1_attribute) 203 | # logits = model(image1, image2, captions, look_ahead_mask, caption_padding_mask)#(batch_size, sequence_len, vocab_size) 204 | 205 | # backward 206 | loss, n_correct = cal_performance(logits, gold, smoothing=smoothing) 207 | 208 | loss.backward() 209 | 210 | torch.nn.utils.clip_grad_norm_(model.get_trainable_parameters(), max_norm=5) 211 | # update parameters 212 | optimizer.step() 213 | 214 | # note keeping 215 | total_loss += loss.item() 216 | 217 | non_pad_mask = gold.ne(Constants_PAD) #pad = 0 218 | n_word = non_pad_mask.sum().item() 219 | n_word_total += n_word 220 | n_word_correct += n_correct 221 | 222 | loss_per_word = total_loss/n_word_total 223 | accuracy = n_word_correct/n_word_total 224 | return loss_per_word, accuracy 225 | 226 | def eval_epoch(model, validation_data, device, vocab): 227 | ''' Epoch operation in evaluation phase ''' 228 | 229 | model.eval() 230 | 231 | total_loss = 0 232 | n_word_total = 0 233 | n_word_correct = 0 234 | 235 | with torch.no_grad(): 236 | for batch in tqdm( 237 | validation_data, mininterval=2, 238 | desc=' - (Validation) ', leave=False): 239 | 240 | # prepare data 241 | image0, image1, captions, gold, image0_attribute, image1_attribute = map(lambda x: x.to(device), batch) 242 | 243 | """[src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions 244 | that should be masked with float('-inf') and False values will be unchanged. 245 | This mask ensures that no information will be taken from position i if 246 | it is masked, and has a separate mask for each sequence in a batch.""" 247 | # caption_padding_mask = captions.eq(Constants_PAD) 248 | 249 | # look_ahead_mask = get_subsequent_mask(captions).to(device) 250 | 251 | # # forward 252 | # logits = model(image1, image2, captions, look_ahead_mask, caption_padding_mask) 253 | 254 | trg_mask = create_masks(captions).to(device) 255 | 256 | logits = model(image0, image1, captions, trg_mask, image0_attribute, image1_attribute) 257 | 258 | loss, n_correct = cal_performance(logits, gold, smoothing=False) 259 | 260 | # bleu_4 = calculate_bleu(captions, logits, vocab) 261 | 262 | # note keeping 263 | total_loss += loss.item() 264 | 265 | non_pad_mask = gold.ne(Constants_PAD) 266 | n_word = non_pad_mask.sum().item() 267 | n_word_total += n_word 268 | n_word_correct += n_correct 269 | 270 | loss_per_word = total_loss/n_word_total 271 | accuracy = n_word_correct/n_word_total 272 | return loss_per_word, accuracy 273 | 274 | 275 | def eval_epoch_bleu(model, validation_data, device, vocab, list_of_refs_dev, args): 276 | ''' Epoch operation in evaluation phase ''' 277 | 278 | model.eval() 279 | 280 | total_loss = 0 281 | n_word_total = 0 282 | n_word_correct = 0 283 | 284 | hypotheses = {} 285 | count = 0 286 | 287 | with torch.no_grad(): 288 | for batch in tqdm( 289 | validation_data, mininterval=2, 290 | desc=' - (Validation) ', leave=False): 291 | 292 | # prepare data 293 | image0, image1, image0_attribute, image1_attribute = map(lambda x: x.to(device), batch) 294 | 295 | """[src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions 296 | that should be masked with float('-inf') and False values will be unchanged. 297 | This mask ensures that no information will be taken from position i if 298 | it is masked, and has a separate mask for each sequence in a batch.""" 299 | 300 | hyp = beam_search(image0, image1, model, args, vocab, image0_attribute, image1_attribute) 301 | 302 | hyp = hyp.split("")[0].strip() 303 | 304 | hypotheses[count] = [hyp] 305 | 306 | count += 1 307 | 308 | scorer = Bleu(4) 309 | 310 | score, _ = scorer.compute_score(list_of_refs_dev, hypotheses) 311 | 312 | return score 313 | 314 | def train(model, training_data, validation_data, optimizer, args, vocab, list_of_refs_dev, validation_data_combined): 315 | ''' Start training ''' 316 | 317 | early_stopping_with_saving = EarlyStopping(patience=args.patience, verbose=True, args=args) 318 | 319 | log_train_file = None 320 | log_valid_file = None 321 | 322 | if args.log: 323 | log_train_file = args.log + '.train.log' 324 | log_valid_file = args.log + '.valid.log' 325 | log_valid_bleu_file = args.log + '.valid.bleu.log' 326 | 327 | print('[Info] Training performance will be written to file: {} and {}'.format( 328 | log_train_file, log_valid_file)) 329 | 330 | with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf: 331 | log_tf.write('epoch,loss,ppl,accuracy\n') 332 | log_vf.write('epoch,loss,ppl,accuracy\n') 333 | 334 | with open(log_valid_bleu_file, 'w') as log_vf_bleu: 335 | log_vf_bleu.write('epoch,bleu1,bleu2,bleu3,bleu4\n') 336 | 337 | valid_accus = [] 338 | 339 | best_valid_score = float('-inf') 340 | 341 | for epoch_i in range(args.epoch): 342 | 343 | if early_stopping_with_saving.early_stop: 344 | print("Early stopping") 345 | break 346 | 347 | print('Epoch {}, lr {}'.format(epoch_i, optimizer.optimizer.param_groups[0]['lr'])) 348 | 349 | start = time.time() 350 | 351 | train_loss, train_accu = train_epoch( 352 | model, training_data, optimizer, args.device) 353 | 354 | print(' - (Training) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\ 355 | 'elapse: {elapse:3.3f} min'.format( 356 | ppl=math.exp(min(train_loss, 100)), accu=100*train_accu, 357 | elapse=(time.time()-start)/60)) 358 | 359 | start = time.time() 360 | 361 | valid_loss, valid_accu = eval_epoch(model, validation_data, args.device, vocab) 362 | 363 | print(' - (Validation) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\ 364 | 'elapse: {elapse:3.3f} min'.format( 365 | ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu, 366 | elapse=(time.time()-start)/60)) 367 | 368 | if epoch_i != 0 and epoch_i % args.bleu_valid_every_n == 0: 369 | 370 | bleu_valid = eval_epoch_bleu(model, validation_data_combined, args.device, vocab, list_of_refs_dev, args) 371 | 372 | print(' - (Validation) bleu-1: {bleu1: 8.5f}, bleu-2: {bleu2: 8.5f}, bleu-3: {bleu3: 8.5f}, bleu-4: {bleu4: 8.5f}'.format(\ 373 | bleu1=bleu_valid[0], bleu2=bleu_valid[1], bleu3=bleu_valid[2], bleu4=bleu_valid[3])) 374 | 375 | # if args.save_model: 376 | # if args.save_mode == 'all': 377 | # model_name = args.save_model + '_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu) 378 | # torch.save(checkpoint, model_name) 379 | # elif args.save_mode == 'best': 380 | early_stopping_with_saving(bleu_valid[2], model, epoch_i) 381 | 382 | with open(log_valid_bleu_file, 'a') as log_vf_bleu: 383 | log_vf_bleu.write('{epoch},{bleu1: 8.5f},{bleu2: 8.5f},{bleu3: 8.5f},{bleu4: 8.5f}\n'.format( \ 384 | epoch=epoch_i, bleu1=bleu_valid[0], bleu2=bleu_valid[1], bleu3=bleu_valid[2], bleu4=bleu_valid[3])) 385 | # model_name = args.save_model + '.chkpt' 386 | # if bleu4_valid >= best_valid_score: 387 | # best_valid_score = bleu4_valid 388 | # torch.save(checkpoint, model_name) 389 | # print(' - [Info] The checkpoint file has been updated.') 390 | 391 | 392 | # checkpoint = { 393 | # 'model': model.state_dict(), 394 | # 'settings': args, 395 | # 'epoch': epoch_i} 396 | # torch.save(checkpoint, args.save_model + '.latest.chkpt') 397 | 398 | 399 | 400 | if log_train_file and log_valid_file: 401 | with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf: 402 | log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format( 403 | epoch=epoch_i, loss=train_loss, 404 | ppl=math.exp(min(train_loss, 100)), accu=100*train_accu)) 405 | log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format( 406 | epoch=epoch_i, loss=valid_loss, 407 | ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu)) 408 | 409 | 410 | 411 | def count_parameters(model): 412 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 413 | 414 | 415 | 416 | def main(): 417 | ''' Main function ''' 418 | parser = argparse.ArgumentParser() 419 | 420 | parser.add_argument('-data_train', type=str, default="") 421 | parser.add_argument('-data_dev', required=True) 422 | parser.add_argument('-data_test', type=str, default="") 423 | parser.add_argument('-vocab', required=True) 424 | 425 | parser.add_argument('-epoch', type=int, default=10000) 426 | parser.add_argument('-batch_size', type=int, default=64) 427 | 428 | #parser.add_argument('-d_word_vec', type=int, default=512) 429 | parser.add_argument('-d_model', type=int, default=512) 430 | # parser.add_argument('-d_inner_hid', type=int, default=2048) 431 | # parser.add_argument('-d_k', type=int, default=64) 432 | # parser.add_argument('-d_v', type=int, default=64) 433 | 434 | parser.add_argument('-n_heads', type=int, default=8) 435 | parser.add_argument('-n_layers', type=int, default=6) 436 | parser.add_argument('-n_warmup_steps', type=int, default=4000) 437 | 438 | parser.add_argument('-dropout', type=float, default=0.1) 439 | # parser.add_argument('-embs_share_weight', action='store_true') 440 | # parser.add_argument('-proj_share_weight', action='store_true') 441 | 442 | parser.add_argument('-log', default=None) 443 | parser.add_argument('-save_model', default=None) 444 | parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best') 445 | 446 | parser.add_argument('-no_cuda', action='store_true') 447 | parser.add_argument('-label_smoothing', action='store_true') 448 | parser.add_argument('-num_workers', type=int, default=1) 449 | 450 | parser.add_argument('-cnn_name', type=str, default="resnet101") 451 | parser.add_argument('-cnn_pretrained_model', type=str, default="") 452 | parser.add_argument('-joint_enc_func', type=str, default="element_multiplication") 453 | # parser.add_argument('-comparative_module_name', type=str, default="transformer_encoder") 454 | parser.add_argument('-lr', type=float, default=0.01) 455 | # parser.add_argument('-step_size', type=int, default=1000) 456 | # parser.add_argument('-gamma', type=float, default=0.9) 457 | parser.add_argument('-crop_size', type=int, default=224) 458 | parser.add_argument('-max_seq_len', type=int, default=64) 459 | parser.add_argument('-attribute_len', type=int, default=5) 460 | 461 | parser.add_argument('-pretrained_model', type=str, default="") 462 | 463 | parser.add_argument('-rank_alpha', type=float, default=1.0) 464 | parser.add_argument('-patience', type=int, default=7) 465 | parser.add_argument('-bleu_valid_every_n', type=int, default=5) 466 | parser.add_argument('-data_dev_combined', required=True) 467 | parser.add_argument('-beam_size', type=int, default=5) 468 | parser.add_argument('-seed', type=int, default=0) 469 | parser.add_argument('-attribute_vocab_size', type=int, default=1000) 470 | parser.add_argument('-add_attribute', action='store_true') 471 | 472 | 473 | 474 | args = parser.parse_args() 475 | args.cuda = not args.no_cuda 476 | args.d_word_vec = args.d_model 477 | 478 | args.load_weights = False 479 | if args.pretrained_model: 480 | args.load_weights = True 481 | 482 | np.random.seed(0) 483 | torch.manual_seed(0) 484 | args.device = torch.device('cuda' if args.cuda else 'cpu') 485 | 486 | log_path = args.log.split("/") 487 | log_path = "/".join(log_path[:-1]) 488 | if not os.path.exists(log_path): 489 | os.makedirs(log_path) 490 | 491 | model_path = args.save_model.split("/") 492 | model_path = "/".join(model_path[:-1]) 493 | if not os.path.exists(model_path): 494 | os.makedirs(model_path) 495 | 496 | print(args) 497 | 498 | if args.data_train: 499 | print("======================================start training======================================") 500 | transform = transforms.Compose([ 501 | transforms.RandomCrop(args.crop_size), 502 | transforms.RandomHorizontalFlip(), 503 | transforms.ToTensor(), 504 | transforms.Normalize((0.485, 0.456, 0.406), 505 | (0.229, 0.224, 0.225))]) 506 | 507 | transform_dev = transforms.Compose([ 508 | transforms.CenterCrop(args.crop_size), 509 | transforms.ToTensor(), 510 | transforms.Normalize((0.485, 0.456, 0.406), 511 | (0.229, 0.224, 0.225))]) 512 | 513 | vocab = Vocabulary() 514 | 515 | vocab.load(args.vocab) 516 | 517 | args.vocab_size = len(vocab) 518 | 519 | # Build data loader 520 | data_loader_training = get_loader(args.data_train, 521 | vocab, transform, 522 | args.batch_size, shuffle=True, num_workers=args.num_workers, \ 523 | max_seq_len=args.max_seq_len,\ 524 | attribute_len=args.attribute_len 525 | ) 526 | 527 | data_loader_dev = get_loader(args.data_dev, 528 | vocab, transform_dev, 529 | args.batch_size, shuffle=False, num_workers=args.num_workers, \ 530 | max_seq_len=args.max_seq_len,\ 531 | attribute_len=args.attribute_len 532 | ) 533 | 534 | data_loader_bleu = get_loader_test(args.data_dev_combined, 535 | vocab, transform_dev, 536 | 1, shuffle=False, 537 | attribute_len=args.attribute_len 538 | ) 539 | 540 | list_of_refs_dev = load_ori_token_data_new(args.data_dev_combined) 541 | 542 | model = get_model(args, load_weights=False) 543 | 544 | 545 | print(count_parameters(model)) 546 | 547 | # print(model.get_trainable_parameters()) 548 | # init_lr = np.power(args.d_model, -0.5) 549 | 550 | # optimizer = torch.optim.Adam(model.get_trainable_parameters(), lr=init_lr) 551 | optimizer = get_std_opt(model, args) 552 | 553 | train( model, data_loader_training, data_loader_dev, optimizer ,args, vocab, list_of_refs_dev, data_loader_bleu) 554 | 555 | if args.data_test: 556 | print("======================================start testing==============================") 557 | args.pretrained_model = args.save_model 558 | test(args) 559 | 560 | 561 | 562 | 563 | if __name__ == '__main__': 564 | main() 565 | --------------------------------------------------------------------------------