├── .gitignore ├── configs └── retrieval.yaml ├── modules ├── __init__.py ├── datasets │ ├── __init__.py │ ├── data_helpers.py │ └── generic_dataset.py ├── layers │ ├── __init__.py │ ├── functional.py │ ├── loss.py │ ├── normalization.py │ └── pooling.py ├── lshash │ ├── __init__.py │ ├── lshash.py │ └── storage.py ├── model_const.py ├── networks │ ├── __init__.py │ └── retrieval_net.py └── solver │ ├── __init__.py │ ├── feature_extractor.py │ ├── image_retriever.py │ └── model_initializer.py ├── readme.md ├── retrieval.py ├── retrieval_demo.py ├── scripts ├── download_imgs_mp.py ├── img_downloader.py ├── imgs.txt ├── lshash_indexer.py ├── remove_file.py └── test.py └── utils ├── __init__.py ├── common_util.py ├── config_util.py ├── download_util.py ├── evalute_util.py └── image_processor.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | .idea 3 | data/images/* 4 | data/output/* 5 | data/query_images/* 6 | workshop/* 7 | models/* -------------------------------------------------------------------------------- /configs/retrieval.yaml: -------------------------------------------------------------------------------- 1 | image_gallery: 'data/images' 2 | query_image: "data/query_images" 3 | query_number: 100 4 | checkpoint: 'models/image_retrieval_best.pth' 5 | out_similar_dir: 'data/output/' 6 | out_similar_file_dir: 'workshop/output/similar_file' 7 | all_csv_file: 'workshop/output/aaa.csv' 8 | 9 | 10 | lsh_config: 11 | hash_size: 0 12 | input_dim: 2048 13 | num_hash_tables: 1 14 | lsh_paths: 15 | hash_size_zero: 16 | - "/data1/changqing/ZyImage_Data/image_gallery/20211117/lsh_hash-size-00_input-idm-2048.pkl" 17 | - "/data1/changqing/ZyImage_Data/image_gallery/20211124/lsh_hash-size-00_input-idm-2048.pkl" 18 | - "/data1/changqing/ZyImage_Data/image_gallery/20211125/lsh_hash-size-00_input-idm-2048.pkl" 19 | - "/data1/changqing/ZyImage_Data/image_gallery/20211126/lsh_hash-size-00_input-idm-2048.pkl" 20 | - "/data1/changqing/ZyImage_Data/image_gallery/20211127/lsh_hash-size-00_input-idm-2048.pkl" 21 | - "/data1/changqing/ZyImage_Data/image_gallery/20211128/lsh_hash-size-00_input-idm-2048.pkl" 22 | - "/data1/changqing/ZyImage_Data/image_gallery/20211129/lsh_hash-size-00_input-idm-2048.pkl" 23 | - "/data1/changqing/ZyImage_Data/image_gallery/20211130/lsh_hash-size-00_input-idm-2048.pkl" 24 | - "/data1/changqing/ZyImage_Data/image_gallery/20211201/lsh_hash-size-00_input-idm-2048.pkl" 25 | - "/data1/changqing/ZyImage_Data/image_gallery/20211202/lsh_hash-size-00_input-idm-2048.pkl" 26 | - "/data1/changqing/ZyImage_Data/image_gallery/20211203/lsh_hash-size-00_input-idm-2048.pkl" 27 | - "/data1/changqing/ZyImage_Data/image_gallery/20211204/lsh_hash-size-00_input-idm-2048.pkl" 28 | - "/data1/changqing/ZyImage_Data/image_gallery/20211205/lsh_hash-size-00_input-idm-2048.pkl" 29 | - "/data1/changqing/ZyImage_Data/image_gallery/20211206/lsh_hash-size-00_input-idm-2048.pkl" 30 | - "/data1/changqing/ZyImage_Data/image_gallery/20211207/lsh_hash-size-00_input-idm-2048.pkl" 31 | - "/data1/changqing/ZyImage_Data/image_gallery/20211208/lsh_hash-size-00_input-idm-2048.pkl" 32 | - "/data1/changqing/ZyImage_Data/image_gallery/20211209/lsh_hash-size-00_input-idm-2048.pkl" 33 | - "/data1/changqing/ZyImage_Data/image_gallery/20211210/lsh_hash-size-00_input-idm-2048.pkl" 34 | - "/data1/changqing/ZyImage_Data/image_gallery/20211211/lsh_hash-size-00_input-idm-2048.pkl" 35 | - "/data1/changqing/ZyImage_Data/image_gallery/20211212/lsh_hash-size-00_input-idm-2048.pkl" 36 | - "/data1/changqing/ZyImage_Data/image_gallery/20211213/lsh_hash-size-00_input-idm-2048.pkl" 37 | hash_size_eight: 38 | - "/data1/changqing/ZyImage_Data/image_gallery/20211117/lsh_indexer-size-08_input-idm-2048.pkl" 39 | - "/data1/changqing/ZyImage_Data/image_gallery/20211124/lsh_indexer-size-08_input-idm-2048.pkl" 40 | - "/data1/changqing/ZyImage_Data/image_gallery/20211125/lsh_indexer-size-08_input-idm-2048.pkl" 41 | - "/data1/changqing/ZyImage_Data/image_gallery/20211126/lsh_indexer-size-08_input-idm-2048.pkl" 42 | - "/data1/changqing/ZyImage_Data/image_gallery/20211127/lsh_indexer-size-08_input-idm-2048.pkl" 43 | - "/data1/changqing/ZyImage_Data/image_gallery/20211128/lsh_indexer-size-08_input-idm-2048.pkl" 44 | - "/data1/changqing/ZyImage_Data/image_gallery/20211129/lsh_indexer-size-08_input-idm-2048.pkl" 45 | - "/data1/changqing/ZyImage_Data/image_gallery/20211130/lsh_indexer-size-08_input-idm-2048.pkl" 46 | - "/data1/changqing/ZyImage_Data/image_gallery/20211201/lsh_indexer-size-08_input-idm-2048.pkl" 47 | - "/data1/changqing/ZyImage_Data/image_gallery/20211202/lsh_indexer-size-08_input-idm-2048.pkl" 48 | - "/data1/changqing/ZyImage_Data/image_gallery/20211203/lsh_indexer-size-08_input-idm-2048.pkl" 49 | - "/data1/changqing/ZyImage_Data/image_gallery/20211204/lsh_indexer-size-08_input-idm-2048.pkl" 50 | - "/data1/changqing/ZyImage_Data/image_gallery/20211205/lsh_indexer-size-08_input-idm-2048.pkl" 51 | - "/data1/changqing/ZyImage_Data/image_gallery/20211206/lsh_indexer-size-08_input-idm-2048.pkl" 52 | - "/data1/changqing/ZyImage_Data/image_gallery/20211207/lsh_indexer-size-08_input-idm-2048.pkl" 53 | - "/data1/changqing/ZyImage_Data/image_gallery/20211208/lsh_indexer-size-08_input-idm-2048.pkl" 54 | - "/data1/changqing/ZyImage_Data/image_gallery/20211210/lsh_indexer-size-08_input-idm-2048.pkl" 55 | - "/data1/changqing/ZyImage_Data/image_gallery/20211211/lsh_indexer-size-08_input-idm-2048.pkl" 56 | - "/data1/changqing/ZyImage_Data/image_gallery/20211212/lsh_indexer-size-08_input-idm-2048.pkl" 57 | - "/data1/changqing/ZyImage_Data/image_gallery/20211213/lsh_indexer-size-08_input-idm-2048.pkl" 58 | 59 | feature_path: "data/test_feature.pkl" 60 | lsh_index_path: "data/test_lsh.pkl" -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :__init__.py.py 4 | # @Time :2021/12/10 下午7:21 5 | # @Author :Chang Qing 6 | 7 | from modules.layers.pooling import * 8 | 9 | FEATURES = { 10 | 'vgg16': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-vgg16-features-d369c8e.pth', 11 | 'resnet50': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet50-features-ac468af.pth', 12 | 'resnet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet101-features-10a101d.pth', 13 | 'resnet152': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet152-features-1011020.pth', 14 | } 15 | 16 | # TODO: pre-compute for more architectures and properly test variations (pre l2norm, post l2norm) 17 | # pre-computed local pca whitening that can be applied before the pooling layer 18 | L_WHITENING = { 19 | 'resnet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-9f830ef.pth', 20 | # no pre l2 norm 21 | # 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-da5c935.pth', # with pre l2 norm 22 | } 23 | 24 | # possible global pooling layers, each on of these can be made regional 25 | POOLING = { 26 | 'mac': MAC, 27 | 'spoc': SPoC, 28 | 'gem': GeM, 29 | 'gemmp': GeMmp, 30 | 'rmac': RMAC, 31 | } 32 | 33 | # TODO: pre-compute for: resnet50-gem-r, resnet50-mac-r, vgg16-mac-r, alexnet-mac-r 34 | # pre-computed regional whitening, for most commonly used architectures and pooling methods 35 | R_WHITENING = { 36 | 'alexnet-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-rwhiten-c8cf7e2.pth', 37 | 'vgg16-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth', 38 | 'resnet101-mac-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-rwhiten-7f1ed8c.pth', 39 | 'resnet101-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-rwhiten-adace84.pth', 40 | } 41 | 42 | # TODO: pre-compute for more architectures 43 | # pre-computed final (global) whitening, for most commonly used architectures and pooling methods 44 | WHITENING = { 45 | 'alexnet-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-whiten-454ad53.pth', 46 | 'alexnet-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-whiten-4c9126b.pth', 47 | 'vgg16-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-whiten-eaa6695.pth', 48 | 'vgg16-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-whiten-83582df.pth', 49 | 'resnet50-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet50-gem-whiten-f15da7b.pth', 50 | 'resnet101-mac-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-whiten-9df41d3.pth', 51 | 'resnet101-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-whiten-22ab0c1.pth', 52 | 'resnet101-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-whiten-b379c0a.pth', 53 | 'resnet101-gemmp': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gemmp-whiten-770f53c.pth', 54 | 'resnet152-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet152-gem-whiten-abe7b93.pth', 55 | 'densenet121-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet121-gem-whiten-79e3eea.pth', 56 | 'densenet169-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet169-gem-whiten-6b2a76a.pth', 57 | 'densenet201-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet201-gem-whiten-22ea45c.pth', 58 | } 59 | 60 | # output dimensionality for supported architectures 61 | OUTPUT_DIM = { 62 | 'alexnet': 256, 63 | 'vgg11': 512, 64 | 'vgg13': 512, 65 | 'vgg16': 512, 66 | 'vgg19': 512, 67 | 'resnet18': 512, 68 | 'resnet34': 512, 69 | 'resnet50': 2048, 70 | 'resnet101': 2048, 71 | 'resnet152': 2048, 72 | 'densenet121': 1024, 73 | 'densenet169': 1664, 74 | 'densenet201': 1920, 75 | 'densenet161': 2208, # largest densenet 76 | 'squeezenet1_0': 512, 77 | 'squeezenet1_1': 512, 78 | } 79 | 80 | 81 | -------------------------------------------------------------------------------- /modules/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :__init__.py.py 4 | # @Time :2021/12/14 下午8:01 5 | # @Author :Chang Qing 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /modules/datasets/data_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | import torch 5 | 6 | 7 | def cid2filename(cid, prefix): 8 | """ 9 | Creates a training image path out of its CID name 10 | 11 | Arguments 12 | --------- 13 | cid : name of the image 14 | prefix : root directory where images are saved 15 | 16 | Returns 17 | ------- 18 | filename : full image filename 19 | """ 20 | return os.path.join(prefix, cid[-2:], cid[-4:-2], cid[-6:-4], cid) 21 | 22 | 23 | def pil_loader(path): 24 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 25 | with open(path, 'rb') as f: 26 | img = Image.open(f) 27 | return img.convert('RGB') 28 | 29 | 30 | def accimage_loader(path): 31 | import accimage 32 | try: 33 | return accimage.Image(path) 34 | except IOError: 35 | # Potentially a decoding problem, fall back to PIL.Image 36 | return pil_loader(path) 37 | 38 | 39 | def default_loader(path): 40 | from torchvision import get_image_backend 41 | if get_image_backend() == 'accimage': 42 | return accimage_loader(path) 43 | else: 44 | return pil_loader(path) 45 | 46 | 47 | def imresize(img, imsize): 48 | img.thumbnail((imsize, imsize), Image.ANTIALIAS) 49 | return img 50 | 51 | 52 | def flip(x, dim): 53 | xsize = x.size() 54 | dim = x.dim() + dim if dim < 0 else dim 55 | x = x.view(-1, *xsize[dim:]) 56 | x = x.view(x.size(0), x.size(1), -1)[:, 57 | getattr(torch.arange(x.size(1) - 1, -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :] 58 | return x.view(xsize) 59 | 60 | 61 | def collate_tuples(batch): 62 | if len(batch) == 1: 63 | return [batch[0][0]], [batch[0][1]] 64 | return [batch[i][0] for i in range(len(batch))], [batch[i][1] for i in range(len(batch))] 65 | -------------------------------------------------------------------------------- /modules/datasets/generic_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | 4 | import torch 5 | import torch.utils.data as data 6 | 7 | from modules.datasets.data_helpers import default_loader, imresize 8 | 9 | 10 | class ImagesFromPathList(data.Dataset): 11 | """A generic data loader that loads images from a list 12 | (Based on ImageFolder from pytorch) 13 | Args: 14 | root (string): Root directory path. 15 | images (list): Relative image paths as strings. 16 | img_resize (int, Default: None): Defines the maximum size of longer image side 17 | bbxs (list): List of (x1,y1,x2,y2) tuples to crop the query images 18 | transform (callable, optional): A function/transform that takes in an PIL image 19 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 20 | loader (callable, optional): A function to load an image given its path. 21 | Attributes: 22 | images_fn (list): List of full image filename 23 | """ 24 | 25 | def __init__(self, path_list, img_resize=None, bbxs=None, transform=None, loader=default_loader): 26 | 27 | if len(path_list) == 0: 28 | raise (RuntimeError("Dataset contains 0 images!")) 29 | 30 | self.path_list = path_list 31 | self.img_resize = img_resize 32 | self.bbxs = bbxs 33 | self.transform = transform 34 | self.loader = loader 35 | 36 | def __getitem__(self, index): 37 | """ 38 | Args: 39 | index (int): Index 40 | Returns: 41 | image (PIL): Loaded image 42 | """ 43 | try: 44 | path = self.path_list[index] 45 | img = self.loader(path) 46 | imfullsize = max(img.size) 47 | 48 | if self.bbxs is not None: 49 | img = img.crop(self.bbxs[index]) 50 | 51 | if self.img_resize is not None: 52 | if self.bbxs is not None: 53 | img = imresize(img, self.img_resize * max(img.size) / imfullsize) 54 | else: 55 | img = imresize(img, self.img_resize) 56 | 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | return img, path 60 | except: 61 | print(path) 62 | return self.__getitem__(index+1) 63 | 64 | def __len__(self): 65 | return len(self.path_list) 66 | 67 | def __repr__(self): 68 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 69 | fmt_str += ' Number of images: {}\n'.format(self.__len__()) 70 | tmp = ' Transforms (if any): ' 71 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 72 | return fmt_str 73 | 74 | 75 | class ImagesFromDataList(data.Dataset): 76 | """A generic data loader that loads images given as an array of pytorch tensors 77 | (Based on ImageFolder from pytorch) 78 | Args: 79 | images (list): Images as tensors. 80 | transform (callable, optional): A function/transform that image as a tensors 81 | and returns a transformed version. E.g, ``normalize`` with mean and std 82 | """ 83 | 84 | def __init__(self, images, transform=None): 85 | 86 | if len(images) == 0: 87 | raise (RuntimeError("Dataset contains 0 images!")) 88 | 89 | self.images = images 90 | self.transform = transform 91 | 92 | def __getitem__(self, index): 93 | """ 94 | Args: 95 | index (int): Index 96 | Returns: 97 | image (Tensor): Loaded image 98 | """ 99 | img = self.images[index] 100 | if self.transform is not None: 101 | img = self.transform(img) 102 | 103 | if len(img.size()): 104 | img = img.unsqueeze(0) 105 | 106 | return img 107 | 108 | def __len__(self): 109 | return len(self.images) 110 | 111 | def __repr__(self): 112 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 113 | fmt_str += ' Number of images: {}\n'.format(self.__len__()) 114 | tmp = ' Transforms (if any): ' 115 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 116 | return fmt_str 117 | -------------------------------------------------------------------------------- /modules/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianruochen/ImageRetrieval-Pytorch/4c478caeb9d24b0cc714cea48cd84e318fd15494/modules/layers/__init__.py -------------------------------------------------------------------------------- /modules/layers/functional.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pdb 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | # -------------------------------------- 8 | # pooling 9 | # -------------------------------------- 10 | 11 | # 全局最大池化 12 | def mac(x): 13 | return F.max_pool2d(x, (x.size(-2), x.size(-1))) # (x.size(-2), x.size(-1))是kernel的大小 14 | # return F.adaptive_max_pool2d(x, (1,1)) # alternative (1,1) 是最终输出的大小 15 | 16 | # 全局平均池化 17 | def spoc(x): 18 | return F.avg_pool2d(x, (x.size(-2), x.size(-1))) 19 | # return F.adaptive_avg_pool2d(x, (1,1)) # alternative 20 | 21 | # 全局平均池化变种,涉及幂指操作 22 | def gem(x, p=3, eps=1e-6): 23 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 24 | # return F.lp_pool2d(F.threshold(x, eps, eps), p, (x.size(-2), x.size(-1))) # alternative 25 | 26 | 27 | def rmac(x, L=3, eps=1e-6): 28 | ovr = 0.4 # desired overlap of neighboring regions 29 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 30 | 31 | W = x.size(3) 32 | H = x.size(2) 33 | 34 | w = min(W, H) 35 | w2 = math.floor(w/2.0 - 1) 36 | 37 | b = (max(H, W)-w)/(steps-1) 38 | (tmp, idx) = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension 39 | 40 | # region overplus per dimension 41 | Wd = 0; 42 | Hd = 0; 43 | if H < W: 44 | Wd = idx.item() + 1 45 | elif H > W: 46 | Hd = idx.item() + 1 47 | 48 | v = F.max_pool2d(x, (x.size(-2), x.size(-1))) 49 | v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v) 50 | 51 | for l in range(1, L+1): 52 | wl = math.floor(2*w/(l+1)) 53 | wl2 = math.floor(wl/2 - 1) 54 | 55 | if l+Wd == 1: 56 | b = 0 57 | else: 58 | b = (W-wl)/(l+Wd-1) 59 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates 60 | if l+Hd == 1: 61 | b = 0 62 | else: 63 | b = (H-wl)/(l+Hd-1) 64 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates 65 | 66 | for i_ in cenH.tolist(): 67 | for j_ in cenW.tolist(): 68 | if wl == 0: 69 | continue 70 | R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:] 71 | R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()] 72 | vt = F.max_pool2d(R, (R.size(-2), R.size(-1))) 73 | vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt) 74 | v += vt 75 | 76 | return v 77 | 78 | 79 | def roipool(x, rpool, L=3, eps=1e-6): 80 | #x: (bs, C, H, W) 81 | ovr = 0.4 # desired overlap of neighboring regions 82 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 83 | 84 | W = x.size(3) 85 | H = x.size(2) 86 | 87 | w = min(W, H) 88 | w2 = math.floor(w/2.0 - 1) 89 | 90 | b = (max(H, W)-w)/(steps-1) 91 | _, idx = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension 92 | 93 | # region overplus per dimension 94 | Wd = 0; 95 | Hd = 0; 96 | if H < W: 97 | Wd = idx.item() + 1 98 | elif H > W: 99 | Hd = idx.item() + 1 100 | 101 | vecs = [] 102 | vecs.append(rpool(x).unsqueeze(1)) 103 | 104 | for l in range(1, L+1): 105 | wl = math.floor(2*w/(l+1)) 106 | wl2 = math.floor(wl/2 - 1) 107 | 108 | if l+Wd == 1: 109 | b = 0 110 | else: 111 | b = (W-wl)/(l+Wd-1) 112 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b).int() - wl2 # center coordinates 113 | if l+Hd == 1: 114 | b = 0 115 | else: 116 | b = (H-wl)/(l+Hd-1) 117 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b).int() - wl2 # center coordinates 118 | 119 | for i_ in cenH.tolist(): 120 | for j_ in cenW.tolist(): 121 | if wl == 0: 122 | continue 123 | vecs.append(rpool(x.narrow(2,i_,wl).narrow(3,j_,wl)).unsqueeze(1)) 124 | 125 | return torch.cat(vecs, dim=1) 126 | 127 | 128 | # -------------------------------------- 129 | # normalization 130 | # -------------------------------------- 131 | 132 | def l2n(x, eps=1e-6): 133 | return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x) 134 | 135 | def powerlaw(x, eps=1e-6): 136 | x = x + eps 137 | return x.abs().sqrt().mul(x.sign()) 138 | 139 | # -------------------------------------- 140 | # loss 141 | # -------------------------------------- 142 | 143 | def contrastive_loss(x, label, margin=0.7, eps=1e-6): 144 | # x is (embedding_dim, bs) 145 | dim = x.size(0) # embedding_dim 146 | nq = torch.sum(label.data==-1) # number of tuples 147 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 148 | 149 | # (embedding_dim, nq) --> (nq, embedding_dim) --> (nq, embedding_dim * (1+n)) --> ((1+n)*nq, embedding_dim) 150 | x1 = x[:, ::S].permute(1,0).repeat(1,S-1).view((S-1)*nq, dim).permute(1, 0) 151 | idx = [i for i in range(len(label)) if label.data[i] != -1] 152 | x2 = x[:, idx] 153 | lbl = label[label!=-1] 154 | 155 | dif = x1 - x2 156 | D = torch.pow(dif+eps, 2).sum(dim=0).sqrt() 157 | 158 | y = 0.5*lbl*torch.pow(D,2) + 0.5*(1-lbl)*torch.pow(torch.clamp(margin-D, min=0),2) 159 | y = torch.sum(y) 160 | return y 161 | 162 | def triplet_loss(x, label, margin=0.1): 163 | # x is D x N 164 | dim = x.size(0) # D 165 | nq = torch.sum(label.data==-1).item() # number of tuples 166 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 167 | 168 | xa = x[:, label.data==-1].permute(1,0).repeat(1,S-2).view((S-2)*nq,dim).permute(1,0) 169 | xp = x[:, label.data==1].permute(1,0).repeat(1,S-2).view((S-2)*nq,dim).permute(1,0) 170 | xn = x[:, label.data==0] 171 | 172 | dist_pos = torch.sum(torch.pow(xa - xp, 2), dim=0) 173 | dist_neg = torch.sum(torch.pow(xa - xn, 2), dim=0) 174 | 175 | return torch.sum(torch.clamp(dist_pos - dist_neg + margin, min=0)) 176 | -------------------------------------------------------------------------------- /modules/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import modules.layers.functional as LF 5 | 6 | # -------------------------------------- 7 | # Loss/Error layers 8 | # -------------------------------------- 9 | 10 | class ContrastiveLoss(nn.Module): 11 | r"""CONTRASTIVELOSS layer that computes contrastive loss for a batch of images: 12 | Q query tuples, each packed in the form of (q,p,n1,..nN) 13 | 14 | Args: 15 | x: tuples arranges in columns as [q,p,n1,nN, ... ] 16 | label: -1 for query, 1 for corresponding positive, 0 for corresponding negative 17 | margin: contrastive loss margin. Default: 0.7 18 | 19 | >>> contrastive_loss = ContrastiveLoss(margin=0.7) 20 | >>> input = torch.randn(128, 35, requires_grad=True) 21 | >>> label = torch.Tensor([-1, 1, 0, 0, 0, 0, 0] * 5) 22 | >>> output = contrastive_loss(input, label) 23 | >>> output.backward() 24 | """ 25 | 26 | def __init__(self, margin=0.7, eps=1e-6): 27 | super(ContrastiveLoss, self).__init__() 28 | self.margin = margin 29 | self.eps = eps 30 | 31 | def forward(self, x, label): 32 | return LF.contrastive_loss(x, label, margin=self.margin, eps=self.eps) 33 | 34 | def __repr__(self): 35 | return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')' 36 | 37 | 38 | class TripletLoss(nn.Module): 39 | 40 | def __init__(self, margin=0.1): 41 | super(TripletLoss, self).__init__() 42 | self.margin = margin 43 | 44 | def forward(self, x, label): 45 | return LF.triplet_loss(x, label, margin=self.margin) 46 | 47 | def __repr__(self): 48 | return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')' 49 | -------------------------------------------------------------------------------- /modules/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import modules.layers.functional as LF 5 | 6 | # -------------------------------------- 7 | # Normalization layers 8 | # -------------------------------------- 9 | 10 | class L2N(nn.Module): 11 | 12 | def __init__(self, eps=1e-6): 13 | super(L2N,self).__init__() 14 | self.eps = eps 15 | 16 | def forward(self, x): 17 | return LF.l2n(x, eps=self.eps) 18 | 19 | def __repr__(self): 20 | return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')' 21 | 22 | 23 | class PowerLaw(nn.Module): 24 | 25 | def __init__(self, eps=1e-6): 26 | super(PowerLaw, self).__init__() 27 | self.eps = eps 28 | 29 | def forward(self, x): 30 | return LF.powerlaw(x, eps=self.eps) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')' -------------------------------------------------------------------------------- /modules/layers/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | import modules.layers.functional as LF 6 | from modules.layers.normalization import L2N 7 | 8 | # -------------------------------------- 9 | # Pooling layers 10 | # -------------------------------------- 11 | 12 | class MAC(nn.Module): 13 | 14 | def __init__(self): 15 | super(MAC,self).__init__() 16 | 17 | def forward(self, x): 18 | return LF.mac(x) 19 | 20 | def __repr__(self): 21 | return self.__class__.__name__ + '()' 22 | 23 | 24 | class SPoC(nn.Module): 25 | 26 | def __init__(self): 27 | super(SPoC,self).__init__() 28 | 29 | def forward(self, x): 30 | return LF.spoc(x) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '()' 34 | 35 | 36 | class GeM(nn.Module): 37 | 38 | def __init__(self, p=3, eps=1e-6): 39 | super(GeM,self).__init__() 40 | self.p = Parameter(torch.ones(1)*p) 41 | self.eps = eps 42 | 43 | def forward(self, x): 44 | return LF.gem(x, p=self.p, eps=self.eps) 45 | 46 | def __repr__(self): 47 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 48 | 49 | class GeMmp(nn.Module): 50 | 51 | def __init__(self, p=3, mp=1, eps=1e-6): 52 | super(GeMmp,self).__init__() 53 | self.p = Parameter(torch.ones(mp)*p) 54 | self.mp = mp 55 | self.eps = eps 56 | 57 | def forward(self, x): 58 | return LF.gem(x, p=self.p.unsqueeze(-1).unsqueeze(-1), eps=self.eps) 59 | 60 | def __repr__(self): 61 | return self.__class__.__name__ + '(' + 'p=' + '[{}]'.format(self.mp) + ', ' + 'eps=' + str(self.eps) + ')' 62 | 63 | class RMAC(nn.Module): 64 | 65 | def __init__(self, L=3, eps=1e-6): 66 | super(RMAC,self).__init__() 67 | self.L = L 68 | self.eps = eps 69 | 70 | def forward(self, x): 71 | return LF.rmac(x, L=self.L, eps=self.eps) 72 | 73 | def __repr__(self): 74 | return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')' 75 | 76 | 77 | class Rpool(nn.Module): 78 | 79 | def __init__(self, rpool, whiten=None, L=3, eps=1e-6): 80 | super(Rpool,self).__init__() 81 | self.rpool = rpool 82 | self.L = L 83 | self.whiten = whiten 84 | self.norm = L2N() 85 | self.eps = eps 86 | 87 | def forward(self, x, aggregate=True): 88 | # features -> roipool 89 | # x: (bs, C, H, W) 90 | # roipool 在特征图上按照一定的规则划定roi(x个), 并对每一个roi进行pool操作,最终将池化结果concatenate起来 (bs, x, C, 1, 1) 91 | o = LF.roipool(x, self.rpool, self.L, self.eps) # size: #im, #reg, D, 1, 1 92 | 93 | # concatenate regions from all images in the batch 94 | s = o.size() # (bs, x, C, 1, 1) 95 | o = o.view(s[0]*s[1], s[2], s[3], s[4]) # size: (bs * x, C, 1, 1) 96 | 97 | # rvecs -> norm 98 | o = self.norm(o) 99 | 100 | # rvecs -> whiten -> norm 101 | if self.whiten is not None: 102 | o = self.norm(self.whiten(o.squeeze(-1).squeeze(-1))) 103 | 104 | # reshape back to regions per image 105 | o = o.view(s[0], s[1], s[2], s[3], s[4]) # (bs, x, C, 1, 1) 106 | 107 | # aggregate regions into a single global vector per image 108 | if aggregate: 109 | # rvecs -> sumpool -> norm 110 | # 这里已经norm过一次了,外面又norm了一次,感觉这里的norm没什么必要, 所以擅自去掉了了。 111 | # o = self.norm(o.sum(1, keepdim=False)) # size:(bs, C, 1, 1) 112 | #自己改过的: 113 | o = o.sum(1, keepdim=False) 114 | return o 115 | 116 | def __repr__(self): 117 | return super(Rpool, self).__repr__() + '(' + 'L=' + '{}'.format(self.L) + ')' -------------------------------------------------------------------------------- /modules/lshash/__init__.py: -------------------------------------------------------------------------------- 1 | import pkg_resources 2 | 3 | try: 4 | __version__ = pkg_resources.get_distribution(__name__).version 5 | except: 6 | __version__ = '0.0.4dev' 7 | 8 | from modules.lshash.lshash import LSHash -------------------------------------------------------------------------------- /modules/lshash/lshash.py: -------------------------------------------------------------------------------- 1 | # lshash/lshash.py 2 | # Copyright 2012 Kay Zhu (a.k.a He Zhu) and contributors (see CONTRIBUTORS.txt) 3 | # 4 | # This module is part of lshash and is released under 5 | # the MIT License: http://www.opensource.org/licenses/mit-license.php 6 | # -*- coding: utf-8 -*- 7 | from __future__ import print_function, unicode_literals, division, absolute_import 8 | from builtins import int, round, str, object # noqa 9 | from future import standard_library 10 | standard_library.install_aliases() # noqa: Counter, OrderedDict, 11 | from past.builtins import basestring # noqa: 12 | 13 | import future # noqa 14 | import builtins # noqa 15 | import past # noqa 16 | import six # noqa 17 | 18 | import os 19 | import json 20 | import numpy as np 21 | 22 | try: 23 | from storage import storage # py2 24 | except ImportError: 25 | from .storage import storage # py3 26 | 27 | try: 28 | from bitarray import bitarray 29 | except ImportError: 30 | bitarray = None 31 | 32 | 33 | try: 34 | xrange # py2 35 | except NameError: 36 | xrange = range # py3 37 | 38 | 39 | class LSHash(object): 40 | """ LSHash implments locality sensitive hashing using random projection for 41 | input vectors of dimension `input_dim`. 42 | 43 | Attributes: 44 | 45 | :param hash_size: 46 | The length of the resulting binary hash in integer. E.g., 32 means the 47 | resulting binary hash will be 32-bit long. 48 | :param input_dim: 49 | The dimension of the input vector. E.g., a grey-scale picture of 30x30 50 | pixels will have an input dimension of 900. 51 | :param num_hashtables: 52 | (optional) The number of hash tables used for multiple lookups. 53 | :param storage_config: 54 | (optional) A dictionary of the form `{backend_name: config}` where 55 | `backend_name` is the either `dict` or `redis`, and `config` is the 56 | configuration used by the backend. For `redis` it should be in the 57 | format of `{"redis": {"host": hostname, "port": port_num}}`, where 58 | `hostname` is normally `localhost` and `port` is normally 6379. 59 | :param matrices_filename: 60 | (optional) Specify the path to the compressed numpy file ending with 61 | extension `.npz`, where the uniform random planes are stored, or to be 62 | stored if the file does not exist yet. 63 | :param overwrite: 64 | (optional) Whether to overwrite the matrices file if it already exist 65 | """ 66 | 67 | def __init__(self, hash_size, input_dim, num_hashtables=1, 68 | storage_config=None, matrices_filename=None, overwrite=False): 69 | 70 | self.hash_size = hash_size # 哈希值用几位数字编码 71 | self.input_dim = input_dim 72 | self.num_hashtables = num_hashtables 73 | 74 | if storage_config is None: 75 | storage_config = {'dict': None} 76 | self.storage_config = storage_config 77 | 78 | # 用来存储 uniform random planes (这些random planes可用来计算hash值) 79 | if matrices_filename and not matrices_filename.endswith('.npz'): 80 | raise ValueError("The specified file name must end with .npz") 81 | self.matrices_filename = matrices_filename 82 | self.overwrite = overwrite 83 | 84 | # 构建self.uniform_planes, shape: (num_hash_table, hash_size, feature_dim) 85 | self._init_uniform_planes() 86 | # 构建hash_tables list. 长度为num_hashtables, 每一个值为一个storage对象(实质上storage是一个字典,字典的值是list类型) 87 | self._init_hashtables() 88 | 89 | def _init_uniform_planes(self): 90 | """ Initialize uniform planes used to calculate the hashes 91 | 92 | if file `self.matrices_filename` exist and `self.overwrite` is 93 | selected, save the uniform planes to the specified file. 94 | 95 | if file `self.matrices_filename` exist and `self.overwrite` is not 96 | selected, load the matrix with `np.load`. 97 | 98 | if file `self.matrices_filename` does not exist and regardless of 99 | `self.overwrite`, only set `self.uniform_planes`. 100 | """ 101 | 102 | if "uniform_planes" in self.__dict__: 103 | return 104 | 105 | if self.matrices_filename: 106 | file_exist = os.path.isfile(self.matrices_filename) 107 | if file_exist and not self.overwrite: 108 | try: 109 | npzfiles = np.load(self.matrices_filename) 110 | except IOError: 111 | print("Cannot load specified file as a numpy array") 112 | raise 113 | else: 114 | npzfiles = sorted(npzfiles.items(), key=lambda x: x[0]) 115 | self.uniform_planes = [t[1] for t in npzfiles] 116 | else: 117 | self.uniform_planes = [self._generate_uniform_planes() 118 | for _ in xrange(self.num_hashtables)] 119 | try: 120 | np.savez_compressed(self.matrices_filename, 121 | *self.uniform_planes) 122 | except IOError: 123 | print("IOError when saving matrices to specificed path") 124 | raise 125 | else: 126 | # shape: (num_hash_table, hash_size, feature_dim) 127 | self.uniform_planes = [self._generate_uniform_planes() 128 | for _ in xrange(self.num_hashtables)] 129 | 130 | def _init_hashtables(self): 131 | """ Initialize the hash tables such that each record will be in the 132 | form of "[storage1, storage2, ...]" """ 133 | # num_hashtables个storage 每一个storage 都是一个字典,字典的值是list类型 134 | self.hash_tables = [storage(self.storage_config, i) 135 | for i in xrange(self.num_hashtables)] 136 | 137 | def _generate_uniform_planes(self): 138 | """ Generate uniformly distributed hyperplanes and return it as a 2D 139 | numpy array. 140 | """ 141 | 142 | return np.random.randn(self.hash_size, self.input_dim) 143 | 144 | def _hash(self, planes, input_point): 145 | """ Generates the binary hash for `input_point` and returns it. 146 | 147 | :param planes: 148 | The planes are random uniform planes with a dimension of 149 | `hash_size` * `input_dim`. 150 | :param input_point: 151 | A Python tuple or list object that contains only numbers. 152 | The dimension needs to be (input_dim,). 153 | :return: hash value shape:(hash_size,) 154 | """ 155 | 156 | try: 157 | # input_point: (embedding_dim,) planes(hash_size, embedding_dim) 158 | input_point = np.array(input_point) # for faster dot product 159 | # (hash_size,) 160 | projections = np.dot(planes, input_point) 161 | except TypeError as e: 162 | print("""The input point needs to be an array-like object with 163 | numbers only elements""") 164 | raise 165 | except ValueError as e: 166 | print("""The input point needs to be of the same dimension as 167 | `input_dim` when initializing this LSHash instance""", e) 168 | raise 169 | else: 170 | return "".join(['1' if i > 0 else '0' for i in projections]) 171 | 172 | def _as_np_array(self, json_or_tuple): 173 | """ Takes either a JSON-serialized data structure or a tuple that has 174 | the original input points stored, and returns the original input point 175 | in numpy array format. 176 | """ 177 | if isinstance(json_or_tuple, basestring): 178 | # JSON-serialized in the case of Redis 179 | try: 180 | # Return the point stored as list, without the extra data 181 | tuples = json.loads(json_or_tuple)[0] 182 | except TypeError: 183 | print("The value stored is not JSON-serilizable") 184 | raise 185 | else: 186 | # If extra_data exists, `tuples` is the entire 187 | # (point:tuple, extra_data). Otherwise (i.e., extra_data=None), 188 | # return the point stored as a tuple 189 | tuples = json_or_tuple 190 | 191 | if isinstance(tuples[0], tuple): 192 | # in this case extra data exists 193 | return np.asarray(tuples[0]) 194 | 195 | elif isinstance(tuples, (tuple, list)): 196 | try: 197 | return np.asarray(tuples) 198 | except ValueError as e: 199 | print("The input needs to be an array-like object", e) 200 | raise 201 | else: 202 | raise TypeError("query data is not supported") 203 | 204 | def index(self, input_point, extra_data=None): 205 | """ Index a single input point by adding it to the selected storage. 206 | 207 | If `extra_data` is provided, it will become the value of the dictionary 208 | {input_point: extra_data}, which in turn will become the value of the 209 | hash table. `extra_data` needs to be JSON serializable if in-memory 210 | dict is not used as storage. 211 | 212 | :param input_point: 213 | A list, or tuple, or numpy ndarray object that contains numbers 214 | only. The dimension needs to be (input_dim,). 215 | This object will be converted to Python tuple and stored in the 216 | selected storage. 217 | :param extra_data: 218 | (optional) Needs to be a JSON-serializable object: list, dicts and 219 | basic types such as strings and integers. 220 | """ 221 | # input_point 特征向量 shape:(embedding_dim,) 一维向量 222 | if isinstance(input_point, np.ndarray): 223 | input_point = input_point.tolist() 224 | 225 | if extra_data: 226 | value = (tuple(input_point), extra_data) 227 | else: 228 | value = tuple(input_point) 229 | 230 | for i, table in enumerate(self.hash_tables): 231 | table.append_val(self._hash(self.uniform_planes[i], input_point), 232 | value) 233 | 234 | def query(self, query_point, num_results=None, distance_func=None): 235 | """ Takes `query_point` which is either a tuple or a list of numbers, 236 | returns `num_results` of results as a list of tuples that are ranked 237 | based on the supplied metric function `distance_func`. 238 | 239 | :param query_point: 240 | A list, or tuple, or numpy ndarray that only contains numbers. 241 | The dimension needs to be (input_dim,). 242 | Used by :meth:`._hash`. 243 | :param num_results: 244 | (optional) Integer, specifies the max amount of results to be 245 | returned. If not specified all candidates will be returned as a 246 | list in ranked order. 247 | :param distance_func: 248 | (optional) The distance function to be used. Currently it needs to 249 | be one of ("hamming", "euclidean", "true_euclidean", 250 | "centred_euclidean", "cosine", "l1norm"). By default "euclidean" 251 | will used. 252 | """ 253 | 254 | candidates = set() 255 | if not distance_func: 256 | distance_func = "euclidean" 257 | 258 | if distance_func == "hamming": 259 | if not bitarray: 260 | raise ImportError(" Bitarray is required for hamming distance") 261 | 262 | for i, table in enumerate(self.hash_tables): 263 | binary_hash = self._hash(self.uniform_planes[i], query_point) 264 | for key in table.keys(): 265 | distance = LSHash.hamming_dist(key, binary_hash) 266 | if distance < 2: 267 | candidates.update(table.get_list(key)) 268 | 269 | d_func = LSHash.euclidean_dist_square 270 | 271 | else: 272 | 273 | if distance_func == "euclidean": 274 | d_func = LSHash.euclidean_dist_square 275 | elif distance_func == "true_euclidean": 276 | d_func = LSHash.euclidean_dist 277 | elif distance_func == "centred_euclidean": 278 | d_func = LSHash.euclidean_dist_centred 279 | elif distance_func == "cosine": 280 | d_func = LSHash.cosine_dist 281 | elif distance_func == "l1norm": 282 | d_func = LSHash.l1norm_dist 283 | else: 284 | raise ValueError("The distance function name is invalid.") 285 | 286 | for i, table in enumerate(self.hash_tables): 287 | binary_hash = self._hash(self.uniform_planes[i], query_point) 288 | candidates.update(table.get_list(binary_hash)) 289 | 290 | # rank candidates by distance function 291 | candidates = [(ix, d_func(query_point, self._as_np_array(ix))) 292 | for ix in candidates] 293 | candidates = sorted(candidates, key=lambda x: x[1]) 294 | 295 | return candidates[:num_results] if num_results else candidates 296 | 297 | ### distance functions 298 | 299 | @staticmethod 300 | def hamming_dist(bitarray1, bitarray2): 301 | xor_result = bitarray(bitarray1) ^ bitarray(bitarray2) 302 | return xor_result.count() 303 | 304 | @staticmethod 305 | def euclidean_dist(x, y): 306 | """ This is a hot function, hence some optimizations are made. """ 307 | diff = np.array(x) - y 308 | return np.sqrt(np.dot(diff, diff)) 309 | 310 | @staticmethod 311 | def euclidean_dist_square(x, y): 312 | """ This is a hot function, hence some optimizations are made. """ 313 | diff = np.array(x) - y 314 | return np.dot(diff, diff) 315 | 316 | @staticmethod 317 | def euclidean_dist_centred(x, y): 318 | """ This is a hot function, hence some optimizations are made. """ 319 | diff = np.mean(x) - np.mean(y) 320 | return np.dot(diff, diff) 321 | 322 | @staticmethod 323 | def l1norm_dist(x, y): 324 | return sum(abs(x - y)) 325 | 326 | @staticmethod 327 | def cosine_dist(x, y): 328 | return 1 - float(np.dot(x, y)) / ((np.dot(x, x) * np.dot(y, y)) ** 0.5) 329 | 330 | if __name__ == '__main__': 331 | lsh = LSHash(3, 8) 332 | lsh.index([1, 2, 3, 4, 5, 6, 7, 8]) 333 | lsh.index([2, 3, 4, 5, 6, 7, 8, 9]) 334 | lsh.index([10, 12, 99, 1, 5, 31, 2, 3]) 335 | res = lsh.query([1, 2, 3, 4, 5, 6, 7, 7]) 336 | print(res) -------------------------------------------------------------------------------- /modules/lshash/storage.py: -------------------------------------------------------------------------------- 1 | # lshash/storage.py 2 | # Copyright 2012 Kay Zhu (a.k.a He Zhu) and contributors (see CONTRIBUTORS.txt) 3 | # 4 | # This module is part of lshash and is released under 5 | # the MIT License: http://www.opensource.org/licenses/mit-license.php 6 | 7 | import json 8 | 9 | try: 10 | import redis 11 | except ImportError: 12 | redis = None 13 | 14 | __all__ = ['storage'] 15 | 16 | 17 | def storage(storage_config, index): 18 | """ Given the configuration for storage and the index, return the 19 | configured storage instance. 20 | """ 21 | if 'dict' in storage_config: 22 | return InMemoryStorage(storage_config['dict']) 23 | elif 'redis' in storage_config: 24 | storage_config['redis']['db'] = index 25 | return RedisStorage(storage_config['redis']) 26 | else: 27 | raise ValueError("Only in-memory dictionary and Redis are supported.") 28 | 29 | 30 | class BaseStorage(object): 31 | def __init__(self, config): 32 | """ An abstract class used as an adapter for storages. """ 33 | raise NotImplementedError 34 | 35 | def keys(self): 36 | """ Returns a list of binary hashes that are used as dict keys. """ 37 | raise NotImplementedError 38 | 39 | def set_val(self, key, val): 40 | """ Set `val` at `key`, note that the `val` must be a string. """ 41 | raise NotImplementedError 42 | 43 | def get_val(self, key): 44 | """ Return `val` at `key`, note that the `val` must be a string. """ 45 | raise NotImplementedError 46 | 47 | def append_val(self, key, val): 48 | """ Append `val` to the list stored at `key`. 49 | 50 | If the key is not yet present in storage, create a list with `val` at 51 | `key`. 52 | """ 53 | raise NotImplementedError 54 | 55 | def get_list(self, key): 56 | """ Returns a list stored in storage at `key`. 57 | 58 | This method should return a list of values stored at `key`. `[]` should 59 | be returned if the list is empty or if `key` is not present in storage. 60 | """ 61 | raise NotImplementedError 62 | 63 | 64 | class InMemoryStorage(BaseStorage): 65 | def __init__(self, config): 66 | self.name = 'dict' 67 | self.storage = dict() 68 | 69 | def keys(self): 70 | return self.storage.keys() 71 | 72 | def set_val(self, key, val): 73 | self.storage[key] = val 74 | 75 | def get_val(self, key): 76 | return self.storage[key] 77 | 78 | def append_val(self, key, val): 79 | # 根据指定的key查找value,如果key不存在,则设置为[] 80 | self.storage.setdefault(key, []).append(val) 81 | 82 | def get_list(self, key): 83 | return self.storage.get(key, []) 84 | 85 | 86 | class RedisStorage(BaseStorage): 87 | def __init__(self, config): 88 | if not redis: 89 | raise ImportError("redis-py is required to use Redis as storage.") 90 | self.name = 'redis' 91 | self.storage = redis.StrictRedis(**config) 92 | 93 | def keys(self, pattern="*"): 94 | return self.storage.keys(pattern) 95 | 96 | def set_val(self, key, val): 97 | self.storage.set(key, val) 98 | 99 | def get_val(self, key): 100 | return self.storage.get(key) 101 | 102 | def append_val(self, key, val): 103 | self.storage.rpush(key, json.dumps(val)) 104 | 105 | def get_list(self, key): 106 | return self.storage.lrange(key, 0, -1) 107 | -------------------------------------------------------------------------------- /modules/model_const.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :model_const.py.py 4 | # @Time :2021/12/14 下午4:22 5 | # @Author :Chang Qing 6 | 7 | 8 | FEATURES = { 9 | 'vgg16': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-vgg16-features-d369c8e.pth', 10 | 'resnet50': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet50-features-ac468af.pth', 11 | 'resnet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet101-features-10a101d.pth', 12 | 'resnet152': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet152-features-1011020.pth', 13 | } 14 | 15 | # TODO: pre-compute for more architectures and properly test variations (pre l2norm, post l2norm) 16 | # pre-computed local pca whitening that can be applied before the pooling layer 17 | L_WHITENING = { 18 | 'resnet101': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-9f830ef.pth', 19 | # no pre l2 norm 20 | # 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-da5c935.pth', # with pre l2 norm 21 | } 22 | 23 | # possible global pooling layers, each on of these can be made regional 24 | POOLING = { 25 | 'mac': MAC, 26 | 'spoc': SPoC, 27 | 'gem': GeM, 28 | 'gemmp': GeMmp, 29 | 'rmac': RMAC, 30 | } 31 | 32 | # TODO: pre-compute for: resnet50-gem-r, resnet50-mac-r, vgg16-mac-r, alexnet-mac-r 33 | # pre-computed regional whitening, for most commonly used architectures and pooling methods 34 | R_WHITENING = { 35 | 'alexnet-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-rwhiten-c8cf7e2.pth', 36 | 'vgg16-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth', 37 | 'resnet101-mac-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-rwhiten-7f1ed8c.pth', 38 | 'resnet101-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-rwhiten-adace84.pth', 39 | } 40 | 41 | # TODO: pre-compute for more architectures 42 | # pre-computed final (global) whitening, for most commonly used architectures and pooling methods 43 | WHITENING = { 44 | 'alexnet-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-whiten-454ad53.pth', 45 | 'alexnet-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-whiten-4c9126b.pth', 46 | 'vgg16-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-whiten-eaa6695.pth', 47 | 'vgg16-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-whiten-83582df.pth', 48 | 'resnet50-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet50-gem-whiten-f15da7b.pth', 49 | 'resnet101-mac-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-whiten-9df41d3.pth', 50 | 'resnet101-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-whiten-22ab0c1.pth', 51 | 'resnet101-gem-r': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-whiten-b379c0a.pth', 52 | 'resnet101-gemmp': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gemmp-whiten-770f53c.pth', 53 | 'resnet152-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet152-gem-whiten-abe7b93.pth', 54 | 'densenet121-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet121-gem-whiten-79e3eea.pth', 55 | 'densenet169-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet169-gem-whiten-6b2a76a.pth', 56 | 'densenet201-gem': 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet201-gem-whiten-22ea45c.pth', 57 | } 58 | 59 | # output dimensionality for supported architectures 60 | OUTPUT_DIM = { 61 | 'alexnet': 256, 62 | 'vgg11': 512, 63 | 'vgg13': 512, 64 | 'vgg16': 512, 65 | 'vgg19': 512, 66 | 'resnet18': 512, 67 | 'resnet34': 512, 68 | 'resnet50': 2048, 69 | 'resnet101': 2048, 70 | 'resnet152': 2048, 71 | 'densenet121': 1024, 72 | 'densenet169': 1664, 73 | 'densenet201': 1920, 74 | 'densenet161': 2208, # largest densenet 75 | 'squeezenet1_0': 512, 76 | 'squeezenet1_1': 512, 77 | } 78 | 79 | -------------------------------------------------------------------------------- /modules/networks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :__init__.py.py 4 | # @Time :2021/12/14 下午4:25 5 | # @Author :Chang Qing 6 | 7 | from modules.networks.retrieval_net import ImageRetrievalNet 8 | 9 | -------------------------------------------------------------------------------- /modules/networks/retrieval_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :retrieval_net.py.py 4 | # @Time :2021/12/14 下午4:26 5 | # @Author :Chang Qing 6 | 7 | import torch.nn as nn 8 | 9 | from modules.layers.normalization import L2N 10 | 11 | 12 | class ImageRetrievalNet(nn.Module): 13 | 14 | def __init__(self, features, lwhiten, pool, whiten, meta): 15 | super(ImageRetrievalNet, self).__init__() 16 | self.features = nn.Sequential(*features) 17 | self.lwhiten = lwhiten 18 | self.pool = pool 19 | self.whiten = whiten 20 | self.norm = L2N() 21 | self.meta = meta 22 | 23 | def forward(self, x): 24 | # x -> features 25 | o = self.features(x) # (bs, C, H, W) 26 | 27 | # TODO: properly test (with pre-l2norm and/or post-l2norm) 28 | # if lwhiten exist: features -> local whiten 29 | if self.lwhiten is not None: 30 | # o = self.norm(o) 31 | s = o.size() 32 | # (bs,C,H,W) --> (bs, H, W, C) --> (bs*H*W, C) 33 | o = o.permute(0, 2, 3, 1).contiguous().view(-1, s[1]) 34 | # o = self.norm(o) pre-l2norm 35 | o = self.lwhiten(o) # 本质是一个全连接层, 可以起到降维的作用,但作者这里没有这么用,输入输出都是相同的维度 36 | # o = self.norm(o) post-l2norm 37 | # 还原回原来的shape: (bs, C, H, W) 38 | o = o.view(s[0], s[2], s[3], self.lwhiten.out_features).permute(0, 3, 1, 2) 39 | 40 | # features -> pool -> norm 41 | o = self.norm(self.pool(o)).squeeze(-1).squeeze(-1) 42 | 43 | # if whiten exist: pooled features -> whiten -> norm 44 | if self.whiten is not None: 45 | o = self.norm(self.whiten(o)) 46 | 47 | # permute so that it is Dx1 column vector per image (DxN if many images) 48 | # (embedding_dim, bs) 49 | return o.permute(1, 0) 50 | 51 | def __repr__(self): 52 | tmpstr = super(ImageRetrievalNet, self).__repr__()[:-1] 53 | tmpstr += self.meta_repr() 54 | tmpstr = tmpstr + ')' 55 | return tmpstr 56 | 57 | def meta_repr(self): 58 | tmpstr = ' (' + 'meta' + '): dict( \n' # + self.meta.__repr__() + '\n' 59 | tmpstr += ' architecture: {}\n'.format(self.meta['architecture']) 60 | tmpstr += ' local_whitening: {}\n'.format(self.meta['local_whitening']) 61 | tmpstr += ' pooling: {}\n'.format(self.meta['pooling']) 62 | tmpstr += ' regional: {}\n'.format(self.meta['regional']) 63 | tmpstr += ' whitening: {}\n'.format(self.meta['whitening']) 64 | tmpstr += ' outputdim: {}\n'.format(self.meta['outputdim']) 65 | tmpstr += ' mean: {}\n'.format(self.meta['mean']) 66 | tmpstr += ' std: {}\n'.format(self.meta['std']) 67 | tmpstr = tmpstr + ' )\n' 68 | return tmpstr 69 | -------------------------------------------------------------------------------- /modules/solver/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :__init__.py.py 4 | # @Time :2021/12/10 下午8:02 5 | # @Author :Chang Qing 6 | 7 | 8 | from modules.solver.model_initializer import ModelInitializer 9 | from modules.solver.feature_extractor import FeatureExtractor 10 | from modules.solver.image_retriever import ImageRetriever 11 | -------------------------------------------------------------------------------- /modules/solver/feature_extractor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :feature_extractor.py 4 | # @Time :2021/12/10 下午8:02 5 | # @Author :Chang Qing 6 | 7 | import os 8 | import pickle 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torchvision.transforms as transforms 13 | 14 | from tqdm import tqdm 15 | from torch.utils.data import DataLoader 16 | 17 | from modules.lshash import LSHash 18 | from modules.datasets.generic_dataset import ImagesFromPathList 19 | from utils.image_processor import ImageProcessor 20 | 21 | 22 | class FeatureExtractor: 23 | def __init__(self, model, img_resize=1024): 24 | self.model = model 25 | self.img_resize = img_resize 26 | self.tfms = transforms.Compose([ 27 | transforms.ToTensor(), 28 | transforms.Normalize( 29 | mean=self.model.meta['mean'], 30 | std=self.model.meta['std'] 31 | ) 32 | ]) 33 | 34 | def _parse_img_path(self, img_path): 35 | if os.path.isdir(img_path): 36 | return ImageProcessor(img_path).process() 37 | else: 38 | return [img_path] 39 | 40 | def _extract_ss(self, input_tensor): 41 | return self.model(input_tensor).cpu().data.squeeze() 42 | 43 | def _extract_ms(self, input_tensor, ms, msp): 44 | v = torch.zeros(self.model.meta['outputdim']) 45 | 46 | for s in ms: 47 | if s == 1: 48 | input_t = input_tensor.clone() 49 | else: 50 | input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False) 51 | v += self.model(input_t).pow(msp).cpu().data.squeeze() 52 | 53 | v /= len(ms) 54 | v = v.pow(1. / msp) 55 | v /= v.norm() 56 | return v 57 | 58 | def extract(self, img_path, feature_path="", lsh_config=None, index_path="", multi_scale=None, msp=1): 59 | # build image path list 60 | if multi_scale is None: 61 | multi_scale = [1] 62 | print(f">>> build dataset and dataloader: \n" 63 | f" ... image_path: {img_path}") 64 | img_list = self._parse_img_path(img_path) 65 | # print(img_list) 66 | if torch.cuda.is_available(): 67 | self.model.cuda() 68 | self.model.eval() 69 | 70 | # create data loader 71 | dataset = ImagesFromPathList(path_list=img_list, img_resize=self.img_resize, transform=self.tfms) 72 | del img_list 73 | loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 74 | 75 | print(f">>> done... image nums: {len(dataset)}") 76 | 77 | print(f">>> extract features:") 78 | # extract feature vectors 79 | feature_vectors = torch.zeros(self.model.meta["outputdim"], len(dataset)) 80 | img_paths = list() 81 | with torch.no_grad(): 82 | for i, (input_tensor, img_path) in enumerate(tqdm(loader, total=len(loader))): 83 | if torch.cuda.is_available(): 84 | input_tensor = input_tensor.cuda() 85 | if len(multi_scale) == 1 and multi_scale[0] == 1: 86 | # feature_vector = self._extract_ss(input_tensor) 87 | feature_vectors[:, i] = self._extract_ss(input_tensor) 88 | else: 89 | # feature_vector = self._extract_ms(input_tensor) 90 | feature_vectors[:, i] = self._extract_ms(input_tensor, multi_scale, msp) 91 | # feature_dict[img_path] = feature_vector.detach().cpu().numpy() 92 | img_paths.extend(img_path) 93 | 94 | feature_dict = dict(zip(img_paths, list(feature_vectors.detach().cpu().numpy().T))) 95 | # feature_dict = dict(zip(map(tuple, img_paths), list(feature_vectors.detach().cpu().numpy().T))) 96 | if feature_path: 97 | with open(feature_path, "wb") as f: 98 | pickle.dump(feature_dict, f) 99 | print(f" ... saved features to: {feature_path}") 100 | 101 | # build lsh index 102 | lsh = None 103 | if lsh_config: 104 | hash_size = lsh_config.get("hash_size", 0) 105 | input_dim = lsh_config.get("input_dim", 2048) 106 | num_hash_tables = lsh_config.get("num_hash_tables", 1) 107 | # lsh = LSHash(**self.lsh_config) 108 | lsh = LSHash(hash_size=int(hash_size), input_dim=int(input_dim), num_hashtables=num_hash_tables) 109 | for img_path, vec in feature_dict.items(): 110 | # 使用flatten展成1维向量 111 | lsh.index(vec.flatten(), extra_data=img_path) 112 | if index_path: 113 | with open(index_path, "wb") as f: 114 | pickle.dump(lsh, f) 115 | print(f" ... saved lsh_info to: {index_path}") 116 | 117 | print(">>> extract feature done...") 118 | return feature_dict, lsh 119 | -------------------------------------------------------------------------------- /modules/solver/image_retriever.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :image_retriever.py 4 | # @Time :2021/12/14 下午9:45 5 | # @Author :Chang Qing 6 | 7 | import pickle 8 | import traceback 9 | 10 | class ImageRetriever: 11 | def __init__(self, lsh_path=None): 12 | self.lsh_path = lsh_path 13 | 14 | def retrieval(self, feature_dict, lsh=None, num_results=3, limit_threshold=False, threshold=1): 15 | try: 16 | if not lsh: 17 | print(f">>> load lsh from: {self.lsh_path}") 18 | lsh = pickle.load(open(self.lsh_path, "rb")) 19 | print(">>> load lsh done...") 20 | except: 21 | traceback.print_exc() 22 | print("load lsh model error") 23 | return 24 | similar_img_dict = dict() 25 | similar_img_dict2 = dict() 26 | similar_img_list = [] 27 | for query_path, query_feature in feature_dict.items(): 28 | try: 29 | # res: (((与query_feature相似的特征向量, 图片路径), 距离得分), 30 | # ((与query_feature相似的特征向量, 图片路径), 距离得分)...) 31 | res = lsh.query(query_feature.flatten(), num_results=int(num_results), distance_func="cosine") 32 | queried_paths = [] 33 | # print(len(res)) 34 | for i in range(min(num_results, len(res))): 35 | queried_path = res[i][0][1] 36 | similarity_score = res[i][1] 37 | if limit_threshold and similarity_score > threshold: 38 | continue 39 | queried_paths.append(queried_path) 40 | if queried_path in similar_img_dict2 and similar_img_dict2[queried_path] < similarity_score: 41 | continue 42 | else: 43 | similar_img_dict2[queried_path] = similarity_score 44 | # similar_img_list.append((queried_path, similarity_score)) 45 | similar_img_dict[query_path] = queried_paths 46 | except: 47 | traceback.print_exc() 48 | # similar_img_list = sorted(similar_img_list, key=lambda item: item[1])[:num_results] 49 | similar_img_list = sorted(similar_img_dict2.items(), key=lambda item: item[1])[:num_results] 50 | return similar_img_dict, similar_img_list 51 | -------------------------------------------------------------------------------- /modules/solver/model_initializer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :model_initializer.py 4 | # @Time :2021/12/14 上午10:26 5 | # @Author :Chang Qing 6 | 7 | import os 8 | import torch 9 | import torchvision 10 | import torch.nn as nn 11 | import torch.utils.model_zoo as model_zoo 12 | 13 | from modules.layers.pooling import Rpool 14 | from modules.networks import ImageRetrievalNet 15 | from modules import OUTPUT_DIM, FEATURES, POOLING 16 | from modules import L_WHITENING, WHITENING, R_WHITENING 17 | from utils.common_util import get_data_root 18 | 19 | 20 | class ModelInitializer: 21 | def __init__(self, model_params=None, checkpoint=None): 22 | assert model_params or checkpoint 23 | self.checkpoint = checkpoint 24 | self.model_params = self._parse_params_from_checkpoint() if self.checkpoint else model_params 25 | self.architecture = self.model_params.get('architecture', 'resnet101') 26 | self.local_whitening = self.model_params.get('local_whitening', False) 27 | self.pooling = self.model_params.get('pooling', 'gem') 28 | self.regional = self.model_params.get('regional', False) 29 | self.whitening = self.model_params.get('whitening', False) 30 | self.mean = self.model_params.get('mean', [0.485, 0.456, 0.406]) 31 | self.std = self.model_params.get('std', [0.229, 0.224, 0.225]) 32 | self.pretrained = self.model_params.get('pretrained', True) 33 | self.state_dict = self.model_params.get("state_dict", None) 34 | self.dim = OUTPUT_DIM[self.architecture] 35 | 36 | def _parse_params_from_checkpoint(self): 37 | state = torch.load(self.checkpoint) 38 | # parsing net params from meta 39 | # architecture, pooling, mean, std required 40 | # the rest has default values, in case that is doesnt exist 41 | model_params = {} 42 | model_params['architecture'] = state['meta']['architecture'] 43 | model_params['pooling'] = state['meta']['pooling'] 44 | model_params['local_whitening'] = state['meta'].get('local_whitening', False) 45 | model_params['regional'] = state['meta'].get('regional', False) 46 | model_params['whitening'] = state['meta'].get('whitening', False) 47 | model_params['mean'] = state['meta']['mean'] 48 | model_params['std'] = state['meta']['std'] 49 | model_params['pretrained'] = False 50 | model_params["state_dict"] = state["state_dict"] 51 | return model_params 52 | 53 | def build_network(self): 54 | # loading network from torchvision 55 | if self.pretrained: 56 | if self.architecture not in FEATURES: 57 | # initialize with network pretrained on imagenet in pytorch 58 | net_in = getattr(torchvision.models, self.architecture)(pretrained=True) 59 | else: 60 | # initialize with random weights, later on we will fill features with custom pretrained network 61 | net_in = getattr(torchvision.models, self.architecture)(pretrained=False) 62 | else: 63 | # initialize with random weights 64 | net_in = getattr(torchvision.models, self.architecture)(pretrained=False) 65 | 66 | # initialize features 67 | # take only convolutions for features, 68 | # always ends with ReLU to make last activations non-negative 69 | if self.architecture.startswith('alexnet'): 70 | features = list(net_in.features.children())[:-1] 71 | elif self.architecture.startswith('vgg'): 72 | features = list(net_in.features.children())[:-1] 73 | elif self.architecture.startswith('resnet'): 74 | features = list(net_in.children())[:-2] 75 | elif self.architecture.startswith('densenet'): 76 | features = list(net_in.features.children()) 77 | features.append(nn.ReLU(inplace=True)) 78 | elif self.architecture.startswith('squeezenet'): 79 | features = list(net_in.features.children()) 80 | else: 81 | raise ValueError('Unsupported or unknown architecture: {}!'.format(self.architecture)) 82 | 83 | # initialize local whitening 84 | if self.local_whitening: 85 | lwhiten = nn.Linear(self.dim, self.dim, bias=True) 86 | # TODO: lwhiten with possible dimensionality reduce 87 | 88 | if self.pretrained: 89 | lw = self.architecture 90 | if lw in L_WHITENING: 91 | print(">> {}: for '{}' custom computed local whitening '{}' is used" 92 | .format(os.path.basename(__file__), lw, os.path.basename(L_WHITENING[lw]))) 93 | whiten_dir = os.path.join(get_data_root(), 'whiten') 94 | lwhiten.load_state_dict(model_zoo.load_url(L_WHITENING[lw], model_dir=whiten_dir)) 95 | else: 96 | print(">> {}: for '{}' there is no local whitening computed, random weights are used" 97 | .format(os.path.basename(__file__), lw)) 98 | 99 | else: 100 | lwhiten = None 101 | 102 | # initialize pooling 103 | if self.pooling == 'gemmp': 104 | pool = POOLING[self.pooling](mp=self.dim) 105 | else: 106 | pool = POOLING[self.pooling]() 107 | 108 | # initialize regional pooling 109 | if self.regional: 110 | rpool = pool 111 | rwhiten = nn.Linear(self.dim, self.dim, bias=True) 112 | # TODO: rwhiten with possible dimensionality reduce 113 | 114 | if self.pretrained: 115 | rw = '{}-{}-r'.format(self.architecture, self.pooling) 116 | if rw in R_WHITENING: 117 | print(">> {}: for '{}' custom computed regional whitening '{}' is used" 118 | .format(os.path.basename(__file__), rw, os.path.basename(R_WHITENING[rw]))) 119 | whiten_dir = os.path.join(get_data_root(), 'whiten') 120 | rwhiten.load_state_dict(model_zoo.load_url(R_WHITENING[rw], model_dir=whiten_dir)) 121 | else: 122 | print(">> {}: for '{}' there is no regional whitening computed, random weights are used" 123 | .format(os.path.basename(__file__), rw)) 124 | 125 | pool = Rpool(rpool, rwhiten) 126 | 127 | # initialize whitening 128 | if self.whitening: 129 | whiten = nn.Linear(self.dim, self.dim, bias=True) 130 | # TODO: whiten with possible dimensionality reduce 131 | 132 | if self.pretrained: 133 | w = self.architecture 134 | if self.local_whitening: 135 | w += '-lw' 136 | w += '-' + self.pooling 137 | if self.regional: 138 | w += '-r' 139 | if w in WHITENING: 140 | print(">> {}: for '{}' custom computed whitening '{}' is used" 141 | .format(os.path.basename(__file__), w, os.path.basename(WHITENING[w]))) 142 | whiten_dir = os.path.join(get_data_root(), 'whiten') 143 | whiten.load_state_dict(model_zoo.load_url(WHITENING[w], model_dir=whiten_dir)) 144 | else: 145 | print(">> {}: for '{}' there is no whitening computed, random weights are used" 146 | .format(os.path.basename(__file__), w)) 147 | else: 148 | whiten = None 149 | 150 | # create meta information to be stored in the network 151 | meta = { 152 | 'architecture': self.architecture, 153 | 'local_whitening': self.local_whitening, 154 | 'pooling': self.pooling, 155 | 'regional': self.regional, 156 | 'whitening': self.whitening, 157 | 'mean': self.mean, 158 | 'std': self.std, 159 | 'outputdim': self.dim, 160 | } 161 | 162 | # create a generic image retrieval network 163 | net = ImageRetrievalNet(features, lwhiten, pool, whiten, meta) 164 | 165 | # initialize features with custom pretrained network if needed 166 | if self.pretrained and self.architecture in FEATURES: 167 | print(">> {}: for '{}' custom pretrained features '{}' are used" 168 | .format(os.path.basename(__file__), self.architecture, os.path.basename(FEATURES[self.architecture]))) 169 | model_dir = os.path.join(get_data_root(), 'networks') 170 | net.features.load_state_dict(model_zoo.load_url(FEATURES[self.architecture], model_dir=model_dir)) 171 | 172 | return net 173 | 174 | def init_model(self): 175 | # 1.build network 176 | print(f">>> build network and load weights: \n" 177 | f" ... network arch: {self.architecture}\n" 178 | f" ... weights path: {self.checkpoint}") 179 | net = self.build_network() 180 | # 2.load weights 181 | if self.state_dict: 182 | net.load_state_dict(self.state_dict) 183 | # 3.move to cuda 184 | if torch.cuda.is_available(): 185 | net.cuda() 186 | net.eval() 187 | print(f">>> init model done...") 188 | return net 189 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # ImageRetrieval-Pytroch 2 | [![python version](https://img.shields.io/badge/python-3.6%2B-brightgreen)]() 3 | [![coverage](https://img.shields.io/badge/coverage-100%25-orange)]() 4 | 5 | image retrieval 6 | 7 | ## Table of Contents 8 | 9 | - [Structure](#structure) 10 | - [Usage](#usage) 11 | - [Config_file](#config_file) 12 | 13 | 14 | 15 | 16 | ## structure 17 | ``` 18 | ├── config 19 | │ ├── retrieval.yaml 图片检索的相关配置 20 | ├── data 21 | │ ├── images 默认图片检索库目录,可在retrieval.yaml中更改 22 | │ ├── output 检索结果的存放目录 23 | │ ├── query_images 默认存放需要检索的图片,可在retrieval.yaml文件中更改 24 | ├── models 25 | │ ├── image_retrieval_best.pth 模型权重文件 26 | ├── modules 27 | │ ├── __init__.py 28 | │ ├── datasets 29 | │ │ ├── data_helpers.py 和dataset相关的一些工具函数 30 | │ │ ├── generic_dataset.py dataset 31 | │ ├── layers 32 | │ │ ├── loss.py 33 | │ │ ├── functional.py 34 | │ │ ├── normalization.py 标准化类 35 | │ ├── lshash lsh相关文件 36 | │ ├── networks 37 | │ │ ├── retrieval_net.py 检索模型代码 38 | │ ├── solver 39 | │ │ ├── feature_extractor.py 特征提取类 40 | │ │ ├── image_retriever.py 图像检索类 41 | │ │ ├── model_initializer.py 模型初始化类 42 | │ ├── train_log 43 | │ │ ├── ckpt a directory to save checkpoint 44 | │ │ ├── log a directory to save log 45 | ├── utils 一些有用的工具函数 46 | ├── retrieval_demo.py demo文件,代码入口 47 | 48 | ``` 49 | ## usage 50 | use default params (all parameters setted in model_config file) 51 | ``` 52 | python retrieval_demo.py --model_config configs/retrieval.yaml 53 | ``` 54 | use custom params 55 | ``` 56 | python retrieval_demo.py --model_config configs/retrieval.yaml --query_image XXX --image_gallery XXX --query_number XXX 57 | # custom params 58 | # --query_image: str, 可以单个图片文件地址,或包含一批图片的文件目录 59 | # --image_gallery: str, 图片库目录 60 | # --query_number: int, 返回结果的数量 61 | 62 | ``` 63 | 64 | ## config_file 65 | ```yaml 66 | image_gallery: 'data/images' # 图片库地址 67 | query_image: "data/query_images" # query图片地址 68 | query_number: 10 # 返回结果的数量 69 | checkpoint: 'models/image_retrieval_best.pth.pth' # 模型权重地址 70 | out_similar_dir: 'data/output/' # 检索结果的输出目录 71 | 72 | lsh_config: 73 | hash_size: 0 # hash值长度 74 | input_dim: 2048 # 特征维度 75 | num_hash_tables: 1 # hash表数量 76 | 77 | feature_path: "data/test_feature.pkl" # 图片库特征的存放地址 78 | lsh_index_path: "data/test_lsh.pkl" # lsh索引的存放地址 79 | ``` 80 | 81 | ## Contributing 82 | 83 | ## License 84 | 85 | -------------------------------------------------------------------------------- /retrieval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :retrieval_demo.py 4 | # @Time :2021/12/10 下午7:39 5 | # @Author :Chang Qing 6 | 7 | import os 8 | import pickle 9 | import shutil 10 | import argparse 11 | 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 13 | 14 | from tqdm import tqdm 15 | from modules.solver import ModelInitializer 16 | from modules.solver import FeatureExtractor 17 | from modules.solver import ImageRetriever 18 | from utils.config_util import parse_config, merge_config 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser(description="Image Retrieval Script") 22 | parser.add_argument("--config_path", type=str, default="configs/retrieval.yaml", 23 | help="config file path for image retrieval") 24 | parser.add_argument("--query_image", type=str, help="query image (single image, dir, or image file)") 25 | parser.add_argument("--query_number", type=int, default=150, help="query number") 26 | 27 | args = parser.parse_args() 28 | 29 | config = parse_config(args.config_path) 30 | config = merge_config(config, vars(args)) 31 | 32 | # 初始化特征提取器 33 | model = ModelInitializer(checkpoint=config.checkpoint).init_model() 34 | feature_extractor = FeatureExtractor(model) 35 | 36 | # 提取query images的特征 37 | test_feature_dict, _ = feature_extractor.extract(img_path=config.query_image) 38 | 39 | # lsh_paths = config.lsh_paths.hash_size_zero 40 | lsh_paths = config.lsh_paths.hash_size_eight 41 | for lsh_path in lsh_paths: 42 | lsh = pickle.load(open(lsh_path, "rb")) 43 | print("=" * 60) 44 | # retrieval test 45 | similar_img_dict, similar_img_list = ImageRetriever().retrieval(test_feature_dict, 46 | lsh, num_results=config.query_number, threshold=0.7) 47 | print(similar_img_dict) 48 | print(similar_img_list) 49 | 50 | print(">>> copying...") 51 | if config.out_similar_dir and os.path.isdir(config.out_similar_dir): 52 | for (similar_img_path, similarity_score) in tqdm(similar_img_list): 53 | basename = os.path.basename(similar_img_path) 54 | new_path = os.path.join(config.out_similar_dir, basename) 55 | shutil.copy(similar_img_path, new_path) 56 | 57 | print("done...") 58 | -------------------------------------------------------------------------------- /retrieval_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :retrieval_demo.py 4 | # @Time :2021/12/10 下午7:39 5 | # @Author :Chang Qing 6 | 7 | import os 8 | import shutil 9 | import argparse 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "8" 12 | 13 | from tqdm import tqdm 14 | from modules.solver import ModelInitializer 15 | from modules.solver import FeatureExtractor 16 | from modules.solver import ImageRetriever 17 | from utils.config_util import parse_config, merge_config 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser(description="Image Retrieval Script") 21 | parser.add_argument("--config_path", type=str, default="configs/retrieval.yaml", 22 | help="config file path for image retrieval") 23 | parser.add_argument("--query_image", type=str, help="query image (single image, dir, or image file)") 24 | parser.add_argument("--image_gallery", type=str, default="data/images", help="image gallery") 25 | parser.add_argument("--query_number", type=int, default=400, help="query number") 26 | 27 | args = parser.parse_args() 28 | 29 | args.image_gallery = "/data1/changqing/ZyImage_Data/auto_ai_cls8/其他" 30 | config = parse_config(args.config_path) 31 | config = merge_config(config, vars(args)) 32 | 33 | # 初始化特征提取器 34 | model = ModelInitializer(checkpoint=config.checkpoint).init_model() 35 | feature_extractor = FeatureExtractor(model) 36 | # date_strs = ["20210713", "20210716", "20210726"] 37 | # for date_str in date_strs: 38 | # config.image_gallery = f"/data1/changqing/ZyImage_Data/image_gallery/{date_str}/imgs/" 39 | # config.feature_path = f"/data1/changqing/ZyImage_Data/image_gallery/{date_str}/features_dim-2048.pkl" 40 | # config.lsh_index_path = f"/data1/changqing/ZyImage_Data/image_gallery/{date_str}/lsh_hash-size-00_input-idm-2048.pkl" 41 | # gallery_feature_dict, lsh = feature_extractor.extract(img_path=config.image_gallery, 42 | # lsh_config=config.lsh_config, 43 | # feature_path=config.feature_path, 44 | # index_path=config.lsh_index_path) 45 | 46 | 47 | config.feature_path = "" 48 | # config.lsh_index_path = "" 49 | # gallery_feature_dict, lsh = feature_extractor.extract(img_path=config.image_gallery, 50 | # lsh_config=config.lsh_config, 51 | # feature_path=config.feature_path, 52 | # index_path=config.lsh_index_path) 53 | 54 | # 提取query images的特征 55 | test_feature_dict, _ = feature_extractor.extract(img_path=config.query_image) 56 | 57 | print("=" * 60) 58 | 59 | # retrieval test 60 | similar_img_dict, similar_img_list = ImageRetriever(lsh_path=config.lsh_index_path).retrieval(test_feature_dict, 61 | num_results=config.query_number, threshold=-1) 62 | print(similar_img_dict) 63 | print(similar_img_list) 64 | print("=" * 60) 65 | 66 | print(">>> copying...") 67 | if config.out_similar_dir and os.path.isdir(config.out_similar_dir): 68 | for (similar_img_path, similarity_score) in tqdm(similar_img_list): 69 | basename = os.path.basename(similar_img_path) 70 | new_path = os.path.join(config.out_similar_dir, basename) 71 | if os.path.exists(similar_img_path): 72 | shutil.copy(similar_img_path, new_path) 73 | 74 | # print("done...") 75 | -------------------------------------------------------------------------------- /scripts/download_imgs_mp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :download_imgs_mp.py 4 | # @Time :2021/10/27 下午3:29 5 | # @Author :Chang Qing 6 | 7 | # !/usr/bin/env python 8 | # -*- coding:utf-8 -*- 9 | # @FileName :imgs_downloader_mp.py 10 | # @Time :2021/9/3 下午4:11 11 | # @Author :Chang Qing 12 | 13 | import os 14 | import time 15 | import random 16 | import requests 17 | import argparse 18 | import traceback 19 | 20 | from glob import glob 21 | from tqdm import tqdm 22 | from multiprocessing import Pool 23 | 24 | requests.DEFAULT_RETRIES = 5 25 | s = requests.session() 26 | s.keep_alive = False 27 | random.seed(666) 28 | 29 | 30 | def download_img(item): 31 | # name, url = item 32 | name, url = item.split("\t") 33 | img_name = os.path.join(imgs_root, f"{name}.jpg") 34 | if not os.path.exists(img_name): 35 | try: 36 | res = requests.get(url, timeout=1) 37 | if res.status_code != 200: 38 | raise Exception 39 | with open(img_name, "wb") as f: 40 | f.write(res.content) 41 | except Exception as e: 42 | print(name, url) 43 | traceback.print_exc() 44 | 45 | 46 | def build_url_list(url_file): 47 | url_list = [] 48 | with open(url_file) as f: 49 | lines = f.readlines() 50 | for i, line in enumerate(lines): 51 | url_list.append([str(i), line.strip()]) 52 | return url_list 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser(description="Images Download Script") 57 | parser.add_argument("--img_root", default="/Users/zuiyou/PycharmProjects/Image_Process_Tool/images/生活_股票基金_股票k线", type=str, 58 | help="the directory of images") 59 | parser.add_argument("--workers", default=3, type=int, help="the nums of process") 60 | args = parser.parse_args() 61 | 62 | imgs_root = args.img_root 63 | workers = args.workers 64 | 65 | os.makedirs(imgs_root, exist_ok=True) 66 | 67 | url_file = "imgs.txt" 68 | # url_file = "others_20211214-20211223.txt" 69 | items = open(url_file).readlines() 70 | # other类太多,分批处理, 一次处理20000张 71 | items = [item.strip() for item in items if item][:10000] 72 | # url_list = build_url_list(url_file) 73 | print(items[:5]) 74 | print(f"total items: {len(items)}") 75 | 76 | # random.shuffle(url_list) 77 | # url_list = url_list[:180000] 78 | 79 | tik_time = time.time() 80 | # create multiprocess pool 81 | pool = Pool(workers) # process num: 20 82 | 83 | # 如果check_img函数仅有1个参数,用map方法 84 | # pool.map(check_img, img_paths) 85 | # 如果check_img函数有不止1个参数,用apply_async方法 86 | # for img_path in tqdm(img_paths): 87 | # pool.apply_async(check_img, (img_path, False)) 88 | list(tqdm(iterable=(pool.imap(download_img, items)), total=len(items))) 89 | pool.close() 90 | pool.join() 91 | tok_time = time.time() 92 | print(tok_time - tik_time) 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /scripts/img_downloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :img_downloader.py 4 | # @Time :2021/12/14 下午2:22 5 | # @Author :Chang Qing 6 | 7 | import os 8 | import json 9 | import requests 10 | import argparse 11 | import traceback 12 | 13 | from tqdm import tqdm 14 | from PIL import Image 15 | from io import BytesIO 16 | from multiprocessing import Pool 17 | 18 | 19 | def download_img(item): 20 | pid, img_id, url = item 21 | img_name = os.path.join(img_save_dir, f"{pid}_{img_id}.jpg") 22 | if not os.path.exists(img_name): 23 | try: 24 | res = requests.get(url, timeout=1) 25 | img = Image.open(BytesIO(res.content)).convert("RGB") 26 | img.verify() 27 | if res.status_code != 200: 28 | raise Exception 29 | with open(img_name, "wb") as f: 30 | f.write(res.content) 31 | except Exception as e: 32 | print(pid, img_id, url) 33 | traceback.print_exc() 34 | 35 | 36 | def build_task_list(data_info): 37 | task_list = [] 38 | for pid, inv_info in data_info.items(): 39 | img_id2img_info = inv_info["img_id2img_info"] 40 | for img_id, img_info in img_id2img_info.items(): 41 | url = img_info["image_url"] 42 | task_list.append((pid, img_id, url)) 43 | return task_list 44 | 45 | 46 | def build_task_list2(data_info): 47 | task_list = [] 48 | for name, url in data_info.items(): 49 | pid, img_id = name.split("_") 50 | url = f"http://tbfile.ixiaochuan.cn/img/view/id/{img_id}" 51 | task_list.append((pid, img_id, url)) 52 | return task_list 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description="Zuiyou Image Download Script") 57 | parser.add_argument("--img_root", type=str, default="/data1/changqing/ZyImage_Data/image_gallery") 58 | parser.add_argument("--date_str", type=str, default="20211231", help="the date str of images") 59 | parser.add_argument("--workers", type=int, default=4, help="the nums of process") 60 | 61 | args = parser.parse_args() 62 | 63 | workers = args.workers 64 | img_root = args.img_root 65 | date_str = args.date_str 66 | img_save_dir = os.path.join(img_root, date_str, "imgs") 67 | os.makedirs(img_save_dir, exist_ok=True) 68 | print(img_save_dir) 69 | img_info_path = os.path.join(img_root, date_str, f"imgtag_name2url_{date_str}.json") 70 | 71 | img_info = json.load(open(img_info_path)) 72 | # task_list = build_task_list(img_info) 73 | task_list = build_task_list2(img_info) 74 | 75 | pool = Pool(workers) 76 | list(tqdm(iterable=(pool.imap(download_img, task_list)), total=len(task_list))) 77 | pool.close() # 调用join之前,先调用close函数,否则会出错。执行完close后不会有新的进程加入到pool 78 | pool.join() # join函数等待所有子进程结束 79 | print("Done...") -------------------------------------------------------------------------------- /scripts/lshash_indexer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :lshash_indexer.py 4 | # @Time :2021/12/16 下午9:42 5 | # @Author :Chang Qing 6 | 7 | import os 8 | import pickle 9 | 10 | from tqdm import tqdm 11 | from modules.lshash import LSHash 12 | from utils.common_util import fix_seed 13 | 14 | feature_root = "/data1/changqing/ZyImage_Data/image_gallery" 15 | date_strs = os.listdir(feature_root) 16 | 17 | fix_seed() 18 | hash_size = 8 19 | input_dim = 2048 20 | num_hash_tables = 1 21 | 22 | for date_str in tqdm(date_strs): 23 | lsh_indexer = LSHash(hash_size=hash_size, input_dim=input_dim, num_hashtables=num_hash_tables) 24 | feature_path = f"{feature_root}/{date_str}/features_dim-2048.pkl" 25 | if os.path.isfile(feature_path): 26 | feature_dict = pickle.load(open(feature_path, "rb")) 27 | lsh_indexer_path = f"{feature_root}/{date_str}/" +\ 28 | "lsh_indexer-size-{:0>2}_input-idm-{:0>4}.pkl".format(hash_size, input_dim) 29 | if os.path.exists(lsh_indexer_path): 30 | continue 31 | for img_path, vec in feature_dict.items(): 32 | # 使用flatten展成1维向量 33 | lsh_indexer.index(vec.flatten(), extra_data=img_path) 34 | pickle.dump(lsh_indexer, open(lsh_indexer_path, "wb")) 35 | 36 | # lsh_indexer = pickle.load(open("/data1/changqing/ZyImage_Data/image_gallery/20211117/lsh_indexer-size-08_input-idm-2048.pkl", "rb")) 37 | # print(lsh_indexer) -------------------------------------------------------------------------------- /scripts/remove_file.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :remove_file.py 4 | # @Time :2022/4/14 下午9:16 5 | # @Author :Chang Qing 6 | 7 | 8 | import os 9 | from utils.common_util import save_to_txt 10 | 11 | file_names = os.listdir("../data/output") 12 | print(file_names) 13 | print(len(file_names)) 14 | dir_path = "/data1/changqing/ZyImage_Data/auto_ai_cls8/其他" 15 | for file_name in file_names: 16 | 17 | file_path = os.path.join(dir_path, file_name) 18 | if os.path.exists(file_path): 19 | os.remove(file_path) 20 | print(f"remove {file_path}") 21 | 22 | save_to_txt(file_names, os.path.join(dir_path, "removed6.txt")) 23 | 24 | 25 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :test.py 4 | # @Time :2021/12/15 下午3:39 5 | # @Author :Chang Qing 6 | import os 7 | from utils.common_util import save_to_json 8 | img_root = "/data1/changqing/ZyImage_Data/image_gallery" 9 | 10 | date_strs = ["20211227", "20211228", "20211229", "20211230", "20211231", 11 | "20220101", "20220102", "20220103", "20220104", "20220105"] 12 | data_file = "/data1/changqing/ZyImage_Data/image_gallery/pid_imgid_urls.txt" 13 | data_root = "/data1/changqing/ZyImage_Data/image_gallery" 14 | lines = open(data_file).readlines() 15 | 16 | 17 | # for i, date_str in enumerate(date_strs): 18 | # temp_lines = lines[i * 80000: (i+1) * 80000] 19 | # data_dir = os.path.join(data_root, date_str) 20 | # imgtag_name2url = dict() 21 | # for line in temp_lines: 22 | # if not line: 23 | # continue 24 | # name, url = line.strip().split("\t") 25 | # imgtag_name2url[name] = url 26 | # save_dir = os.path.join(data_dir, f"imgtag_name2url_{date_str}.json") 27 | # save_to_json(imgtag_name2url, save_dir) 28 | 29 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :__init__.py.py 4 | # @Time :2021/12/10 下午7:22 5 | # @Author :Chang Qing 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /utils/common_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :comm_util.py 4 | # @Time :2021/3/26 上午11:18 5 | # @Author :Chang Qing 6 | 7 | import os 8 | import time 9 | import json 10 | import datetime 11 | import hashlib 12 | import torch 13 | import random 14 | import numpy as np 15 | 16 | from collections import OrderedDict 17 | 18 | 19 | ######################################################################### 20 | ############################ 用于数值记录 ################################# 21 | ######################################################################### 22 | 23 | class AverageMeter(object): 24 | """Computes and stores the average and current value 25 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 26 | """ 27 | 28 | def __init__(self): 29 | self.reset() 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | 37 | def update(self, val, n=1): 38 | self.val = val 39 | self.sum += val * n 40 | self.count += n 41 | self.avg = self.sum / self.count 42 | 43 | 44 | ######################################################################### 45 | ############################ 时间相关 #################################### 46 | ######################################################################### 47 | 48 | def get_time_str(): 49 | timestamp = time.time() 50 | time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(timestamp)) 51 | return time_str 52 | 53 | def get_date_str(n_days_ago=0): 54 | focus_day = datetime.datetime.today().date() - datetime.timedelta(days=n_days_ago) 55 | focus_day_str = time.strftime("%Y%m%d", time.strptime(str(focus_day), '%Y-%m-%d')) 56 | # focus_day_str = time.mktime(time.strptime(str(focus_day), '%Y-%m-%d')) 57 | return focus_day_str 58 | 59 | def format_cost_time(cost_time): 60 | """ 61 | :param cost_time: 毫秒数, 一般指时间戳之差 62 | :return: 格式化后,易读的时间字符串 63 | """ 64 | cost_time = round(cost_time) 65 | 66 | days = cost_time // 86400 67 | hours = cost_time // 3600 % 24 68 | minutes = cost_time // 60 % 60 69 | seconds = cost_time % 60 70 | if days > 0: 71 | return '{:d}d {:d}h {:d}m {:d}s'.format(days, hours, minutes, seconds) 72 | if hours > 0: 73 | return '{:d}h {:d}m {:d}s'.format(hours, minutes, seconds) 74 | if minutes > 0: 75 | return '{:d}m {:d}s'.format(minutes, seconds) 76 | return '{:d}s'.format(seconds) 77 | 78 | 79 | ######################################################################### 80 | ############################ 文件相关 #################################### 81 | ######################################################################### 82 | 83 | def save_to_json(save_target, save_path): 84 | json.dump(save_target, fp=open(save_path, "w"), indent=4, ensure_ascii=False) 85 | 86 | def save_to_txt(item_list, save_path): 87 | item_list = [item + "\n" if not item.endswith("\n") else item for item in item_list] 88 | # print(item_list) 89 | with open(save_path, "w") as f: 90 | f.writelines(item_list) 91 | 92 | ######################################################################### 93 | ############################ 路径相关 #################################### 94 | ######################################################################### 95 | 96 | def get_root(): 97 | return os.path.join(os.path.dirname(os.path.realpath(__file__))) 98 | 99 | def get_data_root(): 100 | return os.path.join(get_root(), 'data') 101 | 102 | 103 | ######################################################################### 104 | ############################ 其他 ####################################### 105 | ######################################################################### 106 | 107 | def sort_dict(ori_dict, by_key=False, reverse=False): 108 | """ 109 | sorted dict by key or value 110 | :param ori_dict: 111 | :param by_key: sorted by key or value 112 | :param reverse: if reverse is true, big to small. if false, small to big 113 | :return: OrderedDict 114 | """ 115 | ordered_list = sorted(ori_dict.items(), key=lambda item: item[0] if by_key else item[1]) 116 | ordered_list = ordered_list[::-1] if reverse else ordered_list 117 | new_dict = OrderedDict(ordered_list) 118 | return new_dict 119 | 120 | def build_mapping_from_list(name_list): 121 | name_list = [name for name in list(set(name_list)) if name] 122 | idx2name_map = OrderedDict() 123 | name2idx_map = OrderedDict() 124 | for idx, name in enumerate(name_list): 125 | idx2name_map[str(idx)] = name 126 | name2idx_map[name] = str(idx) 127 | return idx2name_map, name2idx_map 128 | 129 | def sha256_hash(filename, block_size=65536, length=8): 130 | sha256 = hashlib.sha256() 131 | with open(filename, 'rb') as f: 132 | for block in iter(lambda: f.read(block_size), b''): 133 | sha256.update(block) 134 | return sha256.hexdigest()[:length - 1] 135 | 136 | 137 | def fix_seed(seed=1): 138 | random.seed(seed) 139 | np.random.seed(seed) 140 | torch.manual_seed(seed) 141 | if torch.cuda.is_available(): 142 | torch.cuda.manual_seed(seed) 143 | 144 | if __name__ == "__main__": 145 | a = ["1", "2" + "\n", "3", "4" + "\n", "5"] 146 | # save_to_txt(a, "test.txt") 147 | 148 | ori_dict = { 149 | "a": 3, 150 | "c": 1, 151 | "b": 2, 152 | } 153 | print(ori_dict) 154 | print(sort_dict(ori_dict, by_key=False, reverse=False)) 155 | print(sort_dict(ori_dict, by_key=False, reverse=True)) 156 | print(sort_dict(ori_dict, by_key=True, reverse=False)) 157 | print(sort_dict(ori_dict, by_key=True, reverse=True)) 158 | -------------------------------------------------------------------------------- /utils/config_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :config_util.py 4 | # @Time :2021/3/26 上午11:18 5 | # @Author :Chang Qing 6 | 7 | import json 8 | import yaml 9 | 10 | from typing import Any 11 | 12 | __all__ = ["parse_config", "merge_config", "print_config"] 13 | 14 | 15 | class AttrDict(dict): 16 | def __setattr__(self, key: str, value: Any) -> None: 17 | self[key] = value 18 | 19 | def __getattr__(self, key): 20 | return self[key] 21 | 22 | 23 | def recursive_convert(attr_dict): 24 | if not isinstance(attr_dict, dict): 25 | return attr_dict 26 | obj_dict = AttrDict() 27 | for key, value in attr_dict.items(): 28 | obj_dict[key] = recursive_convert(value) 29 | return obj_dict 30 | 31 | 32 | def parse_config(cfg_file): 33 | with open(cfg_file, "r") as f: 34 | # == AttrDict(yaml.load(f.read())) 35 | attr_dict_conf = AttrDict(yaml.load(f, Loader=yaml.Loader)) 36 | obj_dict_conf = recursive_convert(attr_dict_conf) 37 | return obj_dict_conf 38 | 39 | 40 | def merge_config(cfg, args_dict): 41 | for key, value in args_dict.items(): 42 | if not value: 43 | continue 44 | try: 45 | if hasattr(cfg, key): 46 | setattr(cfg, key, value) 47 | except Exception as e: 48 | pass 49 | return cfg 50 | 51 | 52 | def print_config(config): 53 | try: 54 | print(json.dumps(config, indent=4)) 55 | except: 56 | print(json.dumps(config.__dict__, indent=4)) 57 | 58 | 59 | if __name__ == '__main__': 60 | # temp_config = {'task_name': 'test', 'task_type': 'multi_class', 'n_gpus': 2, 61 | # 'id2name': 'tasks/test/data/id2name.json', 'arch_type': 'efficentnet_b5', 'num_classes': 4, 62 | # 'train_file': 'tasks/test/data/train.txt', 'valid_file': 'tasks/test/data/valid.txt', 'batch_size': 4, 63 | # 'epochs': 2, 'save_dir': 'tasks/test/workshop'} 64 | temp_config = {"batch_size": 1} 65 | train_config = parse_config("../configs/model_config/imgtag_multi_class_train.yaml") 66 | print(train_config) 67 | new_config = merge_config(train_config, temp_config) 68 | print(new_config) 69 | -------------------------------------------------------------------------------- /utils/download_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def download_test(data_dir): 4 | """ 5 | DOWNLOAD_TEST Checks, and, if required, downloads the necessary datasets for the testing. 6 | 7 | download_test(DATA_ROOT) checks if the data necessary for running the example script exist. 8 | If not it downloads it in the folder structure: 9 | DATA_ROOT/test/oxford5k/ : folder with Oxford images and ground truth file 10 | DATA_ROOT/test/paris6k/ : folder with Paris images and ground truth file 11 | DATA_ROOT/test/roxford5k/ : folder with Oxford images and revisited ground truth file 12 | DATA_ROOT/test/rparis6k/ : folder with Paris images and revisited ground truth file 13 | """ 14 | 15 | # Create data folder if it does not exist 16 | if not os.path.isdir(data_dir): 17 | os.mkdir(data_dir) 18 | 19 | # Create datasets folder if it does not exist 20 | datasets_dir = os.path.join(data_dir, 'test') 21 | if not os.path.isdir(datasets_dir): 22 | os.mkdir(datasets_dir) 23 | 24 | # Download datasets folders test/DATASETNAME/ 25 | datasets = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] 26 | for di in range(len(datasets)): 27 | dataset = datasets[di] 28 | 29 | if dataset == 'oxford5k': 30 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 31 | dl_files = ['oxbuild_images.tgz'] 32 | elif dataset == 'paris6k': 33 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 34 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 35 | elif dataset == 'roxford5k': 36 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 37 | dl_files = ['oxbuild_images.tgz'] 38 | elif dataset == 'rparis6k': 39 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 40 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 41 | else: 42 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 43 | 44 | dst_dir = os.path.join(datasets_dir, dataset, 'jpg') 45 | if not os.path.isdir(dst_dir): 46 | 47 | # for oxford and paris download images 48 | if dataset == 'oxford5k' or dataset == 'paris6k': 49 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 50 | os.makedirs(dst_dir) 51 | for dli in range(len(dl_files)): 52 | dl_file = dl_files[dli] 53 | src_file = os.path.join(src_dir, dl_file) 54 | dst_file = os.path.join(dst_dir, dl_file) 55 | print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file)) 56 | os.system('wget {} -O {}'.format(src_file, dst_file)) 57 | print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file)) 58 | # create tmp folder 59 | dst_dir_tmp = os.path.join(dst_dir, 'tmp') 60 | os.system('mkdir {}'.format(dst_dir_tmp)) 61 | # extract in tmp folder 62 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp)) 63 | # remove all (possible) subfolders by moving only files in dst_dir 64 | os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir)) 65 | # remove tmp folder 66 | os.system('rm -rf {}'.format(dst_dir_tmp)) 67 | print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file)) 68 | os.system('rm {}'.format(dst_file)) 69 | 70 | # for roxford and rparis just make sym links 71 | elif dataset == 'roxford5k' or dataset == 'rparis6k': 72 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 73 | dataset_old = dataset[1:] 74 | dst_dir_old = os.path.join(datasets_dir, dataset_old, 'jpg') 75 | os.mkdir(os.path.join(datasets_dir, dataset)) 76 | os.system('ln -s {} {}'.format(dst_dir_old, dst_dir)) 77 | print('>> Created symbolic link from {} jpg to {} jpg'.format(dataset_old, dataset)) 78 | 79 | gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset) 80 | gnd_dst_dir = os.path.join(datasets_dir, dataset) 81 | gnd_dl_file = 'gnd_{}.pkl'.format(dataset) 82 | gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file) 83 | gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file) 84 | if not os.path.exists(gnd_dst_file): 85 | print('>> Downloading dataset {} ground truth file...'.format(dataset)) 86 | os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file)) 87 | 88 | 89 | def download_train(data_dir): 90 | """ 91 | DOWNLOAD_TRAIN Checks, and, if required, downloads the necessary datasets for the training. 92 | 93 | download_train(DATA_ROOT) checks if the data necessary for running the example script exist. 94 | If not it downloads it in the folder structure: 95 | DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db files 96 | DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db files 97 | """ 98 | 99 | # Create data folder if it does not exist 100 | if not os.path.isdir(data_dir): 101 | os.mkdir(data_dir) 102 | 103 | # Create datasets folder if it does not exist 104 | datasets_dir = os.path.join(data_dir, 'train') 105 | if not os.path.isdir(datasets_dir): 106 | os.mkdir(datasets_dir) 107 | 108 | # Download folder train/retrieval-SfM-120k/ 109 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'ims') 110 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 111 | dl_file = 'ims.tar.gz' 112 | if not os.path.isdir(dst_dir): 113 | src_file = os.path.join(src_dir, dl_file) 114 | dst_file = os.path.join(dst_dir, dl_file) 115 | print('>> Image directory does not exist. Creating: {}'.format(dst_dir)) 116 | os.makedirs(dst_dir) 117 | print('>> Downloading ims.tar.gz...') 118 | os.system('wget {} -O {}'.format(src_file, dst_file)) 119 | print('>> Extracting {}...'.format(dst_file)) 120 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir)) 121 | print('>> Extracted, deleting {}...'.format(dst_file)) 122 | os.system('rm {}'.format(dst_file)) 123 | 124 | # Create symlink for train/retrieval-SfM-30k/ 125 | dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 126 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims') 127 | if not os.path.isdir(dst_dir): 128 | os.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k')) 129 | os.system('ln -s {} {}'.format(dst_dir_old, dst_dir)) 130 | print('>> Created symbolic link from retrieval-SfM-120k/ims to retrieval-SfM-30k/ims') 131 | 132 | # Download db files 133 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'dbs') 134 | datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k'] 135 | for dataset in datasets: 136 | dst_dir = os.path.join(datasets_dir, dataset) 137 | if dataset == 'retrieval-SfM-120k': 138 | dl_files = ['{}.pkl'.format(dataset), '{}-whiten.pkl'.format(dataset)] 139 | elif dataset == 'retrieval-SfM-30k': 140 | dl_files = ['{}-whiten.pkl'.format(dataset)] 141 | 142 | if not os.path.isdir(dst_dir): 143 | print('>> Dataset directory does not exist. Creating: {}'.format(dst_dir)) 144 | os.mkdir(dst_dir) 145 | 146 | for i in range(len(dl_files)): 147 | src_file = os.path.join(src_dir, dl_files[i]) 148 | dst_file = os.path.join(dst_dir, dl_files[i]) 149 | if not os.path.isfile(dst_file): 150 | print('>> DB file {} does not exist. Downloading...'.format(dl_files[i])) 151 | os.system('wget {} -O {}'.format(src_file, dst_file)) 152 | 153 | -------------------------------------------------------------------------------- /utils/evalute_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :evalute_util.py 4 | # @Time :2021/12/14 下午9:13 5 | # @Author :Chang Qing 6 | 7 | ''' 8 | Author: yinhao 9 | Email: yinhao_x@163.com 10 | Wechat: xss_yinhao 11 | Github: http://github.com/yinhaoxs 12 | data: 2019-11-23 18:27 13 | desc: 14 | ''' 15 | 16 | import os 17 | import shutil 18 | import numpy as np 19 | import pandas as pd 20 | 21 | 22 | class EvaluteMap(): 23 | def __init__(self, out_similar_dir='', out_similar_file_dir='', all_csv_file='', feature_path='', index_path=''): 24 | self.out_similar_dir = out_similar_dir 25 | self.out_similar_file_dir = out_similar_file_dir 26 | self.all_csv_file = all_csv_file 27 | self.feature_path = feature_path 28 | self.index_path = index_path 29 | 30 | 31 | def get_dict(self, query_no, query_id, simi_no, simi_id, num, score): 32 | new_dict = { 33 | 'index': str(num), 34 | 'id1': str(query_id), 35 | 'id2': str(simi_id), 36 | 'no1': str(query_no), 37 | 'no2': str(simi_no), 38 | 'score': score 39 | } 40 | return new_dict 41 | 42 | 43 | def find_similar_img_gyz(self, feature_dict, lsh, num_results): 44 | for q_path, q_vec in feature_dict.items(): 45 | try: 46 | response = lsh.query(q_vec.flatten(), num_results=int(num_results), distance_func="cosine") 47 | query_img_path0 = response[0][0][1] 48 | query_img_path1 = response[1][0][1] 49 | query_img_path2 = response[2][0][1] 50 | # score0 = response[0][1] 51 | # score0 = np.rint(100 * (1 - score0)) 52 | print('**********************************************') 53 | print('input img: {}'.format(q_path)) 54 | print('query0 img: {}'.format(query_img_path0)) 55 | print('query1 img: {}'.format(query_img_path1)) 56 | print('query2 img: {}'.format(query_img_path2)) 57 | except: 58 | continue 59 | 60 | 61 | def find_similar_img(self, feature_dict, lsh, num_results): 62 | num = 0 63 | result_list = list() 64 | for q_path, q_vec in feature_dict.items(): 65 | response = lsh.query(q_vec.flatten(), num_results=int(num_results), distance_func="cosine") 66 | s_path_list, s_vec_list, s_id_list, s_no_list, score_list = list(), list(), list(), list(), list() 67 | q_path = q_path[0] 68 | q_no, q_id = q_path.split("\\")[-2], q_path.split("\\")[-1] 69 | try: 70 | for i in range(int(num_results)): 71 | s_path, s_vec = response[i][0][1], response[i][0][0] 72 | s_path = s_path[0] 73 | s_no, s_id = s_path.split("\\")[-2], s_path.split("\\")[-1] 74 | if str(s_no) != str(q_no): 75 | score = np.rint(100 * (1 - response[i][1])) 76 | s_path_list.append(s_path) 77 | s_vec_list.append(s_vec) 78 | s_id_list.append(s_id) 79 | s_no_list.append(s_no) 80 | score_list.append(score) 81 | else: 82 | continue 83 | 84 | if len(s_path_list) != 0: 85 | index = score_list.index(max(score_list)) 86 | s_path, s_vec, s_id, s_no, score = s_path_list[index], s_vec_list[index], s_id_list[index], \ 87 | s_no_list[index], score_list[index] 88 | else: 89 | s_path, s_vec, s_id, s_no, score = None, None, None, None, None 90 | except: 91 | s_path, s_vec, s_id, s_no, score = None, None, None, None, None 92 | 93 | try: 94 | ##拷贝文件到指定文件夹 95 | num += 1 96 | des_path = os.path.join(self.out_similar_dir, str(num)) 97 | if not os.path.exists(des_path): 98 | os.makedirs(des_path) 99 | shutil.copy(q_path, des_path) 100 | os.rename(os.path.join(des_path, q_id), os.path.join(des_path, "query_" + q_no + "_" + q_id)) 101 | if s_path != None: 102 | shutil.copy(s_path, des_path) 103 | os.rename(os.path.join(des_path, s_id), os.path.join(des_path, s_no + "_" + s_id)) 104 | 105 | new_dict = self.get_dict(q_no, q_id, s_no, s_id, num, score) 106 | result_list.append(new_dict) 107 | except: 108 | continue 109 | 110 | try: 111 | result_s = pd.DataFrame.from_dict(result_list) 112 | result_s.to_csv(self.all_csv_file, encoding="gbk", index=False) 113 | except: 114 | print("write error") 115 | 116 | 117 | def filter_gap_score(self): 118 | for value in range(90, 101): 119 | try: 120 | pd_df = pd.read_csv(self.all_csv_file, encoding="gbk", error_bad_lines=False) 121 | pd_tmp = pd_df[pd_df["score"] == int(value)] 122 | if not os.path.exists(self.out_similar_file_dir): 123 | os.makedirs(self.out_similar_file_dir) 124 | 125 | try: 126 | results_split_csv = os.path.join(self.out_similar_file_dir + os.sep, 127 | "filter_{}.csv".format(str(value))) 128 | pd_tmp.to_csv(results_split_csv, encoding="gbk", index=False) 129 | except: 130 | print("write part error") 131 | 132 | lines = pd_df[pd_df["score"] == int(value)]["index"] 133 | num = 0 134 | for line in lines: 135 | des_path_temp = os.path.join(self.out_similar_file_dir + os.sep, str(value), str(line)) 136 | if not os.path.exists(des_path_temp): 137 | os.makedirs(des_path_temp) 138 | pairs_path = os.path.join(self.out_similar_dir + os.sep, str(line)) 139 | for img_id in os.listdir(pairs_path): 140 | img_path = os.path.join(pairs_path + os.sep, img_id) 141 | shutil.copy(img_path, des_path_temp) 142 | except: 143 | print("error") 144 | 145 | 146 | def retrieval_images(self, feature_dict, lsh, num_results=1): 147 | # load model 148 | # with open(self.feature_path, "rb") as f: 149 | # feature_dict = pickle.load(f) 150 | # with open(self.index_path, "rb") as f: 151 | # lsh = pickle.load(f) 152 | 153 | self.find_similar_img_gyz(feature_dict, lsh, num_results) 154 | # self.filter_gap_score() 155 | 156 | 157 | if __name__ == "__main__": 158 | pass 159 | 160 | 161 | -------------------------------------------------------------------------------- /utils/image_processor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @FileName :image_processor.py 4 | # @Time :2021/12/10 下午8:04 5 | # @Author :Chang Qing 6 | 7 | import os 8 | from PIL import Image 9 | 10 | 11 | class ImageProcessor: 12 | def __init__(self, img_dir): 13 | self.img_dir = img_dir 14 | 15 | def process(self, ratio_limit=0): 16 | img_list = [] 17 | for root, dir, file_list in os.walk(self.img_dir): 18 | for file_name in file_list: 19 | if file_name[-4:].lower() in [".jpg", ".png", "jpeg"]: 20 | img_path = os.path.join(root, file_name) 21 | img_list.append(img_path) 22 | # try: 23 | # image = Image.open(img_path).convert("RGB") 24 | # if ratio_limit and max(image.size) / min(image.size) > ratio_limit: 25 | # continue 26 | # else: 27 | # img_list.append(img_path) 28 | # except: 29 | # pass 30 | return img_list 31 | 32 | 33 | if __name__ == '__main__': 34 | image_processor = ImageProcessor("/data1/changqing/ZyImage_Data/image_gallery/") 35 | image_list = image_processor.process() 36 | print(image_list[:10]) 37 | print(len(image_list)) 38 | --------------------------------------------------------------------------------