├── .gitignore
├── Corrnet
├── Corrnet_Pytorch.ipynb
├── README.md
└── word2vec.ipynb
├── Demo_Web_App
├── README.md
├── config.py
├── directory.py
├── models
│ ├── __pycache__
│ │ ├── bi_lstm.cpython-37.pyc
│ │ ├── bi_lstm.cpython-38.pyc
│ │ ├── mobilenet.cpython-37.pyc
│ │ ├── mobilenet.cpython-38.pyc
│ │ ├── model.cpython-37.pyc
│ │ ├── model.cpython-38.pyc
│ │ ├── resnet.cpython-37.pyc
│ │ └── resnet.cpython-38.pyc
│ ├── bi_lstm.py
│ ├── mobilenet.py
│ ├── model.py
│ └── resnet.py
├── saved_model
│ ├── test_sort.pkl
│ └── word_to_index.pkl
├── static
│ └── temp
│ │ └── 1620480640.jpg
├── templates
│ ├── image_results.html
│ ├── index.html
│ └── text_results.html
├── test_config.py
├── tester.py
└── web_app.py
├── README.md
├── Siamese
├── README.md
├── Word2vecgensim.ipynb
└── siamese_network.ipynb
├── assets
├── correlational.png
├── deepcmpl.png
├── i2t.PNG
├── siamese.png
└── t2i.PNG
└── deep_cmpl_model
├── .gitignore
├── README.md
├── code
├── __pycache__
│ ├── config.cpython-37.pyc
│ ├── config.cpython-38.pyc
│ ├── test.cpython-37.pyc
│ ├── test.cpython-38.pyc
│ ├── test_config.cpython-37.pyc
│ ├── test_config.cpython-38.pyc
│ ├── test_params.cpython-37.pyc
│ ├── train_config.cpython-37.pyc
│ ├── train_config.cpython-38.pyc
│ └── train_params.cpython-37.pyc
├── datasets
│ ├── __pycache__
│ │ ├── directory.cpython-37.pyc
│ │ ├── fashion.cpython-37.pyc
│ │ ├── pedes.cpython-37.pyc
│ │ └── pedes.cpython-38.pyc
│ ├── data.sh
│ ├── directory.py
│ ├── fashion.py
│ └── preprocess.py
├── models
│ ├── __pycache__
│ │ ├── bi_lstm.cpython-37.pyc
│ │ ├── bi_lstm.cpython-38.pyc
│ │ ├── eff_net.cpython-37.pyc
│ │ ├── efficient_net.cpython-37.pyc
│ │ ├── mobilenet.cpython-37.pyc
│ │ ├── mobilenet.cpython-38.pyc
│ │ ├── model.cpython-37.pyc
│ │ ├── model.cpython-38.pyc
│ │ ├── resnet.cpython-37.pyc
│ │ └── resnet.cpython-38.pyc
│ ├── bi_lstm.py
│ ├── eff_net.py
│ ├── mobilenet.py
│ ├── model.py
│ └── resnet.py
├── scripts
│ ├── tester.py
│ └── trainer.py
├── test.py
├── test_params.py
├── train.py
├── train_params.py
└── utils
│ ├── __pycache__
│ ├── directory.cpython-37.pyc
│ ├── directory.cpython-38.pyc
│ ├── helpers.cpython-37.pyc
│ ├── metric.cpython-37.pyc
│ └── metric.cpython-38.pyc
│ ├── directory.py
│ ├── helpers.py
│ └── metric.py
├── data
├── images.csv
├── make_json.py
├── processed_data
│ ├── metadata_info.txt
│ ├── test_reid.json
│ ├── test_sort.pkl
│ ├── train_reid.json
│ ├── train_sort.pkl
│ ├── val_reid.json
│ ├── val_sort.pkl
│ ├── word_counts.txt
│ └── word_to_index.pkl
└── reid_raw.json
└── deep_cmpl_jupyter.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | <<<<<<< HEAD
2 | # Images of the dataset
3 | dataset
4 |
5 | # Checkpoint directories
6 | model_data
7 | logs
8 |
9 | #zip files
10 | *.zip
11 |
12 | # Misccellaneous
13 | data/f8k
14 |
15 | *.pth.tar
16 | Demo_Web_App/static/dataset/
17 | =======
18 | dataset/
19 | drive/
20 | dgpu/drive/
21 | dgpu/data/logs
22 | >>>>>>> 7d7a4ac943534fa53775eed173be3313ae889b49
23 |
--------------------------------------------------------------------------------
/Corrnet/README.md:
--------------------------------------------------------------------------------
1 | # Corrnet
2 |
3 | > The way to load, train and test this model has been shown in the jupyter notebooks. The code has also been well documented detailing the use of every function.
4 |
5 | ## Instructions to Run Code
6 | 1. The text csv file is to be generated using the word2vec.ipynb notebook, for some text data.
7 | 2. Images must be added in the appropriate folder.
8 | 3. Set all the paths in Corrnet_Pytorch.ipynb.
9 | 4. Then, simply run the Corrnet_Pytorch.ipynb notebook.
--------------------------------------------------------------------------------
/Demo_Web_App/README.md:
--------------------------------------------------------------------------------
1 | # Image-Text-Retrieval-Web-App
2 | Flask Web App for ES-654 Machine Learning course project
3 |
4 |
5 | ## Instructions to Run
6 |
7 | To run the web app, follow the instructions in the following Colab Notebook: [link](https://colab.research.google.com/drive/1Obk2DQpHAyQANObUnth_gWuxi-CR5lMn?usp=sharing)
8 |
9 | ## Requirements
10 |
11 | * Flask
12 | * Flask-ngrok
13 | * Werkzeug
14 | * PyTorch
15 | * imageio
16 |
--------------------------------------------------------------------------------
/Demo_Web_App/config.py:
--------------------------------------------------------------------------------
1 | import torch.backends.cudnn as cudnn
2 | import torch.nn as nn
3 | import torch
4 | from models.model import Model
5 | from directory import check_file
6 | import random
7 | import numpy as np
8 |
9 |
10 | def network_config(args, split='train', param=None, resume=False, model_path=None, ema=False):
11 | network = Model(args)
12 | network = nn.DataParallel(network).cuda()
13 | cudnn.benchmark = True
14 | args.start_epoch = 0
15 |
16 | # process network params
17 | if resume:
18 | check_file(model_path, 'model_file')
19 | checkpoint = torch.load(model_path)
20 | args.start_epoch = checkpoint['epoch'] + 1
21 | # best_prec1 = checkpoint['best_prec1']
22 | #network.load_state_dict(checkpoint['state_dict'])
23 | network_dict = checkpoint['network']
24 | # if ema:
25 | # logging.info('==> EMA Loading')
26 | # network_dict.update(checkpoint['network_ema'])
27 | network.load_state_dict(network_dict)
28 | print('==> Loading checkpoint "{}"'.format(model_path))
29 | else:
30 | # pretrained
31 | if model_path is not None:
32 | print('==> Loading from pretrained models')
33 | network_dict = network.state_dict()
34 | if args.image_model == 'mobilenet_v1':
35 | cnn_pretrained = torch.load(model_path)['state_dict']
36 | start = 7
37 | else:
38 | cnn_pretrained = torch.load(model_path)
39 | start = 0
40 | # process keyword of pretrained model
41 | prefix = 'module.image_model.'
42 | pretrained_dict = {prefix + k[start:] :v for k,v in cnn_pretrained.items()}
43 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in network_dict}
44 | network_dict.update(pretrained_dict)
45 | network.load_state_dict(network_dict)
46 |
47 | # process optimizer params
48 | if split == 'test':
49 | optimizer = None
50 | else:
51 | # optimizer
52 | # different params for different part
53 | cnn_params = list(map(id, network.module.image_model.parameters()))
54 | other_params = filter(lambda p: id(p) not in cnn_params, network.parameters())
55 | other_params = list(other_params)
56 | if param is not None:
57 | other_params.extend(list(param))
58 | param_groups = [{'params':other_params},
59 | {'params':network.module.image_model.parameters(), 'weight_decay':args.wd}]
60 | optimizer = torch.optim.Adam(
61 | param_groups,
62 | lr = args.lr, betas=(args.adam_alpha, args.adam_beta), eps=args.epsilon)
63 | if resume:
64 | optimizer.load_state_dict(checkpoint['optimizer'])
65 |
66 | print('Total params: %2.fM' % (sum(p.numel() for p in network.parameters()) / 1000000.0))
67 | # seed
68 | manualSeed = random.randint(1, 10000)
69 | random.seed(manualSeed)
70 | np.random.seed(manualSeed)
71 | torch.manual_seed(manualSeed)
72 | torch.cuda.manual_seed_all(manualSeed)
73 |
74 | return network, optimizer
75 |
--------------------------------------------------------------------------------
/Demo_Web_App/directory.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | def makedir(root):
5 | if not os.path.exists(root):
6 | os.makedirs(root)
7 |
8 |
9 | def write_json(data, root):
10 | with open(root, 'w') as f:
11 | json.dump(data, f)
12 |
13 |
14 | def check_exists(root):
15 | if os.path.exists(root):
16 | return True
17 | return False
18 |
19 | def check_file(root, keyword):
20 | if not os.path.isfile(root):
21 | raise RuntimeError('===> No {} in {}'.format(keyword, root))
22 |
--------------------------------------------------------------------------------
/Demo_Web_App/models/__pycache__/bi_lstm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/bi_lstm.cpython-37.pyc
--------------------------------------------------------------------------------
/Demo_Web_App/models/__pycache__/bi_lstm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/bi_lstm.cpython-38.pyc
--------------------------------------------------------------------------------
/Demo_Web_App/models/__pycache__/mobilenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/mobilenet.cpython-37.pyc
--------------------------------------------------------------------------------
/Demo_Web_App/models/__pycache__/mobilenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/mobilenet.cpython-38.pyc
--------------------------------------------------------------------------------
/Demo_Web_App/models/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/Demo_Web_App/models/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/Demo_Web_App/models/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/Demo_Web_App/models/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/models/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/Demo_Web_App/models/bi_lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 |
5 | seed_num = 223
6 | torch.manual_seed(seed_num)
7 | random.seed(seed_num)
8 |
9 | """
10 | Neural Networks model : Bidirection LSTM
11 | """
12 |
13 |
14 | class BiLSTM(nn.Module):
15 | def __init__(self, args):
16 | super(BiLSTM, self).__init__()
17 |
18 | self.hidden_dim = args.num_lstm_units
19 |
20 | V = args.vocab_size
21 | D = args.embedding_size
22 |
23 | # word embedding
24 | self.embed = nn.Embedding(V, D, padding_idx=0)
25 |
26 | self.bilstm = nn.ModuleList()
27 | self.bilstm.append(nn.LSTM(D, args.num_lstm_units, num_layers=1, dropout=0, bidirectional=False, bias=False))
28 |
29 | self.bidirectional = args.bidirectional
30 | if self.bidirectional:
31 | self.bilstm.append(nn.LSTM(D, args.num_lstm_units, num_layers=1, dropout=0, bidirectional=False, bias=False))
32 |
33 |
34 | def forward(self, text, text_length):
35 | embed = self.embed(text)
36 |
37 | # unidirectional lstm
38 | bilstm_out = self.bilstm_out(embed, text_length, 0)
39 |
40 | if self.bidirectional:
41 | index_reverse = list(range(embed.shape[0]-1, -1, -1))
42 | index_reverse = torch.LongTensor(index_reverse).cuda()
43 | embed_reverse = embed.index_select(0, index_reverse)
44 | text_length_reverse = text_length.index_select(0, index_reverse)
45 | bilstm_out_bidirection = self.bilstm_out(embed_reverse, text_length_reverse, 1)
46 | bilstm_out_bidirection_reverse = bilstm_out_bidirection.index_select(0, index_reverse)
47 | bilstm_out = torch.cat([bilstm_out, bilstm_out_bidirection_reverse], dim=2)
48 | bilstm_out, _ = torch.max(bilstm_out, dim=1)
49 | bilstm_out = bilstm_out.unsqueeze(2).unsqueeze(2)
50 | return bilstm_out
51 |
52 |
53 | def bilstm_out(self, embed, text_length, index):
54 |
55 | _, idx_sort = torch.sort(text_length, dim=0, descending=True)
56 | _, idx_unsort = torch.sort(idx_sort, dim=0)
57 |
58 | embed_sort = embed.index_select(0, idx_sort)
59 | length_list = text_length[idx_sort]
60 | pack = nn.utils.rnn.pack_padded_sequence(embed_sort, length_list.cpu(), batch_first=True)
61 |
62 | bilstm_sort_out, _ = self.bilstm[index](pack)
63 | bilstm_sort_out = nn.utils.rnn.pad_packed_sequence(bilstm_sort_out, batch_first=True)
64 | bilstm_sort_out = bilstm_sort_out[0]
65 |
66 | bilstm_out = bilstm_sort_out.index_select(0, idx_unsort)
67 |
68 | return bilstm_out
69 |
70 |
71 | def weight_init(self, m):
72 | if isinstance(m, nn.Conv2d):
73 | nn.init.xavier_uniform_(m.weight.data, 1)
74 | nn.init.constant(m.bias.data, 0)
75 |
--------------------------------------------------------------------------------
/Demo_Web_App/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 | """
5 | Imported by https://github.com/marvis/pytorch-mobilenet/blob/master/main.py
6 | """
7 |
8 |
9 | class MobileNetV1(nn.Module):
10 | def __init__(self, dropout_keep_prob=0.999):
11 | super(MobileNetV1, self).__init__()
12 | self.dropout_keep_prob = dropout_keep_prob
13 | self.dropout = nn.Dropout(1 - dropout_keep_prob)
14 | def conv_bn(inp, oup, stride):
15 | return nn.Sequential(
16 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
17 | nn.BatchNorm2d(oup),
18 | nn.ReLU6(inplace=True)
19 | )
20 |
21 | def conv_dw(inp, oup, stride):
22 | return nn.Sequential(
23 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
24 | nn.BatchNorm2d(inp),
25 | nn.ReLU6(inplace=True),
26 |
27 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
28 | nn.BatchNorm2d(oup),
29 | nn.ReLU6(inplace=True),
30 | )
31 |
32 | self.model = nn.Sequential(
33 | conv_bn(3, 32, 2),
34 | conv_dw(32, 64, 1),
35 | conv_dw(64, 128, 2),
36 | conv_dw(128, 128, 1),
37 | conv_dw(128, 256, 2),
38 | conv_dw(256, 256, 1),
39 | conv_dw(256, 512, 2),
40 | conv_dw(512, 512, 1),
41 | conv_dw(512, 512, 1),
42 | conv_dw(512, 512, 1),
43 | conv_dw(512, 512, 1),
44 | conv_dw(512, 512, 1),
45 | conv_dw(512, 1024, 2),
46 | conv_dw(1024, 1024, 1),
47 | nn.AvgPool2d(7),
48 | )
49 |
50 |
51 | def weight_init(self, m):
52 | if isinstance(m, nn.Conv2d):
53 | # truncated_normal_initializer in tensorflow
54 | nn.init.normal_(m.weight.data, std=0.09)
55 | #nn.init.constant(m.bias.data, 0)
56 |
57 | def forward(self, x):
58 | x = self.model(x)
59 | x = self.dropout(x)
60 | return x
61 |
--------------------------------------------------------------------------------
/Demo_Web_App/models/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .bi_lstm import BiLSTM
3 | from .mobilenet import MobileNetV1
4 | from .resnet import resnet50
5 |
6 |
7 | class Model(nn.Module):
8 | def __init__(self, args):
9 | super(Model, self).__init__()
10 | if args.image_model == 'mobilenet_v1':
11 | self.image_model = MobileNetV1()
12 | self.image_model.apply(self.image_model.weight_init)
13 | elif args.image_model == 'resnet50':
14 | self.image_model = resnet50()
15 | elif args.image_model == 'resent101':
16 | self.image_model = resnet101()
17 |
18 | self.bilstm = BiLSTM(args)
19 | self.bilstm.apply(self.bilstm.weight_init)
20 |
21 | inp_size = 1024
22 | if args.image_model == 'resnet50' or args.image_model == 'resnet101':
23 | inp_size = 2048
24 | # shorten the tensor using 1*1 conv
25 | self.conv_images = nn.Conv2d(inp_size, args.feature_size, 1)
26 | self.conv_text = nn.Conv2d(1024, args.feature_size, 1)
27 |
28 |
29 | def forward(self, images, text, text_length):
30 | image_features = self.image_model(images)
31 | text_features = self.bilstm(text, text_length)
32 | image_embeddings, text_embeddings= self.build_joint_embeddings(image_features, text_features)
33 |
34 | return image_embeddings, text_embeddings
35 |
36 |
37 | def build_joint_embeddings(self, images_features, text_features):
38 |
39 | #images_features = images_features.permute(0,2,3,1)
40 | #text_features = text_features.permute(0,3,1,2)
41 | image_embeddings = self.conv_images(images_features).squeeze()
42 | text_embeddings = self.conv_text(text_features).squeeze()
43 |
44 | return image_embeddings, text_embeddings
45 |
--------------------------------------------------------------------------------
/Demo_Web_App/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 |
5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6 | 'resnet152']
7 |
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | }
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=1, bias=False)
22 |
23 |
24 | def conv1x1(in_planes, out_planes, stride=1):
25 | """1x1 convolution"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 |
32 | def __init__(self, inplanes, planes, stride=1, downsample=None):
33 | super(BasicBlock, self).__init__()
34 | self.conv1 = conv3x3(inplanes, planes, stride)
35 | self.bn1 = nn.BatchNorm2d(planes)
36 | self.relu = nn.ReLU(inplace=True)
37 | self.conv2 = conv3x3(planes, planes)
38 | self.bn2 = nn.BatchNorm2d(planes)
39 | self.downsample = downsample
40 | self.stride = stride
41 |
42 | def forward(self, x):
43 | residual = x
44 |
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv2(out)
50 | out = self.bn2(out)
51 |
52 | if self.downsample is not None:
53 | residual = self.downsample(x)
54 |
55 | out += residual
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 |
61 | class Bottleneck(nn.Module):
62 | expansion = 4
63 |
64 | def __init__(self, inplanes, planes, stride=1, downsample=None):
65 | super(Bottleneck, self).__init__()
66 | self.conv1 = conv1x1(inplanes, planes)
67 | self.bn1 = nn.BatchNorm2d(planes)
68 | self.conv2 = conv3x3(planes, planes, stride)
69 | self.bn2 = nn.BatchNorm2d(planes)
70 | self.conv3 = conv1x1(planes, planes * self.expansion)
71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.downsample = downsample
74 | self.stride = stride
75 |
76 | def forward(self, x):
77 | residual = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv3(out)
88 | out = self.bn3(out)
89 |
90 | if self.downsample is not None:
91 | residual = self.downsample(x)
92 |
93 | out += residual
94 | out = self.relu(out)
95 |
96 | return out
97 |
98 |
99 | class ResNet(nn.Module):
100 |
101 | def __init__(self, block, layers, num_classes=1000):
102 | super(ResNet, self).__init__()
103 | self.inplanes = 64
104 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
105 | bias=False)
106 | self.bn1 = nn.BatchNorm2d(64)
107 | self.relu = nn.ReLU(inplace=True)
108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
109 | self.layer1 = self._make_layer(block, 64, layers[0])
110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
113 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
114 | self.fc = nn.Linear(512 * block.expansion, num_classes)
115 |
116 | for m in self.modules():
117 | if isinstance(m, nn.Conv2d):
118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
119 | elif isinstance(m, nn.BatchNorm2d):
120 | nn.init.constant_(m.weight, 1)
121 | nn.init.constant_(m.bias, 0)
122 |
123 | def _make_layer(self, block, planes, blocks, stride=1):
124 | downsample = None
125 | if stride != 1 or self.inplanes != planes * block.expansion:
126 | downsample = nn.Sequential(
127 | conv1x1(self.inplanes, planes * block.expansion, stride),
128 | nn.BatchNorm2d(planes * block.expansion),
129 | )
130 |
131 | layers = []
132 | layers.append(block(self.inplanes, planes, stride, downsample))
133 | self.inplanes = planes * block.expansion
134 | for _ in range(1, blocks):
135 | layers.append(block(self.inplanes, planes))
136 |
137 | return nn.Sequential(*layers)
138 |
139 | def forward(self, x):
140 | x = self.conv1(x)
141 | x = self.bn1(x)
142 | x = self.relu(x)
143 | x = self.maxpool(x)
144 |
145 | x = self.layer1(x)
146 | x = self.layer2(x)
147 | x = self.layer3(x)
148 | x = self.layer4(x)
149 |
150 | x = self.avgpool(x)
151 | #x = x.view(x.size(0), -1)
152 | #x = self.fc(x)
153 |
154 | return x
155 |
156 |
157 | def resnet18(pretrained=False, **kwargs):
158 | """Constructs a ResNet-18 model.
159 |
160 | Args:
161 | pretrained (bool): If True, returns a model pre-trained on ImageNet
162 | """
163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
164 | if pretrained:
165 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
166 | return model
167 |
168 |
169 | def resnet34(pretrained=False, **kwargs):
170 | """Constructs a ResNet-34 model.
171 |
172 | Args:
173 | pretrained (bool): If True, returns a model pre-trained on ImageNet
174 | """
175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
176 | if pretrained:
177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
178 | return model
179 |
180 |
181 | def resnet50(pretrained=False, **kwargs):
182 | """Constructs a ResNet-50 model.
183 |
184 | Args:
185 | pretrained (bool): If True, returns a model pre-trained on ImageNet
186 | """
187 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
188 | if pretrained:
189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
190 | return model
191 |
192 |
193 | def resnet101(pretrained=False, **kwargs):
194 | """Constructs a ResNet-101 model.
195 |
196 | Args:
197 | pretrained (bool): If True, returns a model pre-trained on ImageNet
198 | """
199 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
200 | if pretrained:
201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
202 | return model
203 |
204 |
205 | def resnet152(pretrained=False, **kwargs):
206 | """Constructs a ResNet-152 model.
207 |
208 | Args:
209 | pretrained (bool): If True, returns a model pre-trained on ImageNet
210 | """
211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
212 | if pretrained:
213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
214 | return model
215 |
--------------------------------------------------------------------------------
/Demo_Web_App/saved_model/test_sort.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/saved_model/test_sort.pkl
--------------------------------------------------------------------------------
/Demo_Web_App/saved_model/word_to_index.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/saved_model/word_to_index.pkl
--------------------------------------------------------------------------------
/Demo_Web_App/static/temp/1620480640.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/Demo_Web_App/static/temp/1620480640.jpg
--------------------------------------------------------------------------------
/Demo_Web_App/templates/image_results.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Cross-modal learning for Fashion Retrieval
12 |
13 |
14 |
15 | Closest images
16 |
17 |
18 |
19 |
24 |
25 |
26 |
38 |
39 |
40 |
41 |
42 |
43 |
1
44 |
45 |
46 |
47 |
48 |
49 |
5
50 |
51 |
52 |
53 |
54 |
55 |
9
56 |
57 |
58 |
59 |
60 |
61 |
13
62 |
63 |
64 |
65 |
66 |
67 |
17
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
2
76 |
77 |
78 |
79 |
80 |
81 |
6
82 |
83 |
84 |
85 |
86 |
87 |
10
88 |
89 |
90 |
91 |
92 |
93 |
14
94 |
95 |
96 |
97 |
98 |
99 |
18
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
3
108 |
109 |
110 |
111 |
112 |
113 |
7
114 |
115 |
116 |
117 |
118 |
119 |
11
120 |
121 |
122 |
123 |
124 |
125 |
15
126 |
127 |
128 |
129 |
130 |
131 |
19
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
4
140 |
141 |
142 |
143 |
144 |
145 |
8
146 |
147 |
148 |
149 |
150 |
151 |
12
152 |
153 |
154 |
155 |
156 |
157 |
16
158 |
159 |
160 |
161 |
162 |
163 |
20
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
178 |
179 |
--------------------------------------------------------------------------------
/Demo_Web_App/templates/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Cross-modal learning for Fashion Retrieval
12 |
13 |
14 |
15 | Cross-modal learning for Fashion Retrieval
16 |
17 |
18 |
19 |
24 |
25 |
26 |
27 |
28 |
Query using an image or a text description
29 |
If you search using an image, we will retrieve the most similar text descriptions and if you search using a text description, we will retrieve the most similar images
30 |
31 |
32 |
33 |
54 |
55 |
56 |
57 |
58 |
59 |
62 |
63 |
--------------------------------------------------------------------------------
/Demo_Web_App/templates/text_results.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | Cross-modal learning for Fashion Retrieval
12 |
13 |
14 |
15 | Closest descriptions
16 |
17 |
18 |
19 |
24 |
25 |
26 |
35 |
36 |
37 | {{data[1]}}
38 |
39 |
40 | {{data[2]}}
41 |
42 |
43 | {{data[3]}}
44 |
45 |
46 | {{data[4]}}
47 |
48 |
49 | {{data[5]}}
50 |
51 |
52 | {{data[6]}}
53 |
54 |
55 | {{data[7]}}
56 |
57 |
58 | {{data[8]}}
59 |
60 |
61 | {{data[9]}}
62 |
63 |
64 | {{data[10]}}
65 |
66 |
67 | {{data[11]}}
68 |
69 |
70 | {{data[12]}}
71 |
72 |
73 | {{data[13]}}
74 |
75 |
76 | {{data[14]}}
77 |
78 |
79 | {{data[15]}}
80 |
81 |
82 | {{data[16]}}
83 |
84 |
85 | {{data[17]}}
86 |
87 |
88 | {{data[18]}}
89 |
90 |
91 | {{data[19]}}
92 |
93 |
94 | {{data[20]}}
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
106 |
107 |
--------------------------------------------------------------------------------
/Demo_Web_App/test_config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 |
4 |
5 | def parse_args():
6 | parser = argparse.ArgumentParser(description='command for evaluate on CUHK-PEDES')
7 | # Directory
8 | parser.add_argument('--image_dir', type=str, help='directory to store dataset')
9 | parser.add_argument('--anno_dir', type=str, help='directory to store anno')
10 | parser.add_argument('--model_path', type=str, help='directory to load checkpoint')
11 | parser.add_argument('--log_dir', type=str, help='directory to store log')
12 |
13 | # LSTM setting
14 | parser.add_argument('--embedding_size', type=int, default=512)
15 | parser.add_argument('--num_lstm_units', type=int, default=512)
16 | parser.add_argument('--vocab_size', type=int, default=12000)
17 | parser.add_argument('--lstm_dropout_ratio', type=float, default=0.7)
18 | parser.add_argument('--bidirectional', action='store_true')
19 |
20 | parser.add_argument('--max_length', type=int, default=100)
21 | parser.add_argument('--feature_size', type=int, default=512)
22 |
23 | parser.add_argument('--image_model', type=str, default='mobilenet_v1')
24 | parser.add_argument('--cnn_dropout_keep', type=float, default=0.999)
25 |
26 | parser.add_argument('--epoch_ema', type=int, default=0)
27 |
28 | # Default setting
29 | parser.add_argument('--gpus', type=str, default='0')
30 | args = parser.parse_args()
31 | return args
32 |
33 |
34 |
35 | def config():
36 | args = parse_args()
37 | return args
38 |
--------------------------------------------------------------------------------
/Demo_Web_App/tester.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | GPUS='0'
4 | os.system('export CUDA_VISIBLE_DEVICES='+GPUS)
5 |
6 | BASE_ROOT=''
7 | IMAGE_DIR='/content'
8 | ANNO_DIR=BASE_ROOT+'/data/processed_data'
9 | CKPT_DIR=BASE_ROOT+'/data/model_data'
10 | LOG_DIR=BASE_ROOT+'/data/logs'
11 | IMAGE_MODEL='mobilenet_v1'
12 | lr='0.0002'
13 | batch_size='16'
14 | lr_decay_ratio='0.9'
15 | epoches_decay='80_150_200'
16 |
17 |
18 | string = 'python3 {BASE_ROOT}web_app.py --bidirectional --model_path {CKPT_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --image_model {IMAGE_MODEL} --log_dir {LOG_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --image_dir {IMAGE_DIR} --anno_dir {ANNO_DIR} --gpus {GPUS} --epoch_ema 0'.format(BASE_ROOT=BASE_ROOT, IMAGE_DIR=IMAGE_DIR, IMAGE_MODEL=IMAGE_MODEL, ANNO_DIR=ANNO_DIR, CKPT_DIR=CKPT_DIR, LOG_DIR=LOG_DIR, lr=lr, batch_size=batch_size, lr_decay_ratio=lr_decay_ratio, GPUS=GPUS)
19 |
20 | os.system(string)
21 |
22 |
--------------------------------------------------------------------------------
/Demo_Web_App/web_app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from PIL import Image
4 | import shutil
5 | import pickle
6 | import string
7 | import time
8 | from imageio import imread
9 | import numpy as np
10 | import torch
11 | import torchvision.transforms as transforms
12 | from test_config import config
13 | from config import network_config
14 | from werkzeug.utils import secure_filename
15 | from flask import Flask, render_template, flash, request, redirect, url_for
16 | from flask_ngrok import run_with_ngrok
17 |
18 | app = Flask(__name__)
19 | app.debug = True
20 | run_with_ngrok(app)
21 | network = None
22 | model_path = 'saved_model/299.pth.tar'
23 | test_sort_path = 'saved_model/test_sort.pkl'
24 | word_to_index_path = 'saved_model/word_to_index.pkl'
25 | test_sort = None
26 | word_to_index = None
27 | test_transform = transforms.Compose([
28 | transforms.ToTensor(),
29 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
30 | ])
31 |
32 | def load_model(args):
33 | network, _ = network_config(args,'test', None, True, model_path, False)
34 | network.eval()
35 | with open(test_sort_path, 'rb') as f_pkl:
36 | test_sort = pickle.load(f_pkl)
37 | with open(word_to_index_path, 'rb') as f_2_pkl:
38 | word_to_index = pickle.load(f_2_pkl)
39 | word_to_index = {k.lower(): v for k, v in word_to_index.items()}
40 | return network, test_sort, word_to_index
41 |
42 | def index_to_word(indexed_caption):
43 | index_to_word_dict = {value:key for key, value in word_to_index.items()}
44 |
45 | caption = []
46 | for token in indexed_caption:
47 | if token in index_to_word_dict:
48 | caption.append(index_to_word_dict[token])
49 | else:
50 | caption.append(' ')
51 |
52 | return ' '.join(caption[1:-1])
53 |
54 | def retrieve_captions():
55 | img = imread('static/temp/'+os.listdir('static/temp/')[-1])
56 | img = np.array(Image.fromarray(img).resize(size=(224,224)))
57 | images = test_transform(img)
58 | images = torch.reshape(images, (1, images.shape[0], images.shape[1], images.shape[2]))
59 |
60 | captions = test_sort['caption_id']
61 | caption_lengths = torch.tensor([len(c) for c in captions])
62 | captions = torch.tensor([c + [0]*(100-len(c)) for c in captions])
63 |
64 | with torch.no_grad():
65 | image_embeddings, text_embeddings = network(images, captions, caption_lengths)
66 |
67 | sim = torch.matmul(text_embeddings, image_embeddings.t())
68 | ind = sim.topk(20, 0)[1].reshape(-1)
69 |
70 | retrieved_captions = ['temp/'+os.listdir('static/temp/')[-1]]
71 | for i in ind:
72 | retrieved_captions.append(index_to_word(test_sort['caption_id'][i]))
73 |
74 | return retrieved_captions
75 |
76 | def retrieve_images(caption):
77 | exclude = set(string.punctuation)
78 | cap = ''.join(c for c in caption if c not in exclude)
79 | tokens = cap.split()
80 | tokens = [''] + tokens + ['']
81 | indexed_caption = []
82 |
83 | for token in tokens:
84 | if token.lower() in word_to_index:
85 | indexed_caption.append(word_to_index[token.lower()])
86 | else:
87 | indexed_caption.append(0)
88 |
89 | caption_lengths = torch.tensor([len(indexed_caption)])
90 | indexed_caption += [0] * (100 - len(indexed_caption))
91 | captions = torch.tensor([indexed_caption])
92 |
93 | paths = test_sort['images_path']
94 | images = []
95 | for img_path in paths:
96 | img = imread('static/'+img_path)
97 | img = np.array(Image.fromarray(img).resize(size=(224,224)))
98 | images.append(test_transform(img))
99 | images = torch.stack(images)
100 |
101 | with torch.no_grad():
102 | image_embeddings, text_embeddings = network(images, captions, caption_lengths)
103 |
104 | sim = torch.matmul(image_embeddings, text_embeddings.t())
105 | ind = sim.topk(20, 0)[1].reshape(-1)
106 |
107 | retrieved_images = [caption]
108 | for i in ind:
109 | retrieved_images.append(paths[i])
110 |
111 | return retrieved_images
112 |
113 | def get_query(request):
114 | try:
115 | text = request.form['textquery']
116 | except:
117 | text = None
118 |
119 | try:
120 | image = request.files['imagequery']
121 | for fname in os.listdir('static/temp'):
122 | os.remove('static/temp/'+fname)
123 | image.save('static/temp/'+str(int(time.time()))+'.jpg')
124 | except:
125 | image = None
126 |
127 | if text is None:
128 | return (image, 'image')
129 | else:
130 | return (text, 'text')
131 |
132 |
133 | @app.route('/')
134 | def index():
135 | return render_template('index.html')
136 |
137 | @app.route('/', methods=['POST'])
138 | def predict_from_location():
139 | query, query_type = get_query(request)
140 | if query_type == 'text':
141 | retrieved_images = retrieve_images(query)
142 | return render_template('image_results.html', data=retrieved_images)
143 | else:
144 | retrieved_captions = retrieve_captions()
145 | return render_template('text_results.html', data=retrieved_captions)
146 |
147 | if __name__ == '__main__':
148 | print('Parsing arguments...')
149 | args = config()
150 | print('Loading model weights...')
151 | network, test_sort, word_to_index = load_model(args)
152 | app.run()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Image_Text_Retrieval
2 |
3 | ## Motivation
4 | Cross-modal learning involves information obtained from morethan one modality. Fashion clothing industry is one such field whereproduct retrieval based on multiple modalities such as image and text has become important. In the online fashion industry, being able to search for a product that matches either an image queryor a text query is in high demand
5 |
6 | ## Our Work
7 | In this work, we implement different cross-modal learning schemes such as Siamese Network, Correlational Network and Deep Cross-Modal Projection Learning model and study their performance. We also propose a modified Deep Cross-Modal Projection Learning model that uses a different image feature extractor. We evaluate the model’s performance on image-text retrieval on afashion clothing dataset.
8 |
9 | ## Instructions to run the code
10 |
11 | The repository contains 3 folders each of which contains the source code for different model architectures we experimented with. Specifically these are:
12 | * Deep CMPL model
13 | * Siamese Network
14 | * Correlational Network
15 |
16 | Each of the folders contains a dedicated Readme detailing the instructions to run the source code for each model. The source code is well commented and readable.
17 |
18 |
19 |
20 | # Theoretical Details
21 |
22 | ## Model Architectures
23 |
24 | ### Siamese Network
25 | Siamese Network is a neural network architecture that contains two or more identical sub-networks having the same weights and parameters. It is commonly used to find the similarity of the inputs by comparing its feature vector outputs. We implemented a two-branch neural network inspired from Siamese Network architecture and used a contrastive loss function for our task.
26 |
27 |
28 |
29 | #### *Network Architecture*
30 | 
31 |
32 |
33 |
34 | ### Correlational Network
35 | Correlational Network is an autoencoder based approach that explicitly maximises correlation between image and text embedded vectors in addition to minimising the error of reconstructing thetwo views(image and text). This model also has two branches -one for images and one for text, but at the same time it also has anencoder and decoder.
36 |
37 |
38 |
39 | #### *Network Architecture*
40 | 
41 |
42 |
43 |
44 | ### DEEP CMPL Network
45 | Cross-Modal Projection Learning includes Cross-Modal Pro-jection Matching (CMPM) loss for learning discriminative image-text embeddings. This novel image-text matching loss minimizesthe relative entropy between the projection distributions and thenormalized matching distributions.
46 |
47 | #### Modified Deep CMPL
48 | We modified the Deep Cross-Modal Projection Learning modelby using the EfficientNet architecture instead of MobileNet as the image feature extractor. EfficientNet is a recently pro-posed convolutional neural architecture which outperforms other state-of-the-art convolutional neural networks both in terms of efficiency and accuracy.
49 |
50 |
51 |
52 | #### *Network Architecture*
53 | 
54 |
55 |
56 |
57 | ## Experimentations
58 |
59 | We experimented different combinations of text and image feature extractors for learning common image-text embeddings. The tested combinations include:
60 | * Experiment 1: Siamese Network
61 | * Experiment 2: Correlational Network
62 | * Experiment 3: Deep CMPL with MobileNet
63 | * Experiment 4: Deep CMPL with EfficientNet on Indian Fashion
64 | * Experiment 5: Deep CMPL with EfficientNet on DeepFash-ion
65 |
66 |
67 | The metrics obtained from these experiments are as follows:
68 |
69 |
70 | 
71 |
72 | 
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/Siamese/README.md:
--------------------------------------------------------------------------------
1 | # Siamese Network
2 |
3 | > The way to load, train and test this model has been shown in the jupyter notebooks. The code has also been well documented detailing the use of every function.
4 |
5 | ## Instructions to Run Code
6 | 1. The text csv file is to be generated using the Word2vecgensim.ipynb notebook, for some text data.
7 | 2. Images must be added in the appropriate folder, comments are added to help.
8 | 3. Set all the paths in siamese_network.ipynb.
9 | 4. Then, simply run the siamese_network.ipynb notebook.
--------------------------------------------------------------------------------
/Siamese/siamese_network.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "kernelspec": {
6 | "language": "python",
7 | "display_name": "Python 3",
8 | "name": "python3"
9 | },
10 | "language_info": {
11 | "name": "python",
12 | "version": "3.7.9",
13 | "mimetype": "text/x-python",
14 | "codemirror_mode": {
15 | "name": "ipython",
16 | "version": 3
17 | },
18 | "pygments_lexer": "ipython3",
19 | "nbconvert_exporter": "python",
20 | "file_extension": ".py"
21 | },
22 | "colab": {
23 | "name": "siamese-network.ipynb",
24 | "provenance": []
25 | }
26 | },
27 | "cells": [
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {
31 | "id": "t37o4I-3obHH"
32 | },
33 | "source": [
34 | "## Importing libraries, initialising global variables"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "metadata": {
40 | "_uuid": "4df1608c-a2a9-48f9-8812-af1b5cdf1b16",
41 | "_cell_guid": "b9556bde-e5b6-4643-9afe-82298daf444e",
42 | "trusted": true,
43 | "id": "7CNvDBi7obHL"
44 | },
45 | "source": [
46 | "import imageio\n",
47 | "from statistics import median\n",
48 | "from random import randint\n",
49 | "from glob import glob\n",
50 | "import pandas as pd\n",
51 | "import numpy as np\n",
52 | "from keras.layers.core import Flatten, Dropout\n",
53 | "from keras.layers import Input, Dense, Lambda, Layer\n",
54 | "from keras import backend as K\n",
55 | "from keras import applications\n",
56 | "from keras.models import Sequential, Model\n",
57 | "from keras.optimizers import RMSprop, Adam\n",
58 | "from keras.callbacks import ModelCheckpoint\n",
59 | "from tensorflow.keras.preprocessing.image import load_img\n",
60 | "from tensorflow.keras.applications.resnet import preprocess_input\n",
61 | "from tensorflow.keras.preprocessing.image import img_to_array\n",
62 | "from tensorflow.keras.applications.resnet import ResNet152\n",
63 | "\n",
64 | "# Path to folder containing images\n",
65 | "DATASET_PATH = './Images'\n",
66 | "\n",
67 | "num_samples = 12000"
68 | ],
69 | "execution_count": 1,
70 | "outputs": []
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {
75 | "id": "xPpnmyfIobHO"
76 | },
77 | "source": [
78 | "## Generator function"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "metadata": {
84 | "trusted": true,
85 | "id": "T37PuCDJobHQ"
86 | },
87 | "source": [
88 | "\n",
89 | "# data generator for neural network\n",
90 | "# forms correct and incorrect pairings of images with text descriptions and labels them as correct (1) or incorrect (0)\n",
91 | "\n",
92 | "def generator(batch_size, df):\n",
93 | " \n",
94 | " batch_img = np.zeros((batch_size, 224, 224, 3))\n",
95 | " batch_txt = np.zeros((batch_size, 512))\n",
96 | " batch_labels = np.zeros((batch_size,1))\n",
97 | " \n",
98 | " video_ids = df['image']\n",
99 | " video_txt = df['txt_enc']\n",
100 | " \n",
101 | " length = len(df) -1\n",
102 | " \n",
103 | " while True:\n",
104 | " for i in range(batch_size//2):\n",
105 | " \n",
106 | " i = i*2\n",
107 | " \n",
108 | " #correct\n",
109 | " sample = randint(0,length)\n",
110 | " file = video_ids.iloc[sample]\n",
111 | " \n",
112 | " correct_txt = video_txt.iloc[sample]\n",
113 | " \n",
114 | " im = load_img(file, target_size=(224, 224))\n",
115 | " im = img_to_array(im)\n",
116 | " im = np.expand_dims(im, axis=0)\n",
117 | " im = preprocess_input(im)\n",
118 | " \n",
119 | " batch_img[i-2] = im\n",
120 | " batch_txt[i-2] = correct_txt\n",
121 | " batch_labels[i-2] = 1\n",
122 | " \n",
123 | " #incorrect \n",
124 | " file = video_ids.iloc[randint(0,length)]\n",
125 | " \n",
126 | " im = load_img(file, target_size=(224, 224))\n",
127 | " im = img_to_array(im)\n",
128 | " im = np.expand_dims(im, axis=0)\n",
129 | " im = preprocess_input(im)\n",
130 | "\n",
131 | " batch_img[i-1] = im\n",
132 | " batch_txt[i-1] = correct_txt\n",
133 | " batch_labels[i-1] = 0\n",
134 | " \n",
135 | " yield [batch_txt, batch_img], batch_labels"
136 | ],
137 | "execution_count": 2,
138 | "outputs": []
139 | },
140 | {
141 | "cell_type": "markdown",
142 | "metadata": {
143 | "id": "9Jd2CQ73obHS"
144 | },
145 | "source": [
146 | "## Utils"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "metadata": {
152 | "trusted": true,
153 | "id": "Ye8dgWGgobHT"
154 | },
155 | "source": [
156 | "def euclidean_distance(vects):\n",
157 | " x, y = vects\n",
158 | " return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))\n",
159 | "\n",
160 | "def eucl_dist_output_shape(shapes):\n",
161 | " shape1, shape2 = shapes\n",
162 | " return (shape1[0], 1)\n",
163 | "\n",
164 | "def contrastive_loss(y_true, y_pred):\n",
165 | " margin = 1\n",
166 | " return K.mean(y_true * K.square(y_pred) + (1 - y_true) * K.square(K.maximum(margin - y_pred, 0)))\n",
167 | "\n",
168 | "def create_img_encoder(input_dim, resnet):\n",
169 | " x = Sequential()\n",
170 | " x.add(resnet)\n",
171 | " x.add(Dense(500, activation=\"relu\"))\n",
172 | " x.add(Dropout(0.5))\n",
173 | " x.add(Dense(512, activation=\"relu\"))\n",
174 | " return x\n",
175 | "\n",
176 | "def create_txt_encoder(input_dim):\n",
177 | " x = Sequential()\n",
178 | " x.add(Dense(500, input_shape = (512,), activation=\"relu\"))\n",
179 | " x.add(Dropout(0.5))\n",
180 | " x.add(Dense(512, activation=\"relu\"))\n",
181 | " return x\n",
182 | "\n",
183 | "def compute_accuracy(predictions, labels):\n",
184 | " return labels[predictions.ravel() < 0.5].mean()"
185 | ],
186 | "execution_count": 3,
187 | "outputs": []
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "metadata": {
192 | "id": "3T5g0oNbobHV"
193 | },
194 | "source": [
195 | "## Initialise ResNet152"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "metadata": {
201 | "trusted": true,
202 | "colab": {
203 | "base_uri": "https://localhost:8080/"
204 | },
205 | "id": "-77a0Xr8obHV",
206 | "outputId": "7e5a9a53-9a96-4212-826e-07eaab5a6a25"
207 | },
208 | "source": [
209 | "resnet = ResNet152(include_top=True, weights='imagenet')\n",
210 | "\n",
211 | "for layer in resnet.layers:\n",
212 | " layer.trainable = False"
213 | ],
214 | "execution_count": 4,
215 | "outputs": [
216 | {
217 | "output_type": "stream",
218 | "text": [
219 | "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet152_weights_tf_dim_ordering_tf_kernels.h5\n",
220 | "242900992/242900224 [==============================] - 2s 0us/step\n"
221 | ],
222 | "name": "stdout"
223 | }
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "metadata": {
229 | "id": "luQkh4-XobHX"
230 | },
231 | "source": [
232 | "## Creating model and loading data"
233 | ]
234 | },
235 | {
236 | "cell_type": "code",
237 | "metadata": {
238 | "trusted": true,
239 | "colab": {
240 | "base_uri": "https://localhost:8080/"
241 | },
242 | "id": "x92fCwLmobHX",
243 | "outputId": "57462219-c1f8-4477-bbfb-885fba5c4d2f"
244 | },
245 | "source": [
246 | "input_txt = Input(shape=(512,))\n",
247 | "input_img = Input(shape=(224, 224, 3))\n",
248 | "\n",
249 | "txt_enc = create_txt_encoder(input_txt)\n",
250 | "img_enc = create_img_encoder(input_img, resnet)\n",
251 | "\n",
252 | "encoded_txt = txt_enc(input_txt)\n",
253 | "encoded_img = img_enc(input_img)\n",
254 | "\n",
255 | "distance = Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([encoded_txt, encoded_img])\n",
256 | "\n",
257 | "model = Model([input_txt, input_img], distance)\n",
258 | "\n",
259 | "adam = Adam(lr=0.00001)\n",
260 | "model.compile(loss=contrastive_loss, optimizer=adam)\n",
261 | "\n",
262 | "model.summary()\n",
263 | "\n",
264 | "\n"
265 | ],
266 | "execution_count": 5,
267 | "outputs": [
268 | {
269 | "output_type": "stream",
270 | "text": [
271 | "Model: \"model\"\n",
272 | "__________________________________________________________________________________________________\n",
273 | "Layer (type) Output Shape Param # Connected to \n",
274 | "==================================================================================================\n",
275 | "input_2 (InputLayer) [(None, 512)] 0 \n",
276 | "__________________________________________________________________________________________________\n",
277 | "input_3 (InputLayer) [(None, 224, 224, 3) 0 \n",
278 | "__________________________________________________________________________________________________\n",
279 | "sequential (Sequential) (None, 512) 513012 input_2[0][0] \n",
280 | "__________________________________________________________________________________________________\n",
281 | "sequential_1 (Sequential) (None, 512) 61176956 input_3[0][0] \n",
282 | "__________________________________________________________________________________________________\n",
283 | "lambda (Lambda) (None, 1) 0 sequential[0][0] \n",
284 | " sequential_1[0][0] \n",
285 | "==================================================================================================\n",
286 | "Total params: 61,689,968\n",
287 | "Trainable params: 1,270,024\n",
288 | "Non-trainable params: 60,419,944\n",
289 | "__________________________________________________________________________________________________\n"
290 | ],
291 | "name": "stdout"
292 | }
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "metadata": {
298 | "trusted": true,
299 | "id": "e2mh48irobHY"
300 | },
301 | "source": [
302 | "# The CSV generated by the word2vec(gensim) model\n",
303 | "data = pd.read_csv('./word2vec_gensim.csv', header=None)\n",
304 | "data = list(np.array(data))\n",
305 | "\n",
306 | "img_paths = [DATASET_PATH + str(i) + '.jpg' for i in range(12305)]\n",
307 | "\n",
308 | "dataset = pd.DataFrame()\n",
309 | "dataset['image'] = pd.Series(img_paths)\n",
310 | "dataset['txt_enc'] = pd.Series(data)\n",
311 | "\n",
312 | "df_test = dataset[num_samples:]\n",
313 | "dataset = dataset[:num_samples]\n",
314 | "\n",
315 | "df_train = dataset[:int(num_samples*0.8)]\n",
316 | "df_val = dataset[int(num_samples*0.8):]\n"
317 | ],
318 | "execution_count": null,
319 | "outputs": []
320 | },
321 | {
322 | "cell_type": "markdown",
323 | "metadata": {
324 | "id": "aES0ha1fobHa"
325 | },
326 | "source": [
327 | "## Training"
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "metadata": {
333 | "trusted": true,
334 | "id": "mrlWYM3aobHb"
335 | },
336 | "source": [
337 | "model.fit_generator(generator(30, df_train), steps_per_epoch= int(int(num_samples*0.8)/30), validation_data= generator(30, df_val), validation_steps=int(int(num_samples*0.2)/30), epochs=200, verbose=1)\n",
338 | "model.save_weights('./weights.h5')"
339 | ],
340 | "execution_count": null,
341 | "outputs": []
342 | },
343 | {
344 | "cell_type": "markdown",
345 | "metadata": {
346 | "id": "n6aJV10WobHb"
347 | },
348 | "source": [
349 | "## Load saved weights"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "metadata": {
355 | "trusted": true,
356 | "id": "wgSxvtsrobHb"
357 | },
358 | "source": [
359 | "# Load from where you stored the weights\n",
360 | "model.load_weights('./weights.h5')"
361 | ],
362 | "execution_count": null,
363 | "outputs": []
364 | },
365 | {
366 | "cell_type": "markdown",
367 | "metadata": {
368 | "id": "2SLZVox5obHd"
369 | },
370 | "source": [
371 | "## Decide size of test set"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "metadata": {
377 | "trusted": true,
378 | "id": "mZgSAMXxobHd"
379 | },
380 | "source": [
381 | "subset_size = 300\n",
382 | "subset = df_test.iloc[:subset_size]"
383 | ],
384 | "execution_count": null,
385 | "outputs": []
386 | },
387 | {
388 | "cell_type": "markdown",
389 | "metadata": {
390 | "id": "IQnoQLbZobHd"
391 | },
392 | "source": [
393 | "## Metrics"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "metadata": {
399 | "trusted": true,
400 | "id": "4J7OWSyGobHd"
401 | },
402 | "source": [
403 | "# metrics - img -> text\n",
404 | "\n",
405 | "mr = []\n",
406 | "top_1_count = 0\n",
407 | "top_5_count = 0\n",
408 | "top_10_count = 0\n",
409 | "\n",
410 | "for i in range(subset_size):\n",
411 | " file = subset['image'].iloc[i]\n",
412 | " im = load_img(file, target_size=(224, 224))\n",
413 | " im = img_to_array(im)\n",
414 | " im = np.expand_dims(im, axis=0)\n",
415 | " im = preprocess_input(im)\n",
416 | " \n",
417 | " image_array = np.zeros((subset_size, 224, 224, 3))\n",
418 | " for k in range(subset_size):\n",
419 | " image_array[k] = im\n",
420 | " \n",
421 | " txt_array = np.zeros((subset_size, 512))\n",
422 | " for j in range(subset_size):\n",
423 | " txt = subset['txt_enc'].iloc[j]\n",
424 | " txt_array[j] = txt\n",
425 | " \n",
426 | " predictions = [pred[0] for pred in model.predict([txt_array, image_array])]\n",
427 | " pred_i = predictions[i]\n",
428 | " predictions.sort()\n",
429 | " rank = predictions.index(pred_i)\n",
430 | " if rank < 10:\n",
431 | " top_10_count += 1\n",
432 | " if rank < 5:\n",
433 | " top_5_count += 1\n",
434 | " if rank < 1:\n",
435 | " top_1_count += 1\n",
436 | " mr.append(rank+1) \n",
437 | "\n",
438 | "print('Median Rank(img->txt):', median(mr)*100/subset_size, '%')\n",
439 | "print('R@1(img->txt):', top_1_count*100/subset_size, '%')\n",
440 | "print('R@5(img->txt):', top_5_count*100/subset_size, '%')\n",
441 | "print('R@10(img->txt):', top_10_count*100/subset_size, '%')"
442 | ],
443 | "execution_count": null,
444 | "outputs": []
445 | },
446 | {
447 | "cell_type": "code",
448 | "metadata": {
449 | "trusted": true,
450 | "id": "a7YkBwZoobHd",
451 | "outputId": "e20587e6-d8d5-45af-8e8d-7678adef6643"
452 | },
453 | "source": [
454 | "# metrics - txt -> img\n",
455 | "\n",
456 | "mr = []\n",
457 | "top_1_count = 0\n",
458 | "top_5_count = 0\n",
459 | "top_10_count = 0\n",
460 | "\n",
461 | "for i in range(subset_size):\n",
462 | " txt = subset['txt_enc'].iloc[i] \n",
463 | " txt_array = np.zeros((subset_size, 512))\n",
464 | " for k in range(subset_size):\n",
465 | " txt_array[k] = txt\n",
466 | " \n",
467 | " \n",
468 | " image_array = np.zeros((subset_size, 224, 224, 3))\n",
469 | " for j in range(subset_size):\n",
470 | " file = subset['image'].iloc[j]\n",
471 | " im = load_img(file, target_size=(224, 224))\n",
472 | " im = img_to_array(im)\n",
473 | " im = np.expand_dims(im, axis=0)\n",
474 | " im = preprocess_input(im)\n",
475 | " image_array[k] = im\n",
476 | " \n",
477 | " predictions = [pred[0] for pred in model.predict([txt_array, image_array])]\n",
478 | " pred_i = predictions[i]\n",
479 | " predictions.sort()\n",
480 | " rank = predictions.index(pred_i)\n",
481 | " if rank < 10:\n",
482 | " top_10_count += 1\n",
483 | " if rank < 5:\n",
484 | " top_5_count += 1\n",
485 | " if rank < 1:\n",
486 | " top_1_count += 1\n",
487 | " mr.append(rank+1) \n",
488 | "\n",
489 | "print('Median Rank(txt->img):', median(mr)*100/subset_size, '%')\n",
490 | "print('R@1(txt->img):', top_1_count*100/subset_size, '%')\n",
491 | "print('R@5(txt->img):', top_5_count*100/subset_size, '%')\n",
492 | "print('R@10(txt->img):', top_10_count*100/subset_size, '%')"
493 | ],
494 | "execution_count": null,
495 | "outputs": [
496 | {
497 | "output_type": "stream",
498 | "text": [
499 | "Median Rank(txt->img): 0.6666666666666666 %\n",
500 | "R@1(txt->img): 33.666666666666664 %\n",
501 | "R@5(txt->img): 95.33333333333333 %\n",
502 | "R@10(txt->img): 95.33333333333333 %\n"
503 | ],
504 | "name": "stdout"
505 | }
506 | ]
507 | },
508 | {
509 | "cell_type": "markdown",
510 | "metadata": {
511 | "id": "z8LTe6sGobHf"
512 | },
513 | "source": [
514 | "## Download Weights"
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "metadata": {
520 | "trusted": true,
521 | "id": "osTWoctwobHf"
522 | },
523 | "source": [
524 | "# download weights\n",
525 | "\n",
526 | "from IPython.display import FileLink\n",
527 | "\n",
528 | "FileLink(r'./weights.h5')"
529 | ],
530 | "execution_count": null,
531 | "outputs": []
532 | },
533 | {
534 | "cell_type": "markdown",
535 | "metadata": {
536 | "id": "XObkIosAobHf"
537 | },
538 | "source": [
539 | "## Try predicting"
540 | ]
541 | },
542 | {
543 | "cell_type": "code",
544 | "metadata": {
545 | "trusted": true,
546 | "id": "s7pceQ4dobHg"
547 | },
548 | "source": [
549 | "# trying out predict\n",
550 | "\n",
551 | "text = np.zeros((2, 512))\n",
552 | "image = np.zeros((2, 224, 224, 3))\n",
553 | "\n",
554 | "file = dataset['image'].iloc[21] \n",
555 | "correct_txt = dataset['txt_enc'].iloc[21]\n",
556 | "\n",
557 | "im = load_img(file, target_size=(224, 224))\n",
558 | "im = img_to_array(im)\n",
559 | "im = np.expand_dims(im, axis=0)\n",
560 | "im = preprocess_input(im)\n",
561 | "\n",
562 | "image[0] = im\n",
563 | "\n",
564 | "text[0] = correct_txt\n",
565 | "\n",
566 | "file = dataset['image'].iloc[21] \n",
567 | "correct_txt = dataset['txt_enc'].iloc[90]\n",
568 | "\n",
569 | "im = load_img(file, target_size=(224, 224))\n",
570 | "im = img_to_array(im)\n",
571 | "im = np.expand_dims(im, axis=0)\n",
572 | "im = preprocess_input(im)\n",
573 | "\n",
574 | "image[1] = im\n",
575 | "\n",
576 | "text[1] = correct_txt\n",
577 | "\n",
578 | "model.predict([text, image])"
579 | ],
580 | "execution_count": null,
581 | "outputs": []
582 | }
583 | ]
584 | }
--------------------------------------------------------------------------------
/assets/correlational.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/assets/correlational.png
--------------------------------------------------------------------------------
/assets/deepcmpl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/assets/deepcmpl.png
--------------------------------------------------------------------------------
/assets/i2t.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/assets/i2t.PNG
--------------------------------------------------------------------------------
/assets/siamese.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/assets/siamese.png
--------------------------------------------------------------------------------
/assets/t2i.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/assets/t2i.PNG
--------------------------------------------------------------------------------
/deep_cmpl_model/.gitignore:
--------------------------------------------------------------------------------
1 | <<<<<<< HEAD
2 | # Images of the dataset
3 | dataset
4 |
5 | # Checkpoint directories
6 | model_data
7 | logs
8 |
9 | #zip files
10 | *.zip
11 |
12 | # Misccellaneous
13 | data/f8k
14 |
15 |
16 | =======
17 | dataset/
18 | drive/
19 | dgpu/drive/
20 | dgpu/data/logs
21 | >>>>>>> 7d7a4ac943534fa53775eed173be3313ae889b49
22 |
--------------------------------------------------------------------------------
/deep_cmpl_model/README.md:
--------------------------------------------------------------------------------
1 | # Deep Cross-Modal Projection Matching Model
2 |
3 | > The way to load, train and test this model has been shown in the deep_cmpl_jupyter.ipynb file. The code has also been well documented detailing the use of every function.
4 |
5 |
6 |
7 |
8 | ## Directory Structure
9 | 1. The code folder contains the source code for the model
10 | * The image and text models are stored in the models directory. These include mobilenet, efficientnet and bilstm.
11 | * The scripts folder contains the scripts for training and testing the model.
12 | * The utils folder contains metric calculation functions and other helper functions
13 | * The datasets folder contains `data.sh` which preprocesses the dataset and stores it in a pickled format. The `fashion.py` file contains the class for creating the dataset object.
14 | 2. The data folder contains dataset related files
15 | * The processed_data folder contains the pickled dataset
16 | * Images.csv is the file storing image paths and corresponding captions
17 | * Reid_raw.json is a file accepted as input by process.py
18 |
19 |
20 |
21 |
22 | ## Instructions to run
23 | 1. Set the appropriate path for storing checkpoints in `trainer.py` and `tester.py` scripts. Specifically, you need to change the BASE_ROOT_2 parameter in these files. Set it to a google folder to which you have access.
24 | 2. The sample code in the jupyter notebook loads the Indian Fashion Dataset. For training the model on some other dataset, you need to change `make_json.py`, `images.csv` and `preprocess.py` appropriately.
25 | 3. Execute the instructions in the notebook.
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/__pycache__/test.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/test.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/__pycache__/test.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/test.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/__pycache__/test_config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/test_config.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/__pycache__/test_config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/test_config.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/__pycache__/test_params.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/test_params.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/__pycache__/train_config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/train_config.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/__pycache__/train_config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/train_config.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/__pycache__/train_params.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/__pycache__/train_params.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/datasets/__pycache__/directory.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/datasets/__pycache__/directory.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/datasets/__pycache__/fashion.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/datasets/__pycache__/fashion.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/datasets/__pycache__/pedes.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/datasets/__pycache__/pedes.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/datasets/__pycache__/pedes.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/datasets/__pycache__/pedes.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/datasets/data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | BASE_ROOT=Image_Text_Retrieval/deep_cmpl_model
4 | IMAGE_ROOT=/content/dataset
5 |
6 | JSON_ROOT=$BASE_ROOT/data/reid_raw.json
7 | OUT_ROOT=$BASE_ROOT/data/processed_data
8 |
9 | echo "Preprocessing dataset"
10 |
11 | rm -rf $OUT_ROOT
12 |
13 | python3 $BASE_ROOT/code/datasets/preprocess.py \
14 | --img_root=${IMAGE_ROOT} \
15 | --json_root=${JSON_ROOT} \
16 | --out_root=${OUT_ROOT} \
17 | --min_word_count 3 \
18 | --first
19 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/datasets/directory.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | def makedir(root):
5 | if not os.path.exists(root):
6 | os.makedirs(root)
7 |
8 |
9 | def write_json(data, root):
10 | with open(root, 'w') as f:
11 | json.dump(data, f)
12 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/datasets/fashion.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import numpy as np
3 | import os
4 | import pickle
5 | import h5py
6 | from PIL import Image
7 | from utils.directory import check_exists
8 | # from scipy.misc import imresize
9 | from PIL import Image
10 | from imageio import imread
11 |
12 |
13 | """
14 | Created the dataset object
15 | """
16 |
17 | class Fashion(data.Dataset):
18 | '''
19 | Args:
20 | root (string): Base root directory of dataset where [split].pkl and [split].h5 exists
21 | split (string): 'train', 'val' or 'test'
22 | transform (callable, optional): A function/transform that takes in an PIL image
23 | and returns a transformed vector. E.g, ''transform.RandomCrop'
24 | target_transform (callable, optional): A funciton/transform that tkes in the
25 | targt and transfomrs it.
26 | '''
27 | pklname_list = ['train_sort.pkl', 'val_sort.pkl', 'test_sort.pkl']
28 | h5name_list = ['train.h5', 'val.h5', 'test.h5']
29 |
30 | def __init__(self, image_root, anno_root, split, max_length, transform=None, target_transform=None, cap_transform=None):
31 |
32 | self.image_root = image_root
33 | self.anno_root = anno_root
34 | self.max_length = max_length
35 | self.transform = transform
36 | self.target_transform = target_transform
37 | self.cap_transform = cap_transform
38 | self.split = split.lower()
39 |
40 | if not check_exists(self.image_root):
41 | raise RuntimeError('Dataset not found or corrupted.' +
42 | 'Please follow the directions to generate datasets')
43 |
44 | if self.split == 'train':
45 | self.pklname = self.pklname_list[0]
46 | #self.h5name = self.h5name_list[0]
47 |
48 | with open(os.path.join(self.anno_root, self.pklname), 'rb') as f_pkl:
49 | data = pickle.load(f_pkl)
50 | self.train_labels = data['labels']
51 | self.train_captions = data['caption_id']
52 | self.train_images = data['images_path']
53 | #data_h5py = h5py.File(os.path.join(self.root, self.h5name), 'r')
54 | #self.train_images = data_h5py['images']
55 |
56 |
57 | elif self.split == 'val':
58 | self.pklname = self.pklname_list[1]
59 | #self.h5name = self.h5name_list[1]
60 | with open(os.path.join(self.anno_root, self.pklname), 'rb') as f_pkl:
61 | data = pickle.load(f_pkl)
62 | self.val_labels = data['labels']
63 | self.val_captions = data['caption_id']
64 | self.val_images = data['images_path']
65 | #data_h5py = h5py.File(os.path.join(self.root, self.h5name), 'r')
66 | #self.val_images = data_h5py['images']
67 |
68 | elif self.split == 'test':
69 | self.pklname = self.pklname_list[2]
70 | #self.h5name = self.h5name_list[2]
71 |
72 | with open(os.path.join(self.anno_root, self.pklname), 'rb') as f_pkl:
73 | data = pickle.load(f_pkl)
74 | self.test_labels = data['labels']
75 | self.test_captions = data['caption_id']
76 | self.test_images = data['images_path']
77 |
78 | #data_h5py = h5py.File(os.path.join(self.root, self.h5name), 'r')
79 | #self.test_images = data_h5py['images']
80 |
81 | else:
82 | raise RuntimeError('Wrong split which should be one of "train","val" or "test"')
83 |
84 | def __getitem__(self, index):
85 | """
86 | Args:
87 | index(int): Index
88 | Returns:
89 | tuple: (images, labels, captions)
90 | """
91 | if self.split == 'train':
92 | img_path, caption, label = self.train_images[index], self.train_captions[index], self.train_labels[index]
93 | elif self.split == 'val':
94 | img_path, caption, label = self.val_images[index], self.val_captions[index], self.val_labels[index]
95 | else:
96 | img_path, caption, label = self.test_images[index], self.test_captions[index], self.test_labels[index]
97 | img_path = os.path.join(self.image_root, img_path)
98 | img = imread(img_path)
99 | img=np.array(Image.fromarray(img).resize(size=(224,224)))
100 | # img = imresize(img, (224,224))
101 | if len(img.shape) == 2:
102 | img = np.dstack((img,img,img))
103 | img = Image.fromarray(img)
104 |
105 | if self.transform is not None:
106 | img = self.transform(img)
107 |
108 | if self.target_transform is not None:
109 | label = self.target_transform(label)
110 |
111 | if self.cap_transform is not None:
112 | caption = self.cap_transform(caption)
113 | caption = caption[1:-1]
114 | caption = np.array(caption)
115 | caption, mask = self.fix_length(caption)
116 | return img, caption, label, mask
117 |
118 | def fix_length(self, caption):
119 | caption_len = caption.shape[0]
120 | if caption_len < self.max_length:
121 | pad = np.zeros((self.max_length - caption_len, 1), dtype=np.int64)
122 | caption = np.append(caption, pad)
123 | return caption, caption_len
124 |
125 | def __len__(self):
126 | if self.split == 'train':
127 | return len(self.train_labels)
128 | elif self.split == 'val':
129 | return len(self.val_labels)
130 | else:
131 | return len(self.test_labels)
132 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/datasets/preprocess.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pickle
3 | import json
4 | import argparse
5 | import string
6 | import os
7 | from directory import write_json, makedir
8 | from collections import namedtuple
9 |
10 |
11 | """
12 | Preprocesses the json file
13 | """
14 |
15 | ImageMetaData = namedtuple('ImageMetaData', ['id', 'image_path', 'captions', 'split'])
16 | ImageDecodeData = namedtuple('ImageDecodeData', ['id', 'image_path', 'captions_id', 'split'])
17 |
18 |
19 | class Vocabulary(object):
20 | """
21 | Vocabulary wrapper
22 | """
23 | def __init__(self, vocab, unk_id):
24 | """
25 | :param vocab: A dictionary of word to word_id
26 | :param unk_id: Id of the bad/unknown words
27 | """
28 | self._vocab = vocab
29 | self._unk_id = unk_id
30 |
31 | def word_to_id(self, word):
32 | if word not in self._vocab:
33 | return self._unk_id
34 | return self._vocab[word]
35 |
36 |
37 | def cap2tokens(cap):
38 | exclude = set(string.punctuation)
39 | caption = ''.join(c for c in cap if c not in exclude)
40 | tokens = caption.split()
41 | tokens = add_start_end(tokens)
42 | return tokens
43 |
44 |
45 | def add_start_end(tokens, start_word='', end_word=''):
46 | """
47 | Add start and end words for a caption
48 | """
49 | tokens_processed = [start_word]
50 | tokens_processed.extend(tokens)
51 | tokens_processed.append(end_word)
52 | return tokens_processed
53 |
54 |
55 | def process_captions(imgs):
56 | for img in imgs:
57 | img['processed_tokens'] = []
58 | for s in img['captions']:
59 | tokens = cap2tokens(s)
60 | img['processed_tokens'].append(tokens)
61 |
62 |
63 | def build_vocab(imgs, args):
64 | print('start build vodabulary')
65 | counts = {}
66 | for img in imgs:
67 | for tokens in img['processed_tokens']:
68 | for word in tokens:
69 | counts[word] = counts.get(word, 0) + 1
70 | print('Total words:', len(counts))
71 |
72 | # filter uncommon words and sort by descending count.
73 | # word_counts: a list of (words, count) for words satisfying the condition.
74 | word_counts = [(w,n) for w,n in counts.items() if n >= args.min_word_count]
75 | word_counts.sort(key = lambda x : x[1], reverse=True)
76 | print('Words in vocab:', len(word_counts))
77 |
78 | # words_out: a list of (words, count) for words unsatisfying the condition.
79 | words_out = [(w,n) for w,n in counts.items() if n < args.min_word_count]
80 | bad_words = len(words_out)
81 | bad_count = sum(x[1] for x in words_out)
82 |
83 | # save the word counts file
84 | word_counts_root = os.path.join(args.out_root + '/word_counts.txt')
85 | with open(word_counts_root, 'w') as f:
86 | f.write('Total words: %d \n' % len(counts))
87 | f.write('Words in vocabulary: %d \n' % len(word_counts))
88 | f.write(str(word_counts))
89 |
90 | word_reverse = [w for (w,n) in word_counts]
91 | vocab_dict = dict([(word, index) for (index, word) in enumerate(word_reverse)])
92 | vocab = Vocabulary(vocab_dict, len(vocab_dict))
93 |
94 | # Save word index as pickle form
95 | word_to_idx = {}
96 | for index, word in enumerate(word_reverse):
97 | word_to_idx[word] = index
98 |
99 | with open(os.path.join(args.out_root, 'word_to_index.pkl'), 'wb') as f:
100 | pickle.dump(word_to_idx, f)
101 |
102 | print('number of bad words: %d/%d = %.2f%%' % (bad_words, len(counts), bad_words * 100.0 / len(counts)))
103 | print('number of words in vocab: %d/%d = %.2f%%' % (len(word_counts), len(counts), len(word_counts) * 100.0 / len(counts)))
104 | print('number of Null: %d/%d = %.2f%%' % (bad_count, len(counts), bad_count * 100.0 / len(counts)))
105 |
106 | return vocab
107 |
108 | def load_vocab(args):
109 |
110 | with open(os.path.join(args.out_root, 'word_to_index.pkl'), 'rb') as f:
111 | word_to_idx = pickle.load(f)
112 |
113 | vocab = Vocabulary(word_to_idx, len(word_to_idx))
114 | print('load vocabulary done')
115 | return vocab
116 |
117 |
118 | def process_metadata(split, data, args):
119 | """
120 | Wrap data into ImageMatadata form
121 | """
122 | id_to_captions = {}
123 | image_metadata = []
124 | num_captions = 0
125 | count = 0
126 |
127 | for img in data:
128 | count += 1
129 | # absolute image path
130 | # filepath = os.path.join(args.img_root, img['file_path'])
131 | # relative image path
132 | filepath = img['file_path']
133 | # assert os.path.exists(filepath)
134 | id = img['id'] - 1
135 | captions = img['processed_tokens']
136 | id_to_captions.setdefault(id, [])
137 | id_to_captions[id].append(captions)
138 | assert split == img['split'], 'error: wrong split'
139 | image_metadata.append(ImageMetaData(id, filepath, captions, split))
140 | num_captions += len(captions)
141 |
142 | print("Process metadata done!")
143 | print("Total %d captions %d images %d identities in %s" % (num_captions, count, len(id_to_captions), split))
144 | with open(os.path.join(args.out_root, 'metadata_info.txt') ,'a') as f:
145 | f.write("Total %d captions %d images %d identities in %s" % (num_captions, count, len(id_to_captions), split))
146 | f.write('\n')
147 |
148 | return image_metadata
149 |
150 |
151 | def process_decodedata(data, vocab):
152 | """
153 | Decode ImageMetaData to ImageDecodeData
154 | Each item in imagedecodedata has 2 captions. (len(captions_id) = 2)
155 | """
156 | image_decodedata = []
157 | for img in data:
158 | image_path = img.image_path
159 | #image = imread(img.filepath)
160 | #image = imresize(image, (args.default_image_size, args.default_image_size))
161 | # handle grayscale input images
162 | #if len(image.shape) == 2:
163 | # image = np.dstack((image, image, image))
164 | # (height, width, channel) to (channel, height, weight)
165 | # (224,224,3) to (3,224,224))
166 | #image = image.transpose(2,0,1)
167 | cap_to_vec = []
168 | for cap in img.captions:
169 | cap_to_vec.append([vocab.word_to_id(word) for word in cap])
170 | image_decodedata.append(ImageDecodeData(img.id, image_path, cap_to_vec, img.split))
171 |
172 | print('Process decodedata done!')
173 |
174 | return image_decodedata
175 |
176 |
177 | def process_dataset(split, decodedata):
178 | # Process dataset
179 |
180 | # Arrange by caption in a sorted form
181 | dataset, label_range = create_dataset_sort(split, decodedata)
182 | write_dataset(split, dataset, args, label_range)
183 |
184 |
185 | def create_dataset_sort(split, data):
186 | images_sort = []
187 | label_range = {}
188 | images = {}
189 | for img in data:
190 | label = img.id
191 | image = [ImageDecodeData(img.id, img.image_path, [caption_id], img.split) for caption_id in img.captions_id]
192 | if label in images:
193 | images[label].extend(image)
194 | label_range[label].append(len(image))
195 | else:
196 | images[label] = image
197 | label_range[label] = [len(image)]
198 |
199 | print('=========== Arrange by id=============================')
200 | index = -1
201 | for label in images.keys():
202 | # all captions arrange together
203 | images_sort.extend(images[label])
204 | # label_range is arranged according to their actual index
205 | # label_range[label] = (previous, current]
206 | start = index
207 | for index_image in range(len(label_range[label])):
208 | label_range[label][index_image] += index
209 | index = label_range[label][index_image]
210 | label_range[label].append(start)
211 |
212 | return images_sort, label_range
213 |
214 |
215 | def write_dataset(split, data, args, label_range=None):
216 | """
217 | Separate each component
218 | Write dataset into binary file
219 | """
220 | caption_id = []
221 | images_path = []
222 | labels = []
223 |
224 | for img in data:
225 | assert len(img.captions_id) == 1
226 | caption_id.append(img.captions_id[0])
227 | labels.append(img.id)
228 | images_path.append(img.image_path)
229 |
230 | #N = len(images)
231 | data = {'caption_id':caption_id, 'labels':labels, 'images_path':images_path}
232 |
233 | if label_range is not None:
234 | data['label_range'] = label_range
235 | pickle_root = os.path.join(args.out_root, split + '_sort.pkl')
236 | else:
237 | pickle_root = os.path.join(args.out_root, split + '.pkl')
238 | # Write caption_id and labels as pickle form
239 | with open(pickle_root, 'wb') as f:
240 | pickle.dump(data, f)
241 |
242 | #h5py_root = os.path.join(args.out_root, split + '.h5')
243 | #f = h5py.File(h5py_root, 'w')
244 | #f.create_dataset('images', (N, 3, args.default_image_size, args.default_image_size), data=images)
245 |
246 | print('Save dataset')
247 |
248 |
249 | def generate_split(args):
250 |
251 | with open(args.json_root,'r') as f:
252 | imgs = json.load(f)
253 | # process caption
254 | process_captions(imgs)
255 | val_data = []
256 | train_data = []
257 | test_data = []
258 | for img in imgs:
259 | if img['split'] == 'train':
260 | train_data.append(img)
261 | elif img['split'] =='val':
262 | val_data.append(img)
263 | else:
264 | test_data.append(img)
265 | write_json(train_data, os.path.join(args.out_root, 'train_reid.json'))
266 | write_json(val_data, os.path.join(args.out_root, 'val_reid.json'))
267 | write_json(test_data, os.path.join(args.out_root, 'test_reid.json'))
268 |
269 | return [train_data, val_data, test_data]
270 |
271 |
272 | def load_split(args):
273 |
274 | data = []
275 | splits = ['train', 'val', 'test']
276 | for split in splits:
277 | split_root = os.path.join(args.out_root, split + '_reid.json')
278 | with open(split_root, 'r') as f:
279 | split_data = json.load(f)
280 | data.append(split_data)
281 |
282 | print('load data done')
283 | return data
284 |
285 |
286 | def process_data(args):
287 |
288 | if args.first:
289 | train_data, val_data, test_data = generate_split(args)
290 | vocab = build_vocab(train_data, args)
291 | else:
292 | train_data, val_data, test_data = load_split(args)
293 | vocab = load_vocab(args)
294 |
295 | # Transform original data to Imagedata form.
296 | train_metadata = process_metadata('train', train_data, args)
297 | val_metadata = process_metadata('val', val_data, args)
298 | test_metadata = process_metadata('test', test_data, args)
299 |
300 |
301 | # Decode Imagedata to index caption and replace image file_root with image vecetor.
302 | train_decodedata = process_decodedata(train_metadata, vocab)
303 | val_decodedata = process_decodedata(val_metadata, vocab)
304 | test_decodedata = process_decodedata(test_metadata, vocab)
305 |
306 |
307 | process_dataset('train', train_decodedata)
308 | process_dataset('val', val_decodedata)
309 | process_dataset('test', test_decodedata)
310 |
311 |
312 | def parse_args():
313 | parser = argparse.ArgumentParser(description='Command for data preprocessing')
314 | parser.add_argument('--img_root', type=str)
315 | parser.add_argument('--json_root', type=str)
316 | parser.add_argument('--out_root',type=str)
317 | parser.add_argument('--min_word_count', type=int)
318 | parser.add_argument('--default_image_size', type=int, default=224)
319 | parser.add_argument('--first', action='store_true')
320 | args = parser.parse_args()
321 | return args
322 |
323 | if __name__ == '__main__':
324 | args = parse_args()
325 | makedir(args.out_root)
326 | process_data(args)
327 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/__pycache__/bi_lstm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/bi_lstm.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/__pycache__/bi_lstm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/bi_lstm.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/__pycache__/eff_net.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/eff_net.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/__pycache__/efficient_net.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/efficient_net.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/__pycache__/mobilenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/mobilenet.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/__pycache__/mobilenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/mobilenet.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/models/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/bi_lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 |
5 | seed_num = 223
6 | torch.manual_seed(seed_num)
7 | random.seed(seed_num)
8 |
9 | """
10 | Neural Networks model : Bidirection LSTM
11 | """
12 |
13 |
14 | class BiLSTM(nn.Module):
15 | def __init__(self, args):
16 | super(BiLSTM, self).__init__()
17 |
18 | self.hidden_dim = args.num_lstm_units
19 |
20 | V = args.vocab_size
21 | D = args.embedding_size
22 |
23 | # word embedding
24 | self.embed = nn.Embedding(V, D, padding_idx=0)
25 |
26 | self.bilstm = nn.ModuleList()
27 | self.bilstm.append(nn.LSTM(D, args.num_lstm_units, num_layers=1, dropout=0, bidirectional=False, bias=False))
28 |
29 | self.bidirectional = args.bidirectional
30 | if self.bidirectional:
31 | self.bilstm.append(nn.LSTM(D, args.num_lstm_units, num_layers=1, dropout=0, bidirectional=False, bias=False))
32 |
33 |
34 | def forward(self, text, text_length):
35 | embed = self.embed(text)
36 |
37 | # unidirectional lstm
38 | bilstm_out = self.bilstm_out(embed, text_length, 0)
39 |
40 | if self.bidirectional:
41 | index_reverse = list(range(embed.shape[0]-1, -1, -1))
42 | index_reverse = torch.LongTensor(index_reverse).cuda()
43 | embed_reverse = embed.index_select(0, index_reverse)
44 | text_length_reverse = text_length.index_select(0, index_reverse)
45 | bilstm_out_bidirection = self.bilstm_out(embed_reverse, text_length_reverse, 1)
46 | bilstm_out_bidirection_reverse = bilstm_out_bidirection.index_select(0, index_reverse)
47 | bilstm_out = torch.cat([bilstm_out, bilstm_out_bidirection_reverse], dim=2)
48 | bilstm_out, _ = torch.max(bilstm_out, dim=1)
49 | bilstm_out = bilstm_out.unsqueeze(2).unsqueeze(2)
50 | return bilstm_out
51 |
52 |
53 | def bilstm_out(self, embed, text_length, index):
54 |
55 | _, idx_sort = torch.sort(text_length, dim=0, descending=True)
56 | _, idx_unsort = torch.sort(idx_sort, dim=0)
57 |
58 | embed_sort = embed.index_select(0, idx_sort)
59 | length_list = text_length[idx_sort]
60 | pack = nn.utils.rnn.pack_padded_sequence(embed_sort, length_list.cpu(), batch_first=True)
61 |
62 | bilstm_sort_out, _ = self.bilstm[index](pack)
63 | bilstm_sort_out = nn.utils.rnn.pad_packed_sequence(bilstm_sort_out, batch_first=True)
64 | bilstm_sort_out = bilstm_sort_out[0]
65 |
66 | bilstm_out = bilstm_sort_out.index_select(0, idx_unsort)
67 |
68 | return bilstm_out
69 |
70 |
71 | def weight_init(self, m):
72 | if isinstance(m, nn.Conv2d):
73 | nn.init.xavier_uniform_(m.weight.data, 1)
74 | nn.init.constant(m.bias.data, 0)
75 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/eff_net.py:
--------------------------------------------------------------------------------
1 | import json
2 | from PIL import Image
3 | import torch.nn as nn
4 | import math
5 |
6 | import torch
7 | from torchvision import transforms
8 | from efficientnet_pytorch import EfficientNet
9 |
10 | """
11 | Efficient Net
12 | """
13 |
14 | class EffNet(nn.Module):
15 | def __init__(self):
16 | super(EffNet,self).__init__()
17 | self.main_model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=1024)
18 |
19 | def forward(self,x):
20 | x = self.main_model(x)
21 | x = x.unsqueeze(-1)
22 | x = x.unsqueeze(-1)
23 | return x
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 | """
5 | Imported by https://github.com/marvis/pytorch-mobilenet/blob/master/main.py
6 | """
7 |
8 | """
9 | Mobilenet
10 | """
11 |
12 | class MobileNetV1(nn.Module):
13 | def __init__(self, dropout_keep_prob=0.999):
14 | super(MobileNetV1, self).__init__()
15 | self.dropout_keep_prob = dropout_keep_prob
16 | self.dropout = nn.Dropout(1 - dropout_keep_prob)
17 | def conv_bn(inp, oup, stride):
18 | return nn.Sequential(
19 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
20 | nn.BatchNorm2d(oup),
21 | nn.ReLU6(inplace=True)
22 | )
23 |
24 | def conv_dw(inp, oup, stride):
25 | return nn.Sequential(
26 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
27 | nn.BatchNorm2d(inp),
28 | nn.ReLU6(inplace=True),
29 |
30 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
31 | nn.BatchNorm2d(oup),
32 | nn.ReLU6(inplace=True),
33 | )
34 |
35 | self.model = nn.Sequential(
36 | conv_bn(3, 32, 2),
37 | conv_dw(32, 64, 1),
38 | conv_dw(64, 128, 2),
39 | conv_dw(128, 128, 1),
40 | conv_dw(128, 256, 2),
41 | conv_dw(256, 256, 1),
42 | conv_dw(256, 512, 2),
43 | conv_dw(512, 512, 1),
44 | conv_dw(512, 512, 1),
45 | conv_dw(512, 512, 1),
46 | conv_dw(512, 512, 1),
47 | conv_dw(512, 512, 1),
48 | conv_dw(512, 1024, 2),
49 | conv_dw(1024, 1024, 1),
50 | nn.AvgPool2d(7),
51 | )
52 |
53 |
54 | def weight_init(self, m):
55 | if isinstance(m, nn.Conv2d):
56 | # truncated_normal_initializer in tensorflow
57 | nn.init.normal_(m.weight.data, std=0.09)
58 | #nn.init.constant(m.bias.data, 0)
59 |
60 | def forward(self, x):
61 | x = self.model(x)
62 | x = self.dropout(x)
63 | return x
64 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .bi_lstm import BiLSTM
3 | from .mobilenet import MobileNetV1
4 | from .resnet import resnet50
5 | from .eff_net import EffNet
6 |
7 | """
8 | The entire model pipeline
9 | """
10 |
11 | class Model(nn.Module):
12 | def __init__(self, args):
13 | super(Model, self).__init__()
14 | if args.image_model == 'mobilenet_v1':
15 | self.image_model = MobileNetV1()
16 | self.image_model.apply(self.image_model.weight_init)
17 | elif args.image_model == 'resnet50':
18 | self.image_model = resnet50()
19 | elif args.image_model == 'resnet101':
20 | self.image_model = resnet101()
21 | elif args.image_model == 'efficient_net':
22 | self.image_model = EffNet()
23 |
24 | # self.bilstm = BiLSTM(args.num_lstm_units, num_stacked_layers = 1, vocab_size = args.vocab_size, embedding_dim = 512)
25 | self.bilstm = BiLSTM(args)
26 | self.bilstm.apply(self.bilstm.weight_init)
27 |
28 | inp_size = 1024
29 | if args.image_model == 'resnet50' or args.image_model == 'resnet101':
30 | inp_size = 2048
31 | # shorten the tensor using 1*1 conv
32 | self.conv_images = nn.Conv2d(inp_size, args.feature_size, 1)
33 | self.conv_text = nn.Conv2d(1024, args.feature_size, 1)
34 |
35 |
36 | def forward(self, images, text, text_length):
37 | image_features = self.image_model(images)
38 | # print("Image shape", image_features.shape)
39 | text_features = self.bilstm(text, text_length)
40 | image_embeddings, text_embeddings= self.build_joint_embeddings(image_features, text_features)
41 |
42 | return image_embeddings, text_embeddings
43 |
44 |
45 | def build_joint_embeddings(self, images_features, text_features):
46 |
47 | #images_features = images_features.permute(0,2,3,1)
48 | #text_features = text_features.permute(0,3,1,2)
49 | image_embeddings = self.conv_images(images_features).squeeze()
50 | text_embeddings = self.conv_text(text_features).squeeze()
51 |
52 | return image_embeddings, text_embeddings
53 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 |
5 | """
6 | Resnet
7 | """
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152']
11 |
12 |
13 | model_urls = {
14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
19 | }
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1):
23 | """3x3 convolution with padding"""
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=1, bias=False)
26 |
27 |
28 | def conv1x1(in_planes, out_planes, stride=1):
29 | """1x1 convolution"""
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
31 |
32 |
33 | class BasicBlock(nn.Module):
34 | expansion = 1
35 |
36 | def __init__(self, inplanes, planes, stride=1, downsample=None):
37 | super(BasicBlock, self).__init__()
38 | self.conv1 = conv3x3(inplanes, planes, stride)
39 | self.bn1 = nn.BatchNorm2d(planes)
40 | self.relu = nn.ReLU(inplace=True)
41 | self.conv2 = conv3x3(planes, planes)
42 | self.bn2 = nn.BatchNorm2d(planes)
43 | self.downsample = downsample
44 | self.stride = stride
45 |
46 | def forward(self, x):
47 | residual = x
48 |
49 | out = self.conv1(x)
50 | out = self.bn1(out)
51 | out = self.relu(out)
52 |
53 | out = self.conv2(out)
54 | out = self.bn2(out)
55 |
56 | if self.downsample is not None:
57 | residual = self.downsample(x)
58 |
59 | out += residual
60 | out = self.relu(out)
61 |
62 | return out
63 |
64 |
65 | class Bottleneck(nn.Module):
66 | expansion = 4
67 |
68 | def __init__(self, inplanes, planes, stride=1, downsample=None):
69 | super(Bottleneck, self).__init__()
70 | self.conv1 = conv1x1(inplanes, planes)
71 | self.bn1 = nn.BatchNorm2d(planes)
72 | self.conv2 = conv3x3(planes, planes, stride)
73 | self.bn2 = nn.BatchNorm2d(planes)
74 | self.conv3 = conv1x1(planes, planes * self.expansion)
75 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
76 | self.relu = nn.ReLU(inplace=True)
77 | self.downsample = downsample
78 | self.stride = stride
79 |
80 | def forward(self, x):
81 | residual = x
82 |
83 | out = self.conv1(x)
84 | out = self.bn1(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv2(out)
88 | out = self.bn2(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv3(out)
92 | out = self.bn3(out)
93 |
94 | if self.downsample is not None:
95 | residual = self.downsample(x)
96 |
97 | out += residual
98 | out = self.relu(out)
99 |
100 | return out
101 |
102 |
103 | class ResNet(nn.Module):
104 |
105 | def __init__(self, block, layers, num_classes=1000):
106 | super(ResNet, self).__init__()
107 | self.inplanes = 64
108 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
109 | bias=False)
110 | self.bn1 = nn.BatchNorm2d(64)
111 | self.relu = nn.ReLU(inplace=True)
112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
113 | self.layer1 = self._make_layer(block, 64, layers[0])
114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
117 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
118 | self.fc = nn.Linear(512 * block.expansion, num_classes)
119 |
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
123 | elif isinstance(m, nn.BatchNorm2d):
124 | nn.init.constant_(m.weight, 1)
125 | nn.init.constant_(m.bias, 0)
126 |
127 | def _make_layer(self, block, planes, blocks, stride=1):
128 | downsample = None
129 | if stride != 1 or self.inplanes != planes * block.expansion:
130 | downsample = nn.Sequential(
131 | conv1x1(self.inplanes, planes * block.expansion, stride),
132 | nn.BatchNorm2d(planes * block.expansion),
133 | )
134 |
135 | layers = []
136 | layers.append(block(self.inplanes, planes, stride, downsample))
137 | self.inplanes = planes * block.expansion
138 | for _ in range(1, blocks):
139 | layers.append(block(self.inplanes, planes))
140 |
141 | return nn.Sequential(*layers)
142 |
143 | def forward(self, x):
144 | x = self.conv1(x)
145 | x = self.bn1(x)
146 | x = self.relu(x)
147 | x = self.maxpool(x)
148 |
149 | x = self.layer1(x)
150 | x = self.layer2(x)
151 | x = self.layer3(x)
152 | x = self.layer4(x)
153 |
154 | x = self.avgpool(x)
155 | #x = x.view(x.size(0), -1)
156 | #x = self.fc(x)
157 |
158 | return x
159 |
160 |
161 | def resnet18(pretrained=False, **kwargs):
162 | """Constructs a ResNet-18 model.
163 |
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | """
167 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
168 | if pretrained:
169 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
170 | return model
171 |
172 |
173 | def resnet34(pretrained=False, **kwargs):
174 | """Constructs a ResNet-34 model.
175 |
176 | Args:
177 | pretrained (bool): If True, returns a model pre-trained on ImageNet
178 | """
179 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
180 | if pretrained:
181 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
182 | return model
183 |
184 |
185 | def resnet50(pretrained=False, **kwargs):
186 | """Constructs a ResNet-50 model.
187 |
188 | Args:
189 | pretrained (bool): If True, returns a model pre-trained on ImageNet
190 | """
191 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
192 | if pretrained:
193 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
194 | return model
195 |
196 |
197 | def resnet101(pretrained=False, **kwargs):
198 | """Constructs a ResNet-101 model.
199 |
200 | Args:
201 | pretrained (bool): If True, returns a model pre-trained on ImageNet
202 | """
203 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
204 | if pretrained:
205 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
206 | return model
207 |
208 |
209 | def resnet152(pretrained=False, **kwargs):
210 | """Constructs a ResNet-152 model.
211 |
212 | Args:
213 | pretrained (bool): If True, returns a model pre-trained on ImageNet
214 | """
215 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
216 | if pretrained:
217 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
218 | return model
219 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/scripts/tester.py:
--------------------------------------------------------------------------------
1 | # Used for testing the model
2 |
3 | import os
4 |
5 | GPUS='1'
6 | os.system('export CUDA_VISIBLE_DEVICES='+GPUS)
7 |
8 | BASE_ROOT='/content/Image_Text_Retrieval/deep_cmpl_model'
9 | BASE_ROOT_2='drive/Shareddrives/Image-Text-Retrieval/tempckpt'
10 | IMAGE_DIR='/content'
11 | ANNO_DIR=BASE_ROOT+'/data/processed_data'
12 | # CKPT_DIR=BASE_ROOT+'/data/model_data'
13 | # LOG_DIR=BASE_ROOT+'/data/logs'
14 | CKPT_DIR=BASE_ROOT_2+'/data/model_data'
15 | LOG_DIR=BASE_ROOT_2+'/data/logs'
16 | IMAGE_MODEL='efficient_net'
17 | lr='0.0002'
18 | batch_size='16'
19 | lr_decay_ratio='0.9'
20 | epoches_decay='80_150_200'
21 |
22 |
23 | string = 'python3 {BASE_ROOT}/code/test.py --bidirectional --model_path {CKPT_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --image_model {IMAGE_MODEL} --log_dir {LOG_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --image_dir {IMAGE_DIR} --anno_dir {ANNO_DIR} --gpus {GPUS} --epoch_ema 0'.format(BASE_ROOT=BASE_ROOT, IMAGE_DIR=IMAGE_DIR, IMAGE_MODEL=IMAGE_MODEL, ANNO_DIR=ANNO_DIR, CKPT_DIR=CKPT_DIR, LOG_DIR=LOG_DIR, lr=lr, batch_size=batch_size, lr_decay_ratio=lr_decay_ratio, GPUS=GPUS)
24 |
25 | os.system(string)
26 |
27 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/scripts/trainer.py:
--------------------------------------------------------------------------------
1 | # Used for training the model
2 |
3 | import os
4 |
5 | GPUS = '1'
6 |
7 | os.system('export CUDA_VISIBLE_DEVICES='+GPUS)
8 |
9 |
10 | BASE_ROOT='/content/Image_Text_Retrieval/deep_cmpl_model'
11 | BASE_ROOT_2='drive/Shareddrives/Image-Text-Retrieval/tempckpt'
12 | IMAGE_DIR='/content'
13 | ANNO_DIR=BASE_ROOT+'/data/processed_data'
14 | CKPT_DIR=BASE_ROOT_2+'/data/model_data'
15 | LOG_DIR=BASE_ROOT_2+'/data/logs'
16 | IMAGE_MODEL='efficient_net'
17 | lr='0.0002'
18 | num_epoches='300'
19 | batch_size='16'
20 | lr_decay_ratio='0.9'
21 | epoches_decay='80_150_200'
22 |
23 | models_path='{CKPT_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size}'.format(CKPT_DIR=CKPT_DIR,lr=lr,lr_decay_ratio=lr_decay_ratio,batch_size=batch_size)
24 |
25 | def get_model_path():
26 | # print(models_path)
27 | MODEL_PATH=None
28 | if os.path.exists(models_path):
29 | l=os.listdir(models_path)
30 | try:
31 | l.remove('model_best')
32 | except:
33 | pass
34 | # print(l)
35 | if len(l)==0:
36 | return None
37 | # print(len(l))
38 |
39 | l.sort(key=lambda x:int(x.split('.')[0]))
40 | if len(l)>=2:
41 | x=l[-2]
42 | else:
43 | x=l[-1]
44 | MODEL_PATH=os.path.join(models_path,x)
45 |
46 | return MODEL_PATH
47 |
48 | string = 'python3 {BASE_ROOT}/code/train.py --CMPM --bidirectional --image_model {IMAGE_MODEL} --log_dir {LOG_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --checkpoint_dir {CKPT_DIR}/lr-{lr}-decay-{lr_decay_ratio}-batch-{batch_size} --image_dir {IMAGE_DIR} --anno_dir {ANNO_DIR} --batch_size {batch_size} --gpus {GPUS} --num_epoches {num_epoches} --lr {lr} --lr_decay_ratio {lr_decay_ratio} --epoches_decay {epoches_decay} --num_images 12305'.format(BASE_ROOT=BASE_ROOT, IMAGE_DIR=IMAGE_DIR, IMAGE_MODEL=IMAGE_MODEL, ANNO_DIR=ANNO_DIR, CKPT_DIR=CKPT_DIR, LOG_DIR=LOG_DIR, lr=lr, num_epoches=num_epoches, batch_size=batch_size, lr_decay_ratio=lr_decay_ratio, epoches_decay=epoches_decay, GPUS=GPUS)
49 |
50 | MODEL_PATH = get_model_path()
51 | if(MODEL_PATH!=None):
52 | string += ' --resume --model_path {MODEL_PATH}'.format(MODEL_PATH=MODEL_PATH)
53 |
54 | os.system(string)
55 |
56 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import shutil
5 | import gc
6 | import torch
7 | import torchvision.transforms as transforms
8 | import torch.nn as nn
9 | import torch.utils.data as data
10 |
11 | from utils.helpers import avg_calculator
12 | from utils.metric import compute_topk, compute_mr
13 | from utils.directory import makedir, check_file
14 |
15 | from datasets.fashion import Fashion
16 | from models.model import Model
17 |
18 | from test_params import get_test_args
19 |
20 |
21 |
22 | """
23 | Calculates the recall rate and median rank metrics
24 | """
25 | def get_metrics(test_loader, network, args):
26 |
27 | batch_time = avg_calculator()
28 |
29 | # switch to evaluate mode
30 | network.eval()
31 |
32 | num_samples = args.batch_size * len(test_loader)
33 |
34 | images_bank = torch.zeros((num_samples, args.feature_size))
35 | text_bank = torch.zeros((num_samples,args.feature_size))
36 | labels_bank = torch.zeros(num_samples)
37 |
38 | index = 0
39 | with torch.no_grad():
40 |
41 | timer = time.time()
42 |
43 | for images, captions, labels, captions_length in test_loader:
44 |
45 | test_images = images
46 | test_captions = captions
47 |
48 | image_embeddings, text_embeddings = network(test_images, test_captions, captions_length)
49 |
50 | tsize = images.shape[0]
51 |
52 | images_bank[index: index + tsize] = image_embeddings
53 | text_bank[index: index + tsize] = text_embeddings
54 | labels_bank[index:index + tsize] = labels
55 | index+=tsize
56 |
57 | batch_time.update(time.time() - timer)
58 | timer = time.time()
59 |
60 | images_bank = images_bank[:index]
61 | text_bank = text_bank[:index]
62 | labels_bank = labels_bank[:index]
63 |
64 | i2t_top1, i2t_top5, i2t_top10, t2i_top1, t2i_top5, t2i_top10 = compute_topk(images_bank, text_bank, labels_bank, labels_bank, [1,5,10], True)
65 | i2t_mr, t2i_mr = compute_mr(images_bank, text_bank, labels_bank, labels_bank, 50, True)
66 |
67 |
68 | return i2t_top1, i2t_top5, i2t_top10, i2t_mr, t2i_top1, t2i_top5, t2i_top10, t2i_mr, batch_time.avg
69 |
70 |
71 | """
72 | Initialise the data loader
73 | """
74 | def get_data_loader(image_dir, anno_dir, batch_size, split, max_length):
75 |
76 | test_transform = transforms.Compose([
77 | transforms.ToTensor(),
78 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
79 | ])
80 |
81 | data_split = Fashion(image_dir, anno_dir, split, max_length, test_transform)
82 |
83 | loader = data.DataLoader(data_split, batch_size, shuffle=True, num_workers=2)
84 |
85 | return loader
86 |
87 |
88 | """
89 | Returns a list of model checkpoint paths
90 | """
91 | def get_test_model_paths(ckpt_path):
92 |
93 | test_models = os.listdir(ckpt_path)
94 | test_models.remove('model_best')
95 | test_models = sorted(test_models,key=lambda x: int(x.split(".")[0]))
96 | test_models = [os.path.join(ckpt_path,x) for x in test_models]
97 | # print(test_models)
98 | return test_models
99 |
100 |
101 |
102 | """
103 | Initialise the network object
104 | """
105 | def get_network(args, model_path=None):
106 |
107 | network = Model(args)
108 | network = nn.DataParallel(network)
109 | # cudnn.benchmark = True
110 | args.start_epoch = 0
111 |
112 | # process network params
113 | if model_path==None:
114 | raise ValueError('Supply the model path with --model_path while testing')
115 | check_file(model_path, 'model_file')
116 | checkpoint = torch.load(model_path)
117 | args.start_epoch = checkpoint['epoch'] + 1
118 | network_dict = checkpoint['network']
119 | network.load_state_dict(network_dict)
120 | print('==> Loading checkpoint "{}"'.format(model_path))
121 |
122 | return network
123 |
124 |
125 |
126 | """
127 | Tests the model on the test set using the checkpoints stored, prints the metrics
128 | """
129 | def main(args):
130 |
131 | test_loader = get_data_loader(args.image_dir, args.anno_dir, args.batch_size, 'test', args.max_length)
132 |
133 | i2t_top1 = 0.0
134 | i2t_top5 = 0.0
135 | i2t_top10 = 0.0
136 | i2t_mr = 0.0
137 |
138 | t2i_top1 = 0.0
139 | t2i_top5 = 0.0
140 | t2i_top10 = 0.0
141 | t2i_mr = 0.0
142 |
143 | test_models = get_test_model_paths(args.model_path)
144 | best_model_path = None
145 |
146 | for model_path in test_models:
147 |
148 | network= get_network(args, model_path)
149 |
150 | i2t_top1_cur, i2t_top5_cur, i2t_top10_cur, i2t_mr_cur, t2i_top1_cur, t2i_top5_cur, t2i_top10_cur, t2i_mr_cur, test_time = get_metrics(test_loader, network, args)
151 |
152 | if t2i_top1_cur > t2i_top1:
153 |
154 | i2t_top1 = i2t_top1_cur
155 | i2t_top5 = i2t_top5_cur
156 | i2t_top10 = i2t_top10_cur
157 | i2t_mr = i2t_mr_cur
158 |
159 | t2i_top1 = t2i_top1_cur
160 | t2i_top5 = t2i_top5_cur
161 | t2i_top10 = t2i_top10_cur
162 | t2i_mr = t2i_mr_cur
163 |
164 | best_model_path = model_path
165 | dst_best = os.path.join(args.model_path, 'model_best', 'best.pth.tar')
166 | shutil.copyfile(model_path, dst_best)
167 |
168 |
169 | # print("Best model: {}".format(best_model_path))
170 | print('t2i_top1_best: {:.3f}, t2i_top5_best: {:.3f}, t2i_top10_best: {:.3f}, t2i_mr_best: {:.3f}'.format(t2i_top1, t2i_top5, t2i_top10, t2i_mr))
171 | print('i2t_top1: {:.3f}, i2t_top5: {:.3f}, i2t_top10: {:.3f}, i2t_mr_best: {:.3f}'.format(i2t_top1, i2t_top5, i2t_top10, i2t_mr))
172 |
173 |
174 |
175 | if __name__ == '__main__':
176 | # get user's arguments
177 | args = get_test_args()
178 | main(args)
--------------------------------------------------------------------------------
/deep_cmpl_model/code/test_params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | """
5 | Used for setting the paramsters while testing
6 | """
7 |
8 | def get_test_args():
9 | parser = argparse.ArgumentParser(description='command for evaluate on Fashion Dataset')
10 |
11 | # Directory
12 | parser.add_argument('--image_dir', type=str, help='directory to store dataset')
13 | parser.add_argument('--anno_dir', type=str, help='directory to store anno')
14 | parser.add_argument('--model_path', type=str, help='directory to load checkpoint')
15 | parser.add_argument('--log_dir', type=str, help='directory to store log')
16 |
17 | # LSTM setting
18 | parser.add_argument('--embedding_size', type=int, default=512)
19 | parser.add_argument('--num_lstm_units', type=int, default=512)
20 | parser.add_argument('--vocab_size', type=int, default=12000)
21 | parser.add_argument('--lstm_dropout_ratio', type=float, default=0.7)
22 | parser.add_argument('--bidirectional', action='store_true')
23 | parser.add_argument('--max_length', type=int, default=100)
24 | parser.add_argument('--feature_size', type=int, default=512)
25 |
26 | # Model Setting
27 | parser.add_argument('--image_model', type=str, default='mobilenet_v1')
28 | parser.add_argument('--cnn_dropout_keep', type=float, default=0.999)
29 | parser.add_argument('--batch_size', type=int, default=64)
30 |
31 | # Optimization Settings
32 | parser.add_argument('--epoch_ema', type=int, default=0)
33 |
34 | # Default setting
35 | parser.add_argument('--gpus', type=str, default='0')
36 | args = parser.parse_args()
37 | return args
38 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import time
5 | import random
6 | import numpy as np
7 | import torch
8 | import torch.utils.data as data
9 | import torch.nn as nn
10 | import torchvision.transforms as transforms
11 |
12 | from utils.metric import Loss
13 | from utils.helpers import avg_calculator
14 | from utils.directory import makedir, check_file
15 |
16 | from datasets.fashion import Fashion
17 | from models.model import Model
18 |
19 | from train_params import get_train_args
20 |
21 |
22 | """
23 | Train for one epoch, apply backpropagation, return loss after current epoch
24 | """
25 | def train(epoch, train_loader, network, optimizer, compute_loss, args):
26 |
27 | # Trains for 1 epoch
28 | batch_time = avg_calculator()
29 | train_loss = avg_calculator()
30 |
31 | #switch to train mode, req in some modules
32 | network.train()
33 |
34 | end = time.time()
35 |
36 | for step, (images, captions, labels, captions_length) in enumerate(train_loader):
37 |
38 | images = images
39 | labels = labels
40 | captions = captions
41 |
42 | # compute loss
43 | image_embeddings, text_embeddings = network(images, captions, captions_length)
44 | cmpm_loss, pos_avg_sim, neg_arg_sim = compute_loss(image_embeddings, text_embeddings, labels)
45 |
46 |
47 | if step % 10 == 0:
48 | print('epoch:{}, step:{}, cmpm_loss:{:.3f}'.format(epoch, step, cmpm_loss))
49 |
50 |
51 | # compute gradient and do ADAM step
52 | optimizer.zero_grad()
53 | cmpm_loss.backward()
54 | #nn.utils.clip_grad_norm(network.parameters(), 5)
55 | optimizer.step()
56 |
57 | # measure elapsed time
58 | batch_time.update(time.time() - end)
59 | end = time.time()
60 |
61 | train_loss.update(cmpm_loss, images.shape[0])
62 |
63 | return train_loss.avg, batch_time.avg
64 |
65 |
66 |
67 | """
68 | Initialise the data loader
69 | """
70 | def get_data_loader(image_dir, anno_dir, batch_size, split, max_length):
71 |
72 | train_transform = transforms.Compose([
73 | transforms.RandomHorizontalFlip(),
74 | transforms.ToTensor(),
75 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
76 | ])
77 |
78 | data_split = Fashion(image_dir, anno_dir, split, max_length, train_transform)
79 |
80 | loader = data.DataLoader(data_split, batch_size, shuffle=True, num_workers=2)
81 |
82 | return loader
83 |
84 |
85 | """
86 | Initialise the network object
87 | """
88 | def get_network(args, resume=False, model_path=None):
89 |
90 | network = Model(args)
91 | network = nn.DataParallel(network)
92 | # cudnn.benchmark = True
93 | args.start_epoch = 0
94 |
95 | # process network params
96 | if resume:
97 | if model_path==None:
98 | raise ValueError('Supply the model path with --model_path while using ---resume')
99 | check_file(model_path, 'model_file')
100 | checkpoint = torch.load(model_path)
101 | args.start_epoch = checkpoint['epoch'] + 1
102 | network_dict = checkpoint['network']
103 | network.load_state_dict(network_dict)
104 | print('==> Loading checkpoint "{}"'.format(model_path))
105 |
106 | return network
107 |
108 |
109 | """
110 | Initialise optimizer object
111 | """
112 | def get_optimizer(args, network=None, param=None, resume=False, model_path=None):
113 |
114 | #process optimizer params
115 |
116 | # optimizer
117 | # different params for different part
118 | cnn_params = list(map(id, network.module.image_model.parameters()))
119 | other_params = filter(lambda p: id(p) not in cnn_params, network.parameters())
120 | other_params = list(other_params)
121 | if param is not None:
122 | other_params.extend(list(param))
123 | param_groups = [{'params':other_params}, {'params':network.module.image_model.parameters(), 'weight_decay':args.wd}]
124 |
125 | optimizer = torch.optim.Adam(
126 | param_groups,
127 | lr = args.lr, betas=(args.adam_alpha, args.adam_beta), eps=args.epsilon
128 | )
129 |
130 | if resume:
131 | check_file(model_path, 'model_file')
132 | checkpoint = torch.load(model_path)
133 | optimizer.load_state_dict(checkpoint['optimizer'])
134 |
135 | print('Total params: %2.fM' % (sum(p.numel() for p in network.parameters()) / 1000000.0))
136 | # seed
137 |
138 | manualSeed = random.randint(1, 10000)
139 | random.seed(manualSeed)
140 | np.random.seed(manualSeed)
141 | torch.manual_seed(manualSeed)
142 | # torch.cuda.manual_seed_all(manualSeed)
143 |
144 | return optimizer
145 |
146 |
147 | """
148 | Modify learning rate
149 | """
150 | def adjust_lr(optimizer, epoch, args):
151 |
152 | # Decay learning rate by args.lr_decay_ratio every args.epoches_decay
153 |
154 | if args.lr_decay_type == 'exponential':
155 |
156 | lr = args.lr * (1 - args.lr_decay_ratio)
157 |
158 | for param_group in optimizer.param_groups:
159 | param_group['lr'] = lr
160 |
161 |
162 | """
163 | Multistep Learning Rate Scheduler which decays learning rate after a no. of epochs specified by the user
164 | """
165 | def lr_scheduler(optimizer, args):
166 |
167 | if '_' in args.epoches_decay:
168 | epoches_list = args.epoches_decay.split('_')
169 | epoches_list = [int(e) for e in epoches_list]
170 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, epoches_list)
171 | else:
172 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(args.epoches_decay))
173 |
174 | return scheduler
175 |
176 |
177 | """
178 | Saves the checkpoints
179 | """
180 | def save_checkpoint(state, epoch, ckpt_dir, is_best):
181 |
182 | filename = os.path.join(ckpt_dir, str(args.start_epoch + epoch)) + '.pth.tar'
183 | torch.save(state, filename)
184 | if is_best:
185 | dst_best = os.path.join(ckpt_dir, 'model_best', str(epoch)) + '.pth.tar'
186 | shutil.copyfile(filename, dst_best)
187 |
188 |
189 |
190 |
191 | """
192 | Initializes data loader, loss object, network, optimizer, runs the desired no. of epochs
193 | """
194 | def main(args):
195 |
196 | train_loader = get_data_loader(args.image_dir, args.anno_dir, args.batch_size, 'train', args.max_length)
197 |
198 | # loss
199 | compute_loss = Loss(args)
200 | nn.DataParallel(compute_loss)
201 |
202 | # network
203 | network = get_network(args, args.resume, args.model_path)
204 | optimizer = get_optimizer(args, network, compute_loss.parameters(), args.resume, args.model_path)
205 |
206 | # lr_scheduler
207 | scheduler = lr_scheduler(optimizer, args)
208 |
209 | for epoch in range(args.num_epoches - args.start_epoch):
210 | # train for one epoch
211 | train_loss, train_time = train(args.start_epoch + epoch, train_loader, network, optimizer, compute_loss, args)
212 | # evaluate on validation set
213 | print('Train done for epoch-{}'.format(args.start_epoch + epoch))
214 |
215 |
216 | state = {'network': network.state_dict(),
217 | 'optimizer': optimizer.state_dict(),
218 | 'W': compute_loss.W,
219 | 'epoch': args.start_epoch + epoch
220 | }
221 |
222 | save_checkpoint(state, epoch, args.checkpoint_dir, False)
223 |
224 | adjust_lr(optimizer, args.start_epoch + epoch, args)
225 | scheduler.step()
226 |
227 |
228 | for param in optimizer.param_groups:
229 | print('lr:{}'.format(param['lr']))
230 | break
231 |
232 |
233 |
234 |
235 | if __name__ == "__main__":
236 |
237 | # Get arguments passed by user
238 | args = get_train_args()
239 |
240 | # Validate existence of image and annotation directory
241 | if not os.path.exists(args.image_dir):
242 | raise ValueError('Supply the dataset directory with --image_dir')
243 | if not os.path.exists(args.anno_dir):
244 | raise ValueError('Supply the anno file with --anno_dir')
245 |
246 | # save checkpoint
247 | makedir(args.checkpoint_dir)
248 | makedir(os.path.join(args.checkpoint_dir,'model_best'))
249 |
250 | main(args)
--------------------------------------------------------------------------------
/deep_cmpl_model/code/train_params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | """
5 | Used for setting the paramsters for model, bilstm, etc. while training
6 | """
7 |
8 | def get_train_args():
9 | parser = argparse.ArgumentParser(description='command for train on Fashion Dataset')
10 |
11 | # Directory
12 | parser.add_argument('--image_dir', type=str, help='directory to store dataset')
13 | parser.add_argument('--anno_dir', type=str, help='directory to store anno file')
14 | parser.add_argument('--checkpoint_dir', type=str, help='directory to store checkpoint')
15 | parser.add_argument('--log_dir', type=str, help='directory to store log')
16 | parser.add_argument('--model_path', type=str, default = None, help='directory to pretrained model, whole model or just visual part')
17 |
18 | # LSTM setting
19 | parser.add_argument('--embedding_size', type=int, default=512)
20 | parser.add_argument('--num_lstm_units', type=int, default=512)
21 | parser.add_argument('--vocab_size', type=int, default=12000)
22 | parser.add_argument('--lstm_dropout_ratio', type=float, default=0.7)
23 | parser.add_argument('--max_length', type=int, default=100)
24 | parser.add_argument('--bidirectional', action='store_true')
25 |
26 | # Model setting
27 | parser.add_argument('--image_model', type=str, default='mobilenet_v1')
28 | parser.add_argument('--resume', action='store_true', help='whether or not to restore the pretrained whole model')
29 | parser.add_argument('--batch_size', type=int, default=16)
30 | parser.add_argument('--num_epoches', type=int, default=100)
31 | parser.add_argument('--ckpt_steps', type=int, default=5000, help='#steps to save checkpoint')
32 | parser.add_argument('--feature_size', type=int, default=512)
33 | parser.add_argument('--img_model', type=str, default='mobilenet_v1', help='model to train images')
34 | parser.add_argument('--loss_weight', type=float, default=1)
35 | parser.add_argument('--CMPM', action='store_true')
36 | parser.add_argument('--cnn_dropout_keep', type=float, default=0.999)
37 | parser.add_argument('--constraints_text', action='store_true')
38 | parser.add_argument('--constraints_images', action='store_true')
39 | parser.add_argument('--num_images', type=int, default=12305)
40 | parser.add_argument('--pretrained', action='store_true', help='whether or not to restore the pretrained visual model')
41 |
42 | # Optimization setting
43 | parser.add_argument('--optimizer', type=str, default='adam', help='one of "sgd", "adam", "rmsprop", "adadelta", or "adagrad"')
44 | parser.add_argument('--lr', type=float, default=0.0002)
45 | parser.add_argument('--wd', type=float, default=0.00004)
46 | parser.add_argument('--adam_alpha', type=float, default=0.9)
47 | parser.add_argument('--adam_beta', type=float, default=0.999)
48 | parser.add_argument('--epsilon', type=float, default=1e-8)
49 | parser.add_argument('--end_lr', type=float, default=0.0001, help='minimum end learning rate used by a polynomial decay learning rate')
50 | parser.add_argument('--lr_decay_type', type=str, default='fixed', help='One of "fixed" or "exponential"')
51 | parser.add_argument('--lr_decay_ratio', type=float, default=0.1)
52 | parser.add_argument('--epoches_decay', type=str, default='50_100', help='#epoches when learning rate decays')
53 | parser.add_argument('--epoch_ema', type=int, default=0)
54 | parser.add_argument('--ema_decay', type=float, default=0.9)
55 |
56 | # Default setting
57 | parser.add_argument('--gpus', type=str, default='0')
58 |
59 | args = parser.parse_args()
60 | return args
--------------------------------------------------------------------------------
/deep_cmpl_model/code/utils/__pycache__/directory.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/utils/__pycache__/directory.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/utils/__pycache__/directory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/utils/__pycache__/directory.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/utils/__pycache__/helpers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/utils/__pycache__/helpers.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/utils/__pycache__/metric.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/utils/__pycache__/metric.cpython-37.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/utils/__pycache__/metric.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/code/utils/__pycache__/metric.cpython-38.pyc
--------------------------------------------------------------------------------
/deep_cmpl_model/code/utils/directory.py:
--------------------------------------------------------------------------------
1 | # Directory related helper functions
2 |
3 | import os
4 | import json
5 |
6 | def makedir(root):
7 | if not os.path.exists(root):
8 | os.makedirs(root)
9 |
10 |
11 | def write_json(data, root):
12 | with open(root, 'w') as f:
13 | json.dump(data, f)
14 |
15 |
16 | def check_exists(root):
17 | if os.path.exists(root):
18 | return True
19 | return False
20 |
21 | def check_file(root, keyword):
22 | if not os.path.isfile(root):
23 | raise RuntimeError('===> No {} in {}'.format(keyword, root))
24 |
--------------------------------------------------------------------------------
/deep_cmpl_model/code/utils/helpers.py:
--------------------------------------------------------------------------------
1 |
2 | # Class to calculate average of a quantity over time
3 | class avg_calculator():
4 |
5 | def __init__(self):
6 | self.reset()
7 |
8 | def reset(self):
9 | self.val = 0
10 | self.avg = 0
11 | self.sum = 0
12 | self.count = 0
13 |
14 | def update(self, val, n=1):
15 | self.val = val
16 | self.sum += n * val
17 | self.count += n
18 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/deep_cmpl_model/code/utils/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | import numpy as np
6 | from torch.nn.parameter import Parameter
7 | from torch.autograd import Variable
8 | from statistics import median
9 |
10 |
11 | def pairwise_distance(A, B):
12 |
13 | A_square = torch.sum(A * A, dim=1, keepdim=True)
14 | B_square = torch.sum(B * B, dim=1, keepdim=True)
15 |
16 | distance = A_square + B_square.t() - 2 * torch.matmul(A, B.t())
17 |
18 | return distance
19 |
20 |
21 | def one_hot_coding(index, k):
22 | if type(index) is torch.Tensor:
23 | length = len(index)
24 | else:
25 | length = 1
26 | out = torch.zeros((length, k), dtype=torch.int64).cuda()
27 | index = index.reshape((len(index), 1))
28 | out.scatter_(1, index, 1)
29 | return out
30 |
31 |
32 |
33 | """
34 | LOSS MODULE
35 | """
36 |
37 |
38 | class Loss(nn.Module):
39 |
40 | def __init__(self, args):
41 |
42 | super(Loss, self).__init__()
43 |
44 | self.CMPM = True
45 | self.epsilon = args.epsilon
46 |
47 | self.num_images = args.num_images
48 |
49 | if args.resume:
50 | checkpoint = torch.load(args.model_path)
51 | self.W = Parameter(checkpoint['W'])
52 | print('=====> Loading weights from pretrained path')
53 | else:
54 | self.W = Parameter(torch.randn(args.feature_size, args.num_images))
55 | nn.init.xavier_uniform_(self.W.data, gain=1)
56 |
57 |
58 | # CMPM Loss
59 | def compute_cmpm_loss(self, image_embeddings, text_embeddings, labels):
60 | """
61 | Cross-Modal Projection Matching Loss(CMPM)
62 | :param image_embeddings: Tensor with dtype torch.float32
63 | :param text_embeddings: Tensor with dtype torch.float32
64 | :param labels: Tensor with dtype torch.int32
65 | :return:
66 | i2t_loss: cmpm loss for image projected to text
67 | t2i_loss: cmpm loss for text projected to image
68 | pos_avg_sim: average cosine-similarity for positive pairs
69 | neg_avg_sim: averate cosine-similarity for negative pairs
70 | """
71 |
72 | batch_size = image_embeddings.shape[0]
73 | labels_reshape = torch.reshape(labels, (batch_size, 1))
74 | labels_dist = labels_reshape - labels_reshape.t()
75 | labels_mask = (labels_dist == 0)
76 |
77 | image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
78 | text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
79 | image_proj_text = torch.matmul(image_embeddings, text_norm.t())
80 | text_proj_image = torch.matmul(text_embeddings, image_norm.t())
81 |
82 | # normalize the true matching distribution
83 | labels_mask_norm = labels_mask.float() / labels_mask.float().norm(dim=1)
84 |
85 | i2t_pred = F.softmax(image_proj_text, dim=1)
86 | #i2t_loss = i2t_pred * torch.log((i2t_pred + self.epsilon)/ (labels_mask_norm + self.epsilon))
87 | i2t_loss = i2t_pred.to(device="cpu") * (F.log_softmax(image_proj_text.to(device="cpu"), dim=1) - torch.log(labels_mask_norm.to(device="cpu") + self.epsilon))
88 |
89 | t2i_pred = F.softmax(text_proj_image, dim=1)
90 | #t2i_loss = t2i_pred * torch.log((t2i_pred + self.epsilon)/ (labels_mask_norm + self.epsilon))
91 | t2i_loss = t2i_pred.to(device="cpu") * (F.log_softmax(text_proj_image.to(device="cpu"), dim=1) - torch.log(labels_mask_norm.to(device="cpu") + self.epsilon))
92 |
93 | cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1))
94 |
95 | sim_cos = torch.matmul(image_norm, text_norm.t())
96 |
97 | pos_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask))
98 | neg_avg_sim = torch.mean(torch.masked_select(sim_cos, labels_mask == 0))
99 |
100 | return cmpm_loss, pos_avg_sim, neg_avg_sim
101 |
102 |
103 | def forward(self, image_embeddings, text_embeddings, labels):
104 |
105 | cmpm_loss = 0.0
106 | neg_avg_sim = 0.0
107 | pos_avg_sim =0.0
108 |
109 | if self.CMPM:
110 | cmpm_loss, pos_avg_sim, neg_avg_sim = self.compute_cmpm_loss(image_embeddings, text_embeddings, labels)
111 |
112 | return cmpm_loss, pos_avg_sim, neg_avg_sim
113 |
114 |
115 | """"
116 | Recall rate and Median Rank
117 | """
118 |
119 | # Computes the recall rate
120 | def compute_topk(query, gallery, target_query, target_gallery, k=[1,5,10], reverse=False):
121 | result = []
122 | query = query / query.norm(dim=1,keepdim=True)
123 | gallery = gallery / gallery.norm(dim=1,keepdim=True)
124 | sim_cosine = torch.matmul(query, gallery.t())
125 | result.extend(topk(sim_cosine, target_gallery, target_query, k=[1,5,10]))
126 | if reverse:
127 | result.extend(topk(sim_cosine, target_query, target_gallery, k=[1,5,10], dim=0))
128 | return result
129 |
130 |
131 | def topk(sim, target_gallery, target_query, k=[1,5,10], dim=1):
132 | result = []
133 | maxk = max(k)
134 | size_total = len(target_gallery)
135 | _, pred_index = sim.topk(maxk, dim, True, True)
136 | pred_labels = target_gallery[pred_index]
137 | if dim == 1:
138 | pred_labels = pred_labels.t()
139 | correct = pred_labels.eq(target_query.view(1,-1).expand_as(pred_labels))
140 |
141 | for topk in k:
142 | #correct_k = torch.sum(correct[:topk]).float()
143 | correct_k = torch.sum(correct[:topk], dim=0)
144 | correct_k = torch.sum(correct_k > 0).float()
145 | result.append(correct_k * 100 / size_total)
146 | return result
147 |
148 |
149 | """
150 | Computes the Median Rank
151 | """
152 | def compute_mr(query, gallery, target_query, target_gallery, k, reverse=False):
153 | result = []
154 | query = query / query.norm(dim=1,keepdim=True)
155 | gallery = gallery / gallery.norm(dim=1,keepdim=True)
156 | sim_cosine = torch.matmul(query, gallery.t())
157 | result.extend(mr(sim_cosine, target_gallery, target_query, k))
158 | if reverse:
159 | result.extend(mr(sim_cosine, target_query, target_gallery, k, dim=0))
160 | return result
161 |
162 |
163 | def mr(sim, target_gallery, target_query, k, dim=1):
164 | result = []
165 | maxk = k
166 | size_total = len(target_gallery)
167 | _, pred_index = sim.topk(maxk, dim, True, True)
168 | pred_labels = target_gallery[pred_index]
169 | if dim == 1:
170 | pred_labels = pred_labels.t()
171 | correct = pred_labels.eq(target_query.view(1,-1).expand_as(pred_labels))
172 |
173 | ranks = []
174 | for row in correct.t():
175 | temp = torch.where(row > 0)[0]
176 | if temp.shape[0] > 0:
177 | ranks.append(temp[0].item() + 1)
178 | else:
179 | ranks.append(k)
180 | # print('incr. k')
181 |
182 | result.append(median(ranks) * 100 / size_total)
183 | return result
--------------------------------------------------------------------------------
/deep_cmpl_model/data/make_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import csv
3 | import random
4 | import os
5 | from math import floor
6 |
7 | random.seed(10)
8 |
9 | def get_len(csv_path,header=True):
10 | # Returns no of samples in the dataset
11 | with open(csv_path,"r") as csv_file:
12 | csv_reader=csv.reader(csv_file,delimiter=',')
13 | line_count=0
14 | l=[]
15 | for row in csv_reader:
16 | l.append(row[2])
17 | line_count+=1
18 | return line_count-1
19 |
20 |
21 | def generate_split(num_samples,train_perc,val_perc):
22 | ids=list(range(num_samples))
23 | # print(len(ids))
24 | random.shuffle(ids)
25 | train_size = floor(num_samples*train_perc)
26 | val_size = floor(num_samples*val_perc)
27 | train_ids = ids[:train_size]
28 | val_ids = ids[train_size:train_size+val_size]
29 | test_ids=ids[train_size+val_size:]
30 |
31 | return train_ids,val_ids,test_ids
32 |
33 | def make_file(num_samples=None):
34 | if num_samples==None:
35 | num_samples=get_len(csv_path)
36 | print(num_samples)
37 | # num_samples = 12305
38 | # train -> 10000, test-> 2305
39 | train_ids,val_ids,test_ids = generate_split(num_samples,0.8,0.1)
40 |
41 | data_list=[]
42 | id = 1
43 | with open(csv_path,"r") as csv_file:
44 | csv_reader=csv.reader(csv_file,delimiter=',')
45 | line_no=0
46 | for row in csv_reader:
47 | line_no+=1
48 | if line_no==1:
49 | continue
50 | # if(line_no>num_samples):
51 | # break
52 | description, image_id = row[1],int(row[2])
53 | sample_dict={}
54 | if image_id in train_ids:
55 | split="train"
56 | elif image_id in val_ids:
57 | split="val"
58 | elif image_id in test_ids:
59 | split="test"
60 | else:
61 | # print("**",image_id)
62 | # raise Exception("Sample not alloted")
63 | continue
64 |
65 | sample_dict["split"] = split
66 | sample_dict["captions"] = [description]
67 | sample_dict["file_path"] = os.path.join(img_path,str(line_no-2)+".jpg")
68 | sample_dict["processed_tokens"]=[[]]
69 | sample_dict["id"]=image_id
70 | id+=1
71 |
72 | data_list.append(sample_dict)
73 |
74 | sorted(data_list,key=lambda x:x["id"])
75 |
76 | with open(out_path,"w") as f:
77 | json.dump(data_list,f)
78 |
79 |
80 | if __name__=="__main__":
81 | parent_folder = "/content/Image_Text_Retrieval/deep_cmpl_model"
82 | csv_path = parent_folder + "/data/images.csv"
83 | img_path = "dataset"
84 | out_path = parent_folder + "/data/reid_raw.json"
85 | make_file()
86 |
87 |
88 |
89 | # {"split": "train",
90 | # "captions":
91 | # ["A pedestrian with dark hair is wearing red and white shoes, a black hooded sweatshirt, and black pants.",
92 | # "The person has short black hair and is wearing black pants, a long sleeve black top, and red sneakers."],
93 | # "file_path": "CUHK01/0363004.png",
94 | # "processed_tokens": [["a", "pedestrian", "with", "dark", "hair", "is", "wearing", "red", "and", "white", "shoes", "a", "black", "hooded", "sweatshirt", "and", "black", "pants"],
95 | # ["the", "person", "has", "short", "black", "hair", "and", "is", "wearing", "black", "pants", "a", "long", "sleeve", "black", "top", "and", "red", "sneakers"]],
96 | # "id": 1}
--------------------------------------------------------------------------------
/deep_cmpl_model/data/processed_data/metadata_info.txt:
--------------------------------------------------------------------------------
1 | Total 800 captions 800 images 800 identities in train
2 | Total 100 captions 100 images 100 identities in val
3 | Total 100 captions 100 images 100 identities in test
4 |
--------------------------------------------------------------------------------
/deep_cmpl_model/data/processed_data/test_reid.json:
--------------------------------------------------------------------------------
1 | [{"split": "test", "captions": ["Soch Women Navy Blue & Grey Dyed Straight Kurta"], "file_path": "dataset/4.jpg", "processed_tokens": [["", "Soch", "Women", "Navy", "Blue", "Grey", "Dyed", "Straight", "Kurta", ""]], "id": 4}, {"split": "test", "captions": ["Soch Women Off-White & Blue Printed A-Line Kurta"], "file_path": "dataset/15.jpg", "processed_tokens": [["", "Soch", "Women", "OffWhite", "Blue", "Printed", "ALine", "Kurta", ""]], "id": 15}, {"split": "test", "captions": ["Sztori Women Green Solid Semi Sheer Straight Kurta"], "file_path": "dataset/32.jpg", "processed_tokens": [["", "Sztori", "Women", "Green", "Solid", "Semi", "Sheer", "Straight", "Kurta", ""]], "id": 32}, {"split": "test", "captions": ["HERE&NOW Women Grey Solid Kurta with Palazzos"], "file_path": "dataset/33.jpg", "processed_tokens": [["", "HERENOW", "Women", "Grey", "Solid", "Kurta", "with", "Palazzos", ""]], "id": 33}, {"split": "test", "captions": ["Sangria Women Navy Blue & Golden Striped Straight Kurta"], "file_path": "dataset/35.jpg", "processed_tokens": [["", "Sangria", "Women", "Navy", "Blue", "Golden", "Striped", "Straight", "Kurta", ""]], "id": 35}, {"split": "test", "captions": ["Fabindia Women Pink & Green Printed Slim Fit A-Line Kurta"], "file_path": "dataset/37.jpg", "processed_tokens": [["", "Fabindia", "Women", "Pink", "Green", "Printed", "Slim", "Fit", "ALine", "Kurta", ""]], "id": 37}, {"split": "test", "captions": ["Libas Women Maroon Yoke Design Straight Kurta"], "file_path": "dataset/45.jpg", "processed_tokens": [["", "Libas", "Women", "Maroon", "Yoke", "Design", "Straight", "Kurta", ""]], "id": 45}, {"split": "test", "captions": ["Myshka Women Pink Yoke Design A-Line Kurta"], "file_path": "dataset/66.jpg", "processed_tokens": [["", "Myshka", "Women", "Pink", "Yoke", "Design", "ALine", "Kurta", ""]], "id": 66}, {"split": "test", "captions": ["Fabindia Women Grey & Black Slim Fit Checked A-Line Kurta"], "file_path": "dataset/73.jpg", "processed_tokens": [["", "Fabindia", "Women", "Grey", "Black", "Slim", "Fit", "Checked", "ALine", "Kurta", ""]], "id": 73}, {"split": "test", "captions": ["Soch Women Green Yoke Design Straight Kurta"], "file_path": "dataset/77.jpg", "processed_tokens": [["", "Soch", "Women", "Green", "Yoke", "Design", "Straight", "Kurta", ""]], "id": 77}, {"split": "test", "captions": ["Anubhutee Women Navy Blue & Off-White Yoke Design Kurta with Palazzos"], "file_path": "dataset/82.jpg", "processed_tokens": [["", "Anubhutee", "Women", "Navy", "Blue", "OffWhite", "Yoke", "Design", "Kurta", "with", "Palazzos", ""]], "id": 82}, {"split": "test", "captions": ["Libas Women Red & Golden Block Print Kurta with Palazzos & Dupatta"], "file_path": "dataset/115.jpg", "processed_tokens": [["", "Libas", "Women", "Red", "Golden", "Block", "Print", "Kurta", "with", "Palazzos", "Dupatta", ""]], "id": 115}, {"split": "test", "captions": ["Vishudh Women Peach-Coloured Solid Straight Kurta"], "file_path": "dataset/136.jpg", "processed_tokens": [["", "Vishudh", "Women", "PeachColoured", "Solid", "Straight", "Kurta", ""]], "id": 136}, {"split": "test", "captions": ["Soch Women Red & Gold Printed Straight Kurta"], "file_path": "dataset/137.jpg", "processed_tokens": [["", "Soch", "Women", "Red", "Gold", "Printed", "Straight", "Kurta", ""]], "id": 137}, {"split": "test", "captions": ["Anouk Women Pink Printed Kurta with Palazzos"], "file_path": "dataset/142.jpg", "processed_tokens": [["", "Anouk", "Women", "Pink", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 142}, {"split": "test", "captions": ["Vishudh Women Black Printed A-Line Kurta"], "file_path": "dataset/153.jpg", "processed_tokens": [["", "Vishudh", "Women", "Black", "Printed", "ALine", "Kurta", ""]], "id": 153}, {"split": "test", "captions": ["Jompers Women Beige & Green Printed A-Line Kurta"], "file_path": "dataset/160.jpg", "processed_tokens": [["", "Jompers", "Women", "Beige", "Green", "Printed", "ALine", "Kurta", ""]], "id": 160}, {"split": "test", "captions": ["Vishudh Women Green Printed Straight Kurta"], "file_path": "dataset/164.jpg", "processed_tokens": [["", "Vishudh", "Women", "Green", "Printed", "Straight", "Kurta", ""]], "id": 164}, {"split": "test", "captions": ["ZIYAA Women Cream-Coloured Printed Kurta with Trousers"], "file_path": "dataset/178.jpg", "processed_tokens": [["", "ZIYAA", "Women", "CreamColoured", "Printed", "Kurta", "with", "Trousers", ""]], "id": 178}, {"split": "test", "captions": ["HERE&NOW Women Navy Blue & White Printed Straight Kurta"], "file_path": "dataset/199.jpg", "processed_tokens": [["", "HERENOW", "Women", "Navy", "Blue", "White", "Printed", "Straight", "Kurta", ""]], "id": 199}, {"split": "test", "captions": ["Anouk Women Olive Green & Mustard Brown Printed Kurta with Churidar & Dupatta"], "file_path": "dataset/211.jpg", "processed_tokens": [["", "Anouk", "Women", "Olive", "Green", "Mustard", "Brown", "Printed", "Kurta", "with", "Churidar", "Dupatta", ""]], "id": 211}, {"split": "test", "captions": ["Nayo Women Pink & Green Printed Kurta with Palazzos"], "file_path": "dataset/229.jpg", "processed_tokens": [["", "Nayo", "Women", "Pink", "Green", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 229}, {"split": "test", "captions": ["Rain & Rainbow Women Black & Mustard Yellow Printed A-Line Kurta"], "file_path": "dataset/241.jpg", "processed_tokens": [["", "Rain", "Rainbow", "Women", "Black", "Mustard", "Yellow", "Printed", "ALine", "Kurta", ""]], "id": 241}, {"split": "test", "captions": ["Vishudh Women Burgundy & Solid Kurta with Palazzos "], "file_path": "dataset/244.jpg", "processed_tokens": [["", "Vishudh", "Women", "Burgundy", "Solid", "Kurta", "with", "Palazzos", ""]], "id": 244}, {"split": "test", "captions": ["SASSAFRAS Women Navy Blue & Orange Printed A-Line Kurta"], "file_path": "dataset/245.jpg", "processed_tokens": [["", "SASSAFRAS", "Women", "Navy", "Blue", "Orange", "Printed", "ALine", "Kurta", ""]], "id": 245}, {"split": "test", "captions": ["Vishudh Women Green Embroidered Kurta with Palazzos"], "file_path": "dataset/255.jpg", "processed_tokens": [["", "Vishudh", "Women", "Green", "Embroidered", "Kurta", "with", "Palazzos", ""]], "id": 255}, {"split": "test", "captions": ["Vedic Women Brown Solid A-Line Kurta"], "file_path": "dataset/268.jpg", "processed_tokens": [["", "Vedic", "Women", "Brown", "Solid", "ALine", "Kurta", ""]], "id": 268}, {"split": "test", "captions": ["anayna Women Pink Self-Striped Kurta with Trousers & Dupatta"], "file_path": "dataset/284.jpg", "processed_tokens": [["", "anayna", "Women", "Pink", "SelfStriped", "Kurta", "with", "Trousers", "Dupatta", ""]], "id": 284}, {"split": "test", "captions": ["Vishudh Women Green Printed Kurta with Palazzos"], "file_path": "dataset/290.jpg", "processed_tokens": [["", "Vishudh", "Women", "Green", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 290}, {"split": "test", "captions": ["Jashn Women Blue & Black Checked High-Low Straight Kurta"], "file_path": "dataset/308.jpg", "processed_tokens": [["", "Jashn", "Women", "Blue", "Black", "Checked", "HighLow", "Straight", "Kurta", ""]], "id": 308}, {"split": "test", "captions": ["Ives Women Navy Blue & Red Printed Anarkali Kurta"], "file_path": "dataset/310.jpg", "processed_tokens": [["", "Ives", "Women", "Navy", "Blue", "Red", "Printed", "Anarkali", "Kurta", ""]], "id": 310}, {"split": "test", "captions": ["Anubhutee Women Navy Blue & White Printed Kurta with Palazzos"], "file_path": "dataset/321.jpg", "processed_tokens": [["", "Anubhutee", "Women", "Navy", "Blue", "White", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 321}, {"split": "test", "captions": ["Vishudh Women Grey & Mustard Printed Kurta with Palazzos"], "file_path": "dataset/332.jpg", "processed_tokens": [["", "Vishudh", "Women", "Grey", "Mustard", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 332}, {"split": "test", "captions": ["AKS Women Maroon Solid Anarkali Kurta"], "file_path": "dataset/335.jpg", "processed_tokens": [["", "AKS", "Women", "Maroon", "Solid", "Anarkali", "Kurta", ""]], "id": 335}, {"split": "test", "captions": ["Rain & Rainbow Women Black & Pink Printed Kurta with Churidar & Dupatta"], "file_path": "dataset/363.jpg", "processed_tokens": [["", "Rain", "Rainbow", "Women", "Black", "Pink", "Printed", "Kurta", "with", "Churidar", "Dupatta", ""]], "id": 363}, {"split": "test", "captions": ["Anouk Women Mauve Striped Straight Kurta"], "file_path": "dataset/369.jpg", "processed_tokens": [["", "Anouk", "Women", "Mauve", "Striped", "Straight", "Kurta", ""]], "id": 369}, {"split": "test", "captions": ["Ishin Women Red & White Printed Bandhani Kurta with Skirt"], "file_path": "dataset/371.jpg", "processed_tokens": [["", "Ishin", "Women", "Red", "White", "Printed", "Bandhani", "Kurta", "with", "Skirt", ""]], "id": 371}, {"split": "test", "captions": ["Ishin Women Taupe & Off-White Printed Kurta with Palazzos"], "file_path": "dataset/374.jpg", "processed_tokens": [["", "Ishin", "Women", "Taupe", "OffWhite", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 374}, {"split": "test", "captions": ["Vishudh Women Blue Solid Kurta with Palazzos"], "file_path": "dataset/383.jpg", "processed_tokens": [["", "Vishudh", "Women", "Blue", "Solid", "Kurta", "with", "Palazzos", ""]], "id": 383}, {"split": "test", "captions": ["Anouk Women White & Black Checked A-Line High-Low Kurta"], "file_path": "dataset/384.jpg", "processed_tokens": [["", "Anouk", "Women", "White", "Black", "Checked", "ALine", "HighLow", "Kurta", ""]], "id": 384}, {"split": "test", "captions": ["Fabindia Women Yellow Woven Design Straight Kurta"], "file_path": "dataset/390.jpg", "processed_tokens": [["", "Fabindia", "Women", "Yellow", "Woven", "Design", "Straight", "Kurta", ""]], "id": 390}, {"split": "test", "captions": ["Anubhutee Women Pink & Golden Printed Straight Kurta"], "file_path": "dataset/393.jpg", "processed_tokens": [["", "Anubhutee", "Women", "Pink", "Golden", "Printed", "Straight", "Kurta", ""]], "id": 393}, {"split": "test", "captions": ["Vishudh Women Navy Blue Printed A-Line Kurta"], "file_path": "dataset/422.jpg", "processed_tokens": [["", "Vishudh", "Women", "Navy", "Blue", "Printed", "ALine", "Kurta", ""]], "id": 422}, {"split": "test", "captions": ["Fabindia Women Pink Solid Straight Kurta"], "file_path": "dataset/430.jpg", "processed_tokens": [["", "Fabindia", "Women", "Pink", "Solid", "Straight", "Kurta", ""]], "id": 430}, {"split": "test", "captions": ["Vishudh Women Black Printed A-Line Kurta"], "file_path": "dataset/431.jpg", "processed_tokens": [["", "Vishudh", "Women", "Black", "Printed", "ALine", "Kurta", ""]], "id": 431}, {"split": "test", "captions": ["Anouk Women Purple & Beige Printed Kurta Set"], "file_path": "dataset/439.jpg", "processed_tokens": [["", "Anouk", "Women", "Purple", "Beige", "Printed", "Kurta", "Set", ""]], "id": 439}, {"split": "test", "captions": ["HERE&NOW Women Mustard Yellow & Green Printed Kurta with Palazzos & Dupatta"], "file_path": "dataset/446.jpg", "processed_tokens": [["", "HERENOW", "Women", "Mustard", "Yellow", "Green", "Printed", "Kurta", "with", "Palazzos", "Dupatta", ""]], "id": 446}, {"split": "test", "captions": ["SIAH Women Peach-Coloured Solid Kurta with Trousers & Dupatta"], "file_path": "dataset/450.jpg", "processed_tokens": [["", "SIAH", "Women", "PeachColoured", "Solid", "Kurta", "with", "Trousers", "Dupatta", ""]], "id": 450}, {"split": "test", "captions": ["GERUA Women Black & Maroon Solid Kurta with Palazzos & Stole"], "file_path": "dataset/461.jpg", "processed_tokens": [["", "GERUA", "Women", "Black", "Maroon", "Solid", "Kurta", "with", "Palazzos", "Stole", ""]], "id": 461}, {"split": "test", "captions": ["Janasya Women Red & Beige Printed Kurti with Palazzos"], "file_path": "dataset/467.jpg", "processed_tokens": [["", "Janasya", "Women", "Red", "Beige", "Printed", "Kurti", "with", "Palazzos", ""]], "id": 467}, {"split": "test", "captions": ["Inddus Women Purple & Golden Woven Design A-Line Kurta"], "file_path": "dataset/473.jpg", "processed_tokens": [["", "Inddus", "Women", "Purple", "Golden", "Woven", "Design", "ALine", "Kurta", ""]], "id": 473}, {"split": "test", "captions": ["Libas Women Beige & White Embroidered Kurta with Palazzos"], "file_path": "dataset/481.jpg", "processed_tokens": [["", "Libas", "Women", "Beige", "White", "Embroidered", "Kurta", "with", "Palazzos", ""]], "id": 481}, {"split": "test", "captions": ["Khushal K Women Pink & Silver-Coloured Solid Kurta with Palazzos"], "file_path": "dataset/494.jpg", "processed_tokens": [["", "Khushal", "K", "Women", "Pink", "SilverColoured", "Solid", "Kurta", "with", "Palazzos", ""]], "id": 494}, {"split": "test", "captions": ["W Women Purple Solid Straight Kurta"], "file_path": "dataset/501.jpg", "processed_tokens": [["", "W", "Women", "Purple", "Solid", "Straight", "Kurta", ""]], "id": 501}, {"split": "test", "captions": ["Libas Women Blue Printed Anarkali Kurta"], "file_path": "dataset/503.jpg", "processed_tokens": [["", "Libas", "Women", "Blue", "Printed", "Anarkali", "Kurta", ""]], "id": 503}, {"split": "test", "captions": ["GERUA Women Maroon & Mustard Yellow Printed Kurta with Palazzos & Stole"], "file_path": "dataset/508.jpg", "processed_tokens": [["", "GERUA", "Women", "Maroon", "Mustard", "Yellow", "Printed", "Kurta", "with", "Palazzos", "Stole", ""]], "id": 508}, {"split": "test", "captions": ["Vishudh Women Beige & Gold-Toned Self Design Kurta with Trousers"], "file_path": "dataset/513.jpg", "processed_tokens": [["", "Vishudh", "Women", "Beige", "GoldToned", "Self", "Design", "Kurta", "with", "Trousers", ""]], "id": 513}, {"split": "test", "captions": ["Ritu Kumar Women Navy Blue & Beige Printed Kaftan Kurta"], "file_path": "dataset/533.jpg", "processed_tokens": [["", "Ritu", "Kumar", "Women", "Navy", "Blue", "Beige", "Printed", "Kaftan", "Kurta", ""]], "id": 533}, {"split": "test", "captions": ["SASSAFRAS Women Pink Solid Anarkali Kurta"], "file_path": "dataset/546.jpg", "processed_tokens": [["", "SASSAFRAS", "Women", "Pink", "Solid", "Anarkali", "Kurta", ""]], "id": 546}, {"split": "test", "captions": ["SASSAFRAS Women Burgundy Ikat Print A-Line Kurta"], "file_path": "dataset/549.jpg", "processed_tokens": [["", "SASSAFRAS", "Women", "Burgundy", "Ikat", "Print", "ALine", "Kurta", ""]], "id": 549}, {"split": "test", "captions": ["Libas Women Mustard Yellow Embroidered Detail Straight Kurta"], "file_path": "dataset/562.jpg", "processed_tokens": [["", "Libas", "Women", "Mustard", "Yellow", "Embroidered", "Detail", "Straight", "Kurta", ""]], "id": 562}, {"split": "test", "captions": ["Sangria Women Coral Orange Solid Straight Kurta"], "file_path": "dataset/580.jpg", "processed_tokens": [["", "Sangria", "Women", "Coral", "Orange", "Solid", "Straight", "Kurta", ""]], "id": 580}, {"split": "test", "captions": ["Anubhutee Women Pink & Navy Blue Printed Kurta with Palazzos"], "file_path": "dataset/585.jpg", "processed_tokens": [["", "Anubhutee", "Women", "Pink", "Navy", "Blue", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 585}, {"split": "test", "captions": ["Vishudh Women Blue Printed Kurta with Palazzos"], "file_path": "dataset/591.jpg", "processed_tokens": [["", "Vishudh", "Women", "Blue", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 591}, {"split": "test", "captions": ["GERUA Women Navy Blue & White Printed Kurta with Trousers"], "file_path": "dataset/596.jpg", "processed_tokens": [["", "GERUA", "Women", "Navy", "Blue", "White", "Printed", "Kurta", "with", "Trousers", ""]], "id": 596}, {"split": "test", "captions": ["anayna Women Green & White Printed Straight Kurta"], "file_path": "dataset/598.jpg", "processed_tokens": [["", "anayna", "Women", "Green", "White", "Printed", "Straight", "Kurta", ""]], "id": 598}, {"split": "test", "captions": ["Anouk Women Beige & Yellow Printed A-Line Kurta"], "file_path": "dataset/615.jpg", "processed_tokens": [["", "Anouk", "Women", "Beige", "Yellow", "Printed", "ALine", "Kurta", ""]], "id": 615}, {"split": "test", "captions": ["HERE&NOW Women Red & Golden Printed A-Line Kurta"], "file_path": "dataset/617.jpg", "processed_tokens": [["", "HERENOW", "Women", "Red", "Golden", "Printed", "ALine", "Kurta", ""]], "id": 617}, {"split": "test", "captions": ["Anouk Women White & Turquoise Blue Printed A-Line Tiered Kurta"], "file_path": "dataset/621.jpg", "processed_tokens": [["", "Anouk", "Women", "White", "Turquoise", "Blue", "Printed", "ALine", "Tiered", "Kurta", ""]], "id": 621}, {"split": "test", "captions": ["AASI - HOUSE OF NAYO Women Green & Golden Printed Kurta with Trousers & Dupatta"], "file_path": "dataset/628.jpg", "processed_tokens": [["", "AASI", "HOUSE", "OF", "NAYO", "Women", "Green", "Golden", "Printed", "Kurta", "with", "Trousers", "Dupatta", ""]], "id": 628}, {"split": "test", "captions": ["Idalia Women Pink & Golden Printed Kurta with Palazzos "], "file_path": "dataset/668.jpg", "processed_tokens": [["", "Idalia", "Women", "Pink", "Golden", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 668}, {"split": "test", "captions": ["House of Pataudi Women White Embroidered A-Line Kurta"], "file_path": "dataset/669.jpg", "processed_tokens": [["", "House", "of", "Pataudi", "Women", "White", "Embroidered", "ALine", "Kurta", ""]], "id": 669}, {"split": "test", "captions": ["AASI - HOUSE OF NAYO Women Lime Green & Navy Solid Kurta with Trousers & Dupatta"], "file_path": "dataset/673.jpg", "processed_tokens": [["", "AASI", "HOUSE", "OF", "NAYO", "Women", "Lime", "Green", "Navy", "Solid", "Kurta", "with", "Trousers", "Dupatta", ""]], "id": 673}, {"split": "test", "captions": ["Jaipur Kurti Women Blue & Green Printed Kurta with Salwar & Dupatta"], "file_path": "dataset/677.jpg", "processed_tokens": [["", "Jaipur", "Kurti", "Women", "Blue", "Green", "Printed", "Kurta", "with", "Salwar", "Dupatta", ""]], "id": 677}, {"split": "test", "captions": ["Rain & Rainbow Women Mustard Yellow Floral Print Anarkali Kurta"], "file_path": "dataset/682.jpg", "processed_tokens": [["", "Rain", "Rainbow", "Women", "Mustard", "Yellow", "Floral", "Print", "Anarkali", "Kurta", ""]], "id": 682}, {"split": "test", "captions": ["YASH GALLERY Women Black & White Printed A-Line Kurta"], "file_path": "dataset/691.jpg", "processed_tokens": [["", "YASH", "GALLERY", "Women", "Black", "White", "Printed", "ALine", "Kurta", ""]], "id": 691}, {"split": "test", "captions": ["Nayo Women Navy Blue & White Printed Kurta with Palazzos"], "file_path": "dataset/702.jpg", "processed_tokens": [["", "Nayo", "Women", "Navy", "Blue", "White", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 702}, {"split": "test", "captions": ["GERUA Women Green & White Checked Straight Kurta"], "file_path": "dataset/763.jpg", "processed_tokens": [["", "GERUA", "Women", "Green", "White", "Checked", "Straight", "Kurta", ""]], "id": 763}, {"split": "test", "captions": ["Vishudh Women Purple Printed Kurta with Palazzos"], "file_path": "dataset/786.jpg", "processed_tokens": [["", "Vishudh", "Women", "Purple", "Printed", "Kurta", "with", "Palazzos", ""]], "id": 786}, {"split": "test", "captions": ["Vishudh Women Blue & Pink Printed A-Line Kurta"], "file_path": "dataset/790.jpg", "processed_tokens": [["", "Vishudh", "Women", "Blue", "Pink", "Printed", "ALine", "Kurta", ""]], "id": 790}, {"split": "test", "captions": ["Anouk Women Beige & Black Colourblocked Kurta with Palazzos"], "file_path": "dataset/794.jpg", "processed_tokens": [["", "Anouk", "Women", "Beige", "Black", "Colourblocked", "Kurta", "with", "Palazzos", ""]], "id": 794}, {"split": "test", "captions": ["Florence Women Red & Black Striped Kurta with Palazzos"], "file_path": "dataset/830.jpg", "processed_tokens": [["", "Florence", "Women", "Red", "Black", "Striped", "Kurta", "with", "Palazzos", ""]], "id": 830}, {"split": "test", "captions": ["AKS Women Blue Embroidered Detail Straight Kurta"], "file_path": "dataset/832.jpg", "processed_tokens": [["", "AKS", "Women", "Blue", "Embroidered", "Detail", "Straight", "Kurta", ""]], "id": 832}, {"split": "test", "captions": ["HERE&NOW Women Black & White Printed Straight Kurta"], "file_path": "dataset/837.jpg", "processed_tokens": [["", "HERENOW", "Women", "Black", "White", "Printed", "Straight", "Kurta", ""]], "id": 837}, {"split": "test", "captions": ["HERE&NOW Women Black & Golden Printed Kurta with Sharara & Dupatta"], "file_path": "dataset/843.jpg", "processed_tokens": [["", "HERENOW", "Women", "Black", "Golden", "Printed", "Kurta", "with", "Sharara", "Dupatta", ""]], "id": 843}, {"split": "test", "captions": ["HIGHLANDER Men White Solid Straight Kurta"], "file_path": "dataset/846.jpg", "processed_tokens": [["", "HIGHLANDER", "Men", "White", "Solid", "Straight", "Kurta", ""]], "id": 846}, {"split": "test", "captions": ["DEYANN Men Blue Solid Kurta with Pyjamas & Nehru Jacket"], "file_path": "dataset/862.jpg", "processed_tokens": [["", "DEYANN", "Men", "Blue", "Solid", "Kurta", "with", "Pyjamas", "Nehru", "Jacket", ""]], "id": 862}, {"split": "test", "captions": ["Taavi Men Peach-Coloured Handloom Woven Legacy Kurta with Half Button Placket"], "file_path": "dataset/864.jpg", "processed_tokens": [["", "Taavi", "Men", "PeachColoured", "Handloom", "Woven", "Legacy", "Kurta", "with", "Half", "Button", "Placket", ""]], "id": 864}, {"split": "test", "captions": ["Purple State Men Navy Blue Solid Straight Kurta"], "file_path": "dataset/875.jpg", "processed_tokens": [["", "Purple", "State", "Men", "Navy", "Blue", "Solid", "Straight", "Kurta", ""]], "id": 875}, {"split": "test", "captions": ["DEYANN Men Red & Cream Self Design Kurta with Patiala"], "file_path": "dataset/880.jpg", "processed_tokens": [["", "DEYANN", "Men", "Red", "Cream", "Self", "Design", "Kurta", "with", "Patiala", ""]], "id": 880}, {"split": "test", "captions": ["Purple State Men Maroon Solid Kurta"], "file_path": "dataset/907.jpg", "processed_tokens": [["", "Purple", "State", "Men", "Maroon", "Solid", "Kurta", ""]], "id": 907}, {"split": "test", "captions": ["Anouk Men Grey & Black Woven Design Straight Kurta"], "file_path": "dataset/919.jpg", "processed_tokens": [["", "Anouk", "Men", "Grey", "Black", "Woven", "Design", "Straight", "Kurta", ""]], "id": 919}, {"split": "test", "captions": ["KISAH Men Yellow Embroidered Straight Kurta"], "file_path": "dataset/931.jpg", "processed_tokens": [["", "KISAH", "Men", "Yellow", "Embroidered", "Straight", "Kurta", ""]], "id": 931}, {"split": "test", "captions": ["even Men Mustard Yellow Solid Straight Kurta"], "file_path": "dataset/952.jpg", "processed_tokens": [["", "even", "Men", "Mustard", "Yellow", "Solid", "Straight", "Kurta", ""]], "id": 952}, {"split": "test", "captions": ["Anouk Men Maroon & Black Solid Kurta Set"], "file_path": "dataset/959.jpg", "processed_tokens": [["", "Anouk", "Men", "Maroon", "Black", "Solid", "Kurta", "Set", ""]], "id": 959}, {"split": "test", "captions": ["Svanik Men Teal Green Woven Design Straight Kurta"], "file_path": "dataset/962.jpg", "processed_tokens": [["", "Svanik", "Men", "Teal", "Green", "Woven", "Design", "Straight", "Kurta", ""]], "id": 962}, {"split": "test", "captions": ["House of Pataudi Men Rose Pink Solid Straight Bib Kurta"], "file_path": "dataset/972.jpg", "processed_tokens": [["", "House", "of", "Pataudi", "Men", "Rose", "Pink", "Solid", "Straight", "Bib", "Kurta", ""]], "id": 972}, {"split": "test", "captions": ["SOJANYA Men Lime Green Solid Straight Kurta"], "file_path": "dataset/974.jpg", "processed_tokens": [["", "SOJANYA", "Men", "Lime", "Green", "Solid", "Straight", "Kurta", ""]], "id": 974}, {"split": "test", "captions": ["NEUDIS Men Red & White Self Design Kurta with Trousers"], "file_path": "dataset/975.jpg", "processed_tokens": [["", "NEUDIS", "Men", "Red", "White", "Self", "Design", "Kurta", "with", "Trousers", ""]], "id": 975}, {"split": "test", "captions": ["SOJANYA Men Black Solid Kurta with Dhoti Pants"], "file_path": "dataset/978.jpg", "processed_tokens": [["", "SOJANYA", "Men", "Black", "Solid", "Kurta", "with", "Dhoti", "Pants", ""]], "id": 978}]
--------------------------------------------------------------------------------
/deep_cmpl_model/data/processed_data/test_sort.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/data/processed_data/test_sort.pkl
--------------------------------------------------------------------------------
/deep_cmpl_model/data/processed_data/train_sort.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/data/processed_data/train_sort.pkl
--------------------------------------------------------------------------------
/deep_cmpl_model/data/processed_data/val_sort.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/data/processed_data/val_sort.pkl
--------------------------------------------------------------------------------
/deep_cmpl_model/data/processed_data/word_counts.txt:
--------------------------------------------------------------------------------
1 | Total words: 298
2 | Words in vocabulary: 163
3 | [('', 800), ('', 800), ('Kurta', 784), ('Women', 672), ('Printed', 398), ('with', 311), ('Straight', 304), ('Blue', 210), ('Solid', 156), ('Palazzos', 152), ('ALine', 132), ('Men', 124), ('White', 122), ('Green', 122), ('Black', 109), ('Pink', 108), ('Design', 99), ('Navy', 88), ('Trousers', 83), ('Vishudh', 79), ('Yellow', 76), ('Libas', 76), ('Anouk', 76), ('OffWhite', 69), ('Dupatta', 67), ('Red', 65), ('Embroidered', 58), ('Beige', 54), ('Grey', 54), ('Maroon', 51), ('Mustard', 49), ('Anarkali', 46), ('Yoke', 45), ('Woven', 43), ('Golden', 39), ('Print', 39), ('Kurti', 30), ('HERENOW', 30), ('Churidar', 29), ('Striped', 27), ('Melange', 27), ('by', 27), ('Block', 25), ('GoldToned', 25), ('Lifestyle', 25), ('Teal', 24), ('DEYANN', 23), ('Nayo', 22), ('Soch', 22), ('Orange', 21), ('Self', 21), ('Brown', 21), ('Varanga', 21), ('CreamColoured', 21), ('Jaipur', 19), ('AKS', 19), ('Jacket', 19), ('See', 18), ('Designs', 18), ('W', 16), ('Sangria', 16), ('Purple', 16), ('PeachColoured', 15), ('Nehru', 15), ('Fabindia', 14), ('Taavi', 14), ('Indo', 13), ('Era', 13), ('Sea', 13), ('Burgundy', 13), ('ZIYAA', 13), ('GERUA', 13), ('Olive', 13), ('Layered', 12), ('KISAH', 11), ('SOJANYA', 11), ('Pyjamas', 11), ('AHIKA', 10), ('Inddus', 10), ('Checked', 10), ('House', 10), ('of', 10), ('Pataudi', 10), ('With', 10), ('Legacy', 10), ('Fusion', 10), ('Vaamsi', 9), ('Moda', 9), ('Rapido', 9), ('Jompers', 9), ('Handloom', 9), ('Yufta', 9), ('Foil', 8), ('Coral', 8), ('Charcoal', 8), ('Rust', 8), ('IMARA', 8), ('Cross', 8), ('Court', 8), ('Patiala', 8), ('Silver', 7), ('Azira', 7), ('Rain', 7), ('Rainbow', 7), ('Skirt', 7), ('Pathani', 7), ('ADA', 6), ('Chikankari', 6), ('Hand', 6), ('GoldColoured', 6), ('Bandhani', 6), ('Turquoise', 6), ('Magenta', 6), ('Ishin', 6), ('Sharara', 6), ('even', 6), ('Lime', 5), ('Global', 5), ('Desi', 5), ('Dyed', 5), ('Janasya', 5), ('Anubhutee', 5), ('Tissu', 5), ('Mauve', 5), ('SASSAFRAS', 5), ('Silk', 5), ('Multicoloured', 5), ('Ahalyaa', 5), ('SilverToned', 4), ('Rasada', 4), ('AASI', 4), ('HOUSE', 4), ('OF', 4), ('NAYO', 4), ('Taupe', 4), ('Alena', 4), ('SelfStriped', 4), ('Ritu', 4), ('Kumar', 4), ('Manyavar', 4), ('Dupion', 4), ('NEUDIS', 4), ('Cotton', 3), ('Aline', 3), ('Floral', 3), ('Bitterlime', 3), ('Detail', 3), ('Kaftan', 3), ('Pockets', 3), ('Top', 3), ('Bhama', 3), ('Couture', 3), ('Biba', 3), ('all', 3), ('about', 3), ('you', 3), ('Idalia', 3), ('Shree', 3), ('Kalamkari', 3), ('I', 3), ('Set', 3), ('RollUp', 3), ('Sleeves', 3)]
--------------------------------------------------------------------------------
/deep_cmpl_model/data/processed_data/word_to_index.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/frank-chris/ImageTextRetrieval/020a7dc8eb4b463e4f13697854d08fee4b2ae979/deep_cmpl_model/data/processed_data/word_to_index.pkl
--------------------------------------------------------------------------------
/deep_cmpl_model/deep_cmpl_jupyter.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "deep_cmpl_jupyter.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | },
17 | "accelerator": "GPU"
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "code",
22 | "metadata": {
23 | "id": "W5YoGYWqCNbp",
24 | "colab": {
25 | "base_uri": "https://localhost:8080/"
26 | },
27 | "outputId": "f0f71f76-2444-4fa8-96a2-7e6c8583ffe7"
28 | },
29 | "source": [
30 | "from google.colab import drive\n",
31 | "drive.mount('/content/drive/',force_remount=True)"
32 | ],
33 | "execution_count": 1,
34 | "outputs": [
35 | {
36 | "output_type": "stream",
37 | "text": [
38 | "Mounted at /content/drive/\n"
39 | ],
40 | "name": "stdout"
41 | }
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "metadata": {
47 | "id": "ykxVopBOEHKT"
48 | },
49 | "source": [
50 | "!pip install PyDrive\n",
51 | "from pydrive.auth import GoogleAuth\n",
52 | "from pydrive.drive import GoogleDrive\n",
53 | "from google.colab import auth\n",
54 | "from oauth2client.client import GoogleCredentials\n",
55 | "auth.authenticate_user()\n",
56 | "gauth = GoogleAuth()\n",
57 | "gauth.credentials = GoogleCredentials.get_application_default()\n",
58 | "drive = GoogleDrive(gauth)\n",
59 | "downloaded = drive.CreateFile({'id':\"1XHUwpJDETylqrNJgeXRas89YswqfobG5\"})\n",
60 | "downloaded.GetContentFile('dataset.zip') \n",
61 | "!unzip \"dataset.zip\" -d \"/content\""
62 | ],
63 | "execution_count": null,
64 | "outputs": []
65 | },
66 | {
67 | "cell_type": "code",
68 | "metadata": {
69 | "id": "UovUjGvQRWKK"
70 | },
71 | "source": [
72 | "# !unzip /content/drive/Shareddrives/Image-Text-Retrieval/dataset.zip -d /content"
73 | ],
74 | "execution_count": null,
75 | "outputs": []
76 | },
77 | {
78 | "cell_type": "code",
79 | "metadata": {
80 | "id": "ybIFQeAjFTg0",
81 | "colab": {
82 | "base_uri": "https://localhost:8080/"
83 | },
84 | "outputId": "24b31585-077a-479a-9c1d-2284d43baf80"
85 | },
86 | "source": [
87 | "!rm -rf Image_Text_Retrieval/\n",
88 | "%cd /content/\n",
89 | "!git clone --single-branch --branch main https://github.com/raghavgoyal283/Image_Text_Retrieval"
90 | ],
91 | "execution_count": 4,
92 | "outputs": [
93 | {
94 | "output_type": "stream",
95 | "text": [
96 | "/content\n",
97 | "Cloning into 'Image_Text_Retrieval'...\n",
98 | "remote: Enumerating objects: 196, done.\u001b[K\n",
99 | "remote: Counting objects: 100% (196/196), done.\u001b[K\n",
100 | "remote: Compressing objects: 100% (160/160), done.\u001b[K\n",
101 | "remote: Total 196 (delta 62), reused 159 (delta 28), pack-reused 0\u001b[K\n",
102 | "Receiving objects: 100% (196/196), 3.67 MiB | 27.82 MiB/s, done.\n",
103 | "Resolving deltas: 100% (62/62), done.\n"
104 | ],
105 | "name": "stdout"
106 | }
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "metadata": {
112 | "colab": {
113 | "base_uri": "https://localhost:8080/"
114 | },
115 | "id": "GFd9a065q9fU",
116 | "outputId": "b6480f25-5c4c-4584-913e-25d3d6218d22"
117 | },
118 | "source": [
119 | "%cd /content/Image_Text_Retrieval\n",
120 | "!git pull origin main"
121 | ],
122 | "execution_count": 17,
123 | "outputs": [
124 | {
125 | "output_type": "stream",
126 | "text": [
127 | "/content/Image_Text_Retrieval\n",
128 | "remote: Enumerating objects: 13, done.\u001b[K\n",
129 | "remote: Counting objects: 100% (13/13), done.\u001b[K\n",
130 | "remote: Compressing objects: 100% (3/3), done.\u001b[K\n",
131 | "remote: Total 7 (delta 4), reused 7 (delta 4), pack-reused 0\u001b[K\n",
132 | "Unpacking objects: 100% (7/7), done.\n",
133 | "From https://github.com/raghavgoyal283/Image_Text_Retrieval\n",
134 | " * branch main -> FETCH_HEAD\n",
135 | " 4776f27..8b4fdaf main -> origin/main\n",
136 | "Updating 4776f27..8b4fdaf\n",
137 | "Fast-forward\n",
138 | " deep_cmpl_model/code/scripts/tester.py | 2 \u001b[32m+\u001b[m\u001b[31m-\u001b[m\n",
139 | " deep_cmpl_model/code/scripts/trainer.py | 2 \u001b[32m+\u001b[m\u001b[31m-\u001b[m\n",
140 | " 2 files changed, 2 insertions(+), 2 deletions(-)\n"
141 | ],
142 | "name": "stdout"
143 | }
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "metadata": {
149 | "id": "3twCZ8ABDNJp",
150 | "colab": {
151 | "base_uri": "https://localhost:8080/"
152 | },
153 | "outputId": "3bad414d-0078-4fe5-a2e9-4e3491caefb6"
154 | },
155 | "source": [
156 | "%cd /content/\n",
157 | "!python3 /content/Image_Text_Retrieval/deep_cmpl_model/data/make_json.py\n",
158 | "!sh /content/Image_Text_Retrieval/deep_cmpl_model/code/datasets/data.sh"
159 | ],
160 | "execution_count": 12,
161 | "outputs": [
162 | {
163 | "output_type": "stream",
164 | "text": [
165 | "/content\n",
166 | "12305\n",
167 | "Preprocessing dataset\n",
168 | "start build vodabulary\n",
169 | "Total words: 1418\n",
170 | "Words in vocab: 812\n",
171 | "number of bad words: 606/1418 = 42.74%\n",
172 | "number of words in vocab: 812/1418 = 57.26%\n",
173 | "number of Null: 787/1418 = 55.50%\n",
174 | "Process metadata done!\n",
175 | "Total 9844 captions 9844 images 9844 identities in train\n",
176 | "Process metadata done!\n",
177 | "Total 1230 captions 1230 images 1230 identities in val\n",
178 | "Process metadata done!\n",
179 | "Total 1231 captions 1231 images 1231 identities in test\n",
180 | "Process decodedata done!\n",
181 | "Process decodedata done!\n",
182 | "Process decodedata done!\n",
183 | "=========== Arrange by id=============================\n",
184 | "Save dataset\n",
185 | "=========== Arrange by id=============================\n",
186 | "Save dataset\n",
187 | "=========== Arrange by id=============================\n",
188 | "Save dataset\n"
189 | ],
190 | "name": "stdout"
191 | }
192 | ]
193 | },
194 | {
195 | "cell_type": "markdown",
196 | "metadata": {
197 | "id": "G87T6Ntf1UOV"
198 | },
199 | "source": [
200 | "## Remember to push the pkl files"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "metadata": {
206 | "id": "rVxnw3Iv1QfX",
207 | "colab": {
208 | "base_uri": "https://localhost:8080/"
209 | },
210 | "outputId": "ac403127-d98f-47fe-f288-491fb93745d1"
211 | },
212 | "source": [
213 | "# %cd /content/Image_Text_Retrieval\n",
214 | "# !git config --global user.email \"raghavgoyal283@gmail.com\"\n",
215 | "# !git config --global user.name \"raghavgoyal283\"\n",
216 | "# !git add .\n",
217 | "# !git commit -m \"add\"\n",
218 | "# !git push -u origin main"
219 | ],
220 | "execution_count": null,
221 | "outputs": [
222 | {
223 | "output_type": "stream",
224 | "text": [
225 | "/content/Image_Text_Retrieval\n",
226 | "[anothertry af39ca4] add\n",
227 | " 1 file changed, 3 insertions(+), 3 deletions(-)\n",
228 | "Counting objects: 4, done.\n",
229 | "Delta compression using up to 2 threads.\n",
230 | "Compressing objects: 100% (4/4), done.\n",
231 | "Writing objects: 100% (4/4), 422 bytes | 422.00 KiB/s, done.\n",
232 | "Total 4 (delta 3), reused 0 (delta 0)\n",
233 | "remote: Resolving deltas: 100% (3/3), completed with 3 local objects.\u001b[K\n",
234 | "To https://github.com/raghavgoyal283/Image_Text_Retrieval.git\n",
235 | " c1d7077..af39ca4 anothertry -> anothertry\n",
236 | "Branch 'anothertry' set up to track remote branch 'anothertry' from 'origin'.\n"
237 | ],
238 | "name": "stdout"
239 | }
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {
245 | "id": "Z4iFxUTlRnWQ"
246 | },
247 | "source": [
248 | "# Train/Test"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "metadata": {
254 | "id": "CbFhoqYkoq-6",
255 | "colab": {
256 | "base_uri": "https://localhost:8080/"
257 | },
258 | "outputId": "b905fe42-9ead-4887-9a77-84d62ef05db2"
259 | },
260 | "source": [
261 | "!pip install efficientnet_pytorch"
262 | ],
263 | "execution_count": 13,
264 | "outputs": [
265 | {
266 | "output_type": "stream",
267 | "text": [
268 | "Collecting efficientnet_pytorch\n",
269 | " Downloading https://files.pythonhosted.org/packages/2e/a0/dd40b50aebf0028054b6b35062948da01123d7be38d08b6b1e5435df6363/efficientnet_pytorch-0.7.1.tar.gz\n",
270 | "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from efficientnet_pytorch) (1.8.1+cu101)\n",
271 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->efficientnet_pytorch) (3.7.4.3)\n",
272 | "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torch->efficientnet_pytorch) (1.19.5)\n",
273 | "Building wheels for collected packages: efficientnet-pytorch\n",
274 | " Building wheel for efficientnet-pytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
275 | " Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-cp37-none-any.whl size=16443 sha256=5a0fcea201810a6455bd03e301dc9617b6b4a66d8496abf193b0017c78d619a4\n",
276 | " Stored in directory: /root/.cache/pip/wheels/84/27/aa/c46d23c4e8cc72d41283862b1437e0b3ad318417e8ed7d5921\n",
277 | "Successfully built efficientnet-pytorch\n",
278 | "Installing collected packages: efficientnet-pytorch\n",
279 | "Successfully installed efficientnet-pytorch-0.7.1\n"
280 | ],
281 | "name": "stdout"
282 | }
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "metadata": {
288 | "id": "-aTdr_DrGd4x"
289 | },
290 | "source": [
291 | "# Clear previous checkpoints\n",
292 | "!rm -r /content/drive/Shareddrives/Image-Text-Retrieval/tempckpt/*"
293 | ],
294 | "execution_count": 21,
295 | "outputs": []
296 | },
297 | {
298 | "cell_type": "code",
299 | "metadata": {
300 | "colab": {
301 | "base_uri": "https://localhost:8080/"
302 | },
303 | "id": "IkRKNJVcDM-F",
304 | "outputId": "c078b79b-bdd3-4438-b5d9-d8aa0ba129bd"
305 | },
306 | "source": [
307 | "%cd /content/\n",
308 | "!python3 /content/Image_Text_Retrieval/deep_cmpl_model/code/scripts/trainer.py"
309 | ],
310 | "execution_count": 22,
311 | "outputs": [
312 | {
313 | "output_type": "stream",
314 | "text": [
315 | "/content\n",
316 | "Loaded pretrained weights for efficientnet-b0\n",
317 | "Total params: 17M\n",
318 | "epoch:0, step:0, cmpm_loss:28.980\n",
319 | "epoch:0, step:10, cmpm_loss:27.822\n",
320 | "epoch:0, step:20, cmpm_loss:25.797\n",
321 | "epoch:0, step:30, cmpm_loss:24.350\n",
322 | "epoch:0, step:40, cmpm_loss:22.872\n",
323 | "epoch:0, step:50, cmpm_loss:18.468\n",
324 | "epoch:0, step:60, cmpm_loss:17.170\n",
325 | "epoch:0, step:70, cmpm_loss:17.078\n",
326 | "epoch:0, step:80, cmpm_loss:18.581\n",
327 | "epoch:0, step:90, cmpm_loss:17.275\n",
328 | "epoch:0, step:100, cmpm_loss:16.918\n",
329 | "epoch:0, step:110, cmpm_loss:14.333\n",
330 | "epoch:0, step:120, cmpm_loss:20.252\n",
331 | "epoch:0, step:130, cmpm_loss:13.150\n",
332 | "epoch:0, step:140, cmpm_loss:13.900\n",
333 | "epoch:0, step:150, cmpm_loss:9.549\n",
334 | "epoch:0, step:160, cmpm_loss:13.966\n",
335 | "epoch:0, step:170, cmpm_loss:8.952\n",
336 | "epoch:0, step:180, cmpm_loss:10.939\n",
337 | "epoch:0, step:190, cmpm_loss:8.844\n",
338 | "epoch:0, step:200, cmpm_loss:3.907\n",
339 | "epoch:0, step:210, cmpm_loss:10.644\n",
340 | "epoch:0, step:220, cmpm_loss:10.646\n",
341 | "epoch:0, step:230, cmpm_loss:9.801\n",
342 | "epoch:0, step:240, cmpm_loss:6.492\n",
343 | "epoch:0, step:250, cmpm_loss:11.888\n",
344 | "epoch:0, step:260, cmpm_loss:16.473\n",
345 | "epoch:0, step:270, cmpm_loss:8.894\n",
346 | "epoch:0, step:280, cmpm_loss:10.306\n",
347 | "epoch:0, step:290, cmpm_loss:15.606\n",
348 | "epoch:0, step:300, cmpm_loss:7.734\n",
349 | "epoch:0, step:310, cmpm_loss:10.006\n",
350 | "epoch:0, step:320, cmpm_loss:5.520\n",
351 | "epoch:0, step:330, cmpm_loss:9.725\n",
352 | "epoch:0, step:340, cmpm_loss:8.853\n",
353 | "epoch:0, step:350, cmpm_loss:5.653\n",
354 | "epoch:0, step:360, cmpm_loss:8.572\n",
355 | "epoch:0, step:370, cmpm_loss:9.815\n",
356 | "epoch:0, step:380, cmpm_loss:6.975\n",
357 | "epoch:0, step:390, cmpm_loss:8.218\n",
358 | "epoch:0, step:400, cmpm_loss:12.703\n",
359 | "epoch:0, step:410, cmpm_loss:6.403\n",
360 | "epoch:0, step:420, cmpm_loss:7.904\n",
361 | "epoch:0, step:430, cmpm_loss:8.343\n",
362 | "epoch:0, step:440, cmpm_loss:5.121\n",
363 | "epoch:0, step:450, cmpm_loss:6.697\n",
364 | "epoch:0, step:460, cmpm_loss:8.433\n",
365 | "epoch:0, step:470, cmpm_loss:5.876\n",
366 | "epoch:0, step:480, cmpm_loss:0.763\n",
367 | "epoch:0, step:490, cmpm_loss:7.332\n",
368 | "epoch:0, step:500, cmpm_loss:13.782\n",
369 | "epoch:0, step:510, cmpm_loss:6.662\n",
370 | "epoch:0, step:520, cmpm_loss:3.625\n",
371 | "epoch:0, step:530, cmpm_loss:8.246\n",
372 | "epoch:0, step:540, cmpm_loss:7.430\n",
373 | "epoch:0, step:550, cmpm_loss:5.310\n",
374 | "epoch:0, step:560, cmpm_loss:5.444\n",
375 | "epoch:0, step:570, cmpm_loss:2.826\n",
376 | "epoch:0, step:580, cmpm_loss:4.238\n",
377 | "epoch:0, step:590, cmpm_loss:8.130\n",
378 | "epoch:0, step:600, cmpm_loss:8.824\n",
379 | "epoch:0, step:610, cmpm_loss:6.546\n",
380 | "Train done for epoch-0\n",
381 | "lr:0.0002\n",
382 | "epoch:1, step:0, cmpm_loss:6.692\n",
383 | "epoch:1, step:10, cmpm_loss:7.902\n",
384 | "epoch:1, step:20, cmpm_loss:4.313\n",
385 | "epoch:1, step:30, cmpm_loss:8.069\n",
386 | "epoch:1, step:40, cmpm_loss:3.060\n",
387 | "epoch:1, step:50, cmpm_loss:0.437\n",
388 | "epoch:1, step:60, cmpm_loss:2.161\n",
389 | "epoch:1, step:70, cmpm_loss:2.665\n",
390 | "epoch:1, step:80, cmpm_loss:9.616\n",
391 | "epoch:1, step:90, cmpm_loss:3.825\n",
392 | "epoch:1, step:100, cmpm_loss:2.910\n",
393 | "epoch:1, step:110, cmpm_loss:1.977\n",
394 | "epoch:1, step:120, cmpm_loss:1.750\n",
395 | "epoch:1, step:130, cmpm_loss:5.090\n",
396 | "epoch:1, step:140, cmpm_loss:7.062\n",
397 | "epoch:1, step:150, cmpm_loss:2.141\n",
398 | "epoch:1, step:160, cmpm_loss:7.578\n",
399 | "epoch:1, step:170, cmpm_loss:7.134\n",
400 | "epoch:1, step:180, cmpm_loss:2.916\n",
401 | "epoch:1, step:190, cmpm_loss:7.609\n",
402 | "epoch:1, step:200, cmpm_loss:2.978\n",
403 | "epoch:1, step:210, cmpm_loss:3.521\n",
404 | "epoch:1, step:220, cmpm_loss:3.619\n",
405 | "epoch:1, step:230, cmpm_loss:4.623\n",
406 | "epoch:1, step:240, cmpm_loss:6.489\n",
407 | "epoch:1, step:250, cmpm_loss:4.454\n",
408 | "epoch:1, step:260, cmpm_loss:2.816\n",
409 | "epoch:1, step:270, cmpm_loss:9.463\n",
410 | "epoch:1, step:280, cmpm_loss:3.138\n",
411 | "epoch:1, step:290, cmpm_loss:4.777\n",
412 | "epoch:1, step:300, cmpm_loss:5.296\n",
413 | "epoch:1, step:310, cmpm_loss:5.167\n",
414 | "epoch:1, step:320, cmpm_loss:4.451\n",
415 | "epoch:1, step:330, cmpm_loss:3.969\n",
416 | "epoch:1, step:340, cmpm_loss:4.741\n",
417 | "epoch:1, step:350, cmpm_loss:9.974\n",
418 | "epoch:1, step:360, cmpm_loss:7.483\n",
419 | "epoch:1, step:370, cmpm_loss:8.054\n",
420 | "epoch:1, step:380, cmpm_loss:5.053\n",
421 | "epoch:1, step:390, cmpm_loss:3.425\n",
422 | "epoch:1, step:400, cmpm_loss:2.688\n",
423 | "epoch:1, step:410, cmpm_loss:5.667\n",
424 | "epoch:1, step:420, cmpm_loss:4.761\n",
425 | "epoch:1, step:430, cmpm_loss:4.056\n",
426 | "epoch:1, step:440, cmpm_loss:6.989\n",
427 | "epoch:1, step:450, cmpm_loss:6.960\n",
428 | "epoch:1, step:460, cmpm_loss:9.581\n",
429 | "epoch:1, step:470, cmpm_loss:4.810\n",
430 | "epoch:1, step:480, cmpm_loss:4.324\n",
431 | "epoch:1, step:490, cmpm_loss:3.718\n",
432 | "epoch:1, step:500, cmpm_loss:3.811\n",
433 | "epoch:1, step:510, cmpm_loss:1.995\n",
434 | "epoch:1, step:520, cmpm_loss:5.277\n",
435 | "epoch:1, step:530, cmpm_loss:1.554\n",
436 | "epoch:1, step:540, cmpm_loss:5.259\n",
437 | "epoch:1, step:550, cmpm_loss:5.785\n",
438 | "epoch:1, step:560, cmpm_loss:7.259\n",
439 | "epoch:1, step:570, cmpm_loss:10.739\n",
440 | "epoch:1, step:580, cmpm_loss:5.838\n",
441 | "epoch:1, step:590, cmpm_loss:2.185\n",
442 | "epoch:1, step:600, cmpm_loss:4.268\n",
443 | "epoch:1, step:610, cmpm_loss:2.276\n",
444 | "Train done for epoch-1\n",
445 | "lr:0.0002\n",
446 | "epoch:2, step:0, cmpm_loss:5.452\n",
447 | "epoch:2, step:10, cmpm_loss:2.122\n",
448 | "epoch:2, step:20, cmpm_loss:10.590\n",
449 | "Traceback (most recent call last):\n",
450 | " File \"/content/Image_Text_Retrieval/deep_cmpl_model/code/train.py\", line 230, in \n",
451 | " main(args)\n",
452 | " File \"/content/Image_Text_Retrieval/deep_cmpl_model/code/train.py\", line 193, in main\n",
453 | " train_loss, train_time = train(args.start_epoch + epoch, train_loader, network, optimizer, compute_loss, args)\n",
454 | " File \"/content/Image_Text_Retrieval/deep_cmpl_model/code/train.py\", line 51, in train\n",
455 | " cmpm_loss.backward()\n",
456 | " File \"/usr/local/lib/python3.7/dist-packages/torch/tensor.py\", line 245, in backward\n",
457 | " torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)\n",
458 | " File \"/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py\", line 147, in backward\n",
459 | " allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n",
460 | "KeyboardInterrupt\n"
461 | ],
462 | "name": "stdout"
463 | }
464 | ]
465 | },
466 | {
467 | "cell_type": "code",
468 | "metadata": {
469 | "id": "qHY6hk3oFhu-",
470 | "colab": {
471 | "base_uri": "https://localhost:8080/"
472 | },
473 | "outputId": "c7752301-3279-4969-fd58-74707300c2e6"
474 | },
475 | "source": [
476 | "%cd /content/\n",
477 | "!python3 /content/Image_Text_Retrieval/deep_cmpl_model/code/scripts/tester.py"
478 | ],
479 | "execution_count": 23,
480 | "outputs": [
481 | {
482 | "output_type": "stream",
483 | "text": [
484 | "/content\n",
485 | "Loaded pretrained weights for efficientnet-b0\n",
486 | "==> Loading checkpoint \"drive/Shareddrives/Image-Text-Retrieval/tempckpt/data/model_data/lr-0.0002-decay-0.9-batch-16/0.pth.tar\"\n",
487 | "Loaded pretrained weights for efficientnet-b0\n",
488 | "==> Loading checkpoint \"drive/Shareddrives/Image-Text-Retrieval/tempckpt/data/model_data/lr-0.0002-decay-0.9-batch-16/1.pth.tar\"\n",
489 | "t2i_top1_best: 18.034, t2i_top5_best: 51.097, t2i_top10_best: 68.725, t2i_mr_best: 0.406\n",
490 | "i2t_top1: 17.872, i2t_top5: 49.228, i2t_top10: 67.100, i2t_mr_best: 0.487\n"
491 | ],
492 | "name": "stdout"
493 | }
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "metadata": {
499 | "id": "tsntUW3G_Os3"
500 | },
501 | "source": [
502 | "drive.flush_and_unmount()"
503 | ],
504 | "execution_count": null,
505 | "outputs": []
506 | }
507 | ]
508 | }
--------------------------------------------------------------------------------