├── .gitignore ├── Corrnet ├── Corrnet_Pytorch.ipynb ├── README.md └── word2vec.ipynb ├── Demo_Web_App ├── README.md ├── config.py ├── directory.py ├── models │ ├── __pycache__ │ │ ├── bi_lstm.cpython-37.pyc │ │ ├── bi_lstm.cpython-38.pyc │ │ ├── mobilenet.cpython-37.pyc │ │ ├── mobilenet.cpython-38.pyc │ │ ├── model.cpython-37.pyc │ │ ├── model.cpython-38.pyc │ │ ├── resnet.cpython-37.pyc │ │ └── resnet.cpython-38.pyc │ ├── bi_lstm.py │ ├── mobilenet.py │ ├── model.py │ └── resnet.py ├── saved_model │ ├── test_sort.pkl │ └── word_to_index.pkl ├── static │ └── temp │ │ └── 1620480640.jpg ├── templates │ ├── image_results.html │ ├── index.html │ └── text_results.html ├── test_config.py ├── tester.py └── web_app.py ├── README.md ├── Siamese ├── README.md ├── Word2vecgensim.ipynb └── siamese_network.ipynb ├── assets ├── correlational.png ├── deepcmpl.png ├── i2t.PNG ├── siamese.png └── t2i.PNG └── deep_cmpl_model ├── .gitignore ├── README.md ├── code ├── __pycache__ │ ├── config.cpython-37.pyc │ ├── config.cpython-38.pyc │ ├── test.cpython-37.pyc │ ├── test.cpython-38.pyc │ ├── test_config.cpython-37.pyc │ ├── test_config.cpython-38.pyc │ ├── test_params.cpython-37.pyc │ ├── train_config.cpython-37.pyc │ ├── train_config.cpython-38.pyc │ └── train_params.cpython-37.pyc ├── datasets │ ├── __pycache__ │ │ ├── directory.cpython-37.pyc │ │ ├── fashion.cpython-37.pyc │ │ ├── pedes.cpython-37.pyc │ │ └── pedes.cpython-38.pyc │ ├── data.sh │ ├── directory.py │ ├── fashion.py │ └── preprocess.py ├── models │ ├── __pycache__ │ │ ├── bi_lstm.cpython-37.pyc │ │ ├── bi_lstm.cpython-38.pyc │ │ ├── eff_net.cpython-37.pyc │ │ ├── efficient_net.cpython-37.pyc │ │ ├── mobilenet.cpython-37.pyc │ │ ├── mobilenet.cpython-38.pyc │ │ ├── model.cpython-37.pyc │ │ ├── model.cpython-38.pyc │ │ ├── resnet.cpython-37.pyc │ │ └── resnet.cpython-38.pyc │ ├── bi_lstm.py │ ├── eff_net.py │ ├── mobilenet.py │ ├── model.py │ └── resnet.py ├── scripts │ ├── tester.py │ └── trainer.py ├── test.py ├── test_params.py ├── train.py ├── train_params.py └── utils │ ├── __pycache__ │ ├── directory.cpython-37.pyc │ ├── directory.cpython-38.pyc │ ├── helpers.cpython-37.pyc │ ├── metric.cpython-37.pyc │ └── metric.cpython-38.pyc │ ├── directory.py │ ├── helpers.py │ └── metric.py ├── data ├── images.csv ├── make_json.py ├── processed_data │ ├── metadata_info.txt │ ├── test_reid.json │ ├── test_sort.pkl │ ├── train_reid.json │ ├── train_sort.pkl │ ├── val_reid.json │ ├── val_sort.pkl │ ├── word_counts.txt │ └── word_to_index.pkl └── reid_raw.json └── deep_cmpl_jupyter.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | <<<<<<< HEAD 2 | # Images of the dataset 3 | dataset 4 | 5 | # Checkpoint directories 6 | model_data 7 | logs 8 | 9 | #zip files 10 | *.zip 11 | 12 | # Misccellaneous 13 | data/f8k 14 | 15 | *.pth.tar 16 | Demo_Web_App/static/dataset/ 17 | ======= 18 | dataset/ 19 | drive/ 20 | dgpu/drive/ 21 | dgpu/data/logs 22 | >>>>>>> 7d7a4ac943534fa53775eed173be3313ae889b49 23 | -------------------------------------------------------------------------------- /Corrnet/README.md: -------------------------------------------------------------------------------- 1 | # Corrnet 2 | 3 | > The way to load, train and test this model has been shown in the jupyter notebooks. The code has also been well documented detailing the use of every function. 4 | 5 | ## Instructions to Run Code 6 | 1. The text csv file is to be generated using the word2vec.ipynb notebook, for some text data. 7 | 2. Images must be added in the appropriate folder. 8 | 3. Set all the paths in Corrnet_Pytorch.ipynb. 9 | 4. Then, simply run the Corrnet_Pytorch.ipynb notebook. -------------------------------------------------------------------------------- /Demo_Web_App/README.md: -------------------------------------------------------------------------------- 1 | # Image-Text-Retrieval-Web-App 2 | Flask Web App for ES-654 Machine Learning course project 3 | 4 | 5 | ## Instructions to Run 6 | 7 | To run the web app, follow the instructions in the following Colab Notebook: [link](https://colab.research.google.com/drive/1Obk2DQpHAyQANObUnth_gWuxi-CR5lMn?usp=sharing) 8 | 9 | ## Requirements 10 | 11 | * Flask 12 | * Flask-ngrok 13 | * Werkzeug 14 | * PyTorch 15 | * imageio 16 | -------------------------------------------------------------------------------- /Demo_Web_App/config.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | import torch.nn as nn 3 | import torch 4 | from models.model import Model 5 | from directory import check_file 6 | import random 7 | import numpy as np 8 | 9 | 10 | def network_config(args, split='train', param=None, resume=False, model_path=None, ema=False): 11 | network = Model(args) 12 | network = nn.DataParallel(network).cuda() 13 | cudnn.benchmark = True 14 | args.start_epoch = 0 15 | 16 | # process network params 17 | if resume: 18 | check_file(model_path, 'model_file') 19 | checkpoint = torch.load(model_path) 20 | args.start_epoch = checkpoint['epoch'] + 1 21 | # best_prec1 = checkpoint['best_prec1'] 22 | #network.load_state_dict(checkpoint['state_dict']) 23 | network_dict = checkpoint['network'] 24 | # if ema: 25 | # logging.info('==> EMA Loading') 26 | # network_dict.update(checkpoint['network_ema']) 27 | network.load_state_dict(network_dict) 28 | print('==> Loading checkpoint "{}"'.format(model_path)) 29 | else: 30 | # pretrained 31 | if model_path is not None: 32 | print('==> Loading from pretrained models') 33 | network_dict = network.state_dict() 34 | if args.image_model == 'mobilenet_v1': 35 | cnn_pretrained = torch.load(model_path)['state_dict'] 36 | start = 7 37 | else: 38 | cnn_pretrained = torch.load(model_path) 39 | start = 0 40 | # process keyword of pretrained model 41 | prefix = 'module.image_model.' 42 | pretrained_dict = {prefix + k[start:] :v for k,v in cnn_pretrained.items()} 43 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in network_dict} 44 | network_dict.update(pretrained_dict) 45 | network.load_state_dict(network_dict) 46 | 47 | # process optimizer params 48 | if split == 'test': 49 | optimizer = None 50 | else: 51 | # optimizer 52 | # different params for different part 53 | cnn_params = list(map(id, network.module.image_model.parameters())) 54 | other_params = filter(lambda p: id(p) not in cnn_params, network.parameters()) 55 | other_params = list(other_params) 56 | if param is not None: 57 | other_params.extend(list(param)) 58 | param_groups = [{'params':other_params}, 59 | {'params':network.module.image_model.parameters(), 'weight_decay':args.wd}] 60 | optimizer = torch.optim.Adam( 61 | param_groups, 62 | lr = args.lr, betas=(args.adam_alpha, args.adam_beta), eps=args.epsilon) 63 | if resume: 64 | optimizer.load_state_dict(checkpoint['optimizer']) 65 | 66 | print('Total params: %2.fM' % (sum(p.numel() for p in network.parameters()) / 1000000.0)) 67 | # seed 68 | manualSeed = random.randint(1, 10000) 69 | random.seed(manualSeed) 70 | np.random.seed(manualSeed) 71 | torch.manual_seed(manualSeed) 72 | torch.cuda.manual_seed_all(manualSeed) 73 | 74 | return network, optimizer 75 | -------------------------------------------------------------------------------- /Demo_Web_App/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def makedir(root): 5 | if not os.path.exists(root): 6 | os.makedirs(root) 7 | 8 | 9 | def write_json(data, root): 10 | with open(root, 'w') as f: 11 | json.dump(data, f) 12 | 13 | 14 | def check_exists(root): 15 | if os.path.exists(root): 16 | return True 17 | return False 18 | 19 | def check_file(root, keyword): 20 | if not os.path.isfile(root): 21 | raise RuntimeError('===> No {} in {}'.format(keyword, root)) 22 | -------------------------------------------------------------------------------- /Demo_Web_App/models/__pycache__/bi_lstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/bi_lstm.cpython-37.pyc -------------------------------------------------------------------------------- /Demo_Web_App/models/__pycache__/bi_lstm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/bi_lstm.cpython-38.pyc -------------------------------------------------------------------------------- /Demo_Web_App/models/__pycache__/mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /Demo_Web_App/models/__pycache__/mobilenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/mobilenet.cpython-38.pyc -------------------------------------------------------------------------------- /Demo_Web_App/models/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /Demo_Web_App/models/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /Demo_Web_App/models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /Demo_Web_App/models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /Demo_Web_App/models/bi_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | 5 | seed_num = 223 6 | torch.manual_seed(seed_num) 7 | random.seed(seed_num) 8 | 9 | """ 10 | Neural Networks model : Bidirection LSTM 11 | """ 12 | 13 | 14 | class BiLSTM(nn.Module): 15 | def __init__(self, args): 16 | super(BiLSTM, self).__init__() 17 | 18 | self.hidden_dim = args.num_lstm_units 19 | 20 | V = args.vocab_size 21 | D = args.embedding_size 22 | 23 | # word embedding 24 | self.embed = nn.Embedding(V, D, padding_idx=0) 25 | 26 | self.bilstm = nn.ModuleList() 27 | self.bilstm.append(nn.LSTM(D, args.num_lstm_units, num_layers=1, dropout=0, bidirectional=False, bias=False)) 28 | 29 | self.bidirectional = args.bidirectional 30 | if self.bidirectional: 31 | self.bilstm.append(nn.LSTM(D, args.num_lstm_units, num_layers=1, dropout=0, bidirectional=False, bias=False)) 32 | 33 | 34 | def forward(self, text, text_length): 35 | embed = self.embed(text) 36 | 37 | # unidirectional lstm 38 | bilstm_out = self.bilstm_out(embed, text_length, 0) 39 | 40 | if self.bidirectional: 41 | index_reverse = list(range(embed.shape[0]-1, -1, -1)) 42 | index_reverse = torch.LongTensor(index_reverse).cuda() 43 | embed_reverse = embed.index_select(0, index_reverse) 44 | text_length_reverse = text_length.index_select(0, index_reverse) 45 | bilstm_out_bidirection = self.bilstm_out(embed_reverse, text_length_reverse, 1) 46 | bilstm_out_bidirection_reverse = bilstm_out_bidirection.index_select(0, index_reverse) 47 | bilstm_out = torch.cat([bilstm_out, bilstm_out_bidirection_reverse], dim=2) 48 | bilstm_out, _ = torch.max(bilstm_out, dim=1) 49 | bilstm_out = bilstm_out.unsqueeze(2).unsqueeze(2) 50 | return bilstm_out 51 | 52 | 53 | def bilstm_out(self, embed, text_length, index): 54 | 55 | _, idx_sort = torch.sort(text_length, dim=0, descending=True) 56 | _, idx_unsort = torch.sort(idx_sort, dim=0) 57 | 58 | embed_sort = embed.index_select(0, idx_sort) 59 | length_list = text_length[idx_sort] 60 | pack = nn.utils.rnn.pack_padded_sequence(embed_sort, length_list.cpu(), batch_first=True) 61 | 62 | bilstm_sort_out, _ = self.bilstm[index](pack) 63 | bilstm_sort_out = nn.utils.rnn.pad_packed_sequence(bilstm_sort_out, batch_first=True) 64 | bilstm_sort_out = bilstm_sort_out[0] 65 | 66 | bilstm_out = bilstm_sort_out.index_select(0, idx_unsort) 67 | 68 | return bilstm_out 69 | 70 | 71 | def weight_init(self, m): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.xavier_uniform_(m.weight.data, 1) 74 | nn.init.constant(m.bias.data, 0) 75 | -------------------------------------------------------------------------------- /Demo_Web_App/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | """ 5 | Imported by https://github.com/marvis/pytorch-mobilenet/blob/master/main.py 6 | """ 7 | 8 | 9 | class MobileNetV1(nn.Module): 10 | def __init__(self, dropout_keep_prob=0.999): 11 | super(MobileNetV1, self).__init__() 12 | self.dropout_keep_prob = dropout_keep_prob 13 | self.dropout = nn.Dropout(1 - dropout_keep_prob) 14 | def conv_bn(inp, oup, stride): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 17 | nn.BatchNorm2d(oup), 18 | nn.ReLU6(inplace=True) 19 | ) 20 | 21 | def conv_dw(inp, oup, stride): 22 | return nn.Sequential( 23 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 24 | nn.BatchNorm2d(inp), 25 | nn.ReLU6(inplace=True), 26 | 27 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 28 | nn.BatchNorm2d(oup), 29 | nn.ReLU6(inplace=True), 30 | ) 31 | 32 | self.model = nn.Sequential( 33 | conv_bn(3, 32, 2), 34 | conv_dw(32, 64, 1), 35 | conv_dw(64, 128, 2), 36 | conv_dw(128, 128, 1), 37 | conv_dw(128, 256, 2), 38 | conv_dw(256, 256, 1), 39 | conv_dw(256, 512, 2), 40 | conv_dw(512, 512, 1), 41 | conv_dw(512, 512, 1), 42 | conv_dw(512, 512, 1), 43 | conv_dw(512, 512, 1), 44 | conv_dw(512, 512, 1), 45 | conv_dw(512, 1024, 2), 46 | conv_dw(1024, 1024, 1), 47 | nn.AvgPool2d(7), 48 | ) 49 | 50 | 51 | def weight_init(self, m): 52 | if isinstance(m, nn.Conv2d): 53 | # truncated_normal_initializer in tensorflow 54 | nn.init.normal_(m.weight.data, std=0.09) 55 | #nn.init.constant(m.bias.data, 0) 56 | 57 | def forward(self, x): 58 | x = self.model(x) 59 | x = self.dropout(x) 60 | return x 61 | -------------------------------------------------------------------------------- /Demo_Web_App/models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .bi_lstm import BiLSTM 3 | from .mobilenet import MobileNetV1 4 | from .resnet import resnet50 5 | 6 | 7 | class Model(nn.Module): 8 | def __init__(self, args): 9 | super(Model, self).__init__() 10 | if args.image_model == 'mobilenet_v1': 11 | self.image_model = MobileNetV1() 12 | self.image_model.apply(self.image_model.weight_init) 13 | elif args.image_model == 'resnet50': 14 | self.image_model = resnet50() 15 | elif args.image_model == 'resent101': 16 | self.image_model = resnet101() 17 | 18 | self.bilstm = BiLSTM(args) 19 | self.bilstm.apply(self.bilstm.weight_init) 20 | 21 | inp_size = 1024 22 | if args.image_model == 'resnet50' or args.image_model == 'resnet101': 23 | inp_size = 2048 24 | # shorten the tensor using 1*1 conv 25 | self.conv_images = nn.Conv2d(inp_size, args.feature_size, 1) 26 | self.conv_text = nn.Conv2d(1024, args.feature_size, 1) 27 | 28 | 29 | def forward(self, images, text, text_length): 30 | image_features = self.image_model(images) 31 | text_features = self.bilstm(text, text_length) 32 | image_embeddings, text_embeddings= self.build_joint_embeddings(image_features, text_features) 33 | 34 | return image_embeddings, text_embeddings 35 | 36 | 37 | def build_joint_embeddings(self, images_features, text_features): 38 | 39 | #images_features = images_features.permute(0,2,3,1) 40 | #text_features = text_features.permute(0,3,1,2) 41 | image_embeddings = self.conv_images(images_features).squeeze() 42 | text_embeddings = self.conv_text(text_features).squeeze() 43 | 44 | return image_embeddings, text_embeddings 45 | -------------------------------------------------------------------------------- /Demo_Web_App/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = conv1x1(inplanes, planes) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = conv3x3(planes, planes, stride) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = conv1x1(planes, planes * self.expansion) 71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, num_classes=1000): 102 | super(ResNet, self).__init__() 103 | self.inplanes = 64 104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 105 | bias=False) 106 | self.bn1 = nn.BatchNorm2d(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 109 | self.layer1 = self._make_layer(block, 64, layers[0]) 110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 113 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 114 | self.fc = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.constant_(m.weight, 1) 121 | nn.init.constant_(m.bias, 0) 122 | 123 | def _make_layer(self, block, planes, blocks, stride=1): 124 | downsample = None 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | conv1x1(self.inplanes, planes * block.expansion, stride), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for _ in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.relu(x) 143 | x = self.maxpool(x) 144 | 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | 150 | x = self.avgpool(x) 151 | #x = x.view(x.size(0), -1) 152 | #x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def resnet18(pretrained=False, **kwargs): 158 | """Constructs a ResNet-18 model. 159 | 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 164 | if pretrained: 165 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 166 | return model 167 | 168 | 169 | def resnet34(pretrained=False, **kwargs): 170 | """Constructs a ResNet-34 model. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 214 | return model 215 | -------------------------------------------------------------------------------- /Demo_Web_App/saved_model/test_sort.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/saved_model/test_sort.pkl -------------------------------------------------------------------------------- /Demo_Web_App/saved_model/word_to_index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/saved_model/word_to_index.pkl -------------------------------------------------------------------------------- /Demo_Web_App/static/temp/1620480640.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/static/temp/1620480640.jpg -------------------------------------------------------------------------------- /Demo_Web_App/templates/image_results.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Cross-modal learning for Fashion Retrieval 12 | 13 | 14 | 25 |
26 |
27 |
28 |
29 |
30 |

{{data[0]}}

31 |
32 |
33 | Go back to search 34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | Card image cap 42 |
43 |
1
44 |
45 |
46 |
47 | Card image cap 48 |
49 |
5
50 |
51 |
52 |
53 | Card image cap 54 |
55 |
9
56 |
57 |
58 |
59 | Card image cap 60 |
61 |
13
62 |
63 |
64 |
65 | Card image cap 66 |
67 |
17
68 |
69 |
70 |
71 |
72 |
73 | Card image cap 74 |
75 |
2
76 |
77 |
78 |
79 | Card image cap 80 |
81 |
6
82 |
83 |
84 |
85 | Card image cap 86 |
87 |
10
88 |
89 |
90 |
91 | Card image cap 92 |
93 |
14
94 |
95 |
96 |
97 | Card image cap 98 |
99 |
18
100 |
101 |
102 |
103 |
104 |
105 | Card image cap 106 |
107 |
3
108 |
109 |
110 |
111 | Card image cap 112 |
113 |
7
114 |
115 |
116 |
117 | Card image cap 118 |
119 |
11
120 |
121 |
122 |
123 | Card image cap 124 |
125 |
15
126 |
127 |
128 |
129 | Card image cap 130 |
131 |
19
132 |
133 |
134 |
135 |
136 |
137 | Card image cap 138 |
139 |
4
140 |
141 |
142 |
143 | Card image cap 144 |
145 |
8
146 |
147 |
148 |
149 | Card image cap 150 |
151 |
12
152 |
153 |
154 |
155 | Card image cap 156 |
157 |
16
158 |
159 |
160 |
161 | Card image cap 162 |
163 |
20
164 |
165 |
166 |
167 |
168 |
169 | 170 | 171 | 172 | 173 | 174 | 175 | 178 | 179 | -------------------------------------------------------------------------------- /Demo_Web_App/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Cross-modal learning for Fashion Retrieval 12 | 13 | 14 | 25 | 26 |
27 |
28 |

Query using an image or a text description

29 |

If you search using an image, we will retrieve the most similar text descriptions and if you search using a text description, we will retrieve the most similar images

30 | 31 |
32 |
33 |
34 |
35 |
36 |
37 | 38 | 39 |
40 | 41 | 42 |
43 |
44 |
45 |
46 |
47 | 48 | 49 |
50 | 51 |
52 |
53 |
54 | 55 | 56 | 57 | 58 | 59 | 62 | 63 | -------------------------------------------------------------------------------- /Demo_Web_App/templates/text_results.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | Cross-modal learning for Fashion Retrieval 12 | 13 | 14 | 25 |
26 |
27 |
28 | Card image cap 29 |
30 |
Query Image
31 | Go back to search 32 |
33 |
34 |
35 |
36 | 39 | 42 | 45 | 48 | 51 | 54 | 57 | 60 | 63 | 66 | 69 | 72 | 75 | 78 | 81 | 84 | 87 | 90 | 93 | 96 |
97 |
98 | 99 | 100 | 101 | 102 | 103 | 106 | 107 | -------------------------------------------------------------------------------- /Demo_Web_App/test_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description='command for evaluate on CUHK-PEDES') 7 | # Directory 8 | parser.add_argument('--image_dir', type=str, help='directory to store dataset') 9 | parser.add_argument('--anno_dir', type=str, help='directory to store anno') 10 | parser.add_argument('--model_path', type=str, help='directory to load checkpoint') 11 | parser.add_argument('--log_dir', type=str, help='directory to store log') 12 | 13 | # LSTM setting 14 | parser.add_argument('--embedding_size', type=int, default=512) 15 | parser.add_argument('--num_lstm_units', type=int, default=512) 16 | parser.add_argument('--vocab_size', type=int, default=12000) 17 | parser.add_argument('--lstm_dropout_ratio', type=float, default=0.7) 18 | parser.add_argument('--bidirectional', action='store_true') 19 | 20 | parser.add_argument('--max_length', type=int, default=100) 21 | parser.add_argument('--feature_size', type=int, default=512) 22 | 23 | parser.add_argument('--image_model', type=str, default='mobilenet_v1') 24 | parser.add_argument('--cnn_dropout_keep', type=float, default=0.999) 25 | 26 | parser.add_argument('--epoch_ema', type=int, default=0) 27 | 28 | # Default setting 29 | parser.add_argument('--gpus', type=str, default='0') 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | 35 | def config(): 36 | args = parse_args() 37 | return args 38 | -------------------------------------------------------------------------------- /Demo_Web_App/tester.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | GPUS='0' 4 | os.system('export CUDA_VISIBLE_DEVICES='+GPUS) 5 | 6 | BASE_ROOT='' 7 | IMAGE_DIR='/content' 8 | ANNO_DIR=BASE_ROOT+'/data/processed_data' 9 | CKPT_DIR=BASE_ROOT+'/data/model_data' 10 | LOG_DIR=BASE_ROOT+'/data/logs' 11 | IMAGE_MODEL='mobilenet_v1' 12 | lr='0.0002' 13 | batch_size='16' 14 | lr_decay_ratio='0.9' 15 | epoches_decay='80_150_200' 16 | 17 | 18 | string = 'python3 {BASE_ROOT}web_app.py --bidirectional --model_path {CKPT_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --image_model {IMAGE_MODEL} --log_dir {LOG_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --image_dir {IMAGE_DIR} --anno_dir {ANNO_DIR} --gpus {GPUS} --epoch_ema 0'.format(BASE_ROOT=BASE_ROOT, IMAGE_DIR=IMAGE_DIR, IMAGE_MODEL=IMAGE_MODEL, ANNO_DIR=ANNO_DIR, CKPT_DIR=CKPT_DIR, LOG_DIR=LOG_DIR, lr=lr, batch_size=batch_size, lr_decay_ratio=lr_decay_ratio, GPUS=GPUS) 19 | 20 | os.system(string) 21 | 22 | -------------------------------------------------------------------------------- /Demo_Web_App/web_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from PIL import Image 4 | import shutil 5 | import pickle 6 | import string 7 | import time 8 | from imageio import imread 9 | import numpy as np 10 | import torch 11 | import torchvision.transforms as transforms 12 | from test_config import config 13 | from config import network_config 14 | from werkzeug.utils import secure_filename 15 | from flask import Flask, render_template, flash, request, redirect, url_for 16 | from flask_ngrok import run_with_ngrok 17 | 18 | app = Flask(__name__) 19 | app.debug = True 20 | run_with_ngrok(app) 21 | network = None 22 | model_path = 'saved_model/299.pth.tar' 23 | test_sort_path = 'saved_model/test_sort.pkl' 24 | word_to_index_path = 'saved_model/word_to_index.pkl' 25 | test_sort = None 26 | word_to_index = None 27 | test_transform = transforms.Compose([ 28 | transforms.ToTensor(), 29 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 30 | ]) 31 | 32 | def load_model(args): 33 | network, _ = network_config(args,'test', None, True, model_path, False) 34 | network.eval() 35 | with open(test_sort_path, 'rb') as f_pkl: 36 | test_sort = pickle.load(f_pkl) 37 | with open(word_to_index_path, 'rb') as f_2_pkl: 38 | word_to_index = pickle.load(f_2_pkl) 39 | word_to_index = {k.lower(): v for k, v in word_to_index.items()} 40 | return network, test_sort, word_to_index 41 | 42 | def index_to_word(indexed_caption): 43 | index_to_word_dict = {value:key for key, value in word_to_index.items()} 44 | 45 | caption = [] 46 | for token in indexed_caption: 47 | if token in index_to_word_dict: 48 | caption.append(index_to_word_dict[token]) 49 | else: 50 | caption.append(' ') 51 | 52 | return ' '.join(caption[1:-1]) 53 | 54 | def retrieve_captions(): 55 | img = imread('static/temp/'+os.listdir('static/temp/')[-1]) 56 | img = np.array(Image.fromarray(img).resize(size=(224,224))) 57 | images = test_transform(img) 58 | images = torch.reshape(images, (1, images.shape[0], images.shape[1], images.shape[2])) 59 | 60 | captions = test_sort['caption_id'] 61 | caption_lengths = torch.tensor([len(c) for c in captions]) 62 | captions = torch.tensor([c + [0]*(100-len(c)) for c in captions]) 63 | 64 | with torch.no_grad(): 65 | image_embeddings, text_embeddings = network(images, captions, caption_lengths) 66 | 67 | sim = torch.matmul(text_embeddings, image_embeddings.t()) 68 | ind = sim.topk(20, 0)[1].reshape(-1) 69 | 70 | retrieved_captions = ['temp/'+os.listdir('static/temp/')[-1]] 71 | for i in ind: 72 | retrieved_captions.append(index_to_word(test_sort['caption_id'][i])) 73 | 74 | return retrieved_captions 75 | 76 | def retrieve_images(caption): 77 | exclude = set(string.punctuation) 78 | cap = ''.join(c for c in caption if c not in exclude) 79 | tokens = cap.split() 80 | tokens = [''] + tokens + [''] 81 | indexed_caption = [] 82 | 83 | for token in tokens: 84 | if token.lower() in word_to_index: 85 | indexed_caption.append(word_to_index[token.lower()]) 86 | else: 87 | indexed_caption.append(0) 88 | 89 | caption_lengths = torch.tensor([len(indexed_caption)]) 90 | indexed_caption += [0] * (100 - len(indexed_caption)) 91 | captions = torch.tensor([indexed_caption]) 92 | 93 | paths = test_sort['images_path'] 94 | images = [] 95 | for img_path in paths: 96 | img = imread('static/'+img_path) 97 | img = np.array(Image.fromarray(img).resize(size=(224,224))) 98 | images.append(test_transform(img)) 99 | images = torch.stack(images) 100 | 101 | with torch.no_grad(): 102 | image_embeddings, text_embeddings = network(images, captions, caption_lengths) 103 | 104 | sim = torch.matmul(image_embeddings, text_embeddings.t()) 105 | ind = sim.topk(20, 0)[1].reshape(-1) 106 | 107 | retrieved_images = [caption] 108 | for i in ind: 109 | retrieved_images.append(paths[i]) 110 | 111 | return retrieved_images 112 | 113 | def get_query(request): 114 | try: 115 | text = request.form['textquery'] 116 | except: 117 | text = None 118 | 119 | try: 120 | image = request.files['imagequery'] 121 | for fname in os.listdir('static/temp'): 122 | os.remove('static/temp/'+fname) 123 | image.save('static/temp/'+str(int(time.time()))+'.jpg') 124 | except: 125 | image = None 126 | 127 | if text is None: 128 | return (image, 'image') 129 | else: 130 | return (text, 'text') 131 | 132 | 133 | @app.route('/') 134 | def index(): 135 | return render_template('index.html') 136 | 137 | @app.route('/', methods=['POST']) 138 | def predict_from_location(): 139 | query, query_type = get_query(request) 140 | if query_type == 'text': 141 | retrieved_images = retrieve_images(query) 142 | return render_template('image_results.html', data=retrieved_images) 143 | else: 144 | retrieved_captions = retrieve_captions() 145 | return render_template('text_results.html', data=retrieved_captions) 146 | 147 | if __name__ == '__main__': 148 | print('Parsing arguments...') 149 | args = config() 150 | print('Loading model weights...') 151 | network, test_sort, word_to_index = load_model(args) 152 | app.run() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image_Text_Retrieval 2 | 3 | ## Motivation 4 | Cross-modal learning involves information obtained from morethan one modality. Fashion clothing industry is one such field whereproduct retrieval based on multiple modalities such as image and text has become important. In the online fashion industry, being able to search for a product that matches either an image queryor a text query is in high demand 5 | 6 | ## Our Work 7 | In this work, we implement different cross-modal learning schemes such as Siamese Network, Correlational Network and Deep Cross-Modal Projection Learning model and study their performance. We also propose a modified Deep Cross-Modal Projection Learning model that uses a different image feature extractor. We evaluate the model’s performance on image-text retrieval on afashion clothing dataset. 8 | 9 | ## Instructions to run the code 10 | 11 | The repository contains 3 folders each of which contains the source code for different model architectures we experimented with. Specifically these are: 12 | * Deep CMPL model 13 | * Siamese Network 14 | * Correlational Network 15 | 16 | Each of the folders contains a dedicated Readme detailing the instructions to run the source code for each model. The source code is well commented and readable. 17 | 18 |
19 | 20 | # Theoretical Details 21 | 22 | ## Model Architectures 23 | 24 | ### Siamese Network 25 | Siamese Network is a neural network architecture that contains two or more identical sub-networks having the same weights and parameters. It is commonly used to find the similarity of the inputs by comparing its feature vector outputs. We implemented a two-branch neural network inspired from Siamese Network architecture and used a contrastive loss function for our task. 26 | 27 |
28 | 29 | #### *Network Architecture* 30 | ![alt text](assets/siamese.png) 31 | 32 |
33 | 34 | ### Correlational Network 35 | Correlational Network is an autoencoder based approach that explicitly maximises correlation between image and text embedded vectors in addition to minimising the error of reconstructing thetwo views(image and text). This model also has two branches -one for images and one for text, but at the same time it also has anencoder and decoder. 36 | 37 |
38 | 39 | #### *Network Architecture* 40 | ![alt text](assets/correlational.png) 41 | 42 |
43 | 44 | ### DEEP CMPL Network 45 | Cross-Modal Projection Learning includes Cross-Modal Pro-jection Matching (CMPM) loss for learning discriminative image-text embeddings. This novel image-text matching loss minimizesthe relative entropy between the projection distributions and thenormalized matching distributions. 46 | 47 | #### Modified Deep CMPL 48 | We modified the Deep Cross-Modal Projection Learning modelby using the EfficientNet architecture instead of MobileNet as the image feature extractor. EfficientNet is a recently pro-posed convolutional neural architecture which outperforms other state-of-the-art convolutional neural networks both in terms of efficiency and accuracy. 49 | 50 |
51 | 52 | #### *Network Architecture* 53 | ![alt text](assets/deepcmpl.png) 54 | 55 |
56 | 57 | ## Experimentations 58 | 59 | We experimented different combinations of text and image feature extractors for learning common image-text embeddings. The tested combinations include: 60 | * Experiment 1: Siamese Network 61 | * Experiment 2: Correlational Network 62 | * Experiment 3: Deep CMPL with MobileNet 63 | * Experiment 4: Deep CMPL with EfficientNet on Indian Fashion 64 | * Experiment 5: Deep CMPL with EfficientNet on DeepFash-ion 65 | 66 |
67 | The metrics obtained from these experiments are as follows: 68 |
69 | 70 | ![alt text](assets/i2t.PNG) 71 | 72 | ![alt text](assets/t2i.PNG) 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /Siamese/README.md: -------------------------------------------------------------------------------- 1 | # Siamese Network 2 | 3 | > The way to load, train and test this model has been shown in the jupyter notebooks. The code has also been well documented detailing the use of every function. 4 | 5 | ## Instructions to Run Code 6 | 1. The text csv file is to be generated using the Word2vecgensim.ipynb notebook, for some text data. 7 | 2. Images must be added in the appropriate folder, comments are added to help. 8 | 3. Set all the paths in siamese_network.ipynb. 9 | 4. Then, simply run the siamese_network.ipynb notebook. -------------------------------------------------------------------------------- /Siamese/siamese_network.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "language": "python", 7 | "display_name": "Python 3", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "name": "python", 12 | "version": "3.7.9", 13 | "mimetype": "text/x-python", 14 | "codemirror_mode": { 15 | "name": "ipython", 16 | "version": 3 17 | }, 18 | "pygments_lexer": "ipython3", 19 | "nbconvert_exporter": "python", 20 | "file_extension": ".py" 21 | }, 22 | "colab": { 23 | "name": "siamese-network.ipynb", 24 | "provenance": [] 25 | } 26 | }, 27 | "cells": [ 28 | { 29 | "cell_type": "markdown", 30 | "metadata": { 31 | "id": "t37o4I-3obHH" 32 | }, 33 | "source": [ 34 | "## Importing libraries, initialising global variables" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "metadata": { 40 | "_uuid": "4df1608c-a2a9-48f9-8812-af1b5cdf1b16", 41 | "_cell_guid": "b9556bde-e5b6-4643-9afe-82298daf444e", 42 | "trusted": true, 43 | "id": "7CNvDBi7obHL" 44 | }, 45 | "source": [ 46 | "import imageio\n", 47 | "from statistics import median\n", 48 | "from random import randint\n", 49 | "from glob import glob\n", 50 | "import pandas as pd\n", 51 | "import numpy as np\n", 52 | "from keras.layers.core import Flatten, Dropout\n", 53 | "from keras.layers import Input, Dense, Lambda, Layer\n", 54 | "from keras import backend as K\n", 55 | "from keras import applications\n", 56 | "from keras.models import Sequential, Model\n", 57 | "from keras.optimizers import RMSprop, Adam\n", 58 | "from keras.callbacks import ModelCheckpoint\n", 59 | "from tensorflow.keras.preprocessing.image import load_img\n", 60 | "from tensorflow.keras.applications.resnet import preprocess_input\n", 61 | "from tensorflow.keras.preprocessing.image import img_to_array\n", 62 | "from tensorflow.keras.applications.resnet import ResNet152\n", 63 | "\n", 64 | "# Path to folder containing images\n", 65 | "DATASET_PATH = './Images'\n", 66 | "\n", 67 | "num_samples = 12000" 68 | ], 69 | "execution_count": 1, 70 | "outputs": [] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": { 75 | "id": "xPpnmyfIobHO" 76 | }, 77 | "source": [ 78 | "## Generator function" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "metadata": { 84 | "trusted": true, 85 | "id": "T37PuCDJobHQ" 86 | }, 87 | "source": [ 88 | "\n", 89 | "# data generator for neural network\n", 90 | "# forms correct and incorrect pairings of images with text descriptions and labels them as correct (1) or incorrect (0)\n", 91 | "\n", 92 | "def generator(batch_size, df):\n", 93 | " \n", 94 | " batch_img = np.zeros((batch_size, 224, 224, 3))\n", 95 | " batch_txt = np.zeros((batch_size, 512))\n", 96 | " batch_labels = np.zeros((batch_size,1))\n", 97 | " \n", 98 | " video_ids = df['image']\n", 99 | " video_txt = df['txt_enc']\n", 100 | " \n", 101 | " length = len(df) -1\n", 102 | " \n", 103 | " while True:\n", 104 | " for i in range(batch_size//2):\n", 105 | " \n", 106 | " i = i*2\n", 107 | " \n", 108 | " #correct\n", 109 | " sample = randint(0,length)\n", 110 | " file = video_ids.iloc[sample]\n", 111 | " \n", 112 | " correct_txt = video_txt.iloc[sample]\n", 113 | " \n", 114 | " im = load_img(file, target_size=(224, 224))\n", 115 | " im = img_to_array(im)\n", 116 | " im = np.expand_dims(im, axis=0)\n", 117 | " im = preprocess_input(im)\n", 118 | " \n", 119 | " batch_img[i-2] = im\n", 120 | " batch_txt[i-2] = correct_txt\n", 121 | " batch_labels[i-2] = 1\n", 122 | " \n", 123 | " #incorrect \n", 124 | " file = video_ids.iloc[randint(0,length)]\n", 125 | " \n", 126 | " im = load_img(file, target_size=(224, 224))\n", 127 | " im = img_to_array(im)\n", 128 | " im = np.expand_dims(im, axis=0)\n", 129 | " im = preprocess_input(im)\n", 130 | "\n", 131 | " batch_img[i-1] = im\n", 132 | " batch_txt[i-1] = correct_txt\n", 133 | " batch_labels[i-1] = 0\n", 134 | " \n", 135 | " yield [batch_txt, batch_img], batch_labels" 136 | ], 137 | "execution_count": 2, 138 | "outputs": [] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": { 143 | "id": "9Jd2CQ73obHS" 144 | }, 145 | "source": [ 146 | "## Utils" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "metadata": { 152 | "trusted": true, 153 | "id": "Ye8dgWGgobHT" 154 | }, 155 | "source": [ 156 | "def euclidean_distance(vects):\n", 157 | " x, y = vects\n", 158 | " return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))\n", 159 | "\n", 160 | "def eucl_dist_output_shape(shapes):\n", 161 | " shape1, shape2 = shapes\n", 162 | " return (shape1[0], 1)\n", 163 | "\n", 164 | "def contrastive_loss(y_true, y_pred):\n", 165 | " margin = 1\n", 166 | " return K.mean(y_true * K.square(y_pred) + (1 - y_true) * K.square(K.maximum(margin - y_pred, 0)))\n", 167 | "\n", 168 | "def create_img_encoder(input_dim, resnet):\n", 169 | " x = Sequential()\n", 170 | " x.add(resnet)\n", 171 | " x.add(Dense(500, activation=\"relu\"))\n", 172 | " x.add(Dropout(0.5))\n", 173 | " x.add(Dense(512, activation=\"relu\"))\n", 174 | " return x\n", 175 | "\n", 176 | "def create_txt_encoder(input_dim):\n", 177 | " x = Sequential()\n", 178 | " x.add(Dense(500, input_shape = (512,), activation=\"relu\"))\n", 179 | " x.add(Dropout(0.5))\n", 180 | " x.add(Dense(512, activation=\"relu\"))\n", 181 | " return x\n", 182 | "\n", 183 | "def compute_accuracy(predictions, labels):\n", 184 | " return labels[predictions.ravel() < 0.5].mean()" 185 | ], 186 | "execution_count": 3, 187 | "outputs": [] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": { 192 | "id": "3T5g0oNbobHV" 193 | }, 194 | "source": [ 195 | "## Initialise ResNet152" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "metadata": { 201 | "trusted": true, 202 | "colab": { 203 | "base_uri": "https://localhost:8080/" 204 | }, 205 | "id": "-77a0Xr8obHV", 206 | "outputId": "7e5a9a53-9a96-4212-826e-07eaab5a6a25" 207 | }, 208 | "source": [ 209 | "resnet = ResNet152(include_top=True, weights='imagenet')\n", 210 | "\n", 211 | "for layer in resnet.layers:\n", 212 | " layer.trainable = False" 213 | ], 214 | "execution_count": 4, 215 | "outputs": [ 216 | { 217 | "output_type": "stream", 218 | "text": [ 219 | "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet152_weights_tf_dim_ordering_tf_kernels.h5\n", 220 | "242900992/242900224 [==============================] - 2s 0us/step\n" 221 | ], 222 | "name": "stdout" 223 | } 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": { 229 | "id": "luQkh4-XobHX" 230 | }, 231 | "source": [ 232 | "## Creating model and loading data" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "metadata": { 238 | "trusted": true, 239 | "colab": { 240 | "base_uri": "https://localhost:8080/" 241 | }, 242 | "id": "x92fCwLmobHX", 243 | "outputId": "57462219-c1f8-4477-bbfb-885fba5c4d2f" 244 | }, 245 | "source": [ 246 | "input_txt = Input(shape=(512,))\n", 247 | "input_img = Input(shape=(224, 224, 3))\n", 248 | "\n", 249 | "txt_enc = create_txt_encoder(input_txt)\n", 250 | "img_enc = create_img_encoder(input_img, resnet)\n", 251 | "\n", 252 | "encoded_txt = txt_enc(input_txt)\n", 253 | "encoded_img = img_enc(input_img)\n", 254 | "\n", 255 | "distance = Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([encoded_txt, encoded_img])\n", 256 | "\n", 257 | "model = Model([input_txt, input_img], distance)\n", 258 | "\n", 259 | "adam = Adam(lr=0.00001)\n", 260 | "model.compile(loss=contrastive_loss, optimizer=adam)\n", 261 | "\n", 262 | "model.summary()\n", 263 | "\n", 264 | "\n" 265 | ], 266 | "execution_count": 5, 267 | "outputs": [ 268 | { 269 | "output_type": "stream", 270 | "text": [ 271 | "Model: \"model\"\n", 272 | "__________________________________________________________________________________________________\n", 273 | "Layer (type) Output Shape Param # Connected to \n", 274 | "==================================================================================================\n", 275 | "input_2 (InputLayer) [(None, 512)] 0 \n", 276 | "__________________________________________________________________________________________________\n", 277 | "input_3 (InputLayer) [(None, 224, 224, 3) 0 \n", 278 | "__________________________________________________________________________________________________\n", 279 | "sequential (Sequential) (None, 512) 513012 input_2[0][0] \n", 280 | "__________________________________________________________________________________________________\n", 281 | "sequential_1 (Sequential) (None, 512) 61176956 input_3[0][0] \n", 282 | "__________________________________________________________________________________________________\n", 283 | "lambda (Lambda) (None, 1) 0 sequential[0][0] \n", 284 | " sequential_1[0][0] \n", 285 | "==================================================================================================\n", 286 | "Total params: 61,689,968\n", 287 | "Trainable params: 1,270,024\n", 288 | "Non-trainable params: 60,419,944\n", 289 | "__________________________________________________________________________________________________\n" 290 | ], 291 | "name": "stdout" 292 | } 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "metadata": { 298 | "trusted": true, 299 | "id": "e2mh48irobHY" 300 | }, 301 | "source": [ 302 | "# The CSV generated by the word2vec(gensim) model\n", 303 | "data = pd.read_csv('./word2vec_gensim.csv', header=None)\n", 304 | "data = list(np.array(data))\n", 305 | "\n", 306 | "img_paths = [DATASET_PATH + str(i) + '.jpg' for i in range(12305)]\n", 307 | "\n", 308 | "dataset = pd.DataFrame()\n", 309 | "dataset['image'] = pd.Series(img_paths)\n", 310 | "dataset['txt_enc'] = pd.Series(data)\n", 311 | "\n", 312 | "df_test = dataset[num_samples:]\n", 313 | "dataset = dataset[:num_samples]\n", 314 | "\n", 315 | "df_train = dataset[:int(num_samples*0.8)]\n", 316 | "df_val = dataset[int(num_samples*0.8):]\n" 317 | ], 318 | "execution_count": null, 319 | "outputs": [] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "metadata": { 324 | "id": "aES0ha1fobHa" 325 | }, 326 | "source": [ 327 | "## Training" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "metadata": { 333 | "trusted": true, 334 | "id": "mrlWYM3aobHb" 335 | }, 336 | "source": [ 337 | "model.fit_generator(generator(30, df_train), steps_per_epoch= int(int(num_samples*0.8)/30), validation_data= generator(30, df_val), validation_steps=int(int(num_samples*0.2)/30), epochs=200, verbose=1)\n", 338 | "model.save_weights('./weights.h5')" 339 | ], 340 | "execution_count": null, 341 | "outputs": [] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": { 346 | "id": "n6aJV10WobHb" 347 | }, 348 | "source": [ 349 | "## Load saved weights" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "metadata": { 355 | "trusted": true, 356 | "id": "wgSxvtsrobHb" 357 | }, 358 | "source": [ 359 | "# Load from where you stored the weights\n", 360 | "model.load_weights('./weights.h5')" 361 | ], 362 | "execution_count": null, 363 | "outputs": [] 364 | }, 365 | { 366 | "cell_type": "markdown", 367 | "metadata": { 368 | "id": "2SLZVox5obHd" 369 | }, 370 | "source": [ 371 | "## Decide size of test set" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "metadata": { 377 | "trusted": true, 378 | "id": "mZgSAMXxobHd" 379 | }, 380 | "source": [ 381 | "subset_size = 300\n", 382 | "subset = df_test.iloc[:subset_size]" 383 | ], 384 | "execution_count": null, 385 | "outputs": [] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": { 390 | "id": "IQnoQLbZobHd" 391 | }, 392 | "source": [ 393 | "## Metrics" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "metadata": { 399 | "trusted": true, 400 | "id": "4J7OWSyGobHd" 401 | }, 402 | "source": [ 403 | "# metrics - img -> text\n", 404 | "\n", 405 | "mr = []\n", 406 | "top_1_count = 0\n", 407 | "top_5_count = 0\n", 408 | "top_10_count = 0\n", 409 | "\n", 410 | "for i in range(subset_size):\n", 411 | " file = subset['image'].iloc[i]\n", 412 | " im = load_img(file, target_size=(224, 224))\n", 413 | " im = img_to_array(im)\n", 414 | " im = np.expand_dims(im, axis=0)\n", 415 | " im = preprocess_input(im)\n", 416 | " \n", 417 | " image_array = np.zeros((subset_size, 224, 224, 3))\n", 418 | " for k in range(subset_size):\n", 419 | " image_array[k] = im\n", 420 | " \n", 421 | " txt_array = np.zeros((subset_size, 512))\n", 422 | " for j in range(subset_size):\n", 423 | " txt = subset['txt_enc'].iloc[j]\n", 424 | " txt_array[j] = txt\n", 425 | " \n", 426 | " predictions = [pred[0] for pred in model.predict([txt_array, image_array])]\n", 427 | " pred_i = predictions[i]\n", 428 | " predictions.sort()\n", 429 | " rank = predictions.index(pred_i)\n", 430 | " if rank < 10:\n", 431 | " top_10_count += 1\n", 432 | " if rank < 5:\n", 433 | " top_5_count += 1\n", 434 | " if rank < 1:\n", 435 | " top_1_count += 1\n", 436 | " mr.append(rank+1) \n", 437 | "\n", 438 | "print('Median Rank(img->txt):', median(mr)*100/subset_size, '%')\n", 439 | "print('R@1(img->txt):', top_1_count*100/subset_size, '%')\n", 440 | "print('R@5(img->txt):', top_5_count*100/subset_size, '%')\n", 441 | "print('R@10(img->txt):', top_10_count*100/subset_size, '%')" 442 | ], 443 | "execution_count": null, 444 | "outputs": [] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "metadata": { 449 | "trusted": true, 450 | "id": "a7YkBwZoobHd", 451 | "outputId": "e20587e6-d8d5-45af-8e8d-7678adef6643" 452 | }, 453 | "source": [ 454 | "# metrics - txt -> img\n", 455 | "\n", 456 | "mr = []\n", 457 | "top_1_count = 0\n", 458 | "top_5_count = 0\n", 459 | "top_10_count = 0\n", 460 | "\n", 461 | "for i in range(subset_size):\n", 462 | " txt = subset['txt_enc'].iloc[i] \n", 463 | " txt_array = np.zeros((subset_size, 512))\n", 464 | " for k in range(subset_size):\n", 465 | " txt_array[k] = txt\n", 466 | " \n", 467 | " \n", 468 | " image_array = np.zeros((subset_size, 224, 224, 3))\n", 469 | " for j in range(subset_size):\n", 470 | " file = subset['image'].iloc[j]\n", 471 | " im = load_img(file, target_size=(224, 224))\n", 472 | " im = img_to_array(im)\n", 473 | " im = np.expand_dims(im, axis=0)\n", 474 | " im = preprocess_input(im)\n", 475 | " image_array[k] = im\n", 476 | " \n", 477 | " predictions = [pred[0] for pred in model.predict([txt_array, image_array])]\n", 478 | " pred_i = predictions[i]\n", 479 | " predictions.sort()\n", 480 | " rank = predictions.index(pred_i)\n", 481 | " if rank < 10:\n", 482 | " top_10_count += 1\n", 483 | " if rank < 5:\n", 484 | " top_5_count += 1\n", 485 | " if rank < 1:\n", 486 | " top_1_count += 1\n", 487 | " mr.append(rank+1) \n", 488 | "\n", 489 | "print('Median Rank(txt->img):', median(mr)*100/subset_size, '%')\n", 490 | "print('R@1(txt->img):', top_1_count*100/subset_size, '%')\n", 491 | "print('R@5(txt->img):', top_5_count*100/subset_size, '%')\n", 492 | "print('R@10(txt->img):', top_10_count*100/subset_size, '%')" 493 | ], 494 | "execution_count": null, 495 | "outputs": [ 496 | { 497 | "output_type": "stream", 498 | "text": [ 499 | "Median Rank(txt->img): 0.6666666666666666 %\n", 500 | "R@1(txt->img): 33.666666666666664 %\n", 501 | "R@5(txt->img): 95.33333333333333 %\n", 502 | "R@10(txt->img): 95.33333333333333 %\n" 503 | ], 504 | "name": "stdout" 505 | } 506 | ] 507 | }, 508 | { 509 | "cell_type": "markdown", 510 | "metadata": { 511 | "id": "z8LTe6sGobHf" 512 | }, 513 | "source": [ 514 | "## Download Weights" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "metadata": { 520 | "trusted": true, 521 | "id": "osTWoctwobHf" 522 | }, 523 | "source": [ 524 | "# download weights\n", 525 | "\n", 526 | "from IPython.display import FileLink\n", 527 | "\n", 528 | "FileLink(r'./weights.h5')" 529 | ], 530 | "execution_count": null, 531 | "outputs": [] 532 | }, 533 | { 534 | "cell_type": "markdown", 535 | "metadata": { 536 | "id": "XObkIosAobHf" 537 | }, 538 | "source": [ 539 | "## Try predicting" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "metadata": { 545 | "trusted": true, 546 | "id": "s7pceQ4dobHg" 547 | }, 548 | "source": [ 549 | "# trying out predict\n", 550 | "\n", 551 | "text = np.zeros((2, 512))\n", 552 | "image = np.zeros((2, 224, 224, 3))\n", 553 | "\n", 554 | "file = dataset['image'].iloc[21] \n", 555 | "correct_txt = dataset['txt_enc'].iloc[21]\n", 556 | "\n", 557 | "im = load_img(file, target_size=(224, 224))\n", 558 | "im = img_to_array(im)\n", 559 | "im = np.expand_dims(im, axis=0)\n", 560 | "im = preprocess_input(im)\n", 561 | "\n", 562 | "image[0] = im\n", 563 | "\n", 564 | "text[0] = correct_txt\n", 565 | "\n", 566 | "file = dataset['image'].iloc[21] \n", 567 | "correct_txt = dataset['txt_enc'].iloc[90]\n", 568 | "\n", 569 | "im = load_img(file, target_size=(224, 224))\n", 570 | "im = img_to_array(im)\n", 571 | "im = np.expand_dims(im, axis=0)\n", 572 | "im = preprocess_input(im)\n", 573 | "\n", 574 | "image[1] = im\n", 575 | "\n", 576 | "text[1] = correct_txt\n", 577 | "\n", 578 | "model.predict([text, image])" 579 | ], 580 | "execution_count": null, 581 | "outputs": [] 582 | } 583 | ] 584 | } -------------------------------------------------------------------------------- /assets/correlational.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/assets/correlational.png -------------------------------------------------------------------------------- /assets/deepcmpl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/assets/deepcmpl.png -------------------------------------------------------------------------------- /assets/i2t.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/assets/i2t.PNG -------------------------------------------------------------------------------- /assets/siamese.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/assets/siamese.png -------------------------------------------------------------------------------- /assets/t2i.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/assets/t2i.PNG -------------------------------------------------------------------------------- /deep_cmpl_model/.gitignore: -------------------------------------------------------------------------------- 1 | <<<<<<< HEAD 2 | # Images of the dataset 3 | dataset 4 | 5 | # Checkpoint directories 6 | model_data 7 | logs 8 | 9 | #zip files 10 | *.zip 11 | 12 | # Misccellaneous 13 | data/f8k 14 | 15 | 16 | ======= 17 | dataset/ 18 | drive/ 19 | dgpu/drive/ 20 | dgpu/data/logs 21 | >>>>>>> 7d7a4ac943534fa53775eed173be3313ae889b49 22 | -------------------------------------------------------------------------------- /deep_cmpl_model/README.md: -------------------------------------------------------------------------------- 1 | # Deep Cross-Modal Projection Matching Model 2 | 3 | > The way to load, train and test this model has been shown in the deep_cmpl_jupyter.ipynb file. The code has also been well documented detailing the use of every function. 4 | 5 |
6 | 7 | 8 | ## Directory Structure 9 | 1. The code folder contains the source code for the model 10 | * The image and text models are stored in the models directory. These include mobilenet, efficientnet and bilstm. 11 | * The scripts folder contains the scripts for training and testing the model. 12 | * The utils folder contains metric calculation functions and other helper functions 13 | * The datasets folder contains `data.sh` which preprocesses the dataset and stores it in a pickled format. The `fashion.py` file contains the class for creating the dataset object. 14 | 2. The data folder contains dataset related files 15 | * The processed_data folder contains the pickled dataset 16 | * Images.csv is the file storing image paths and corresponding captions 17 | * Reid_raw.json is a file accepted as input by process.py 18 | 19 | 20 |
21 | 22 | ## Instructions to run 23 | 1. Set the appropriate path for storing checkpoints in `trainer.py` and `tester.py` scripts. Specifically, you need to change the BASE_ROOT_2 parameter in these files. Set it to a google folder to which you have access. 24 | 2. The sample code in the jupyter notebook loads the Indian Fashion Dataset. For training the model on some other dataset, you need to change `make_json.py`, `images.csv` and `preprocess.py` appropriately. 25 | 3. Execute the instructions in the notebook. 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/__pycache__/test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/test.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/__pycache__/test_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/test_config.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/__pycache__/test_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/test_config.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/__pycache__/test_params.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/test_params.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/__pycache__/train_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/train_config.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/__pycache__/train_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/train_config.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/__pycache__/train_params.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/train_params.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/datasets/__pycache__/directory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/datasets/__pycache__/directory.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/datasets/__pycache__/fashion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/datasets/__pycache__/fashion.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/datasets/__pycache__/pedes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/datasets/__pycache__/pedes.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/datasets/__pycache__/pedes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/datasets/__pycache__/pedes.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/datasets/data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_ROOT=Image_Text_Retrieval/deep_cmpl_model 4 | IMAGE_ROOT=/content/dataset 5 | 6 | JSON_ROOT=$BASE_ROOT/data/reid_raw.json 7 | OUT_ROOT=$BASE_ROOT/data/processed_data 8 | 9 | echo "Preprocessing dataset" 10 | 11 | rm -rf $OUT_ROOT 12 | 13 | python3 $BASE_ROOT/code/datasets/preprocess.py \ 14 | --img_root=${IMAGE_ROOT} \ 15 | --json_root=${JSON_ROOT} \ 16 | --out_root=${OUT_ROOT} \ 17 | --min_word_count 3 \ 18 | --first 19 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/datasets/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def makedir(root): 5 | if not os.path.exists(root): 6 | os.makedirs(root) 7 | 8 | 9 | def write_json(data, root): 10 | with open(root, 'w') as f: 11 | json.dump(data, f) 12 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/datasets/fashion.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import os 4 | import pickle 5 | import h5py 6 | from PIL import Image 7 | from utils.directory import check_exists 8 | # from scipy.misc import imresize 9 | from PIL import Image 10 | from imageio import imread 11 | 12 | 13 | """ 14 | Created the dataset object 15 | """ 16 | 17 | class Fashion(data.Dataset): 18 | ''' 19 | Args: 20 | root (string): Base root directory of dataset where [split].pkl and [split].h5 exists 21 | split (string): 'train', 'val' or 'test' 22 | transform (callable, optional): A function/transform that takes in an PIL image 23 | and returns a transformed vector. E.g, ''transform.RandomCrop' 24 | target_transform (callable, optional): A funciton/transform that tkes in the 25 | targt and transfomrs it. 26 | ''' 27 | pklname_list = ['train_sort.pkl', 'val_sort.pkl', 'test_sort.pkl'] 28 | h5name_list = ['train.h5', 'val.h5', 'test.h5'] 29 | 30 | def __init__(self, image_root, anno_root, split, max_length, transform=None, target_transform=None, cap_transform=None): 31 | 32 | self.image_root = image_root 33 | self.anno_root = anno_root 34 | self.max_length = max_length 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | self.cap_transform = cap_transform 38 | self.split = split.lower() 39 | 40 | if not check_exists(self.image_root): 41 | raise RuntimeError('Dataset not found or corrupted.' + 42 | 'Please follow the directions to generate datasets') 43 | 44 | if self.split == 'train': 45 | self.pklname = self.pklname_list[0] 46 | #self.h5name = self.h5name_list[0] 47 | 48 | with open(os.path.join(self.anno_root, self.pklname), 'rb') as f_pkl: 49 | data = pickle.load(f_pkl) 50 | self.train_labels = data['labels'] 51 | self.train_captions = data['caption_id'] 52 | self.train_images = data['images_path'] 53 | #data_h5py = h5py.File(os.path.join(self.root, self.h5name), 'r') 54 | #self.train_images = data_h5py['images'] 55 | 56 | 57 | elif self.split == 'val': 58 | self.pklname = self.pklname_list[1] 59 | #self.h5name = self.h5name_list[1] 60 | with open(os.path.join(self.anno_root, self.pklname), 'rb') as f_pkl: 61 | data = pickle.load(f_pkl) 62 | self.val_labels = data['labels'] 63 | self.val_captions = data['caption_id'] 64 | self.val_images = data['images_path'] 65 | #data_h5py = h5py.File(os.path.join(self.root, self.h5name), 'r') 66 | #self.val_images = data_h5py['images'] 67 | 68 | elif self.split == 'test': 69 | self.pklname = self.pklname_list[2] 70 | #self.h5name = self.h5name_list[2] 71 | 72 | with open(os.path.join(self.anno_root, self.pklname), 'rb') as f_pkl: 73 | data = pickle.load(f_pkl) 74 | self.test_labels = data['labels'] 75 | self.test_captions = data['caption_id'] 76 | self.test_images = data['images_path'] 77 | 78 | #data_h5py = h5py.File(os.path.join(self.root, self.h5name), 'r') 79 | #self.test_images = data_h5py['images'] 80 | 81 | else: 82 | raise RuntimeError('Wrong split which should be one of "train","val" or "test"') 83 | 84 | def __getitem__(self, index): 85 | """ 86 | Args: 87 | index(int): Index 88 | Returns: 89 | tuple: (images, labels, captions) 90 | """ 91 | if self.split == 'train': 92 | img_path, caption, label = self.train_images[index], self.train_captions[index], self.train_labels[index] 93 | elif self.split == 'val': 94 | img_path, caption, label = self.val_images[index], self.val_captions[index], self.val_labels[index] 95 | else: 96 | img_path, caption, label = self.test_images[index], self.test_captions[index], self.test_labels[index] 97 | img_path = os.path.join(self.image_root, img_path) 98 | img = imread(img_path) 99 | img=np.array(Image.fromarray(img).resize(size=(224,224))) 100 | # img = imresize(img, (224,224)) 101 | if len(img.shape) == 2: 102 | img = np.dstack((img,img,img)) 103 | img = Image.fromarray(img) 104 | 105 | if self.transform is not None: 106 | img = self.transform(img) 107 | 108 | if self.target_transform is not None: 109 | label = self.target_transform(label) 110 | 111 | if self.cap_transform is not None: 112 | caption = self.cap_transform(caption) 113 | caption = caption[1:-1] 114 | caption = np.array(caption) 115 | caption, mask = self.fix_length(caption) 116 | return img, caption, label, mask 117 | 118 | def fix_length(self, caption): 119 | caption_len = caption.shape[0] 120 | if caption_len < self.max_length: 121 | pad = np.zeros((self.max_length - caption_len, 1), dtype=np.int64) 122 | caption = np.append(caption, pad) 123 | return caption, caption_len 124 | 125 | def __len__(self): 126 | if self.split == 'train': 127 | return len(self.train_labels) 128 | elif self.split == 'val': 129 | return len(self.val_labels) 130 | else: 131 | return len(self.test_labels) 132 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle 3 | import json 4 | import argparse 5 | import string 6 | import os 7 | from directory import write_json, makedir 8 | from collections import namedtuple 9 | 10 | 11 | """ 12 | Preprocesses the json file 13 | """ 14 | 15 | ImageMetaData = namedtuple('ImageMetaData', ['id', 'image_path', 'captions', 'split']) 16 | ImageDecodeData = namedtuple('ImageDecodeData', ['id', 'image_path', 'captions_id', 'split']) 17 | 18 | 19 | class Vocabulary(object): 20 | """ 21 | Vocabulary wrapper 22 | """ 23 | def __init__(self, vocab, unk_id): 24 | """ 25 | :param vocab: A dictionary of word to word_id 26 | :param unk_id: Id of the bad/unknown words 27 | """ 28 | self._vocab = vocab 29 | self._unk_id = unk_id 30 | 31 | def word_to_id(self, word): 32 | if word not in self._vocab: 33 | return self._unk_id 34 | return self._vocab[word] 35 | 36 | 37 | def cap2tokens(cap): 38 | exclude = set(string.punctuation) 39 | caption = ''.join(c for c in cap if c not in exclude) 40 | tokens = caption.split() 41 | tokens = add_start_end(tokens) 42 | return tokens 43 | 44 | 45 | def add_start_end(tokens, start_word='', end_word=''): 46 | """ 47 | Add start and end words for a caption 48 | """ 49 | tokens_processed = [start_word] 50 | tokens_processed.extend(tokens) 51 | tokens_processed.append(end_word) 52 | return tokens_processed 53 | 54 | 55 | def process_captions(imgs): 56 | for img in imgs: 57 | img['processed_tokens'] = [] 58 | for s in img['captions']: 59 | tokens = cap2tokens(s) 60 | img['processed_tokens'].append(tokens) 61 | 62 | 63 | def build_vocab(imgs, args): 64 | print('start build vodabulary') 65 | counts = {} 66 | for img in imgs: 67 | for tokens in img['processed_tokens']: 68 | for word in tokens: 69 | counts[word] = counts.get(word, 0) + 1 70 | print('Total words:', len(counts)) 71 | 72 | # filter uncommon words and sort by descending count. 73 | # word_counts: a list of (words, count) for words satisfying the condition. 74 | word_counts = [(w,n) for w,n in counts.items() if n >= args.min_word_count] 75 | word_counts.sort(key = lambda x : x[1], reverse=True) 76 | print('Words in vocab:', len(word_counts)) 77 | 78 | # words_out: a list of (words, count) for words unsatisfying the condition. 79 | words_out = [(w,n) for w,n in counts.items() if n < args.min_word_count] 80 | bad_words = len(words_out) 81 | bad_count = sum(x[1] for x in words_out) 82 | 83 | # save the word counts file 84 | word_counts_root = os.path.join(args.out_root + '/word_counts.txt') 85 | with open(word_counts_root, 'w') as f: 86 | f.write('Total words: %d \n' % len(counts)) 87 | f.write('Words in vocabulary: %d \n' % len(word_counts)) 88 | f.write(str(word_counts)) 89 | 90 | word_reverse = [w for (w,n) in word_counts] 91 | vocab_dict = dict([(word, index) for (index, word) in enumerate(word_reverse)]) 92 | vocab = Vocabulary(vocab_dict, len(vocab_dict)) 93 | 94 | # Save word index as pickle form 95 | word_to_idx = {} 96 | for index, word in enumerate(word_reverse): 97 | word_to_idx[word] = index 98 | 99 | with open(os.path.join(args.out_root, 'word_to_index.pkl'), 'wb') as f: 100 | pickle.dump(word_to_idx, f) 101 | 102 | print('number of bad words: %d/%d = %.2f%%' % (bad_words, len(counts), bad_words * 100.0 / len(counts))) 103 | print('number of words in vocab: %d/%d = %.2f%%' % (len(word_counts), len(counts), len(word_counts) * 100.0 / len(counts))) 104 | print('number of Null: %d/%d = %.2f%%' % (bad_count, len(counts), bad_count * 100.0 / len(counts))) 105 | 106 | return vocab 107 | 108 | def load_vocab(args): 109 | 110 | with open(os.path.join(args.out_root, 'word_to_index.pkl'), 'rb') as f: 111 | word_to_idx = pickle.load(f) 112 | 113 | vocab = Vocabulary(word_to_idx, len(word_to_idx)) 114 | print('load vocabulary done') 115 | return vocab 116 | 117 | 118 | def process_metadata(split, data, args): 119 | """ 120 | Wrap data into ImageMatadata form 121 | """ 122 | id_to_captions = {} 123 | image_metadata = [] 124 | num_captions = 0 125 | count = 0 126 | 127 | for img in data: 128 | count += 1 129 | # absolute image path 130 | # filepath = os.path.join(args.img_root, img['file_path']) 131 | # relative image path 132 | filepath = img['file_path'] 133 | # assert os.path.exists(filepath) 134 | id = img['id'] - 1 135 | captions = img['processed_tokens'] 136 | id_to_captions.setdefault(id, []) 137 | id_to_captions[id].append(captions) 138 | assert split == img['split'], 'error: wrong split' 139 | image_metadata.append(ImageMetaData(id, filepath, captions, split)) 140 | num_captions += len(captions) 141 | 142 | print("Process metadata done!") 143 | print("Total %d captions %d images %d identities in %s" % (num_captions, count, len(id_to_captions), split)) 144 | with open(os.path.join(args.out_root, 'metadata_info.txt') ,'a') as f: 145 | f.write("Total %d captions %d images %d identities in %s" % (num_captions, count, len(id_to_captions), split)) 146 | f.write('\n') 147 | 148 | return image_metadata 149 | 150 | 151 | def process_decodedata(data, vocab): 152 | """ 153 | Decode ImageMetaData to ImageDecodeData 154 | Each item in imagedecodedata has 2 captions. (len(captions_id) = 2) 155 | """ 156 | image_decodedata = [] 157 | for img in data: 158 | image_path = img.image_path 159 | #image = imread(img.filepath) 160 | #image = imresize(image, (args.default_image_size, args.default_image_size)) 161 | # handle grayscale input images 162 | #if len(image.shape) == 2: 163 | # image = np.dstack((image, image, image)) 164 | # (height, width, channel) to (channel, height, weight) 165 | # (224,224,3) to (3,224,224)) 166 | #image = image.transpose(2,0,1) 167 | cap_to_vec = [] 168 | for cap in img.captions: 169 | cap_to_vec.append([vocab.word_to_id(word) for word in cap]) 170 | image_decodedata.append(ImageDecodeData(img.id, image_path, cap_to_vec, img.split)) 171 | 172 | print('Process decodedata done!') 173 | 174 | return image_decodedata 175 | 176 | 177 | def process_dataset(split, decodedata): 178 | # Process dataset 179 | 180 | # Arrange by caption in a sorted form 181 | dataset, label_range = create_dataset_sort(split, decodedata) 182 | write_dataset(split, dataset, args, label_range) 183 | 184 | 185 | def create_dataset_sort(split, data): 186 | images_sort = [] 187 | label_range = {} 188 | images = {} 189 | for img in data: 190 | label = img.id 191 | image = [ImageDecodeData(img.id, img.image_path, [caption_id], img.split) for caption_id in img.captions_id] 192 | if label in images: 193 | images[label].extend(image) 194 | label_range[label].append(len(image)) 195 | else: 196 | images[label] = image 197 | label_range[label] = [len(image)] 198 | 199 | print('=========== Arrange by id=============================') 200 | index = -1 201 | for label in images.keys(): 202 | # all captions arrange together 203 | images_sort.extend(images[label]) 204 | # label_range is arranged according to their actual index 205 | # label_range[label] = (previous, current] 206 | start = index 207 | for index_image in range(len(label_range[label])): 208 | label_range[label][index_image] += index 209 | index = label_range[label][index_image] 210 | label_range[label].append(start) 211 | 212 | return images_sort, label_range 213 | 214 | 215 | def write_dataset(split, data, args, label_range=None): 216 | """ 217 | Separate each component 218 | Write dataset into binary file 219 | """ 220 | caption_id = [] 221 | images_path = [] 222 | labels = [] 223 | 224 | for img in data: 225 | assert len(img.captions_id) == 1 226 | caption_id.append(img.captions_id[0]) 227 | labels.append(img.id) 228 | images_path.append(img.image_path) 229 | 230 | #N = len(images) 231 | data = {'caption_id':caption_id, 'labels':labels, 'images_path':images_path} 232 | 233 | if label_range is not None: 234 | data['label_range'] = label_range 235 | pickle_root = os.path.join(args.out_root, split + '_sort.pkl') 236 | else: 237 | pickle_root = os.path.join(args.out_root, split + '.pkl') 238 | # Write caption_id and labels as pickle form 239 | with open(pickle_root, 'wb') as f: 240 | pickle.dump(data, f) 241 | 242 | #h5py_root = os.path.join(args.out_root, split + '.h5') 243 | #f = h5py.File(h5py_root, 'w') 244 | #f.create_dataset('images', (N, 3, args.default_image_size, args.default_image_size), data=images) 245 | 246 | print('Save dataset') 247 | 248 | 249 | def generate_split(args): 250 | 251 | with open(args.json_root,'r') as f: 252 | imgs = json.load(f) 253 | # process caption 254 | process_captions(imgs) 255 | val_data = [] 256 | train_data = [] 257 | test_data = [] 258 | for img in imgs: 259 | if img['split'] == 'train': 260 | train_data.append(img) 261 | elif img['split'] =='val': 262 | val_data.append(img) 263 | else: 264 | test_data.append(img) 265 | write_json(train_data, os.path.join(args.out_root, 'train_reid.json')) 266 | write_json(val_data, os.path.join(args.out_root, 'val_reid.json')) 267 | write_json(test_data, os.path.join(args.out_root, 'test_reid.json')) 268 | 269 | return [train_data, val_data, test_data] 270 | 271 | 272 | def load_split(args): 273 | 274 | data = [] 275 | splits = ['train', 'val', 'test'] 276 | for split in splits: 277 | split_root = os.path.join(args.out_root, split + '_reid.json') 278 | with open(split_root, 'r') as f: 279 | split_data = json.load(f) 280 | data.append(split_data) 281 | 282 | print('load data done') 283 | return data 284 | 285 | 286 | def process_data(args): 287 | 288 | if args.first: 289 | train_data, val_data, test_data = generate_split(args) 290 | vocab = build_vocab(train_data, args) 291 | else: 292 | train_data, val_data, test_data = load_split(args) 293 | vocab = load_vocab(args) 294 | 295 | # Transform original data to Imagedata form. 296 | train_metadata = process_metadata('train', train_data, args) 297 | val_metadata = process_metadata('val', val_data, args) 298 | test_metadata = process_metadata('test', test_data, args) 299 | 300 | 301 | # Decode Imagedata to index caption and replace image file_root with image vecetor. 302 | train_decodedata = process_decodedata(train_metadata, vocab) 303 | val_decodedata = process_decodedata(val_metadata, vocab) 304 | test_decodedata = process_decodedata(test_metadata, vocab) 305 | 306 | 307 | process_dataset('train', train_decodedata) 308 | process_dataset('val', val_decodedata) 309 | process_dataset('test', test_decodedata) 310 | 311 | 312 | def parse_args(): 313 | parser = argparse.ArgumentParser(description='Command for data preprocessing') 314 | parser.add_argument('--img_root', type=str) 315 | parser.add_argument('--json_root', type=str) 316 | parser.add_argument('--out_root',type=str) 317 | parser.add_argument('--min_word_count', type=int) 318 | parser.add_argument('--default_image_size', type=int, default=224) 319 | parser.add_argument('--first', action='store_true') 320 | args = parser.parse_args() 321 | return args 322 | 323 | if __name__ == '__main__': 324 | args = parse_args() 325 | makedir(args.out_root) 326 | process_data(args) 327 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/__pycache__/bi_lstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/bi_lstm.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/__pycache__/bi_lstm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/bi_lstm.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/__pycache__/eff_net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/eff_net.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/__pycache__/efficient_net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/efficient_net.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/__pycache__/mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/__pycache__/mobilenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/mobilenet.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/bi_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | 5 | seed_num = 223 6 | torch.manual_seed(seed_num) 7 | random.seed(seed_num) 8 | 9 | """ 10 | Neural Networks model : Bidirection LSTM 11 | """ 12 | 13 | 14 | class BiLSTM(nn.Module): 15 | def __init__(self, args): 16 | super(BiLSTM, self).__init__() 17 | 18 | self.hidden_dim = args.num_lstm_units 19 | 20 | V = args.vocab_size 21 | D = args.embedding_size 22 | 23 | # word embedding 24 | self.embed = nn.Embedding(V, D, padding_idx=0) 25 | 26 | self.bilstm = nn.ModuleList() 27 | self.bilstm.append(nn.LSTM(D, args.num_lstm_units, num_layers=1, dropout=0, bidirectional=False, bias=False)) 28 | 29 | self.bidirectional = args.bidirectional 30 | if self.bidirectional: 31 | self.bilstm.append(nn.LSTM(D, args.num_lstm_units, num_layers=1, dropout=0, bidirectional=False, bias=False)) 32 | 33 | 34 | def forward(self, text, text_length): 35 | embed = self.embed(text) 36 | 37 | # unidirectional lstm 38 | bilstm_out = self.bilstm_out(embed, text_length, 0) 39 | 40 | if self.bidirectional: 41 | index_reverse = list(range(embed.shape[0]-1, -1, -1)) 42 | index_reverse = torch.LongTensor(index_reverse).cuda() 43 | embed_reverse = embed.index_select(0, index_reverse) 44 | text_length_reverse = text_length.index_select(0, index_reverse) 45 | bilstm_out_bidirection = self.bilstm_out(embed_reverse, text_length_reverse, 1) 46 | bilstm_out_bidirection_reverse = bilstm_out_bidirection.index_select(0, index_reverse) 47 | bilstm_out = torch.cat([bilstm_out, bilstm_out_bidirection_reverse], dim=2) 48 | bilstm_out, _ = torch.max(bilstm_out, dim=1) 49 | bilstm_out = bilstm_out.unsqueeze(2).unsqueeze(2) 50 | return bilstm_out 51 | 52 | 53 | def bilstm_out(self, embed, text_length, index): 54 | 55 | _, idx_sort = torch.sort(text_length, dim=0, descending=True) 56 | _, idx_unsort = torch.sort(idx_sort, dim=0) 57 | 58 | embed_sort = embed.index_select(0, idx_sort) 59 | length_list = text_length[idx_sort] 60 | pack = nn.utils.rnn.pack_padded_sequence(embed_sort, length_list.cpu(), batch_first=True) 61 | 62 | bilstm_sort_out, _ = self.bilstm[index](pack) 63 | bilstm_sort_out = nn.utils.rnn.pad_packed_sequence(bilstm_sort_out, batch_first=True) 64 | bilstm_sort_out = bilstm_sort_out[0] 65 | 66 | bilstm_out = bilstm_sort_out.index_select(0, idx_unsort) 67 | 68 | return bilstm_out 69 | 70 | 71 | def weight_init(self, m): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.xavier_uniform_(m.weight.data, 1) 74 | nn.init.constant(m.bias.data, 0) 75 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/eff_net.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import torch.nn as nn 4 | import math 5 | 6 | import torch 7 | from torchvision import transforms 8 | from efficientnet_pytorch import EfficientNet 9 | 10 | """ 11 | Efficient Net 12 | """ 13 | 14 | class EffNet(nn.Module): 15 | def __init__(self): 16 | super(EffNet,self).__init__() 17 | self.main_model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=1024) 18 | 19 | def forward(self,x): 20 | x = self.main_model(x) 21 | x = x.unsqueeze(-1) 22 | x = x.unsqueeze(-1) 23 | return x -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | """ 5 | Imported by https://github.com/marvis/pytorch-mobilenet/blob/master/main.py 6 | """ 7 | 8 | """ 9 | Mobilenet 10 | """ 11 | 12 | class MobileNetV1(nn.Module): 13 | def __init__(self, dropout_keep_prob=0.999): 14 | super(MobileNetV1, self).__init__() 15 | self.dropout_keep_prob = dropout_keep_prob 16 | self.dropout = nn.Dropout(1 - dropout_keep_prob) 17 | def conv_bn(inp, oup, stride): 18 | return nn.Sequential( 19 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 20 | nn.BatchNorm2d(oup), 21 | nn.ReLU6(inplace=True) 22 | ) 23 | 24 | def conv_dw(inp, oup, stride): 25 | return nn.Sequential( 26 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 27 | nn.BatchNorm2d(inp), 28 | nn.ReLU6(inplace=True), 29 | 30 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 31 | nn.BatchNorm2d(oup), 32 | nn.ReLU6(inplace=True), 33 | ) 34 | 35 | self.model = nn.Sequential( 36 | conv_bn(3, 32, 2), 37 | conv_dw(32, 64, 1), 38 | conv_dw(64, 128, 2), 39 | conv_dw(128, 128, 1), 40 | conv_dw(128, 256, 2), 41 | conv_dw(256, 256, 1), 42 | conv_dw(256, 512, 2), 43 | conv_dw(512, 512, 1), 44 | conv_dw(512, 512, 1), 45 | conv_dw(512, 512, 1), 46 | conv_dw(512, 512, 1), 47 | conv_dw(512, 512, 1), 48 | conv_dw(512, 1024, 2), 49 | conv_dw(1024, 1024, 1), 50 | nn.AvgPool2d(7), 51 | ) 52 | 53 | 54 | def weight_init(self, m): 55 | if isinstance(m, nn.Conv2d): 56 | # truncated_normal_initializer in tensorflow 57 | nn.init.normal_(m.weight.data, std=0.09) 58 | #nn.init.constant(m.bias.data, 0) 59 | 60 | def forward(self, x): 61 | x = self.model(x) 62 | x = self.dropout(x) 63 | return x 64 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .bi_lstm import BiLSTM 3 | from .mobilenet import MobileNetV1 4 | from .resnet import resnet50 5 | from .eff_net import EffNet 6 | 7 | """ 8 | The entire model pipeline 9 | """ 10 | 11 | class Model(nn.Module): 12 | def __init__(self, args): 13 | super(Model, self).__init__() 14 | if args.image_model == 'mobilenet_v1': 15 | self.image_model = MobileNetV1() 16 | self.image_model.apply(self.image_model.weight_init) 17 | elif args.image_model == 'resnet50': 18 | self.image_model = resnet50() 19 | elif args.image_model == 'resnet101': 20 | self.image_model = resnet101() 21 | elif args.image_model == 'efficient_net': 22 | self.image_model = EffNet() 23 | 24 | # self.bilstm = BiLSTM(args.num_lstm_units, num_stacked_layers = 1, vocab_size = args.vocab_size, embedding_dim = 512) 25 | self.bilstm = BiLSTM(args) 26 | self.bilstm.apply(self.bilstm.weight_init) 27 | 28 | inp_size = 1024 29 | if args.image_model == 'resnet50' or args.image_model == 'resnet101': 30 | inp_size = 2048 31 | # shorten the tensor using 1*1 conv 32 | self.conv_images = nn.Conv2d(inp_size, args.feature_size, 1) 33 | self.conv_text = nn.Conv2d(1024, args.feature_size, 1) 34 | 35 | 36 | def forward(self, images, text, text_length): 37 | image_features = self.image_model(images) 38 | # print("Image shape", image_features.shape) 39 | text_features = self.bilstm(text, text_length) 40 | image_embeddings, text_embeddings= self.build_joint_embeddings(image_features, text_features) 41 | 42 | return image_embeddings, text_embeddings 43 | 44 | 45 | def build_joint_embeddings(self, images_features, text_features): 46 | 47 | #images_features = images_features.permute(0,2,3,1) 48 | #text_features = text_features.permute(0,3,1,2) 49 | image_embeddings = self.conv_images(images_features).squeeze() 50 | text_embeddings = self.conv_text(text_features).squeeze() 51 | 52 | return image_embeddings, text_embeddings 53 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | """ 6 | Resnet 7 | """ 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | 28 | def conv1x1(in_planes, out_planes, stride=1): 29 | """1x1 convolution""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class Bottleneck(nn.Module): 66 | expansion = 4 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None): 69 | super(Bottleneck, self).__init__() 70 | self.conv1 = conv1x1(inplanes, planes) 71 | self.bn1 = nn.BatchNorm2d(planes) 72 | self.conv2 = conv3x3(planes, planes, stride) 73 | self.bn2 = nn.BatchNorm2d(planes) 74 | self.conv3 = conv1x1(planes, planes * self.expansion) 75 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | residual = x 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv3(out) 92 | out = self.bn3(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.relu(out) 99 | 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, block, layers, num_classes=1000): 106 | super(ResNet, self).__init__() 107 | self.inplanes = 64 108 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 109 | bias=False) 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 113 | self.layer1 = self._make_layer(block, 64, layers[0]) 114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 117 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 118 | self.fc = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | elif isinstance(m, nn.BatchNorm2d): 124 | nn.init.constant_(m.weight, 1) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | conv1x1(self.inplanes, planes * block.expansion, stride), 132 | nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for _ in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | 154 | x = self.avgpool(x) 155 | #x = x.view(x.size(0), -1) 156 | #x = self.fc(x) 157 | 158 | return x 159 | 160 | 161 | def resnet18(pretrained=False, **kwargs): 162 | """Constructs a ResNet-18 model. 163 | 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 168 | if pretrained: 169 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 170 | return model 171 | 172 | 173 | def resnet34(pretrained=False, **kwargs): 174 | """Constructs a ResNet-34 model. 175 | 176 | Args: 177 | pretrained (bool): If True, returns a model pre-trained on ImageNet 178 | """ 179 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 180 | if pretrained: 181 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 182 | return model 183 | 184 | 185 | def resnet50(pretrained=False, **kwargs): 186 | """Constructs a ResNet-50 model. 187 | 188 | Args: 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 192 | if pretrained: 193 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 194 | return model 195 | 196 | 197 | def resnet101(pretrained=False, **kwargs): 198 | """Constructs a ResNet-101 model. 199 | 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 204 | if pretrained: 205 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 206 | return model 207 | 208 | 209 | def resnet152(pretrained=False, **kwargs): 210 | """Constructs a ResNet-152 model. 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 216 | if pretrained: 217 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 218 | return model 219 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/scripts/tester.py: -------------------------------------------------------------------------------- 1 | # Used for testing the model 2 | 3 | import os 4 | 5 | GPUS='1' 6 | os.system('export CUDA_VISIBLE_DEVICES='+GPUS) 7 | 8 | BASE_ROOT='/content/Image_Text_Retrieval/deep_cmpl_model' 9 | BASE_ROOT_2='drive/Shareddrives/Image-Text-Retrieval/tempckpt' 10 | IMAGE_DIR='/content' 11 | ANNO_DIR=BASE_ROOT+'/data/processed_data' 12 | # CKPT_DIR=BASE_ROOT+'/data/model_data' 13 | # LOG_DIR=BASE_ROOT+'/data/logs' 14 | CKPT_DIR=BASE_ROOT_2+'/data/model_data' 15 | LOG_DIR=BASE_ROOT_2+'/data/logs' 16 | IMAGE_MODEL='efficient_net' 17 | lr='0.0002' 18 | batch_size='16' 19 | lr_decay_ratio='0.9' 20 | epoches_decay='80_150_200' 21 | 22 | 23 | string = 'python3 {BASE_ROOT}/code/test.py --bidirectional --model_path {CKPT_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --image_model {IMAGE_MODEL} --log_dir {LOG_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --image_dir {IMAGE_DIR} --anno_dir {ANNO_DIR} --gpus {GPUS} --epoch_ema 0'.format(BASE_ROOT=BASE_ROOT, IMAGE_DIR=IMAGE_DIR, IMAGE_MODEL=IMAGE_MODEL, ANNO_DIR=ANNO_DIR, CKPT_DIR=CKPT_DIR, LOG_DIR=LOG_DIR, lr=lr, batch_size=batch_size, lr_decay_ratio=lr_decay_ratio, GPUS=GPUS) 24 | 25 | os.system(string) 26 | 27 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/scripts/trainer.py: -------------------------------------------------------------------------------- 1 | # Used for training the model 2 | 3 | import os 4 | 5 | GPUS = '1' 6 | 7 | os.system('export CUDA_VISIBLE_DEVICES='+GPUS) 8 | 9 | 10 | BASE_ROOT='/content/Image_Text_Retrieval/deep_cmpl_model' 11 | BASE_ROOT_2='drive/Shareddrives/Image-Text-Retrieval/tempckpt' 12 | IMAGE_DIR='/content' 13 | ANNO_DIR=BASE_ROOT+'/data/processed_data' 14 | CKPT_DIR=BASE_ROOT_2+'/data/model_data' 15 | LOG_DIR=BASE_ROOT_2+'/data/logs' 16 | IMAGE_MODEL='efficient_net' 17 | lr='0.0002' 18 | num_epoches='300' 19 | batch_size='16' 20 | lr_decay_ratio='0.9' 21 | epoches_decay='80_150_200' 22 | 23 | models_path='{CKPT_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size}'.format(CKPT_DIR=CKPT_DIR,lr=lr,lr_decay_ratio=lr_decay_ratio,batch_size=batch_size) 24 | 25 | def get_model_path(): 26 | # print(models_path) 27 | MODEL_PATH=None 28 | if os.path.exists(models_path): 29 | l=os.listdir(models_path) 30 | try: 31 | l.remove('model_best') 32 | except: 33 | pass 34 | # print(l) 35 | if len(l)==0: 36 | return None 37 | # print(len(l)) 38 | 39 | l.sort(key=lambda x:int(x.split('.')[0])) 40 | if len(l)>=2: 41 | x=l[-2] 42 | else: 43 | x=l[-1] 44 | MODEL_PATH=os.path.join(models_path,x) 45 | 46 | return MODEL_PATH 47 | 48 | string = 'python3 {BASE_ROOT}/code/train.py --CMPM --bidirectional --image_model {IMAGE_MODEL} --log_dir {LOG_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --checkpoint_dir {CKPT_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --image_dir {IMAGE_DIR} --anno_dir {ANNO_DIR} --batch_size {batch_size} --gpus {GPUS} --num_epoches {num_epoches} --lr {lr} --lr_decay_ratio {lr_decay_ratio} --epoches_decay {epoches_decay} --num_images 12305'.format(BASE_ROOT=BASE_ROOT, IMAGE_DIR=IMAGE_DIR, IMAGE_MODEL=IMAGE_MODEL, ANNO_DIR=ANNO_DIR, CKPT_DIR=CKPT_DIR, LOG_DIR=LOG_DIR, lr=lr, num_epoches=num_epoches, batch_size=batch_size, lr_decay_ratio=lr_decay_ratio, epoches_decay=epoches_decay, GPUS=GPUS) 49 | 50 | MODEL_PATH = get_model_path() 51 | if(MODEL_PATH!=None): 52 | string += ' --resume --model_path {MODEL_PATH}'.format(MODEL_PATH=MODEL_PATH) 53 | 54 | os.system(string) 55 | 56 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import shutil 5 | import gc 6 | import torch 7 | import torchvision.transforms as transforms 8 | import torch.nn as nn 9 | import torch.utils.data as data 10 | 11 | from utils.helpers import avg_calculator 12 | from utils.metric import compute_topk, compute_mr 13 | from utils.directory import makedir, check_file 14 | 15 | from datasets.fashion import Fashion 16 | from models.model import Model 17 | 18 | from test_params import get_test_args 19 | 20 | 21 | 22 | """ 23 | Calculates the recall rate and median rank metrics 24 | """ 25 | def get_metrics(test_loader, network, args): 26 | 27 | batch_time = avg_calculator() 28 | 29 | # switch to evaluate mode 30 | network.eval() 31 | 32 | num_samples = args.batch_size * len(test_loader) 33 | 34 | images_bank = torch.zeros((num_samples, args.feature_size)) 35 | text_bank = torch.zeros((num_samples,args.feature_size)) 36 | labels_bank = torch.zeros(num_samples) 37 | 38 | index = 0 39 | with torch.no_grad(): 40 | 41 | timer = time.time() 42 | 43 | for images, captions, labels, captions_length in test_loader: 44 | 45 | test_images = images 46 | test_captions = captions 47 | 48 | image_embeddings, text_embeddings = network(test_images, test_captions, captions_length) 49 | 50 | tsize = images.shape[0] 51 | 52 | images_bank[index: index + tsize] = image_embeddings 53 | text_bank[index: index + tsize] = text_embeddings 54 | labels_bank[index:index + tsize] = labels 55 | index+=tsize 56 | 57 | batch_time.update(time.time() - timer) 58 | timer = time.time() 59 | 60 | images_bank = images_bank[:index] 61 | text_bank = text_bank[:index] 62 | labels_bank = labels_bank[:index] 63 | 64 | i2t_top1, i2t_top5, i2t_top10, t2i_top1, t2i_top5, t2i_top10 = compute_topk(images_bank, text_bank, labels_bank, labels_bank, [1,5,10], True) 65 | i2t_mr, t2i_mr = compute_mr(images_bank, text_bank, labels_bank, labels_bank, 50, True) 66 | 67 | 68 | return i2t_top1, i2t_top5, i2t_top10, i2t_mr, t2i_top1, t2i_top5, t2i_top10, t2i_mr, batch_time.avg 69 | 70 | 71 | """ 72 | Initialise the data loader 73 | """ 74 | def get_data_loader(image_dir, anno_dir, batch_size, split, max_length): 75 | 76 | test_transform = transforms.Compose([ 77 | transforms.ToTensor(), 78 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 79 | ]) 80 | 81 | data_split = Fashion(image_dir, anno_dir, split, max_length, test_transform) 82 | 83 | loader = data.DataLoader(data_split, batch_size, shuffle=True, num_workers=2) 84 | 85 | return loader 86 | 87 | 88 | """ 89 | Returns a list of model checkpoint paths 90 | """ 91 | def get_test_model_paths(ckpt_path): 92 | 93 | test_models = os.listdir(ckpt_path) 94 | test_models.remove('model_best') 95 | test_models = sorted(test_models,key=lambda x: int(x.split(".")[0])) 96 | test_models = [os.path.join(ckpt_path,x) for x in test_models] 97 | # print(test_models) 98 | return test_models 99 | 100 | 101 | 102 | """ 103 | Initialise the network object 104 | """ 105 | def get_network(args, model_path=None): 106 | 107 | network = Model(args) 108 | network = nn.DataParallel(network) 109 | # cudnn.benchmark = True 110 | args.start_epoch = 0 111 | 112 | # process network params 113 | if model_path==None: 114 | raise ValueError('Supply the model path with --model_path while testing') 115 | check_file(model_path, 'model_file') 116 | checkpoint = torch.load(model_path) 117 | args.start_epoch = checkpoint['epoch'] + 1 118 | network_dict = checkpoint['network'] 119 | network.load_state_dict(network_dict) 120 | print('==> Loading checkpoint "{}"'.format(model_path)) 121 | 122 | return network 123 | 124 | 125 | 126 | """ 127 | Tests the model on the test set using the checkpoints stored, prints the metrics 128 | """ 129 | def main(args): 130 | 131 | test_loader = get_data_loader(args.image_dir, args.anno_dir, args.batch_size, 'test', args.max_length) 132 | 133 | i2t_top1 = 0.0 134 | i2t_top5 = 0.0 135 | i2t_top10 = 0.0 136 | i2t_mr = 0.0 137 | 138 | t2i_top1 = 0.0 139 | t2i_top5 = 0.0 140 | t2i_top10 = 0.0 141 | t2i_mr = 0.0 142 | 143 | test_models = get_test_model_paths(args.model_path) 144 | best_model_path = None 145 | 146 | for model_path in test_models: 147 | 148 | network= get_network(args, model_path) 149 | 150 | i2t_top1_cur, i2t_top5_cur, i2t_top10_cur, i2t_mr_cur, t2i_top1_cur, t2i_top5_cur, t2i_top10_cur, t2i_mr_cur, test_time = get_metrics(test_loader, network, args) 151 | 152 | if t2i_top1_cur > t2i_top1: 153 | 154 | i2t_top1 = i2t_top1_cur 155 | i2t_top5 = i2t_top5_cur 156 | i2t_top10 = i2t_top10_cur 157 | i2t_mr = i2t_mr_cur 158 | 159 | t2i_top1 = t2i_top1_cur 160 | t2i_top5 = t2i_top5_cur 161 | t2i_top10 = t2i_top10_cur 162 | t2i_mr = t2i_mr_cur 163 | 164 | best_model_path = model_path 165 | dst_best = os.path.join(args.model_path, 'model_best', 'best.pth.tar') 166 | shutil.copyfile(model_path, dst_best) 167 | 168 | 169 | # print("Best model: {}".format(best_model_path)) 170 | print('t2i_top1_best: {:.3f}, t2i_top5_best: {:.3f}, t2i_top10_best: {:.3f}, t2i_mr_best: {:.3f}'.format(t2i_top1, t2i_top5, t2i_top10, t2i_mr)) 171 | print('i2t_top1: {:.3f}, i2t_top5: {:.3f}, i2t_top10: {:.3f}, i2t_mr_best: {:.3f}'.format(i2t_top1, i2t_top5, i2t_top10, i2t_mr)) 172 | 173 | 174 | 175 | if __name__ == '__main__': 176 | # get user's arguments 177 | args = get_test_args() 178 | main(args) -------------------------------------------------------------------------------- /deep_cmpl_model/code/test_params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | """ 5 | Used for setting the paramsters while testing 6 | """ 7 | 8 | def get_test_args(): 9 | parser = argparse.ArgumentParser(description='command for evaluate on Fashion Dataset') 10 | 11 | # Directory 12 | parser.add_argument('--image_dir', type=str, help='directory to store dataset') 13 | parser.add_argument('--anno_dir', type=str, help='directory to store anno') 14 | parser.add_argument('--model_path', type=str, help='directory to load checkpoint') 15 | parser.add_argument('--log_dir', type=str, help='directory to store log') 16 | 17 | # LSTM setting 18 | parser.add_argument('--embedding_size', type=int, default=512) 19 | parser.add_argument('--num_lstm_units', type=int, default=512) 20 | parser.add_argument('--vocab_size', type=int, default=12000) 21 | parser.add_argument('--lstm_dropout_ratio', type=float, default=0.7) 22 | parser.add_argument('--bidirectional', action='store_true') 23 | parser.add_argument('--max_length', type=int, default=100) 24 | parser.add_argument('--feature_size', type=int, default=512) 25 | 26 | # Model Setting 27 | parser.add_argument('--image_model', type=str, default='mobilenet_v1') 28 | parser.add_argument('--cnn_dropout_keep', type=float, default=0.999) 29 | parser.add_argument('--batch_size', type=int, default=64) 30 | 31 | # Optimization Settings 32 | parser.add_argument('--epoch_ema', type=int, default=0) 33 | 34 | # Default setting 35 | parser.add_argument('--gpus', type=str, default='0') 36 | args = parser.parse_args() 37 | return args 38 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import time 5 | import random 6 | import numpy as np 7 | import torch 8 | import torch.utils.data as data 9 | import torch.nn as nn 10 | import torchvision.transforms as transforms 11 | 12 | from utils.metric import Loss 13 | from utils.helpers import avg_calculator 14 | from utils.directory import makedir, check_file 15 | 16 | from datasets.fashion import Fashion 17 | from models.model import Model 18 | 19 | from train_params import get_train_args 20 | 21 | 22 | """ 23 | Train for one epoch, apply backpropagation, return loss after current epoch 24 | """ 25 | def train(epoch, train_loader, network, optimizer, compute_loss, args): 26 | 27 | # Trains for 1 epoch 28 | batch_time = avg_calculator() 29 | train_loss = avg_calculator() 30 | 31 | #switch to train mode, req in some modules 32 | network.train() 33 | 34 | end = time.time() 35 | 36 | for step, (images, captions, labels, captions_length) in enumerate(train_loader): 37 | 38 | images = images 39 | labels = labels 40 | captions = captions 41 | 42 | # compute loss 43 | image_embeddings, text_embeddings = network(images, captions, captions_length) 44 | cmpm_loss, pos_avg_sim, neg_arg_sim = compute_loss(image_embeddings, text_embeddings, labels) 45 | 46 | 47 | if step % 10 == 0: 48 | print('epoch:{}, step:{}, cmpm_loss:{:.3f}'.format(epoch, step, cmpm_loss)) 49 | 50 | 51 | # compute gradient and do ADAM step 52 | optimizer.zero_grad() 53 | cmpm_loss.backward() 54 | #nn.utils.clip_grad_norm(network.parameters(), 5) 55 | optimizer.step() 56 | 57 | # measure elapsed time 58 | batch_time.update(time.time() - end) 59 | end = time.time() 60 | 61 | train_loss.update(cmpm_loss, images.shape[0]) 62 | 63 | return train_loss.avg, batch_time.avg 64 | 65 | 66 | 67 | """ 68 | Initialise the data loader 69 | """ 70 | def get_data_loader(image_dir, anno_dir, batch_size, split, max_length): 71 | 72 | train_transform = transforms.Compose([ 73 | transforms.RandomHorizontalFlip(), 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 76 | ]) 77 | 78 | data_split = Fashion(image_dir, anno_dir, split, max_length, train_transform) 79 | 80 | loader = data.DataLoader(data_split, batch_size, shuffle=True, num_workers=2) 81 | 82 | return loader 83 | 84 | 85 | """ 86 | Initialise the network object 87 | """ 88 | def get_network(args, resume=False, model_path=None): 89 | 90 | network = Model(args) 91 | network = nn.DataParallel(network) 92 | # cudnn.benchmark = True 93 | args.start_epoch = 0 94 | 95 | # process network params 96 | if resume: 97 | if model_path==None: 98 | raise ValueError('Supply the model path with --model_path while using ---resume') 99 | check_file(model_path, 'model_file') 100 | checkpoint = torch.load(model_path) 101 | args.start_epoch = checkpoint['epoch'] + 1 102 | network_dict = checkpoint['network'] 103 | network.load_state_dict(network_dict) 104 | print('==> Loading checkpoint "{}"'.format(model_path)) 105 | 106 | return network 107 | 108 | 109 | """ 110 | Initialise optimizer object 111 | """ 112 | def get_optimizer(args, network=None, param=None, resume=False, model_path=None): 113 | 114 | #process optimizer params 115 | 116 | # optimizer 117 | # different params for different part 118 | cnn_params = list(map(id, network.module.image_model.parameters())) 119 | other_params = filter(lambda p: id(p) not in cnn_params, network.parameters()) 120 | other_params = list(other_params) 121 | if param is not None: 122 | other_params.extend(list(param)) 123 | param_groups = [{'params':other_params}, {'params':network.module.image_model.parameters(), 'weight_decay':args.wd}] 124 | 125 | optimizer = torch.optim.Adam( 126 | param_groups, 127 | lr = args.lr, betas=(args.adam_alpha, args.adam_beta), eps=args.epsilon 128 | ) 129 | 130 | if resume: 131 | check_file(model_path, 'model_file') 132 | checkpoint = torch.load(model_path) 133 | optimizer.load_state_dict(checkpoint['optimizer']) 134 | 135 | print('Total params: %2.fM' % (sum(p.numel() for p in network.parameters()) / 1000000.0)) 136 | # seed 137 | 138 | manualSeed = random.randint(1, 10000) 139 | random.seed(manualSeed) 140 | np.random.seed(manualSeed) 141 | torch.manual_seed(manualSeed) 142 | # torch.cuda.manual_seed_all(manualSeed) 143 | 144 | return optimizer 145 | 146 | 147 | """ 148 | Modify learning rate 149 | """ 150 | def adjust_lr(optimizer, epoch, args): 151 | 152 | # Decay learning rate by args.lr_decay_ratio every args.epoches_decay 153 | 154 | if args.lr_decay_type == 'exponential': 155 | 156 | lr = args.lr * (1 - args.lr_decay_ratio) 157 | 158 | for param_group in optimizer.param_groups: 159 | param_group['lr'] = lr 160 | 161 | 162 | """ 163 | Multistep Learning Rate Scheduler which decays learning rate after a no. of epochs specified by the user 164 | """ 165 | def lr_scheduler(optimizer, args): 166 | 167 | if '_' in args.epoches_decay: 168 | epoches_list = args.epoches_decay.split('_') 169 | epoches_list = [int(e) for e in epoches_list] 170 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, epoches_list) 171 | else: 172 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(args.epoches_decay)) 173 | 174 | return scheduler 175 | 176 | 177 | """ 178 | Saves the checkpoints 179 | """ 180 | def save_checkpoint(state, epoch, ckpt_dir, is_best): 181 | 182 | filename = os.path.join(ckpt_dir, str(args.start_epoch + epoch)) + '.pth.tar' 183 | torch.save(state, filename) 184 | if is_best: 185 | dst_best = os.path.join(ckpt_dir, 'model_best', str(epoch)) + '.pth.tar' 186 | shutil.copyfile(filename, dst_best) 187 | 188 | 189 | 190 | 191 | """ 192 | Initializes data loader, loss object, network, optimizer, runs the desired no. of epochs 193 | """ 194 | def main(args): 195 | 196 | train_loader = get_data_loader(args.image_dir, args.anno_dir, args.batch_size, 'train', args.max_length) 197 | 198 | # loss 199 | compute_loss = Loss(args) 200 | nn.DataParallel(compute_loss) 201 | 202 | # network 203 | network = get_network(args, args.resume, args.model_path) 204 | optimizer = get_optimizer(args, network, compute_loss.parameters(), args.resume, args.model_path) 205 | 206 | # lr_scheduler 207 | scheduler = lr_scheduler(optimizer, args) 208 | 209 | for epoch in range(args.num_epoches - args.start_epoch): 210 | # train for one epoch 211 | train_loss, train_time = train(args.start_epoch + epoch, train_loader, network, optimizer, compute_loss, args) 212 | # evaluate on validation set 213 | print('Train done for epoch-{}'.format(args.start_epoch + epoch)) 214 | 215 | 216 | state = {'network': network.state_dict(), 217 | 'optimizer': optimizer.state_dict(), 218 | 'W': compute_loss.W, 219 | 'epoch': args.start_epoch + epoch 220 | } 221 | 222 | save_checkpoint(state, epoch, args.checkpoint_dir, False) 223 | 224 | adjust_lr(optimizer, args.start_epoch + epoch, args) 225 | scheduler.step() 226 | 227 | 228 | for param in optimizer.param_groups: 229 | print('lr:{}'.format(param['lr'])) 230 | break 231 | 232 | 233 | 234 | 235 | if __name__ == "__main__": 236 | 237 | # Get arguments passed by user 238 | args = get_train_args() 239 | 240 | # Validate existence of image and annotation directory 241 | if not os.path.exists(args.image_dir): 242 | raise ValueError('Supply the dataset directory with --image_dir') 243 | if not os.path.exists(args.anno_dir): 244 | raise ValueError('Supply the anno file with --anno_dir') 245 | 246 | # save checkpoint 247 | makedir(args.checkpoint_dir) 248 | makedir(os.path.join(args.checkpoint_dir,'model_best')) 249 | 250 | main(args) -------------------------------------------------------------------------------- /deep_cmpl_model/code/train_params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | """ 5 | Used for setting the paramsters for model, bilstm, etc. while training 6 | """ 7 | 8 | def get_train_args(): 9 | parser = argparse.ArgumentParser(description='command for train on Fashion Dataset') 10 | 11 | # Directory 12 | parser.add_argument('--image_dir', type=str, help='directory to store dataset') 13 | parser.add_argument('--anno_dir', type=str, help='directory to store anno file') 14 | parser.add_argument('--checkpoint_dir', type=str, help='directory to store checkpoint') 15 | parser.add_argument('--log_dir', type=str, help='directory to store log') 16 | parser.add_argument('--model_path', type=str, default = None, help='directory to pretrained model, whole model or just visual part') 17 | 18 | # LSTM setting 19 | parser.add_argument('--embedding_size', type=int, default=512) 20 | parser.add_argument('--num_lstm_units', type=int, default=512) 21 | parser.add_argument('--vocab_size', type=int, default=12000) 22 | parser.add_argument('--lstm_dropout_ratio', type=float, default=0.7) 23 | parser.add_argument('--max_length', type=int, default=100) 24 | parser.add_argument('--bidirectional', action='store_true') 25 | 26 | # Model setting 27 | parser.add_argument('--image_model', type=str, default='mobilenet_v1') 28 | parser.add_argument('--resume', action='store_true', help='whether or not to restore the pretrained whole model') 29 | parser.add_argument('--batch_size', type=int, default=16) 30 | parser.add_argument('--num_epoches', type=int, default=100) 31 | parser.add_argument('--ckpt_steps', type=int, default=5000, help='#steps to save checkpoint') 32 | parser.add_argument('--feature_size', type=int, default=512) 33 | parser.add_argument('--img_model', type=str, default='mobilenet_v1', help='model to train images') 34 | parser.add_argument('--loss_weight', type=float, default=1) 35 | parser.add_argument('--CMPM', action='store_true') 36 | parser.add_argument('--cnn_dropout_keep', type=float, default=0.999) 37 | parser.add_argument('--constraints_text', action='store_true') 38 | parser.add_argument('--constraints_images', action='store_true') 39 | parser.add_argument('--num_images', type=int, default=12305) 40 | parser.add_argument('--pretrained', action='store_true', help='whether or not to restore the pretrained visual model') 41 | 42 | # Optimization setting 43 | parser.add_argument('--optimizer', type=str, default='adam', help='one of "sgd", "adam", "rmsprop", "adadelta", or "adagrad"') 44 | parser.add_argument('--lr', type=float, default=0.0002) 45 | parser.add_argument('--wd', type=float, default=0.00004) 46 | parser.add_argument('--adam_alpha', type=float, default=0.9) 47 | parser.add_argument('--adam_beta', type=float, default=0.999) 48 | parser.add_argument('--epsilon', type=float, default=1e-8) 49 | parser.add_argument('--end_lr', type=float, default=0.0001, help='minimum end learning rate used by a polynomial decay learning rate') 50 | parser.add_argument('--lr_decay_type', type=str, default='fixed', help='One of "fixed" or "exponential"') 51 | parser.add_argument('--lr_decay_ratio', type=float, default=0.1) 52 | parser.add_argument('--epoches_decay', type=str, default='50_100', help='#epoches when learning rate decays') 53 | parser.add_argument('--epoch_ema', type=int, default=0) 54 | parser.add_argument('--ema_decay', type=float, default=0.9) 55 | 56 | # Default setting 57 | parser.add_argument('--gpus', type=str, default='0') 58 | 59 | args = parser.parse_args() 60 | return args -------------------------------------------------------------------------------- /deep_cmpl_model/code/utils/__pycache__/directory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/utils/__pycache__/directory.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/utils/__pycache__/directory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/utils/__pycache__/directory.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/utils/__pycache__/helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/utils/__pycache__/helpers.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/utils/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/utils/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/utils/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/utils/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /deep_cmpl_model/code/utils/directory.py: -------------------------------------------------------------------------------- 1 | # Directory related helper functions 2 | 3 | import os 4 | import json 5 | 6 | def makedir(root): 7 | if not os.path.exists(root): 8 | os.makedirs(root) 9 | 10 | 11 | def write_json(data, root): 12 | with open(root, 'w') as f: 13 | json.dump(data, f) 14 | 15 | 16 | def check_exists(root): 17 | if os.path.exists(root): 18 | return True 19 | return False 20 | 21 | def check_file(root, keyword): 22 | if not os.path.isfile(root): 23 | raise RuntimeError('===> No {} in {}'.format(keyword, root)) 24 | -------------------------------------------------------------------------------- /deep_cmpl_model/code/utils/helpers.py: -------------------------------------------------------------------------------- 1 | 2 | # Class to calculate average of a quantity over time 3 | class avg_calculator(): 4 | 5 | def __init__(self): 6 | self.reset() 7 | 8 | def reset(self): 9 | self.val = 0 10 | self.avg = 0 11 | self.sum = 0 12 | self.count = 0 13 | 14 | def update(self, val, n=1): 15 | self.val = val 16 | self.sum += n * val 17 | self.count += n 18 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /deep_cmpl_model/code/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | from torch.nn.parameter import Parameter 7 | from torch.autograd import Variable 8 | from statistics import median 9 | 10 | 11 | def pairwise_distance(A, B): 12 | 13 | A_square = torch.sum(A * A, dim=1, keepdim=True) 14 | B_square = torch.sum(B * B, dim=1, keepdim=True) 15 | 16 | distance = A_square + B_square.t() - 2 * torch.matmul(A, B.t()) 17 | 18 | return distance 19 | 20 | 21 | def one_hot_coding(index, k): 22 | if type(index) is torch.Tensor: 23 | length = len(index) 24 | else: 25 | length = 1 26 | out = torch.zeros((length, k), dtype=torch.int64).cuda() 27 | index = index.reshape((len(index), 1)) 28 | out.scatter_(1, index, 1) 29 | return out 30 | 31 | 32 | 33 | """ 34 | LOSS MODULE 35 | """ 36 | 37 | 38 | class Loss(nn.Module): 39 | 40 | def __init__(self, args): 41 | 42 | super(Loss, self).__init__() 43 | 44 | self.CMPM = True 45 | self.epsilon = args.epsilon 46 | 47 | self.num_images = args.num_images 48 | 49 | if args.resume: 50 | checkpoint = torch.load(args.model_path) 51 | self.W = Parameter(checkpoint['W']) 52 | print('=====> Loading weights from pretrained path') 53 | else: 54 | self.W = Parameter(torch.randn(args.feature_size, args.num_images)) 55 | nn.init.xavier_uniform_(self.W.data, gain=1) 56 | 57 | 58 | # CMPM Loss 59 | def compute_cmpm_loss(self, image_embeddings, text_embeddings, labels): 60 | """ 61 | Cross-Modal Projection Matching Loss(CMPM) 62 | :param image_embeddings: Tensor with dtype torch.float32 63 | :param text_embeddings: Tensor with dtype torch.float32 64 | :param labels: Tensor with dtype torch.int32 65 | :return: 66 | i2t_loss: cmpm loss for image projected to text 67 | t2i_loss: cmpm loss for text projected to image 68 | pos_avg_sim: average cosine-similarity for positive pairs 69 | neg_avg_sim: averate cosine-similarity for negative pairs 70 | """ 71 | 72 | batch_size = image_embeddings.shape[0] 73 | labels_reshape = torch.reshape(labels, (batch_size, 1)) 74 | labels_dist = labels_reshape - labels_reshape.t() 75 | labels_mask = (labels_dist == 0) 76 | 77 | image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) 78 | text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) 79 | image_proj_text = torch.matmul(image_embeddings, text_norm.t()) 80 | text_proj_image = torch.matmul(text_embeddings, image_norm.t()) 81 | 82 | # normalize the true matching distribution 83 | labels_mask_norm = labels_mask.float() / labels_mask.float().norm(dim=1) 84 | 85 | i2t_pred = F.softmax(image_proj_text, dim=1) 86 | #i2t_loss = i2t_pred * torch.log((i2t_pred + self.epsilon)/ (labels_mask_norm + self.epsilon)) 87 | i2t_loss = i2t_pred.to(device="cpu") * (F.log_softmax(image_proj_text.to(device="cpu"), dim=1) - torch.log(labels_mask_norm.to(device="cpu") + self.epsilon)) 88 | 89 | t2i_pred = F.softmax(text_proj_image, dim=1) 90 | #t2i_loss = t2i_pred * torch.log((t2i_pred + self.epsilon)/ (labels_mask_norm + self.epsilon)) 91 | t2i_loss = t2i_pred.to(device="cpu") * (F.log_softmax(text_proj_image.to(device="cpu"), dim=1) - torch.log(labels_mask_norm.to(device="cpu") + self.epsilon)) 92 | 93 | cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) 94 | 95 | sim_cos = torch.matmul(image_norm, text_norm.t()) 96 | 97 | pos_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask)) 98 | neg_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask == 0)) 99 | 100 | return cmpm_loss, pos_avg_sim, neg_avg_sim 101 | 102 | 103 | def forward(self, image_embeddings, text_embeddings, labels): 104 | 105 | cmpm_loss = 0.0 106 | neg_avg_sim = 0.0 107 | pos_avg_sim =0.0 108 | 109 | if self.CMPM: 110 | cmpm_loss, pos_avg_sim, neg_avg_sim = self.compute_cmpm_loss(image_embeddings, text_embeddings, labels) 111 | 112 | return cmpm_loss, pos_avg_sim, neg_avg_sim 113 | 114 | 115 | """" 116 | Recall rate and Median Rank 117 | """ 118 | 119 | # Computes the recall rate 120 | def compute_topk(query, gallery, target_query, target_gallery, k=[1,5,10], reverse=False): 121 | result = [] 122 | query = query / query.norm(dim=1,keepdim=True) 123 | gallery = gallery / gallery.norm(dim=1,keepdim=True) 124 | sim_cosine = torch.matmul(query, gallery.t()) 125 | result.extend(topk(sim_cosine, target_gallery, target_query, k=[1,5,10])) 126 | if reverse: 127 | result.extend(topk(sim_cosine, target_query, target_gallery, k=[1,5,10], dim=0)) 128 | return result 129 | 130 | 131 | def topk(sim, target_gallery, target_query, k=[1,5,10], dim=1): 132 | result = [] 133 | maxk = max(k) 134 | size_total = len(target_gallery) 135 | _, pred_index = sim.topk(maxk, dim, True, True) 136 | pred_labels = target_gallery[pred_index] 137 | if dim == 1: 138 | pred_labels = pred_labels.t() 139 | correct = pred_labels.eq(target_query.view(1,-1).expand_as(pred_labels)) 140 | 141 | for topk in k: 142 | #correct_k = torch.sum(correct[:topk]).float() 143 | correct_k = torch.sum(correct[:topk], dim=0) 144 | correct_k = torch.sum(correct_k > 0).float() 145 | result.append(correct_k * 100 / size_total) 146 | return result 147 | 148 | 149 | """ 150 | Computes the Median Rank 151 | """ 152 | def compute_mr(query, gallery, target_query, target_gallery, k, reverse=False): 153 | result = [] 154 | query = query / query.norm(dim=1,keepdim=True) 155 | gallery = gallery / gallery.norm(dim=1,keepdim=True) 156 | sim_cosine = torch.matmul(query, gallery.t()) 157 | result.extend(mr(sim_cosine, target_gallery, target_query, k)) 158 | if reverse: 159 | result.extend(mr(sim_cosine, target_query, target_gallery, k, dim=0)) 160 | return result 161 | 162 | 163 | def mr(sim, target_gallery, target_query, k, dim=1): 164 | result = [] 165 | maxk = k 166 | size_total = len(target_gallery) 167 | _, pred_index = sim.topk(maxk, dim, True, True) 168 | pred_labels = target_gallery[pred_index] 169 | if dim == 1: 170 | pred_labels = pred_labels.t() 171 | correct = pred_labels.eq(target_query.view(1,-1).expand_as(pred_labels)) 172 | 173 | ranks = [] 174 | for row in correct.t(): 175 | temp = torch.where(row > 0)[0] 176 | if temp.shape[0] > 0: 177 | ranks.append(temp[0].item() + 1) 178 | else: 179 | ranks.append(k) 180 | # print('incr. k') 181 | 182 | result.append(median(ranks) * 100 / size_total) 183 | return result -------------------------------------------------------------------------------- /deep_cmpl_model/data/make_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import csv 3 | import random 4 | import os 5 | from math import floor 6 | 7 | random.seed(10) 8 | 9 | def get_len(csv_path,header=True): 10 | # Returns no of samples in the dataset 11 | with open(csv_path,"r") as csv_file: 12 | csv_reader=csv.reader(csv_file,delimiter=',') 13 | line_count=0 14 | l=[] 15 | for row in csv_reader: 16 | l.append(row[2]) 17 | line_count+=1 18 | return line_count-1 19 | 20 | 21 | def generate_split(num_samples,train_perc,val_perc): 22 | ids=list(range(num_samples)) 23 | # print(len(ids)) 24 | random.shuffle(ids) 25 | train_size = floor(num_samples*train_perc) 26 | val_size = floor(num_samples*val_perc) 27 | train_ids = ids[:train_size] 28 | val_ids = ids[train_size:train_size+val_size] 29 | test_ids=ids[train_size+val_size:] 30 | 31 | return train_ids,val_ids,test_ids 32 | 33 | def make_file(num_samples=None): 34 | if num_samples==None: 35 | num_samples=get_len(csv_path) 36 | print(num_samples) 37 | # num_samples = 12305 38 | # train -> 10000, test-> 2305 39 | train_ids,val_ids,test_ids = generate_split(num_samples,0.8,0.1) 40 | 41 | data_list=[] 42 | id = 1 43 | with open(csv_path,"r") as csv_file: 44 | csv_reader=csv.reader(csv_file,delimiter=',') 45 | line_no=0 46 | for row in csv_reader: 47 | line_no+=1 48 | if line_no==1: 49 | continue 50 | # if(line_no>num_samples): 51 | # break 52 | description, image_id = row[1],int(row[2]) 53 | sample_dict={} 54 | if image_id in train_ids: 55 | split="train" 56 | elif image_id in val_ids: 57 | split="val" 58 | elif image_id in test_ids: 59 | split="test" 60 | else: 61 | # print("**",image_id) 62 | # raise Exception("Sample not alloted") 63 | continue 64 | 65 | sample_dict["split"] = split 66 | sample_dict["captions"] = [description] 67 | sample_dict["file_path"] = os.path.join(img_path,str(line_no-2)+".jpg") 68 | sample_dict["processed_tokens"]=[[]] 69 | sample_dict["id"]=image_id 70 | id+=1 71 | 72 | data_list.append(sample_dict) 73 | 74 | sorted(data_list,key=lambda x:x["id"]) 75 | 76 | with open(out_path,"w") as f: 77 | json.dump(data_list,f) 78 | 79 | 80 | if __name__=="__main__": 81 | parent_folder = "/content/Image_Text_Retrieval/deep_cmpl_model" 82 | csv_path = parent_folder + "/data/images.csv" 83 | img_path = "dataset" 84 | out_path = parent_folder + "/data/reid_raw.json" 85 | make_file() 86 | 87 | 88 | 89 | # {"split": "train", 90 | # "captions": 91 | # ["A pedestrian with dark hair is wearing red and white shoes, a black hooded sweatshirt, and black pants.", 92 | # "The person has short black hair and is wearing black pants, a long sleeve black top, and red sneakers."], 93 | # "file_path": "CUHK01/0363004.png", 94 | # "processed_tokens": [["a", "pedestrian", "with", "dark", "hair", "is", "wearing", "red", "and", "white", "shoes", "a", "black", "hooded", "sweatshirt", "and", "black", "pants"], 95 | # ["the", "person", "has", "short", "black", "hair", "and", "is", "wearing", "black", "pants", "a", "long", "sleeve", "black", "top", "and", "red", "sneakers"]], 96 | # "id": 1} -------------------------------------------------------------------------------- /deep_cmpl_model/data/processed_data/metadata_info.txt: -------------------------------------------------------------------------------- 1 | Total 800 captions 800 images 800 identities in train 2 | Total 100 captions 100 images 100 identities in val 3 | Total 100 captions 100 images 100 identities in test 4 | -------------------------------------------------------------------------------- /deep_cmpl_model/data/processed_data/test_reid.json: -------------------------------------------------------------------------------- 1 | [{"split": "test", "captions": ["Soch Women Navy Blue & Grey Dyed Straight Kurta"], "file_path": "dataset/4.jpg", "processed_tokens": [["", "Soch", "Women", "Navy", "Blue", "Grey", "Dyed", "Straight", "Kurta", ""]], "id": 4}, {"split": "test", "captions": ["Soch Women Off-White & Blue Printed A-Line Kurta"], "file_path": "dataset/15.jpg", "processed_tokens": [["", "Soch", "Women", "OffWhite", "Blue", "Printed", "ALine", "Kurta", ""]], "id": 15}, {"split": "test", "captions": ["Sztori Women Green Solid Semi Sheer Straight Kurta"], "file_path": "dataset/32.jpg", "processed_tokens": [["", "Sztori", "Women", "Green", "Solid", "Semi", "Sheer", "Straight", "Kurta", ""]], "id": 32}, {"split": "test", "captions": ["HERE&NOW Women Grey Solid Kurta with Palazzos"], "file_path": "dataset/33.jpg", "processed_tokens": [["", "HERENOW", "Women", "Grey", "Solid", "Kurta", "with", "Palazzos", ""]], "id": 33}, {"split": "test", "captions": ["Sangria Women Navy Blue & Golden Striped Straight Kurta"], "file_path": "dataset/35.jpg", "processed_tokens": [["", "Sangria", "Women", "Navy", "Blue", "Golden", "Striped", "Straight", "Kurta", ""]], "id": 35}, {"split": "test", "captions": ["Fabindia Women Pink & Green Printed Slim Fit A-Line Kurta"], "file_path": "dataset/37.jpg", "processed_tokens": [["", "Fabindia", "Women", "Pink", "Green", "Printed", "Slim", "Fit", "ALine", "Kurta", ""]], "id": 37}, {"split": "test", "captions": ["Libas Women Maroon Yoke Design Straight Kurta"], "file_path": "dataset/45.jpg", "processed_tokens": [["", "Libas", "Women", "Maroon", "Yoke", "Design", "Straight", "Kurta", ""]], "id": 45}, {"split": "test", "captions": ["Myshka Women Pink Yoke Design A-Line Kurta"], "file_path": "dataset/66.jpg", "processed_tokens": [["", "Myshka", "Women", "Pink", "Yoke", "Design", "ALine", "Kurta", ""]], "id": 66}, {"split": "test", "captions": ["Fabindia Women Grey & Black Slim Fit Checked A-Line Kurta"], "file_path": "dataset/73.jpg", "processed_tokens": [["", "Fabindia", "Women", "Grey", "Black", "Slim", "Fit", "Checked", "ALine", "Kurta", ""]], "id": 73}, {"split": "test", "captions": ["Soch Women Green Yoke Design Straight Kurta"], "file_path": "dataset/77.jpg", "processed_tokens": [["", "Soch", "Women", "Green", "Yoke", "Design", "Straight", "Kurta", ""]], "id": 77}, {"split": "test", "captions": ["Anubhutee Women Navy Blue & Off-White Yoke Design Kurta with Palazzos"], "file_path": "dataset/82.jpg", "processed_tokens": [["", "Anubhutee", "Women", "Navy", "Blue", "OffWhite", "Yoke", "Design", "Kurta", "with", "Palazzos", ""]], "id": 82}, {"split": "test", "captions": ["Libas Women Red & Golden Block Print Kurta with Palazzos & Dupatta"], "file_path": "dataset/115.jpg", "processed_tokens": [["", "Libas", "Women", "Red", "Golden", "Block", "Print", "Kurta", "with", "Palazzos", "Dupatta", ""]], "id": 115}, {"split": "test", "captions": ["Vishudh Women Peach-Coloured Solid Straight Kurta"], "file_path": "dataset/136.jpg", "processed_tokens": [["", "Vishudh", "Women", "PeachColoured", "Solid", "Straight", "Kurta", ""]], "id": 136}, {"split": "test", "captions": ["Soch Women Red & Gold Printed Straight Kurta"], "file_path": "dataset/137.jpg", "processed_tokens": [["", "Soch", "Women", "Red", "Gold", "Printed", "Straight", "Kurta", ""]], "id": 137}, {"split": "test", "captions": ["Anouk Women Pink Printed Kurta with Palazzos"], "file_path": "dataset/142.jpg", "processed_tokens": [["", "Anouk", "Women", "Pink", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 142}, {"split": "test", "captions": ["Vishudh Women Black Printed A-Line Kurta"], "file_path": "dataset/153.jpg", "processed_tokens": [["", "Vishudh", "Women", "Black", "Printed", "ALine", "Kurta", ""]], "id": 153}, {"split": "test", "captions": ["Jompers Women Beige & Green Printed A-Line Kurta"], "file_path": "dataset/160.jpg", "processed_tokens": [["", "Jompers", "Women", "Beige", "Green", "Printed", "ALine", "Kurta", ""]], "id": 160}, {"split": "test", "captions": ["Vishudh Women Green Printed Straight Kurta"], "file_path": "dataset/164.jpg", "processed_tokens": [["", "Vishudh", "Women", "Green", "Printed", "Straight", "Kurta", ""]], "id": 164}, {"split": "test", "captions": ["ZIYAA Women Cream-Coloured Printed Kurta with Trousers"], "file_path": "dataset/178.jpg", "processed_tokens": [["", "ZIYAA", "Women", "CreamColoured", "Printed", "Kurta", "with", "Trousers", ""]], "id": 178}, {"split": "test", "captions": ["HERE&NOW Women Navy Blue & White Printed Straight Kurta"], "file_path": "dataset/199.jpg", "processed_tokens": [["", "HERENOW", "Women", "Navy", "Blue", "White", "Printed", "Straight", "Kurta", ""]], "id": 199}, {"split": "test", "captions": ["Anouk Women Olive Green & Mustard Brown Printed Kurta with Churidar & Dupatta"], "file_path": "dataset/211.jpg", "processed_tokens": [["", "Anouk", "Women", "Olive", "Green", "Mustard", "Brown", "Printed", "Kurta", "with", "Churidar", "Dupatta", ""]], "id": 211}, {"split": "test", "captions": ["Nayo Women Pink & Green Printed Kurta with Palazzos"], "file_path": "dataset/229.jpg", "processed_tokens": [["", "Nayo", "Women", "Pink", "Green", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 229}, {"split": "test", "captions": ["Rain & Rainbow Women Black & Mustard Yellow Printed A-Line Kurta"], "file_path": "dataset/241.jpg", "processed_tokens": [["", "Rain", "Rainbow", "Women", "Black", "Mustard", "Yellow", "Printed", "ALine", "Kurta", ""]], "id": 241}, {"split": "test", "captions": ["Vishudh Women Burgundy & Solid Kurta with Palazzos "], "file_path": "dataset/244.jpg", "processed_tokens": [["", "Vishudh", "Women", "Burgundy", "Solid", "Kurta", "with", "Palazzos", ""]], "id": 244}, {"split": "test", "captions": ["SASSAFRAS Women Navy Blue & Orange Printed A-Line Kurta"], "file_path": "dataset/245.jpg", "processed_tokens": [["", "SASSAFRAS", "Women", "Navy", "Blue", "Orange", "Printed", "ALine", "Kurta", ""]], "id": 245}, {"split": "test", "captions": ["Vishudh Women Green Embroidered Kurta with Palazzos"], "file_path": "dataset/255.jpg", "processed_tokens": [["", "Vishudh", "Women", "Green", "Embroidered", "Kurta", "with", "Palazzos", ""]], "id": 255}, {"split": "test", "captions": ["Vedic Women Brown Solid A-Line Kurta"], "file_path": "dataset/268.jpg", "processed_tokens": [["", "Vedic", "Women", "Brown", "Solid", "ALine", "Kurta", ""]], "id": 268}, {"split": "test", "captions": ["anayna Women Pink Self-Striped Kurta with Trousers & Dupatta"], "file_path": "dataset/284.jpg", "processed_tokens": [["", "anayna", "Women", "Pink", "SelfStriped", "Kurta", "with", "Trousers", "Dupatta", ""]], "id": 284}, {"split": "test", "captions": ["Vishudh Women Green Printed Kurta with Palazzos"], "file_path": "dataset/290.jpg", "processed_tokens": [["", "Vishudh", "Women", "Green", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 290}, {"split": "test", "captions": ["Jashn Women Blue & Black Checked High-Low Straight Kurta"], "file_path": "dataset/308.jpg", "processed_tokens": [["", "Jashn", "Women", "Blue", "Black", "Checked", "HighLow", "Straight", "Kurta", ""]], "id": 308}, {"split": "test", "captions": ["Ives Women Navy Blue & Red Printed Anarkali Kurta"], "file_path": "dataset/310.jpg", "processed_tokens": [["", "Ives", "Women", "Navy", "Blue", "Red", "Printed", "Anarkali", "Kurta", ""]], "id": 310}, {"split": "test", "captions": ["Anubhutee Women Navy Blue & White Printed Kurta with Palazzos"], "file_path": "dataset/321.jpg", "processed_tokens": [["", "Anubhutee", "Women", "Navy", "Blue", "White", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 321}, {"split": "test", "captions": ["Vishudh Women Grey & Mustard Printed Kurta with Palazzos"], "file_path": "dataset/332.jpg", "processed_tokens": [["", "Vishudh", "Women", "Grey", "Mustard", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 332}, {"split": "test", "captions": ["AKS Women Maroon Solid Anarkali Kurta"], "file_path": "dataset/335.jpg", "processed_tokens": [["", "AKS", "Women", "Maroon", "Solid", "Anarkali", "Kurta", ""]], "id": 335}, {"split": "test", "captions": ["Rain & Rainbow Women Black & Pink Printed Kurta with Churidar & Dupatta"], "file_path": "dataset/363.jpg", "processed_tokens": [["", "Rain", "Rainbow", "Women", "Black", "Pink", "Printed", "Kurta", "with", "Churidar", "Dupatta", ""]], "id": 363}, {"split": "test", "captions": ["Anouk Women Mauve Striped Straight Kurta"], "file_path": "dataset/369.jpg", "processed_tokens": [["", "Anouk", "Women", "Mauve", "Striped", "Straight", "Kurta", ""]], "id": 369}, {"split": "test", "captions": ["Ishin Women Red & White Printed Bandhani Kurta with Skirt"], "file_path": "dataset/371.jpg", "processed_tokens": [["", "Ishin", "Women", "Red", "White", "Printed", "Bandhani", "Kurta", "with", "Skirt", ""]], "id": 371}, {"split": "test", "captions": ["Ishin Women Taupe & Off-White Printed Kurta with Palazzos"], "file_path": "dataset/374.jpg", "processed_tokens": [["", "Ishin", "Women", "Taupe", "OffWhite", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 374}, {"split": "test", "captions": ["Vishudh Women Blue Solid Kurta with Palazzos"], "file_path": "dataset/383.jpg", "processed_tokens": [["", "Vishudh", "Women", "Blue", "Solid", "Kurta", "with", "Palazzos", ""]], "id": 383}, {"split": "test", "captions": ["Anouk Women White & Black Checked A-Line High-Low Kurta"], "file_path": "dataset/384.jpg", "processed_tokens": [["", "Anouk", "Women", "White", "Black", "Checked", "ALine", "HighLow", "Kurta", ""]], "id": 384}, {"split": "test", "captions": ["Fabindia Women Yellow Woven Design Straight Kurta"], "file_path": "dataset/390.jpg", "processed_tokens": [["", "Fabindia", "Women", "Yellow", "Woven", "Design", "Straight", "Kurta", ""]], "id": 390}, {"split": "test", "captions": ["Anubhutee Women Pink & Golden Printed Straight Kurta"], "file_path": "dataset/393.jpg", "processed_tokens": [["", "Anubhutee", "Women", "Pink", "Golden", "Printed", "Straight", "Kurta", ""]], "id": 393}, {"split": "test", "captions": ["Vishudh Women Navy Blue Printed A-Line Kurta"], "file_path": "dataset/422.jpg", "processed_tokens": [["", "Vishudh", "Women", "Navy", "Blue", "Printed", "ALine", "Kurta", ""]], "id": 422}, {"split": "test", "captions": ["Fabindia Women Pink Solid Straight Kurta"], "file_path": "dataset/430.jpg", "processed_tokens": [["", "Fabindia", "Women", "Pink", "Solid", "Straight", "Kurta", ""]], "id": 430}, {"split": "test", "captions": ["Vishudh Women Black Printed A-Line Kurta"], "file_path": "dataset/431.jpg", "processed_tokens": [["", "Vishudh", "Women", "Black", "Printed", "ALine", "Kurta", ""]], "id": 431}, {"split": "test", "captions": ["Anouk Women Purple & Beige Printed Kurta Set"], "file_path": "dataset/439.jpg", "processed_tokens": [["", "Anouk", "Women", "Purple", "Beige", "Printed", "Kurta", "Set", ""]], "id": 439}, {"split": "test", "captions": ["HERE&NOW Women Mustard Yellow & Green Printed Kurta with Palazzos & Dupatta"], "file_path": "dataset/446.jpg", "processed_tokens": [["", "HERENOW", "Women", "Mustard", "Yellow", "Green", "Printed", "Kurta", "with", "Palazzos", "Dupatta", ""]], "id": 446}, {"split": "test", "captions": ["SIAH Women Peach-Coloured Solid Kurta with Trousers & Dupatta"], "file_path": "dataset/450.jpg", "processed_tokens": [["", "SIAH", "Women", "PeachColoured", "Solid", "Kurta", "with", "Trousers", "Dupatta", ""]], "id": 450}, {"split": "test", "captions": ["GERUA Women Black & Maroon Solid Kurta with Palazzos & Stole"], "file_path": "dataset/461.jpg", "processed_tokens": [["", "GERUA", "Women", "Black", "Maroon", "Solid", "Kurta", "with", "Palazzos", "Stole", ""]], "id": 461}, {"split": "test", "captions": ["Janasya Women Red & Beige Printed Kurti with Palazzos"], "file_path": "dataset/467.jpg", "processed_tokens": [["", "Janasya", "Women", "Red", "Beige", "Printed", "Kurti", "with", "Palazzos", ""]], "id": 467}, {"split": "test", "captions": ["Inddus Women Purple & Golden Woven Design A-Line Kurta"], "file_path": "dataset/473.jpg", "processed_tokens": [["", "Inddus", "Women", "Purple", "Golden", "Woven", "Design", "ALine", "Kurta", ""]], "id": 473}, {"split": "test", "captions": ["Libas Women Beige & White Embroidered Kurta with Palazzos"], "file_path": "dataset/481.jpg", "processed_tokens": [["", "Libas", "Women", "Beige", "White", "Embroidered", "Kurta", "with", "Palazzos", ""]], "id": 481}, {"split": "test", "captions": ["Khushal K Women Pink & Silver-Coloured Solid Kurta with Palazzos"], "file_path": "dataset/494.jpg", "processed_tokens": [["", "Khushal", "K", "Women", "Pink", "SilverColoured", "Solid", "Kurta", "with", "Palazzos", ""]], "id": 494}, {"split": "test", "captions": ["W Women Purple Solid Straight Kurta"], "file_path": "dataset/501.jpg", "processed_tokens": [["", "W", "Women", "Purple", "Solid", "Straight", "Kurta", ""]], "id": 501}, {"split": "test", "captions": ["Libas Women Blue Printed Anarkali Kurta"], "file_path": "dataset/503.jpg", "processed_tokens": [["", "Libas", "Women", "Blue", "Printed", "Anarkali", "Kurta", ""]], "id": 503}, {"split": "test", "captions": ["GERUA Women Maroon & Mustard Yellow Printed Kurta with Palazzos & Stole"], "file_path": "dataset/508.jpg", "processed_tokens": [["", "GERUA", "Women", "Maroon", "Mustard", "Yellow", "Printed", "Kurta", "with", "Palazzos", "Stole", ""]], "id": 508}, {"split": "test", "captions": ["Vishudh Women Beige & Gold-Toned Self Design Kurta with Trousers"], "file_path": "dataset/513.jpg", "processed_tokens": [["", "Vishudh", "Women", "Beige", "GoldToned", "Self", "Design", "Kurta", "with", "Trousers", ""]], "id": 513}, {"split": "test", "captions": ["Ritu Kumar Women Navy Blue & Beige Printed Kaftan Kurta"], "file_path": "dataset/533.jpg", "processed_tokens": [["", "Ritu", "Kumar", "Women", "Navy", "Blue", "Beige", "Printed", "Kaftan", "Kurta", ""]], "id": 533}, {"split": "test", "captions": ["SASSAFRAS Women Pink Solid Anarkali Kurta"], "file_path": "dataset/546.jpg", "processed_tokens": [["", "SASSAFRAS", "Women", "Pink", "Solid", "Anarkali", "Kurta", ""]], "id": 546}, {"split": "test", "captions": ["SASSAFRAS Women Burgundy Ikat Print A-Line Kurta"], "file_path": "dataset/549.jpg", "processed_tokens": [["", "SASSAFRAS", "Women", "Burgundy", "Ikat", "Print", "ALine", "Kurta", ""]], "id": 549}, {"split": "test", "captions": ["Libas Women Mustard Yellow Embroidered Detail Straight Kurta"], "file_path": "dataset/562.jpg", "processed_tokens": [["", "Libas", "Women", "Mustard", "Yellow", "Embroidered", "Detail", "Straight", "Kurta", ""]], "id": 562}, {"split": "test", "captions": ["Sangria Women Coral Orange Solid Straight Kurta"], "file_path": "dataset/580.jpg", "processed_tokens": [["", "Sangria", "Women", "Coral", "Orange", "Solid", "Straight", "Kurta", ""]], "id": 580}, {"split": "test", "captions": ["Anubhutee Women Pink & Navy Blue Printed Kurta with Palazzos"], "file_path": "dataset/585.jpg", "processed_tokens": [["", "Anubhutee", "Women", "Pink", "Navy", "Blue", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 585}, {"split": "test", "captions": ["Vishudh Women Blue Printed Kurta with Palazzos"], "file_path": "dataset/591.jpg", "processed_tokens": [["", "Vishudh", "Women", "Blue", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 591}, {"split": "test", "captions": ["GERUA Women Navy Blue & White Printed Kurta with Trousers"], "file_path": "dataset/596.jpg", "processed_tokens": [["", "GERUA", "Women", "Navy", "Blue", "White", "Printed", "Kurta", "with", "Trousers", ""]], "id": 596}, {"split": "test", "captions": ["anayna Women Green & White Printed Straight Kurta"], "file_path": "dataset/598.jpg", "processed_tokens": [["", "anayna", "Women", "Green", "White", "Printed", "Straight", "Kurta", ""]], "id": 598}, {"split": "test", "captions": ["Anouk Women Beige & Yellow Printed A-Line Kurta"], "file_path": "dataset/615.jpg", "processed_tokens": [["", "Anouk", "Women", "Beige", "Yellow", "Printed", "ALine", "Kurta", ""]], "id": 615}, {"split": "test", "captions": ["HERE&NOW Women Red & Golden Printed A-Line Kurta"], "file_path": "dataset/617.jpg", "processed_tokens": [["", "HERENOW", "Women", "Red", "Golden", "Printed", "ALine", "Kurta", ""]], "id": 617}, {"split": "test", "captions": ["Anouk Women White & Turquoise Blue Printed A-Line Tiered Kurta"], "file_path": "dataset/621.jpg", "processed_tokens": [["", "Anouk", "Women", "White", "Turquoise", "Blue", "Printed", "ALine", "Tiered", "Kurta", ""]], "id": 621}, {"split": "test", "captions": ["AASI - HOUSE OF NAYO Women Green & Golden Printed Kurta with Trousers & Dupatta"], "file_path": "dataset/628.jpg", "processed_tokens": [["", "AASI", "HOUSE", "OF", "NAYO", "Women", "Green", "Golden", "Printed", "Kurta", "with", "Trousers", "Dupatta", ""]], "id": 628}, {"split": "test", "captions": ["Idalia Women Pink & Golden Printed Kurta with Palazzos "], "file_path": "dataset/668.jpg", "processed_tokens": [["", "Idalia", "Women", "Pink", "Golden", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 668}, {"split": "test", "captions": ["House of Pataudi Women White Embroidered A-Line Kurta"], "file_path": "dataset/669.jpg", "processed_tokens": [["", "House", "of", "Pataudi", "Women", "White", "Embroidered", "ALine", "Kurta", ""]], "id": 669}, {"split": "test", "captions": ["AASI - HOUSE OF NAYO Women Lime Green & Navy Solid Kurta with Trousers & Dupatta"], "file_path": "dataset/673.jpg", "processed_tokens": [["", "AASI", "HOUSE", "OF", "NAYO", "Women", "Lime", "Green", "Navy", "Solid", "Kurta", "with", "Trousers", "Dupatta", ""]], "id": 673}, {"split": "test", "captions": ["Jaipur Kurti Women Blue & Green Printed Kurta with Salwar & Dupatta"], "file_path": "dataset/677.jpg", "processed_tokens": [["", "Jaipur", "Kurti", "Women", "Blue", "Green", "Printed", "Kurta", "with", "Salwar", "Dupatta", ""]], "id": 677}, {"split": "test", "captions": ["Rain & Rainbow Women Mustard Yellow Floral Print Anarkali Kurta"], "file_path": "dataset/682.jpg", "processed_tokens": [["", "Rain", "Rainbow", "Women", "Mustard", "Yellow", "Floral", "Print", "Anarkali", "Kurta", ""]], "id": 682}, {"split": "test", "captions": ["YASH GALLERY Women Black & White Printed A-Line Kurta"], "file_path": "dataset/691.jpg", "processed_tokens": [["", "YASH", "GALLERY", "Women", "Black", "White", "Printed", "ALine", "Kurta", ""]], "id": 691}, {"split": "test", "captions": ["Nayo Women Navy Blue & White Printed Kurta with Palazzos"], "file_path": "dataset/702.jpg", "processed_tokens": [["", "Nayo", "Women", "Navy", "Blue", "White", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 702}, {"split": "test", "captions": ["GERUA Women Green & White Checked Straight Kurta"], "file_path": "dataset/763.jpg", "processed_tokens": [["", "GERUA", "Women", "Green", "White", "Checked", "Straight", "Kurta", ""]], "id": 763}, {"split": "test", "captions": ["Vishudh Women Purple Printed Kurta with Palazzos"], "file_path": "dataset/786.jpg", "processed_tokens": [["", "Vishudh", "Women", "Purple", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 786}, {"split": "test", "captions": ["Vishudh Women Blue & Pink Printed A-Line Kurta"], "file_path": "dataset/790.jpg", "processed_tokens": [["", "Vishudh", "Women", "Blue", "Pink", "Printed", "ALine", "Kurta", ""]], "id": 790}, {"split": "test", "captions": ["Anouk Women Beige & Black Colourblocked Kurta with Palazzos"], "file_path": "dataset/794.jpg", "processed_tokens": [["", "Anouk", "Women", "Beige", "Black", "Colourblocked", "Kurta", "with", "Palazzos", ""]], "id": 794}, {"split": "test", "captions": ["Florence Women Red & Black Striped Kurta with Palazzos"], "file_path": "dataset/830.jpg", "processed_tokens": [["", "Florence", "Women", "Red", "Black", "Striped", "Kurta", "with", "Palazzos", ""]], "id": 830}, {"split": "test", "captions": ["AKS Women Blue Embroidered Detail Straight Kurta"], "file_path": "dataset/832.jpg", "processed_tokens": [["", "AKS", "Women", "Blue", "Embroidered", "Detail", "Straight", "Kurta", ""]], "id": 832}, {"split": "test", "captions": ["HERE&NOW Women Black & White Printed Straight Kurta"], "file_path": "dataset/837.jpg", "processed_tokens": [["", "HERENOW", "Women", "Black", "White", "Printed", "Straight", "Kurta", ""]], "id": 837}, {"split": "test", "captions": ["HERE&NOW Women Black & Golden Printed Kurta with Sharara & Dupatta"], "file_path": "dataset/843.jpg", "processed_tokens": [["", "HERENOW", "Women", "Black", "Golden", "Printed", "Kurta", "with", "Sharara", "Dupatta", ""]], "id": 843}, {"split": "test", "captions": ["HIGHLANDER Men White Solid Straight Kurta"], "file_path": "dataset/846.jpg", "processed_tokens": [["", "HIGHLANDER", "Men", "White", "Solid", "Straight", "Kurta", ""]], "id": 846}, {"split": "test", "captions": ["DEYANN Men Blue Solid Kurta with Pyjamas & Nehru Jacket"], "file_path": "dataset/862.jpg", "processed_tokens": [["", "DEYANN", "Men", "Blue", "Solid", "Kurta", "with", "Pyjamas", "Nehru", "Jacket", ""]], "id": 862}, {"split": "test", "captions": ["Taavi Men Peach-Coloured Handloom Woven Legacy Kurta with Half Button Placket"], "file_path": "dataset/864.jpg", "processed_tokens": [["", "Taavi", "Men", "PeachColoured", "Handloom", "Woven", "Legacy", "Kurta", "with", "Half", "Button", "Placket", ""]], "id": 864}, {"split": "test", "captions": ["Purple State Men Navy Blue Solid Straight Kurta"], "file_path": "dataset/875.jpg", "processed_tokens": [["", "Purple", "State", "Men", "Navy", "Blue", "Solid", "Straight", "Kurta", ""]], "id": 875}, {"split": "test", "captions": ["DEYANN Men Red & Cream Self Design Kurta with Patiala"], "file_path": "dataset/880.jpg", "processed_tokens": [["", "DEYANN", "Men", "Red", "Cream", "Self", "Design", "Kurta", "with", "Patiala", ""]], "id": 880}, {"split": "test", "captions": ["Purple State Men Maroon Solid Kurta"], "file_path": "dataset/907.jpg", "processed_tokens": [["", "Purple", "State", "Men", "Maroon", "Solid", "Kurta", ""]], "id": 907}, {"split": "test", "captions": ["Anouk Men Grey & Black Woven Design Straight Kurta"], "file_path": "dataset/919.jpg", "processed_tokens": [["", "Anouk", "Men", "Grey", "Black", "Woven", "Design", "Straight", "Kurta", ""]], "id": 919}, {"split": "test", "captions": ["KISAH Men Yellow Embroidered Straight Kurta"], "file_path": "dataset/931.jpg", "processed_tokens": [["", "KISAH", "Men", "Yellow", "Embroidered", "Straight", "Kurta", ""]], "id": 931}, {"split": "test", "captions": ["even Men Mustard Yellow Solid Straight Kurta"], "file_path": "dataset/952.jpg", "processed_tokens": [["", "even", "Men", "Mustard", "Yellow", "Solid", "Straight", "Kurta", ""]], "id": 952}, {"split": "test", "captions": ["Anouk Men Maroon & Black Solid Kurta Set"], "file_path": "dataset/959.jpg", "processed_tokens": [["", "Anouk", "Men", "Maroon", "Black", "Solid", "Kurta", "Set", ""]], "id": 959}, {"split": "test", "captions": ["Svanik Men Teal Green Woven Design Straight Kurta"], "file_path": "dataset/962.jpg", "processed_tokens": [["", "Svanik", "Men", "Teal", "Green", "Woven", "Design", "Straight", "Kurta", ""]], "id": 962}, {"split": "test", "captions": ["House of Pataudi Men Rose Pink Solid Straight Bib Kurta"], "file_path": "dataset/972.jpg", "processed_tokens": [["", "House", "of", "Pataudi", "Men", "Rose", "Pink", "Solid", "Straight", "Bib", "Kurta", ""]], "id": 972}, {"split": "test", "captions": ["SOJANYA Men Lime Green Solid Straight Kurta"], "file_path": "dataset/974.jpg", "processed_tokens": [["", "SOJANYA", "Men", "Lime", "Green", "Solid", "Straight", "Kurta", ""]], "id": 974}, {"split": "test", "captions": ["NEUDIS Men Red & White Self Design Kurta with Trousers"], "file_path": "dataset/975.jpg", "processed_tokens": [["", "NEUDIS", "Men", "Red", "White", "Self", "Design", "Kurta", "with", "Trousers", ""]], "id": 975}, {"split": "test", "captions": ["SOJANYA Men Black Solid Kurta with Dhoti Pants"], "file_path": "dataset/978.jpg", "processed_tokens": [["", "SOJANYA", "Men", "Black", "Solid", "Kurta", "with", "Dhoti", "Pants", ""]], "id": 978}] -------------------------------------------------------------------------------- /deep_cmpl_model/data/processed_data/test_sort.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/data/processed_data/test_sort.pkl -------------------------------------------------------------------------------- /deep_cmpl_model/data/processed_data/train_sort.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/data/processed_data/train_sort.pkl -------------------------------------------------------------------------------- /deep_cmpl_model/data/processed_data/val_sort.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/data/processed_data/val_sort.pkl -------------------------------------------------------------------------------- /deep_cmpl_model/data/processed_data/word_counts.txt: -------------------------------------------------------------------------------- 1 | Total words: 298 2 | Words in vocabulary: 163 3 | [('', 800), ('', 800), ('Kurta', 784), ('Women', 672), ('Printed', 398), ('with', 311), ('Straight', 304), ('Blue', 210), ('Solid', 156), ('Palazzos', 152), ('ALine', 132), ('Men', 124), ('White', 122), ('Green', 122), ('Black', 109), ('Pink', 108), ('Design', 99), ('Navy', 88), ('Trousers', 83), ('Vishudh', 79), ('Yellow', 76), ('Libas', 76), ('Anouk', 76), ('OffWhite', 69), ('Dupatta', 67), ('Red', 65), ('Embroidered', 58), ('Beige', 54), ('Grey', 54), ('Maroon', 51), ('Mustard', 49), ('Anarkali', 46), ('Yoke', 45), ('Woven', 43), ('Golden', 39), ('Print', 39), ('Kurti', 30), ('HERENOW', 30), ('Churidar', 29), ('Striped', 27), ('Melange', 27), ('by', 27), ('Block', 25), ('GoldToned', 25), ('Lifestyle', 25), ('Teal', 24), ('DEYANN', 23), ('Nayo', 22), ('Soch', 22), ('Orange', 21), ('Self', 21), ('Brown', 21), ('Varanga', 21), ('CreamColoured', 21), ('Jaipur', 19), ('AKS', 19), ('Jacket', 19), ('See', 18), ('Designs', 18), ('W', 16), ('Sangria', 16), ('Purple', 16), ('PeachColoured', 15), ('Nehru', 15), ('Fabindia', 14), ('Taavi', 14), ('Indo', 13), ('Era', 13), ('Sea', 13), ('Burgundy', 13), ('ZIYAA', 13), ('GERUA', 13), ('Olive', 13), ('Layered', 12), ('KISAH', 11), ('SOJANYA', 11), ('Pyjamas', 11), ('AHIKA', 10), ('Inddus', 10), ('Checked', 10), ('House', 10), ('of', 10), ('Pataudi', 10), ('With', 10), ('Legacy', 10), ('Fusion', 10), ('Vaamsi', 9), ('Moda', 9), ('Rapido', 9), ('Jompers', 9), ('Handloom', 9), ('Yufta', 9), ('Foil', 8), ('Coral', 8), ('Charcoal', 8), ('Rust', 8), ('IMARA', 8), ('Cross', 8), ('Court', 8), ('Patiala', 8), ('Silver', 7), ('Azira', 7), ('Rain', 7), ('Rainbow', 7), ('Skirt', 7), ('Pathani', 7), ('ADA', 6), ('Chikankari', 6), ('Hand', 6), ('GoldColoured', 6), ('Bandhani', 6), ('Turquoise', 6), ('Magenta', 6), ('Ishin', 6), ('Sharara', 6), ('even', 6), ('Lime', 5), ('Global', 5), ('Desi', 5), ('Dyed', 5), ('Janasya', 5), ('Anubhutee', 5), ('Tissu', 5), ('Mauve', 5), ('SASSAFRAS', 5), ('Silk', 5), ('Multicoloured', 5), ('Ahalyaa', 5), ('SilverToned', 4), ('Rasada', 4), ('AASI', 4), ('HOUSE', 4), ('OF', 4), ('NAYO', 4), ('Taupe', 4), ('Alena', 4), ('SelfStriped', 4), ('Ritu', 4), ('Kumar', 4), ('Manyavar', 4), ('Dupion', 4), ('NEUDIS', 4), ('Cotton', 3), ('Aline', 3), ('Floral', 3), ('Bitterlime', 3), ('Detail', 3), ('Kaftan', 3), ('Pockets', 3), ('Top', 3), ('Bhama', 3), ('Couture', 3), ('Biba', 3), ('all', 3), ('about', 3), ('you', 3), ('Idalia', 3), ('Shree', 3), ('Kalamkari', 3), ('I', 3), ('Set', 3), ('RollUp', 3), ('Sleeves', 3)] -------------------------------------------------------------------------------- /deep_cmpl_model/data/processed_data/word_to_index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/data/processed_data/word_to_index.pkl -------------------------------------------------------------------------------- /deep_cmpl_model/deep_cmpl_jupyter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "deep_cmpl_jupyter.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "code", 22 | "metadata": { 23 | "id": "W5YoGYWqCNbp", 24 | "colab": { 25 | "base_uri": "https://localhost:8080/" 26 | }, 27 | "outputId": "f0f71f76-2444-4fa8-96a2-7e6c8583ffe7" 28 | }, 29 | "source": [ 30 | "from google.colab import drive\n", 31 | "drive.mount('/content/drive/',force_remount=True)" 32 | ], 33 | "execution_count": 1, 34 | "outputs": [ 35 | { 36 | "output_type": "stream", 37 | "text": [ 38 | "Mounted at /content/drive/\n" 39 | ], 40 | "name": "stdout" 41 | } 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "metadata": { 47 | "id": "ykxVopBOEHKT" 48 | }, 49 | "source": [ 50 | "!pip install PyDrive\n", 51 | "from pydrive.auth import GoogleAuth\n", 52 | "from pydrive.drive import GoogleDrive\n", 53 | "from google.colab import auth\n", 54 | "from oauth2client.client import GoogleCredentials\n", 55 | "auth.authenticate_user()\n", 56 | "gauth = GoogleAuth()\n", 57 | "gauth.credentials = GoogleCredentials.get_application_default()\n", 58 | "drive = GoogleDrive(gauth)\n", 59 | "downloaded = drive.CreateFile({'id':\"1XHUwpJDETylqrNJgeXRas89YswqfobG5\"})\n", 60 | "downloaded.GetContentFile('dataset.zip') \n", 61 | "!unzip \"dataset.zip\" -d \"/content\"" 62 | ], 63 | "execution_count": null, 64 | "outputs": [] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "metadata": { 69 | "id": "UovUjGvQRWKK" 70 | }, 71 | "source": [ 72 | "# !unzip /content/drive/Shareddrives/Image-Text-Retrieval/dataset.zip -d /content" 73 | ], 74 | "execution_count": null, 75 | "outputs": [] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "metadata": { 80 | "id": "ybIFQeAjFTg0", 81 | "colab": { 82 | "base_uri": "https://localhost:8080/" 83 | }, 84 | "outputId": "24b31585-077a-479a-9c1d-2284d43baf80" 85 | }, 86 | "source": [ 87 | "!rm -rf Image_Text_Retrieval/\n", 88 | "%cd /content/\n", 89 | "!git clone --single-branch --branch main https://github.com/raghavgoyal283/Image_Text_Retrieval" 90 | ], 91 | "execution_count": 4, 92 | "outputs": [ 93 | { 94 | "output_type": "stream", 95 | "text": [ 96 | "/content\n", 97 | "Cloning into 'Image_Text_Retrieval'...\n", 98 | "remote: Enumerating objects: 196, done.\u001b[K\n", 99 | "remote: Counting objects: 100% (196/196), done.\u001b[K\n", 100 | "remote: Compressing objects: 100% (160/160), done.\u001b[K\n", 101 | "remote: Total 196 (delta 62), reused 159 (delta 28), pack-reused 0\u001b[K\n", 102 | "Receiving objects: 100% (196/196), 3.67 MiB | 27.82 MiB/s, done.\n", 103 | "Resolving deltas: 100% (62/62), done.\n" 104 | ], 105 | "name": "stdout" 106 | } 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "metadata": { 112 | "colab": { 113 | "base_uri": "https://localhost:8080/" 114 | }, 115 | "id": "GFd9a065q9fU", 116 | "outputId": "b6480f25-5c4c-4584-913e-25d3d6218d22" 117 | }, 118 | "source": [ 119 | "%cd /content/Image_Text_Retrieval\n", 120 | "!git pull origin main" 121 | ], 122 | "execution_count": 17, 123 | "outputs": [ 124 | { 125 | "output_type": "stream", 126 | "text": [ 127 | "/content/Image_Text_Retrieval\n", 128 | "remote: Enumerating objects: 13, done.\u001b[K\n", 129 | "remote: Counting objects: 100% (13/13), done.\u001b[K\n", 130 | "remote: Compressing objects: 100% (3/3), done.\u001b[K\n", 131 | "remote: Total 7 (delta 4), reused 7 (delta 4), pack-reused 0\u001b[K\n", 132 | "Unpacking objects: 100% (7/7), done.\n", 133 | "From https://github.com/raghavgoyal283/Image_Text_Retrieval\n", 134 | " * branch main -> FETCH_HEAD\n", 135 | " 4776f27..8b4fdaf main -> origin/main\n", 136 | "Updating 4776f27..8b4fdaf\n", 137 | "Fast-forward\n", 138 | " deep_cmpl_model/code/scripts/tester.py | 2 \u001b[32m+\u001b[m\u001b[31m-\u001b[m\n", 139 | " deep_cmpl_model/code/scripts/trainer.py | 2 \u001b[32m+\u001b[m\u001b[31m-\u001b[m\n", 140 | " 2 files changed, 2 insertions(+), 2 deletions(-)\n" 141 | ], 142 | "name": "stdout" 143 | } 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "metadata": { 149 | "id": "3twCZ8ABDNJp", 150 | "colab": { 151 | "base_uri": "https://localhost:8080/" 152 | }, 153 | "outputId": "3bad414d-0078-4fe5-a2e9-4e3491caefb6" 154 | }, 155 | "source": [ 156 | "%cd /content/\n", 157 | "!python3 /content/Image_Text_Retrieval/deep_cmpl_model/data/make_json.py\n", 158 | "!sh /content/Image_Text_Retrieval/deep_cmpl_model/code/datasets/data.sh" 159 | ], 160 | "execution_count": 12, 161 | "outputs": [ 162 | { 163 | "output_type": "stream", 164 | "text": [ 165 | "/content\n", 166 | "12305\n", 167 | "Preprocessing dataset\n", 168 | "start build vodabulary\n", 169 | "Total words: 1418\n", 170 | "Words in vocab: 812\n", 171 | "number of bad words: 606/1418 = 42.74%\n", 172 | "number of words in vocab: 812/1418 = 57.26%\n", 173 | "number of Null: 787/1418 = 55.50%\n", 174 | "Process metadata done!\n", 175 | "Total 9844 captions 9844 images 9844 identities in train\n", 176 | "Process metadata done!\n", 177 | "Total 1230 captions 1230 images 1230 identities in val\n", 178 | "Process metadata done!\n", 179 | "Total 1231 captions 1231 images 1231 identities in test\n", 180 | "Process decodedata done!\n", 181 | "Process decodedata done!\n", 182 | "Process decodedata done!\n", 183 | "=========== Arrange by id=============================\n", 184 | "Save dataset\n", 185 | "=========== Arrange by id=============================\n", 186 | "Save dataset\n", 187 | "=========== Arrange by id=============================\n", 188 | "Save dataset\n" 189 | ], 190 | "name": "stdout" 191 | } 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": { 197 | "id": "G87T6Ntf1UOV" 198 | }, 199 | "source": [ 200 | "## Remember to push the pkl files" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "metadata": { 206 | "id": "rVxnw3Iv1QfX", 207 | "colab": { 208 | "base_uri": "https://localhost:8080/" 209 | }, 210 | "outputId": "ac403127-d98f-47fe-f288-491fb93745d1" 211 | }, 212 | "source": [ 213 | "# %cd /content/Image_Text_Retrieval\n", 214 | "# !git config --global user.email \"raghavgoyal283@gmail.com\"\n", 215 | "# !git config --global user.name \"raghavgoyal283\"\n", 216 | "# !git add .\n", 217 | "# !git commit -m \"add\"\n", 218 | "# !git push -u origin main" 219 | ], 220 | "execution_count": null, 221 | "outputs": [ 222 | { 223 | "output_type": "stream", 224 | "text": [ 225 | "/content/Image_Text_Retrieval\n", 226 | "[anothertry af39ca4] add\n", 227 | " 1 file changed, 3 insertions(+), 3 deletions(-)\n", 228 | "Counting objects: 4, done.\n", 229 | "Delta compression using up to 2 threads.\n", 230 | "Compressing objects: 100% (4/4), done.\n", 231 | "Writing objects: 100% (4/4), 422 bytes | 422.00 KiB/s, done.\n", 232 | "Total 4 (delta 3), reused 0 (delta 0)\n", 233 | "remote: Resolving deltas: 100% (3/3), completed with 3 local objects.\u001b[K\n", 234 | "To https://github.com/raghavgoyal283/Image_Text_Retrieval.git\n", 235 | " c1d7077..af39ca4 anothertry -> anothertry\n", 236 | "Branch 'anothertry' set up to track remote branch 'anothertry' from 'origin'.\n" 237 | ], 238 | "name": "stdout" 239 | } 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": { 245 | "id": "Z4iFxUTlRnWQ" 246 | }, 247 | "source": [ 248 | "# Train/Test" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "metadata": { 254 | "id": "CbFhoqYkoq-6", 255 | "colab": { 256 | "base_uri": "https://localhost:8080/" 257 | }, 258 | "outputId": "b905fe42-9ead-4887-9a77-84d62ef05db2" 259 | }, 260 | "source": [ 261 | "!pip install efficientnet_pytorch" 262 | ], 263 | "execution_count": 13, 264 | "outputs": [ 265 | { 266 | "output_type": "stream", 267 | "text": [ 268 | "Collecting efficientnet_pytorch\n", 269 | " Downloading https://files.pythonhosted.org/packages/2e/a0/dd40b50aebf0028054b6b35062948da01123d7be38d08b6b1e5435df6363/efficientnet_pytorch-0.7.1.tar.gz\n", 270 | "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from efficientnet_pytorch) (1.8.1+cu101)\n", 271 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->efficientnet_pytorch) (3.7.4.3)\n", 272 | "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch->efficientnet_pytorch) (1.19.5)\n", 273 | "Building wheels for collected packages: efficientnet-pytorch\n", 274 | " Building wheel for efficientnet-pytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 275 | " Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-cp37-none-any.whl size=16443 sha256=5a0fcea201810a6455bd03e301dc9617b6b4a66d8496abf193b0017c78d619a4\n", 276 | " Stored in directory: /root/.cache/pip/wheels/84/27/aa/c46d23c4e8cc72d41283862b1437e0b3ad318417e8ed7d5921\n", 277 | "Successfully built efficientnet-pytorch\n", 278 | "Installing collected packages: efficientnet-pytorch\n", 279 | "Successfully installed efficientnet-pytorch-0.7.1\n" 280 | ], 281 | "name": "stdout" 282 | } 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "metadata": { 288 | "id": "-aTdr_DrGd4x" 289 | }, 290 | "source": [ 291 | "# Clear previous checkpoints\n", 292 | "!rm -r /content/drive/Shareddrives/Image-Text-Retrieval/tempckpt/*" 293 | ], 294 | "execution_count": 21, 295 | "outputs": [] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "metadata": { 300 | "colab": { 301 | "base_uri": "https://localhost:8080/" 302 | }, 303 | "id": "IkRKNJVcDM-F", 304 | "outputId": "c078b79b-bdd3-4438-b5d9-d8aa0ba129bd" 305 | }, 306 | "source": [ 307 | "%cd /content/\n", 308 | "!python3 /content/Image_Text_Retrieval/deep_cmpl_model/code/scripts/trainer.py" 309 | ], 310 | "execution_count": 22, 311 | "outputs": [ 312 | { 313 | "output_type": "stream", 314 | "text": [ 315 | "/content\n", 316 | "Loaded pretrained weights for efficientnet-b0\n", 317 | "Total params: 17M\n", 318 | "epoch:0, step:0, cmpm_loss:28.980\n", 319 | "epoch:0, step:10, cmpm_loss:27.822\n", 320 | "epoch:0, step:20, cmpm_loss:25.797\n", 321 | "epoch:0, step:30, cmpm_loss:24.350\n", 322 | "epoch:0, step:40, cmpm_loss:22.872\n", 323 | "epoch:0, step:50, cmpm_loss:18.468\n", 324 | "epoch:0, step:60, cmpm_loss:17.170\n", 325 | "epoch:0, step:70, cmpm_loss:17.078\n", 326 | "epoch:0, step:80, cmpm_loss:18.581\n", 327 | "epoch:0, step:90, cmpm_loss:17.275\n", 328 | "epoch:0, step:100, cmpm_loss:16.918\n", 329 | "epoch:0, step:110, cmpm_loss:14.333\n", 330 | "epoch:0, step:120, cmpm_loss:20.252\n", 331 | "epoch:0, step:130, cmpm_loss:13.150\n", 332 | "epoch:0, step:140, cmpm_loss:13.900\n", 333 | "epoch:0, step:150, cmpm_loss:9.549\n", 334 | "epoch:0, step:160, cmpm_loss:13.966\n", 335 | "epoch:0, step:170, cmpm_loss:8.952\n", 336 | "epoch:0, step:180, cmpm_loss:10.939\n", 337 | "epoch:0, step:190, cmpm_loss:8.844\n", 338 | "epoch:0, step:200, cmpm_loss:3.907\n", 339 | "epoch:0, step:210, cmpm_loss:10.644\n", 340 | "epoch:0, step:220, cmpm_loss:10.646\n", 341 | "epoch:0, step:230, cmpm_loss:9.801\n", 342 | "epoch:0, step:240, cmpm_loss:6.492\n", 343 | "epoch:0, step:250, cmpm_loss:11.888\n", 344 | "epoch:0, step:260, cmpm_loss:16.473\n", 345 | "epoch:0, step:270, cmpm_loss:8.894\n", 346 | "epoch:0, step:280, cmpm_loss:10.306\n", 347 | "epoch:0, step:290, cmpm_loss:15.606\n", 348 | "epoch:0, step:300, cmpm_loss:7.734\n", 349 | "epoch:0, step:310, cmpm_loss:10.006\n", 350 | "epoch:0, step:320, cmpm_loss:5.520\n", 351 | "epoch:0, step:330, cmpm_loss:9.725\n", 352 | "epoch:0, step:340, cmpm_loss:8.853\n", 353 | "epoch:0, step:350, cmpm_loss:5.653\n", 354 | "epoch:0, step:360, cmpm_loss:8.572\n", 355 | "epoch:0, step:370, cmpm_loss:9.815\n", 356 | "epoch:0, step:380, cmpm_loss:6.975\n", 357 | "epoch:0, step:390, cmpm_loss:8.218\n", 358 | "epoch:0, step:400, cmpm_loss:12.703\n", 359 | "epoch:0, step:410, cmpm_loss:6.403\n", 360 | "epoch:0, step:420, cmpm_loss:7.904\n", 361 | "epoch:0, step:430, cmpm_loss:8.343\n", 362 | "epoch:0, step:440, cmpm_loss:5.121\n", 363 | "epoch:0, step:450, cmpm_loss:6.697\n", 364 | "epoch:0, step:460, cmpm_loss:8.433\n", 365 | "epoch:0, step:470, cmpm_loss:5.876\n", 366 | "epoch:0, step:480, cmpm_loss:0.763\n", 367 | "epoch:0, step:490, cmpm_loss:7.332\n", 368 | "epoch:0, step:500, cmpm_loss:13.782\n", 369 | "epoch:0, step:510, cmpm_loss:6.662\n", 370 | "epoch:0, step:520, cmpm_loss:3.625\n", 371 | "epoch:0, step:530, cmpm_loss:8.246\n", 372 | "epoch:0, step:540, cmpm_loss:7.430\n", 373 | "epoch:0, step:550, cmpm_loss:5.310\n", 374 | "epoch:0, step:560, cmpm_loss:5.444\n", 375 | "epoch:0, step:570, cmpm_loss:2.826\n", 376 | "epoch:0, step:580, cmpm_loss:4.238\n", 377 | "epoch:0, step:590, cmpm_loss:8.130\n", 378 | "epoch:0, step:600, cmpm_loss:8.824\n", 379 | "epoch:0, step:610, cmpm_loss:6.546\n", 380 | "Train done for epoch-0\n", 381 | "lr:0.0002\n", 382 | "epoch:1, step:0, cmpm_loss:6.692\n", 383 | "epoch:1, step:10, cmpm_loss:7.902\n", 384 | "epoch:1, step:20, cmpm_loss:4.313\n", 385 | "epoch:1, step:30, cmpm_loss:8.069\n", 386 | "epoch:1, step:40, cmpm_loss:3.060\n", 387 | "epoch:1, step:50, cmpm_loss:0.437\n", 388 | "epoch:1, step:60, cmpm_loss:2.161\n", 389 | "epoch:1, step:70, cmpm_loss:2.665\n", 390 | "epoch:1, step:80, cmpm_loss:9.616\n", 391 | "epoch:1, step:90, cmpm_loss:3.825\n", 392 | "epoch:1, step:100, cmpm_loss:2.910\n", 393 | "epoch:1, step:110, cmpm_loss:1.977\n", 394 | "epoch:1, step:120, cmpm_loss:1.750\n", 395 | "epoch:1, step:130, cmpm_loss:5.090\n", 396 | "epoch:1, step:140, cmpm_loss:7.062\n", 397 | "epoch:1, step:150, cmpm_loss:2.141\n", 398 | "epoch:1, step:160, cmpm_loss:7.578\n", 399 | "epoch:1, step:170, cmpm_loss:7.134\n", 400 | "epoch:1, step:180, cmpm_loss:2.916\n", 401 | "epoch:1, step:190, cmpm_loss:7.609\n", 402 | "epoch:1, step:200, cmpm_loss:2.978\n", 403 | "epoch:1, step:210, cmpm_loss:3.521\n", 404 | "epoch:1, step:220, cmpm_loss:3.619\n", 405 | "epoch:1, step:230, cmpm_loss:4.623\n", 406 | "epoch:1, step:240, cmpm_loss:6.489\n", 407 | "epoch:1, step:250, cmpm_loss:4.454\n", 408 | "epoch:1, step:260, cmpm_loss:2.816\n", 409 | "epoch:1, step:270, cmpm_loss:9.463\n", 410 | "epoch:1, step:280, cmpm_loss:3.138\n", 411 | "epoch:1, step:290, cmpm_loss:4.777\n", 412 | "epoch:1, step:300, cmpm_loss:5.296\n", 413 | "epoch:1, step:310, cmpm_loss:5.167\n", 414 | "epoch:1, step:320, cmpm_loss:4.451\n", 415 | "epoch:1, step:330, cmpm_loss:3.969\n", 416 | "epoch:1, step:340, cmpm_loss:4.741\n", 417 | "epoch:1, step:350, cmpm_loss:9.974\n", 418 | "epoch:1, step:360, cmpm_loss:7.483\n", 419 | "epoch:1, step:370, cmpm_loss:8.054\n", 420 | "epoch:1, step:380, cmpm_loss:5.053\n", 421 | "epoch:1, step:390, cmpm_loss:3.425\n", 422 | "epoch:1, step:400, cmpm_loss:2.688\n", 423 | "epoch:1, step:410, cmpm_loss:5.667\n", 424 | "epoch:1, step:420, cmpm_loss:4.761\n", 425 | "epoch:1, step:430, cmpm_loss:4.056\n", 426 | "epoch:1, step:440, cmpm_loss:6.989\n", 427 | "epoch:1, step:450, cmpm_loss:6.960\n", 428 | "epoch:1, step:460, cmpm_loss:9.581\n", 429 | "epoch:1, step:470, cmpm_loss:4.810\n", 430 | "epoch:1, step:480, cmpm_loss:4.324\n", 431 | "epoch:1, step:490, cmpm_loss:3.718\n", 432 | "epoch:1, step:500, cmpm_loss:3.811\n", 433 | "epoch:1, step:510, cmpm_loss:1.995\n", 434 | "epoch:1, step:520, cmpm_loss:5.277\n", 435 | "epoch:1, step:530, cmpm_loss:1.554\n", 436 | "epoch:1, step:540, cmpm_loss:5.259\n", 437 | "epoch:1, step:550, cmpm_loss:5.785\n", 438 | "epoch:1, step:560, cmpm_loss:7.259\n", 439 | "epoch:1, step:570, cmpm_loss:10.739\n", 440 | "epoch:1, step:580, cmpm_loss:5.838\n", 441 | "epoch:1, step:590, cmpm_loss:2.185\n", 442 | "epoch:1, step:600, cmpm_loss:4.268\n", 443 | "epoch:1, step:610, cmpm_loss:2.276\n", 444 | "Train done for epoch-1\n", 445 | "lr:0.0002\n", 446 | "epoch:2, step:0, cmpm_loss:5.452\n", 447 | "epoch:2, step:10, cmpm_loss:2.122\n", 448 | "epoch:2, step:20, cmpm_loss:10.590\n", 449 | "Traceback (most recent call last):\n", 450 | " File \"/content/Image_Text_Retrieval/deep_cmpl_model/code/train.py\", line 230, in \n", 451 | " main(args)\n", 452 | " File \"/content/Image_Text_Retrieval/deep_cmpl_model/code/train.py\", line 193, in main\n", 453 | " train_loss, train_time = train(args.start_epoch + epoch, train_loader, network, optimizer, compute_loss, args)\n", 454 | " File \"/content/Image_Text_Retrieval/deep_cmpl_model/code/train.py\", line 51, in train\n", 455 | " cmpm_loss.backward()\n", 456 | " File \"/usr/local/lib/python3.7/dist-packages/torch/tensor.py\", line 245, in backward\n", 457 | " torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)\n", 458 | " File \"/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py\", line 147, in backward\n", 459 | " allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n", 460 | "KeyboardInterrupt\n" 461 | ], 462 | "name": "stdout" 463 | } 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "metadata": { 469 | "id": "qHY6hk3oFhu-", 470 | "colab": { 471 | "base_uri": "https://localhost:8080/" 472 | }, 473 | "outputId": "c7752301-3279-4969-fd58-74707300c2e6" 474 | }, 475 | "source": [ 476 | "%cd /content/\n", 477 | "!python3 /content/Image_Text_Retrieval/deep_cmpl_model/code/scripts/tester.py" 478 | ], 479 | "execution_count": 23, 480 | "outputs": [ 481 | { 482 | "output_type": "stream", 483 | "text": [ 484 | "/content\n", 485 | "Loaded pretrained weights for efficientnet-b0\n", 486 | "==> Loading checkpoint \"drive/Shareddrives/Image-Text-Retrieval/tempckpt/data/model_data/lr-0.0002-decay-0.9-batch-16/0.pth.tar\"\n", 487 | "Loaded pretrained weights for efficientnet-b0\n", 488 | "==> Loading checkpoint \"drive/Shareddrives/Image-Text-Retrieval/tempckpt/data/model_data/lr-0.0002-decay-0.9-batch-16/1.pth.tar\"\n", 489 | "t2i_top1_best: 18.034, t2i_top5_best: 51.097, t2i_top10_best: 68.725, t2i_mr_best: 0.406\n", 490 | "i2t_top1: 17.872, i2t_top5: 49.228, i2t_top10: 67.100, i2t_mr_best: 0.487\n" 491 | ], 492 | "name": "stdout" 493 | } 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "metadata": { 499 | "id": "tsntUW3G_Os3" 500 | }, 501 | "source": [ 502 | "drive.flush_and_unmount()" 503 | ], 504 | "execution_count": null, 505 | "outputs": [] 506 | } 507 | ] 508 | } --------------------------------------------------------------------------------