├── cirtorch ├── layers │ ├── __init__.py │ ├── normalization.py │ ├── loss.py │ ├── pooling.py │ └── functional.py ├── utils │ ├── __init__.py │ ├── general.py │ ├── whiten.py │ ├── evaluate.py │ ├── download.py │ └── download_win.py ├── datasets │ ├── __init__.py │ ├── testdataset.py │ ├── datahelpers.py │ ├── genericdataset.py │ └── traindataset.py ├── examples │ ├── __init__.py │ ├── test_e2e.py │ └── test.py ├── networks │ ├── __init__.py │ ├── imageretrievalnet_cpu.py │ └── imageretrievalnet.py ├── .DS_Store └── __init__.py ├── utils ├── .DS_Store ├── retrieval_feature.py ├── retrieval_index.py └── classify.py ├── config.yaml ├── lshash ├── __init__.py ├── storage.py └── lshash.py ├── .idea ├── vcs.xml ├── misc.xml ├── modules.xml ├── IMAGE_Retrieval.iml └── workspace.xml ├── nts ├── config.py ├── README.md ├── test.py ├── core │ ├── utils.py │ ├── dataset.py │ ├── anchors.py │ ├── model.py │ └── resnet.py └── train.py ├── demo.py ├── README.md ├── .gitignore └── interface.py /cirtorch/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cirtorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cirtorch/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cirtorch/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cirtorch/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yinhaoxs/ImageRetrieval-LSH/HEAD/utils/.DS_Store -------------------------------------------------------------------------------- /cirtorch/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yinhaoxs/ImageRetrieval-LSH/HEAD/cirtorch/.DS_Store -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | websites: 2 | host: 0.0.0.0 3 | port: 15788 4 | 5 | model: 6 | network: /*.pth 7 | model_dir: /* 8 | type: [SA,SB] -------------------------------------------------------------------------------- /lshash/__init__.py: -------------------------------------------------------------------------------- 1 | import pkg_resources 2 | 3 | try: 4 | __version__ = pkg_resources.get_distribution(__name__).version 5 | except: 6 | __version__ = '0.0.4dev' 7 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /nts/config.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 16 2 | PROPOSAL_NUM = 6 3 | CAT_NUM = 4 4 | INPUT_SIZE = (448, 448) # (w, h) 5 | LR = 0.001 6 | WD = 1e-4 7 | SAVE_FREQ = 1 8 | resume = '' 9 | test_model = 'model.ckpt' 10 | save_dir = '/data_4t/yangz/models/' 11 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /cirtorch/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets, examples, layers, networks, utils 2 | 3 | from .datasets import datahelpers, genericdataset, testdataset, traindataset 4 | from .layers import functional, loss, normalization, pooling 5 | from .networks import imageretrievalnet 6 | from .utils import general, download, evaluate, whiten -------------------------------------------------------------------------------- /.idea/IMAGE_Retrieval.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /cirtorch/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import cirtorch.layers.functional as LF 5 | 6 | # -------------------------------------- 7 | # Normalization layers 8 | # -------------------------------------- 9 | 10 | class L2N(nn.Module): 11 | 12 | def __init__(self, eps=1e-6): 13 | super(L2N,self).__init__() 14 | self.eps = eps 15 | 16 | def forward(self, x): 17 | return LF.l2n(x, eps=self.eps) 18 | 19 | def __repr__(self): 20 | return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')' 21 | 22 | 23 | class PowerLaw(nn.Module): 24 | 25 | def __init__(self, eps=1e-6): 26 | super(PowerLaw, self).__init__() 27 | self.eps = eps 28 | 29 | def forward(self, x): 30 | return LF.powerlaw(x, eps=self.eps) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')' -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from utils.retrieval_feature import AntiFraudFeatureDataset 2 | from utils.retrieval_index import EvaluteMap 3 | 4 | 5 | if __name__ == '__main__': 6 | """ 7 | img_dir存放所有图像库的图片,然后拿test_img_dir中的图片与图像库中的图片匹配,并输出top3的图像路径; 8 | """ 9 | hash_size = 0 10 | input_dim = 2048 11 | num_hashtables = 1 12 | img_dir = 'ImageRetrieval/data' 13 | test_img_dir = './images' 14 | network = './weights/gl18-tl-resnet50-gem-w-83fdc30.pth' 15 | out_similar_dir = './output/similar' 16 | out_similar_file_dir = './output/similar_file' 17 | all_csv_file = './output/aaa.csv' 18 | 19 | feature_dict, lsh = AntiFraudFeatureDataset(img_dir, network).constructfeature(hash_size, input_dim, num_hashtables) 20 | test_feature_dict = AntiFraudFeatureDataset(test_img_dir, network).test_feature() 21 | EvaluteMap(out_similar_dir, out_similar_file_dir, all_csv_file).retrieval_images(test_feature_dict, lsh, 3) 22 | 23 | 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 图像检索模型 2 | 3 | ## 1.模型介绍 4 | --技术:深度卷积神经网络技术、LSH局部敏感哈希算法、flask web端部署、nts细粒度分类技术 5 | 6 | ## 2.预训练模型 7 | --图像分类预训练模型:https://drive.google.com/file/d/1F-eKqPRjlya5GH2HwTlLKNSPEUaxCu9H/view?usp=sharing 8 | --图像检索预训练模型:http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth 9 | resnet50 http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth 10 | 11 | ## 3.数据集 12 | --生产上数据集进行迁移学习并查重 13 | 14 | ## 4.数据预处理 15 | 数据大小处理为(224*224) 16 | --筛选出tiff、tif等格式文件,并解决pillow的底层问题(opencv解决conver(“RGB”)问题) 17 | --数据分类筛选:采用nts网络进行细粒度分类 18 | 19 | ## 5.模型业务使用 20 | --分类:python utils/classify.py 21 | --特征提取:python utils/retrieval_feature.py 22 | --图像离线检索:python utils/retrieval_index.py 23 | --在线部署:python interface.py (采用flask框架部署,同时离线更新数据库) 24 | python app_test.py (测试接口) 25 | 26 | ## 6.指标 27 | --map:0.93 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /cirtorch/utils/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | 4 | def get_root(): 5 | return os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))) 6 | 7 | 8 | def get_data_root(): 9 | return os.path.join(get_root(), 'data') 10 | 11 | 12 | def htime(c): 13 | c = round(c) 14 | 15 | days = c // 86400 16 | hours = c // 3600 % 24 17 | minutes = c // 60 % 60 18 | seconds = c % 60 19 | 20 | if days > 0: 21 | return '{:d}d {:d}h {:d}m {:d}s'.format(days, hours, minutes, seconds) 22 | if hours > 0: 23 | return '{:d}h {:d}m {:d}s'.format(hours, minutes, seconds) 24 | if minutes > 0: 25 | return '{:d}m {:d}s'.format(minutes, seconds) 26 | return '{:d}s'.format(seconds) 27 | 28 | 29 | def sha256_hash(filename, block_size=65536, length=8): 30 | sha256 = hashlib.sha256() 31 | with open(filename, 'rb') as f: 32 | for block in iter(lambda: f.read(block_size), b''): 33 | sha256.update(block) 34 | return sha256.hexdigest()[:length-1] -------------------------------------------------------------------------------- /cirtorch/datasets/testdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | DATASETS = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] 5 | 6 | def configdataset(dataset, dir_main): 7 | 8 | dataset = dataset.lower() 9 | 10 | if dataset not in DATASETS: 11 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 12 | 13 | # loading imlist, qimlist, and gnd, in cfg as a dict 14 | gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset)) 15 | with open(gnd_fname, 'rb') as f: 16 | cfg = pickle.load(f) 17 | cfg['gnd_fname'] = gnd_fname 18 | 19 | cfg['ext'] = '.jpg' 20 | cfg['qext'] = '.jpg' 21 | cfg['dir_data'] = os.path.join(dir_main, dataset) 22 | cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg') 23 | 24 | cfg['n'] = len(cfg['imlist']) 25 | cfg['nq'] = len(cfg['qimlist']) 26 | 27 | cfg['im_fname'] = config_imname 28 | cfg['qim_fname'] = config_qimname 29 | 30 | cfg['dataset'] = dataset 31 | 32 | return cfg 33 | 34 | def config_imname(cfg, i): 35 | return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext']) 36 | 37 | def config_qimname(cfg, i): 38 | return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext']) 39 | -------------------------------------------------------------------------------- /cirtorch/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import cirtorch.layers.functional as LF 5 | 6 | # -------------------------------------- 7 | # Loss/Error layers 8 | # -------------------------------------- 9 | 10 | class ContrastiveLoss(nn.Module): 11 | r"""CONTRASTIVELOSS layer that computes contrastive loss for a batch of images: 12 | Q query tuples, each packed in the form of (q,p,n1,..nN) 13 | 14 | Args: 15 | x: tuples arranges in columns as [q,p,n1,nN, ... ] 16 | label: -1 for query, 1 for corresponding positive, 0 for corresponding negative 17 | margin: contrastive loss margin. Default: 0.7 18 | 19 | >>> contrastive_loss = ContrastiveLoss(margin=0.7) 20 | >>> input = torch.randn(128, 35, requires_grad=True) 21 | >>> label = torch.Tensor([-1, 1, 0, 0, 0, 0, 0] * 5) 22 | >>> output = contrastive_loss(input, label) 23 | >>> output.backward() 24 | """ 25 | 26 | def __init__(self, margin=0.7, eps=1e-6): 27 | super(ContrastiveLoss, self).__init__() 28 | self.margin = margin 29 | self.eps = eps 30 | 31 | def forward(self, x, label): 32 | return LF.contrastive_loss(x, label, margin=self.margin, eps=self.eps) 33 | 34 | def __repr__(self): 35 | return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')' 36 | 37 | 38 | class TripletLoss(nn.Module): 39 | 40 | def __init__(self, margin=0.1): 41 | super(TripletLoss, self).__init__() 42 | self.margin = margin 43 | 44 | def forward(self, x, label): 45 | return LF.triplet_loss(x, label, margin=self.margin) 46 | 47 | def __repr__(self): 48 | return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')' 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /cirtorch/datasets/datahelpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | import torch 5 | 6 | def cid2filename(cid, prefix): 7 | """ 8 | Creates a training image path out of its CID name 9 | 10 | Arguments 11 | --------- 12 | cid : name of the image 13 | prefix : root directory where images are saved 14 | 15 | Returns 16 | ------- 17 | filename : full image filename 18 | """ 19 | return os.path.join(prefix, cid[-2:], cid[-4:-2], cid[-6:-4], cid) 20 | 21 | def pil_loader(path): 22 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 23 | with open(path, 'rb') as f: 24 | img = Image.open(f) 25 | return img.convert('RGB') 26 | 27 | def accimage_loader(path): 28 | import accimage 29 | try: 30 | return accimage.Image(path) 31 | except IOError: 32 | # Potentially a decoding problem, fall back to PIL.Image 33 | return pil_loader(path) 34 | 35 | def default_loader(path): 36 | from torchvision import get_image_backend 37 | if get_image_backend() == 'accimage': 38 | return accimage_loader(path) 39 | else: 40 | return pil_loader(path) 41 | 42 | def imresize(img, imsize): 43 | img.thumbnail((imsize, imsize), Image.ANTIALIAS) 44 | return img 45 | 46 | def flip(x, dim): 47 | xsize = x.size() 48 | dim = x.dim() + dim if dim < 0 else dim 49 | x = x.view(-1, *xsize[dim:]) 50 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] 51 | return x.view(xsize) 52 | 53 | def collate_tuples(batch): 54 | if len(batch) == 1: 55 | return [batch[0][0]], [batch[0][1]] 56 | return [batch[i][0] for i in range(len(batch))], [batch[i][1] for i in range(len(batch))] -------------------------------------------------------------------------------- /nts/README.md: -------------------------------------------------------------------------------- 1 | # NTS-Net 2 | 3 | This is a PyTorch implementation of the ECCV2018 paper "Learning to Navigate for Fine-grained Classification" (Ze Yang, Tiange Luo, Dong Wang, Zhiqiang Hu, Jun Gao, Liwei Wang). 4 | 5 | ## Requirements 6 | - python 3+ 7 | - pytorch 0.4+ 8 | - numpy 9 | - datetime 10 | 11 | ## Datasets 12 | Download the [CUB-200-2011](http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) datasets and put it in the root directory named **CUB_200_2011**, You can also try other fine-grained datasets. 13 | 14 | ## Train the model 15 | If you want to train the NTS-Net, just run ``python train.py``. You may need to change the configurations in ``config.py``. The parameter ``PROPOSAL_NUM`` is ``M`` in the original paper and the parameter ``CAT_NUM`` is ``K`` in the original paper. During training, the log file and checkpoint file will be saved in ``save_dir`` directory. You can change the parameter ``resume`` to choose the checkpoint model to resume. 16 | 17 | ## Test the model 18 | If you want to test the NTS-Net, just run ``python test.py``. You need to specify the ``test_model`` in ``config.py`` to choose the checkpoint model for testing. 19 | 20 | ## Model 21 | We also provide the checkpoint model trained by ourselves, you can download it from [here](https://drive.google.com/file/d/1F-eKqPRjlya5GH2HwTlLKNSPEUaxCu9H/view?usp=sharing). If you test on our provided model, you will get a 87.6% test accuracy. 22 | 23 | ## Reference 24 | If you are interested in our work and want to cite it, please acknowledge the following paper: 25 | 26 | ``` 27 | @inproceedings{Yang2018Learning, 28 | author = {Yang, Ze and Luo, Tiange and Wang, Dong and Hu, Zhiqiang and Gao, Jun and Wang, Liwei}, 29 | title = {Learning to Navigate for Fine-grained Classification}, 30 | booktitle = {ECCV}, 31 | year = {2018} 32 | } 33 | ``` 34 | -------------------------------------------------------------------------------- /cirtorch/utils/whiten.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def whitenapply(X, m, P, dimensions=None): 5 | 6 | if not dimensions: 7 | dimensions = P.shape[0] 8 | 9 | X = np.dot(P[:dimensions, :], X-m) 10 | X = X / (np.linalg.norm(X, ord=2, axis=0, keepdims=True) + 1e-6) 11 | 12 | return X 13 | 14 | def pcawhitenlearn(X): 15 | 16 | N = X.shape[1] 17 | 18 | # Learning PCA w/o annotations 19 | m = X.mean(axis=1, keepdims=True) 20 | Xc = X - m 21 | Xcov = np.dot(Xc, Xc.T) 22 | Xcov = (Xcov + Xcov.T) / (2*N) 23 | eigval, eigvec = np.linalg.eig(Xcov) 24 | order = eigval.argsort()[::-1] 25 | eigval = eigval[order] 26 | eigvec = eigvec[:, order] 27 | 28 | P = np.dot(np.linalg.inv(np.sqrt(np.diag(eigval))), eigvec.T) 29 | 30 | return m, P 31 | 32 | def whitenlearn(X, qidxs, pidxs): 33 | 34 | # Learning Lw w annotations 35 | m = X[:, qidxs].mean(axis=1, keepdims=True) 36 | df = X[:, qidxs] - X[:, pidxs] 37 | S = np.dot(df, df.T) / df.shape[1] 38 | P = np.linalg.inv(cholesky(S)) 39 | df = np.dot(P, X-m) 40 | D = np.dot(df, df.T) 41 | eigval, eigvec = np.linalg.eig(D) 42 | order = eigval.argsort()[::-1] 43 | eigval = eigval[order] 44 | eigvec = eigvec[:, order] 45 | 46 | P = np.dot(eigvec.T, P) 47 | 48 | return m, P 49 | 50 | def cholesky(S): 51 | # Cholesky decomposition 52 | # with adding a small value on the diagonal 53 | # until matrix is positive definite 54 | alpha = 0 55 | while 1: 56 | try: 57 | L = np.linalg.cholesky(S + alpha*np.eye(*S.shape)) 58 | return L 59 | except: 60 | if alpha == 0: 61 | alpha = 1e-10 62 | else: 63 | alpha *= 10 64 | print(">>>> {}::cholesky: Matrix is not positive definite, adding {:.0e} on the diagonal" 65 | .format(os.path.basename(__file__), alpha)) 66 | -------------------------------------------------------------------------------- /nts/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.autograd import Variable 3 | import torch.utils.data 4 | from torch.nn import DataParallel 5 | from config import BATCH_SIZE, PROPOSAL_NUM, test_model 6 | from core import model, dataset 7 | from core.utils import progress_bar 8 | 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 10 | if not test_model: 11 | raise NameError('please set the test_model file to choose the checkpoint!') 12 | # read dataset 13 | trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None) 14 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, 15 | shuffle=True, num_workers=8, drop_last=False) 16 | testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None) 17 | testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, 18 | shuffle=False, num_workers=8, drop_last=False) 19 | # define model 20 | net = model.attention_net(topN=PROPOSAL_NUM) 21 | ckpt = torch.load(test_model) 22 | net.load_state_dict(ckpt['net_state_dict']) 23 | net = net.cuda() 24 | net = DataParallel(net) 25 | creterion = torch.nn.CrossEntropyLoss() 26 | 27 | # evaluate on train set 28 | train_loss = 0 29 | train_correct = 0 30 | total = 0 31 | net.eval() 32 | 33 | for i, data in enumerate(trainloader): 34 | with torch.no_grad(): 35 | img, label = data[0].cuda(), data[1].cuda() 36 | batch_size = img.size(0) 37 | _, concat_logits, _, _, _ = net(img) 38 | # calculate loss 39 | concat_loss = creterion(concat_logits, label) 40 | # calculate accuracy 41 | _, concat_predict = torch.max(concat_logits, 1) 42 | total += batch_size 43 | train_correct += torch.sum(concat_predict.data == label.data) 44 | train_loss += concat_loss.item() * batch_size 45 | progress_bar(i, len(trainloader), 'eval on train set') 46 | 47 | train_acc = float(train_correct) / total 48 | train_loss = train_loss / total 49 | print('train set loss: {:.3f} and train set acc: {:.3f} total sample: {}'.format(train_loss, train_acc, total)) 50 | 51 | 52 | # evaluate on test set 53 | test_loss = 0 54 | test_correct = 0 55 | total = 0 56 | for i, data in enumerate(testloader): 57 | with torch.no_grad(): 58 | img, label = data[0].cuda(), data[1].cuda() 59 | batch_size = img.size(0) 60 | _, concat_logits, _, _, _ = net(img) 61 | # calculate loss 62 | concat_loss = creterion(concat_logits, label) 63 | # calculate accuracy 64 | _, concat_predict = torch.max(concat_logits, 1) 65 | total += batch_size 66 | test_correct += torch.sum(concat_predict.data == label.data) 67 | test_loss += concat_loss.item() * batch_size 68 | progress_bar(i, len(testloader), 'eval on test set') 69 | 70 | test_acc = float(test_correct) / total 71 | test_loss = test_loss / total 72 | print('test set loss: {:.3f} and test set acc: {:.3f} total sample: {}'.format(test_loss, test_acc, total)) 73 | 74 | print('finishing testing') 75 | -------------------------------------------------------------------------------- /nts/core/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import time 5 | import logging 6 | 7 | _, term_width = os.popen('stty size', 'r').read().split() 8 | term_width = int(term_width) 9 | 10 | TOTAL_BAR_LENGTH = 40. 11 | last_time = time.time() 12 | begin_time = last_time 13 | 14 | 15 | def progress_bar(current, total, msg=None): 16 | global last_time, begin_time 17 | if current == 0: 18 | begin_time = time.time() # Reset for new bar. 19 | 20 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 21 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 22 | 23 | sys.stdout.write(' [') 24 | for i in range(cur_len): 25 | sys.stdout.write('=') 26 | sys.stdout.write('>') 27 | for i in range(rest_len): 28 | sys.stdout.write('.') 29 | sys.stdout.write(']') 30 | 31 | cur_time = time.time() 32 | step_time = cur_time - last_time 33 | last_time = cur_time 34 | tot_time = cur_time - begin_time 35 | 36 | L = [] 37 | L.append(' Step: %s' % format_time(step_time)) 38 | L.append(' | Tot: %s' % format_time(tot_time)) 39 | if msg: 40 | L.append(' | ' + msg) 41 | 42 | msg = ''.join(L) 43 | sys.stdout.write(msg) 44 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 45 | sys.stdout.write(' ') 46 | 47 | # Go back to the center of the bar. 48 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2)): 49 | sys.stdout.write('\b') 50 | sys.stdout.write(' %d/%d ' % (current + 1, total)) 51 | 52 | if current < total - 1: 53 | sys.stdout.write('\r') 54 | else: 55 | sys.stdout.write('\n') 56 | sys.stdout.flush() 57 | 58 | 59 | def format_time(seconds): 60 | days = int(seconds / 3600 / 24) 61 | seconds = seconds - days * 3600 * 24 62 | hours = int(seconds / 3600) 63 | seconds = seconds - hours * 3600 64 | minutes = int(seconds / 60) 65 | seconds = seconds - minutes * 60 66 | secondsf = int(seconds) 67 | seconds = seconds - secondsf 68 | millis = int(seconds * 1000) 69 | 70 | f = '' 71 | i = 1 72 | if days > 0: 73 | f += str(days) + 'D' 74 | i += 1 75 | if hours > 0 and i <= 2: 76 | f += str(hours) + 'h' 77 | i += 1 78 | if minutes > 0 and i <= 2: 79 | f += str(minutes) + 'm' 80 | i += 1 81 | if secondsf > 0 and i <= 2: 82 | f += str(secondsf) + 's' 83 | i += 1 84 | if millis > 0 and i <= 2: 85 | f += str(millis) + 'ms' 86 | i += 1 87 | if f == '': 88 | f = '0ms' 89 | return f 90 | 91 | 92 | def init_log(output_dir): 93 | logging.basicConfig(level=logging.DEBUG, 94 | format='%(asctime)s %(message)s', 95 | datefmt='%Y%m%d-%H:%M:%S', 96 | filename=os.path.join(output_dir, 'log.log'), 97 | filemode='w') 98 | console = logging.StreamHandler() 99 | console.setLevel(logging.INFO) 100 | logging.getLogger('').addHandler(console) 101 | return logging 102 | 103 | if __name__ == '__main__': 104 | pass 105 | -------------------------------------------------------------------------------- /lshash/storage.py: -------------------------------------------------------------------------------- 1 | # lshash/storage.py 2 | # Copyright 2012 Kay Zhu (a.k.a He Zhu) and contributors (see CONTRIBUTORS.txt) 3 | # 4 | # This module is part of lshash and is released under 5 | # the MIT License: http://www.opensource.org/licenses/mit-license.php 6 | 7 | import json 8 | 9 | try: 10 | import redis 11 | except ImportError: 12 | redis = None 13 | 14 | __all__ = ['storage'] 15 | 16 | 17 | def storage(storage_config, index): 18 | """ Given the configuration for storage and the index, return the 19 | configured storage instance. 20 | """ 21 | if 'dict' in storage_config: 22 | return InMemoryStorage(storage_config['dict']) 23 | elif 'redis' in storage_config: 24 | storage_config['redis']['db'] = index 25 | return RedisStorage(storage_config['redis']) 26 | else: 27 | raise ValueError("Only in-memory dictionary and Redis are supported.") 28 | 29 | 30 | class BaseStorage(object): 31 | def __init__(self, config): 32 | """ An abstract class used as an adapter for storages. """ 33 | raise NotImplementedError 34 | 35 | def keys(self): 36 | """ Returns a list of binary hashes that are used as dict keys. """ 37 | raise NotImplementedError 38 | 39 | def set_val(self, key, val): 40 | """ Set `val` at `key`, note that the `val` must be a string. """ 41 | raise NotImplementedError 42 | 43 | def get_val(self, key): 44 | """ Return `val` at `key`, note that the `val` must be a string. """ 45 | raise NotImplementedError 46 | 47 | def append_val(self, key, val): 48 | """ Append `val` to the list stored at `key`. 49 | 50 | If the key is not yet present in storage, create a list with `val` at 51 | `key`. 52 | """ 53 | raise NotImplementedError 54 | 55 | def get_list(self, key): 56 | """ Returns a list stored in storage at `key`. 57 | 58 | This method should return a list of values stored at `key`. `[]` should 59 | be returned if the list is empty or if `key` is not present in storage. 60 | """ 61 | raise NotImplementedError 62 | 63 | 64 | class InMemoryStorage(BaseStorage): 65 | def __init__(self, config): 66 | self.name = 'dict' 67 | self.storage = dict() 68 | 69 | def keys(self): 70 | return self.storage.keys() 71 | 72 | def set_val(self, key, val): 73 | self.storage[key] = val 74 | 75 | def get_val(self, key): 76 | return self.storage[key] 77 | 78 | def append_val(self, key, val): 79 | self.storage.setdefault(key, []).append(val) 80 | 81 | def get_list(self, key): 82 | return self.storage.get(key, []) 83 | 84 | 85 | class RedisStorage(BaseStorage): 86 | def __init__(self, config): 87 | if not redis: 88 | raise ImportError("redis-py is required to use Redis as storage.") 89 | self.name = 'redis' 90 | self.storage = redis.StrictRedis(**config) 91 | 92 | def keys(self, pattern="*"): 93 | return self.storage.keys(pattern) 94 | 95 | def set_val(self, key, val): 96 | self.storage.set(key, val) 97 | 98 | def get_val(self, key): 99 | return self.storage.get(key) 100 | 101 | def append_val(self, key, val): 102 | self.storage.rpush(key, json.dumps(val)) 103 | 104 | def get_list(self, key): 105 | return self.storage.lrange(key, 0, -1) 106 | -------------------------------------------------------------------------------- /nts/core/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import os 4 | from PIL import Image 5 | from torchvision import transforms 6 | from config import INPUT_SIZE 7 | 8 | 9 | class CUB(): 10 | def __init__(self, root, is_train=True, data_len=None): 11 | self.root = root 12 | self.is_train = is_train 13 | img_txt_file = open(os.path.join(self.root, 'images.txt')) 14 | label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt')) 15 | train_val_file = open(os.path.join(self.root, 'train_test_split.txt')) 16 | img_name_list = [] 17 | for line in img_txt_file: 18 | img_name_list.append(line[:-1].split(' ')[-1]) 19 | label_list = [] 20 | for line in label_txt_file: 21 | label_list.append(int(line[:-1].split(' ')[-1]) - 1) 22 | train_test_list = [] 23 | for line in train_val_file: 24 | train_test_list.append(int(line[:-1].split(' ')[-1])) 25 | train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] 26 | test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] 27 | if self.is_train: 28 | self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in 29 | train_file_list[:data_len]] 30 | self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] 31 | if not self.is_train: 32 | self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in 33 | test_file_list[:data_len]] 34 | self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] 35 | 36 | def __getitem__(self, index): 37 | if self.is_train: 38 | img, target = self.train_img[index], self.train_label[index] 39 | if len(img.shape) == 2: 40 | img = np.stack([img] * 3, 2) 41 | img = Image.fromarray(img, mode='RGB') 42 | img = transforms.Resize((600, 600), Image.BILINEAR)(img) 43 | img = transforms.RandomCrop(INPUT_SIZE)(img) 44 | img = transforms.RandomHorizontalFlip()(img) 45 | img = transforms.ToTensor()(img) 46 | img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img) 47 | 48 | else: 49 | img, target = self.test_img[index], self.test_label[index] 50 | if len(img.shape) == 2: 51 | img = np.stack([img] * 3, 2) 52 | img = Image.fromarray(img, mode='RGB') 53 | img = transforms.Resize((600, 600), Image.BILINEAR)(img) 54 | img = transforms.CenterCrop(INPUT_SIZE)(img) 55 | img = transforms.ToTensor()(img) 56 | img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img) 57 | 58 | return img, target 59 | 60 | def __len__(self): 61 | if self.is_train: 62 | return len(self.train_label) 63 | else: 64 | return len(self.test_label) 65 | 66 | 67 | if __name__ == '__main__': 68 | dataset = CUB(root='./CUB_200_2011') 69 | print(len(dataset.train_img)) 70 | print(len(dataset.train_label)) 71 | for data in dataset: 72 | print(data[0].size(), data[1]) 73 | dataset = CUB(root='./CUB_200_2011', is_train=False) 74 | print(len(dataset.test_img)) 75 | print(len(dataset.test_label)) 76 | for data in dataset: 77 | print(data[0].size(), data[1]) 78 | -------------------------------------------------------------------------------- /cirtorch/layers/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | import cirtorch.layers.functional as LF 6 | from cirtorch.layers.normalization import L2N 7 | 8 | # -------------------------------------- 9 | # Pooling layers 10 | # -------------------------------------- 11 | 12 | class MAC(nn.Module): 13 | 14 | def __init__(self): 15 | super(MAC,self).__init__() 16 | 17 | def forward(self, x): 18 | return LF.mac(x) 19 | 20 | def __repr__(self): 21 | return self.__class__.__name__ + '()' 22 | 23 | 24 | class SPoC(nn.Module): 25 | 26 | def __init__(self): 27 | super(SPoC,self).__init__() 28 | 29 | def forward(self, x): 30 | return LF.spoc(x) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '()' 34 | 35 | 36 | class GeM(nn.Module): 37 | 38 | def __init__(self, p=3, eps=1e-6): 39 | super(GeM,self).__init__() 40 | self.p = Parameter(torch.ones(1)*p) 41 | self.eps = eps 42 | 43 | def forward(self, x): 44 | return LF.gem(x, p=self.p, eps=self.eps) 45 | 46 | def __repr__(self): 47 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 48 | 49 | class GeMmp(nn.Module): 50 | 51 | def __init__(self, p=3, mp=1, eps=1e-6): 52 | super(GeMmp,self).__init__() 53 | self.p = Parameter(torch.ones(mp)*p) 54 | self.mp = mp 55 | self.eps = eps 56 | 57 | def forward(self, x): 58 | return LF.gem(x, p=self.p.unsqueeze(-1).unsqueeze(-1), eps=self.eps) 59 | 60 | def __repr__(self): 61 | return self.__class__.__name__ + '(' + 'p=' + '[{}]'.format(self.mp) + ', ' + 'eps=' + str(self.eps) + ')' 62 | 63 | class RMAC(nn.Module): 64 | 65 | def __init__(self, L=3, eps=1e-6): 66 | super(RMAC,self).__init__() 67 | self.L = L 68 | self.eps = eps 69 | 70 | def forward(self, x): 71 | return LF.rmac(x, L=self.L, eps=self.eps) 72 | 73 | def __repr__(self): 74 | return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')' 75 | 76 | 77 | class Rpool(nn.Module): 78 | 79 | def __init__(self, rpool, whiten=None, L=3, eps=1e-6): 80 | super(Rpool,self).__init__() 81 | self.rpool = rpool 82 | self.L = L 83 | self.whiten = whiten 84 | self.norm = L2N() 85 | self.eps = eps 86 | 87 | def forward(self, x, aggregate=True): 88 | # features -> roipool 89 | o = LF.roipool(x, self.rpool, self.L, self.eps) # size: #im, #reg, D, 1, 1 90 | 91 | # concatenate regions from all images in the batch 92 | s = o.size() 93 | o = o.view(s[0]*s[1], s[2], s[3], s[4]) # size: #im x #reg, D, 1, 1 94 | 95 | # rvecs -> norm 96 | o = self.norm(o) 97 | 98 | # rvecs -> whiten -> norm 99 | if self.whiten is not None: 100 | o = self.norm(self.whiten(o.squeeze(-1).squeeze(-1))) 101 | 102 | # reshape back to regions per image 103 | o = o.view(s[0], s[1], s[2], s[3], s[4]) # size: #im, #reg, D, 1, 1 104 | 105 | # aggregate regions into a single global vector per image 106 | if aggregate: 107 | # rvecs -> sumpool -> norm 108 | o = self.norm(o.sum(1, keepdim=False)) # size: #im, D, 1, 1 109 | 110 | return o 111 | 112 | def __repr__(self): 113 | return super(Rpool, self).__repr__() + '(' + 'L=' + '{}'.format(self.L) + ')' -------------------------------------------------------------------------------- /cirtorch/datasets/genericdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | 4 | import torch 5 | import torch.utils.data as data 6 | 7 | from cirtorch.datasets.datahelpers import default_loader, imresize 8 | 9 | 10 | class ImagesFromList(data.Dataset): 11 | """A generic data loader that loads images from a list 12 | (Based on ImageFolder from pytorch) 13 | Args: 14 | root (string): Root directory path. 15 | images (list): Relative image paths as strings. 16 | imsize (int, Default: None): Defines the maximum size of longer image side 17 | bbxs (list): List of (x1,y1,x2,y2) tuples to crop the query images 18 | transform (callable, optional): A function/transform that takes in an PIL image 19 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 20 | loader (callable, optional): A function to load an image given its path. 21 | Attributes: 22 | images_fn (list): List of full image filename 23 | """ 24 | 25 | def __init__(self, root, images, imsize=None, bbxs=None, transform=None, loader=default_loader): 26 | 27 | images_fn = [os.path.join(root,images[i]) for i in range(len(images))] 28 | 29 | if len(images_fn) == 0: 30 | raise(RuntimeError("Dataset contains 0 images!")) 31 | 32 | self.root = root 33 | self.images = images 34 | self.imsize = imsize 35 | self.images_fn = images_fn 36 | self.bbxs = bbxs 37 | self.transform = transform 38 | self.loader = loader 39 | 40 | def __getitem__(self, index): 41 | """ 42 | Args: 43 | index (int): Index 44 | Returns: 45 | image (PIL): Loaded image 46 | """ 47 | path = self.images_fn[index] 48 | img = self.loader(path) 49 | imfullsize = max(img.size) 50 | 51 | if self.bbxs is not None: 52 | img = img.crop(self.bbxs[index]) 53 | 54 | if self.imsize is not None: 55 | if self.bbxs is not None: 56 | img = imresize(img, self.imsize * max(img.size) / imfullsize) 57 | else: 58 | img = imresize(img, self.imsize) 59 | 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | 63 | return img, path 64 | 65 | def __len__(self): 66 | return len(self.images_fn) 67 | 68 | def __repr__(self): 69 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 70 | fmt_str += ' Number of images: {}\n'.format(self.__len__()) 71 | fmt_str += ' Root Location: {}\n'.format(self.root) 72 | tmp = ' Transforms (if any): ' 73 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 74 | return fmt_str 75 | 76 | class ImagesFromDataList(data.Dataset): 77 | """A generic data loader that loads images given as an array of pytorch tensors 78 | (Based on ImageFolder from pytorch) 79 | Args: 80 | images (list): Images as tensors. 81 | transform (callable, optional): A function/transform that image as a tensors 82 | and returns a transformed version. E.g, ``normalize`` with mean and std 83 | """ 84 | 85 | def __init__(self, images, transform=None): 86 | 87 | if len(images) == 0: 88 | raise(RuntimeError("Dataset contains 0 images!")) 89 | 90 | self.images = images 91 | self.transform = transform 92 | 93 | def __getitem__(self, index): 94 | """ 95 | Args: 96 | index (int): Index 97 | Returns: 98 | image (Tensor): Loaded image 99 | """ 100 | img = self.images[index] 101 | if self.transform is not None: 102 | img = self.transform(img) 103 | 104 | if len(img.size()): 105 | img = img.unsqueeze(0) 106 | 107 | return img 108 | 109 | def __len__(self): 110 | return len(self.images) 111 | 112 | def __repr__(self): 113 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 114 | fmt_str += ' Number of images: {}\n'.format(self.__len__()) 115 | tmp = ' Transforms (if any): ' 116 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 117 | return fmt_str 118 | -------------------------------------------------------------------------------- /nts/core/anchors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from config import INPUT_SIZE 3 | 4 | _default_anchors_setting = ( 5 | dict(layer='p3', stride=32, size=48, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]), 6 | dict(layer='p4', stride=64, size=96, scale=[2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]), 7 | dict(layer='p5', stride=128, size=192, scale=[1, 2 ** (1. / 3.), 2 ** (2. / 3.)], aspect_ratio=[0.667, 1, 1.5]), 8 | ) 9 | 10 | 11 | def generate_default_anchor_maps(anchors_setting=None, input_shape=INPUT_SIZE): 12 | """ 13 | generate default anchor 14 | 15 | :param anchors_setting: all informations of anchors 16 | :param input_shape: shape of input images, e.g. (h, w) 17 | :return: center_anchors: # anchors * 4 (oy, ox, h, w) 18 | edge_anchors: # anchors * 4 (y0, x0, y1, x1) 19 | anchor_area: # anchors * 1 (area) 20 | """ 21 | if anchors_setting is None: 22 | anchors_setting = _default_anchors_setting 23 | 24 | center_anchors = np.zeros((0, 4), dtype=np.float32) 25 | edge_anchors = np.zeros((0, 4), dtype=np.float32) 26 | anchor_areas = np.zeros((0,), dtype=np.float32) 27 | input_shape = np.array(input_shape, dtype=int) 28 | 29 | for anchor_info in anchors_setting: 30 | 31 | stride = anchor_info['stride'] 32 | size = anchor_info['size'] 33 | scales = anchor_info['scale'] 34 | aspect_ratios = anchor_info['aspect_ratio'] 35 | 36 | output_map_shape = np.ceil(input_shape.astype(np.float32) / stride) 37 | output_map_shape = output_map_shape.astype(np.int) 38 | output_shape = tuple(output_map_shape) + (4,) 39 | ostart = stride / 2. 40 | oy = np.arange(ostart, ostart + stride * output_shape[0], stride) 41 | oy = oy.reshape(output_shape[0], 1) 42 | ox = np.arange(ostart, ostart + stride * output_shape[1], stride) 43 | ox = ox.reshape(1, output_shape[1]) 44 | center_anchor_map_template = np.zeros(output_shape, dtype=np.float32) 45 | center_anchor_map_template[:, :, 0] = oy 46 | center_anchor_map_template[:, :, 1] = ox 47 | for scale in scales: 48 | for aspect_ratio in aspect_ratios: 49 | center_anchor_map = center_anchor_map_template.copy() 50 | center_anchor_map[:, :, 2] = size * scale / float(aspect_ratio) ** 0.5 51 | center_anchor_map[:, :, 3] = size * scale * float(aspect_ratio) ** 0.5 52 | 53 | edge_anchor_map = np.concatenate((center_anchor_map[..., :2] - center_anchor_map[..., 2:4] / 2., 54 | center_anchor_map[..., :2] + center_anchor_map[..., 2:4] / 2.), 55 | axis=-1) 56 | anchor_area_map = center_anchor_map[..., 2] * center_anchor_map[..., 3] 57 | center_anchors = np.concatenate((center_anchors, center_anchor_map.reshape(-1, 4))) 58 | edge_anchors = np.concatenate((edge_anchors, edge_anchor_map.reshape(-1, 4))) 59 | anchor_areas = np.concatenate((anchor_areas, anchor_area_map.reshape(-1))) 60 | 61 | return center_anchors, edge_anchors, anchor_areas 62 | 63 | 64 | def hard_nms(cdds, topn=10, iou_thresh=0.25): 65 | if not (type(cdds).__module__ == 'numpy' and len(cdds.shape) == 2 and cdds.shape[1] >= 5): 66 | raise TypeError('edge_box_map should be N * 5+ ndarray') 67 | 68 | cdds = cdds.copy() 69 | indices = np.argsort(cdds[:, 0]) 70 | cdds = cdds[indices] 71 | cdd_results = [] 72 | 73 | res = cdds 74 | 75 | while res.any(): 76 | cdd = res[-1] 77 | cdd_results.append(cdd) 78 | if len(cdd_results) == topn: 79 | return np.array(cdd_results) 80 | res = res[:-1] 81 | 82 | start_max = np.maximum(res[:, 1:3], cdd[1:3]) 83 | end_min = np.minimum(res[:, 3:5], cdd[3:5]) 84 | lengths = end_min - start_max 85 | intersec_map = lengths[:, 0] * lengths[:, 1] 86 | intersec_map[np.logical_or(lengths[:, 0] < 0, lengths[:, 1] < 0)] = 0 87 | iou_map_cur = intersec_map / ((res[:, 3] - res[:, 1]) * (res[:, 4] - res[:, 2]) + (cdd[3] - cdd[1]) * ( 88 | cdd[4] - cdd[2]) - intersec_map) 89 | res = res[iou_map_cur < iou_thresh] 90 | 91 | return np.array(cdd_results) 92 | 93 | 94 | if __name__ == '__main__': 95 | a = hard_nms(np.array([ 96 | [0.4, 1, 10, 12, 20], 97 | [0.5, 1, 11, 11, 20], 98 | [0.55, 20, 30, 40, 50] 99 | ]), topn=100, iou_thresh=0.4) 100 | print(a) 101 | -------------------------------------------------------------------------------- /nts/core/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from core import resnet 6 | import numpy as np 7 | from core.anchors import generate_default_anchor_maps, hard_nms 8 | from config import CAT_NUM, PROPOSAL_NUM 9 | 10 | 11 | class ProposalNet(nn.Module): 12 | def __init__(self): 13 | super(ProposalNet, self).__init__() 14 | self.down1 = nn.Conv2d(2048, 128, 3, 1, 1) 15 | self.down2 = nn.Conv2d(128, 128, 3, 2, 1) 16 | self.down3 = nn.Conv2d(128, 128, 3, 2, 1) 17 | self.ReLU = nn.ReLU() 18 | self.tidy1 = nn.Conv2d(128, 6, 1, 1, 0) 19 | self.tidy2 = nn.Conv2d(128, 6, 1, 1, 0) 20 | self.tidy3 = nn.Conv2d(128, 9, 1, 1, 0) 21 | 22 | def forward(self, x): 23 | batch_size = x.size(0) 24 | d1 = self.ReLU(self.down1(x)) 25 | d2 = self.ReLU(self.down2(d1)) 26 | d3 = self.ReLU(self.down3(d2)) 27 | t1 = self.tidy1(d1).view(batch_size, -1) 28 | t2 = self.tidy2(d2).view(batch_size, -1) 29 | t3 = self.tidy3(d3).view(batch_size, -1) 30 | return torch.cat((t1, t2, t3), dim=1) 31 | 32 | 33 | class attention_net(nn.Module): 34 | def __init__(self, topN=4): 35 | super(attention_net, self).__init__() 36 | self.pretrained_model = resnet.resnet50(pretrained=True) 37 | self.pretrained_model.avgpool = nn.AdaptiveAvgPool2d(1) 38 | self.pretrained_model.fc = nn.Linear(512 * 4, 200) 39 | self.proposal_net = ProposalNet() 40 | self.topN = topN 41 | self.concat_net = nn.Linear(2048 * (CAT_NUM + 1), 200) 42 | self.partcls_net = nn.Linear(512 * 4, 200) 43 | _, edge_anchors, _ = generate_default_anchor_maps() 44 | self.pad_side = 224 45 | self.edge_anchors = (edge_anchors + 224).astype(np.int) 46 | 47 | def forward(self, x): 48 | resnet_out, rpn_feature, feature = self.pretrained_model(x) 49 | x_pad = F.pad(x, (self.pad_side, self.pad_side, self.pad_side, self.pad_side), mode='constant', value=0) 50 | batch = x.size(0) 51 | # we will reshape rpn to shape: batch * nb_anchor 52 | rpn_score = self.proposal_net(rpn_feature.detach()) 53 | all_cdds = [ 54 | np.concatenate((x.reshape(-1, 1), self.edge_anchors.copy(), np.arange(0, len(x)).reshape(-1, 1)), axis=1) 55 | for x in rpn_score.data.cpu().numpy()] 56 | top_n_cdds = [hard_nms(x, topn=self.topN, iou_thresh=0.25) for x in all_cdds] 57 | top_n_cdds = np.array(top_n_cdds) 58 | top_n_index = top_n_cdds[:, :, -1].astype(np.int) 59 | top_n_index = torch.from_numpy(top_n_index).cuda() 60 | top_n_prob = torch.gather(rpn_score, dim=1, index=top_n_index) 61 | part_imgs = torch.zeros([batch, self.topN, 3, 224, 224]).cuda() 62 | for i in range(batch): 63 | for j in range(self.topN): 64 | [y0, x0, y1, x1] = top_n_cdds[i][j, 1:5].astype(np.int) 65 | part_imgs[i:i + 1, j] = F.interpolate(x_pad[i:i + 1, :, y0:y1, x0:x1], size=(224, 224), mode='bilinear', 66 | align_corners=True) 67 | part_imgs = part_imgs.view(batch * self.topN, 3, 224, 224) 68 | _, _, part_features = self.pretrained_model(part_imgs.detach()) 69 | part_feature = part_features.view(batch, self.topN, -1) 70 | part_feature = part_feature[:, :CAT_NUM, ...].contiguous() 71 | part_feature = part_feature.view(batch, -1) 72 | # concat_logits have the shape: B*200 73 | concat_out = torch.cat([part_feature, feature], dim=1) 74 | concat_logits = self.concat_net(concat_out) 75 | raw_logits = resnet_out 76 | # part_logits have the shape: B*N*200 77 | part_logits = self.partcls_net(part_features).view(batch, self.topN, -1) 78 | return [raw_logits, concat_logits, part_logits, top_n_index, top_n_prob] 79 | 80 | 81 | def list_loss(logits, targets): 82 | temp = F.log_softmax(logits, -1) 83 | loss = [-temp[i][targets[i].item()] for i in range(logits.size(0))] 84 | return torch.stack(loss) 85 | 86 | 87 | def ranking_loss(score, targets, proposal_num=PROPOSAL_NUM): 88 | loss = Variable(torch.zeros(1).cuda()) 89 | batch_size = score.size(0) 90 | for i in range(proposal_num): 91 | targets_p = (targets > targets[:, i].unsqueeze(1)).type(torch.cuda.FloatTensor) 92 | pivot = score[:, i].unsqueeze(1) 93 | loss_p = (1 - pivot + score) * targets_p 94 | loss_p = torch.sum(F.relu(loss_p)) 95 | loss += loss_p 96 | return loss / batch_size 97 | -------------------------------------------------------------------------------- /cirtorch/utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def compute_ap(ranks, nres): 4 | """ 5 | Computes average precision for given ranked indexes. 6 | 7 | Arguments 8 | --------- 9 | ranks : zerro-based ranks of positive images 10 | nres : number of positive images 11 | 12 | Returns 13 | ------- 14 | ap : average precision 15 | """ 16 | 17 | # number of images ranked by the system 18 | nimgranks = len(ranks) 19 | 20 | # accumulate trapezoids in PR-plot 21 | ap = 0 22 | 23 | recall_step = 1. / nres 24 | 25 | for j in np.arange(nimgranks): 26 | rank = ranks[j] 27 | 28 | if rank == 0: 29 | precision_0 = 1. 30 | else: 31 | precision_0 = float(j) / rank 32 | 33 | precision_1 = float(j + 1) / (rank + 1) 34 | 35 | ap += (precision_0 + precision_1) * recall_step / 2. 36 | 37 | return ap 38 | 39 | def compute_map(ranks, gnd, kappas=[]): 40 | """ 41 | Computes the mAP for a given set of returned results. 42 | 43 | Usage: 44 | map = compute_map (ranks, gnd) 45 | computes mean average precsion (map) only 46 | 47 | map, aps, pr, prs = compute_map (ranks, gnd, kappas) 48 | computes mean average precision (map), average precision (aps) for each query 49 | computes mean precision at kappas (pr), precision at kappas (prs) for each query 50 | 51 | Notes: 52 | 1) ranks starts from 0, ranks.shape = db_size X #queries 53 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array 54 | 3) If there are no positive images for some query, that query is excluded from the evaluation 55 | """ 56 | 57 | map = 0. 58 | nq = len(gnd) # number of queries 59 | aps = np.zeros(nq) 60 | pr = np.zeros(len(kappas)) 61 | prs = np.zeros((nq, len(kappas))) 62 | nempty = 0 63 | 64 | for i in np.arange(nq): 65 | qgnd = np.array(gnd[i]['ok']) 66 | 67 | # no positive images, skip from the average 68 | if qgnd.shape[0] == 0: 69 | aps[i] = float('nan') 70 | prs[i, :] = float('nan') 71 | nempty += 1 72 | continue 73 | 74 | try: 75 | qgndj = np.array(gnd[i]['junk']) 76 | except: 77 | qgndj = np.empty(0) 78 | 79 | # sorted positions of positive and junk images (0 based) 80 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)] 81 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)] 82 | 83 | k = 0; 84 | ij = 0; 85 | if len(junk): 86 | # decrease positions of positives based on the number of 87 | # junk images appearing before them 88 | ip = 0 89 | while (ip < len(pos)): 90 | while (ij < len(junk) and pos[ip] > junk[ij]): 91 | k += 1 92 | ij += 1 93 | pos[ip] = pos[ip] - k 94 | ip += 1 95 | 96 | # compute ap 97 | ap = compute_ap(pos, len(qgnd)) 98 | map = map + ap 99 | aps[i] = ap 100 | 101 | # compute precision @ k 102 | pos += 1 # get it to 1-based 103 | for j in np.arange(len(kappas)): 104 | kq = min(max(pos), kappas[j]); 105 | prs[i, j] = (pos <= kq).sum() / kq 106 | pr = pr + prs[i, :] 107 | 108 | map = map / (nq - nempty) 109 | pr = pr / (nq - nempty) 110 | 111 | return map, aps, pr, prs 112 | 113 | 114 | def compute_map_and_print(dataset, ranks, gnd, kappas=[1, 5, 10]): 115 | 116 | # old evaluation protocol 117 | if dataset.startswith('oxford5k') or dataset.startswith('paris6k'): 118 | map, aps, _, _ = compute_map(ranks, gnd) 119 | print('>> {}: mAP {:.2f}'.format(dataset, np.around(map*100, decimals=2))) 120 | 121 | # new evaluation protocol 122 | elif dataset.startswith('roxford5k') or dataset.startswith('rparis6k'): 123 | 124 | gnd_t = [] 125 | for i in range(len(gnd)): 126 | g = {} 127 | g['ok'] = np.concatenate([gnd[i]['easy']]) 128 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['hard']]) 129 | gnd_t.append(g) 130 | mapE, apsE, mprE, prsE = compute_map(ranks, gnd_t, kappas) 131 | 132 | gnd_t = [] 133 | for i in range(len(gnd)): 134 | g = {} 135 | g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']]) 136 | g['junk'] = np.concatenate([gnd[i]['junk']]) 137 | gnd_t.append(g) 138 | mapM, apsM, mprM, prsM = compute_map(ranks, gnd_t, kappas) 139 | 140 | gnd_t = [] 141 | for i in range(len(gnd)): 142 | g = {} 143 | g['ok'] = np.concatenate([gnd[i]['hard']]) 144 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']]) 145 | gnd_t.append(g) 146 | mapH, apsH, mprH, prsH = compute_map(ranks, gnd_t, kappas) 147 | 148 | print('>> {}: mAP E: {}, M: {}, H: {}'.format(dataset, np.around(mapE*100, decimals=2), np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2))) 149 | print('>> {}: mP@k{} E: {}, M: {}, H: {}'.format(dataset, kappas, np.around(mprE*100, decimals=2), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2))) -------------------------------------------------------------------------------- /cirtorch/layers/functional.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pdb 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | # -------------------------------------- 8 | # pooling 9 | # -------------------------------------- 10 | 11 | def mac(x): 12 | return F.max_pool2d(x, (x.size(-2), x.size(-1))) 13 | # return F.adaptive_max_pool2d(x, (1,1)) # alternative 14 | 15 | 16 | def spoc(x): 17 | return F.avg_pool2d(x, (x.size(-2), x.size(-1))) 18 | # return F.adaptive_avg_pool2d(x, (1,1)) # alternative 19 | 20 | 21 | def gem(x, p=3, eps=1e-6): 22 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 23 | # return F.lp_pool2d(F.threshold(x, eps, eps), p, (x.size(-2), x.size(-1))) # alternative 24 | 25 | 26 | def rmac(x, L=3, eps=1e-6): 27 | ovr = 0.4 # desired overlap of neighboring regions 28 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 29 | 30 | W = x.size(3) 31 | H = x.size(2) 32 | 33 | w = min(W, H) 34 | w2 = math.floor(w/2.0 - 1) 35 | 36 | b = (max(H, W)-w)/(steps-1) 37 | (tmp, idx) = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension 38 | 39 | # region overplus per dimension 40 | Wd = 0; 41 | Hd = 0; 42 | if H < W: 43 | Wd = idx.item() + 1 44 | elif H > W: 45 | Hd = idx.item() + 1 46 | 47 | v = F.max_pool2d(x, (x.size(-2), x.size(-1))) 48 | v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v) 49 | 50 | for l in range(1, L+1): 51 | wl = math.floor(2*w/(l+1)) 52 | wl2 = math.floor(wl/2 - 1) 53 | 54 | if l+Wd == 1: 55 | b = 0 56 | else: 57 | b = (W-wl)/(l+Wd-1) 58 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates 59 | if l+Hd == 1: 60 | b = 0 61 | else: 62 | b = (H-wl)/(l+Hd-1) 63 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates 64 | 65 | for i_ in cenH.tolist(): 66 | for j_ in cenW.tolist(): 67 | if wl == 0: 68 | continue 69 | R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:] 70 | R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()] 71 | vt = F.max_pool2d(R, (R.size(-2), R.size(-1))) 72 | vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt) 73 | v += vt 74 | 75 | return v 76 | 77 | 78 | def roipool(x, rpool, L=3, eps=1e-6): 79 | ovr = 0.4 # desired overlap of neighboring regions 80 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 81 | 82 | W = x.size(3) 83 | H = x.size(2) 84 | 85 | w = min(W, H) 86 | w2 = math.floor(w/2.0 - 1) 87 | 88 | b = (max(H, W)-w)/(steps-1) 89 | _, idx = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension 90 | 91 | # region overplus per dimension 92 | Wd = 0; 93 | Hd = 0; 94 | if H < W: 95 | Wd = idx.item() + 1 96 | elif H > W: 97 | Hd = idx.item() + 1 98 | 99 | vecs = [] 100 | vecs.append(rpool(x).unsqueeze(1)) 101 | 102 | for l in range(1, L+1): 103 | wl = math.floor(2*w/(l+1)) 104 | wl2 = math.floor(wl/2 - 1) 105 | 106 | if l+Wd == 1: 107 | b = 0 108 | else: 109 | b = (W-wl)/(l+Wd-1) 110 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b).int() - wl2 # center coordinates 111 | if l+Hd == 1: 112 | b = 0 113 | else: 114 | b = (H-wl)/(l+Hd-1) 115 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b).int() - wl2 # center coordinates 116 | 117 | for i_ in cenH.tolist(): 118 | for j_ in cenW.tolist(): 119 | if wl == 0: 120 | continue 121 | vecs.append(rpool(x.narrow(2,i_,wl).narrow(3,j_,wl)).unsqueeze(1)) 122 | 123 | return torch.cat(vecs, dim=1) 124 | 125 | 126 | # -------------------------------------- 127 | # normalization 128 | # -------------------------------------- 129 | 130 | def l2n(x, eps=1e-6): 131 | return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x) 132 | 133 | def powerlaw(x, eps=1e-6): 134 | x = x + self.eps 135 | return x.abs().sqrt().mul(x.sign()) 136 | 137 | # -------------------------------------- 138 | # loss 139 | # -------------------------------------- 140 | 141 | def contrastive_loss(x, label, margin=0.7, eps=1e-6): 142 | # x is D x N 143 | dim = x.size(0) # D 144 | nq = torch.sum(label.data==-1) # number of tuples 145 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 146 | 147 | x1 = x[:, ::S].permute(1,0).repeat(1,S-1).view((S-1)*nq,dim).permute(1,0) 148 | idx = [i for i in range(len(label)) if label.data[i] != -1] 149 | x2 = x[:, idx] 150 | lbl = label[label!=-1] 151 | 152 | dif = x1 - x2 153 | D = torch.pow(dif+eps, 2).sum(dim=0).sqrt() 154 | 155 | y = 0.5*lbl*torch.pow(D,2) + 0.5*(1-lbl)*torch.pow(torch.clamp(margin-D, min=0),2) 156 | y = torch.sum(y) 157 | return y 158 | 159 | def triplet_loss(x, label, margin=0.1): 160 | # x is D x N 161 | dim = x.size(0) # D 162 | nq = torch.sum(label.data==-1).item() # number of tuples 163 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 164 | 165 | xa = x[:, label.data==-1].permute(1,0).repeat(1,S-2).view((S-2)*nq,dim).permute(1,0) 166 | xp = x[:, label.data==1].permute(1,0).repeat(1,S-2).view((S-2)*nq,dim).permute(1,0) 167 | xn = x[:, label.data==0] 168 | 169 | dist_pos = torch.sum(torch.pow(xa - xp, 2), dim=0) 170 | dist_neg = torch.sum(torch.pow(xa - xn, 2), dim=0) 171 | 172 | return torch.sum(torch.clamp(dist_pos - dist_neg + margin, min=0)) 173 | -------------------------------------------------------------------------------- /utils/retrieval_feature.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # /usr/bin/env pythpn 3 | 4 | ''' 5 | Author: yinhao 6 | Email: yinhao_x@163.com 7 | Wechat: xss_yinhao 8 | Github: http://github.com/yinhaoxs 9 | data: 2019-11-23 18:26 10 | desc: 11 | ''' 12 | 13 | import os 14 | from PIL import Image 15 | from lshash.lshash import LSHash 16 | import torch 17 | from torchvision import transforms 18 | from cirtorch.networks.imageretrievalnet import init_network, extract_vectors 19 | 20 | # setting up the visible GPU 21 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 22 | 23 | 24 | class ImageProcess(): 25 | def __init__(self, img_dir): 26 | self.img_dir = img_dir 27 | 28 | def process(self): 29 | imgs = list() 30 | for root, dirs, files in os.walk(self.img_dir): 31 | for file in files: 32 | img_path = os.path.join(root + os.sep, file) 33 | try: 34 | image = Image.open(img_path) 35 | if max(image.size) / min(image.size) < 5: 36 | imgs.append(img_path) 37 | else: 38 | continue 39 | except: 40 | print("image height/width ratio is small") 41 | 42 | return imgs 43 | 44 | 45 | class AntiFraudFeatureDataset(): 46 | def __init__(self, img_dir, network, feature_path='', index_path=''): 47 | self.img_dir = img_dir 48 | self.network = network 49 | self.feature_path = feature_path 50 | self.index_path = index_path 51 | 52 | def constructfeature(self, hash_size, input_dim, num_hashtables): 53 | multiscale = '[1]' 54 | print(">> Loading network:\n>>>> '{}'".format(self.network)) 55 | # state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks')) 56 | state = torch.load(self.network) 57 | # parsing net params from meta 58 | # architecture, pooling, mean, std required 59 | # the rest has default values, in case that is doesnt exist 60 | net_params = {} 61 | net_params['architecture'] = state['meta']['architecture'] 62 | net_params['pooling'] = state['meta']['pooling'] 63 | net_params['local_whitening'] = state['meta'].get('local_whitening', False) 64 | net_params['regional'] = state['meta'].get('regional', False) 65 | net_params['whitening'] = state['meta'].get('whitening', False) 66 | net_params['mean'] = state['meta']['mean'] 67 | net_params['std'] = state['meta']['std'] 68 | net_params['pretrained'] = False 69 | # network initialization 70 | net = init_network(net_params) 71 | net.load_state_dict(state['state_dict']) 72 | print(">>>> loaded network: ") 73 | print(net.meta_repr()) 74 | # setting up the multi-scale parameters 75 | ms = list(eval(multiscale)) 76 | print(">>>> Evaluating scales: {}".format(ms)) 77 | # moving network to gpu and eval mode 78 | if torch.cuda.is_available(): 79 | net.cuda() 80 | net.eval() 81 | 82 | # set up the transform 83 | normalize = transforms.Normalize( 84 | mean=net.meta['mean'], 85 | std=net.meta['std'] 86 | ) 87 | transform = transforms.Compose([ 88 | transforms.ToTensor(), 89 | normalize 90 | ]) 91 | 92 | # extract database and query vectors 93 | print('>> database images...') 94 | images = ImageProcess(self.img_dir).process() 95 | vecs, img_paths = extract_vectors(net, images, 1024, transform, ms=ms) 96 | feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T))) 97 | # index 98 | lsh = LSHash(hash_size=int(hash_size), input_dim=int(input_dim), num_hashtables=int(num_hashtables)) 99 | for img_path, vec in feature_dict.items(): 100 | lsh.index(vec.flatten(), extra_data=img_path) 101 | 102 | # ## 保存索引模型 103 | # with open(self.feature_path, "wb") as f: 104 | # pickle.dump(feature_dict, f) 105 | # with open(self.index_path, "wb") as f: 106 | # pickle.dump(lsh, f) 107 | 108 | print("extract feature is done") 109 | return feature_dict, lsh 110 | 111 | def test_feature(self): 112 | multiscale = '[1]' 113 | print(">> Loading network:\n>>>> '{}'".format(self.network)) 114 | # state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks')) 115 | state = torch.load(self.network) 116 | # parsing net params from meta 117 | # architecture, pooling, mean, std required 118 | # the rest has default values, in case that is doesnt exist 119 | net_params = {} 120 | net_params['architecture'] = state['meta']['architecture'] 121 | net_params['pooling'] = state['meta']['pooling'] 122 | net_params['local_whitening'] = state['meta'].get('local_whitening', False) 123 | net_params['regional'] = state['meta'].get('regional', False) 124 | net_params['whitening'] = state['meta'].get('whitening', False) 125 | net_params['mean'] = state['meta']['mean'] 126 | net_params['std'] = state['meta']['std'] 127 | net_params['pretrained'] = False 128 | # network initialization 129 | net = init_network(net_params) 130 | net.load_state_dict(state['state_dict']) 131 | print(">>>> loaded network: ") 132 | print(net.meta_repr()) 133 | # setting up the multi-scale parameters 134 | ms = list(eval(multiscale)) 135 | print(">>>> Evaluating scales: {}".format(ms)) 136 | # moving network to gpu and eval mode 137 | if torch.cuda.is_available(): 138 | net.cuda() 139 | net.eval() 140 | 141 | # set up the transform 142 | normalize = transforms.Normalize( 143 | mean=net.meta['mean'], 144 | std=net.meta['std'] 145 | ) 146 | transform = transforms.Compose([ 147 | transforms.ToTensor(), 148 | normalize 149 | ]) 150 | 151 | # extract database and query vectors 152 | print('>> database images...') 153 | images = ImageProcess(self.img_dir).process() 154 | vecs, img_paths = extract_vectors(net, images, 1024, transform, ms=ms) 155 | feature_dict = dict(zip(img_paths, list(vecs.detach().cpu().numpy().T))) 156 | return feature_dict 157 | 158 | 159 | if __name__ == '__main__': 160 | pass 161 | -------------------------------------------------------------------------------- /cirtorch/examples/test_e2e.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import pickle 5 | import pdb 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.model_zoo import load_url 11 | from torchvision import transforms 12 | 13 | from cirtorch.networks.imageretrievalnet import init_network, extract_vectors 14 | from cirtorch.datasets.testdataset import configdataset 15 | from cirtorch.utils.download import download_train, download_test 16 | from cirtorch.utils.evaluate import compute_map_and_print 17 | from cirtorch.utils.general import get_data_root, htime 18 | 19 | PRETRAINED = { 20 | 'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth', 21 | 'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth', 22 | 'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth', 23 | 'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth', 24 | 'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth', 25 | 'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth', 26 | } 27 | 28 | datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing End-to-End') 31 | 32 | # test options 33 | parser.add_argument('--network', '-n', metavar='NETWORK', 34 | help="network to be evaluated: " + 35 | " | ".join(PRETRAINED.keys())) 36 | parser.add_argument('--datasets', '-d', metavar='DATASETS', default='roxford5k,rparis6k', 37 | help="comma separated list of test datasets: " + 38 | " | ".join(datasets_names) + 39 | " (default: 'roxford5k,rparis6k')") 40 | parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N', 41 | help="maximum size of longer image side used for testing (default: 1024)") 42 | parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]', 43 | help="use multiscale vectors for testing, " + 44 | " examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')") 45 | 46 | # GPU ID 47 | parser.add_argument('--gpu-id', '-g', default='0', metavar='N', 48 | help="gpu id used for testing (default: '0')") 49 | 50 | def main(): 51 | args = parser.parse_args() 52 | 53 | # check if there are unknown datasets 54 | for dataset in args.datasets.split(','): 55 | if dataset not in datasets_names: 56 | raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset)) 57 | 58 | # check if test dataset are downloaded 59 | # and download if they are not 60 | download_train(get_data_root()) 61 | download_test(get_data_root()) 62 | 63 | # setting up the visible GPU 64 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 65 | 66 | # loading network 67 | # pretrained networks (downloaded automatically) 68 | print(">> Loading network:\n>>>> '{}'".format(args.network)) 69 | state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks')) 70 | # state = torch.load(args.network) 71 | # parsing net params from meta 72 | # architecture, pooling, mean, std required 73 | # the rest has default values, in case that is doesnt exist 74 | net_params = {} 75 | net_params['architecture'] = state['meta']['architecture'] 76 | net_params['pooling'] = state['meta']['pooling'] 77 | net_params['local_whitening'] = state['meta'].get('local_whitening', False) 78 | net_params['regional'] = state['meta'].get('regional', False) 79 | net_params['whitening'] = state['meta'].get('whitening', False) 80 | net_params['mean'] = state['meta']['mean'] 81 | net_params['std'] = state['meta']['std'] 82 | net_params['pretrained'] = False 83 | # network initialization 84 | net = init_network(net_params) 85 | net.load_state_dict(state['state_dict']) 86 | 87 | print(">>>> loaded network: ") 88 | print(net.meta_repr()) 89 | 90 | # setting up the multi-scale parameters 91 | ms = list(eval(args.multiscale)) 92 | print(">>>> Evaluating scales: {}".format(ms)) 93 | 94 | # moving network to gpu and eval mode 95 | net.cuda() 96 | net.eval() 97 | 98 | # set up the transform 99 | normalize = transforms.Normalize( 100 | mean=net.meta['mean'], 101 | std=net.meta['std'] 102 | ) 103 | transform = transforms.Compose([ 104 | transforms.ToTensor(), 105 | normalize 106 | ]) 107 | 108 | # evaluate on test datasets 109 | datasets = args.datasets.split(',') 110 | for dataset in datasets: 111 | start = time.time() 112 | 113 | print('>> {}: Extracting...'.format(dataset)) 114 | 115 | # prepare config structure for the test dataset 116 | cfg = configdataset(dataset, os.path.join(get_data_root(), 'test')) 117 | images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])] 118 | qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])] 119 | try: 120 | bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])] 121 | except: 122 | bbxs = None # for holidaysmanrot and copydays 123 | 124 | # extract database and query vectors 125 | print('>> {}: database images...'.format(dataset)) 126 | vecs = extract_vectors(net, images, args.image_size, transform, ms=ms) 127 | print('>> {}: query images...'.format(dataset)) 128 | qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms) 129 | 130 | print('>> {}: Evaluating...'.format(dataset)) 131 | 132 | # convert to numpy 133 | vecs = vecs.numpy() 134 | qvecs = qvecs.numpy() 135 | 136 | # search, rank, and print 137 | scores = np.dot(vecs.T, qvecs) 138 | ranks = np.argsort(-scores, axis=0) 139 | compute_map_and_print(dataset, ranks, cfg['gnd']) 140 | 141 | print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start))) 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /nts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data 3 | from torch.nn import DataParallel 4 | from datetime import datetime 5 | from torch.optim.lr_scheduler import MultiStepLR 6 | from config import BATCH_SIZE, PROPOSAL_NUM, SAVE_FREQ, LR, WD, resume, save_dir 7 | from core import model, dataset 8 | from core.utils import init_log, progress_bar 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 11 | start_epoch = 1 12 | save_dir = os.path.join(save_dir, datetime.now().strftime('%Y%m%d_%H%M%S')) 13 | if os.path.exists(save_dir): 14 | raise NameError('model dir exists!') 15 | os.makedirs(save_dir) 16 | logging = init_log(save_dir) 17 | _print = logging.info 18 | 19 | # read dataset 20 | trainset = dataset.CUB(root='./CUB_200_2011', is_train=True, data_len=None) 21 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, 22 | shuffle=True, num_workers=8, drop_last=False) 23 | testset = dataset.CUB(root='./CUB_200_2011', is_train=False, data_len=None) 24 | testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, 25 | shuffle=False, num_workers=8, drop_last=False) 26 | # define model 27 | net = model.attention_net(topN=PROPOSAL_NUM) 28 | if resume: 29 | ckpt = torch.load(resume) 30 | net.load_state_dict(ckpt['net_state_dict']) 31 | start_epoch = ckpt['epoch'] + 1 32 | creterion = torch.nn.CrossEntropyLoss() 33 | 34 | # define optimizers 35 | raw_parameters = list(net.pretrained_model.parameters()) 36 | part_parameters = list(net.proposal_net.parameters()) 37 | concat_parameters = list(net.concat_net.parameters()) 38 | partcls_parameters = list(net.partcls_net.parameters()) 39 | 40 | raw_optimizer = torch.optim.SGD(raw_parameters, lr=LR, momentum=0.9, weight_decay=WD) 41 | concat_optimizer = torch.optim.SGD(concat_parameters, lr=LR, momentum=0.9, weight_decay=WD) 42 | part_optimizer = torch.optim.SGD(part_parameters, lr=LR, momentum=0.9, weight_decay=WD) 43 | partcls_optimizer = torch.optim.SGD(partcls_parameters, lr=LR, momentum=0.9, weight_decay=WD) 44 | schedulers = [MultiStepLR(raw_optimizer, milestones=[60, 100], gamma=0.1), 45 | MultiStepLR(concat_optimizer, milestones=[60, 100], gamma=0.1), 46 | MultiStepLR(part_optimizer, milestones=[60, 100], gamma=0.1), 47 | MultiStepLR(partcls_optimizer, milestones=[60, 100], gamma=0.1)] 48 | net = net.cuda() 49 | net = DataParallel(net) 50 | 51 | for epoch in range(start_epoch, 500): 52 | for scheduler in schedulers: 53 | scheduler.step() 54 | 55 | # begin training 56 | _print('--' * 50) 57 | net.train() 58 | for i, data in enumerate(trainloader): 59 | img, label = data[0].cuda(), data[1].cuda() 60 | batch_size = img.size(0) 61 | raw_optimizer.zero_grad() 62 | part_optimizer.zero_grad() 63 | concat_optimizer.zero_grad() 64 | partcls_optimizer.zero_grad() 65 | 66 | raw_logits, concat_logits, part_logits, _, top_n_prob = net(img) 67 | part_loss = model.list_loss(part_logits.view(batch_size * PROPOSAL_NUM, -1), 68 | label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)).view(batch_size, PROPOSAL_NUM) 69 | raw_loss = creterion(raw_logits, label) 70 | concat_loss = creterion(concat_logits, label) 71 | rank_loss = model.ranking_loss(top_n_prob, part_loss) 72 | partcls_loss = creterion(part_logits.view(batch_size * PROPOSAL_NUM, -1), 73 | label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)) 74 | 75 | total_loss = raw_loss + rank_loss + concat_loss + partcls_loss 76 | total_loss.backward() 77 | raw_optimizer.step() 78 | part_optimizer.step() 79 | concat_optimizer.step() 80 | partcls_optimizer.step() 81 | progress_bar(i, len(trainloader), 'train') 82 | 83 | if epoch % SAVE_FREQ == 0: 84 | train_loss = 0 85 | train_correct = 0 86 | total = 0 87 | net.eval() 88 | for i, data in enumerate(trainloader): 89 | with torch.no_grad(): 90 | img, label = data[0].cuda(), data[1].cuda() 91 | batch_size = img.size(0) 92 | _, concat_logits, _, _, _ = net(img) 93 | # calculate loss 94 | concat_loss = creterion(concat_logits, label) 95 | # calculate accuracy 96 | _, concat_predict = torch.max(concat_logits, 1) 97 | total += batch_size 98 | train_correct += torch.sum(concat_predict.data == label.data) 99 | train_loss += concat_loss.item() * batch_size 100 | progress_bar(i, len(trainloader), 'eval train set') 101 | 102 | train_acc = float(train_correct) / total 103 | train_loss = train_loss / total 104 | 105 | _print( 106 | 'epoch:{} - train loss: {:.3f} and train acc: {:.3f} total sample: {}'.format( 107 | epoch, 108 | train_loss, 109 | train_acc, 110 | total)) 111 | 112 | # evaluate on test set 113 | test_loss = 0 114 | test_correct = 0 115 | total = 0 116 | for i, data in enumerate(testloader): 117 | with torch.no_grad(): 118 | img, label = data[0].cuda(), data[1].cuda() 119 | batch_size = img.size(0) 120 | _, concat_logits, _, _, _ = net(img) 121 | # calculate loss 122 | concat_loss = creterion(concat_logits, label) 123 | # calculate accuracy 124 | _, concat_predict = torch.max(concat_logits, 1) 125 | total += batch_size 126 | test_correct += torch.sum(concat_predict.data == label.data) 127 | test_loss += concat_loss.item() * batch_size 128 | progress_bar(i, len(testloader), 'eval test set') 129 | 130 | test_acc = float(test_correct) / total 131 | test_loss = test_loss / total 132 | _print( 133 | 'epoch:{} - test loss: {:.3f} and test acc: {:.3f} total sample: {}'.format( 134 | epoch, 135 | test_loss, 136 | test_acc, 137 | total)) 138 | 139 | # save model 140 | net_state_dict = net.module.state_dict() 141 | if not os.path.exists(save_dir): 142 | os.mkdir(save_dir) 143 | torch.save({ 144 | 'epoch': epoch, 145 | 'train_loss': train_loss, 146 | 'train_acc': train_acc, 147 | 'test_loss': test_loss, 148 | 'test_acc': test_acc, 149 | 'net_state_dict': net_state_dict}, 150 | os.path.join(save_dir, '%03d.ckpt' % epoch)) 151 | 152 | print('finishing training') 153 | -------------------------------------------------------------------------------- /utils/retrieval_index.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # /usr/bin/env pythpn 3 | 4 | ''' 5 | Author: yinhao 6 | Email: yinhao_x@163.com 7 | Wechat: xss_yinhao 8 | Github: http://github.com/yinhaoxs 9 | data: 2019-11-23 18:27 10 | desc: 11 | ''' 12 | 13 | import os 14 | import shutil 15 | import numpy as np 16 | import pandas as pd 17 | 18 | 19 | class EvaluteMap(): 20 | def __init__(self, out_similar_dir='', out_similar_file_dir='', all_csv_file='', feature_path='', index_path=''): 21 | self.out_similar_dir = out_similar_dir 22 | self.out_similar_file_dir = out_similar_file_dir 23 | self.all_csv_file = all_csv_file 24 | self.feature_path = feature_path 25 | self.index_path = index_path 26 | 27 | 28 | def get_dict(self, query_no, query_id, simi_no, simi_id, num, score): 29 | new_dict = { 30 | 'index': str(num), 31 | 'id1': str(query_id), 32 | 'id2': str(simi_id), 33 | 'no1': str(query_no), 34 | 'no2': str(simi_no), 35 | 'score': score 36 | } 37 | return new_dict 38 | 39 | 40 | def find_similar_img_gyz(self, feature_dict, lsh, num_results): 41 | for q_path, q_vec in feature_dict.items(): 42 | try: 43 | response = lsh.query(q_vec.flatten(), num_results=int(num_results), distance_func="cosine") 44 | query_img_path0 = response[0][0][1] 45 | query_img_path1 = response[1][0][1] 46 | query_img_path2 = response[2][0][1] 47 | # score0 = response[0][1] 48 | # score0 = np.rint(100 * (1 - score0)) 49 | print('**********************************************') 50 | print('input img: {}'.format(q_path)) 51 | print('query0 img: {}'.format(query_img_path0)) 52 | print('query1 img: {}'.format(query_img_path1)) 53 | print('query2 img: {}'.format(query_img_path2)) 54 | except: 55 | continue 56 | 57 | 58 | def find_similar_img(self, feature_dict, lsh, num_results): 59 | num = 0 60 | result_list = list() 61 | for q_path, q_vec in feature_dict.items(): 62 | response = lsh.query(q_vec.flatten(), num_results=int(num_results), distance_func="cosine") 63 | s_path_list, s_vec_list, s_id_list, s_no_list, score_list = list(), list(), list(), list(), list() 64 | q_path = q_path[0] 65 | q_no, q_id = q_path.split("\\")[-2], q_path.split("\\")[-1] 66 | try: 67 | for i in range(int(num_results)): 68 | s_path, s_vec = response[i][0][1], response[i][0][0] 69 | s_path = s_path[0] 70 | s_no, s_id = s_path.split("\\")[-2], s_path.split("\\")[-1] 71 | if str(s_no) != str(q_no): 72 | score = np.rint(100 * (1 - response[i][1])) 73 | s_path_list.append(s_path) 74 | s_vec_list.append(s_vec) 75 | s_id_list.append(s_id) 76 | s_no_list.append(s_no) 77 | score_list.append(score) 78 | else: 79 | continue 80 | 81 | if len(s_path_list) != 0: 82 | index = score_list.index(max(score_list)) 83 | s_path, s_vec, s_id, s_no, score = s_path_list[index], s_vec_list[index], s_id_list[index], \ 84 | s_no_list[index], score_list[index] 85 | else: 86 | s_path, s_vec, s_id, s_no, score = None, None, None, None, None 87 | except: 88 | s_path, s_vec, s_id, s_no, score = None, None, None, None, None 89 | 90 | try: 91 | ##拷贝文件到指定文件夹 92 | num += 1 93 | des_path = os.path.join(self.out_similar_dir, str(num)) 94 | if not os.path.exists(des_path): 95 | os.makedirs(des_path) 96 | shutil.copy(q_path, des_path) 97 | os.rename(os.path.join(des_path, q_id), os.path.join(des_path, "query_" + q_no + "_" + q_id)) 98 | if s_path != None: 99 | shutil.copy(s_path, des_path) 100 | os.rename(os.path.join(des_path, s_id), os.path.join(des_path, s_no + "_" + s_id)) 101 | 102 | new_dict = self.get_dict(q_no, q_id, s_no, s_id, num, score) 103 | result_list.append(new_dict) 104 | except: 105 | continue 106 | 107 | try: 108 | result_s = pd.DataFrame.from_dict(result_list) 109 | result_s.to_csv(self.all_csv_file, encoding="gbk", index=False) 110 | except: 111 | print("write error") 112 | 113 | 114 | def filter_gap_score(self): 115 | for value in range(90, 101): 116 | try: 117 | pd_df = pd.read_csv(self.all_csv_file, encoding="gbk", error_bad_lines=False) 118 | pd_tmp = pd_df[pd_df["score"] == int(value)] 119 | if not os.path.exists(self.out_similar_file_dir): 120 | os.makedirs(self.out_similar_file_dir) 121 | 122 | try: 123 | results_split_csv = os.path.join(self.out_similar_file_dir + os.sep, 124 | "filter_{}.csv".format(str(value))) 125 | pd_tmp.to_csv(results_split_csv, encoding="gbk", index=False) 126 | except: 127 | print("write part error") 128 | 129 | lines = pd_df[pd_df["score"] == int(value)]["index"] 130 | num = 0 131 | for line in lines: 132 | des_path_temp = os.path.join(self.out_similar_file_dir + os.sep, str(value), str(line)) 133 | if not os.path.exists(des_path_temp): 134 | os.makedirs(des_path_temp) 135 | pairs_path = os.path.join(self.out_similar_dir + os.sep, str(line)) 136 | for img_id in os.listdir(pairs_path): 137 | img_path = os.path.join(pairs_path + os.sep, img_id) 138 | shutil.copy(img_path, des_path_temp) 139 | except: 140 | print("error") 141 | 142 | 143 | def retrieval_images(self, feature_dict, lsh, num_results=1): 144 | # load model 145 | # with open(self.feature_path, "rb") as f: 146 | # feature_dict = pickle.load(f) 147 | # with open(self.index_path, "rb") as f: 148 | # lsh = pickle.load(f) 149 | 150 | self.find_similar_img_gyz(feature_dict, lsh, num_results) 151 | # self.filter_gap_score() 152 | 153 | 154 | if __name__ == "__main__": 155 | pass 156 | -------------------------------------------------------------------------------- /nts/core/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * 4) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | def __init__(self, block, layers, num_classes=1000): 96 | self.inplanes = 64 97 | super(ResNet, self).__init__() 98 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 99 | bias=False) 100 | self.bn1 = nn.BatchNorm2d(64) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 103 | self.layer1 = self._make_layer(block, 64, layers[0]) 104 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 105 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 106 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 107 | self.avgpool = nn.AvgPool2d(7) 108 | self.fc = nn.Linear(512 * block.expansion, num_classes) 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(0, math.sqrt(2. / n)) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def _make_layer(self, block, planes, blocks, stride=1): 119 | downsample = None 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, planes * block.expansion, 123 | kernel_size=1, stride=stride, bias=False), 124 | nn.BatchNorm2d(planes * block.expansion), 125 | ) 126 | 127 | layers = [] 128 | layers.append(block(self.inplanes, planes, stride, downsample)) 129 | self.inplanes = planes * block.expansion 130 | for i in range(1, blocks): 131 | layers.append(block(self.inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | x = self.conv1(x) 137 | x = self.bn1(x) 138 | x = self.relu(x) 139 | x = self.maxpool(x) 140 | 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | feature1 = x 146 | x = self.avgpool(x) 147 | x = x.view(x.size(0), -1) 148 | x = nn.Dropout(p=0.5)(x) 149 | feature2 = x 150 | x = self.fc(x) 151 | 152 | return x, feature1, feature2 153 | 154 | 155 | def resnet18(pretrained=False, **kwargs): 156 | """Constructs a ResNet-18 model. 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 162 | if pretrained: 163 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 164 | return model 165 | 166 | 167 | def resnet34(pretrained=False, **kwargs): 168 | """Constructs a ResNet-34 model. 169 | 170 | Args: 171 | pretrained (bool): If True, returns a model pre-trained on ImageNet 172 | """ 173 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 174 | if pretrained: 175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 176 | return model 177 | 178 | 179 | def resnet50(pretrained=False, **kwargs): 180 | """Constructs a ResNet-50 model. 181 | 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 186 | if pretrained: 187 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 188 | return model 189 | 190 | 191 | def resnet101(pretrained=False, **kwargs): 192 | """Constructs a ResNet-101 model. 193 | 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 198 | if pretrained: 199 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 200 | return model 201 | 202 | 203 | def resnet152(pretrained=False, **kwargs): 204 | """Constructs a ResNet-152 model. 205 | 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 212 | return model 213 | -------------------------------------------------------------------------------- /cirtorch/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def download_test(data_dir): 4 | """ 5 | DOWNLOAD_TEST Checks, and, if required, downloads the necessary datasets for the testing. 6 | 7 | download_test(DATA_ROOT) checks if the data necessary for running the example script exist. 8 | If not it downloads it in the folder structure: 9 | DATA_ROOT/test/oxford5k/ : folder with Oxford images and ground truth file 10 | DATA_ROOT/test/paris6k/ : folder with Paris images and ground truth file 11 | DATA_ROOT/test/roxford5k/ : folder with Oxford images and revisited ground truth file 12 | DATA_ROOT/test/rparis6k/ : folder with Paris images and revisited ground truth file 13 | """ 14 | 15 | # Create data folder if it does not exist 16 | if not os.path.isdir(data_dir): 17 | os.mkdir(data_dir) 18 | 19 | # Create datasets folder if it does not exist 20 | datasets_dir = os.path.join(data_dir, 'test') 21 | if not os.path.isdir(datasets_dir): 22 | os.mkdir(datasets_dir) 23 | 24 | # Download datasets folders test/DATASETNAME/ 25 | datasets = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] 26 | for di in range(len(datasets)): 27 | dataset = datasets[di] 28 | 29 | if dataset == 'oxford5k': 30 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 31 | dl_files = ['oxbuild_images.tgz'] 32 | elif dataset == 'paris6k': 33 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 34 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 35 | elif dataset == 'roxford5k': 36 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 37 | dl_files = ['oxbuild_images.tgz'] 38 | elif dataset == 'rparis6k': 39 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 40 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 41 | else: 42 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 43 | 44 | dst_dir = os.path.join(datasets_dir, dataset, 'jpg') 45 | if not os.path.isdir(dst_dir): 46 | 47 | # for oxford and paris download images 48 | if dataset == 'oxford5k' or dataset == 'paris6k': 49 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 50 | os.makedirs(dst_dir) 51 | for dli in range(len(dl_files)): 52 | dl_file = dl_files[dli] 53 | src_file = os.path.join(src_dir, dl_file) 54 | dst_file = os.path.join(dst_dir, dl_file) 55 | print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file)) 56 | os.system('wget {} -O {}'.format(src_file, dst_file)) 57 | print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file)) 58 | # create tmp folder 59 | dst_dir_tmp = os.path.join(dst_dir, 'tmp') 60 | os.system('mkdir {}'.format(dst_dir_tmp)) 61 | # extract in tmp folder 62 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp)) 63 | # remove all (possible) subfolders by moving only files in dst_dir 64 | os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir)) 65 | # remove tmp folder 66 | os.system('rm -rf {}'.format(dst_dir_tmp)) 67 | print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file)) 68 | os.system('rm {}'.format(dst_file)) 69 | 70 | # for roxford and rparis just make sym links 71 | elif dataset == 'roxford5k' or dataset == 'rparis6k': 72 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 73 | dataset_old = dataset[1:] 74 | dst_dir_old = os.path.join(datasets_dir, dataset_old, 'jpg') 75 | os.mkdir(os.path.join(datasets_dir, dataset)) 76 | os.system('ln -s {} {}'.format(dst_dir_old, dst_dir)) 77 | print('>> Created symbolic link from {} jpg to {} jpg'.format(dataset_old, dataset)) 78 | 79 | 80 | gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset) 81 | gnd_dst_dir = os.path.join(datasets_dir, dataset) 82 | gnd_dl_file = 'gnd_{}.pkl'.format(dataset) 83 | gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file) 84 | gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file) 85 | if not os.path.exists(gnd_dst_file): 86 | print('>> Downloading dataset {} ground truth file...'.format(dataset)) 87 | os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file)) 88 | 89 | 90 | def download_train(data_dir): 91 | """ 92 | DOWNLOAD_TRAIN Checks, and, if required, downloads the necessary datasets for the training. 93 | 94 | download_train(DATA_ROOT) checks if the data necessary for running the example script exist. 95 | If not it downloads it in the folder structure: 96 | DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db files 97 | DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db files 98 | """ 99 | 100 | # Create data folder if it does not exist 101 | if not os.path.isdir(data_dir): 102 | os.mkdir(data_dir) 103 | 104 | # Create datasets folder if it does not exist 105 | datasets_dir = os.path.join(data_dir, 'train') 106 | if not os.path.isdir(datasets_dir): 107 | os.mkdir(datasets_dir) 108 | 109 | # Download folder train/retrieval-SfM-120k/ 110 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'ims') 111 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 112 | dl_file = 'ims.tar.gz' 113 | if not os.path.isdir(dst_dir): 114 | src_file = os.path.join(src_dir, dl_file) 115 | dst_file = os.path.join(dst_dir, dl_file) 116 | print('>> Image directory does not exist. Creating: {}'.format(dst_dir)) 117 | os.makedirs(dst_dir) 118 | print('>> Downloading ims.tar.gz...') 119 | os.system('wget {} -O {}'.format(src_file, dst_file)) 120 | print('>> Extracting {}...'.format(dst_file)) 121 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir)) 122 | print('>> Extracted, deleting {}...'.format(dst_file)) 123 | os.system('rm {}'.format(dst_file)) 124 | 125 | # Create symlink for train/retrieval-SfM-30k/ 126 | dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 127 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims') 128 | if not os.path.isdir(dst_dir): 129 | os.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k')) 130 | os.system('ln -s {} {}'.format(dst_dir_old, dst_dir)) 131 | print('>> Created symbolic link from retrieval-SfM-120k/ims to retrieval-SfM-30k/ims') 132 | 133 | # Download db files 134 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'dbs') 135 | datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k'] 136 | for dataset in datasets: 137 | dst_dir = os.path.join(datasets_dir, dataset) 138 | if dataset == 'retrieval-SfM-120k': 139 | dl_files = ['{}.pkl'.format(dataset), '{}-whiten.pkl'.format(dataset)] 140 | elif dataset == 'retrieval-SfM-30k': 141 | dl_files = ['{}-whiten.pkl'.format(dataset)] 142 | 143 | if not os.path.isdir(dst_dir): 144 | print('>> Dataset directory does not exist. Creating: {}'.format(dst_dir)) 145 | os.mkdir(dst_dir) 146 | 147 | for i in range(len(dl_files)): 148 | src_file = os.path.join(src_dir, dl_files[i]) 149 | dst_file = os.path.join(dst_dir, dl_files[i]) 150 | if not os.path.isfile(dst_file): 151 | print('>> DB file {} does not exist. Downloading...'.format(dl_files[i])) 152 | os.system('wget {} -O {}'.format(src_file, dst_file)) 153 | -------------------------------------------------------------------------------- /cirtorch/utils/download_win.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def download_test(data_dir): 4 | """ 5 | DOWNLOAD_TEST Checks, and, if required, downloads the necessary datasets for the testing. 6 | 7 | download_test(DATA_ROOT) checks if the data necessary for running the example script exist. 8 | If not it downloads it in the folder structure: 9 | DATA_ROOT/test/oxford5k/ : folder with Oxford images and ground truth file 10 | DATA_ROOT/test/paris6k/ : folder with Paris images and ground truth file 11 | DATA_ROOT/test/roxford5k/ : folder with Oxford images and revisited ground truth file 12 | DATA_ROOT/test/rparis6k/ : folder with Paris images and revisited ground truth file 13 | """ 14 | 15 | # Create data folder if it does not exist 16 | if not os.path.isdir(data_dir): 17 | os.mkdir(data_dir) 18 | 19 | # Create datasets folder if it does not exist 20 | datasets_dir = os.path.join(data_dir, 'test') 21 | if not os.path.isdir(datasets_dir): 22 | os.mkdir(datasets_dir) 23 | 24 | # Download datasets folders test/DATASETNAME/ 25 | datasets = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] 26 | for di in range(len(datasets)): 27 | dataset = datasets[di] 28 | 29 | if dataset == 'oxford5k': 30 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 31 | dl_files = ['oxbuild_images.tgz'] 32 | elif dataset == 'paris6k': 33 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 34 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 35 | elif dataset == 'roxford5k': 36 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 37 | dl_files = ['oxbuild_images.tgz'] 38 | elif dataset == 'rparis6k': 39 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 40 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 41 | else: 42 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 43 | 44 | dst_dir = os.path.join(datasets_dir, dataset, 'jpg') 45 | if not os.path.isdir(dst_dir): 46 | 47 | # for oxford and paris download images 48 | if dataset == 'oxford5k' or dataset == 'paris6k': 49 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 50 | os.makedirs(dst_dir) 51 | for dli in range(len(dl_files)): 52 | dl_file = dl_files[dli] 53 | src_file = os.path.join(src_dir, dl_file) 54 | dst_file = os.path.join(dst_dir, dl_file) 55 | print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file)) 56 | os.system('wget {} -O {}'.format(src_file, dst_file)) 57 | print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file)) 58 | # create tmp folder 59 | dst_dir_tmp = os.path.join(dst_dir, 'tmp') 60 | os.system('mkdir {}'.format(dst_dir_tmp)) 61 | # extract in tmp folder 62 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp)) 63 | # remove all (possible) subfolders by moving only files in dst_dir 64 | os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir)) 65 | # remove tmp folder 66 | os.system('rd {}'.format(dst_dir_tmp)) 67 | print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file)) 68 | os.system('del {}'.format(dst_file)) 69 | 70 | # for roxford and rparis just make sym links 71 | elif dataset == 'roxford5k' or dataset == 'rparis6k': 72 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 73 | dataset_old = dataset[1:] 74 | dst_dir_old = os.path.join(datasets_dir, dataset_old, 'jpg') 75 | os.mkdir(os.path.join(datasets_dir, dataset)) 76 | os.system('cmd /c mklink /d {} {}'.format(dst_dir_old, dst_dir)) 77 | print('>> Created symbolic link from {} jpg to {} jpg'.format(dataset_old, dataset)) 78 | 79 | 80 | gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset) 81 | gnd_dst_dir = os.path.join(datasets_dir, dataset) 82 | gnd_dl_file = 'gnd_{}.pkl'.format(dataset) 83 | gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file) 84 | gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file) 85 | if not os.path.exists(gnd_dst_file): 86 | print('>> Downloading dataset {} ground truth file...'.format(dataset)) 87 | os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file)) 88 | 89 | 90 | def download_train(data_dir): 91 | """ 92 | DOWNLOAD_TRAIN Checks, and, if required, downloads the necessary datasets for the training. 93 | 94 | download_train(DATA_ROOT) checks if the data necessary for running the example script exist. 95 | If not it downloads it in the folder structure: 96 | DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db files 97 | DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db files 98 | """ 99 | 100 | # Create data folder if it does not exist 101 | if not os.path.isdir(data_dir): 102 | os.mkdir(data_dir) 103 | print(data_dir) 104 | # Create datasets folder if it does not exist 105 | datasets_dir = os.path.join(data_dir, 'train') 106 | if not os.path.isdir(datasets_dir): 107 | os.mkdir(datasets_dir) 108 | 109 | # Download folder train/retrieval-SfM-120k/ 110 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'ims') 111 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 112 | dl_file = 'ims.tar.gz' 113 | if not os.path.isdir(dst_dir): 114 | src_file = os.path.join(src_dir, dl_file) 115 | dst_file = os.path.join(dst_dir, dl_file) 116 | print('>> Image directory does not exist. Creating: {}'.format(dst_dir)) 117 | os.makedirs(dst_dir) 118 | print('>> Downloading ims.tar.gz...') 119 | # os.system('wget {} -O {}'.format(src_file, dst_file)) 120 | print('>> Extracting {}...'.format(dst_file)) 121 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir)) 122 | print('>> Extracted, deleting {}...'.format(dst_file)) 123 | os.system('del {}'.format(dst_file)) 124 | 125 | # Create symlink for train/retrieval-SfM-30k/ 126 | dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 127 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims') 128 | if not os.path.isdir(dst_dir): 129 | os.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k','ims')) 130 | os.system('mklink {} {}'.format(dst_dir_old, dst_dir)) 131 | print('>> Created symbolic link from retrieval-SfM-120k/ims to retrieval-SfM-30k/ims') 132 | 133 | # Download db files 134 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'dbs') 135 | datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k'] 136 | for dataset in datasets: 137 | dst_dir = os.path.join(datasets_dir, dataset) 138 | if dataset == 'retrieval-SfM-120k': 139 | dl_files = ['{}.pkl'.format(dataset), '{}-whiten.pkl'.format(dataset)] 140 | elif dataset == 'retrieval-SfM-30k': 141 | dl_files = ['{}-whiten.pkl'.format(dataset)] 142 | 143 | if not os.path.isdir(dst_dir): 144 | print('>> Dataset directory does not exist. Creating: {}'.format(dst_dir)) 145 | os.mkdir(dst_dir) 146 | 147 | for i in range(len(dl_files)): 148 | src_file = os.path.join(src_dir, dl_files[i]) 149 | dst_file = os.path.join(dst_dir, dl_files[i]) 150 | if not os.path.isfile(dst_file): 151 | print('>> DB file {} does not exist. Downloading...'.format(dl_files[i])) 152 | os.system('wget {} -O {}'.format(src_file, dst_file)) 153 | -------------------------------------------------------------------------------- /interface.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # /usr/bin/env pythpn 3 | 4 | ''' 5 | Author: yinhao 6 | Email: yinhao_x@163.com 7 | Wechat: xss_yinhao 8 | Github: http://github.com/yinhaoxs 9 | 10 | data: 2019-11-23 21:51 11 | desc: 12 | ''' 13 | import torch 14 | from torch.utils.model_zoo import load_url 15 | from torchvision import transforms 16 | from cirtorch.datasets.testdataset import configdataset 17 | from cirtorch.utils.download import download_train, download_test 18 | from cirtorch.utils.evaluate import compute_map_and_print 19 | from cirtorch.utils.general import get_data_root, htime 20 | from cirtorch.networks.imageretrievalnet_cpu import init_network, extract_vectors 21 | from cirtorch.datasets.datahelpers import imresize 22 | 23 | from PIL import Image 24 | import numpy as np 25 | import pandas as pd 26 | from flask import Flask, request 27 | import json, io, sys, time, traceback, argparse, logging, subprocess, pickle, os, yaml,shutil 28 | import cv2 29 | import pdb 30 | from werkzeug.utils import cached_property 31 | from apscheduler.schedulers.background import BackgroundScheduler 32 | from multiprocessing import Pool 33 | 34 | app = Flask(__name__) 35 | 36 | @app.route("/") 37 | def index(): 38 | return "" 39 | 40 | @app.route("/images/*", methods=['GET','POST']) 41 | def accInsurance(): 42 | """ 43 | flask request process handle 44 | :return: 45 | """ 46 | try: 47 | if request.method == 'GET': 48 | return json.dumps({'err': 1, 'msg': 'POST only'}) 49 | else: 50 | app.logger.debug("print headers------") 51 | headers = request.headers 52 | headers_info = "" 53 | for k, v in headers.items(): 54 | headers_info += "{}: {}\n".format(k, v) 55 | app.logger.debug(headers_info) 56 | 57 | app.logger.debug("print forms------") 58 | forms_info = "" 59 | for k, v in request.form.items(): 60 | forms_info += "{}: {}\n".format(k, v) 61 | app.logger.debug(forms_info) 62 | 63 | if 'query' not in request.files: 64 | return json.dumps({'err': 2, 'msg': 'query image is empty'}) 65 | 66 | if 'sig' not in request.form: 67 | return json.dumps({'err': 3, 'msg': 'sig is empty'}) 68 | 69 | if 'q_no' not in request.form: 70 | return json.dumps({'err': 4, 'msg': 'no is empty'}) 71 | 72 | if 'q_did' not in request.form: 73 | return json.dumps({'err': 5, 'msg': 'did is empty'}) 74 | 75 | if 'q_id' not in request.form: 76 | return json.dumps({'err': 6, 'msg': 'id is empty'}) 77 | 78 | if 'type' not in request.form: 79 | return json.dumps({'err': 7, 'msg': 'type is empty'}) 80 | 81 | img_name = request.files['query'].filename 82 | img_bytes = request.files['query'].read() 83 | img = request.files['query'] 84 | sig = request.form['sig'] 85 | q_no = request.form['q_no'] 86 | q_did = request.form['q_did'] 87 | q_id = request.form['q_id'] 88 | type = request.form['type'] 89 | 90 | if str(type) not in types: 91 | return json.dumps({'err': 8, 'msg': 'type is not exist'}) 92 | 93 | if img_bytes is None: 94 | return json.dumps({'err': 10, 'msg': 'img is none'}) 95 | 96 | results = imageRetrieval().retrieval_online_v0(img, q_no, q_did, q_id, type) 97 | 98 | data = dict() 99 | data['query'] = img_name 100 | data['sig'] = sig 101 | data['type'] = type 102 | data['q_no'] = q_no 103 | data['q_did'] = q_did 104 | data['q_id'] = q_id 105 | data['results'] = results 106 | 107 | return json.dumps({'err': 0, 'msg': 'success', 'data': data}) 108 | 109 | except: 110 | app.logger.exception(sys.exc_info()) 111 | return json.dumps({'err': 9, 'msg': 'unknow error'}) 112 | 113 | 114 | class imageRetrieval(): 115 | def __init__(self): 116 | pass 117 | 118 | def cosine_dist(self, x, y): 119 | return 100 * float(np.dot(x, y))/(np.dot(x,x)*np.dot(y,y)) ** 0.5 120 | 121 | def inference(self, img): 122 | try: 123 | input = Image.open(img).convert("RGB") 124 | input = imresize(input, 224) 125 | input = transforms(input).unsqueeze() 126 | with torch.no_grad(): 127 | vect = net(input) 128 | return vect 129 | except: 130 | print('cannot indentify error') 131 | 132 | def retrieval_online_v0(self, img, q_no, q_did, q_id, type): 133 | # load model 134 | query_vect = self.inference(img) 135 | query_vect = list(query_vect.detach().numpy().T[0]) 136 | 137 | lsh = lsh_dict[str(type)] 138 | response = lsh.query(query_vect, num_results=1, distance_func = "cosine") 139 | 140 | try: 141 | similar_path = response[0][0][1] 142 | score = np.rint(self.cosine_dist(list(query_vect), list(response[0][0][0]))) 143 | rank_list = similar_path.split("/") 144 | s_id, s_did, s_no = rank_list[-1].split("_")[-1].split(".")[0], rank_list[-1].split("_")[0], rank_list[-2] 145 | results = [{"s_no": s_no, "r_did": s_did, "s_id": s_id, "score": score}] 146 | except: 147 | results = [] 148 | 149 | img_path = "/{}/{}_{}".format(q_no, q_did, q_id) 150 | lsh.index(query_vect, extra_data=img_path) 151 | lsh_dict[str(type)] = lsh 152 | 153 | return results 154 | 155 | 156 | 157 | class initModel(): 158 | def __init__(self): 159 | pass 160 | 161 | def init_model(self, network, model_dir, types): 162 | print(">> Loading network:\n>>>> '{}'".format(network)) 163 | # state = load_url(PRETRAINED[args.network], model_dir=os.path.join(get_data_root(), 'networks')) 164 | state = torch.load(network) 165 | # parsing net params from meta 166 | # architecture, pooling, mean, std required 167 | # the rest has default values, in case that is doesnt exist 168 | net_params = {} 169 | net_params['architecture'] = state['meta']['architecture'] 170 | net_params['pooling'] = state['meta']['pooling'] 171 | net_params['local_whitening'] = state['meta'].get('local_whitening', False) 172 | net_params['regional'] = state['meta'].get('regional', False) 173 | net_params['whitening'] = state['meta'].get('whitening', False) 174 | net_params['mean'] = state['meta']['mean'] 175 | net_params['std'] = state['meta']['std'] 176 | net_params['pretrained'] = False 177 | # network initialization 178 | net = init_network(net_params) 179 | net.load_state_dict(state['state_dict']) 180 | print(">>>> loaded network: ") 181 | print(net.meta_repr()) 182 | # moving network to gpu and eval mode 183 | # net.cuda() 184 | net.eval() 185 | 186 | # set up the transform 187 | normalize = transforms.Normalize( 188 | mean=net.meta['mean'], 189 | std=net.meta['std'] 190 | ) 191 | transform = transforms.Compose([ 192 | transforms.ToTensor(), 193 | normalize 194 | ]) 195 | 196 | lsh_dict = dict() 197 | for type in types: 198 | with open(os.path.join(model_dir, "dataset_index_{}.pkl".format(str(type))), "rb") as f: 199 | lsh = pickle.load(f) 200 | 201 | lsh_dict[str(type)] = lsh 202 | 203 | return net, lsh_dict, transforms 204 | 205 | def init(self): 206 | with open('config.yaml', 'r') as f: 207 | conf = yaml.load(f) 208 | 209 | app.logger.info(conf) 210 | host = conf['website']['host'] 211 | port = conf['website']['port'] 212 | network = conf['model']['network'] 213 | model_dir = conf['model']['model_dir'] 214 | types = conf['model']['type'] 215 | 216 | net, lsh_dict, transforms = self.init_model(network, model_dir, types) 217 | 218 | return host, port, net, lsh_dict, transforms, model_dir, types 219 | 220 | 221 | def job(): 222 | for type in types: 223 | with open(os.path.join(model_dir, "dataset_index_{}_v0.pkl".format(str(type))), "wb") as f: 224 | pickle.dump(lsh_dict[str(type)], f) 225 | 226 | 227 | if __name__ == "__main__": 228 | """ 229 | start app from ssh 230 | """ 231 | scheduler = BackgroundScheduler() 232 | host, port, net, lsh_dict, transforms, model_dir, types = initModel().init() 233 | app.run(host=host, port=port, debug=True) 234 | print("start server {}:{}".format(host, port)) 235 | 236 | scheduler.add_job(job, 'interval', seconds= 30) 237 | scheduler.start() 238 | 239 | else: 240 | """ 241 | start app from gunicorn 242 | """ 243 | scheduler = BackgroundScheduler() 244 | gunicorn_logger = logging.getLogger("gunicorn.error") 245 | app.logger.handlers = gunicorn_logger.handlers 246 | app.logger.setLevel(gunicorn_logger.level) 247 | 248 | host, port, net, lsh_dict, transforms, model_dir, types = initModel().init() 249 | app.logger.info("started from gunicorn...") 250 | 251 | scheduler.add_job(job, 'interval', seconds=30) 252 | scheduler.start() 253 | 254 | 255 | 256 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 44 | 45 | 46 | 48 | 49 | 55 | 56 | 57 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 |