├── .gitignore
├── LICENSE
├── README.md
├── demo
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg
├── 5.jpg
├── 6.jpg
├── test1.csv
└── test2.csv
├── demo_override
├── README.md
├── main_override.py
├── test1.csv
└── test2.csv
├── image_util_cli.py
├── main_multi.py
├── model_util.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .vscode
3 | __pycache__
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Ryan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Image Similarity
2 |
3 | This is an efficient utility of image similarity using [MobileNet](https://arxiv.org/abs/1704.04861) deep neural network.
4 |
5 | Image similarity is a task mostly about feature selection of the image. Here, the Convolutional Neural Network (CNN) is used to extract features of these images. It is a better way for computer to understand them effectively.
6 |
7 | This repository use a light-weight model, the MobileNet, to extract image features, then calculate their cosine distances as matrixes. The distance of two features will lie in `[-1, 1]`, where `-1` denotes the features are the most unlike, and `1` denotes they are the most similar. Choose a proper threshold `[-1, 1]`, the most similar images will be matched.
8 |
9 | ## Usage
10 |
11 | The code is written to match the similar images in a huge amount as efficiently as possible.
12 |
13 | To use it, two `.csv` source files should be prepared before running. Here is an example of one source file. By default, the `.csv` file should at least include one field that place the urls [[1]](#notice).
14 |
15 | ```text
16 | id,url
17 | 1,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/1.jpg
18 | 2,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/2.jpg
19 | 3,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/3.jpg
20 | 4,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/4.jpg
21 | 5,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/5.jpg
22 | 6,https://raw.githubusercontent.com/ryanfwy/image-similarity/master/demo/6.jpg
23 | ```
24 |
25 | After that, we can setup the number of processes that are used to request images from the urls parallelly. For example, we use 2 processes with this tiny demo.
26 |
27 | ```python
28 | similarity.num_processes = 2
29 | ```
30 |
31 | For feature extraction, a data generator is used to predict images with model batch by batch. By default, GPU will be used if it satisfy the conditions of [Tensorflow](https://www.tensorflow.org/install/gpu). Now we can set a proper size of batch based on the memory size of our computer or server. In this demo, we set it to 16.
32 |
33 | ```python
34 | similarity.batch_size = 16
35 | ```
36 |
37 | After invoking the function `save_data()` two times, four self-generated files will be saved into `__generated__` directory with the file names of `_*_feature.h5` and `_*_fields.csv`. We can further calculate the similarities by calling `iteration()`, or load the generated files at any time afterward.
38 |
39 | Totally, the full example will look like:
40 |
41 | ```python
42 | similarity = ImageSimilarity()
43 |
44 | '''Setup'''
45 | similarity.batch_size = 16
46 | similarity.num_processes = 2
47 |
48 | '''Load source data'''
49 | test1 = similarity.load_data_csv('./demo/test1.csv', delimiter=',')
50 | test2 = similarity.load_data_csv('./demo/test2.csv', delimiter=',', cols=['id', 'url'])
51 |
52 | '''Save features and fields'''
53 | similarity.save_data('test1', test1)
54 | similarity.save_data('test2', test2)
55 |
56 | '''Calculate similarities'''
57 | result = similarity.iteration(['test1_id', 'test1_url', 'test2_id', 'test2_url'], thresh=0.845)
58 | print('Row for source file 1, and column for source file 2.')
59 | print(result)
60 | ```
61 |
62 | or if the files have been generated before:
63 |
64 | ```python
65 | similarity = ImageSimilarity()
66 | similarity.iteration(['test1_id', 'test1_url', 'test2_id', 'test2_id'], thresh=0.845, title1='test1', title2='test2')
67 | ```
68 |
69 | For practical usage, the `thresh` argument of `save_data()` is recommended to be in `[0.84, 1)`. One balanced value can be `0.845`.
70 |
71 | Any other details, please check the usages of each function given by `main_multi.py`.
72 |
73 | ## Requirements and Installation
74 |
75 | **NOTE**: Tensorflow is not included in `requirements.txt` due to the platform differences, please install and configure yourself based on your computer or server. Also note that `Python 3` is required.
76 |
77 | ```pip
78 | $ git clone https://github.com/ryanfwy/image-similarity.git
79 | $ cd image-similarity
80 | $ pip3 install -r requirements.txt
81 | ```
82 |
83 | The requirements are also listed down bellow.
84 |
85 | - tensorflow: the newest version for CPU, or the version that matches your GPU and CUDA.
86 | - h5py~=2.6.0
87 | - numpy~=1.14.5
88 | - requests~=2.21.0
89 |
90 | ## Experiment
91 |
92 | In the demo, 6 and 3 images are used to match their similarities.
93 |
94 | ### Accuracy
95 |
96 | The cosine distances are shown in the table.
97 |
98 | | |
|
|
|
99 | | --- | :---: | :---: | :---: |
100 | |
| **0.9229318** | 0.5577963 | 0.5826051 |
101 | |
| **0.84877944** | 0.538753 | 0.5624183 |
102 | |
| **1.** | 0.5512465 | 0.57025677 |
103 | |
| 0.5512465 | **0.99999994** | 0.54037786 |
104 | |
| 0.57025677 | 0.54037786 | **0.9999998** |
105 | |
| 0.5575757 | 0.5238174 | **0.91234696** |
106 |
107 | As it is shown, image similarity using deep neural network works fine. The distances of the matched images will roughly be greater than `0.84`.
108 |
109 | ### Efficiency
110 |
111 | For running efficiency, multi-processing and batch-wise prediction are used in feature extraction procedure. And thus, image requesting and processing in CPU, image prediction with model in GPU, will run simultaneously. In the procedure of similarity analysis, a matrix-wise mathematical method is used to avoid n*m iteration one by one. This may help a lot in the condition of low efficiency of python iteration, especially in a huge amount.
112 |
113 | Table bellow shows the time consumption runing with 8 processes in a practical case. The results are only for reference, they may change a lot based on the number of processes we use, the quality of the network, the image size of the online resources and so on.
114 |
115 | | | Source 1 | Source 2 | Iteration |
116 | | :---: | :---: | :---: | :---: |
117 | | Amount | 13501 | 21221 | 13501 * 21221 |
118 | | Time Consumption | 0:35:53 | 0:17:50 | 0:00:03.913282 |
119 |
120 | ## Notice
121 |
122 | [1] By default, the programme have to get the online images from urls we prepared in `.csv`. If we want to run the code with a list of offline images, we need to override the `_sub_process()` class method by ourselves. For demo and details, please check [demo_override](./demo_override).
123 |
124 |
125 | ## Thanks
126 |
127 | Demo images come from [ImageSimilarity](https://github.com/nivance/image-similarity) by [nivance](https://github.com/nivance). It is an another algorithm (pHash) of image similarity implementation in java.
128 |
--------------------------------------------------------------------------------
/demo/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/1.jpg
--------------------------------------------------------------------------------
/demo/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/2.jpg
--------------------------------------------------------------------------------
/demo/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/3.jpg
--------------------------------------------------------------------------------
/demo/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/4.jpg
--------------------------------------------------------------------------------
/demo/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/5.jpg
--------------------------------------------------------------------------------
/demo/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ryanfwy/image-similarity/fcf4856c1ea420e87bfa09f64e530d2d4d5a83d3/demo/6.jpg
--------------------------------------------------------------------------------
/demo/test1.csv:
--------------------------------------------------------------------------------
1 | id,url
2 | 1,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/1.jpg
3 | 2,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/2.jpg
4 | 3,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/3.jpg
5 | 4,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/4.jpg
6 | 5,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/5.jpg
7 | 6,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/6.jpg
--------------------------------------------------------------------------------
/demo/test2.csv:
--------------------------------------------------------------------------------
1 | id,url
2 | 3,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/3.jpg
3 | 4,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/4.jpg
4 | 5,https://raw.githubusercontent.com/ryanfwy/image_similarity/master/demo/5.jpg
--------------------------------------------------------------------------------
/demo_override/README.md:
--------------------------------------------------------------------------------
1 | # Implement Your Own `_sub_process()`
2 |
3 | By default, the `.csv` source file should at least include one field that place the **urls**. In other words, the programme have to get the online images from urls. However, if we want to run the code with a list of offline images, we need to override the `_sub_process()` class method by ourselves.
4 |
5 | ## Implement the Subclass
6 |
7 | The implementation should look like:
8 |
9 | ```python
10 | class NewImageSimilarity(ImageSimilarity):
11 | @staticmethod
12 | def _sub_process(para):
13 | # Override the method from the base class
14 | path, fields = para['path'], para['fields']
15 | try:
16 | feature = DeepModel.preprocess_image(path)
17 | return feature, fields
18 |
19 | except Exception as e:
20 | print('Error file %s: %s' % (fields[0], e))
21 |
22 | return None, None
23 | ```
24 |
25 | As it is shown, the method `_sub_process()` just simply remove one line `request.get(path)` and pass the `path` argument to `DeepModel.preprocess_image()` directly.
26 |
27 | In here, the `.csv` source file should at least include a field, such as `path`, to place all the local image paths. For example, it can be prepared like this.
28 |
29 | ```
30 | id,path
31 | 3,../demo/3.jpg
32 | 4,../demo/4.jpg
33 | 5,../demo/5.jpg
34 | ```
35 |
36 | The full example is also given in [main_override.py](./main_override.py). Please read it for more details about how to implement your own `_sub_process()` and run.
37 |
38 | ## Quick Preparation
39 |
40 | If we want to load a batch of offline image paths from the local directory which are prepared for `.csv` source file, the [image_util_cli.py](../image_util_cli.py) quick preparation script can easily do this job.
41 |
42 | To run this script, you should first put a batch of images into a directory, such as `source1`. The document tree will look like this.
43 |
44 | ```
45 | ./source1
46 | |- image1.jpg
47 | |- image2.jpg
48 | |- ...
49 | |_ image100.jpg
50 | ```
51 |
52 | After that, open `Terminal.app` (MacOS), `cd` to the directory of `image_util_cli.py`, and run it with the required arguments.
53 |
54 | ```
55 | $ cd image-similarity
56 | $ python3 image_util_cli.py ./source1 -d '\t' -o ./images.csv
57 | ```
58 |
59 | The usage of `image_util_cli.py` is given bellow. Also we can check it at any time by passing the argument `-h`.
60 |
61 | ```
62 | usage: image_util_cli [-h] [-d DELIMITER] [-o OUT_PATH] source
63 | positional arguments:
64 | source directory of the source images
65 |
66 | optional arguments:
67 | -h, --help show this help message and exit
68 | -d DELIMITER, --delimiter DELIMITER
69 | delimiter to the output file, default: ','
70 | -o OUT_PATH, --out-path OUT_PATH
71 | path to the output file, default: name of the source directory
72 | ```
73 |
--------------------------------------------------------------------------------
/demo_override/main_override.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 |
4 | from main_multi import ImageSimilarity, DeepModel
5 |
6 | class NewImageSimilarity(ImageSimilarity):
7 | @staticmethod
8 | def _sub_process(para):
9 | # Override the method from the base class
10 | path, fields = para['path'], para['fields']
11 | try:
12 | feature = DeepModel.preprocess_image(path)
13 | return feature, fields
14 |
15 | except Exception as e:
16 | print('Error file %s: %s' % (fields[0], e))
17 |
18 | return None, None
19 |
20 |
21 | if __name__ == "__main__":
22 | similarity = NewImageSimilarity()
23 |
24 | '''Setup'''
25 | similarity.batch_size = 16
26 | similarity.num_processes = 2
27 |
28 | '''Load source data'''
29 | test1 = similarity.load_data_csv('./test1.csv', delimiter=',')
30 | test2 = similarity.load_data_csv('./test2.csv', delimiter=',', cols=['id', 'path'])
31 |
32 | '''Save features and fields'''
33 | similarity.save_data('test1', test1)
34 | similarity.save_data('test2', test2)
35 |
--------------------------------------------------------------------------------
/demo_override/test1.csv:
--------------------------------------------------------------------------------
1 | id,path
2 | 1,../demo/1.jpg
3 | 2,../demo/2.jpg
4 | 3,../demo/3.jpg
5 | 4,../demo/4.jpg
6 | 5,../demo/5.jpg
7 | 6,../demo/6.jpg
--------------------------------------------------------------------------------
/demo_override/test2.csv:
--------------------------------------------------------------------------------
1 | id,path
2 | 3,../demo/3.jpg
3 | 4,../demo/4.jpg
4 | 5,../demo/5.jpg
--------------------------------------------------------------------------------
/image_util_cli.py:
--------------------------------------------------------------------------------
1 | '''CLI utility for image preparation.'''
2 |
3 | import os
4 | import argparse
5 | import numpy as np
6 |
7 |
8 | def process(input_dir, delimiter=',', output_path=None):
9 | '''Generate a `.csv` file with image paths.'''
10 | result = [['name', 'path']]
11 | file_names = os.listdir(input_dir)
12 | file_names.sort()
13 | for file_name in file_names:
14 | file_path = os.path.join(input_dir, file_name)
15 | result.append([os.path.splitext(file_name)[0], os.path.abspath(file_path)])
16 |
17 | if output_path is None:
18 | parent_dir = list(filter(lambda x: not x == '', input_dir.split('/')))[-1]
19 | output_path = parent_dir + '.csv'
20 |
21 | np.savetxt(output_path, result, delimiter=delimiter, fmt='%s', encoding='utf-8')
22 |
23 | print('File saved to `%s`.' % output_path)
24 |
25 | def main():
26 | '''CLI entrance.'''
27 | parser = argparse.ArgumentParser(prog='image_util_cli')
28 | parser.add_argument('source', action='store', type=str, help='directory of the source images')
29 | parser.add_argument('-d', '--delimiter', required=False, type=str, default=',', help="delimiter to the output file, default: ','")
30 | parser.add_argument('-o', '--out-path', required=False, type=str, help='path to the output file, default: name of the source directory')
31 |
32 | args = parser.parse_args()
33 | if args.source:
34 | if os.path.isdir(args.source) is False:
35 | exit('No directory `%s`.' % args.source)
36 |
37 | process(args.source, delimiter=args.delimiter, output_path=args.out_path)
38 |
39 |
40 | if __name__ == '__main__':
41 | main()
42 |
--------------------------------------------------------------------------------
/main_multi.py:
--------------------------------------------------------------------------------
1 | '''Image similarity using deep features.
2 |
3 | Recommendation: the threshold of the `DeepModel.cosine_distance` can be set as the following values.
4 | 0.84 = greater matches amount
5 | 0.845 = balance, default
6 | 0.85 = better accuracy
7 | '''
8 |
9 | from io import BytesIO
10 | from multiprocessing import Pool
11 |
12 | import os
13 | import datetime
14 | import numpy as np
15 | import requests
16 | import h5py
17 |
18 | from model_util import DeepModel, DataSequence
19 |
20 |
21 | class ImageSimilarity():
22 | '''Image similarity.'''
23 | def __init__(self):
24 | self._tmp_dir = './__generated__'
25 | self._batch_size = 64
26 | self._num_processes = 4
27 | self._model = None
28 | self._title = []
29 |
30 | @property
31 | def batch_size(self):
32 | '''Batch size of model prediction.'''
33 | return self._batch_size
34 |
35 | @property
36 | def num_processes(self):
37 | '''Number of processes using `Multiprocessing.Pool`.'''
38 | return self._num_processes
39 |
40 | @batch_size.setter
41 | def batch_size(self, batch_size):
42 | self._batch_size = batch_size
43 |
44 | @num_processes.setter
45 | def num_processes(self, num_processes):
46 | self._num_processes = num_processes
47 |
48 | def _data_generation(self, args):
49 | '''Generate input batches for predict generator.
50 |
51 | Args:
52 | args: parameters that pass to `sub_process`.
53 | - path: path of the image, online url by default.
54 | - fields: all other fields.
55 |
56 | Returns:
57 | batch_x: a batch of predict samples.
58 | batch_fields: a batch of fields that matches the samples.
59 | '''
60 | # Multiprocessing
61 | pool = Pool(self._num_processes)
62 | res = pool.map(self._sub_process, args)
63 | pool.close()
64 | pool.join()
65 |
66 | batch_x, batch_fields = [], []
67 | for x, fields in res:
68 | if x is not None:
69 | batch_x.append(x)
70 | batch_fields.append(fields)
71 |
72 | return batch_x, batch_fields
73 |
74 | def _predict_generator(self, paras):
75 | '''Build a predict generator.
76 |
77 | Args:
78 | paras: input parameters of all samples.
79 | - path: path of the image, online url by default.
80 | - fields: all other fields.
81 |
82 | Returns:
83 | The predict generator.
84 | '''
85 | return DataSequence(paras, self._data_generation, batch_size=self._batch_size)
86 |
87 | @staticmethod
88 | def _sub_process(para):
89 | '''A sub-process function of `multiprocessing`.
90 |
91 | Download image from url and process it into a numpy array.
92 |
93 | Args:
94 | para: input parameters of one image.
95 | - path: path of the image, online url by default.
96 | - fields: all other fields.
97 |
98 | Returns:
99 | feature: feature array of one image.
100 | fields: all other fields of one image that passed from `para`.
101 |
102 | Note: If error happens, `None` will be returned.
103 | '''
104 | path, fields = para['path'], para['fields']
105 | try:
106 | headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/67.0.3396.99 Safari/537.36'}
107 | res = requests.get(path, headers=headers)
108 | feature = DeepModel.preprocess_image(BytesIO(res.content))
109 | return feature, fields
110 |
111 | except Exception as e:
112 | print('Error downloading %s: %s' % (fields[0], e))
113 |
114 | return None, None
115 |
116 | @staticmethod
117 | def load_data_csv(fname, delimiter=None, include_header=True, cols=None):
118 | '''Load `.csv` file. Mostly it should be a file that list all fields to match.
119 |
120 | Args:
121 | fname: name or path to the file.
122 | delimiter: delimiter to split the content.
123 | include_header: whether the source file include header or not.
124 | cols: a list of columns to read. Pass `None` to read all columns.
125 |
126 | Returns:
127 | A list of data.
128 | '''
129 | assert delimiter is not None, 'Delimiter is required.'
130 |
131 | if include_header:
132 | usecols = None
133 | skip_header = 1
134 | if cols:
135 | with open(fname, 'r', encoding='utf-8') as f:
136 | csv_head = f.readline().strip().split(delimiter)
137 |
138 | usecols = [csv_head.index(col) for col in cols]
139 |
140 | else:
141 | usecols = None
142 | skip_header = 0
143 |
144 | data = np.genfromtxt(
145 | fname,
146 | dtype=str,
147 | comments=None,
148 | delimiter=delimiter,
149 | encoding='utf-8',
150 | invalid_raise=False,
151 | usecols=usecols,
152 | skip_header=skip_header
153 | )
154 |
155 | return data if len(data.shape) > 1 else data.reshape(1, -1)
156 |
157 | @staticmethod
158 | def load_data_h5(fname):
159 | '''Load `.h5` file. Mostly it should be a file with features that extracted from the model.
160 |
161 | Args:
162 | fname: name or path to the file.
163 |
164 | Returns:
165 | A list of data.
166 | '''
167 | with h5py.File(fname, 'r') as h:
168 | data = np.array(h['data'])
169 | return data
170 |
171 |
172 |
173 | def save_data(self, title, lines):
174 | '''Load images from `url`, extract features and fields, save as `.h5` and `.csv` files.
175 |
176 | Args:
177 | title: title to save the results.
178 | lines: lines of the source data. `url` should be placed at the end of all the fields.
179 |
180 | Returns:
181 | None. `.h5` and `.csv` files will be saved instead.
182 | '''
183 | # Load model
184 | if self._model is None:
185 | self._model = DeepModel()
186 |
187 | print('%s: download starts.' % title)
188 | start = datetime.datetime.now()
189 |
190 | args = [{'path': line[-1], 'fields': line} for line in lines]
191 |
192 | # Prediction
193 | generator = self._predict_generator(args)
194 | features = self._model.extract_feature(generator)
195 |
196 | # Save files
197 | if len(self._title) == 2:
198 | self._title = []
199 | self._title.append(title)
200 |
201 | if not os.path.isdir(self._tmp_dir):
202 | os.mkdir(self._tmp_dir)
203 |
204 | fname_feature = os.path.join(self._tmp_dir, '_' + title + '_feature.h5')
205 | with h5py.File(fname_feature, mode='w') as h:
206 | h.create_dataset('data', data=features)
207 | print('%s: feature saved to `%s`.' % (title, fname_feature))
208 |
209 | fname_fields = os.path.join(self._tmp_dir, '_' + title + '_fields.csv')
210 | np.savetxt(fname_fields, generator.list_of_label_fields, delimiter='\t', fmt='%s', encoding='utf-8')
211 | print('%s: fields saved to `%s`.' % (title, fname_fields))
212 |
213 | print('%s: download succeeded.' % title)
214 | print('Amount:', len(generator.list_of_label_fields))
215 | print('Time consumed:', datetime.datetime.now()-start)
216 | print()
217 |
218 | def iteration(self, save_header, thresh=0.845, title1=None, title2=None):
219 | '''Calculate the cosine distance of two inputs, save the matched fields to `.csv` file.
220 |
221 | Args:
222 | save_header: header of the result `.csv` file.
223 | thresh: threshold of the similarity.
224 | title1, title2: Optional. If `save_data()` is not invoked, titles of two inputs should be passed.
225 |
226 | Returns:
227 | A matrix of element-wise cosine distance.
228 |
229 | Note:
230 | 1. The threshold can be set as the following values.
231 | 0.84 = greater matches amount
232 | 0.845 = balance, default
233 | 0.85 = better accuracy
234 |
235 | 2. If the generated files are exist, set `title1` or `title2` as same as the title of their source files.
236 | For example, pass `test.csv` to `save_data()` will generate `_test_feature.h5` and `_test_fields.csv` files,
237 | so set `title1` or `title2` to `test`, and `save_data()` will not be required to invoke.
238 | '''
239 | if title1 and title2:
240 | self._title = [title1, title2]
241 |
242 | assert len(self._title) == 2, 'Two inputs are required.'
243 |
244 | feature1 = self.load_data_h5(os.path.join(self._tmp_dir, '_' + self._title[0] + '_feature.h5'))
245 | feature2 = self.load_data_h5(os.path.join(self._tmp_dir, '_' + self._title[1] + '_feature.h5'))
246 |
247 | fields1 = self.load_data_csv(os.path.join(self._tmp_dir, '_' + self._title[0] + '_fields.csv'), delimiter='\t', include_header=False)
248 | fields2 = self.load_data_csv(os.path.join(self._tmp_dir, '_' + self._title[1] + '_fields.csv'), delimiter='\t', include_header=False)
249 |
250 | print('%s: feature loaded, shape' % self._title[0], feature1.shape)
251 | print('%s: fields loaded, length' % self._title[0], len(fields1))
252 |
253 | print('%s: feature loaded, shape' % self._title[1], feature2.shape)
254 | print('%s: fields loaded, length' % self._title[1], len(fields2))
255 |
256 | print('Iteration starts.')
257 | start = datetime.datetime.now()
258 |
259 | distances = DeepModel.cosine_distance(feature1, feature2)
260 | indexes = np.argmax(distances, axis=1)
261 |
262 | result = [save_header + ['similarity']]
263 |
264 | for x, y in enumerate(indexes):
265 | dis = distances[x][y]
266 | if dis >= thresh:
267 | result.append(np.concatenate((fields1[x], fields2[y], np.array(['%.5f' % dis])), axis=0))
268 |
269 | if len(result) > 0:
270 | np.savetxt('result_similarity.csv', result, fmt='%s', delimiter='\t', encoding='utf-8')
271 |
272 | print('Iteration finished: results saved to `result_similarity.csv`.')
273 | print('Amount: %d (%d * %d)' % (len(fields1)*len(fields2), len(fields1), len(fields2)))
274 | print('Time consumed:', datetime.datetime.now()-start)
275 | print()
276 |
277 | return distances
278 |
279 |
280 | if __name__ == '__main__':
281 | similarity = ImageSimilarity()
282 |
283 | '''Setup'''
284 | similarity.batch_size = 16
285 | similarity.num_processes = 2
286 |
287 | '''Load source data'''
288 | test1 = similarity.load_data_csv('./demo/test1.csv', delimiter=',')
289 | test2 = similarity.load_data_csv('./demo/test2.csv', delimiter=',', cols=['id', 'url'])
290 |
291 | '''Save features and fields'''
292 | similarity.save_data('test1', test1)
293 | similarity.save_data('test2', test2)
294 |
295 | '''Calculate similarities'''
296 | result = similarity.iteration(['test1_id', 'test1_url', 'test2_id', 'test2_url'], thresh=0.845)
297 | print('Row for source file 1, and column for source file 2.')
298 | print(result)
299 |
--------------------------------------------------------------------------------
/model_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
3 |
4 | import numpy as np
5 |
6 | from tensorflow.python.keras.applications.mobilenet import MobileNet, preprocess_input
7 | from tensorflow.python.keras.preprocessing import image as process_image
8 | from tensorflow.python.keras.utils import Sequence
9 | from tensorflow.python.keras.layers import GlobalAveragePooling2D
10 | from tensorflow.python.keras import Model
11 |
12 |
13 | class DeepModel():
14 | '''MobileNet deep model.'''
15 | def __init__(self):
16 | self._model = self._define_model()
17 |
18 | print('Loading MobileNet.')
19 | print()
20 |
21 | @staticmethod
22 | def _define_model(output_layer=-1):
23 | '''Define a pre-trained MobileNet model.
24 |
25 | Args:
26 | output_layer: the number of layer that output.
27 |
28 | Returns:
29 | Class of keras model with weights.
30 | '''
31 | base_model = MobileNet(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
32 | output = base_model.layers[output_layer].output
33 | output = GlobalAveragePooling2D()(output)
34 | model = Model(inputs=base_model.input, outputs=output)
35 | return model
36 |
37 | @staticmethod
38 | def preprocess_image(path):
39 | '''Process an image to numpy array.
40 |
41 | Args:
42 | path: the path of the image.
43 |
44 | Returns:
45 | Numpy array of the image.
46 | '''
47 | img = process_image.load_img(path, target_size=(224, 224))
48 | x = process_image.img_to_array(img)
49 | # x = np.expand_dims(x, axis=0)
50 | x = preprocess_input(x)
51 | return x
52 |
53 | @staticmethod
54 | def cosine_distance(input1, input2):
55 | '''Calculating the distance of two inputs.
56 |
57 | The return values lies in [-1, 1]. `-1` denotes two features are the most unlike,
58 | `1` denotes they are the most similar.
59 |
60 | Args:
61 | input1, input2: two input numpy arrays.
62 |
63 | Returns:
64 | Element-wise cosine distances of two inputs.
65 | '''
66 | # return np.dot(input1, input2) / (np.linalg.norm(input1) * np.linalg.norm(input2))
67 | return np.dot(input1, input2.T) / \
68 | np.dot(np.linalg.norm(input1, axis=1, keepdims=True), \
69 | np.linalg.norm(input2.T, axis=0, keepdims=True))
70 |
71 | def extract_feature(self, generator):
72 | '''Extract deep feature using MobileNet model.
73 |
74 | Args:
75 | generator: a predict generator inherit from `keras.utils.Sequence`.
76 |
77 | Returns:
78 | The output features of all inputs.
79 | '''
80 | features = self._model.predict_generator(generator)
81 | return features
82 |
83 |
84 | class DataSequence(Sequence):
85 | '''Predict generator inherit from `keras.utils.Sequence`.'''
86 | def __init__(self, paras, generation, batch_size=32):
87 | self.list_of_label_fields = []
88 | self.list_of_paras = paras
89 | self.data_generation = generation
90 | self.batch_size = batch_size
91 | self.__idx = 0
92 |
93 | def __len__(self):
94 | '''The number of batches per epoch.'''
95 | return int(np.ceil(len(self.list_of_paras) / self.batch_size))
96 |
97 | def __getitem__(self, idx):
98 | '''Generate one batch of data.'''
99 | paras = self.list_of_paras[idx * self.batch_size : (idx+1) * self.batch_size]
100 | batch_x, batch_fields = self.data_generation(paras)
101 |
102 | if idx == self.__idx:
103 | self.list_of_label_fields += batch_fields
104 | self.__idx += 1
105 |
106 | return np.array(batch_x)
107 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | h5py~=2.6.0
2 | numpy~=1.14.5
3 | requests~=2.21.0
4 |
--------------------------------------------------------------------------------