├── cirtorch
├── layers
│ ├── __init__.py
│ ├── normalization.py
│ ├── loss.py
│ ├── pooling.py
│ └── functional.py
├── utils
│ ├── __init__.py
│ ├── general.py
│ ├── whiten.py
│ ├── evaluate.py
│ ├── download.py
│ └── download_win.py
├── datasets
│ ├── __init__.py
│ ├── testdataset.py
│ ├── datahelpers.py
│ ├── genericdataset.py
│ └── traindataset.py
├── examples
│ ├── __init__.py
│ ├── test_e2e.py
│ └── test.py
├── networks
│ ├── __init__.py
│ ├── imageretrievalnet_cpu.py
│ └── imageretrievalnet.py
├── .DS_Store
└── __init__.py
├── utils
├── .DS_Store
├── retrieval_feature.py
├── retrieval_index.py
└── classify.py
├── config.yaml
├── lshash
├── __init__.py
├── storage.py
└── lshash.py
├── .idea
├── vcs.xml
├── misc.xml
├── modules.xml
├── IMAGE_Retrieval.iml
└── workspace.xml
├── nts
├── config.py
├── README.md
├── test.py
├── core
│ ├── utils.py
│ ├── dataset.py
│ ├── anchors.py
│ ├── model.py
│ └── resnet.py
└── train.py
├── demo.py
├── README.md
├── .gitignore
└── interface.py
/cirtorch/layers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cirtorch/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cirtorch/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cirtorch/examples/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cirtorch/networks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yinhaoxs/ImageRetrieval-LSH/HEAD/utils/.DS_Store
--------------------------------------------------------------------------------
/cirtorch/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yinhaoxs/ImageRetrieval-LSH/HEAD/cirtorch/.DS_Store
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | websites:
2 | host: 0.0.0.0
3 | port: 15788
4 |
5 | model:
6 | network: /*.pth
7 | model_dir: /*
8 | type: [SA,SB]
--------------------------------------------------------------------------------
/lshash/__init__.py:
--------------------------------------------------------------------------------
1 | import pkg_resources
2 |
3 | try:
4 | __version__ = pkg_resources.get_distribution(__name__).version
5 | except:
6 | __version__ = '0.0.4dev'
7 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/nts/config.py:
--------------------------------------------------------------------------------
1 | BATCH_SIZE = 16
2 | PROPOSAL_NUM = 6
3 | CAT_NUM = 4
4 | INPUT_SIZE = (448, 448) # (w, h)
5 | LR = 0.001
6 | WD = 1e-4
7 | SAVE_FREQ = 1
8 | resume = ''
9 | test_model = 'model.ckpt'
10 | save_dir = '/data_4t/yangz/models/'
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/cirtorch/__init__.py:
--------------------------------------------------------------------------------
1 | from . import datasets, examples, layers, networks, utils
2 |
3 | from .datasets import datahelpers, genericdataset, testdataset, traindataset
4 | from .layers import functional, loss, normalization, pooling
5 | from .networks import imageretrievalnet
6 | from .utils import general, download, evaluate, whiten
--------------------------------------------------------------------------------
/.idea/IMAGE_Retrieval.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/cirtorch/layers/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import cirtorch.layers.functional as LF
5 |
6 | # --------------------------------------
7 | # Normalization layers
8 | # --------------------------------------
9 |
10 | class L2N(nn.Module):
11 |
12 | def __init__(self, eps=1e-6):
13 | super(L2N,self).__init__()
14 | self.eps = eps
15 |
16 | def forward(self, x):
17 | return LF.l2n(x, eps=self.eps)
18 |
19 | def __repr__(self):
20 | return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')'
21 |
22 |
23 | class PowerLaw(nn.Module):
24 |
25 | def __init__(self, eps=1e-6):
26 | super(PowerLaw, self).__init__()
27 | self.eps = eps
28 |
29 | def forward(self, x):
30 | return LF.powerlaw(x, eps=self.eps)
31 |
32 | def __repr__(self):
33 | return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')'
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | from utils.retrieval_feature import AntiFraudFeatureDataset
2 | from utils.retrieval_index import EvaluteMap
3 |
4 |
5 | if __name__ == '__main__':
6 | """
7 | img_dir存放所有图像库的图片,然后拿test_img_dir中的图片与图像库中的图片匹配,并输出top3的图像路径;
8 | """
9 | hash_size = 0
10 | input_dim = 2048
11 | num_hashtables = 1
12 | img_dir = 'ImageRetrieval/data'
13 | test_img_dir = './images'
14 | network = './weights/gl18-tl-resnet50-gem-w-83fdc30.pth'
15 | out_similar_dir = './output/similar'
16 | out_similar_file_dir = './output/similar_file'
17 | all_csv_file = './output/aaa.csv'
18 |
19 | feature_dict, lsh = AntiFraudFeatureDataset(img_dir, network).constructfeature(hash_size, input_dim, num_hashtables)
20 | test_feature_dict = AntiFraudFeatureDataset(test_img_dir, network).test_feature()
21 | EvaluteMap(out_similar_dir, out_similar_file_dir, all_csv_file).retrieval_images(test_feature_dict, lsh, 3)
22 |
23 |
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 图像检索模型
2 |
3 | ## 1.模型介绍
4 | --技术:深度卷积神经网络技术、LSH局部敏感哈希算法、flask web端部署、nts细粒度分类技术
5 |
6 | ## 2.预训练模型
7 | --图像分类预训练模型:https://drive.google.com/file/d/1F-eKqPRjlya5GH2HwTlLKNSPEUaxCu9H/view?usp=sharing
8 | --图像检索预训练模型:http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth
9 | resnet50 http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth
10 |
11 | ## 3.数据集
12 | --生产上数据集进行迁移学习并查重
13 |
14 | ## 4.数据预处理
15 | 数据大小处理为(224*224)
16 | --筛选出tiff、tif等格式文件,并解决pillow的底层问题(opencv解决conver(“RGB”)问题)
17 | --数据分类筛选:采用nts网络进行细粒度分类
18 |
19 | ## 5.模型业务使用
20 | --分类:python utils/classify.py
21 | --特征提取:python utils/retrieval_feature.py
22 | --图像离线检索:python utils/retrieval_index.py
23 | --在线部署:python interface.py (采用flask框架部署,同时离线更新数据库)
24 | python app_test.py (测试接口)
25 |
26 | ## 6.指标
27 | --map:0.93
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/cirtorch/utils/general.py:
--------------------------------------------------------------------------------
1 | import os
2 | import hashlib
3 |
4 | def get_root():
5 | return os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))
6 |
7 |
8 | def get_data_root():
9 | return os.path.join(get_root(), 'data')
10 |
11 |
12 | def htime(c):
13 | c = round(c)
14 |
15 | days = c // 86400
16 | hours = c // 3600 % 24
17 | minutes = c // 60 % 60
18 | seconds = c % 60
19 |
20 | if days > 0:
21 | return '{:d}d {:d}h {:d}m {:d}s'.format(days, hours, minutes, seconds)
22 | if hours > 0:
23 | return '{:d}h {:d}m {:d}s'.format(hours, minutes, seconds)
24 | if minutes > 0:
25 | return '{:d}m {:d}s'.format(minutes, seconds)
26 | return '{:d}s'.format(seconds)
27 |
28 |
29 | def sha256_hash(filename, block_size=65536, length=8):
30 | sha256 = hashlib.sha256()
31 | with open(filename, 'rb') as f:
32 | for block in iter(lambda: f.read(block_size), b''):
33 | sha256.update(block)
34 | return sha256.hexdigest()[:length-1]
--------------------------------------------------------------------------------
/cirtorch/datasets/testdataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | DATASETS = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
5 |
6 | def configdataset(dataset, dir_main):
7 |
8 | dataset = dataset.lower()
9 |
10 | if dataset not in DATASETS:
11 | raise ValueError('Unknown dataset: {}!'.format(dataset))
12 |
13 | # loading imlist, qimlist, and gnd, in cfg as a dict
14 | gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset))
15 | with open(gnd_fname, 'rb') as f:
16 | cfg = pickle.load(f)
17 | cfg['gnd_fname'] = gnd_fname
18 |
19 | cfg['ext'] = '.jpg'
20 | cfg['qext'] = '.jpg'
21 | cfg['dir_data'] = os.path.join(dir_main, dataset)
22 | cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg')
23 |
24 | cfg['n'] = len(cfg['imlist'])
25 | cfg['nq'] = len(cfg['qimlist'])
26 |
27 | cfg['im_fname'] = config_imname
28 | cfg['qim_fname'] = config_qimname
29 |
30 | cfg['dataset'] = dataset
31 |
32 | return cfg
33 |
34 | def config_imname(cfg, i):
35 | return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext'])
36 |
37 | def config_qimname(cfg, i):
38 | return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext'])
39 |
--------------------------------------------------------------------------------
/cirtorch/layers/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import cirtorch.layers.functional as LF
5 |
6 | # --------------------------------------
7 | # Loss/Error layers
8 | # --------------------------------------
9 |
10 | class ContrastiveLoss(nn.Module):
11 | r"""CONTRASTIVELOSS layer that computes contrastive loss for a batch of images:
12 | Q query tuples, each packed in the form of (q,p,n1,..nN)
13 |
14 | Args:
15 | x: tuples arranges in columns as [q,p,n1,nN, ... ]
16 | label: -1 for query, 1 for corresponding positive, 0 for corresponding negative
17 | margin: contrastive loss margin. Default: 0.7
18 |
19 | >>> contrastive_loss = ContrastiveLoss(margin=0.7)
20 | >>> input = torch.randn(128, 35, requires_grad=True)
21 | >>> label = torch.Tensor([-1, 1, 0, 0, 0, 0, 0] * 5)
22 | >>> output = contrastive_loss(input, label)
23 | >>> output.backward()
24 | """
25 |
26 | def __init__(self, margin=0.7, eps=1e-6):
27 | super(ContrastiveLoss, self).__init__()
28 | self.margin = margin
29 | self.eps = eps
30 |
31 | def forward(self, x, label):
32 | return LF.contrastive_loss(x, label, margin=self.margin, eps=self.eps)
33 |
34 | def __repr__(self):
35 | return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')'
36 |
37 |
38 | class TripletLoss(nn.Module):
39 |
40 | def __init__(self, margin=0.1):
41 | super(TripletLoss, self).__init__()
42 | self.margin = margin
43 |
44 | def forward(self, x, label):
45 | return LF.triplet_loss(x, label, margin=self.margin)
46 |
47 | def __repr__(self):
48 | return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')'
49 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/cirtorch/datasets/datahelpers.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 |
4 | import torch
5 |
6 | def cid2filename(cid, prefix):
7 | """
8 | Creates a training image path out of its CID name
9 |
10 | Arguments
11 | ---------
12 | cid : name of the image
13 | prefix : root directory where images are saved
14 |
15 | Returns
16 | -------
17 | filename : full image filename
18 | """
19 | return os.path.join(prefix, cid[-2:], cid[-4:-2], cid[-6:-4], cid)
20 |
21 | def pil_loader(path):
22 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
23 | with open(path, 'rb') as f:
24 | img = Image.open(f)
25 | return img.convert('RGB')
26 |
27 | def accimage_loader(path):
28 | import accimage
29 | try:
30 | return accimage.Image(path)
31 | except IOError:
32 | # Potentially a decoding problem, fall back to PIL.Image
33 | return pil_loader(path)
34 |
35 | def default_loader(path):
36 | from torchvision import get_image_backend
37 | if get_image_backend() == 'accimage':
38 | return accimage_loader(path)
39 | else:
40 | return pil_loader(path)
41 |
42 | def imresize(img, imsize):
43 | img.thumbnail((imsize, imsize), Image.ANTIALIAS)
44 | return img
45 |
46 | def flip(x, dim):
47 | xsize = x.size()
48 | dim = x.dim() + dim if dim < 0 else dim
49 | x = x.view(-1, *xsize[dim:])
50 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
51 | return x.view(xsize)
52 |
53 | def collate_tuples(batch):
54 | if len(batch) == 1:
55 | return [batch[0][0]], [batch[0][1]]
56 | return [batch[i][0] for i in range(len(batch))], [batch[i][1] for i in range(len(batch))]
--------------------------------------------------------------------------------
/nts/README.md:
--------------------------------------------------------------------------------
1 | # NTS-Net
2 |
3 | This is a PyTorch implementation of the ECCV2018 paper "Learning to Navigate for Fine-grained Classification" (Ze Yang, Tiange Luo, Dong Wang, Zhiqiang Hu, Jun Gao, Liwei Wang).
4 |
5 | ## Requirements
6 | - python 3+
7 | - pytorch 0.4+
8 | - numpy
9 | - datetime
10 |
11 | ## Datasets
12 | Download the [CUB-200-2011](http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) datasets and put it in the root directory named **CUB_200_2011**, You can also try other fine-grained datasets.
13 |
14 | ## Train the model
15 | If you want to train the NTS-Net, just run ``python train.py``. You may need to change the configurations in ``config.py``. The parameter ``PROPOSAL_NUM`` is ``M`` in the original paper and the parameter ``CAT_NUM`` is ``K`` in the original paper. During training, the log file and checkpoint file will be saved in ``save_dir`` directory. You can change the parameter ``resume`` to choose the checkpoint model to resume.
16 |
17 | ## Test the model
18 | If you want to test the NTS-Net, just run ``python test.py``. You need to specify the ``test_model`` in ``config.py`` to choose the checkpoint model for testing.
19 |
20 | ## Model
21 | We also provide the checkpoint model trained by ourselves, you can download it from [here](https://drive.google.com/file/d/1F-eKqPRjlya5GH2HwTlLKNSPEUaxCu9H/view?usp=sharing). If you test on our provided model, you will get a 87.6% test accuracy.
22 |
23 | ## Reference
24 | If you are interested in our work and want to cite it, please acknowledge the following paper:
25 |
26 | ```
27 | @inproceedings{Yang2018Learning,
28 | author = {Yang, Ze and Luo, Tiange and Wang, Dong and Hu, Zhiqiang and Gao, Jun and Wang, Liwei},
29 | title = {Learning to Navigate for Fine-grained Classification},
30 | booktitle = {ECCV},
31 | year = {2018}
32 | }
33 | ```
34 |
--------------------------------------------------------------------------------
/cirtorch/utils/whiten.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | def whitenapply(X, m, P, dimensions=None):
5 |
6 | if not dimensions:
7 | dimensions = P.shape[0]
8 |
9 | X = np.dot(P[:dimensions, :], X-m)
10 | X = X / (np.linalg.norm(X, ord=2, axis=0, keepdims=True) + 1e-6)
11 |
12 | return X
13 |
14 | def pcawhitenlearn(X):
15 |
16 | N = X.shape[1]
17 |
18 | # Learning PCA w/o annotations
19 | m = X.mean(axis=1, keepdims=True)
20 | Xc = X - m
21 | Xcov = np.dot(Xc, Xc.T)
22 | Xcov = (Xcov + Xcov.T) / (2*N)
23 | eigval, eigvec = np.linalg.eig(Xcov)
24 | order = eigval.argsort()[::-1]
25 | eigval = eigval[order]
26 | eigvec = eigvec[:, order]
27 |
28 | P = np.dot(np.linalg.inv(np.sqrt(np.diag(eigval))), eigvec.T)
29 |
30 | return m, P
31 |
32 | def whitenlearn(X, qidxs, pidxs):
33 |
34 | # Learning Lw w annotations
35 | m = X[:, qidxs].mean(axis=1, keepdims=True)
36 | df = X[:, qidxs] - X[:, pidxs]
37 | S = np.dot(df, df.T) / df.shape[1]
38 | P = np.linalg.inv(cholesky(S))
39 | df = np.dot(P, X-m)
40 | D = np.dot(df, df.T)
41 | eigval, eigvec = np.linalg.eig(D)
42 | order = eigval.argsort()[::-1]
43 | eigval = eigval[order]
44 | eigvec = eigvec[:, order]
45 |
46 | P = np.dot(eigvec.T, P)
47 |
48 | return m, P
49 |
50 | def cholesky(S):
51 | # Cholesky decomposition
52 | # with adding a small value on the diagonal
53 | # until matrix is positive definite
54 | alpha = 0
55 | while 1:
56 | try:
57 | L = np.linalg.cholesky(S + alpha*np.eye(*S.shape))
58 | return L
59 | except:
60 | if alpha == 0:
61 | alpha = 1e-10
62 | else:
63 | alpha *= 10
64 | print(">>>> {}::cholesky: Matrix is not positive definite, adding {:.0e} on the diagonal"
65 | .format(os.path.basename(__file__), alpha))
66 |
--------------------------------------------------------------------------------
/nts/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.autograd import Variable
3 | import torch.utils.data
4 | from torch.nn import DataParallel
5 | from config import BATCH_SIZE, PROPOSAL_NUM, test_model
6 | from core import model, dataset
7 | from core.utils import progress_bar
8 |
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
10 | if not test_model:
11 | raise NameError('please set the test_model file to choose the checkpoint!')
12 | # read dataset
13 | trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None)
14 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
15 | shuffle=True, num_workers=8, drop_last=False)
16 | testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None)
17 | testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
18 | shuffle=False, num_workers=8, drop_last=False)
19 | # define model
20 | net = model.attention_net(topN=PROPOSAL_NUM)
21 | ckpt = torch.load(test_model)
22 | net.load_state_dict(ckpt['net_state_dict'])
23 | net = net.cuda()
24 | net = DataParallel(net)
25 | creterion = torch.nn.CrossEntropyLoss()
26 |
27 | # evaluate on train set
28 | train_loss = 0
29 | train_correct = 0
30 | total = 0
31 | net.eval()
32 |
33 | for i, data in enumerate(trainloader):
34 | with torch.no_grad():
35 | img, label = data[0].cuda(), data[1].cuda()
36 | batch_size = img.size(0)
37 | _, concat_logits, _, _, _ = net(img)
38 | # calculate loss
39 | concat_loss = creterion(concat_logits, label)
40 | # calculate accuracy
41 | _, concat_predict = torch.max(concat_logits, 1)
42 | total += batch_size
43 | train_correct += torch.sum(concat_predict.data == label.data)
44 | train_loss += concat_loss.item() * batch_size
45 | progress_bar(i, len(trainloader), 'eval on train set')
46 |
47 | train_acc = float(train_correct) / total
48 | train_loss = train_loss / total
49 | print('train set loss: {:.3f} and train set acc: {:.3f} total sample: {}'.format(train_loss, train_acc, total))
50 |
51 |
52 | # evaluate on test set
53 | test_loss = 0
54 | test_correct = 0
55 | total = 0
56 | for i, data in enumerate(testloader):
57 | with torch.no_grad():
58 | img, label = data[0].cuda(), data[1].cuda()
59 | batch_size = img.size(0)
60 | _, concat_logits, _, _, _ = net(img)
61 | # calculate loss
62 | concat_loss = creterion(concat_logits, label)
63 | # calculate accuracy
64 | _, concat_predict = torch.max(concat_logits, 1)
65 | total += batch_size
66 | test_correct += torch.sum(concat_predict.data == label.data)
67 | test_loss += concat_loss.item() * batch_size
68 | progress_bar(i, len(testloader), 'eval on test set')
69 |
70 | test_acc = float(test_correct) / total
71 | test_loss = test_loss / total
72 | print('test set loss: {:.3f} and test set acc: {:.3f} total sample: {}'.format(test_loss, test_acc, total))
73 |
74 | print('finishing testing')
75 |
--------------------------------------------------------------------------------
/nts/core/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import sys
4 | import time
5 | import logging
6 |
7 | _, term_width = os.popen('stty size', 'r').read().split()
8 | term_width = int(term_width)
9 |
10 | TOTAL_BAR_LENGTH = 40.
11 | last_time = time.time()
12 | begin_time = last_time
13 |
14 |
15 | def progress_bar(current, total, msg=None):
16 | global last_time, begin_time
17 | if current == 0:
18 | begin_time = time.time() # Reset for new bar.
19 |
20 | cur_len = int(TOTAL_BAR_LENGTH * current / total)
21 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
22 |
23 | sys.stdout.write(' [')
24 | for i in range(cur_len):
25 | sys.stdout.write('=')
26 | sys.stdout.write('>')
27 | for i in range(rest_len):
28 | sys.stdout.write('.')
29 | sys.stdout.write(']')
30 |
31 | cur_time = time.time()
32 | step_time = cur_time - last_time
33 | last_time = cur_time
34 | tot_time = cur_time - begin_time
35 |
36 | L = []
37 | L.append(' Step: %s' % format_time(step_time))
38 | L.append(' | Tot: %s' % format_time(tot_time))
39 | if msg:
40 | L.append(' | ' + msg)
41 |
42 | msg = ''.join(L)
43 | sys.stdout.write(msg)
44 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
45 | sys.stdout.write(' ')
46 |
47 | # Go back to the center of the bar.
48 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2)):
49 | sys.stdout.write('\b')
50 | sys.stdout.write(' %d/%d ' % (current + 1, total))
51 |
52 | if current < total - 1:
53 | sys.stdout.write('\r')
54 | else:
55 | sys.stdout.write('\n')
56 | sys.stdout.flush()
57 |
58 |
59 | def format_time(seconds):
60 | days = int(seconds / 3600 / 24)
61 | seconds = seconds - days * 3600 * 24
62 | hours = int(seconds / 3600)
63 | seconds = seconds - hours * 3600
64 | minutes = int(seconds / 60)
65 | seconds = seconds - minutes * 60
66 | secondsf = int(seconds)
67 | seconds = seconds - secondsf
68 | millis = int(seconds * 1000)
69 |
70 | f = ''
71 | i = 1
72 | if days > 0:
73 | f += str(days) + 'D'
74 | i += 1
75 | if hours > 0 and i <= 2:
76 | f += str(hours) + 'h'
77 | i += 1
78 | if minutes > 0 and i <= 2:
79 | f += str(minutes) + 'm'
80 | i += 1
81 | if secondsf > 0 and i <= 2:
82 | f += str(secondsf) + 's'
83 | i += 1
84 | if millis > 0 and i <= 2:
85 | f += str(millis) + 'ms'
86 | i += 1
87 | if f == '':
88 | f = '0ms'
89 | return f
90 |
91 |
92 | def init_log(output_dir):
93 | logging.basicConfig(level=logging.DEBUG,
94 | format='%(asctime)s %(message)s',
95 | datefmt='%Y%m%d-%H:%M:%S',
96 | filename=os.path.join(output_dir, 'log.log'),
97 | filemode='w')
98 | console = logging.StreamHandler()
99 | console.setLevel(logging.INFO)
100 | logging.getLogger('').addHandler(console)
101 | return logging
102 |
103 | if __name__ == '__main__':
104 | pass
105 |
--------------------------------------------------------------------------------
/lshash/storage.py:
--------------------------------------------------------------------------------
1 | # lshash/storage.py
2 | # Copyright 2012 Kay Zhu (a.k.a He Zhu) and contributors (see CONTRIBUTORS.txt)
3 | #
4 | # This module is part of lshash and is released under
5 | # the MIT License: http://www.opensource.org/licenses/mit-license.php
6 |
7 | import json
8 |
9 | try:
10 | import redis
11 | except ImportError:
12 | redis = None
13 |
14 | __all__ = ['storage']
15 |
16 |
17 | def storage(storage_config, index):
18 | """ Given the configuration for storage and the index, return the
19 | configured storage instance.
20 | """
21 | if 'dict' in storage_config:
22 | return InMemoryStorage(storage_config['dict'])
23 | elif 'redis' in storage_config:
24 | storage_config['redis']['db'] = index
25 | return RedisStorage(storage_config['redis'])
26 | else:
27 | raise ValueError("Only in-memory dictionary and Redis are supported.")
28 |
29 |
30 | class BaseStorage(object):
31 | def __init__(self, config):
32 | """ An abstract class used as an adapter for storages. """
33 | raise NotImplementedError
34 |
35 | def keys(self):
36 | """ Returns a list of binary hashes that are used as dict keys. """
37 | raise NotImplementedError
38 |
39 | def set_val(self, key, val):
40 | """ Set `val` at `key`, note that the `val` must be a string. """
41 | raise NotImplementedError
42 |
43 | def get_val(self, key):
44 | """ Return `val` at `key`, note that the `val` must be a string. """
45 | raise NotImplementedError
46 |
47 | def append_val(self, key, val):
48 | """ Append `val` to the list stored at `key`.
49 |
50 | If the key is not yet present in storage, create a list with `val` at
51 | `key`.
52 | """
53 | raise NotImplementedError
54 |
55 | def get_list(self, key):
56 | """ Returns a list stored in storage at `key`.
57 |
58 | This method should return a list of values stored at `key`. `[]` should
59 | be returned if the list is empty or if `key` is not present in storage.
60 | """
61 | raise NotImplementedError
62 |
63 |
64 | class InMemoryStorage(BaseStorage):
65 | def __init__(self, config):
66 | self.name = 'dict'
67 | self.storage = dict()
68 |
69 | def keys(self):
70 | return self.storage.keys()
71 |
72 | def set_val(self, key, val):
73 | self.storage[key] = val
74 |
75 | def get_val(self, key):
76 | return self.storage[key]
77 |
78 | def append_val(self, key, val):
79 | self.storage.setdefault(key, []).append(val)
80 |
81 | def get_list(self, key):
82 | return self.storage.get(key, [])
83 |
84 |
85 | class RedisStorage(BaseStorage):
86 | def __init__(self, config):
87 | if not redis:
88 | raise ImportError("redis-py is required to use Redis as storage.")
89 | self.name = 'redis'
90 | self.storage = redis.StrictRedis(**config)
91 |
92 | def keys(self, pattern="*"):
93 | return self.storage.keys(pattern)
94 |
95 | def set_val(self, key, val):
96 | self.storage.set(key, val)
97 |
98 | def get_val(self, key):
99 | return self.storage.get(key)
100 |
101 | def append_val(self, key, val):
102 | self.storage.rpush(key, json.dumps(val))
103 |
104 | def get_list(self, key):
105 | return self.storage.lrange(key, 0, -1)
106 |
--------------------------------------------------------------------------------
/nts/core/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.misc
3 | import os
4 | from PIL import Image
5 | from torchvision import transforms
6 | from config import INPUT_SIZE
7 |
8 |
9 | class CUB():
10 | def __init__(self, root, is_train=True, data_len=None):
11 | self.root = root
12 | self.is_train = is_train
13 | img_txt_file = open(os.path.join(self.root, 'images.txt'))
14 | label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
15 | train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
16 | img_name_list = []
17 | for line in img_txt_file:
18 | img_name_list.append(line[:-1].split(' ')[-1])
19 | label_list = []
20 | for line in label_txt_file:
21 | label_list.append(int(line[:-1].split(' ')[-1]) - 1)
22 | train_test_list = []
23 | for line in train_val_file:
24 | train_test_list.append(int(line[:-1].split(' ')[-1]))
25 | train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
26 | test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
27 | if self.is_train:
28 | self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
29 | train_file_list[:data_len]]
30 | self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
31 | if not self.is_train:
32 | self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
33 | test_file_list[:data_len]]
34 | self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
35 |
36 | def __getitem__(self, index):
37 | if self.is_train:
38 | img, target = self.train_img[index], self.train_label[index]
39 | if len(img.shape) == 2:
40 | img = np.stack([img] * 3, 2)
41 | img = Image.fromarray(img, mode='RGB')
42 | img = transforms.Resize((600, 600), Image.BILINEAR)(img)
43 | img = transforms.RandomCrop(INPUT_SIZE)(img)
44 | img = transforms.RandomHorizontalFlip()(img)
45 | img = transforms.ToTensor()(img)
46 | img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
47 |
48 | else:
49 | img, target = self.test_img[index], self.test_label[index]
50 | if len(img.shape) == 2:
51 | img = np.stack([img] * 3, 2)
52 | img = Image.fromarray(img, mode='RGB')
53 | img = transforms.Resize((600, 600), Image.BILINEAR)(img)
54 | img = transforms.CenterCrop(INPUT_SIZE)(img)
55 | img = transforms.ToTensor()(img)
56 | img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
57 |
58 | return img, target
59 |
60 | def __len__(self):
61 | if self.is_train:
62 | return len(self.train_label)
63 | else:
64 | return len(self.test_label)
65 |
66 |
67 | if __name__ == '__main__':
68 | dataset = CUB(root='./CUB_200_2011')
69 | print(len(dataset.train_img))
70 | print(len(dataset.train_label))
71 | for data in dataset:
72 | print(data[0].size(), data[1])
73 | dataset = CUB(root='./CUB_200_2011', is_train=False)
74 | print(len(dataset.test_img))
75 | print(len(dataset.test_label))
76 | for data in dataset:
77 | print(data[0].size(), data[1])
78 |
--------------------------------------------------------------------------------
/cirtorch/layers/pooling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.parameter import Parameter
4 |
5 | import cirtorch.layers.functional as LF
6 | from cirtorch.layers.normalization import L2N
7 |
8 | # --------------------------------------
9 | # Pooling layers
10 | # --------------------------------------
11 |
12 | class MAC(nn.Module):
13 |
14 | def __init__(self):
15 | super(MAC,self).__init__()
16 |
17 | def forward(self, x):
18 | return LF.mac(x)
19 |
20 | def __repr__(self):
21 | return self.__class__.__name__ + '()'
22 |
23 |
24 | class SPoC(nn.Module):
25 |
26 | def __init__(self):
27 | super(SPoC,self).__init__()
28 |
29 | def forward(self, x):
30 | return LF.spoc(x)
31 |
32 | def __repr__(self):
33 | return self.__class__.__name__ + '()'
34 |
35 |
36 | class GeM(nn.Module):
37 |
38 | def __init__(self, p=3, eps=1e-6):
39 | super(GeM,self).__init__()
40 | self.p = Parameter(torch.ones(1)*p)
41 | self.eps = eps
42 |
43 | def forward(self, x):
44 | return LF.gem(x, p=self.p, eps=self.eps)
45 |
46 | def __repr__(self):
47 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
48 |
49 | class GeMmp(nn.Module):
50 |
51 | def __init__(self, p=3, mp=1, eps=1e-6):
52 | super(GeMmp,self).__init__()
53 | self.p = Parameter(torch.ones(mp)*p)
54 | self.mp = mp
55 | self.eps = eps
56 |
57 | def forward(self, x):
58 | return LF.gem(x, p=self.p.unsqueeze(-1).unsqueeze(-1), eps=self.eps)
59 |
60 | def __repr__(self):
61 | return self.__class__.__name__ + '(' + 'p=' + '[{}]'.format(self.mp) + ', ' + 'eps=' + str(self.eps) + ')'
62 |
63 | class RMAC(nn.Module):
64 |
65 | def __init__(self, L=3, eps=1e-6):
66 | super(RMAC,self).__init__()
67 | self.L = L
68 | self.eps = eps
69 |
70 | def forward(self, x):
71 | return LF.rmac(x, L=self.L, eps=self.eps)
72 |
73 | def __repr__(self):
74 | return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')'
75 |
76 |
77 | class Rpool(nn.Module):
78 |
79 | def __init__(self, rpool, whiten=None, L=3, eps=1e-6):
80 | super(Rpool,self).__init__()
81 | self.rpool = rpool
82 | self.L = L
83 | self.whiten = whiten
84 | self.norm = L2N()
85 | self.eps = eps
86 |
87 | def forward(self, x, aggregate=True):
88 | # features -> roipool
89 | o = LF.roipool(x, self.rpool, self.L, self.eps) # size: #im, #reg, D, 1, 1
90 |
91 | # concatenate regions from all images in the batch
92 | s = o.size()
93 | o = o.view(s[0]*s[1], s[2], s[3], s[4]) # size: #im x #reg, D, 1, 1
94 |
95 | # rvecs -> norm
96 | o = self.norm(o)
97 |
98 | # rvecs -> whiten -> norm
99 | if self.whiten is not None:
100 | o = self.norm(self.whiten(o.squeeze(-1).squeeze(-1)))
101 |
102 | # reshape back to regions per image
103 | o = o.view(s[0], s[1], s[2], s[3], s[4]) # size: #im, #reg, D, 1, 1
104 |
105 | # aggregate regions into a single global vector per image
106 | if aggregate:
107 | # rvecs -> sumpool -> norm
108 | o = self.norm(o.sum(1, keepdim=False)) # size: #im, D, 1, 1
109 |
110 | return o
111 |
112 | def __repr__(self):
113 | return super(Rpool, self).__repr__() + '(' + 'L=' + '{}'.format(self.L) + ')'
--------------------------------------------------------------------------------
/cirtorch/datasets/genericdataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 |
4 | import torch
5 | import torch.utils.data as data
6 |
7 | from cirtorch.datasets.datahelpers import default_loader, imresize
8 |
9 |
10 | class ImagesFromList(data.Dataset):
11 | """A generic data loader that loads images from a list
12 | (Based on ImageFolder from pytorch)
13 | Args:
14 | root (string): Root directory path.
15 | images (list): Relative image paths as strings.
16 | imsize (int, Default: None): Defines the maximum size of longer image side
17 | bbxs (list): List of (x1,y1,x2,y2) tuples to crop the query images
18 | transform (callable, optional): A function/transform that takes in an PIL image
19 | and returns a transformed version. E.g, ``transforms.RandomCrop``
20 | loader (callable, optional): A function to load an image given its path.
21 | Attributes:
22 | images_fn (list): List of full image filename
23 | """
24 |
25 | def __init__(self, root, images, imsize=None, bbxs=None, transform=None, loader=default_loader):
26 |
27 | images_fn = [os.path.join(root,images[i]) for i in range(len(images))]
28 |
29 | if len(images_fn) == 0:
30 | raise(RuntimeError("Dataset contains 0 images!"))
31 |
32 | self.root = root
33 | self.images = images
34 | self.imsize = imsize
35 | self.images_fn = images_fn
36 | self.bbxs = bbxs
37 | self.transform = transform
38 | self.loader = loader
39 |
40 | def __getitem__(self, index):
41 | """
42 | Args:
43 | index (int): Index
44 | Returns:
45 | image (PIL): Loaded image
46 | """
47 | path = self.images_fn[index]
48 | img = self.loader(path)
49 | imfullsize = max(img.size)
50 |
51 | if self.bbxs is not None:
52 | img = img.crop(self.bbxs[index])
53 |
54 | if self.imsize is not None:
55 | if self.bbxs is not None:
56 | img = imresize(img, self.imsize * max(img.size) / imfullsize)
57 | else:
58 | img = imresize(img, self.imsize)
59 |
60 | if self.transform is not None:
61 | img = self.transform(img)
62 |
63 | return img, path
64 |
65 | def __len__(self):
66 | return len(self.images_fn)
67 |
68 | def __repr__(self):
69 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
70 | fmt_str += ' Number of images: {}\n'.format(self.__len__())
71 | fmt_str += ' Root Location: {}\n'.format(self.root)
72 | tmp = ' Transforms (if any): '
73 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
74 | return fmt_str
75 |
76 | class ImagesFromDataList(data.Dataset):
77 | """A generic data loader that loads images given as an array of pytorch tensors
78 | (Based on ImageFolder from pytorch)
79 | Args:
80 | images (list): Images as tensors.
81 | transform (callable, optional): A function/transform that image as a tensors
82 | and returns a transformed version. E.g, ``normalize`` with mean and std
83 | """
84 |
85 | def __init__(self, images, transform=None):
86 |
87 | if len(images) == 0:
88 | raise(RuntimeError("Dataset contains 0 images!"))
89 |
90 | self.images = images
91 | self.transform = transform
92 |
93 | def __getitem__(self, index):
94 | """
95 | Args:
96 | index (int): Index
97 | Returns:
98 | image (Tensor): Loaded image
99 | """
100 | img = self.images[index]
101 | if self.transform is not None:
102 | img = self.transform(img)
103 |
104 | if len(img.size()):
105 | img = img.unsqueeze(0)
106 |
107 | return img
108 |
109 | def __len__(self):
110 | return len(self.images)
111 |
112 | def __repr__(self):
113 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
114 | fmt_str += ' Number of images: {}\n'.format(self.__len__())
115 | tmp = ' Transforms (if any): '
116 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
117 | return fmt_str
118 |
--------------------------------------------------------------------------------
/nts/core/anchors.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from config import INPUT_SIZE
3 |
4 | _default_anchors_setting = (
5 | dict(layer='p3', stride=32, size=48, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
6 | dict(layer='p4', stride=64, size=96, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
7 | dict(layer='p5', stride=128, size=192, scale=[1, 2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
8 | )
9 |
10 |
11 | def generate_default_anchor_maps(anchors_setting=None, input_shape=INPUT_SIZE):
12 | """
13 | generate default anchor
14 |
15 | :param anchors_setting: all informations of anchors
16 | :param input_shape: shape of input images, e.g. (h, w)
17 | :return: center_anchors: # anchors * 4 (oy, ox, h, w)
18 | edge_anchors: # anchors * 4 (y0, x0, y1, x1)
19 | anchor_area: # anchors * 1 (area)
20 | """
21 | if anchors_setting is None:
22 | anchors_setting = _default_anchors_setting
23 |
24 | center_anchors = np.zeros((0, 4), dtype=np.float32)
25 | edge_anchors = np.zeros((0, 4), dtype=np.float32)
26 | anchor_areas = np.zeros((0,), dtype=np.float32)
27 | input_shape = np.array(input_shape, dtype=int)
28 |
29 | for anchor_info in anchors_setting:
30 |
31 | stride = anchor_info['stride']
32 | size = anchor_info['size']
33 | scales = anchor_info['scale']
34 | aspect_ratios = anchor_info['aspect_ratio']
35 |
36 | output_map_shape = np.ceil(input_shape.astype(np.float32) / stride)
37 | output_map_shape = output_map_shape.astype(np.int)
38 | output_shape = tuple(output_map_shape) + (4,)
39 | ostart = stride / 2.
40 | oy = np.arange(ostart, ostart + stride * output_shape[0], stride)
41 | oy = oy.reshape(output_shape[0], 1)
42 | ox = np.arange(ostart, ostart + stride * output_shape[1], stride)
43 | ox = ox.reshape(1, output_shape[1])
44 | center_anchor_map_template = np.zeros(output_shape, dtype=np.float32)
45 | center_anchor_map_template[:, :, 0] = oy
46 | center_anchor_map_template[:, :, 1] = ox
47 | for scale in scales:
48 | for aspect_ratio in aspect_ratios:
49 | center_anchor_map = center_anchor_map_template.copy()
50 | center_anchor_map[:, :, 2] = size * scale / float(aspect_ratio) ** 0.5
51 | center_anchor_map[:, :, 3] = size * scale * float(aspect_ratio) ** 0.5
52 |
53 | edge_anchor_map = np.concatenate((center_anchor_map[..., :2] - center_anchor_map[..., 2:4] / 2.,
54 | center_anchor_map[..., :2] + center_anchor_map[..., 2:4] / 2.),
55 | axis=-1)
56 | anchor_area_map = center_anchor_map[..., 2] * center_anchor_map[..., 3]
57 | center_anchors = np.concatenate((center_anchors, center_anchor_map.reshape(-1, 4)))
58 | edge_anchors = np.concatenate((edge_anchors, edge_anchor_map.reshape(-1, 4)))
59 | anchor_areas = np.concatenate((anchor_areas, anchor_area_map.reshape(-1)))
60 |
61 | return center_anchors, edge_anchors, anchor_areas
62 |
63 |
64 | def hard_nms(cdds, topn=10, iou_thresh=0.25):
65 | if not (type(cdds).__module__ == 'numpy' and len(cdds.shape) == 2 and cdds.shape[1] >= 5):
66 | raise TypeError('edge_box_map should be N * 5+ ndarray')
67 |
68 | cdds = cdds.copy()
69 | indices = np.argsort(cdds[:, 0])
70 | cdds = cdds[indices]
71 | cdd_results = []
72 |
73 | res = cdds
74 |
75 | while res.any():
76 | cdd = res[-1]
77 | cdd_results.append(cdd)
78 | if len(cdd_results) == topn:
79 | return np.array(cdd_results)
80 | res = res[:-1]
81 |
82 | start_max = np.maximum(res[:, 1:3], cdd[1:3])
83 | end_min = np.minimum(res[:, 3:5], cdd[3:5])
84 | lengths = end_min - start_max
85 | intersec_map = lengths[:, 0] * lengths[:, 1]
86 | intersec_map[np.logical_or(lengths[:, 0] < 0, lengths[:, 1] < 0)] = 0
87 | iou_map_cur = intersec_map / ((res[:, 3] - res[:, 1]) * (res[:, 4] - res[:, 2]) + (cdd[3] - cdd[1]) * (
88 | cdd[4] - cdd[2]) - intersec_map)
89 | res = res[iou_map_cur < iou_thresh]
90 |
91 | return np.array(cdd_results)
92 |
93 |
94 | if __name__ == '__main__':
95 | a = hard_nms(np.array([
96 | [0.4, 1, 10, 12, 20],
97 | [0.5, 1, 11, 11, 20],
98 | [0.55, 20, 30, 40, 50]
99 | ]), topn=100, iou_thresh=0.4)
100 | print(a)
101 |
--------------------------------------------------------------------------------
/nts/core/model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from core import resnet
6 | import numpy as np
7 | from core.anchors import generate_default_anchor_maps, hard_nms
8 | from config import CAT_NUM, PROPOSAL_NUM
9 |
10 |
11 | class ProposalNet(nn.Module):
12 | def __init__(self):
13 | super(ProposalNet, self).__init__()
14 | self.down1 = nn.Conv2d(2048, 128, 3, 1, 1)
15 | self.down2 = nn.Conv2d(128, 128, 3, 2, 1)
16 | self.down3 = nn.Conv2d(128, 128, 3, 2, 1)
17 | self.ReLU = nn.ReLU()
18 | self.tidy1 = nn.Conv2d(128, 6, 1, 1, 0)
19 | self.tidy2 = nn.Conv2d(128, 6, 1, 1, 0)
20 | self.tidy3 = nn.Conv2d(128, 9, 1, 1, 0)
21 |
22 | def forward(self, x):
23 | batch_size = x.size(0)
24 | d1 = self.ReLU(self.down1(x))
25 | d2 = self.ReLU(self.down2(d1))
26 | d3 = self.ReLU(self.down3(d2))
27 | t1 = self.tidy1(d1).view(batch_size, -1)
28 | t2 = self.tidy2(d2).view(batch_size, -1)
29 | t3 = self.tidy3(d3).view(batch_size, -1)
30 | return torch.cat((t1, t2, t3), dim=1)
31 |
32 |
33 | class attention_net(nn.Module):
34 | def __init__(self, topN=4):
35 | super(attention_net, self).__init__()
36 | self.pretrained_model = resnet.resnet50(pretrained=True)
37 | self.pretrained_model.avgpool = nn.AdaptiveAvgPool2d(1)
38 | self.pretrained_model.fc = nn.Linear(512 * 4, 200)
39 | self.proposal_net = ProposalNet()
40 | self.topN = topN
41 | self.concat_net = nn.Linear(2048 * (CAT_NUM + 1), 200)
42 | self.partcls_net = nn.Linear(512 * 4, 200)
43 | _, edge_anchors, _ = generate_default_anchor_maps()
44 | self.pad_side = 224
45 | self.edge_anchors = (edge_anchors + 224).astype(np.int)
46 |
47 | def forward(self, x):
48 | resnet_out, rpn_feature, feature = self.pretrained_model(x)
49 | x_pad = F.pad(x, (self.pad_side, self.pad_side, self.pad_side, self.pad_side), mode='constant', value=0)
50 | batch = x.size(0)
51 | # we will reshape rpn to shape: batch * nb_anchor
52 | rpn_score = self.proposal_net(rpn_feature.detach())
53 | all_cdds = [
54 | np.concatenate((x.reshape(-1, 1), self.edge_anchors.copy(), np.arange(0, len(x)).reshape(-1, 1)), axis=1)
55 | for x in rpn_score.data.cpu().numpy()]
56 | top_n_cdds = [hard_nms(x, topn=self.topN, iou_thresh=0.25) for x in all_cdds]
57 | top_n_cdds = np.array(top_n_cdds)
58 | top_n_index = top_n_cdds[:, :, -1].astype(np.int)
59 | top_n_index = torch.from_numpy(top_n_index).cuda()
60 | top_n_prob = torch.gather(rpn_score, dim=1, index=top_n_index)
61 | part_imgs = torch.zeros([batch, self.topN, 3, 224, 224]).cuda()
62 | for i in range(batch):
63 | for j in range(self.topN):
64 | [y0, x0, y1, x1] = top_n_cdds[i][j, 1:5].astype(np.int)
65 | part_imgs[i:i + 1, j] = F.interpolate(x_pad[i:i + 1, :, y0:y1, x0:x1], size=(224, 224), mode='bilinear',
66 | align_corners=True)
67 | part_imgs = part_imgs.view(batch * self.topN, 3, 224, 224)
68 | _, _, part_features = self.pretrained_model(part_imgs.detach())
69 | part_feature = part_features.view(batch, self.topN, -1)
70 | part_feature = part_feature[:, :CAT_NUM, ...].contiguous()
71 | part_feature = part_feature.view(batch, -1)
72 | # concat_logits have the shape: B*200
73 | concat_out = torch.cat([part_feature, feature], dim=1)
74 | concat_logits = self.concat_net(concat_out)
75 | raw_logits = resnet_out
76 | # part_logits have the shape: B*N*200
77 | part_logits = self.partcls_net(part_features).view(batch, self.topN, -1)
78 | return [raw_logits, concat_logits, part_logits, top_n_index, top_n_prob]
79 |
80 |
81 | def list_loss(logits, targets):
82 | temp = F.log_softmax(logits, -1)
83 | loss = [-temp[i][targets[i].item()] for i in range(logits.size(0))]
84 | return torch.stack(loss)
85 |
86 |
87 | def ranking_loss(score, targets, proposal_num=PROPOSAL_NUM):
88 | loss = Variable(torch.zeros(1).cuda())
89 | batch_size = score.size(0)
90 | for i in range(proposal_num):
91 | targets_p = (targets > targets[:, i].unsqueeze(1)).type(torch.cuda.FloatTensor)
92 | pivot = score[:, i].unsqueeze(1)
93 | loss_p = (1 - pivot + score) * targets_p
94 | loss_p = torch.sum(F.relu(loss_p))
95 | loss += loss_p
96 | return loss / batch_size
97 |
--------------------------------------------------------------------------------
/cirtorch/utils/evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def compute_ap(ranks, nres):
4 | """
5 | Computes average precision for given ranked indexes.
6 |
7 | Arguments
8 | ---------
9 | ranks : zerro-based ranks of positive images
10 | nres : number of positive images
11 |
12 | Returns
13 | -------
14 | ap : average precision
15 | """
16 |
17 | # number of images ranked by the system
18 | nimgranks = len(ranks)
19 |
20 | # accumulate trapezoids in PR-plot
21 | ap = 0
22 |
23 | recall_step = 1. / nres
24 |
25 | for j in np.arange(nimgranks):
26 | rank = ranks[j]
27 |
28 | if rank == 0:
29 | precision_0 = 1.
30 | else:
31 | precision_0 = float(j) / rank
32 |
33 | precision_1 = float(j + 1) / (rank + 1)
34 |
35 | ap += (precision_0 + precision_1) * recall_step / 2.
36 |
37 | return ap
38 |
39 | def compute_map(ranks, gnd, kappas=[]):
40 | """
41 | Computes the mAP for a given set of returned results.
42 |
43 | Usage:
44 | map = compute_map (ranks, gnd)
45 | computes mean average precsion (map) only
46 |
47 | map, aps, pr, prs = compute_map (ranks, gnd, kappas)
48 | computes mean average precision (map), average precision (aps) for each query
49 | computes mean precision at kappas (pr), precision at kappas (prs) for each query
50 |
51 | Notes:
52 | 1) ranks starts from 0, ranks.shape = db_size X #queries
53 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
54 | 3) If there are no positive images for some query, that query is excluded from the evaluation
55 | """
56 |
57 | map = 0.
58 | nq = len(gnd) # number of queries
59 | aps = np.zeros(nq)
60 | pr = np.zeros(len(kappas))
61 | prs = np.zeros((nq, len(kappas)))
62 | nempty = 0
63 |
64 | for i in np.arange(nq):
65 | qgnd = np.array(gnd[i]['ok'])
66 |
67 | # no positive images, skip from the average
68 | if qgnd.shape[0] == 0:
69 | aps[i] = float('nan')
70 | prs[i, :] = float('nan')
71 | nempty += 1
72 | continue
73 |
74 | try:
75 | qgndj = np.array(gnd[i]['junk'])
76 | except:
77 | qgndj = np.empty(0)
78 |
79 | # sorted positions of positive and junk images (0 based)
80 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)]
81 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)]
82 |
83 | k = 0;
84 | ij = 0;
85 | if len(junk):
86 | # decrease positions of positives based on the number of
87 | # junk images appearing before them
88 | ip = 0
89 | while (ip < len(pos)):
90 | while (ij < len(junk) and pos[ip] > junk[ij]):
91 | k += 1
92 | ij += 1
93 | pos[ip] = pos[ip] - k
94 | ip += 1
95 |
96 | # compute ap
97 | ap = compute_ap(pos, len(qgnd))
98 | map = map + ap
99 | aps[i] = ap
100 |
101 | # compute precision @ k
102 | pos += 1 # get it to 1-based
103 | for j in np.arange(len(kappas)):
104 | kq = min(max(pos), kappas[j]);
105 | prs[i, j] = (pos <= kq).sum() / kq
106 | pr = pr + prs[i, :]
107 |
108 | map = map / (nq - nempty)
109 | pr = pr / (nq - nempty)
110 |
111 | return map, aps, pr, prs
112 |
113 |
114 | def compute_map_and_print(dataset, ranks, gnd, kappas=[1, 5, 10]):
115 |
116 | # old evaluation protocol
117 | if dataset.startswith('oxford5k') or dataset.startswith('paris6k'):
118 | map, aps, _, _ = compute_map(ranks, gnd)
119 | print('>> {}: mAP {:.2f}'.format(dataset, np.around(map*100, decimals=2)))
120 |
121 | # new evaluation protocol
122 | elif dataset.startswith('roxford5k') or dataset.startswith('rparis6k'):
123 |
124 | gnd_t = []
125 | for i in range(len(gnd)):
126 | g = {}
127 | g['ok'] = np.concatenate([gnd[i]['easy']])
128 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['hard']])
129 | gnd_t.append(g)
130 | mapE, apsE, mprE, prsE = compute_map(ranks, gnd_t, kappas)
131 |
132 | gnd_t = []
133 | for i in range(len(gnd)):
134 | g = {}
135 | g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']])
136 | g['junk'] = np.concatenate([gnd[i]['junk']])
137 | gnd_t.append(g)
138 | mapM, apsM, mprM, prsM = compute_map(ranks, gnd_t, kappas)
139 |
140 | gnd_t = []
141 | for i in range(len(gnd)):
142 | g = {}
143 | g['ok'] = np.concatenate([gnd[i]['hard']])
144 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']])
145 | gnd_t.append(g)
146 | mapH, apsH, mprH, prsH = compute_map(ranks, gnd_t, kappas)
147 |
148 | print('>> {}: mAP E: {}, M: {}, H: {}'.format(dataset, np.around(mapE*100, decimals=2), np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2)))
149 | print('>> {}: mP@k{} E: {}, M: {}, H: {}'.format(dataset, kappas, np.around(mprE*100, decimals=2), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2)))
--------------------------------------------------------------------------------
/cirtorch/layers/functional.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pdb
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | # --------------------------------------
8 | # pooling
9 | # --------------------------------------
10 |
11 | def mac(x):
12 | return F.max_pool2d(x, (x.size(-2), x.size(-1)))
13 | # return F.adaptive_max_pool2d(x, (1,1)) # alternative
14 |
15 |
16 | def spoc(x):
17 | return F.avg_pool2d(x, (x.size(-2), x.size(-1)))
18 | # return F.adaptive_avg_pool2d(x, (1,1)) # alternative
19 |
20 |
21 | def gem(x, p=3, eps=1e-6):
22 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
23 | # return F.lp_pool2d(F.threshold(x, eps, eps), p, (x.size(-2), x.size(-1))) # alternative
24 |
25 |
26 | def rmac(x, L=3, eps=1e-6):
27 | ovr = 0.4 # desired overlap of neighboring regions
28 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension
29 |
30 | W = x.size(3)
31 | H = x.size(2)
32 |
33 | w = min(W, H)
34 | w2 = math.floor(w/2.0 - 1)
35 |
36 | b = (max(H, W)-w)/(steps-1)
37 | (tmp, idx) = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension
38 |
39 | # region overplus per dimension
40 | Wd = 0;
41 | Hd = 0;
42 | if H < W:
43 | Wd = idx.item() + 1
44 | elif H > W:
45 | Hd = idx.item() + 1
46 |
47 | v = F.max_pool2d(x, (x.size(-2), x.size(-1)))
48 | v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v)
49 |
50 | for l in range(1, L+1):
51 | wl = math.floor(2*w/(l+1))
52 | wl2 = math.floor(wl/2 - 1)
53 |
54 | if l+Wd == 1:
55 | b = 0
56 | else:
57 | b = (W-wl)/(l+Wd-1)
58 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates
59 | if l+Hd == 1:
60 | b = 0
61 | else:
62 | b = (H-wl)/(l+Hd-1)
63 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates
64 |
65 | for i_ in cenH.tolist():
66 | for j_ in cenW.tolist():
67 | if wl == 0:
68 | continue
69 | R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:]
70 | R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()]
71 | vt = F.max_pool2d(R, (R.size(-2), R.size(-1)))
72 | vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt)
73 | v += vt
74 |
75 | return v
76 |
77 |
78 | def roipool(x, rpool, L=3, eps=1e-6):
79 | ovr = 0.4 # desired overlap of neighboring regions
80 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension
81 |
82 | W = x.size(3)
83 | H = x.size(2)
84 |
85 | w = min(W, H)
86 | w2 = math.floor(w/2.0 - 1)
87 |
88 | b = (max(H, W)-w)/(steps-1)
89 | _, idx = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension
90 |
91 | # region overplus per dimension
92 | Wd = 0;
93 | Hd = 0;
94 | if H < W:
95 | Wd = idx.item() + 1
96 | elif H > W:
97 | Hd = idx.item() + 1
98 |
99 | vecs = []
100 | vecs.append(rpool(x).unsqueeze(1))
101 |
102 | for l in range(1, L+1):
103 | wl = math.floor(2*w/(l+1))
104 | wl2 = math.floor(wl/2 - 1)
105 |
106 | if l+Wd == 1:
107 | b = 0
108 | else:
109 | b = (W-wl)/(l+Wd-1)
110 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b).int() - wl2 # center coordinates
111 | if l+Hd == 1:
112 | b = 0
113 | else:
114 | b = (H-wl)/(l+Hd-1)
115 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b).int() - wl2 # center coordinates
116 |
117 | for i_ in cenH.tolist():
118 | for j_ in cenW.tolist():
119 | if wl == 0:
120 | continue
121 | vecs.append(rpool(x.narrow(2,i_,wl).narrow(3,j_,wl)).unsqueeze(1))
122 |
123 | return torch.cat(vecs, dim=1)
124 |
125 |
126 | # --------------------------------------
127 | # normalization
128 | # --------------------------------------
129 |
130 | def l2n(x, eps=1e-6):
131 | return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x)
132 |
133 | def powerlaw(x, eps=1e-6):
134 | x = x + self.eps
135 | return x.abs().sqrt().mul(x.sign())
136 |
137 | # --------------------------------------
138 | # loss
139 | # --------------------------------------
140 |
141 | def contrastive_loss(x, label, margin=0.7, eps=1e-6):
142 | # x is D x N
143 | dim = x.size(0) # D
144 | nq = torch.sum(label.data==-1) # number of tuples
145 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n
146 |
147 | x1 = x[:, ::S].permute(1,0).repeat(1,S-1).view((S-1)*nq,dim).permute(1,0)
148 | idx = [i for i in range(len(label)) if label.data[i] != -1]
149 | x2 = x[:, idx]
150 | lbl = label[label!=-1]
151 |
152 | dif = x1 - x2
153 | D = torch.pow(dif+eps, 2).sum(dim=0).sqrt()
154 |
155 | y = 0.5*lbl*torch.pow(D,2) + 0.5*(1-lbl)*torch.pow(torch.clamp(margin-D, min=0),2)
156 | y = torch.sum(y)
157 | return y
158 |
159 | def triplet_loss(x, label, margin=0.1):
160 | # x is D x N
161 | dim = x.size(0) # D
162 | nq = torch.sum(label.data==-1).item() # number of tuples
163 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n
164 |
165 | xa = x[:, label.data==-1].permute(1,0).repeat(1,S-2).view((S-2)*nq,dim).permute(1,0)
166 | xp = x[:, label.data==1].permute(1,0).repeat(1,S-2).view((S-2)*nq,dim).permute(1,0)
167 | xn = x[:, label.data==0]
168 |
169 | dist_pos = torch.sum(torch.pow(xa - xp, 2), dim=0)
170 | dist_neg = torch.sum(torch.pow(xa - xn, 2), dim=0)
171 |
172 | return torch.sum(torch.clamp(dist_pos - dist_neg + margin, min=0))
173 |
--------------------------------------------------------------------------------
/utils/retrieval_feature.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # /usr/bin/env pythpn
3 |
4 | '''
5 | Author: yinhao
6 | Email: yinhao_x@163.com
7 | Wechat: xss_yinhao
8 | Github: http://github.com/yinhaoxs
9 | data: 2019-11-23 18:26
10 | desc:
11 | '''
12 |
13 | import os
14 | from PIL import Image
15 | from lshash.lshash import LSHash
16 | import torch
17 | from torchvision import transforms
18 | from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
19 |
20 | # setting up the visible GPU
21 | os.environ['CUDA_VISIBLE_DEVICES'] = "0"
22 |
23 |
24 | class ImageProcess():
25 | def __init__(self, img_dir):
26 | self.img_dir = img_dir
27 |
28 | def process(self):
29 | imgs = list()
30 | for root, dirs, files in os.walk(self.img_dir):
31 | for file in files:
32 | img_path = os.path.join(root + os.sep, file)
33 | try:
34 | image = Image.open(img_path)
35 | if max(image.size) / min(image.size) < 5:
36 | imgs.append(img_path)
37 | else:
38 | continue
39 | except:
40 | print("image height/width ratio is small")
41 |
42 | return imgs
43 |
44 |
45 | class AntiFraudFeatureDataset():
46 | def __init__(self, img_dir, network, feature_path='', index_path=''):
47 | self.img_dir = img_dir
48 | self.network = network
49 | self.feature_path = feature_path
50 | self.index_path = index_path
51 |
52 | def constructfeature(self, hash_size, input_dim, num_hashtables):
53 | multiscale = '[1]'
54 | print(">> Loading network:\n>>>> '{}'".format(self.network))
55 | # state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
56 | state = torch.load(self.network)
57 | # parsing net params from meta
58 | # architecture, pooling, mean, std required
59 | # the rest has default values, in case that is doesnt exist
60 | net_params = {}
61 | net_params['architecture'] = state['meta']['architecture']
62 | net_params['pooling'] = state['meta']['pooling']
63 | net_params['local_whitening'] = state['meta'].get('local_whitening', False)
64 | net_params['regional'] = state['meta'].get('regional', False)
65 | net_params['whitening'] = state['meta'].get('whitening', False)
66 | net_params['mean'] = state['meta']['mean']
67 | net_params['std'] = state['meta']['std']
68 | net_params['pretrained'] = False
69 | # network initialization
70 | net = init_network(net_params)
71 | net.load_state_dict(state['state_dict'])
72 | print(">>>> loaded network: ")
73 | print(net.meta_repr())
74 | # setting up the multi-scale parameters
75 | ms = list(eval(multiscale))
76 | print(">>>> Evaluating scales: {}".format(ms))
77 | # moving network to gpu and eval mode
78 | if torch.cuda.is_available():
79 | net.cuda()
80 | net.eval()
81 |
82 | # set up the transform
83 | normalize = transforms.Normalize(
84 | mean=net.meta['mean'],
85 | std=net.meta['std']
86 | )
87 | transform = transforms.Compose([
88 | transforms.ToTensor(),
89 | normalize
90 | ])
91 |
92 | # extract database and query vectors
93 | print('>> database images...')
94 | images = ImageProcess(self.img_dir).process()
95 | vecs, img_paths = extract_vectors(net, images, 1024, transform, ms=ms)
96 | feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T)))
97 | # index
98 | lsh = LSHash(hash_size=int(hash_size), input_dim=int(input_dim), num_hashtables=int(num_hashtables))
99 | for img_path, vec in feature_dict.items():
100 | lsh.index(vec.flatten(), extra_data=img_path)
101 |
102 | # ## 保存索引模型
103 | # with open(self.feature_path, "wb") as f:
104 | # pickle.dump(feature_dict, f)
105 | # with open(self.index_path, "wb") as f:
106 | # pickle.dump(lsh, f)
107 |
108 | print("extract feature is done")
109 | return feature_dict, lsh
110 |
111 | def test_feature(self):
112 | multiscale = '[1]'
113 | print(">> Loading network:\n>>>> '{}'".format(self.network))
114 | # state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
115 | state = torch.load(self.network)
116 | # parsing net params from meta
117 | # architecture, pooling, mean, std required
118 | # the rest has default values, in case that is doesnt exist
119 | net_params = {}
120 | net_params['architecture'] = state['meta']['architecture']
121 | net_params['pooling'] = state['meta']['pooling']
122 | net_params['local_whitening'] = state['meta'].get('local_whitening', False)
123 | net_params['regional'] = state['meta'].get('regional', False)
124 | net_params['whitening'] = state['meta'].get('whitening', False)
125 | net_params['mean'] = state['meta']['mean']
126 | net_params['std'] = state['meta']['std']
127 | net_params['pretrained'] = False
128 | # network initialization
129 | net = init_network(net_params)
130 | net.load_state_dict(state['state_dict'])
131 | print(">>>> loaded network: ")
132 | print(net.meta_repr())
133 | # setting up the multi-scale parameters
134 | ms = list(eval(multiscale))
135 | print(">>>> Evaluating scales: {}".format(ms))
136 | # moving network to gpu and eval mode
137 | if torch.cuda.is_available():
138 | net.cuda()
139 | net.eval()
140 |
141 | # set up the transform
142 | normalize = transforms.Normalize(
143 | mean=net.meta['mean'],
144 | std=net.meta['std']
145 | )
146 | transform = transforms.Compose([
147 | transforms.ToTensor(),
148 | normalize
149 | ])
150 |
151 | # extract database and query vectors
152 | print('>> database images...')
153 | images = ImageProcess(self.img_dir).process()
154 | vecs, img_paths = extract_vectors(net, images, 1024, transform, ms=ms)
155 | feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T)))
156 | return feature_dict
157 |
158 |
159 | if __name__ == '__main__':
160 | pass
161 |
--------------------------------------------------------------------------------
/cirtorch/examples/test_e2e.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import pickle
5 | import pdb
6 |
7 | import numpy as np
8 |
9 | import torch
10 | from torch.utils.model_zoo import load_url
11 | from torchvision import transforms
12 |
13 | from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
14 | from cirtorch.datasets.testdataset import configdataset
15 | from cirtorch.utils.download import download_train, download_test
16 | from cirtorch.utils.evaluate import compute_map_and_print
17 | from cirtorch.utils.general import get_data_root, htime
18 |
19 | PRETRAINED = {
20 | 'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',
21 | 'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',
22 | 'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth',
23 | 'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',
24 | 'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',
25 | 'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth',
26 | }
27 |
28 | datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
29 |
30 | parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing End-to-End')
31 |
32 | # test options
33 | parser.add_argument('--network', '-n', metavar='NETWORK',
34 | help="network to be evaluated: " +
35 | " | ".join(PRETRAINED.keys()))
36 | parser.add_argument('--datasets', '-d', metavar='DATASETS', default='roxford5k,rparis6k',
37 | help="comma separated list of test datasets: " +
38 | " | ".join(datasets_names) +
39 | " (default: 'roxford5k,rparis6k')")
40 | parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N',
41 | help="maximum size of longer image side used for testing (default: 1024)")
42 | parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]',
43 | help="use multiscale vectors for testing, " +
44 | " examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')")
45 |
46 | # GPU ID
47 | parser.add_argument('--gpu-id', '-g', default='0', metavar='N',
48 | help="gpu id used for testing (default: '0')")
49 |
50 | def main():
51 | args = parser.parse_args()
52 |
53 | # check if there are unknown datasets
54 | for dataset in args.datasets.split(','):
55 | if dataset not in datasets_names:
56 | raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset))
57 |
58 | # check if test dataset are downloaded
59 | # and download if they are not
60 | download_train(get_data_root())
61 | download_test(get_data_root())
62 |
63 | # setting up the visible GPU
64 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
65 |
66 | # loading network
67 | # pretrained networks (downloaded automatically)
68 | print(">> Loading network:\n>>>> '{}'".format(args.network))
69 | state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
70 | # state = torch.load(args.network)
71 | # parsing net params from meta
72 | # architecture, pooling, mean, std required
73 | # the rest has default values, in case that is doesnt exist
74 | net_params = {}
75 | net_params['architecture'] = state['meta']['architecture']
76 | net_params['pooling'] = state['meta']['pooling']
77 | net_params['local_whitening'] = state['meta'].get('local_whitening', False)
78 | net_params['regional'] = state['meta'].get('regional', False)
79 | net_params['whitening'] = state['meta'].get('whitening', False)
80 | net_params['mean'] = state['meta']['mean']
81 | net_params['std'] = state['meta']['std']
82 | net_params['pretrained'] = False
83 | # network initialization
84 | net = init_network(net_params)
85 | net.load_state_dict(state['state_dict'])
86 |
87 | print(">>>> loaded network: ")
88 | print(net.meta_repr())
89 |
90 | # setting up the multi-scale parameters
91 | ms = list(eval(args.multiscale))
92 | print(">>>> Evaluating scales: {}".format(ms))
93 |
94 | # moving network to gpu and eval mode
95 | net.cuda()
96 | net.eval()
97 |
98 | # set up the transform
99 | normalize = transforms.Normalize(
100 | mean=net.meta['mean'],
101 | std=net.meta['std']
102 | )
103 | transform = transforms.Compose([
104 | transforms.ToTensor(),
105 | normalize
106 | ])
107 |
108 | # evaluate on test datasets
109 | datasets = args.datasets.split(',')
110 | for dataset in datasets:
111 | start = time.time()
112 |
113 | print('>> {}: Extracting...'.format(dataset))
114 |
115 | # prepare config structure for the test dataset
116 | cfg = configdataset(dataset, os.path.join(get_data_root(), 'test'))
117 | images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])]
118 | qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])]
119 | try:
120 | bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
121 | except:
122 | bbxs = None # for holidaysmanrot and copydays
123 |
124 | # extract database and query vectors
125 | print('>> {}: database images...'.format(dataset))
126 | vecs = extract_vectors(net, images, args.image_size, transform, ms=ms)
127 | print('>> {}: query images...'.format(dataset))
128 | qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms)
129 |
130 | print('>> {}: Evaluating...'.format(dataset))
131 |
132 | # convert to numpy
133 | vecs = vecs.numpy()
134 | qvecs = qvecs.numpy()
135 |
136 | # search, rank, and print
137 | scores = np.dot(vecs.T, qvecs)
138 | ranks = np.argsort(-scores, axis=0)
139 | compute_map_and_print(dataset, ranks, cfg['gnd'])
140 |
141 | print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start)))
142 |
143 |
144 | if __name__ == '__main__':
145 | main()
146 |
--------------------------------------------------------------------------------
/nts/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.utils.data
3 | from torch.nn import DataParallel
4 | from datetime import datetime
5 | from torch.optim.lr_scheduler import MultiStepLR
6 | from config import BATCH_SIZE, PROPOSAL_NUM, SAVE_FREQ, LR, WD, resume, save_dir
7 | from core import model, dataset
8 | from core.utils import init_log, progress_bar
9 |
10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
11 | start_epoch = 1
12 | save_dir = os.path.join(save_dir, datetime.now().strftime('%Y%m%d_%H%M%S'))
13 | if os.path.exists(save_dir):
14 | raise NameError('model dir exists!')
15 | os.makedirs(save_dir)
16 | logging = init_log(save_dir)
17 | _print = logging.info
18 |
19 | # read dataset
20 | trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None)
21 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
22 | shuffle=True, num_workers=8, drop_last=False)
23 | testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None)
24 | testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
25 | shuffle=False, num_workers=8, drop_last=False)
26 | # define model
27 | net = model.attention_net(topN=PROPOSAL_NUM)
28 | if resume:
29 | ckpt = torch.load(resume)
30 | net.load_state_dict(ckpt['net_state_dict'])
31 | start_epoch = ckpt['epoch'] + 1
32 | creterion = torch.nn.CrossEntropyLoss()
33 |
34 | # define optimizers
35 | raw_parameters = list(net.pretrained_model.parameters())
36 | part_parameters = list(net.proposal_net.parameters())
37 | concat_parameters = list(net.concat_net.parameters())
38 | partcls_parameters = list(net.partcls_net.parameters())
39 |
40 | raw_optimizer = torch.optim.SGD(raw_parameters, lr=LR, momentum=0.9, weight_decay=WD)
41 | concat_optimizer = torch.optim.SGD(concat_parameters, lr=LR, momentum=0.9, weight_decay=WD)
42 | part_optimizer = torch.optim.SGD(part_parameters, lr=LR, momentum=0.9, weight_decay=WD)
43 | partcls_optimizer = torch.optim.SGD(partcls_parameters, lr=LR, momentum=0.9, weight_decay=WD)
44 | schedulers = [MultiStepLR(raw_optimizer, milestones=[60, 100], gamma=0.1),
45 | MultiStepLR(concat_optimizer, milestones=[60, 100], gamma=0.1),
46 | MultiStepLR(part_optimizer, milestones=[60, 100], gamma=0.1),
47 | MultiStepLR(partcls_optimizer, milestones=[60, 100], gamma=0.1)]
48 | net = net.cuda()
49 | net = DataParallel(net)
50 |
51 | for epoch in range(start_epoch, 500):
52 | for scheduler in schedulers:
53 | scheduler.step()
54 |
55 | # begin training
56 | _print('--' * 50)
57 | net.train()
58 | for i, data in enumerate(trainloader):
59 | img, label = data[0].cuda(), data[1].cuda()
60 | batch_size = img.size(0)
61 | raw_optimizer.zero_grad()
62 | part_optimizer.zero_grad()
63 | concat_optimizer.zero_grad()
64 | partcls_optimizer.zero_grad()
65 |
66 | raw_logits, concat_logits, part_logits, _, top_n_prob = net(img)
67 | part_loss = model.list_loss(part_logits.view(batch_size * PROPOSAL_NUM, -1),
68 | label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)).view(batch_size, PROPOSAL_NUM)
69 | raw_loss = creterion(raw_logits, label)
70 | concat_loss = creterion(concat_logits, label)
71 | rank_loss = model.ranking_loss(top_n_prob, part_loss)
72 | partcls_loss = creterion(part_logits.view(batch_size * PROPOSAL_NUM, -1),
73 | label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1))
74 |
75 | total_loss = raw_loss + rank_loss + concat_loss + partcls_loss
76 | total_loss.backward()
77 | raw_optimizer.step()
78 | part_optimizer.step()
79 | concat_optimizer.step()
80 | partcls_optimizer.step()
81 | progress_bar(i, len(trainloader), 'train')
82 |
83 | if epoch % SAVE_FREQ == 0:
84 | train_loss = 0
85 | train_correct = 0
86 | total = 0
87 | net.eval()
88 | for i, data in enumerate(trainloader):
89 | with torch.no_grad():
90 | img, label = data[0].cuda(), data[1].cuda()
91 | batch_size = img.size(0)
92 | _, concat_logits, _, _, _ = net(img)
93 | # calculate loss
94 | concat_loss = creterion(concat_logits, label)
95 | # calculate accuracy
96 | _, concat_predict = torch.max(concat_logits, 1)
97 | total += batch_size
98 | train_correct += torch.sum(concat_predict.data == label.data)
99 | train_loss += concat_loss.item() * batch_size
100 | progress_bar(i, len(trainloader), 'eval train set')
101 |
102 | train_acc = float(train_correct) / total
103 | train_loss = train_loss / total
104 |
105 | _print(
106 | 'epoch:{} - train loss: {:.3f} and train acc: {:.3f} total sample: {}'.format(
107 | epoch,
108 | train_loss,
109 | train_acc,
110 | total))
111 |
112 | # evaluate on test set
113 | test_loss = 0
114 | test_correct = 0
115 | total = 0
116 | for i, data in enumerate(testloader):
117 | with torch.no_grad():
118 | img, label = data[0].cuda(), data[1].cuda()
119 | batch_size = img.size(0)
120 | _, concat_logits, _, _, _ = net(img)
121 | # calculate loss
122 | concat_loss = creterion(concat_logits, label)
123 | # calculate accuracy
124 | _, concat_predict = torch.max(concat_logits, 1)
125 | total += batch_size
126 | test_correct += torch.sum(concat_predict.data == label.data)
127 | test_loss += concat_loss.item() * batch_size
128 | progress_bar(i, len(testloader), 'eval test set')
129 |
130 | test_acc = float(test_correct) / total
131 | test_loss = test_loss / total
132 | _print(
133 | 'epoch:{} - test loss: {:.3f} and test acc: {:.3f} total sample: {}'.format(
134 | epoch,
135 | test_loss,
136 | test_acc,
137 | total))
138 |
139 | # save model
140 | net_state_dict = net.module.state_dict()
141 | if not os.path.exists(save_dir):
142 | os.mkdir(save_dir)
143 | torch.save({
144 | 'epoch': epoch,
145 | 'train_loss': train_loss,
146 | 'train_acc': train_acc,
147 | 'test_loss': test_loss,
148 | 'test_acc': test_acc,
149 | 'net_state_dict': net_state_dict},
150 | os.path.join(save_dir, '%03d.ckpt' % epoch))
151 |
152 | print('finishing training')
153 |
--------------------------------------------------------------------------------
/utils/retrieval_index.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # /usr/bin/env pythpn
3 |
4 | '''
5 | Author: yinhao
6 | Email: yinhao_x@163.com
7 | Wechat: xss_yinhao
8 | Github: http://github.com/yinhaoxs
9 | data: 2019-11-23 18:27
10 | desc:
11 | '''
12 |
13 | import os
14 | import shutil
15 | import numpy as np
16 | import pandas as pd
17 |
18 |
19 | class EvaluteMap():
20 | def __init__(self, out_similar_dir='', out_similar_file_dir='', all_csv_file='', feature_path='', index_path=''):
21 | self.out_similar_dir = out_similar_dir
22 | self.out_similar_file_dir = out_similar_file_dir
23 | self.all_csv_file = all_csv_file
24 | self.feature_path = feature_path
25 | self.index_path = index_path
26 |
27 |
28 | def get_dict(self, query_no, query_id, simi_no, simi_id, num, score):
29 | new_dict = {
30 | 'index': str(num),
31 | 'id1': str(query_id),
32 | 'id2': str(simi_id),
33 | 'no1': str(query_no),
34 | 'no2': str(simi_no),
35 | 'score': score
36 | }
37 | return new_dict
38 |
39 |
40 | def find_similar_img_gyz(self, feature_dict, lsh, num_results):
41 | for q_path, q_vec in feature_dict.items():
42 | try:
43 | response = lsh.query(q_vec.flatten(), num_results=int(num_results), distance_func="cosine")
44 | query_img_path0 = response[0][0][1]
45 | query_img_path1 = response[1][0][1]
46 | query_img_path2 = response[2][0][1]
47 | # score0 = response[0][1]
48 | # score0 = np.rint(100 * (1 - score0))
49 | print('**********************************************')
50 | print('input img: {}'.format(q_path))
51 | print('query0 img: {}'.format(query_img_path0))
52 | print('query1 img: {}'.format(query_img_path1))
53 | print('query2 img: {}'.format(query_img_path2))
54 | except:
55 | continue
56 |
57 |
58 | def find_similar_img(self, feature_dict, lsh, num_results):
59 | num = 0
60 | result_list = list()
61 | for q_path, q_vec in feature_dict.items():
62 | response = lsh.query(q_vec.flatten(), num_results=int(num_results), distance_func="cosine")
63 | s_path_list, s_vec_list, s_id_list, s_no_list, score_list = list(), list(), list(), list(), list()
64 | q_path = q_path[0]
65 | q_no, q_id = q_path.split("\\")[-2], q_path.split("\\")[-1]
66 | try:
67 | for i in range(int(num_results)):
68 | s_path, s_vec = response[i][0][1], response[i][0][0]
69 | s_path = s_path[0]
70 | s_no, s_id = s_path.split("\\")[-2], s_path.split("\\")[-1]
71 | if str(s_no) != str(q_no):
72 | score = np.rint(100 * (1 - response[i][1]))
73 | s_path_list.append(s_path)
74 | s_vec_list.append(s_vec)
75 | s_id_list.append(s_id)
76 | s_no_list.append(s_no)
77 | score_list.append(score)
78 | else:
79 | continue
80 |
81 | if len(s_path_list) != 0:
82 | index = score_list.index(max(score_list))
83 | s_path, s_vec, s_id, s_no, score = s_path_list[index], s_vec_list[index], s_id_list[index], \
84 | s_no_list[index], score_list[index]
85 | else:
86 | s_path, s_vec, s_id, s_no, score = None, None, None, None, None
87 | except:
88 | s_path, s_vec, s_id, s_no, score = None, None, None, None, None
89 |
90 | try:
91 | ##拷贝文件到指定文件夹
92 | num += 1
93 | des_path = os.path.join(self.out_similar_dir, str(num))
94 | if not os.path.exists(des_path):
95 | os.makedirs(des_path)
96 | shutil.copy(q_path, des_path)
97 | os.rename(os.path.join(des_path, q_id), os.path.join(des_path, "query_" + q_no + "_" + q_id))
98 | if s_path != None:
99 | shutil.copy(s_path, des_path)
100 | os.rename(os.path.join(des_path, s_id), os.path.join(des_path, s_no + "_" + s_id))
101 |
102 | new_dict = self.get_dict(q_no, q_id, s_no, s_id, num, score)
103 | result_list.append(new_dict)
104 | except:
105 | continue
106 |
107 | try:
108 | result_s = pd.DataFrame.from_dict(result_list)
109 | result_s.to_csv(self.all_csv_file, encoding="gbk", index=False)
110 | except:
111 | print("write error")
112 |
113 |
114 | def filter_gap_score(self):
115 | for value in range(90, 101):
116 | try:
117 | pd_df = pd.read_csv(self.all_csv_file, encoding="gbk", error_bad_lines=False)
118 | pd_tmp = pd_df[pd_df["score"] == int(value)]
119 | if not os.path.exists(self.out_similar_file_dir):
120 | os.makedirs(self.out_similar_file_dir)
121 |
122 | try:
123 | results_split_csv = os.path.join(self.out_similar_file_dir + os.sep,
124 | "filter_{}.csv".format(str(value)))
125 | pd_tmp.to_csv(results_split_csv, encoding="gbk", index=False)
126 | except:
127 | print("write part error")
128 |
129 | lines = pd_df[pd_df["score"] == int(value)]["index"]
130 | num = 0
131 | for line in lines:
132 | des_path_temp = os.path.join(self.out_similar_file_dir + os.sep, str(value), str(line))
133 | if not os.path.exists(des_path_temp):
134 | os.makedirs(des_path_temp)
135 | pairs_path = os.path.join(self.out_similar_dir + os.sep, str(line))
136 | for img_id in os.listdir(pairs_path):
137 | img_path = os.path.join(pairs_path + os.sep, img_id)
138 | shutil.copy(img_path, des_path_temp)
139 | except:
140 | print("error")
141 |
142 |
143 | def retrieval_images(self, feature_dict, lsh, num_results=1):
144 | # load model
145 | # with open(self.feature_path, "rb") as f:
146 | # feature_dict = pickle.load(f)
147 | # with open(self.index_path, "rb") as f:
148 | # lsh = pickle.load(f)
149 |
150 | self.find_similar_img_gyz(feature_dict, lsh, num_results)
151 | # self.filter_gap_score()
152 |
153 |
154 | if __name__ == "__main__":
155 | pass
156 |
--------------------------------------------------------------------------------
/nts/core/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6 | 'resnet152']
7 |
8 | model_urls = {
9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
14 | }
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | "3x3 convolution with padding"
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=1, bias=False)
21 |
22 |
23 | class BasicBlock(nn.Module):
24 | expansion = 1
25 |
26 | def __init__(self, inplanes, planes, stride=1, downsample=None):
27 | super(BasicBlock, self).__init__()
28 | self.conv1 = conv3x3(inplanes, planes, stride)
29 | self.bn1 = nn.BatchNorm2d(planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.conv2 = conv3x3(planes, planes)
32 | self.bn2 = nn.BatchNorm2d(planes)
33 | self.downsample = downsample
34 | self.stride = stride
35 |
36 | def forward(self, x):
37 | residual = x
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 |
46 | if self.downsample is not None:
47 | residual = self.downsample(x)
48 |
49 | out += residual
50 | out = self.relu(out)
51 |
52 | return out
53 |
54 |
55 | class Bottleneck(nn.Module):
56 | expansion = 4
57 |
58 | def __init__(self, inplanes, planes, stride=1, downsample=None):
59 | super(Bottleneck, self).__init__()
60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61 | self.bn1 = nn.BatchNorm2d(planes)
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63 | padding=1, bias=False)
64 | self.bn2 = nn.BatchNorm2d(planes)
65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
66 | self.bn3 = nn.BatchNorm2d(planes * 4)
67 | self.relu = nn.ReLU(inplace=True)
68 | self.downsample = downsample
69 | self.stride = stride
70 |
71 | def forward(self, x):
72 | residual = x
73 |
74 | out = self.conv1(x)
75 | out = self.bn1(out)
76 | out = self.relu(out)
77 |
78 | out = self.conv2(out)
79 | out = self.bn2(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv3(out)
83 | out = self.bn3(out)
84 |
85 | if self.downsample is not None:
86 | residual = self.downsample(x)
87 |
88 | out += residual
89 | out = self.relu(out)
90 |
91 | return out
92 |
93 |
94 | class ResNet(nn.Module):
95 | def __init__(self, block, layers, num_classes=1000):
96 | self.inplanes = 64
97 | super(ResNet, self).__init__()
98 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
99 | bias=False)
100 | self.bn1 = nn.BatchNorm2d(64)
101 | self.relu = nn.ReLU(inplace=True)
102 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
103 | self.layer1 = self._make_layer(block, 64, layers[0])
104 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
105 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
106 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
107 | self.avgpool = nn.AvgPool2d(7)
108 | self.fc = nn.Linear(512 * block.expansion, num_classes)
109 |
110 | for m in self.modules():
111 | if isinstance(m, nn.Conv2d):
112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
113 | m.weight.data.normal_(0, math.sqrt(2. / n))
114 | elif isinstance(m, nn.BatchNorm2d):
115 | m.weight.data.fill_(1)
116 | m.bias.data.zero_()
117 |
118 | def _make_layer(self, block, planes, blocks, stride=1):
119 | downsample = None
120 | if stride != 1 or self.inplanes != planes * block.expansion:
121 | downsample = nn.Sequential(
122 | nn.Conv2d(self.inplanes, planes * block.expansion,
123 | kernel_size=1, stride=stride, bias=False),
124 | nn.BatchNorm2d(planes * block.expansion),
125 | )
126 |
127 | layers = []
128 | layers.append(block(self.inplanes, planes, stride, downsample))
129 | self.inplanes = planes * block.expansion
130 | for i in range(1, blocks):
131 | layers.append(block(self.inplanes, planes))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 | x = self.conv1(x)
137 | x = self.bn1(x)
138 | x = self.relu(x)
139 | x = self.maxpool(x)
140 |
141 | x = self.layer1(x)
142 | x = self.layer2(x)
143 | x = self.layer3(x)
144 | x = self.layer4(x)
145 | feature1 = x
146 | x = self.avgpool(x)
147 | x = x.view(x.size(0), -1)
148 | x = nn.Dropout(p=0.5)(x)
149 | feature2 = x
150 | x = self.fc(x)
151 |
152 | return x, feature1, feature2
153 |
154 |
155 | def resnet18(pretrained=False, **kwargs):
156 | """Constructs a ResNet-18 model.
157 |
158 | Args:
159 | pretrained (bool): If True, returns a model pre-trained on ImageNet
160 | """
161 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
162 | if pretrained:
163 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
164 | return model
165 |
166 |
167 | def resnet34(pretrained=False, **kwargs):
168 | """Constructs a ResNet-34 model.
169 |
170 | Args:
171 | pretrained (bool): If True, returns a model pre-trained on ImageNet
172 | """
173 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
174 | if pretrained:
175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
176 | return model
177 |
178 |
179 | def resnet50(pretrained=False, **kwargs):
180 | """Constructs a ResNet-50 model.
181 |
182 | Args:
183 | pretrained (bool): If True, returns a model pre-trained on ImageNet
184 | """
185 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
186 | if pretrained:
187 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
188 | return model
189 |
190 |
191 | def resnet101(pretrained=False, **kwargs):
192 | """Constructs a ResNet-101 model.
193 |
194 | Args:
195 | pretrained (bool): If True, returns a model pre-trained on ImageNet
196 | """
197 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
198 | if pretrained:
199 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
200 | return model
201 |
202 |
203 | def resnet152(pretrained=False, **kwargs):
204 | """Constructs a ResNet-152 model.
205 |
206 | Args:
207 | pretrained (bool): If True, returns a model pre-trained on ImageNet
208 | """
209 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
210 | if pretrained:
211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
212 | return model
213 |
--------------------------------------------------------------------------------
/cirtorch/utils/download.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def download_test(data_dir):
4 | """
5 | DOWNLOAD_TEST Checks, and, if required, downloads the necessary datasets for the testing.
6 |
7 | download_test(DATA_ROOT) checks if the data necessary for running the example script exist.
8 | If not it downloads it in the folder structure:
9 | DATA_ROOT/test/oxford5k/ : folder with Oxford images and ground truth file
10 | DATA_ROOT/test/paris6k/ : folder with Paris images and ground truth file
11 | DATA_ROOT/test/roxford5k/ : folder with Oxford images and revisited ground truth file
12 | DATA_ROOT/test/rparis6k/ : folder with Paris images and revisited ground truth file
13 | """
14 |
15 | # Create data folder if it does not exist
16 | if not os.path.isdir(data_dir):
17 | os.mkdir(data_dir)
18 |
19 | # Create datasets folder if it does not exist
20 | datasets_dir = os.path.join(data_dir, 'test')
21 | if not os.path.isdir(datasets_dir):
22 | os.mkdir(datasets_dir)
23 |
24 | # Download datasets folders test/DATASETNAME/
25 | datasets = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
26 | for di in range(len(datasets)):
27 | dataset = datasets[di]
28 |
29 | if dataset == 'oxford5k':
30 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings'
31 | dl_files = ['oxbuild_images.tgz']
32 | elif dataset == 'paris6k':
33 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings'
34 | dl_files = ['paris_1.tgz', 'paris_2.tgz']
35 | elif dataset == 'roxford5k':
36 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings'
37 | dl_files = ['oxbuild_images.tgz']
38 | elif dataset == 'rparis6k':
39 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings'
40 | dl_files = ['paris_1.tgz', 'paris_2.tgz']
41 | else:
42 | raise ValueError('Unknown dataset: {}!'.format(dataset))
43 |
44 | dst_dir = os.path.join(datasets_dir, dataset, 'jpg')
45 | if not os.path.isdir(dst_dir):
46 |
47 | # for oxford and paris download images
48 | if dataset == 'oxford5k' or dataset == 'paris6k':
49 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir))
50 | os.makedirs(dst_dir)
51 | for dli in range(len(dl_files)):
52 | dl_file = dl_files[dli]
53 | src_file = os.path.join(src_dir, dl_file)
54 | dst_file = os.path.join(dst_dir, dl_file)
55 | print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file))
56 | os.system('wget {} -O {}'.format(src_file, dst_file))
57 | print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file))
58 | # create tmp folder
59 | dst_dir_tmp = os.path.join(dst_dir, 'tmp')
60 | os.system('mkdir {}'.format(dst_dir_tmp))
61 | # extract in tmp folder
62 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp))
63 | # remove all (possible) subfolders by moving only files in dst_dir
64 | os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir))
65 | # remove tmp folder
66 | os.system('rm -rf {}'.format(dst_dir_tmp))
67 | print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file))
68 | os.system('rm {}'.format(dst_file))
69 |
70 | # for roxford and rparis just make sym links
71 | elif dataset == 'roxford5k' or dataset == 'rparis6k':
72 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir))
73 | dataset_old = dataset[1:]
74 | dst_dir_old = os.path.join(datasets_dir, dataset_old, 'jpg')
75 | os.mkdir(os.path.join(datasets_dir, dataset))
76 | os.system('ln -s {} {}'.format(dst_dir_old, dst_dir))
77 | print('>> Created symbolic link from {} jpg to {} jpg'.format(dataset_old, dataset))
78 |
79 |
80 | gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset)
81 | gnd_dst_dir = os.path.join(datasets_dir, dataset)
82 | gnd_dl_file = 'gnd_{}.pkl'.format(dataset)
83 | gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file)
84 | gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file)
85 | if not os.path.exists(gnd_dst_file):
86 | print('>> Downloading dataset {} ground truth file...'.format(dataset))
87 | os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file))
88 |
89 |
90 | def download_train(data_dir):
91 | """
92 | DOWNLOAD_TRAIN Checks, and, if required, downloads the necessary datasets for the training.
93 |
94 | download_train(DATA_ROOT) checks if the data necessary for running the example script exist.
95 | If not it downloads it in the folder structure:
96 | DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db files
97 | DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db files
98 | """
99 |
100 | # Create data folder if it does not exist
101 | if not os.path.isdir(data_dir):
102 | os.mkdir(data_dir)
103 |
104 | # Create datasets folder if it does not exist
105 | datasets_dir = os.path.join(data_dir, 'train')
106 | if not os.path.isdir(datasets_dir):
107 | os.mkdir(datasets_dir)
108 |
109 | # Download folder train/retrieval-SfM-120k/
110 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'ims')
111 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims')
112 | dl_file = 'ims.tar.gz'
113 | if not os.path.isdir(dst_dir):
114 | src_file = os.path.join(src_dir, dl_file)
115 | dst_file = os.path.join(dst_dir, dl_file)
116 | print('>> Image directory does not exist. Creating: {}'.format(dst_dir))
117 | os.makedirs(dst_dir)
118 | print('>> Downloading ims.tar.gz...')
119 | os.system('wget {} -O {}'.format(src_file, dst_file))
120 | print('>> Extracting {}...'.format(dst_file))
121 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir))
122 | print('>> Extracted, deleting {}...'.format(dst_file))
123 | os.system('rm {}'.format(dst_file))
124 |
125 | # Create symlink for train/retrieval-SfM-30k/
126 | dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims')
127 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims')
128 | if not os.path.isdir(dst_dir):
129 | os.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k'))
130 | os.system('ln -s {} {}'.format(dst_dir_old, dst_dir))
131 | print('>> Created symbolic link from retrieval-SfM-120k/ims to retrieval-SfM-30k/ims')
132 |
133 | # Download db files
134 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'dbs')
135 | datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k']
136 | for dataset in datasets:
137 | dst_dir = os.path.join(datasets_dir, dataset)
138 | if dataset == 'retrieval-SfM-120k':
139 | dl_files = ['{}.pkl'.format(dataset), '{}-whiten.pkl'.format(dataset)]
140 | elif dataset == 'retrieval-SfM-30k':
141 | dl_files = ['{}-whiten.pkl'.format(dataset)]
142 |
143 | if not os.path.isdir(dst_dir):
144 | print('>> Dataset directory does not exist. Creating: {}'.format(dst_dir))
145 | os.mkdir(dst_dir)
146 |
147 | for i in range(len(dl_files)):
148 | src_file = os.path.join(src_dir, dl_files[i])
149 | dst_file = os.path.join(dst_dir, dl_files[i])
150 | if not os.path.isfile(dst_file):
151 | print('>> DB file {} does not exist. Downloading...'.format(dl_files[i]))
152 | os.system('wget {} -O {}'.format(src_file, dst_file))
153 |
--------------------------------------------------------------------------------
/cirtorch/utils/download_win.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def download_test(data_dir):
4 | """
5 | DOWNLOAD_TEST Checks, and, if required, downloads the necessary datasets for the testing.
6 |
7 | download_test(DATA_ROOT) checks if the data necessary for running the example script exist.
8 | If not it downloads it in the folder structure:
9 | DATA_ROOT/test/oxford5k/ : folder with Oxford images and ground truth file
10 | DATA_ROOT/test/paris6k/ : folder with Paris images and ground truth file
11 | DATA_ROOT/test/roxford5k/ : folder with Oxford images and revisited ground truth file
12 | DATA_ROOT/test/rparis6k/ : folder with Paris images and revisited ground truth file
13 | """
14 |
15 | # Create data folder if it does not exist
16 | if not os.path.isdir(data_dir):
17 | os.mkdir(data_dir)
18 |
19 | # Create datasets folder if it does not exist
20 | datasets_dir = os.path.join(data_dir, 'test')
21 | if not os.path.isdir(datasets_dir):
22 | os.mkdir(datasets_dir)
23 |
24 | # Download datasets folders test/DATASETNAME/
25 | datasets = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
26 | for di in range(len(datasets)):
27 | dataset = datasets[di]
28 |
29 | if dataset == 'oxford5k':
30 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings'
31 | dl_files = ['oxbuild_images.tgz']
32 | elif dataset == 'paris6k':
33 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings'
34 | dl_files = ['paris_1.tgz', 'paris_2.tgz']
35 | elif dataset == 'roxford5k':
36 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings'
37 | dl_files = ['oxbuild_images.tgz']
38 | elif dataset == 'rparis6k':
39 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings'
40 | dl_files = ['paris_1.tgz', 'paris_2.tgz']
41 | else:
42 | raise ValueError('Unknown dataset: {}!'.format(dataset))
43 |
44 | dst_dir = os.path.join(datasets_dir, dataset, 'jpg')
45 | if not os.path.isdir(dst_dir):
46 |
47 | # for oxford and paris download images
48 | if dataset == 'oxford5k' or dataset == 'paris6k':
49 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir))
50 | os.makedirs(dst_dir)
51 | for dli in range(len(dl_files)):
52 | dl_file = dl_files[dli]
53 | src_file = os.path.join(src_dir, dl_file)
54 | dst_file = os.path.join(dst_dir, dl_file)
55 | print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file))
56 | os.system('wget {} -O {}'.format(src_file, dst_file))
57 | print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file))
58 | # create tmp folder
59 | dst_dir_tmp = os.path.join(dst_dir, 'tmp')
60 | os.system('mkdir {}'.format(dst_dir_tmp))
61 | # extract in tmp folder
62 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp))
63 | # remove all (possible) subfolders by moving only files in dst_dir
64 | os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir))
65 | # remove tmp folder
66 | os.system('rd {}'.format(dst_dir_tmp))
67 | print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file))
68 | os.system('del {}'.format(dst_file))
69 |
70 | # for roxford and rparis just make sym links
71 | elif dataset == 'roxford5k' or dataset == 'rparis6k':
72 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir))
73 | dataset_old = dataset[1:]
74 | dst_dir_old = os.path.join(datasets_dir, dataset_old, 'jpg')
75 | os.mkdir(os.path.join(datasets_dir, dataset))
76 | os.system('cmd /c mklink /d {} {}'.format(dst_dir_old, dst_dir))
77 | print('>> Created symbolic link from {} jpg to {} jpg'.format(dataset_old, dataset))
78 |
79 |
80 | gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset)
81 | gnd_dst_dir = os.path.join(datasets_dir, dataset)
82 | gnd_dl_file = 'gnd_{}.pkl'.format(dataset)
83 | gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file)
84 | gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file)
85 | if not os.path.exists(gnd_dst_file):
86 | print('>> Downloading dataset {} ground truth file...'.format(dataset))
87 | os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file))
88 |
89 |
90 | def download_train(data_dir):
91 | """
92 | DOWNLOAD_TRAIN Checks, and, if required, downloads the necessary datasets for the training.
93 |
94 | download_train(DATA_ROOT) checks if the data necessary for running the example script exist.
95 | If not it downloads it in the folder structure:
96 | DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db files
97 | DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db files
98 | """
99 |
100 | # Create data folder if it does not exist
101 | if not os.path.isdir(data_dir):
102 | os.mkdir(data_dir)
103 | print(data_dir)
104 | # Create datasets folder if it does not exist
105 | datasets_dir = os.path.join(data_dir, 'train')
106 | if not os.path.isdir(datasets_dir):
107 | os.mkdir(datasets_dir)
108 |
109 | # Download folder train/retrieval-SfM-120k/
110 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'ims')
111 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims')
112 | dl_file = 'ims.tar.gz'
113 | if not os.path.isdir(dst_dir):
114 | src_file = os.path.join(src_dir, dl_file)
115 | dst_file = os.path.join(dst_dir, dl_file)
116 | print('>> Image directory does not exist. Creating: {}'.format(dst_dir))
117 | os.makedirs(dst_dir)
118 | print('>> Downloading ims.tar.gz...')
119 | # os.system('wget {} -O {}'.format(src_file, dst_file))
120 | print('>> Extracting {}...'.format(dst_file))
121 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir))
122 | print('>> Extracted, deleting {}...'.format(dst_file))
123 | os.system('del {}'.format(dst_file))
124 |
125 | # Create symlink for train/retrieval-SfM-30k/
126 | dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims')
127 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims')
128 | if not os.path.isdir(dst_dir):
129 | os.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k','ims'))
130 | os.system('mklink {} {}'.format(dst_dir_old, dst_dir))
131 | print('>> Created symbolic link from retrieval-SfM-120k/ims to retrieval-SfM-30k/ims')
132 |
133 | # Download db files
134 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'dbs')
135 | datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k']
136 | for dataset in datasets:
137 | dst_dir = os.path.join(datasets_dir, dataset)
138 | if dataset == 'retrieval-SfM-120k':
139 | dl_files = ['{}.pkl'.format(dataset), '{}-whiten.pkl'.format(dataset)]
140 | elif dataset == 'retrieval-SfM-30k':
141 | dl_files = ['{}-whiten.pkl'.format(dataset)]
142 |
143 | if not os.path.isdir(dst_dir):
144 | print('>> Dataset directory does not exist. Creating: {}'.format(dst_dir))
145 | os.mkdir(dst_dir)
146 |
147 | for i in range(len(dl_files)):
148 | src_file = os.path.join(src_dir, dl_files[i])
149 | dst_file = os.path.join(dst_dir, dl_files[i])
150 | if not os.path.isfile(dst_file):
151 | print('>> DB file {} does not exist. Downloading...'.format(dl_files[i]))
152 | os.system('wget {} -O {}'.format(src_file, dst_file))
153 |
--------------------------------------------------------------------------------
/interface.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # /usr/bin/env pythpn
3 |
4 | '''
5 | Author: yinhao
6 | Email: yinhao_x@163.com
7 | Wechat: xss_yinhao
8 | Github: http://github.com/yinhaoxs
9 |
10 | data: 2019-11-23 21:51
11 | desc:
12 | '''
13 | import torch
14 | from torch.utils.model_zoo import load_url
15 | from torchvision import transforms
16 | from cirtorch.datasets.testdataset import configdataset
17 | from cirtorch.utils.download import download_train, download_test
18 | from cirtorch.utils.evaluate import compute_map_and_print
19 | from cirtorch.utils.general import get_data_root, htime
20 | from cirtorch.networks.imageretrievalnet_cpu import init_network, extract_vectors
21 | from cirtorch.datasets.datahelpers import imresize
22 |
23 | from PIL import Image
24 | import numpy as np
25 | import pandas as pd
26 | from flask import Flask, request
27 | import json, io, sys, time, traceback, argparse, logging, subprocess, pickle, os, yaml,shutil
28 | import cv2
29 | import pdb
30 | from werkzeug.utils import cached_property
31 | from apscheduler.schedulers.background import BackgroundScheduler
32 | from multiprocessing import Pool
33 |
34 | app = Flask(__name__)
35 |
36 | @app.route("/")
37 | def index():
38 | return ""
39 |
40 | @app.route("/images/*", methods=['GET','POST'])
41 | def accInsurance():
42 | """
43 | flask request process handle
44 | :return:
45 | """
46 | try:
47 | if request.method == 'GET':
48 | return json.dumps({'err': 1, 'msg': 'POST only'})
49 | else:
50 | app.logger.debug("print headers------")
51 | headers = request.headers
52 | headers_info = ""
53 | for k, v in headers.items():
54 | headers_info += "{}: {}\n".format(k, v)
55 | app.logger.debug(headers_info)
56 |
57 | app.logger.debug("print forms------")
58 | forms_info = ""
59 | for k, v in request.form.items():
60 | forms_info += "{}: {}\n".format(k, v)
61 | app.logger.debug(forms_info)
62 |
63 | if 'query' not in request.files:
64 | return json.dumps({'err': 2, 'msg': 'query image is empty'})
65 |
66 | if 'sig' not in request.form:
67 | return json.dumps({'err': 3, 'msg': 'sig is empty'})
68 |
69 | if 'q_no' not in request.form:
70 | return json.dumps({'err': 4, 'msg': 'no is empty'})
71 |
72 | if 'q_did' not in request.form:
73 | return json.dumps({'err': 5, 'msg': 'did is empty'})
74 |
75 | if 'q_id' not in request.form:
76 | return json.dumps({'err': 6, 'msg': 'id is empty'})
77 |
78 | if 'type' not in request.form:
79 | return json.dumps({'err': 7, 'msg': 'type is empty'})
80 |
81 | img_name = request.files['query'].filename
82 | img_bytes = request.files['query'].read()
83 | img = request.files['query']
84 | sig = request.form['sig']
85 | q_no = request.form['q_no']
86 | q_did = request.form['q_did']
87 | q_id = request.form['q_id']
88 | type = request.form['type']
89 |
90 | if str(type) not in types:
91 | return json.dumps({'err': 8, 'msg': 'type is not exist'})
92 |
93 | if img_bytes is None:
94 | return json.dumps({'err': 10, 'msg': 'img is none'})
95 |
96 | results = imageRetrieval().retrieval_online_v0(img, q_no, q_did, q_id, type)
97 |
98 | data = dict()
99 | data['query'] = img_name
100 | data['sig'] = sig
101 | data['type'] = type
102 | data['q_no'] = q_no
103 | data['q_did'] = q_did
104 | data['q_id'] = q_id
105 | data['results'] = results
106 |
107 | return json.dumps({'err': 0, 'msg': 'success', 'data': data})
108 |
109 | except:
110 | app.logger.exception(sys.exc_info())
111 | return json.dumps({'err': 9, 'msg': 'unknow error'})
112 |
113 |
114 | class imageRetrieval():
115 | def __init__(self):
116 | pass
117 |
118 | def cosine_dist(self, x, y):
119 | return 100 * float(np.dot(x, y))/(np.dot(x,x)*np.dot(y,y)) ** 0.5
120 |
121 | def inference(self, img):
122 | try:
123 | input = Image.open(img).convert("RGB")
124 | input = imresize(input, 224)
125 | input = transforms(input).unsqueeze()
126 | with torch.no_grad():
127 | vect = net(input)
128 | return vect
129 | except:
130 | print('cannot indentify error')
131 |
132 | def retrieval_online_v0(self, img, q_no, q_did, q_id, type):
133 | # load model
134 | query_vect = self.inference(img)
135 | query_vect = list(query_vect.detach().numpy().T[0])
136 |
137 | lsh = lsh_dict[str(type)]
138 | response = lsh.query(query_vect, num_results=1, distance_func = "cosine")
139 |
140 | try:
141 | similar_path = response[0][0][1]
142 | score = np.rint(self.cosine_dist(list(query_vect), list(response[0][0][0])))
143 | rank_list = similar_path.split("/")
144 | s_id, s_did, s_no = rank_list[-1].split("_")[-1].split(".")[0], rank_list[-1].split("_")[0], rank_list[-2]
145 | results = [{"s_no": s_no, "r_did": s_did, "s_id": s_id, "score": score}]
146 | except:
147 | results = []
148 |
149 | img_path = "/{}/{}_{}".format(q_no, q_did, q_id)
150 | lsh.index(query_vect, extra_data=img_path)
151 | lsh_dict[str(type)] = lsh
152 |
153 | return results
154 |
155 |
156 |
157 | class initModel():
158 | def __init__(self):
159 | pass
160 |
161 | def init_model(self, network, model_dir, types):
162 | print(">> Loading network:\n>>>> '{}'".format(network))
163 | # state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks'))
164 | state = torch.load(network)
165 | # parsing net params from meta
166 | # architecture, pooling, mean, std required
167 | # the rest has default values, in case that is doesnt exist
168 | net_params = {}
169 | net_params['architecture'] = state['meta']['architecture']
170 | net_params['pooling'] = state['meta']['pooling']
171 | net_params['local_whitening'] = state['meta'].get('local_whitening', False)
172 | net_params['regional'] = state['meta'].get('regional', False)
173 | net_params['whitening'] = state['meta'].get('whitening', False)
174 | net_params['mean'] = state['meta']['mean']
175 | net_params['std'] = state['meta']['std']
176 | net_params['pretrained'] = False
177 | # network initialization
178 | net = init_network(net_params)
179 | net.load_state_dict(state['state_dict'])
180 | print(">>>> loaded network: ")
181 | print(net.meta_repr())
182 | # moving network to gpu and eval mode
183 | # net.cuda()
184 | net.eval()
185 |
186 | # set up the transform
187 | normalize = transforms.Normalize(
188 | mean=net.meta['mean'],
189 | std=net.meta['std']
190 | )
191 | transform = transforms.Compose([
192 | transforms.ToTensor(),
193 | normalize
194 | ])
195 |
196 | lsh_dict = dict()
197 | for type in types:
198 | with open(os.path.join(model_dir, "dataset_index_{}.pkl".format(str(type))), "rb") as f:
199 | lsh = pickle.load(f)
200 |
201 | lsh_dict[str(type)] = lsh
202 |
203 | return net, lsh_dict, transforms
204 |
205 | def init(self):
206 | with open('config.yaml', 'r') as f:
207 | conf = yaml.load(f)
208 |
209 | app.logger.info(conf)
210 | host = conf['website']['host']
211 | port = conf['website']['port']
212 | network = conf['model']['network']
213 | model_dir = conf['model']['model_dir']
214 | types = conf['model']['type']
215 |
216 | net, lsh_dict, transforms = self.init_model(network, model_dir, types)
217 |
218 | return host, port, net, lsh_dict, transforms, model_dir, types
219 |
220 |
221 | def job():
222 | for type in types:
223 | with open(os.path.join(model_dir, "dataset_index_{}_v0.pkl".format(str(type))), "wb") as f:
224 | pickle.dump(lsh_dict[str(type)], f)
225 |
226 |
227 | if __name__ == "__main__":
228 | """
229 | start app from ssh
230 | """
231 | scheduler = BackgroundScheduler()
232 | host, port, net, lsh_dict, transforms, model_dir, types = initModel().init()
233 | app.run(host=host, port=port, debug=True)
234 | print("start server {}:{}".format(host, port))
235 |
236 | scheduler.add_job(job, 'interval', seconds= 30)
237 | scheduler.start()
238 |
239 | else:
240 | """
241 | start app from gunicorn
242 | """
243 | scheduler = BackgroundScheduler()
244 | gunicorn_logger = logging.getLogger("gunicorn.error")
245 | app.logger.handlers = gunicorn_logger.handlers
246 | app.logger.setLevel(gunicorn_logger.level)
247 |
248 | host, port, net, lsh_dict, transforms, model_dir, types = initModel().init()
249 | app.logger.info("started from gunicorn...")
250 |
251 | scheduler.add_job(job, 'interval', seconds=30)
252 | scheduler.start()
253 |
254 |
255 |
256 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 | $USER_HOME$/.subversion
103 |
104 |
105 |
106 |
107 | 1574525406918
108 |
109 |
110 | 1574525406918
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
--------------------------------------------------------------------------------
/cirtorch/datasets/traindataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import pdb
4 |
5 | import torch
6 | import torch.utils.data as data
7 |
8 | from cirtorch.datasets.datahelpers import default_loader, imresize, cid2filename
9 | from cirtorch.datasets.genericdataset import ImagesFromList
10 | from cirtorch.utils.general import get_data_root
11 |
12 | class TuplesDataset(data.Dataset):
13 | """Data loader that loads training and validation tuples of
14 | Radenovic etal ECCV16: CNN image retrieval learns from BoW
15 |
16 | Args:
17 | name (string): dataset name: 'retrieval-sfm-120k'
18 | mode (string): 'train' or 'val' for training and validation parts of dataset
19 | imsize (int, Default: None): Defines the maximum size of longer image side
20 | transform (callable, optional): A function/transform that takes in an PIL image
21 | and returns a transformed version. E.g, ``transforms.RandomCrop``
22 | loader (callable, optional): A function to load an image given its path.
23 | nnum (int, Default:5): Number of negatives for a query image in a training tuple
24 | qsize (int, Default:1000): Number of query images, ie number of (q,p,n1,...nN) tuples, to be processed in one epoch
25 | poolsize (int, Default:10000): Pool size for negative images re-mining
26 |
27 | Attributes:
28 | images (list): List of full filenames for each image
29 | clusters (list): List of clusterID per image
30 | qpool (list): List of all query image indexes
31 | ppool (list): List of positive image indexes, each corresponding to query at the same position in qpool
32 |
33 | qidxs (list): List of qsize query image indexes to be processed in an epoch
34 | pidxs (list): List of qsize positive image indexes, each corresponding to query at the same position in qidxs
35 | nidxs (list): List of qsize tuples of negative images
36 | Each nidxs tuple contains nnum images corresponding to query image at the same position in qidxs
37 |
38 | Lists qidxs, pidxs, nidxs are refreshed by calling the ``create_epoch_tuples()`` method,
39 | ie new q-p pairs are picked and negative images are remined
40 | """
41 |
42 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, transform=None, loader=default_loader):
43 |
44 | if not (mode == 'train' or mode == 'val'):
45 | raise(RuntimeError("MODE should be either train or val, passed as string"))
46 |
47 | if name.startswith('retrieval-SfM'):
48 | # setting up paths
49 | data_root = get_data_root()
50 | db_root = os.path.join(data_root, 'train', name)
51 | ims_root = os.path.join(db_root, 'ims')
52 |
53 | # loading db
54 | db_fn = os.path.join(db_root, '{}.pkl'.format(name))
55 | with open(db_fn, 'rb') as f:
56 | db = pickle.load(f)[mode]
57 |
58 | # setting fullpath for images
59 | self.images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
60 |
61 | elif name.startswith('gl'):
62 | ## TODO: NOT IMPLEMENTED YET PROPOERLY (WITH AUTOMATIC DOWNLOAD)
63 |
64 | # setting up paths
65 | db_root = '/mnt/fry2/users/datasets/landmarkscvprw18/recognition/'
66 | ims_root = os.path.join(db_root, 'images', 'train')
67 |
68 | # loading db
69 | db_fn = os.path.join(db_root, '{}.pkl'.format(name))
70 | with open(db_fn, 'rb') as f:
71 | db = pickle.load(f)[mode]
72 |
73 | # setting fullpath for images
74 | self.images = [os.path.join(ims_root, db['cids'][i]+'.jpg') for i in range(len(db['cids']))]
75 | else:
76 | raise(RuntimeError("Unknown dataset name!"))
77 |
78 | # initializing tuples dataset
79 | self.name = name
80 | self.mode = mode
81 | self.imsize = imsize
82 | self.clusters = db['cluster']
83 | self.qpool = db['qidxs']
84 | self.ppool = db['pidxs']
85 |
86 | ## If we want to keep only unique q-p pairs
87 | ## However, ordering of pairs will change, although that is not important
88 | # qpidxs = list(set([(self.qidxs[i], self.pidxs[i]) for i in range(len(self.qidxs))]))
89 | # self.qidxs = [qpidxs[i][0] for i in range(len(qpidxs))]
90 | # self.pidxs = [qpidxs[i][1] for i in range(len(qpidxs))]
91 |
92 | # size of training subset for an epoch
93 | self.nnum = nnum
94 | self.qsize = min(qsize, len(self.qpool))
95 | self.poolsize = min(poolsize, len(self.images))
96 | self.qidxs = None
97 | self.pidxs = None
98 | self.nidxs = None
99 |
100 | self.transform = transform
101 | self.loader = loader
102 |
103 | self.print_freq = 10
104 |
105 | def __getitem__(self, index):
106 | """
107 | Args:
108 | index (int): Index
109 |
110 | Returns:
111 | images tuple (q,p,n1,...,nN): Loaded train/val tuple at index of self.qidxs
112 | """
113 | if self.__len__() == 0:
114 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!"))
115 |
116 | output = []
117 | # query image
118 | output.append(self.loader(self.images[self.qidxs[index]]))
119 | # positive image
120 | output.append(self.loader(self.images[self.pidxs[index]]))
121 | # negative images
122 | for i in range(len(self.nidxs[index])):
123 | output.append(self.loader(self.images[self.nidxs[index][i]]))
124 |
125 | if self.imsize is not None:
126 | output = [imresize(img, self.imsize) for img in output]
127 |
128 | if self.transform is not None:
129 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))]
130 |
131 | target = torch.Tensor([-1, 1] + [0]*len(self.nidxs[index]))
132 |
133 | return output, target
134 |
135 | def __len__(self):
136 | # if not self.qidxs:
137 | # return 0
138 | # return len(self.qidxs)
139 | return self.qsize
140 |
141 | def __repr__(self):
142 | fmt_str = self.__class__.__name__ + '\n'
143 | fmt_str += ' Name and mode: {} {}\n'.format(self.name, self.mode)
144 | fmt_str += ' Number of images: {}\n'.format(len(self.images))
145 | fmt_str += ' Number of training tuples: {}\n'.format(len(self.qpool))
146 | fmt_str += ' Number of negatives per tuple: {}\n'.format(self.nnum)
147 | fmt_str += ' Number of tuples processed in an epoch: {}\n'.format(self.qsize)
148 | fmt_str += ' Pool size for negative remining: {}\n'.format(self.poolsize)
149 | tmp = ' Transforms (if any): '
150 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
151 | return fmt_str
152 |
153 | def create_epoch_tuples(self, net):
154 |
155 | print('>> Creating tuples for an epoch of {}-{}...'.format(self.name, self.mode))
156 | print(">>>> used network: ")
157 | print(net.meta_repr())
158 |
159 | ## ------------------------
160 | ## SELECTING POSITIVE PAIRS
161 | ## ------------------------
162 |
163 | # draw qsize random queries for tuples
164 | idxs2qpool = torch.randperm(len(self.qpool))[:self.qsize]
165 | self.qidxs = [self.qpool[i] for i in idxs2qpool]
166 | self.pidxs = [self.ppool[i] for i in idxs2qpool]
167 |
168 | ## ------------------------
169 | ## SELECTING NEGATIVE PAIRS
170 | ## ------------------------
171 |
172 | # if nnum = 0 create dummy nidxs
173 | # useful when only positives used for training
174 | if self.nnum == 0:
175 | self.nidxs = [[] for _ in range(len(self.qidxs))]
176 | return 0
177 |
178 | # draw poolsize random images for pool of negatives images
179 | idxs2images = torch.randperm(len(self.images))[:self.poolsize]
180 |
181 | # prepare network
182 | net.cuda()
183 | net.eval()
184 |
185 | # no gradients computed, to reduce memory and increase speed
186 | with torch.no_grad():
187 |
188 | print('>> Extracting descriptors for query images...')
189 | # prepare query loader
190 | loader = torch.utils.data.DataLoader(
191 | ImagesFromList(root='', images=[self.images[i] for i in self.qidxs], imsize=self.imsize, transform=self.transform),
192 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True
193 | )
194 | # extract query vectors
195 | qvecs = torch.zeros(net.meta['outputdim'], len(self.qidxs)).cuda()
196 | for i, input in enumerate(loader):
197 | qvecs[:, i] = net(input.cuda()).data.squeeze()
198 | if (i+1) % self.print_freq == 0 or (i+1) == len(self.qidxs):
199 | print('\r>>>> {}/{} done...'.format(i+1, len(self.qidxs)), end='')
200 | print('')
201 |
202 | print('>> Extracting descriptors for negative pool...')
203 | # prepare negative pool data loader
204 | loader = torch.utils.data.DataLoader(
205 | ImagesFromList(root='', images=[self.images[i] for i in idxs2images], imsize=self.imsize, transform=self.transform),
206 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True
207 | )
208 | # extract negative pool vectors
209 | poolvecs = torch.zeros(net.meta['outputdim'], len(idxs2images)).cuda()
210 | for i, input in enumerate(loader):
211 | poolvecs[:, i] = net(input.cuda()).data.squeeze()
212 | if (i+1) % self.print_freq == 0 or (i+1) == len(idxs2images):
213 | print('\r>>>> {}/{} done...'.format(i+1, len(idxs2images)), end='')
214 | print('')
215 |
216 | print('>> Searching for hard negatives...')
217 | # compute dot product scores and ranks on GPU
218 | scores = torch.mm(poolvecs.t(), qvecs)
219 | scores, ranks = torch.sort(scores, dim=0, descending=True)
220 | avg_ndist = torch.tensor(0).float().cuda() # for statistics
221 | n_ndist = torch.tensor(0).float().cuda() # for statistics
222 | # selection of negative examples
223 | self.nidxs = []
224 | for q in range(len(self.qidxs)):
225 | # do not use query cluster,
226 | # those images are potentially positive
227 | qcluster = self.clusters[self.qidxs[q]]
228 | clusters = [qcluster]
229 | nidxs = []
230 | r = 0
231 | while len(nidxs) < self.nnum:
232 | potential = idxs2images[ranks[r, q]]
233 | # take at most one image from the same cluster
234 | if not self.clusters[potential] in clusters:
235 | nidxs.append(potential)
236 | clusters.append(self.clusters[potential])
237 | avg_ndist += torch.pow(qvecs[:,q]-poolvecs[:,ranks[r, q]]+1e-6, 2).sum(dim=0).sqrt()
238 | n_ndist += 1
239 | r += 1
240 | self.nidxs.append(nidxs)
241 | print('>>>> Average negative l2-distance: {:.2f}'.format(avg_ndist/n_ndist))
242 | print('>>>> Done')
243 |
244 | return (avg_ndist/n_ndist).item() # return average negative l2-distance
245 |
--------------------------------------------------------------------------------
/cirtorch/examples/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import pickle
5 | import pdb
6 |
7 | import numpy as np
8 |
9 | import torch
10 | from torch.utils.model_zoo import load_url
11 | from torchvision import transforms
12 |
13 | from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
14 | from cirtorch.datasets.datahelpers import cid2filename
15 | from cirtorch.datasets.testdataset import configdataset
16 | from cirtorch.utils.download import download_train, download_test
17 | from cirtorch.utils.whiten import whitenlearn, whitenapply
18 | from cirtorch.utils.evaluate import compute_map_and_print
19 | from cirtorch.utils.general import get_data_root, htime
20 |
21 | PRETRAINED = {
22 | 'retrievalSfM120k-vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-vgg16-gem-b4dcdc6.pth',
23 | 'retrievalSfM120k-resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-resnet101-gem-b80fb85.pth',
24 | # new networks with whitening learned end-to-end
25 | 'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',
26 | 'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',
27 | 'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth',
28 | 'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',
29 | 'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',
30 | 'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth',
31 | }
32 |
33 | datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
34 | whitening_names = ['retrieval-SfM-30k', 'retrieval-SfM-120k']
35 |
36 | parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing')
37 |
38 | # network
39 | group = parser.add_mutually_exclusive_group(required=True)
40 | group.add_argument('--network-path', '-npath', metavar='NETWORK',
41 | help="pretrained network or network path (destination where network is saved)")
42 | group.add_argument('--network-offtheshelf', '-noff', metavar='NETWORK',
43 | help="off-the-shelf network, in the format 'ARCHITECTURE-POOLING' or 'ARCHITECTURE-POOLING-{reg-lwhiten-whiten}'," +
44 | " examples: 'resnet101-gem' | 'resnet101-gem-reg' | 'resnet101-gem-whiten' | 'resnet101-gem-lwhiten' | 'resnet101-gem-reg-whiten'")
45 |
46 | # test options
47 | parser.add_argument('--datasets', '-d', metavar='DATASETS', default='oxford5k,paris6k',
48 | help="comma separated list of test datasets: " +
49 | " | ".join(datasets_names) +
50 | " (default: 'oxford5k,paris6k')")
51 | parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N',
52 | help="maximum size of longer image side used for testing (default: 1024)")
53 | parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]',
54 | help="use multiscale vectors for testing, " +
55 | " examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')")
56 | parser.add_argument('--whitening', '-w', metavar='WHITENING', default=None, choices=whitening_names,
57 | help="dataset used to learn whitening for testing: " +
58 | " | ".join(whitening_names) +
59 | " (default: None)")
60 |
61 | # GPU ID
62 | parser.add_argument('--gpu-id', '-g', default='0', metavar='N',
63 | help="gpu id used for testing (default: '0')")
64 |
65 | def main():
66 | args = parser.parse_args()
67 |
68 | # check if there are unknown datasets
69 | for dataset in args.datasets.split(','):
70 | if dataset not in datasets_names:
71 | raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset))
72 |
73 | # check if test dataset are downloaded
74 | # and download if they are not
75 | download_train(get_data_root())
76 | download_test(get_data_root())
77 |
78 | # setting up the visible GPU
79 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
80 |
81 | # loading network from path
82 | if args.network_path is not None:
83 |
84 | print(">> Loading network:\n>>>> '{}'".format(args.network_path))
85 | if args.network_path in PRETRAINED:
86 | # pretrained networks (downloaded automatically)
87 | state = load_url(PRETRAINED[args.network_path], model_dir=os.path.join(get_data_root(), 'networks'))
88 | else:
89 | # fine-tuned network from path
90 | state = torch.load(args.network_path)
91 |
92 | # parsing net params from meta
93 | # architecture, pooling, mean, std required
94 | # the rest has default values, in case that is doesnt exist
95 | net_params = {}
96 | net_params['architecture'] = state['meta']['architecture']
97 | net_params['pooling'] = state['meta']['pooling']
98 | net_params['local_whitening'] = state['meta'].get('local_whitening', False)
99 | net_params['regional'] = state['meta'].get('regional', False)
100 | net_params['whitening'] = state['meta'].get('whitening', False)
101 | net_params['mean'] = state['meta']['mean']
102 | net_params['std'] = state['meta']['std']
103 | net_params['pretrained'] = False
104 |
105 | # load network
106 | net = init_network(net_params)
107 | net.load_state_dict(state['state_dict'])
108 |
109 | # if whitening is precomputed
110 | if 'Lw' in state['meta']:
111 | net.meta['Lw'] = state['meta']['Lw']
112 |
113 | print(">>>> loaded network: ")
114 | print(net.meta_repr())
115 |
116 | # loading offtheshelf network
117 | elif args.network_offtheshelf is not None:
118 |
119 | # parse off-the-shelf parameters
120 | offtheshelf = args.network_offtheshelf.split('-')
121 | net_params = {}
122 | net_params['architecture'] = offtheshelf[0]
123 | net_params['pooling'] = offtheshelf[1]
124 | net_params['local_whitening'] = 'lwhiten' in offtheshelf[2:]
125 | net_params['regional'] = 'reg' in offtheshelf[2:]
126 | net_params['whitening'] = 'whiten' in offtheshelf[2:]
127 | net_params['pretrained'] = True
128 |
129 | # load off-the-shelf network
130 | print(">> Loading off-the-shelf network:\n>>>> '{}'".format(args.network_offtheshelf))
131 | net = init_network(net_params)
132 | print(">>>> loaded network: ")
133 | print(net.meta_repr())
134 |
135 | # setting up the multi-scale parameters
136 | ms = list(eval(args.multiscale))
137 | if len(ms)>1 and net.meta['pooling'] == 'gem' and not net.meta['regional'] and not net.meta['whitening']:
138 | msp = net.pool.p.item()
139 | print(">> Set-up multiscale:")
140 | print(">>>> ms: {}".format(ms))
141 | print(">>>> msp: {}".format(msp))
142 | else:
143 | msp = 1
144 |
145 | # moving network to gpu and eval mode
146 | net.cuda()
147 | net.eval()
148 |
149 | # set up the transform
150 | normalize = transforms.Normalize(
151 | mean=net.meta['mean'],
152 | std=net.meta['std']
153 | )
154 | transform = transforms.Compose([
155 | transforms.ToTensor(),
156 | normalize
157 | ])
158 |
159 | # compute whitening
160 | if args.whitening is not None:
161 | start = time.time()
162 |
163 | if 'Lw' in net.meta and args.whitening in net.meta['Lw']:
164 |
165 | print('>> {}: Whitening is precomputed, loading it...'.format(args.whitening))
166 |
167 | if len(ms)>1:
168 | Lw = net.meta['Lw'][args.whitening]['ms']
169 | else:
170 | Lw = net.meta['Lw'][args.whitening]['ss']
171 |
172 | else:
173 |
174 | # if we evaluate networks from path we should save/load whitening
175 | # not to compute it every time
176 | if args.network_path is not None:
177 | whiten_fn = args.network_path + '_{}_whiten'.format(args.whitening)
178 | if len(ms) > 1:
179 | whiten_fn += '_ms'
180 | whiten_fn += '.pth'
181 | else:
182 | whiten_fn = None
183 |
184 | if whiten_fn is not None and os.path.isfile(whiten_fn):
185 | print('>> {}: Whitening is precomputed, loading it...'.format(args.whitening))
186 | Lw = torch.load(whiten_fn)
187 |
188 | else:
189 | print('>> {}: Learning whitening...'.format(args.whitening))
190 |
191 | # loading db
192 | db_root = os.path.join(get_data_root(), 'train', args.whitening)
193 | ims_root = os.path.join(db_root, 'ims')
194 | db_fn = os.path.join(db_root, '{}-whiten.pkl'.format(args.whitening))
195 | with open(db_fn, 'rb') as f:
196 | db = pickle.load(f)
197 | images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
198 |
199 | # extract whitening vectors
200 | print('>> {}: Extracting...'.format(args.whitening))
201 | wvecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp)
202 |
203 | # learning whitening
204 | print('>> {}: Learning...'.format(args.whitening))
205 | wvecs = wvecs.numpy()
206 | m, P = whitenlearn(wvecs, db['qidxs'], db['pidxs'])
207 | Lw = {'m': m, 'P': P}
208 |
209 | # saving whitening if whiten_fn exists
210 | if whiten_fn is not None:
211 | print('>> {}: Saving to {}...'.format(args.whitening, whiten_fn))
212 | torch.save(Lw, whiten_fn)
213 |
214 | print('>> {}: elapsed time: {}'.format(args.whitening, htime(time.time()-start)))
215 |
216 | else:
217 | Lw = None
218 |
219 | # evaluate on test datasets
220 | datasets = args.datasets.split(',')
221 | for dataset in datasets:
222 | start = time.time()
223 |
224 | print('>> {}: Extracting...'.format(dataset))
225 |
226 | # prepare config structure for the test dataset
227 | cfg = configdataset(dataset, os.path.join(get_data_root(), 'test'))
228 | images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])]
229 | qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])]
230 | try:
231 | bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
232 | except:
233 | bbxs = None # for holidaysmanrot and copydays
234 |
235 | # extract database and query vectors
236 | print('>> {}: database images...'.format(dataset))
237 | vecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp)
238 | print('>> {}: query images...'.format(dataset))
239 | qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms, msp=msp)
240 |
241 | print('>> {}: Evaluating...'.format(dataset))
242 |
243 | # convert to numpy
244 | vecs = vecs.numpy()
245 | qvecs = qvecs.numpy()
246 |
247 | # search, rank, and print
248 | scores = np.dot(vecs.T, qvecs)
249 | ranks = np.argsort(-scores, axis=0)
250 | compute_map_and_print(dataset, ranks, cfg['gnd'])
251 |
252 | if Lw is not None:
253 | # whiten the vectors
254 | vecs_lw = whitenapply(vecs, Lw['m'], Lw['P'])
255 | qvecs_lw = whitenapply(qvecs, Lw['m'], Lw['P'])
256 |
257 | # search, rank, and print
258 | scores = np.dot(vecs_lw.T, qvecs_lw)
259 | ranks = np.argsort(-scores, axis=0)
260 | compute_map_and_print(dataset + ' + whiten', ranks, cfg['gnd'])
261 |
262 | print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start)))
263 |
264 |
265 | if __name__ == '__main__':
266 | main()
--------------------------------------------------------------------------------
/lshash/lshash.py:
--------------------------------------------------------------------------------
1 | # lshash/lshash.py
2 | # Copyright 2012 Kay Zhu (a.k.a He Zhu) and contributors (see CONTRIBUTORS.txt)
3 | #
4 | # This module is part of lshash and is released under
5 | # the MIT License: http://www.opensource.org/licenses/mit-license.php
6 | # -*- coding: utf-8 -*-
7 | from __future__ import print_function, unicode_literals, division, absolute_import
8 | from builtins import int, round, str, object # noqa
9 | from future import standard_library
10 | standard_library.install_aliases() # noqa: Counter, OrderedDict,
11 | from past.builtins import basestring # noqa:
12 |
13 | import future # noqa
14 | import builtins # noqa
15 | import past # noqa
16 | import six # noqa
17 |
18 | import os
19 | import json
20 | import numpy as np
21 |
22 | try:
23 | from storage import storage # py2
24 | except ImportError:
25 | from .storage import storage # py3
26 |
27 | try:
28 | from bitarray import bitarray
29 | except ImportError:
30 | bitarray = None
31 |
32 |
33 | try:
34 | xrange # py2
35 | except NameError:
36 | xrange = range # py3
37 |
38 |
39 | class LSHash(object):
40 | """ LSHash implments locality sensitive hashing using random projection for
41 | input vectors of dimension `input_dim`.
42 |
43 | Attributes:
44 |
45 | :param hash_size:
46 | The length of the resulting binary hash in integer. E.g., 32 means the
47 | resulting binary hash will be 32-bit long.
48 | :param input_dim:
49 | The dimension of the input vector. E.g., a grey-scale picture of 30x30
50 | pixels will have an input dimension of 900.
51 | :param num_hashtables:
52 | (optional) The number of hash tables used for multiple lookups.
53 | :param storage_config:
54 | (optional) A dictionary of the form `{backend_name: config}` where
55 | `backend_name` is the either `dict` or `redis`, and `config` is the
56 | configuration used by the backend. For `redis` it should be in the
57 | format of `{"redis": {"host": hostname, "port": port_num}}`, where
58 | `hostname` is normally `localhost` and `port` is normally 6379.
59 | :param matrices_filename:
60 | (optional) Specify the path to the compressed numpy file ending with
61 | extension `.npz`, where the uniform random planes are stored, or to be
62 | stored if the file does not exist yet.
63 | :param overwrite:
64 | (optional) Whether to overwrite the matrices file if it already exist
65 | """
66 |
67 | def __init__(self, hash_size, input_dim, num_hashtables=1,
68 | storage_config=None, matrices_filename=None, overwrite=False):
69 |
70 | self.hash_size = hash_size
71 | self.input_dim = input_dim
72 | self.num_hashtables = num_hashtables
73 |
74 | if storage_config is None:
75 | storage_config = {'dict': None}
76 | self.storage_config = storage_config
77 |
78 | if matrices_filename and not matrices_filename.endswith('.npz'):
79 | raise ValueError("The specified file name must end with .npz")
80 | self.matrices_filename = matrices_filename
81 | self.overwrite = overwrite
82 |
83 | self._init_uniform_planes()
84 | self._init_hashtables()
85 |
86 | def _init_uniform_planes(self):
87 | """ Initialize uniform planes used to calculate the hashes
88 |
89 | if file `self.matrices_filename` exist and `self.overwrite` is
90 | selected, save the uniform planes to the specified file.
91 |
92 | if file `self.matrices_filename` exist and `self.overwrite` is not
93 | selected, load the matrix with `np.load`.
94 |
95 | if file `self.matrices_filename` does not exist and regardless of
96 | `self.overwrite`, only set `self.uniform_planes`.
97 | """
98 |
99 | if "uniform_planes" in self.__dict__:
100 | return
101 |
102 | if self.matrices_filename:
103 | file_exist = os.path.isfile(self.matrices_filename)
104 | if file_exist and not self.overwrite:
105 | try:
106 | npzfiles = np.load(self.matrices_filename)
107 | except IOError:
108 | print("Cannot load specified file as a numpy array")
109 | raise
110 | else:
111 | npzfiles = sorted(npzfiles.items(), key=lambda x: x[0])
112 | self.uniform_planes = [t[1] for t in npzfiles]
113 | else:
114 | self.uniform_planes = [self._generate_uniform_planes()
115 | for _ in xrange(self.num_hashtables)]
116 | try:
117 | np.savez_compressed(self.matrices_filename,
118 | *self.uniform_planes)
119 | except IOError:
120 | print("IOError when saving matrices to specificed path")
121 | raise
122 | else:
123 | self.uniform_planes = [self._generate_uniform_planes()
124 | for _ in xrange(self.num_hashtables)]
125 |
126 | def _init_hashtables(self):
127 | """ Initialize the hash tables such that each record will be in the
128 | form of "[storage1, storage2, ...]" """
129 |
130 | self.hash_tables = [storage(self.storage_config, i)
131 | for i in xrange(self.num_hashtables)]
132 |
133 | def _generate_uniform_planes(self):
134 | """ Generate uniformly distributed hyperplanes and return it as a 2D
135 | numpy array.
136 | """
137 |
138 | return np.random.randn(self.hash_size, self.input_dim)
139 |
140 | def _hash(self, planes, input_point):
141 | """ Generates the binary hash for `input_point` and returns it.
142 |
143 | :param planes:
144 | The planes are random uniform planes with a dimension of
145 | `hash_size` * `input_dim`.
146 | :param input_point:
147 | A Python tuple or list object that contains only numbers.
148 | The dimension needs to be 1 * `input_dim`.
149 | """
150 |
151 | try:
152 | input_point = np.array(input_point) # for faster dot product
153 | projections = np.dot(planes, input_point)
154 | except TypeError as e:
155 | print("""The input point needs to be an array-like object with
156 | numbers only elements""")
157 | raise
158 | except ValueError as e:
159 | print("""The input point needs to be of the same dimension as
160 | `input_dim` when initializing this LSHash instance""", e)
161 | raise
162 | else:
163 | return "".join(['1' if i > 0 else '0' for i in projections])
164 |
165 | def _as_np_array(self, json_or_tuple):
166 | """ Takes either a JSON-serialized data structure or a tuple that has
167 | the original input points stored, and returns the original input point
168 | in numpy array format.
169 | """
170 | if isinstance(json_or_tuple, basestring):
171 | # JSON-serialized in the case of Redis
172 | try:
173 | # Return the point stored as list, without the extra data
174 | tuples = json.loads(json_or_tuple)[0]
175 | except TypeError:
176 | print("The value stored is not JSON-serilizable")
177 | raise
178 | else:
179 | # If extra_data exists, `tuples` is the entire
180 | # (point:tuple, extra_data). Otherwise (i.e., extra_data=None),
181 | # return the point stored as a tuple
182 | tuples = json_or_tuple
183 |
184 | if isinstance(tuples[0], tuple):
185 | # in this case extra data exists
186 | return np.asarray(tuples[0])
187 |
188 | elif isinstance(tuples, (tuple, list)):
189 | try:
190 | return np.asarray(tuples)
191 | except ValueError as e:
192 | print("The input needs to be an array-like object", e)
193 | raise
194 | else:
195 | raise TypeError("query data is not supported")
196 |
197 | def index(self, input_point, extra_data=None):
198 | """ Index a single input point by adding it to the selected storage.
199 |
200 | If `extra_data` is provided, it will become the value of the dictionary
201 | {input_point: extra_data}, which in turn will become the value of the
202 | hash table. `extra_data` needs to be JSON serializable if in-memory
203 | dict is not used as storage.
204 |
205 | :param input_point:
206 | A list, or tuple, or numpy ndarray object that contains numbers
207 | only. The dimension needs to be 1 * `input_dim`.
208 | This object will be converted to Python tuple and stored in the
209 | selected storage.
210 | :param extra_data:
211 | (optional) Needs to be a JSON-serializable object: list, dicts and
212 | basic types such as strings and integers.
213 | """
214 |
215 | if isinstance(input_point, np.ndarray):
216 | input_point = input_point.tolist()
217 |
218 | if extra_data:
219 | value = (tuple(input_point), extra_data)
220 | else:
221 | value = tuple(input_point)
222 |
223 | for i, table in enumerate(self.hash_tables):
224 | table.append_val(self._hash(self.uniform_planes[i], input_point),
225 | value)
226 |
227 | def query(self, query_point, num_results=None, distance_func=None):
228 | """ Takes `query_point` which is either a tuple or a list of numbers,
229 | returns `num_results` of results as a list of tuples that are ranked
230 | based on the supplied metric function `distance_func`.
231 |
232 | :param query_point:
233 | A list, or tuple, or numpy ndarray that only contains numbers.
234 | The dimension needs to be 1 * `input_dim`.
235 | Used by :meth:`._hash`.
236 | :param num_results:
237 | (optional) Integer, specifies the max amount of results to be
238 | returned. If not specified all candidates will be returned as a
239 | list in ranked order.
240 | :param distance_func:
241 | (optional) The distance function to be used. Currently it needs to
242 | be one of ("hamming", "euclidean", "true_euclidean",
243 | "centred_euclidean", "cosine", "l1norm"). By default "euclidean"
244 | will used.
245 | """
246 |
247 | candidates = set()
248 | if not distance_func:
249 | distance_func = "euclidean"
250 |
251 | if distance_func == "hamming":
252 | if not bitarray:
253 | raise ImportError(" Bitarray is required for hamming distance")
254 |
255 | for i, table in enumerate(self.hash_tables):
256 | binary_hash = self._hash(self.uniform_planes[i], query_point)
257 | for key in table.keys():
258 | distance = LSHash.hamming_dist(key, binary_hash)
259 | if distance < 2:
260 | candidates.update(table.get_list(key))
261 |
262 | d_func = LSHash.euclidean_dist_square
263 |
264 | else:
265 |
266 | if distance_func == "euclidean":
267 | d_func = LSHash.euclidean_dist_square
268 | elif distance_func == "true_euclidean":
269 | d_func = LSHash.euclidean_dist
270 | elif distance_func == "centred_euclidean":
271 | d_func = LSHash.euclidean_dist_centred
272 | elif distance_func == "cosine":
273 | d_func = LSHash.cosine_dist
274 | elif distance_func == "l1norm":
275 | d_func = LSHash.l1norm_dist
276 | else:
277 | raise ValueError("The distance function name is invalid.")
278 |
279 | for i, table in enumerate(self.hash_tables):
280 | binary_hash = self._hash(self.uniform_planes[i], query_point)
281 | candidates.update(table.get_list(binary_hash))
282 |
283 | # rank candidates by distance function
284 | candidates = [(ix, d_func(query_point, self._as_np_array(ix)))
285 | for ix in candidates]
286 | candidates = sorted(candidates, key=lambda x: x[1])
287 |
288 | return candidates[:num_results] if num_results else candidates
289 |
290 | ### distance functions
291 |
292 | @staticmethod
293 | def hamming_dist(bitarray1, bitarray2):
294 | xor_result = bitarray(bitarray1) ^ bitarray(bitarray2)
295 | return xor_result.count()
296 |
297 | @staticmethod
298 | def euclidean_dist(x, y):
299 | """ This is a hot function, hence some optimizations are made. """
300 | diff = np.array(x) - y
301 | return np.sqrt(np.dot(diff, diff))
302 |
303 | @staticmethod
304 | def euclidean_dist_square(x, y):
305 | """ This is a hot function, hence some optimizations are made. """
306 | diff = np.array(x) - y
307 | return np.dot(diff, diff)
308 |
309 | @staticmethod
310 | def euclidean_dist_centred(x, y):
311 | """ This is a hot function, hence some optimizations are made. """
312 | diff = np.mean(x) - np.mean(y)
313 | return np.dot(diff, diff)
314 |
315 | @staticmethod
316 | def l1norm_dist(x, y):
317 | return sum(abs(x - y))
318 |
319 | @staticmethod
320 | def cosine_dist(x, y):
321 | return 1 - float(np.dot(x, y)) / ((np.dot(x, x) * np.dot(y, y)) ** 0.5)
322 |
--------------------------------------------------------------------------------
/cirtorch/networks/imageretrievalnet_cpu.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | import torchvision
9 |
10 | from cirtorch.layers.pooling import MAC, SPoC, GeM, GeMmp, RMAC, Rpool
11 | from cirtorch.layers.normalization import L2N, PowerLaw
12 | from cirtorch.datasets.genericdataset import ImagesFromList
13 | from cirtorch.utils.general import get_data_root
14 |
15 | # for some models, we have imported features (convolutions) from caffe because the image retrieval performance is higher for them
16 | FEATURES = {
17 | 'vgg16' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-vgg16-features-d369c8e.pth',
18 | 'resnet50' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet50-features-ac468af.pth',
19 | 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet101-features-10a101d.pth',
20 | 'resnet152' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet152-features-1011020.pth',
21 | }
22 |
23 | # TODO: pre-compute for more architectures and properly test variations (pre l2norm, post l2norm)
24 | # pre-computed local pca whitening that can be applied before the pooling layer
25 | L_WHITENING = {
26 | 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-9f830ef.pth', # no pre l2 norm
27 | # 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-da5c935.pth', # with pre l2 norm
28 | }
29 |
30 | # possible global pooling layers, each on of these can be made regional
31 | POOLING = {
32 | 'mac' : MAC,
33 | 'spoc' : SPoC,
34 | 'gem' : GeM,
35 | 'gemmp' : GeMmp,
36 | 'rmac' : RMAC,
37 | }
38 |
39 | # TODO: pre-compute for: resnet50-gem-r, resnet50-mac-r, vgg16-mac-r, alexnet-mac-r
40 | # pre-computed regional whitening, for most commonly used architectures and pooling methods
41 | R_WHITENING = {
42 | 'alexnet-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-rwhiten-c8cf7e2.pth',
43 | 'vgg16-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth',
44 | 'resnet101-mac-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-rwhiten-7f1ed8c.pth',
45 | 'resnet101-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-rwhiten-adace84.pth',
46 | }
47 |
48 | # TODO: pre-compute for more architectures
49 | # pre-computed final (global) whitening, for most commonly used architectures and pooling methods
50 | WHITENING = {
51 | 'alexnet-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-whiten-454ad53.pth',
52 | 'alexnet-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-whiten-4c9126b.pth',
53 | 'vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-whiten-eaa6695.pth',
54 | 'vgg16-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-whiten-83582df.pth',
55 | 'resnet50-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet50-gem-whiten-f15da7b.pth',
56 | 'resnet101-mac-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-whiten-9df41d3.pth',
57 | 'resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-whiten-22ab0c1.pth',
58 | 'resnet101-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-whiten-b379c0a.pth',
59 | 'resnet101-gemmp' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gemmp-whiten-770f53c.pth',
60 | 'resnet152-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet152-gem-whiten-abe7b93.pth',
61 | 'densenet121-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet121-gem-whiten-79e3eea.pth',
62 | 'densenet169-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet169-gem-whiten-6b2a76a.pth',
63 | 'densenet201-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet201-gem-whiten-22ea45c.pth',
64 | }
65 |
66 | # output dimensionality for supported architectures
67 | OUTPUT_DIM = {
68 | 'alexnet' : 256,
69 | 'vgg11' : 512,
70 | 'vgg13' : 512,
71 | 'vgg16' : 512,
72 | 'vgg19' : 512,
73 | 'resnet18' : 512,
74 | 'resnet34' : 512,
75 | 'resnet50' : 2048,
76 | 'resnet101' : 2048,
77 | 'resnet152' : 2048,
78 | 'densenet121' : 1024,
79 | 'densenet169' : 1664,
80 | 'densenet201' : 1920,
81 | 'densenet161' : 2208, # largest densenet
82 | 'squeezenet1_0' : 512,
83 | 'squeezenet1_1' : 512,
84 | }
85 |
86 |
87 | class ImageRetrievalNet(nn.Module):
88 |
89 | def __init__(self, features, lwhiten, pool, whiten, meta):
90 | super(ImageRetrievalNet, self).__init__()
91 | self.features = nn.Sequential(*features)
92 | self.lwhiten = lwhiten
93 | self.pool = pool
94 | self.whiten = whiten
95 | self.norm = L2N()
96 | self.meta = meta
97 |
98 | def forward(self, x):
99 | # x -> features
100 | o = self.features(x)
101 |
102 | # TODO: properly test (with pre-l2norm and/or post-l2norm)
103 | # if lwhiten exist: features -> local whiten
104 | if self.lwhiten is not None:
105 | # o = self.norm(o)
106 | s = o.size()
107 | o = o.permute(0,2,3,1).contiguous().view(-1, s[1])
108 | o = self.lwhiten(o)
109 | o = o.view(s[0],s[2],s[3],self.lwhiten.out_features).permute(0,3,1,2)
110 | # o = self.norm(o)
111 |
112 | # features -> pool -> norm
113 | o = self.norm(self.pool(o)).squeeze(-1).squeeze(-1)
114 |
115 | # if whiten exist: pooled features -> whiten -> norm
116 | if self.whiten is not None:
117 | o = self.norm(self.whiten(o))
118 |
119 | # permute so that it is Dx1 column vector per image (DxN if many images)
120 | return o.permute(1,0)
121 |
122 | def __repr__(self):
123 | tmpstr = super(ImageRetrievalNet, self).__repr__()[:-1]
124 | tmpstr += self.meta_repr()
125 | tmpstr = tmpstr + ')'
126 | return tmpstr
127 |
128 | def meta_repr(self):
129 | tmpstr = ' (' + 'meta' + '): dict( \n' # + self.meta.__repr__() + '\n'
130 | tmpstr += ' architecture: {}\n'.format(self.meta['architecture'])
131 | tmpstr += ' local_whitening: {}\n'.format(self.meta['local_whitening'])
132 | tmpstr += ' pooling: {}\n'.format(self.meta['pooling'])
133 | tmpstr += ' regional: {}\n'.format(self.meta['regional'])
134 | tmpstr += ' whitening: {}\n'.format(self.meta['whitening'])
135 | tmpstr += ' outputdim: {}\n'.format(self.meta['outputdim'])
136 | tmpstr += ' mean: {}\n'.format(self.meta['mean'])
137 | tmpstr += ' std: {}\n'.format(self.meta['std'])
138 | tmpstr = tmpstr + ' )\n'
139 | return tmpstr
140 |
141 |
142 | def init_network(params):
143 |
144 | # parse params with default values
145 | architecture = params.get('architecture', 'resnet101')
146 | local_whitening = params.get('local_whitening', False)
147 | pooling = params.get('pooling', 'gem')
148 | regional = params.get('regional', False)
149 | whitening = params.get('whitening', False)
150 | mean = params.get('mean', [0.485, 0.456, 0.406])
151 | std = params.get('std', [0.229, 0.224, 0.225])
152 | pretrained = params.get('pretrained', True)
153 |
154 | # get output dimensionality size
155 | dim = OUTPUT_DIM[architecture]
156 |
157 | # loading network from torchvision
158 | if pretrained:
159 | if architecture not in FEATURES:
160 | # initialize with network pretrained on imagenet in pytorch
161 | net_in = getattr(torchvision.models, architecture)(pretrained=True)
162 | else:
163 | # initialize with random weights, later on we will fill features with custom pretrained network
164 | net_in = getattr(torchvision.models, architecture)(pretrained=False)
165 | else:
166 | # initialize with random weights
167 | net_in = getattr(torchvision.models, architecture)(pretrained=False)
168 |
169 | # initialize features
170 | # take only convolutions for features,
171 | # always ends with ReLU to make last activations non-negative
172 | if architecture.startswith('alexnet'):
173 | features = list(net_in.features.children())[:-1]
174 | elif architecture.startswith('vgg'):
175 | features = list(net_in.features.children())[:-1]
176 | elif architecture.startswith('resnet'):
177 | features = list(net_in.children())[:-2]
178 | elif architecture.startswith('densenet'):
179 | features = list(net_in.features.children())
180 | features.append(nn.ReLU(inplace=True))
181 | elif architecture.startswith('squeezenet'):
182 | features = list(net_in.features.children())
183 | else:
184 | raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
185 |
186 | # initialize local whitening
187 | if local_whitening:
188 | lwhiten = nn.Linear(dim, dim, bias=True)
189 | # TODO: lwhiten with possible dimensionality reduce
190 |
191 | if pretrained:
192 | lw = architecture
193 | if lw in L_WHITENING:
194 | print(">> {}: for '{}' custom computed local whitening '{}' is used"
195 | .format(os.path.basename(__file__), lw, os.path.basename(L_WHITENING[lw])))
196 | whiten_dir = os.path.join(get_data_root(), 'whiten')
197 | lwhiten.load_state_dict(model_zoo.load_url(L_WHITENING[lw], model_dir=whiten_dir))
198 | else:
199 | print(">> {}: for '{}' there is no local whitening computed, random weights are used"
200 | .format(os.path.basename(__file__), lw))
201 |
202 | else:
203 | lwhiten = None
204 |
205 | # initialize pooling
206 | if pooling == 'gemmp':
207 | pool = POOLING[pooling](mp=dim)
208 | else:
209 | pool = POOLING[pooling]()
210 |
211 | # initialize regional pooling
212 | if regional:
213 | rpool = pool
214 | rwhiten = nn.Linear(dim, dim, bias=True)
215 | # TODO: rwhiten with possible dimensionality reduce
216 |
217 | if pretrained:
218 | rw = '{}-{}-r'.format(architecture, pooling)
219 | if rw in R_WHITENING:
220 | print(">> {}: for '{}' custom computed regional whitening '{}' is used"
221 | .format(os.path.basename(__file__), rw, os.path.basename(R_WHITENING[rw])))
222 | whiten_dir = os.path.join(get_data_root(), 'whiten')
223 | rwhiten.load_state_dict(model_zoo.load_url(R_WHITENING[rw], model_dir=whiten_dir))
224 | else:
225 | print(">> {}: for '{}' there is no regional whitening computed, random weights are used"
226 | .format(os.path.basename(__file__), rw))
227 |
228 | pool = Rpool(rpool, rwhiten)
229 |
230 | # initialize whitening
231 | if whitening:
232 | whiten = nn.Linear(dim, dim, bias=True)
233 | # TODO: whiten with possible dimensionality reduce
234 |
235 | if pretrained:
236 | w = architecture
237 | if local_whitening:
238 | w += '-lw'
239 | w += '-' + pooling
240 | if regional:
241 | w += '-r'
242 | if w in WHITENING:
243 | print(">> {}: for '{}' custom computed whitening '{}' is used"
244 | .format(os.path.basename(__file__), w, os.path.basename(WHITENING[w])))
245 | whiten_dir = os.path.join(get_data_root(), 'whiten')
246 | whiten.load_state_dict(model_zoo.load_url(WHITENING[w], model_dir=whiten_dir))
247 | else:
248 | print(">> {}: for '{}' there is no whitening computed, random weights are used"
249 | .format(os.path.basename(__file__), w))
250 | else:
251 | whiten = None
252 |
253 | # create meta information to be stored in the network
254 | meta = {
255 | 'architecture' : architecture,
256 | 'local_whitening' : local_whitening,
257 | 'pooling' : pooling,
258 | 'regional' : regional,
259 | 'whitening' : whitening,
260 | 'mean' : mean,
261 | 'std' : std,
262 | 'outputdim' : dim,
263 | }
264 |
265 | # create a generic image retrieval network
266 | net = ImageRetrievalNet(features, lwhiten, pool, whiten, meta)
267 |
268 | # initialize features with custom pretrained network if needed
269 | if pretrained and architecture in FEATURES:
270 | print(">> {}: for '{}' custom pretrained features '{}' are used"
271 | .format(os.path.basename(__file__), architecture, os.path.basename(FEATURES[architecture])))
272 | model_dir = os.path.join(get_data_root(), 'networks')
273 | net.features.load_state_dict(model_zoo.load_url(FEATURES[architecture], model_dir=model_dir))
274 |
275 | return net
276 |
277 |
278 | def extract_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
279 | # moving network to gpu and eval mode
280 | net.cuda()
281 | net.eval()
282 |
283 | # creating dataset loader
284 | loader = torch.utils.data.DataLoader(
285 | ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
286 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True
287 | )
288 |
289 | # extracting vectors
290 | with torch.no_grad():
291 | vecs = torch.zeros(net.meta['outputdim'], len(images))
292 | for i, input in enumerate(loader):
293 | input = input.cuda()
294 |
295 | if len(ms) == 1 and ms[0] == 1:
296 | vecs[:, i] = extract_ss(net, input)
297 | else:
298 | vecs[:, i] = extract_ms(net, input, ms, msp)
299 |
300 | if (i+1) % print_freq == 0 or (i+1) == len(images):
301 | print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='')
302 | print('')
303 |
304 | return vecs
305 |
306 | def extract_ss(net, input):
307 | return net(input).cpu().data.squeeze()
308 |
309 | def extract_ms(net, input, ms, msp):
310 |
311 | v = torch.zeros(net.meta['outputdim'])
312 |
313 | for s in ms:
314 | if s == 1:
315 | input_t = input.clone()
316 | else:
317 | input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False)
318 | v += net(input_t).pow(msp).cpu().data.squeeze()
319 |
320 | v /= len(ms)
321 | v = v.pow(1./msp)
322 | v /= v.norm()
323 |
324 | return v
325 |
326 |
327 | def extract_regional_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
328 | # moving network to gpu and eval mode
329 | net.cuda()
330 | net.eval()
331 |
332 | # creating dataset loader
333 | loader = torch.utils.data.DataLoader(
334 | ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
335 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True
336 | )
337 |
338 | # extracting vectors
339 | with torch.no_grad():
340 | vecs = []
341 | for i, input in enumerate(loader):
342 | input = input.cuda()
343 |
344 | if len(ms) == 1:
345 | vecs.append(extract_ssr(net, input))
346 | else:
347 | # TODO: not implemented yet
348 | # vecs.append(extract_msr(net, input, ms, msp))
349 | raise NotImplementedError
350 |
351 | if (i+1) % print_freq == 0 or (i+1) == len(images):
352 | print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='')
353 | print('')
354 |
355 | return vecs
356 |
357 | def extract_ssr(net, input):
358 | return net.pool(net.features(input), aggregate=False).squeeze(0).squeeze(-1).squeeze(-1).permute(1,0).cpu().data
359 |
360 |
361 | def extract_local_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
362 | # moving network to gpu and eval mode
363 | net.cuda()
364 | net.eval()
365 |
366 | # creating dataset loader
367 | loader = torch.utils.data.DataLoader(
368 | ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
369 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True
370 | )
371 |
372 | # extracting vectors
373 | with torch.no_grad():
374 | vecs = []
375 | for i, input in enumerate(loader):
376 | input = input.cuda()
377 |
378 | if len(ms) == 1:
379 | vecs.append(extract_ssl(net, input))
380 | else:
381 | # TODO: not implemented yet
382 | # vecs.append(extract_msl(net, input, ms, msp))
383 | raise NotImplementedError
384 |
385 | if (i+1) % print_freq == 0 or (i+1) == len(images):
386 | print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='')
387 | print('')
388 |
389 | return vecs
390 |
391 | def extract_ssl(net, input):
392 | return net.norm(net.features(input)).squeeze(0).view(net.meta['outputdim'], -1).cpu().data
--------------------------------------------------------------------------------
/cirtorch/networks/imageretrievalnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pdb
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | import torchvision
9 |
10 | from cirtorch.layers.pooling import MAC, SPoC, GeM, GeMmp, RMAC, Rpool
11 | from cirtorch.layers.normalization import L2N, PowerLaw
12 | from cirtorch.datasets.genericdataset import ImagesFromList
13 | from cirtorch.utils.general import get_data_root
14 | from PIL import Image
15 | from ModelHelper.Common.CommonUtils.ImageAugmentation import Padding
16 | import cv2
17 |
18 | # for some models, we have imported features (convolutions) from caffe because the image retrieval performance is higher for them
19 | FEATURES = {
20 | 'vgg16': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-vgg16-features-d369c8e.pth',
21 | 'resnet50': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet50-features-ac468af.pth',
22 | 'resnet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet101-features-10a101d.pth',
23 | 'resnet152': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet152-features-1011020.pth',
24 | }
25 |
26 | # TODO: pre-compute for more architectures and properly test variations (pre l2norm, post l2norm)
27 | # pre-computed local pca whitening that can be applied before the pooling layer
28 | L_WHITENING = {
29 | 'resnet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-9f830ef.pth',
30 | # no pre l2 norm
31 | # 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-da5c935.pth', # with pre l2 norm
32 | }
33 |
34 | # possible global pooling layers, each on of these can be made regional
35 | POOLING = {
36 | 'mac': MAC,
37 | 'spoc': SPoC,
38 | 'gem': GeM,
39 | 'gemmp': GeMmp,
40 | 'rmac': RMAC,
41 | }
42 |
43 | # TODO: pre-compute for: resnet50-gem-r, resnet50-mac-r, vgg16-mac-r, alexnet-mac-r
44 | # pre-computed regional whitening, for most commonly used architectures and pooling methods
45 | R_WHITENING = {
46 | 'alexnet-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-rwhiten-c8cf7e2.pth',
47 | 'vgg16-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth',
48 | 'resnet101-mac-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-rwhiten-7f1ed8c.pth',
49 | 'resnet101-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-rwhiten-adace84.pth',
50 | }
51 |
52 | # TODO: pre-compute for more architectures
53 | # pre-computed final (global) whitening, for most commonly used architectures and pooling methods
54 | WHITENING = {
55 | 'alexnet-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-whiten-454ad53.pth',
56 | 'alexnet-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-whiten-4c9126b.pth',
57 | 'vgg16-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-whiten-eaa6695.pth',
58 | 'vgg16-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-whiten-83582df.pth',
59 | 'resnet50-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet50-gem-whiten-f15da7b.pth',
60 | 'resnet101-mac-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-whiten-9df41d3.pth',
61 | 'resnet101-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-whiten-22ab0c1.pth',
62 | 'resnet101-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-whiten-b379c0a.pth',
63 | 'resnet101-gemmp': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gemmp-whiten-770f53c.pth',
64 | 'resnet152-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet152-gem-whiten-abe7b93.pth',
65 | 'densenet121-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet121-gem-whiten-79e3eea.pth',
66 | 'densenet169-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet169-gem-whiten-6b2a76a.pth',
67 | 'densenet201-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet201-gem-whiten-22ea45c.pth',
68 | }
69 |
70 | # output dimensionality for supported architectures
71 | OUTPUT_DIM = {
72 | 'alexnet': 256,
73 | 'vgg11': 512,
74 | 'vgg13': 512,
75 | 'vgg16': 512,
76 | 'vgg19': 512,
77 | 'resnet18': 512,
78 | 'resnet34': 512,
79 | 'resnet50': 2048,
80 | 'resnet101': 2048,
81 | 'resnet152': 2048,
82 | 'densenet121': 1024,
83 | 'densenet169': 1664,
84 | 'densenet201': 1920,
85 | 'densenet161': 2208, # largest densenet
86 | 'squeezenet1_0': 512,
87 | 'squeezenet1_1': 512,
88 | }
89 |
90 |
91 | class ImageRetrievalNet(nn.Module):
92 |
93 | def __init__(self, features, lwhiten, pool, whiten, meta):
94 | super(ImageRetrievalNet, self).__init__()
95 | self.features = nn.Sequential(*features)
96 | self.lwhiten = lwhiten
97 | self.pool = pool
98 | self.whiten = whiten
99 | self.norm = L2N()
100 | self.meta = meta
101 |
102 | def forward(self, x):
103 | # x -> features
104 | o = self.features(x)
105 |
106 | # TODO: properly test (with pre-l2norm and/or post-l2norm)
107 | # if lwhiten exist: features -> local whiten
108 | if self.lwhiten is not None:
109 | # o = self.norm(o)
110 | s = o.size()
111 | o = o.permute(0, 2, 3, 1).contiguous().view(-1, s[1])
112 | o = self.lwhiten(o)
113 | o = o.view(s[0], s[2], s[3], self.lwhiten.out_features).permute(0, 3, 1, 2)
114 | # o = self.norm(o)
115 |
116 | # features -> pool -> norm
117 | o = self.norm(self.pool(o)).squeeze(-1).squeeze(-1)
118 |
119 | # if whiten exist: pooled features -> whiten -> norm
120 | if self.whiten is not None:
121 | o = self.norm(self.whiten(o))
122 |
123 | # permute so that it is Dx1 column vector per image (DxN if many images)
124 | return o.permute(1, 0)
125 |
126 | def __repr__(self):
127 | tmpstr = super(ImageRetrievalNet, self).__repr__()[:-1]
128 | tmpstr += self.meta_repr()
129 | tmpstr = tmpstr + ')'
130 | return tmpstr
131 |
132 | def meta_repr(self):
133 | tmpstr = ' (' + 'meta' + '): dict( \n' # + self.meta.__repr__() + '\n'
134 | tmpstr += ' architecture: {}\n'.format(self.meta['architecture'])
135 | tmpstr += ' local_whitening: {}\n'.format(self.meta['local_whitening'])
136 | tmpstr += ' pooling: {}\n'.format(self.meta['pooling'])
137 | tmpstr += ' regional: {}\n'.format(self.meta['regional'])
138 | tmpstr += ' whitening: {}\n'.format(self.meta['whitening'])
139 | tmpstr += ' outputdim: {}\n'.format(self.meta['outputdim'])
140 | tmpstr += ' mean: {}\n'.format(self.meta['mean'])
141 | tmpstr += ' std: {}\n'.format(self.meta['std'])
142 | tmpstr = tmpstr + ' )\n'
143 | return tmpstr
144 |
145 |
146 | def init_network(params):
147 | # parse params with default values
148 | architecture = params.get('architecture', 'resnet101')
149 | local_whitening = params.get('local_whitening', False)
150 | pooling = params.get('pooling', 'gem')
151 | regional = params.get('regional', False)
152 | whitening = params.get('whitening', False)
153 | mean = params.get('mean', [0.485, 0.456, 0.406])
154 | std = params.get('std', [0.229, 0.224, 0.225])
155 | pretrained = params.get('pretrained', True)
156 |
157 | # get output dimensionality size
158 | dim = OUTPUT_DIM[architecture]
159 |
160 | # loading network from torchvision
161 | if pretrained:
162 | if architecture not in FEATURES:
163 | # initialize with network pretrained on imagenet in pytorch
164 | net_in = getattr(torchvision.models, architecture)(pretrained=True)
165 | else:
166 | # initialize with random weights, later on we will fill features with custom pretrained network
167 | net_in = getattr(torchvision.models, architecture)(pretrained=False)
168 | else:
169 | # initialize with random weights
170 | net_in = getattr(torchvision.models, architecture)(pretrained=False)
171 |
172 | # initialize features
173 | # take only convolutions for features,
174 | # always ends with ReLU to make last activations non-negative
175 | if architecture.startswith('alexnet'):
176 | features = list(net_in.features.children())[:-1]
177 | elif architecture.startswith('vgg'):
178 | features = list(net_in.features.children())[:-1]
179 | elif architecture.startswith('resnet'):
180 | features = list(net_in.children())[:-2]
181 | elif architecture.startswith('densenet'):
182 | features = list(net_in.features.children())
183 | features.append(nn.ReLU(inplace=True))
184 | elif architecture.startswith('squeezenet'):
185 | features = list(net_in.features.children())
186 | else:
187 | raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
188 |
189 | # initialize local whitening
190 | if local_whitening:
191 | lwhiten = nn.Linear(dim, dim, bias=True)
192 | # TODO: lwhiten with possible dimensionality reduce
193 |
194 | if pretrained:
195 | lw = architecture
196 | if lw in L_WHITENING:
197 | print(">> {}: for '{}' custom computed local whitening '{}' is used"
198 | .format(os.path.basename(__file__), lw, os.path.basename(L_WHITENING[lw])))
199 | whiten_dir = os.path.join(get_data_root(), 'whiten')
200 | lwhiten.load_state_dict(model_zoo.load_url(L_WHITENING[lw], model_dir=whiten_dir))
201 | else:
202 | print(">> {}: for '{}' there is no local whitening computed, random weights are used"
203 | .format(os.path.basename(__file__), lw))
204 |
205 | else:
206 | lwhiten = None
207 |
208 | # initialize pooling
209 | if pooling == 'gemmp':
210 | pool = POOLING[pooling](mp=dim)
211 | else:
212 | pool = POOLING[pooling]()
213 |
214 | # initialize regional pooling
215 | if regional:
216 | rpool = pool
217 | rwhiten = nn.Linear(dim, dim, bias=True)
218 | # TODO: rwhiten with possible dimensionality reduce
219 |
220 | if pretrained:
221 | rw = '{}-{}-r'.format(architecture, pooling)
222 | if rw in R_WHITENING:
223 | print(">> {}: for '{}' custom computed regional whitening '{}' is used"
224 | .format(os.path.basename(__file__), rw, os.path.basename(R_WHITENING[rw])))
225 | whiten_dir = os.path.join(get_data_root(), 'whiten')
226 | rwhiten.load_state_dict(model_zoo.load_url(R_WHITENING[rw], model_dir=whiten_dir))
227 | else:
228 | print(">> {}: for '{}' there is no regional whitening computed, random weights are used"
229 | .format(os.path.basename(__file__), rw))
230 |
231 | pool = Rpool(rpool, rwhiten)
232 |
233 | # initialize whitening
234 | if whitening:
235 | whiten = nn.Linear(dim, dim, bias=True)
236 | # TODO: whiten with possible dimensionality reduce
237 |
238 | if pretrained:
239 | w = architecture
240 | if local_whitening:
241 | w += '-lw'
242 | w += '-' + pooling
243 | if regional:
244 | w += '-r'
245 | if w in WHITENING:
246 | print(">> {}: for '{}' custom computed whitening '{}' is used"
247 | .format(os.path.basename(__file__), w, os.path.basename(WHITENING[w])))
248 | whiten_dir = os.path.join(get_data_root(), 'whiten')
249 | whiten.load_state_dict(model_zoo.load_url(WHITENING[w], model_dir=whiten_dir))
250 | else:
251 | print(">> {}: for '{}' there is no whitening computed, random weights are used"
252 | .format(os.path.basename(__file__), w))
253 | else:
254 | whiten = None
255 |
256 | # create meta information to be stored in the network
257 | meta = {
258 | 'architecture': architecture,
259 | 'local_whitening': local_whitening,
260 | 'pooling': pooling,
261 | 'regional': regional,
262 | 'whitening': whitening,
263 | 'mean': mean,
264 | 'std': std,
265 | 'outputdim': dim,
266 | }
267 |
268 | # create a generic image retrieval network
269 | net = ImageRetrievalNet(features, lwhiten, pool, whiten, meta)
270 |
271 | # initialize features with custom pretrained network if needed
272 | if pretrained and architecture in FEATURES:
273 | print(">> {}: for '{}' custom pretrained features '{}' are used"
274 | .format(os.path.basename(__file__), architecture, os.path.basename(FEATURES[architecture])))
275 | model_dir = os.path.join(get_data_root(), 'networks')
276 | net.features.load_state_dict(model_zoo.load_url(FEATURES[architecture], model_dir=model_dir))
277 |
278 | return net
279 |
280 |
281 | # def img2tensor(img_path, imsize, transform):
282 | # img = cv2.imread(img_path)
283 | # padding = Padding((imsize, imsize))
284 | # img = padding(img)
285 | # if transform is not None:
286 | # img = transform(img)
287 | #
288 | # return img
289 |
290 |
291 | def extract_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
292 | # moving network to gpu and eval mode
293 | if torch.cuda.is_available():
294 | net.cuda()
295 | net.eval()
296 |
297 | # creating dataset loader
298 | loader = torch.utils.data.DataLoader(
299 | ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
300 | batch_size=1, shuffle=False, num_workers=1, pin_memory=True
301 | )
302 |
303 | # extracting vectors
304 | with torch.no_grad():
305 | vecs = torch.zeros(net.meta['outputdim'], len(images))
306 | img_paths = list()
307 | for i, (input, path) in enumerate(loader):
308 | if torch.cuda.is_available():
309 | input = input.cuda()
310 |
311 | if len(ms) == 1 and ms[0] == 1:
312 | vecs[:, i] = extract_ss(net, input)
313 | else:
314 | vecs[:, i] = extract_ms(net, input, ms, msp)
315 | img_paths.append(path)
316 |
317 | if (i + 1) % print_freq == 0 or (i + 1) == len(images):
318 | print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')
319 | print('')
320 |
321 | # vecs = torch.zeros(net.meta['outputdim'], len(images))
322 | # img_path_list = list()
323 | # for i in range(len(images)):
324 | # img_path = images[i]
325 | # img_path_list.append(img_path)
326 | # input = img2tensor(img_path, image_size, transform)
327 | # if torch.cuda.is_available():
328 | # input = input.cuda()
329 | #
330 | # if len(ms) == 1 and ms[0] == 1:
331 | # vecs[:, i] = extract_ss(net, input)
332 | # else:
333 | # vecs[:, i] = extract_ms(net, input, ms, msp)
334 | #
335 | # if (i + 1) % print_freq == 0 or (i + 1) == len(images):
336 | # print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')
337 |
338 | return vecs, img_paths
339 |
340 |
341 | def extract_ss(net, input):
342 | return net(input).cpu().data.squeeze()
343 |
344 |
345 | def extract_ms(net, input, ms, msp):
346 | v = torch.zeros(net.meta['outputdim'])
347 |
348 | for s in ms:
349 | if s == 1:
350 | input_t = input.clone()
351 | else:
352 | input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False)
353 | v += net(input_t).pow(msp).cpu().data.squeeze()
354 |
355 | v /= len(ms)
356 | v = v.pow(1. / msp)
357 | v /= v.norm()
358 |
359 | return v
360 |
361 |
362 | def extract_regional_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
363 | # moving network to gpu and eval mode
364 | net.cuda()
365 | net.eval()
366 |
367 | # creating dataset loader
368 | loader = torch.utils.data.DataLoader(
369 | ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
370 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True
371 | )
372 |
373 | # extracting vectors
374 | with torch.no_grad():
375 | vecs = []
376 | for i, input in enumerate(loader):
377 | input = input.cuda()
378 |
379 | if len(ms) == 1:
380 | vecs.append(extract_ssr(net, input))
381 | else:
382 | # TODO: not implemented yet
383 | # vecs.append(extract_msr(net, input, ms, msp))
384 | raise NotImplementedError
385 |
386 | if (i + 1) % print_freq == 0 or (i + 1) == len(images):
387 | print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')
388 | print('')
389 |
390 | return vecs
391 |
392 |
393 | def extract_ssr(net, input):
394 | return net.pool(net.features(input), aggregate=False).squeeze(0).squeeze(-1).squeeze(-1).permute(1, 0).cpu().data
395 |
396 |
397 | def extract_local_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10):
398 | # moving network to gpu and eval mode
399 | net.cuda()
400 | net.eval()
401 |
402 | # creating dataset loader
403 | loader = torch.utils.data.DataLoader(
404 | ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform),
405 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True
406 | )
407 |
408 | # extracting vectors
409 | with torch.no_grad():
410 | vecs = []
411 | for i, input in enumerate(loader):
412 | input = input.cuda()
413 |
414 | if len(ms) == 1:
415 | vecs.append(extract_ssl(net, input))
416 | else:
417 | # TODO: not implemented yet
418 | # vecs.append(extract_msl(net, input, ms, msp))
419 | raise NotImplementedError
420 |
421 | if (i + 1) % print_freq == 0 or (i + 1) == len(images):
422 | print('\r>>>> {}/{} done...'.format((i + 1), len(images)), end='')
423 | print('')
424 |
425 | return vecs
426 |
427 |
428 | def extract_ssl(net, input):
429 | return net.norm(net.features(input)).squeeze(0).view(net.meta['outputdim'], -1).cpu().data
430 |
--------------------------------------------------------------------------------
/utils/classify.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # /usr/bin/env pythpn
3 |
4 | '''
5 | Author: yinhao
6 | Email: yinhao_x@163.com
7 | Wechat: xss_yinhao
8 | Github: http://github.com/yinhaoxs
9 | data: 2019-11-23 18:29
10 | desc:
11 | '''
12 |
13 | import torch.nn as nn
14 | import math
15 | import torch.utils.model_zoo as model_zoo
16 | from torch import nn
17 | import torch
18 | import torch.nn.functional as F
19 | from torch.autograd import Variable
20 | import cv2
21 | import shutil
22 | import numpy as np
23 | import pandas as pd
24 | from PIL import Image
25 | from torchvision import transforms
26 | from torch.utils.data import DataLoader, Dataset
27 | import os
28 | import time
29 | from collections import OrderedDict
30 |
31 | # config.py
32 | BATCH_SIZE = 16
33 | PROPOSAL_NUM = 6
34 | CAT_NUM = 4
35 | INPUT_SIZE = (448, 448) # (w, h)
36 | DROP_OUT = 0.5
37 | CLASS_NUM = 37
38 |
39 |
40 | # resnet.py
41 | def conv3x3(in_planes, out_planes, stride=1):
42 | "3x3 convolution with padding"
43 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
44 | padding=1, bias=False)
45 |
46 |
47 | class BasicBlock(nn.Module):
48 | expansion = 1
49 |
50 | def __init__(self, inplanes, planes, stride=1, downsample=None):
51 | super(BasicBlock, self).__init__()
52 | self.conv1 = conv3x3(inplanes, planes, stride)
53 | self.bn1 = nn.BatchNorm2d(planes)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.conv2 = conv3x3(planes, planes)
56 | self.bn2 = nn.BatchNorm2d(planes)
57 | self.downsample = downsample
58 | self.stride = stride
59 |
60 | def forward(self, x):
61 | residual = x
62 |
63 | out = self.conv1(x)
64 | out = self.bn1(out)
65 | out = self.relu(out)
66 |
67 | out = self.conv2(out)
68 | out = self.bn2(out)
69 |
70 | if self.downsample is not None:
71 | residual = self.downsample(x)
72 |
73 | out += residual
74 | out = self.relu(out)
75 |
76 | return out
77 |
78 |
79 | class Bottleneck(nn.Module):
80 | expansion = 4
81 |
82 | def __init__(self, inplanes, planes, stride=1, downsample=None):
83 | super(Bottleneck, self).__init__()
84 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
85 | self.bn1 = nn.BatchNorm2d(planes)
86 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
87 | padding=1, bias=False)
88 | self.bn2 = nn.BatchNorm2d(planes)
89 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
90 | self.bn3 = nn.BatchNorm2d(planes * 4)
91 | self.relu = nn.ReLU(inplace=True)
92 | self.downsample = downsample
93 | self.stride = stride
94 |
95 | def forward(self, x):
96 | residual = x
97 |
98 | out = self.conv1(x)
99 | out = self.bn1(out)
100 | out = self.relu(out)
101 |
102 | out = self.conv2(out)
103 | out = self.bn2(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv3(out)
107 | out = self.bn3(out)
108 |
109 | if self.downsample is not None:
110 | residual = self.downsample(x)
111 |
112 | out += residual
113 | out = self.relu(out)
114 |
115 | return out
116 |
117 |
118 | class ResNet(nn.Module):
119 | def __init__(self, block, layers, num_classes=1000):
120 | self.inplanes = 64
121 | super(ResNet, self).__init__()
122 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
123 | bias=False)
124 | self.bn1 = nn.BatchNorm2d(64)
125 | self.relu = nn.ReLU(inplace=True)
126 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
127 | self.layer1 = self._make_layer(block, 64, layers[0])
128 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
130 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
131 | self.avgpool = nn.AvgPool2d(7)
132 | self.fc = nn.Linear(512 * block.expansion, num_classes)
133 |
134 | for m in self.modules():
135 | if isinstance(m, nn.Conv2d):
136 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137 | m.weight.data.normal_(0, math.sqrt(2. / n))
138 | elif isinstance(m, nn.BatchNorm2d):
139 | m.weight.data.fill_(1)
140 | m.bias.data.zero_()
141 |
142 | def _make_layer(self, block, planes, blocks, stride=1):
143 | downsample = None
144 | if stride != 1 or self.inplanes != planes * block.expansion:
145 | downsample = nn.Sequential(
146 | nn.Conv2d(self.inplanes, planes * block.expansion,
147 | kernel_size=1, stride=stride, bias=False),
148 | nn.BatchNorm2d(planes * block.expansion),
149 | )
150 |
151 | layers = []
152 | layers.append(block(self.inplanes, planes, stride, downsample))
153 | self.inplanes = planes * block.expansion
154 | for i in range(1, blocks):
155 | layers.append(block(self.inplanes, planes))
156 |
157 | return nn.Sequential(*layers)
158 |
159 | def forward(self, x):
160 | x = self.conv1(x)
161 | x = self.bn1(x)
162 | x = self.relu(x)
163 | x = self.maxpool(x)
164 |
165 | x = self.layer1(x)
166 | x = self.layer2(x)
167 | x = self.layer3(x)
168 | x = self.layer4(x)
169 | feature1 = x
170 | x = self.avgpool(x)
171 | x = x.view(x.size(0), -1)
172 | x = nn.Dropout(p=0.5)(x)
173 | feature2 = x
174 | x = self.fc(x)
175 |
176 | return x, feature1, feature2
177 |
178 |
179 | # model.py
180 | class ProposalNet(nn.Module):
181 | def __init__(self):
182 | super(ProposalNet, self).__init__()
183 | self.down1 = nn.Conv2d(2048, 128, 3, 1, 1)
184 | self.down2 = nn.Conv2d(128, 128, 3, 2, 1)
185 | self.down3 = nn.Conv2d(128, 128, 3, 2, 1)
186 | self.ReLU = nn.ReLU()
187 | self.tidy1 = nn.Conv2d(128, 6, 1, 1, 0)
188 | self.tidy2 = nn.Conv2d(128, 6, 1, 1, 0)
189 | self.tidy3 = nn.Conv2d(128, 9, 1, 1, 0)
190 |
191 | def forward(self, x):
192 | batch_size = x.size(0)
193 | d1 = self.ReLU(self.down1(x))
194 | d2 = self.ReLU(self.down2(d1))
195 | d3 = self.ReLU(self.down3(d2))
196 | t1 = self.tidy1(d1).view(batch_size, -1)
197 | t2 = self.tidy2(d2).view(batch_size, -1)
198 | t3 = self.tidy3(d3).view(batch_size, -1)
199 | return torch.cat((t1, t2, t3), dim=1)
200 |
201 |
202 | class AttentionNet(nn.Module):
203 | def __init__(self, topN=4):
204 | super(attention_net, self).__init__()
205 | self.pretrained_model = ResNet(Bottleneck, [3, 4, 6, 3])
206 | self.pretrained_model.avgpool = nn.AdaptiveAvgPool2d(1)
207 | self.pretrained_model.fc = nn.Linear(512 * 4, 200)
208 | self.proposal_net = ProposalNet()
209 | self.topN = topN
210 | self.concat_net = nn.Linear(2048 * (CAT_NUM + 1), 200)
211 | self.partcls_net = nn.Linear(512 * 4, 200)
212 | _, edge_anchors, _ = generate_default_anchor_maps()
213 | self.pad_side = 224
214 | self.edge_anchors = (edge_anchors + 224).astype(np.int)
215 |
216 | def forward(self, x):
217 | resnet_out, rpn_feature, feature = self.pretrained_model(x)
218 | x_pad = F.pad(x, (self.pad_side, self.pad_side, self.pad_side, self.pad_side), mode='constant', value=0)
219 | batch = x.size(0)
220 | # we will reshape rpn to shape: batch * nb_anchor
221 | rpn_score = self.proposal_net(rpn_feature.detach())
222 | all_cdds = [
223 | np.concatenate((x.reshape(-1, 1), self.edge_anchors.copy(), np.arange(0, len(x)).reshape(-1, 1)), axis=1)
224 | for x in rpn_score.data.cpu().numpy()]
225 | top_n_cdds = [hard_nms(x, topn=self.topN, iou_thresh=0.25) for x in all_cdds]
226 | top_n_cdds = np.array(top_n_cdds)
227 | top_n_index = top_n_cdds[:, :, -1].astype(np.int)
228 | top_n_index = torch.from_numpy(top_n_index).cuda()
229 | top_n_prob = torch.gather(rpn_score, dim=1, index=top_n_index)
230 | part_imgs = torch.zeros([batch, self.topN, 3, 224, 224]).cuda()
231 | for i in range(batch):
232 | for j in range(self.topN):
233 | [y0, x0, y1, x1] = top_n_cdds[i][j, 1:5].astype(np.int)
234 | part_imgs[i:i + 1, j] = F.interpolate(x_pad[i:i + 1, :, y0:y1, x0:x1], size=(224, 224), mode='bilinear',
235 | align_corners=True)
236 | part_imgs = part_imgs.view(batch * self.topN, 3, 224, 224)
237 | _, _, part_features = self.pretrained_model(part_imgs.detach())
238 | part_feature = part_features.view(batch, self.topN, -1)
239 | part_feature = part_feature[:, :CAT_NUM, ...].contiguous()
240 | part_feature = part_feature.view(batch, -1)
241 | # concat_logits have the shape: B*200
242 | concat_out = torch.cat([part_feature, feature], dim=1)
243 | concat_logits = self.concat_net(concat_out)
244 | raw_logits = resnet_out
245 | # part_logits have the shape: B*N*200
246 | part_logits = self.partcls_net(part_features).view(batch, self.topN, -1)
247 | return [raw_logits, concat_logits, part_logits, top_n_index, top_n_prob]
248 |
249 |
250 | def list_loss(logits, targets):
251 | temp = F.log_softmax(logits, -1)
252 | loss = [-temp[i][targets[i].item()] for i in range(logits.size(0))]
253 | return torch.stack(loss)
254 |
255 |
256 | def ranking_loss(score, targets, proposal_num=PROPOSAL_NUM):
257 | loss = Variable(torch.zeros(1).cuda())
258 | batch_size = score.size(0)
259 | for i in range(proposal_num):
260 | targets_p = (targets > targets[:, i].unsqueeze(1)).type(torch.cuda.FloatTensor)
261 | pivot = score[:, i].unsqueeze(1)
262 | loss_p = (1 - pivot + score) * targets_p
263 | loss_p = torch.sum(F.relu(loss_p))
264 | loss += loss_p
265 | return loss / batch_size
266 |
267 |
268 | # anchors.py
269 | _default_anchors_setting = (
270 | dict(layer='p3', stride=32, size=48, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
271 | dict(layer='p4', stride=64, size=96, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
272 | dict(layer='p5', stride=128, size=192, scale=[1, 2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]),
273 | )
274 |
275 |
276 | def generate_default_anchor_maps(anchors_setting=None, input_shape=INPUT_SIZE):
277 | """
278 | generate default anchor
279 | :param anchors_setting: all informations of anchors
280 | :param input_shape: shape of input images, e.g. (h, w)
281 | :return: center_anchors: # anchors * 4 (oy, ox, h, w)
282 | edge_anchors: # anchors * 4 (y0, x0, y1, x1)
283 | anchor_area: # anchors * 1 (area)
284 | """
285 | if anchors_setting is None:
286 | anchors_setting = _default_anchors_setting
287 |
288 | center_anchors = np.zeros((0, 4), dtype=np.float32)
289 | edge_anchors = np.zeros((0, 4), dtype=np.float32)
290 | anchor_areas = np.zeros((0,), dtype=np.float32)
291 | input_shape = np.array(input_shape, dtype=int)
292 |
293 | for anchor_info in anchors_setting:
294 |
295 | stride = anchor_info['stride']
296 | size = anchor_info['size']
297 | scales = anchor_info['scale']
298 | aspect_ratios = anchor_info['aspect_ratio']
299 |
300 | output_map_shape = np.ceil(input_shape.astype(np.float32) / stride)
301 | output_map_shape = output_map_shape.astype(np.int)
302 | output_shape = tuple(output_map_shape) + (4,)
303 | ostart = stride / 2.
304 | oy = np.arange(ostart, ostart + stride * output_shape[0], stride)
305 | oy = oy.reshape(output_shape[0], 1)
306 | ox = np.arange(ostart, ostart + stride * output_shape[1], stride)
307 | ox = ox.reshape(1, output_shape[1])
308 | center_anchor_map_template = np.zeros(output_shape, dtype=np.float32)
309 | center_anchor_map_template[:, :, 0] = oy
310 | center_anchor_map_template[:, :, 1] = ox
311 | for scale in scales:
312 | for aspect_ratio in aspect_ratios:
313 | center_anchor_map = center_anchor_map_template.copy()
314 | center_anchor_map[:, :, 2] = size * scale / float(aspect_ratio) ** 0.5
315 | center_anchor_map[:, :, 3] = size * scale * float(aspect_ratio) ** 0.5
316 |
317 | edge_anchor_map = np.concatenate((center_anchor_map[..., :2] - center_anchor_map[..., 2:4] / 2.,
318 | center_anchor_map[..., :2] + center_anchor_map[..., 2:4] / 2.),
319 | axis=-1)
320 | anchor_area_map = center_anchor_map[..., 2] * center_anchor_map[..., 3]
321 | center_anchors = np.concatenate((center_anchors, center_anchor_map.reshape(-1, 4)))
322 | edge_anchors = np.concatenate((edge_anchors, edge_anchor_map.reshape(-1, 4)))
323 | anchor_areas = np.concatenate((anchor_areas, anchor_area_map.reshape(-1)))
324 |
325 | return center_anchors, edge_anchors, anchor_areas
326 |
327 |
328 | def hard_nms(cdds, topn=10, iou_thresh=0.25):
329 | if not (type(cdds).__module__ == 'numpy' and len(cdds.shape) == 2 and cdds.shape[1] >= 5):
330 | raise TypeError('edge_box_map should be N * 5+ ndarray')
331 |
332 | cdds = cdds.copy()
333 | indices = np.argsort(cdds[:, 0])
334 | cdds = cdds[indices]
335 | cdd_results = []
336 |
337 | res = cdds
338 |
339 | while res.any():
340 | cdd = res[-1]
341 | cdd_results.append(cdd)
342 | if len(cdd_results) == topn:
343 | return np.array(cdd_results)
344 | res = res[:-1]
345 |
346 | start_max = np.maximum(res[:, 1:3], cdd[1:3])
347 | end_min = np.minimum(res[:, 3:5], cdd[3:5])
348 | lengths = end_min - start_max
349 | intersec_map = lengths[:, 0] * lengths[:, 1]
350 | intersec_map[np.logical_or(lengths[:, 0] < 0, lengths[:, 1] < 0)] = 0
351 | iou_map_cur = intersec_map / ((res[:, 3] - res[:, 1]) * (res[:, 4] - res[:, 2]) + (cdd[3] - cdd[1]) * (
352 | cdd[4] - cdd[2]) - intersec_map)
353 | res = res[iou_map_cur < iou_thresh]
354 |
355 | return np.array(cdd_results)
356 |
357 |
358 | #### -------------------------------如何定义batch的读写方式-------------------------------
359 | # 默认读写方式
360 | def default_loader(path):
361 | try:
362 | img = Image.open(path).convert("RGB")
363 | if img is not None:
364 | return img
365 | except:
366 | print("error image:{}".format(path))
367 |
368 |
369 | def opencv_isvalid(img_path):
370 | img_bgr = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
371 | img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
372 | return img_bgr
373 |
374 |
375 | # 判断图片是否为无效
376 | def IsValidImage(img_path):
377 | vaild = True
378 | if img_path.endswith(".tif") or img_path.endswith(".tiff"):
379 | vaild = False
380 | return vaild
381 | try:
382 | img = opencv_isvalid(img_path)
383 | if img is None:
384 | vaild = False
385 | return vaild
386 | except:
387 | vaild = False
388 | return vaild
389 |
390 |
391 | class MyDataset(Dataset):
392 | def __init__(self, dir_path, transform=None, loader=default_loader):
393 | fh, imgs = list(), list()
394 | num = 0
395 | for root, dirs, files in os.walk(dir_path):
396 | for file in files:
397 | try:
398 | img_path = os.path.join(root + os.sep, file)
399 | num += 1
400 | if IsValidImage(img_path):
401 | fh.append(img_path)
402 | else:
403 | os.remove(img_path)
404 |
405 | except:
406 | print("image is broken")
407 | print("total images is:{}".format(num))
408 |
409 | for line in fh:
410 | line = line.strip()
411 | imgs.append(line)
412 |
413 | self.imgs = imgs
414 | self.transform = transform
415 | self.loader = loader
416 |
417 | def __getitem__(self, item):
418 | fh = self.imgs[item]
419 | img = self.loader(fh)
420 | if self.transform is not None:
421 | img = self.transform(img)
422 | return fh, img
423 |
424 | def __len__(self):
425 | return len(self.imgs)
426 |
427 |
428 | #### -------------------------------如何定义batch的读写方式-------------------------------
429 |
430 |
431 | #### -------------------------------图像模糊的定义-------------------------------
432 | def variance_of_laplacian(image):
433 | return cv2.Laplacian(image, cv2.CV_64f).var()
434 |
435 |
436 | ## 如何定义接口函数
437 | def imgQualJudge(img, QA_THRESHOLD):
438 | '''
439 | :param img:
440 | :param QA_THRESHOLD: 越高越清晰
441 | :return: 是否模糊,0为模糊,1为清晰
442 | '''
443 |
444 | norheight = 1707
445 | norwidth = 1280
446 | flag = 0
447 | # 筛选尺寸
448 | if max(img.shape[0], img.shape[1]) < 320:
449 | flag = '10002'
450 | return flag
451 |
452 | # 模糊筛选部分
453 | if img.shape[0] <= img.shape[1]:
454 | size1 = (norheight, norwidth)
455 | timage = cv2.resize(img, size1)
456 | else:
457 | size2 = (norwidth, norheight)
458 | timage = cv2.resize(img, size2)
459 |
460 | tgray = cv2.cvtColor(timage, cv2.COLOR_BGR2GRAY)
461 | halfgray = tgray[0:int(tgray.shape[0] / 2), 0:tgray.shape[1]]
462 | norgrayImg = np.zeros(halfgray.shape, np.int8)
463 | cv2.normalize(halfgray, norgrayImg, 0, 255, cv2.NORM_MINMAX)
464 | fm = variance_of_laplacian(norgrayImg) # 模糊值
465 | if fm < QA_THRESHOLD:
466 | flag = '10001'
467 | return flag
468 | return flag
469 |
470 |
471 | def process(img_path):
472 | img = Image.open(img_path).convert("RGB")
473 | valid = True
474 | low_quality = "10001"
475 | size_error = "10002"
476 |
477 | flag = imgQualJudge(np.array(img), 5)
478 | if flag == low_quality or flag == size_error or not img or 0 in np.asarray(img).shape[:2]:
479 | valid = False
480 |
481 | return valid
482 |
483 |
484 | #### -------------------------------图像模糊的定义-------------------------------
485 |
486 | def build_dict():
487 | dict_club = dict()
488 | dict_club[0] = ["身份证", 0.999999]
489 | dict_club[1] = ["校园卡", 0.890876]
490 | return dict_club
491 |
492 |
493 | class Classifier():
494 | def __init__(self):
495 | self.device = torch.device('cuda')
496 | self.class_id_name_dict = build_dict()
497 | self.mean = [0.485, 0.456, 0.406]
498 | self.std = [0.229, 0.224, 0.225]
499 | self.input_size = 448
500 | self.use_cuda = torch.cuda.is_available()
501 | self.model = AttentionNet(topN=4)
502 | self.model.eval()
503 |
504 | checkpoint = torch.load("./.ckpt")
505 | newweights = checkpoint['net_state_dict']
506 |
507 | # 多卡测试转为单卡
508 | new_state_dic = OrderedDict()
509 | for k, v in newweights.items():
510 | name = k[7:] if k.startwith("module.") else k
511 | new_state_dic[name] = v
512 |
513 | self.model.load_state_dict(new_state_dic)
514 | self.model = self.model.to(self.device)
515 |
516 | def evalute(self, dir_path):
517 | data = MyDataset(dir_path, transform=self.preprocess)
518 | dataloader = DataLoader(dataset=data, batch_size=32, num_workers=8)
519 |
520 | self.model.eval()
521 | with torch.no_grad():
522 | num = 0
523 | for i, (data, path) in enumerate(dataloader, 1):
524 | data = data.to(self.device)
525 | output = self.model(data)
526 | for j in range(len(data)):
527 | img_path = path[j]
528 | img_output = output[1][j]
529 | score, label, type = self.postprocess(img_output)
530 | out_dict, score = self.process(score, label, type)
531 | class_id = out_dict["results"]["class2"]["code"]
532 | num += 1
533 | if class_id != '00038':
534 | os.remove(img_path)
535 | else:
536 | continue
537 |
538 | def preprocess(self, img):
539 | img = transforms.Resize((600, 600), Image.BILINEAR)(img)
540 | img = transforms.CenterCrop(self.input_size)(img)
541 | img = transforms.ToTensor()(img)
542 | img = transforms.Normalize(self.mean, self.std)
543 |
544 | def postprocess(self, output):
545 | pred_logits = F.softmax(output, dim=0)
546 | score, label = pred_logits.max(0)
547 | score = score.item()
548 | label = label.item()
549 | type = self.class_id_name_dict[label][0]
550 | return score, label, type
551 |
552 | def process(self, score, label, type):
553 | success_code = "200"
554 | lower_conf_code = "10008"
555 |
556 | threshold = float(self.class_id_name_dict[label][1])
557 | if threshold > 0.99:
558 | threshold = 0.99
559 | if threshold < 0.9:
560 | threshold = 0.9
561 | ## 设置查勘图片较低的阈值
562 | if label == 38:
563 | threshold = 0.5
564 |
565 | if score > threshold:
566 | status_code = success_code
567 | pred_label = str(label).zfill(5)
568 | print("pred_label:", pred_label)
569 | return {"code:": status_code, "message": '图像分类成功',
570 | "results": {"class2": {'code': pred_label, 'name': type}}}, score
571 | else:
572 | status_code = lower_conf_code
573 | return {"code:": status_code, "message": '图像分类置信度低,不返回结果',
574 | "results": {"class2": {'code': '', 'name': ''}}}, score
575 |
576 |
577 | def class_results(img_dir):
578 | Classifier().evalute(img_dir)
579 |
580 |
581 | if __name__ == "__main__":
582 | pass
583 |
--------------------------------------------------------------------------------