├── .gitignore ├── LICENSE ├── README.md ├── demo ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── test1.csv └── test2.csv ├── demo_override ├── README.md ├── main_override.py ├── test1.csv └── test2.csv ├── image_util_cli.py ├── main_multi.py ├── model_util.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .vscode 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ryan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Similarity 2 | 3 | This is an efficient utility of image similarity using [MobileNet](https://arxiv.org/abs/1704.04861) deep neural network. 4 | 5 | Image similarity is a task mostly about feature selection of the image. Here, the Convolutional Neural Network (CNN) is used to extract features of these images. It is a better way for computer to understand them effectively. 6 | 7 | This repository use a light-weight model, the MobileNet, to extract image features, then calculate their cosine distances as matrixes. The distance of two features will lie in `[-1, 1]`, where `-1` denotes the features are the most unlike, and `1` denotes they are the most similar. Choose a proper threshold `[-1, 1]`, the most similar images will be matched. 8 | 9 | ## Usage 10 | 11 | The code is written to match the similar images in a huge amount as efficiently as possible. 12 | 13 | To use it, two `.csv` source files should be prepared before running. Here is an example of one source file. By default, the `.csv` file should at least include one field that place the urls [[1]](#notice). 14 | 15 | ```text 16 | id,url 17 | 1,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/1.jpg 18 | 2,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/2.jpg 19 | 3,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/3.jpg 20 | 4,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/4.jpg 21 | 5,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/5.jpg 22 | 6,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/6.jpg 23 | ``` 24 | 25 | After that, we can setup the number of processes that are used to request images from the urls parallelly. For example, we use 2 processes with this tiny demo. 26 | 27 | ```python 28 | similarity.num_processes = 2 29 | ``` 30 | 31 | For feature extraction, a data generator is used to predict images with model batch by batch. By default, GPU will be used if it satisfy the conditions of [Tensorflow](https://www.tensorflow.org/install/gpu). Now we can set a proper size of batch based on the memory size of our computer or server. In this demo, we set it to 16. 32 | 33 | ```python 34 | similarity.batch_size = 16 35 | ``` 36 | 37 | After invoking the function `save_data()` two times, four self-generated files will be saved into `__generated__` directory with the file names of `_*_feature.h5` and `_*_fields.csv`. We can further calculate the similarities by calling `iteration()`, or load the generated files at any time afterward. 38 | 39 | Totally, the full example will look like: 40 | 41 | ```python 42 | similarity = ImageSimilarity() 43 | 44 | '''Setup''' 45 | similarity.batch_size = 16 46 | similarity.num_processes = 2 47 | 48 | '''Load source data''' 49 | test1 = similarity.load_data_csv('./demo/test1.csv', delimiter=',') 50 | test2 = similarity.load_data_csv('./demo/test2.csv', delimiter=',', cols=['id', 'url']) 51 | 52 | '''Save features and fields''' 53 | similarity.save_data('test1', test1) 54 | similarity.save_data('test2', test2) 55 | 56 | '''Calculate similarities''' 57 | result = similarity.iteration(['test1_id', 'test1_url', 'test2_id', 'test2_url'], thresh=0.845) 58 | print('Row for source file 1, and column for source file 2.') 59 | print(result) 60 | ``` 61 | 62 | or if the files have been generated before: 63 | 64 | ```python 65 | similarity = ImageSimilarity() 66 | similarity.iteration(['test1_id', 'test1_url', 'test2_id', 'test2_id'], thresh=0.845, title1='test1', title2='test2') 67 | ``` 68 | 69 | For practical usage, the `thresh` argument of `save_data()` is recommended to be in `[0.84, 1)`. One balanced value can be `0.845`. 70 | 71 | Any other details, please check the usages of each function given by `main_multi.py`. 72 | 73 | ## Requirements and Installation 74 | 75 | **NOTE**: Tensorflow is not included in `requirements.txt` due to the platform differences, please install and configure yourself based on your computer or server. Also note that `Python 3` is required. 76 | 77 | ```pip 78 | $ git clone https://github.com/ryanfwy/image-similarity.git 79 | $ cd image-similarity 80 | $ pip3 install -r requirements.txt 81 | ``` 82 | 83 | The requirements are also listed down bellow. 84 | 85 | - tensorflow: the newest version for CPU, or the version that matches your GPU and CUDA. 86 | - h5py~=2.6.0 87 | - numpy~=1.14.5 88 | - requests~=2.21.0 89 | 90 | ## Experiment 91 | 92 | In the demo, 6 and 3 images are used to match their similarities. 93 | 94 | ### Accuracy 95 | 96 | The cosine distances are shown in the table. 97 | 98 | | | | | | 99 | | --- | :---: | :---: | :---: | 100 | | | **0.9229318** | 0.5577963 | 0.5826051 | 101 | | | **0.84877944** | 0.538753 | 0.5624183 | 102 | | | **1.** | 0.5512465 | 0.57025677 | 103 | | | 0.5512465 | **0.99999994** | 0.54037786 | 104 | | | 0.57025677 | 0.54037786 | **0.9999998** | 105 | | | 0.5575757 | 0.5238174 | **0.91234696** | 106 | 107 | As it is shown, image similarity using deep neural network works fine. The distances of the matched images will roughly be greater than `0.84`. 108 | 109 | ### Efficiency 110 | 111 | For running efficiency, multi-processing and batch-wise prediction are used in feature extraction procedure. And thus, image requesting and processing in CPU, image prediction with model in GPU, will run simultaneously. In the procedure of similarity analysis, a matrix-wise mathematical method is used to avoid n*m iteration one by one. This may help a lot in the condition of low efficiency of python iteration, especially in a huge amount. 112 | 113 | Table bellow shows the time consumption runing with 8 processes in a practical case. The results are only for reference, they may change a lot based on the number of processes we use, the quality of the network, the image size of the online resources and so on. 114 | 115 | | | Source 1 | Source 2 | Iteration | 116 | | :---: | :---: | :---: | :---: | 117 | | Amount | 13501 | 21221 | 13501 * 21221 | 118 | | Time Consumption | 0:35:53 | 0:17:50 | 0:00:03.913282 | 119 | 120 | ## Notice 121 | 122 | [1] By default, the programme have to get the online images from urls we prepared in `.csv`. If we want to run the code with a list of offline images, we need to override the `_sub_process()` class method by ourselves. For demo and details, please check [demo_override](./demo_override). 123 | 124 | 125 | ## Thanks 126 | 127 | Demo images come from [ImageSimilarity](https://github.com/nivance/image-similarity) by [nivance](https://github.com/nivance). It is an another algorithm (pHash) of image similarity implementation in java. 128 | -------------------------------------------------------------------------------- /demo/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/1.jpg -------------------------------------------------------------------------------- /demo/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/2.jpg -------------------------------------------------------------------------------- /demo/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/3.jpg -------------------------------------------------------------------------------- /demo/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/4.jpg -------------------------------------------------------------------------------- /demo/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/5.jpg -------------------------------------------------------------------------------- /demo/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/6.jpg -------------------------------------------------------------------------------- /demo/test1.csv: -------------------------------------------------------------------------------- 1 | id,url 2 | 1,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/1.jpg 3 | 2,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/2.jpg 4 | 3,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/3.jpg 5 | 4,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/4.jpg 6 | 5,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/5.jpg 7 | 6,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/6.jpg -------------------------------------------------------------------------------- /demo/test2.csv: -------------------------------------------------------------------------------- 1 | id,url 2 | 3,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/3.jpg 3 | 4,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/4.jpg 4 | 5,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/5.jpg -------------------------------------------------------------------------------- /demo_override/README.md: -------------------------------------------------------------------------------- 1 | # Implement Your Own `_sub_process()` 2 | 3 | By default, the `.csv` source file should at least include one field that place the **urls**. In other words, the programme have to get the online images from urls. However, if we want to run the code with a list of offline images, we need to override the `_sub_process()` class method by ourselves. 4 | 5 | ## Implement the Subclass 6 | 7 | The implementation should look like: 8 | 9 | ```python 10 | class NewImageSimilarity(ImageSimilarity): 11 | @staticmethod 12 | def _sub_process(para): 13 | # Override the method from the base class 14 | path, fields = para['path'], para['fields'] 15 | try: 16 | feature = DeepModel.preprocess_image(path) 17 | return feature, fields 18 | 19 | except Exception as e: 20 | print('Error file %s: %s' % (fields[0], e)) 21 | 22 | return None, None 23 | ``` 24 | 25 | As it is shown, the method `_sub_process()` just simply remove one line `request.get(path)` and pass the `path` argument to `DeepModel.preprocess_image()` directly. 26 | 27 | In here, the `.csv` source file should at least include a field, such as `path`, to place all the local image paths. For example, it can be prepared like this. 28 | 29 | ``` 30 | id,path 31 | 3,../demo/3.jpg 32 | 4,../demo/4.jpg 33 | 5,../demo/5.jpg 34 | ``` 35 | 36 | The full example is also given in [main_override.py](./main_override.py). Please read it for more details about how to implement your own `_sub_process()` and run. 37 | 38 | ## Quick Preparation 39 | 40 | If we want to load a batch of offline image paths from the local directory which are prepared for `.csv` source file, the [image_util_cli.py](../image_util_cli.py) quick preparation script can easily do this job. 41 | 42 | To run this script, you should first put a batch of images into a directory, such as `source1`. The document tree will look like this. 43 | 44 | ``` 45 | ./source1 46 | |- image1.jpg 47 | |- image2.jpg 48 | |- ... 49 | |_ image100.jpg 50 | ``` 51 | 52 | After that, open `Terminal.app` (MacOS), `cd` to the directory of `image_util_cli.py`, and run it with the required arguments. 53 | 54 | ``` 55 | $ cd image-similarity 56 | $ python3 image_util_cli.py ./source1 -d '\t' -o ./images.csv 57 | ``` 58 | 59 | The usage of `image_util_cli.py` is given bellow. Also we can check it at any time by passing the argument `-h`. 60 | 61 | ``` 62 | usage: image_util_cli [-h] [-d DELIMITER] [-o OUT_PATH] source 63 | positional arguments: 64 | source directory of the source images 65 | 66 | optional arguments: 67 | -h, --help show this help message and exit 68 | -d DELIMITER, --delimiter DELIMITER 69 | delimiter to the output file, default: ',' 70 | -o OUT_PATH, --out-path OUT_PATH 71 | path to the output file, default: name of the source directory 72 | ``` 73 | -------------------------------------------------------------------------------- /demo_override/main_override.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | from main_multi import ImageSimilarity, DeepModel 5 | 6 | class NewImageSimilarity(ImageSimilarity): 7 | @staticmethod 8 | def _sub_process(para): 9 | # Override the method from the base class 10 | path, fields = para['path'], para['fields'] 11 | try: 12 | feature = DeepModel.preprocess_image(path) 13 | return feature, fields 14 | 15 | except Exception as e: 16 | print('Error file %s: %s' % (fields[0], e)) 17 | 18 | return None, None 19 | 20 | 21 | if __name__ == "__main__": 22 | similarity = NewImageSimilarity() 23 | 24 | '''Setup''' 25 | similarity.batch_size = 16 26 | similarity.num_processes = 2 27 | 28 | '''Load source data''' 29 | test1 = similarity.load_data_csv('./test1.csv', delimiter=',') 30 | test2 = similarity.load_data_csv('./test2.csv', delimiter=',', cols=['id', 'path']) 31 | 32 | '''Save features and fields''' 33 | similarity.save_data('test1', test1) 34 | similarity.save_data('test2', test2) 35 | -------------------------------------------------------------------------------- /demo_override/test1.csv: -------------------------------------------------------------------------------- 1 | id,path 2 | 1,../demo/1.jpg 3 | 2,../demo/2.jpg 4 | 3,../demo/3.jpg 5 | 4,../demo/4.jpg 6 | 5,../demo/5.jpg 7 | 6,../demo/6.jpg -------------------------------------------------------------------------------- /demo_override/test2.csv: -------------------------------------------------------------------------------- 1 | id,path 2 | 3,../demo/3.jpg 3 | 4,../demo/4.jpg 4 | 5,../demo/5.jpg -------------------------------------------------------------------------------- /image_util_cli.py: -------------------------------------------------------------------------------- 1 | '''CLI utility for image preparation.''' 2 | 3 | import os 4 | import argparse 5 | import numpy as np 6 | 7 | 8 | def process(input_dir, delimiter=',', output_path=None): 9 | '''Generate a `.csv` file with image paths.''' 10 | result = [['name', 'path']] 11 | file_names = os.listdir(input_dir) 12 | file_names.sort() 13 | for file_name in file_names: 14 | file_path = os.path.join(input_dir, file_name) 15 | result.append([os.path.splitext(file_name)[0], os.path.abspath(file_path)]) 16 | 17 | if output_path is None: 18 | parent_dir = list(filter(lambda x: not x == '', input_dir.split('/')))[-1] 19 | output_path = parent_dir + '.csv' 20 | 21 | np.savetxt(output_path, result, delimiter=delimiter, fmt='%s', encoding='utf-8') 22 | 23 | print('File saved to `%s`.' % output_path) 24 | 25 | def main(): 26 | '''CLI entrance.''' 27 | parser = argparse.ArgumentParser(prog='image_util_cli') 28 | parser.add_argument('source', action='store', type=str, help='directory of the source images') 29 | parser.add_argument('-d', '--delimiter', required=False, type=str, default=',', help="delimiter to the output file, default: ','") 30 | parser.add_argument('-o', '--out-path', required=False, type=str, help='path to the output file, default: name of the source directory') 31 | 32 | args = parser.parse_args() 33 | if args.source: 34 | if os.path.isdir(args.source) is False: 35 | exit('No directory `%s`.' % args.source) 36 | 37 | process(args.source, delimiter=args.delimiter, output_path=args.out_path) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /main_multi.py: -------------------------------------------------------------------------------- 1 | '''Image similarity using deep features. 2 | 3 | Recommendation: the threshold of the `DeepModel.cosine_distance` can be set as the following values. 4 | 0.84 = greater matches amount 5 | 0.845 = balance, default 6 | 0.85 = better accuracy 7 | ''' 8 | 9 | from io import BytesIO 10 | from multiprocessing import Pool 11 | 12 | import os 13 | import datetime 14 | import numpy as np 15 | import requests 16 | import h5py 17 | 18 | from model_util import DeepModel, DataSequence 19 | 20 | 21 | class ImageSimilarity(): 22 | '''Image similarity.''' 23 | def __init__(self): 24 | self._tmp_dir = './__generated__' 25 | self._batch_size = 64 26 | self._num_processes = 4 27 | self._model = None 28 | self._title = [] 29 | 30 | @property 31 | def batch_size(self): 32 | '''Batch size of model prediction.''' 33 | return self._batch_size 34 | 35 | @property 36 | def num_processes(self): 37 | '''Number of processes using `Multiprocessing.Pool`.''' 38 | return self._num_processes 39 | 40 | @batch_size.setter 41 | def batch_size(self, batch_size): 42 | self._batch_size = batch_size 43 | 44 | @num_processes.setter 45 | def num_processes(self, num_processes): 46 | self._num_processes = num_processes 47 | 48 | def _data_generation(self, args): 49 | '''Generate input batches for predict generator. 50 | 51 | Args: 52 | args: parameters that pass to `sub_process`. 53 | - path: path of the image, online url by default. 54 | - fields: all other fields. 55 | 56 | Returns: 57 | batch_x: a batch of predict samples. 58 | batch_fields: a batch of fields that matches the samples. 59 | ''' 60 | # Multiprocessing 61 | pool = Pool(self._num_processes) 62 | res = pool.map(self._sub_process, args) 63 | pool.close() 64 | pool.join() 65 | 66 | batch_x, batch_fields = [], [] 67 | for x, fields in res: 68 | if x is not None: 69 | batch_x.append(x) 70 | batch_fields.append(fields) 71 | 72 | return batch_x, batch_fields 73 | 74 | def _predict_generator(self, paras): 75 | '''Build a predict generator. 76 | 77 | Args: 78 | paras: input parameters of all samples. 79 | - path: path of the image, online url by default. 80 | - fields: all other fields. 81 | 82 | Returns: 83 | The predict generator. 84 | ''' 85 | return DataSequence(paras, self._data_generation, batch_size=self._batch_size) 86 | 87 | @staticmethod 88 | def _sub_process(para): 89 | '''A sub-process function of `multiprocessing`. 90 | 91 | Download image from url and process it into a numpy array. 92 | 93 | Args: 94 | para: input parameters of one image. 95 | - path: path of the image, online url by default. 96 | - fields: all other fields. 97 | 98 | Returns: 99 | feature: feature array of one image. 100 | fields: all other fields of one image that passed from `para`. 101 | 102 | Note: If error happens, `None` will be returned. 103 | ''' 104 | path, fields = para['path'], para['fields'] 105 | try: 106 | headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/67.0.3396.99 Safari/537.36'} 107 | res = requests.get(path, headers=headers) 108 | feature = DeepModel.preprocess_image(BytesIO(res.content)) 109 | return feature, fields 110 | 111 | except Exception as e: 112 | print('Error downloading %s: %s' % (fields[0], e)) 113 | 114 | return None, None 115 | 116 | @staticmethod 117 | def load_data_csv(fname, delimiter=None, include_header=True, cols=None): 118 | '''Load `.csv` file. Mostly it should be a file that list all fields to match. 119 | 120 | Args: 121 | fname: name or path to the file. 122 | delimiter: delimiter to split the content. 123 | include_header: whether the source file include header or not. 124 | cols: a list of columns to read. Pass `None` to read all columns. 125 | 126 | Returns: 127 | A list of data. 128 | ''' 129 | assert delimiter is not None, 'Delimiter is required.' 130 | 131 | if include_header: 132 | usecols = None 133 | skip_header = 1 134 | if cols: 135 | with open(fname, 'r', encoding='utf-8') as f: 136 | csv_head = f.readline().strip().split(delimiter) 137 | 138 | usecols = [csv_head.index(col) for col in cols] 139 | 140 | else: 141 | usecols = None 142 | skip_header = 0 143 | 144 | data = np.genfromtxt( 145 | fname, 146 | dtype=str, 147 | comments=None, 148 | delimiter=delimiter, 149 | encoding='utf-8', 150 | invalid_raise=False, 151 | usecols=usecols, 152 | skip_header=skip_header 153 | ) 154 | 155 | return data if len(data.shape) > 1 else data.reshape(1, -1) 156 | 157 | @staticmethod 158 | def load_data_h5(fname): 159 | '''Load `.h5` file. Mostly it should be a file with features that extracted from the model. 160 | 161 | Args: 162 | fname: name or path to the file. 163 | 164 | Returns: 165 | A list of data. 166 | ''' 167 | with h5py.File(fname, 'r') as h: 168 | data = np.array(h['data']) 169 | return data 170 | 171 | 172 | 173 | def save_data(self, title, lines): 174 | '''Load images from `url`, extract features and fields, save as `.h5` and `.csv` files. 175 | 176 | Args: 177 | title: title to save the results. 178 | lines: lines of the source data. `url` should be placed at the end of all the fields. 179 | 180 | Returns: 181 | None. `.h5` and `.csv` files will be saved instead. 182 | ''' 183 | # Load model 184 | if self._model is None: 185 | self._model = DeepModel() 186 | 187 | print('%s: download starts.' % title) 188 | start = datetime.datetime.now() 189 | 190 | args = [{'path': line[-1], 'fields': line} for line in lines] 191 | 192 | # Prediction 193 | generator = self._predict_generator(args) 194 | features = self._model.extract_feature(generator) 195 | 196 | # Save files 197 | if len(self._title) == 2: 198 | self._title = [] 199 | self._title.append(title) 200 | 201 | if not os.path.isdir(self._tmp_dir): 202 | os.mkdir(self._tmp_dir) 203 | 204 | fname_feature = os.path.join(self._tmp_dir, '_' + title + '_feature.h5') 205 | with h5py.File(fname_feature, mode='w') as h: 206 | h.create_dataset('data', data=features) 207 | print('%s: feature saved to `%s`.' % (title, fname_feature)) 208 | 209 | fname_fields = os.path.join(self._tmp_dir, '_' + title + '_fields.csv') 210 | np.savetxt(fname_fields, generator.list_of_label_fields, delimiter='\t', fmt='%s', encoding='utf-8') 211 | print('%s: fields saved to `%s`.' % (title, fname_fields)) 212 | 213 | print('%s: download succeeded.' % title) 214 | print('Amount:', len(generator.list_of_label_fields)) 215 | print('Time consumed:', datetime.datetime.now()-start) 216 | print() 217 | 218 | def iteration(self, save_header, thresh=0.845, title1=None, title2=None): 219 | '''Calculate the cosine distance of two inputs, save the matched fields to `.csv` file. 220 | 221 | Args: 222 | save_header: header of the result `.csv` file. 223 | thresh: threshold of the similarity. 224 | title1, title2: Optional. If `save_data()` is not invoked, titles of two inputs should be passed. 225 | 226 | Returns: 227 | A matrix of element-wise cosine distance. 228 | 229 | Note: 230 | 1. The threshold can be set as the following values. 231 | 0.84 = greater matches amount 232 | 0.845 = balance, default 233 | 0.85 = better accuracy 234 | 235 | 2. If the generated files are exist, set `title1` or `title2` as same as the title of their source files. 236 | For example, pass `test.csv` to `save_data()` will generate `_test_feature.h5` and `_test_fields.csv` files, 237 | so set `title1` or `title2` to `test`, and `save_data()` will not be required to invoke. 238 | ''' 239 | if title1 and title2: 240 | self._title = [title1, title2] 241 | 242 | assert len(self._title) == 2, 'Two inputs are required.' 243 | 244 | feature1 = self.load_data_h5(os.path.join(self._tmp_dir, '_' + self._title[0] + '_feature.h5')) 245 | feature2 = self.load_data_h5(os.path.join(self._tmp_dir, '_' + self._title[1] + '_feature.h5')) 246 | 247 | fields1 = self.load_data_csv(os.path.join(self._tmp_dir, '_' + self._title[0] + '_fields.csv'), delimiter='\t', include_header=False) 248 | fields2 = self.load_data_csv(os.path.join(self._tmp_dir, '_' + self._title[1] + '_fields.csv'), delimiter='\t', include_header=False) 249 | 250 | print('%s: feature loaded, shape' % self._title[0], feature1.shape) 251 | print('%s: fields loaded, length' % self._title[0], len(fields1)) 252 | 253 | print('%s: feature loaded, shape' % self._title[1], feature2.shape) 254 | print('%s: fields loaded, length' % self._title[1], len(fields2)) 255 | 256 | print('Iteration starts.') 257 | start = datetime.datetime.now() 258 | 259 | distances = DeepModel.cosine_distance(feature1, feature2) 260 | indexes = np.argmax(distances, axis=1) 261 | 262 | result = [save_header + ['similarity']] 263 | 264 | for x, y in enumerate(indexes): 265 | dis = distances[x][y] 266 | if dis >= thresh: 267 | result.append(np.concatenate((fields1[x], fields2[y], np.array(['%.5f' % dis])), axis=0)) 268 | 269 | if len(result) > 0: 270 | np.savetxt('result_similarity.csv', result, fmt='%s', delimiter='\t', encoding='utf-8') 271 | 272 | print('Iteration finished: results saved to `result_similarity.csv`.') 273 | print('Amount: %d (%d * %d)' % (len(fields1)*len(fields2), len(fields1), len(fields2))) 274 | print('Time consumed:', datetime.datetime.now()-start) 275 | print() 276 | 277 | return distances 278 | 279 | 280 | if __name__ == '__main__': 281 | similarity = ImageSimilarity() 282 | 283 | '''Setup''' 284 | similarity.batch_size = 16 285 | similarity.num_processes = 2 286 | 287 | '''Load source data''' 288 | test1 = similarity.load_data_csv('./demo/test1.csv', delimiter=',') 289 | test2 = similarity.load_data_csv('./demo/test2.csv', delimiter=',', cols=['id', 'url']) 290 | 291 | '''Save features and fields''' 292 | similarity.save_data('test1', test1) 293 | similarity.save_data('test2', test2) 294 | 295 | '''Calculate similarities''' 296 | result = similarity.iteration(['test1_id', 'test1_url', 'test2_id', 'test2_url'], thresh=0.845) 297 | print('Row for source file 1, and column for source file 2.') 298 | print(result) 299 | -------------------------------------------------------------------------------- /model_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 3 | 4 | import numpy as np 5 | 6 | from tensorflow.python.keras.applications.mobilenet import MobileNet, preprocess_input 7 | from tensorflow.python.keras.preprocessing import image as process_image 8 | from tensorflow.python.keras.utils import Sequence 9 | from tensorflow.python.keras.layers import GlobalAveragePooling2D 10 | from tensorflow.python.keras import Model 11 | 12 | 13 | class DeepModel(): 14 | '''MobileNet deep model.''' 15 | def __init__(self): 16 | self._model = self._define_model() 17 | 18 | print('Loading MobileNet.') 19 | print() 20 | 21 | @staticmethod 22 | def _define_model(output_layer=-1): 23 | '''Define a pre-trained MobileNet model. 24 | 25 | Args: 26 | output_layer: the number of layer that output. 27 | 28 | Returns: 29 | Class of keras model with weights. 30 | ''' 31 | base_model = MobileNet(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) 32 | output = base_model.layers[output_layer].output 33 | output = GlobalAveragePooling2D()(output) 34 | model = Model(inputs=base_model.input, outputs=output) 35 | return model 36 | 37 | @staticmethod 38 | def preprocess_image(path): 39 | '''Process an image to numpy array. 40 | 41 | Args: 42 | path: the path of the image. 43 | 44 | Returns: 45 | Numpy array of the image. 46 | ''' 47 | img = process_image.load_img(path, target_size=(224, 224)) 48 | x = process_image.img_to_array(img) 49 | # x = np.expand_dims(x, axis=0) 50 | x = preprocess_input(x) 51 | return x 52 | 53 | @staticmethod 54 | def cosine_distance(input1, input2): 55 | '''Calculating the distance of two inputs. 56 | 57 | The return values lies in [-1, 1]. `-1` denotes two features are the most unlike, 58 | `1` denotes they are the most similar. 59 | 60 | Args: 61 | input1, input2: two input numpy arrays. 62 | 63 | Returns: 64 | Element-wise cosine distances of two inputs. 65 | ''' 66 | # return np.dot(input1, input2) / (np.linalg.norm(input1) * np.linalg.norm(input2)) 67 | return np.dot(input1, input2.T) / \ 68 | np.dot(np.linalg.norm(input1, axis=1, keepdims=True), \ 69 | np.linalg.norm(input2.T, axis=0, keepdims=True)) 70 | 71 | def extract_feature(self, generator): 72 | '''Extract deep feature using MobileNet model. 73 | 74 | Args: 75 | generator: a predict generator inherit from `keras.utils.Sequence`. 76 | 77 | Returns: 78 | The output features of all inputs. 79 | ''' 80 | features = self._model.predict_generator(generator) 81 | return features 82 | 83 | 84 | class DataSequence(Sequence): 85 | '''Predict generator inherit from `keras.utils.Sequence`.''' 86 | def __init__(self, paras, generation, batch_size=32): 87 | self.list_of_label_fields = [] 88 | self.list_of_paras = paras 89 | self.data_generation = generation 90 | self.batch_size = batch_size 91 | self.__idx = 0 92 | 93 | def __len__(self): 94 | '''The number of batches per epoch.''' 95 | return int(np.ceil(len(self.list_of_paras) / self.batch_size)) 96 | 97 | def __getitem__(self, idx): 98 | '''Generate one batch of data.''' 99 | paras = self.list_of_paras[idx * self.batch_size : (idx+1) * self.batch_size] 100 | batch_x, batch_fields = self.data_generation(paras) 101 | 102 | if idx == self.__idx: 103 | self.list_of_label_fields += batch_fields 104 | self.__idx += 1 105 | 106 | return np.array(batch_x) 107 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py~=2.6.0 2 | numpy~=1.14.5 3 | requests~=2.21.0 4 | --------------------------------------------------------------------------------