├── .gitignore
├── LICENSE
├── README.md
├── data_preprocess
├── .gitignore
├── feature_extraction.py
└── filter_images.py
├── demo_result.png
├── index_construction
├── .gitignore
├── SPTAG_rpc_search_client.py
├── export_SPTAG_indexbuilder_input.py
├── search_test.py
└── vgg_model.py
├── mm_images
├── .gitignore
└── faces
│ └── .gitignore
├── requirements.txt
├── search_test_result.jpg
├── test_img.jpeg
└── web_demo
├── .gitignore
├── README.md
├── babel.config.js
├── backend
├── .gitignore
├── SPTAG_rpc_search_client.py
├── vgg_model.py
├── web_service.py
└── web_utils.py
├── dist
├── index.html
└── static
│ ├── css
│ └── app.af6337e4.css
│ ├── fonts
│ ├── element-icons.535877f5.woff
│ └── element-icons.732389de.ttf
│ ├── js
│ ├── app.3d295f3c.js
│ ├── app.3d295f3c.js.map
│ ├── chunk-vendors.c354b2dd.js
│ └── chunk-vendors.c354b2dd.js.map
│ └── upload_images
│ └── .gitignore
├── main.py
├── package-lock.json
├── package.json
├── public
├── index.html
└── static
│ └── upload_images
│ └── .gitignore
├── src
├── App.vue
├── actions
│ └── api.js
├── components
│ └── SearchView.vue
├── element-variables.scss
├── main.js
└── plugins
│ └── element.js
└── vue.config.js
/.gitignore:
--------------------------------------------------------------------------------
1 | build-desktop-client-*/
2 | .vscode/
3 | .idea/
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Kalen Blue
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MMFinder
2 | 一个美女图搜索应用的demo。
3 |
4 | > 新增了ElasticSearch版本,配置简单,详细说明请查看:[elasticsearch分支](https://github.com/nladuo/MMFinder/tree/elasticsearch)。
5 |
6 | ## 环境
7 | python3.6以上 + mongodb + SPTAG
8 |
9 |
10 | ## 安装依赖
11 | ### 安装Dlib
12 | #### Mac or Linux
13 | 需要先安装好Cmake再通过pip安装即可。
14 | ```
15 | pip3 install dlib
16 | ```
17 |
18 | #### 对于Windows
19 | 对于python3.6可以通过whl快速安装。
20 | ```
21 | pip install https://pypi.python.org/packages/da/06/bd3e241c4eb0a662914b3b4875fc52dd176a9db0d4a2c915ac2ad8800e9e/dlib-19.7.0-cp36-cp36m-win_amd64.whl#md5=b7330a5b2d46420343fbed5df69e6a3f
22 | ```
23 | 其他版本可以参考网上教程。
24 |
25 | ### 安装Python依赖
26 | ```
27 | cd MMFinder
28 | pip3 install -r requirements.txt
29 | ```
30 |
31 | ## 数据准备
32 | ### 1. 准备数据
33 | 爬取MM图片数据。
34 | > 如果没有美女图片,可以用我的数据,Google Drive下载链接:[https://drive.google.com/file/d/1shZ3gx9nHPHUgylsZIrvWliwCh9TucAo/view?usp=sharing](https://drive.google.com/file/d/1shZ3gx9nHPHUgylsZIrvWliwCh9TucAo/view?usp=sharing)。解压密码:nladuo。
35 |
36 | ### 2. 过滤图片
37 | 只选出带一个脸的美女图,然后放到mongo里面
38 | ```bash
39 | cd data_prprocess
40 | python3 filter_images.py
41 | ```
42 |
43 | ## 特征工程
44 | 通过VGG-net对人脸图片特征提取,转换成dense-vector。
45 | ### 1. 下载VGG预训练模型
46 | Google Drive:https://drive.google.com/file/d/1CPSeum3HpopfomUEK1gybeuIVoeJT_Eo/view?usp=sharing]
47 |
48 | 百度云链接:https://pan.baidu.com/s/1Dk40tW2lx1ezTda9IyIO9g 密码:0vc7
49 | ### 2. 使用VGG提取特征并构建数据集
50 | ```bash
51 | cd data_prprocess
52 | python3 feature_extraction.py
53 | ```
54 |
55 | ## 建立索引
56 | ### 1. 安装SPTAG,并启动Rpc服务
57 | 见:[Docker下SPTAG的安装与测试](https://www.jianshu.com/p/fcedf00eac32)
58 |
59 |
60 | ### 2. 对图片建立索引
61 | ```
62 | cd index_construction
63 | python3 export_SPTAG_indexbuilder_input.py
64 | ```
65 | 拷贝mm_index_input.txt到docker容器中
66 | ```bash
67 | docker cp mm_index_input.txt 25042d741f07:/app/Release/
68 | ```
69 |
70 | 进入SPTAG的docker容器中,建立索引
71 | ```bash
72 | docker attach 25042d741f07
73 | ./indexbuilder -d 2622 -v Float -i ./mm_index_input.txt -o data/mm_index -a BKT -t 2
74 | ```
75 |
76 | 启动SPTAG搜索服务
77 | ```bash
78 | python3 SPTAG_rpc_search_service.py
79 | ```
80 |
81 | ### 3. 搜索测试
82 | 对于mac用户,可以先安装``imgcat``,然后运行``index_construction/search_test.py``.
83 |
84 | 效果如下:
85 | 
86 |
87 | ## 运行demo
88 | ### 运行演示网站
89 | ```
90 | cd web_demo
91 | python3 main.py
92 | ```
93 |
94 | ### 测试效果
95 | 打开[http://localhost:3889](http://localhost:3889)
96 |
97 | 上传一张图片测试,效果如下:
98 | 
99 | ## Reference
100 | - https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/
101 |
102 | ## LICENSE
103 | MIT
104 |
--------------------------------------------------------------------------------
/data_preprocess/.gitignore:
--------------------------------------------------------------------------------
1 | vgg_face_weights.h5
2 | data.pickle
--------------------------------------------------------------------------------
/data_preprocess/feature_extraction.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | from keras.layers import ZeroPadding2D, Convolution2D, MaxPooling2D, Dropout, Flatten, Activation
5 | from keras import Sequential, Model
6 | from keras.preprocessing import image
7 | from keras.applications.vgg19 import preprocess_input
8 | import pymongo
9 |
10 |
11 | def build_model():
12 | model = Sequential()
13 | model.add(ZeroPadding2D((1, 1), input_shape=(224, 224, 3)))
14 | model.add(Convolution2D(64, (3, 3), activation='relu'))
15 | model.add(ZeroPadding2D((1, 1)))
16 | model.add(Convolution2D(64, (3, 3), activation='relu'))
17 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
18 |
19 | model.add(ZeroPadding2D((1, 1)))
20 | model.add(Convolution2D(128, (3, 3), activation='relu'))
21 | model.add(ZeroPadding2D((1, 1)))
22 | model.add(Convolution2D(128, (3, 3), activation='relu'))
23 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
24 |
25 | model.add(ZeroPadding2D((1, 1)))
26 | model.add(Convolution2D(256, (3, 3), activation='relu'))
27 | model.add(ZeroPadding2D((1, 1)))
28 | model.add(Convolution2D(256, (3, 3), activation='relu'))
29 | model.add(ZeroPadding2D((1, 1)))
30 | model.add(Convolution2D(256, (3, 3), activation='relu'))
31 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
32 |
33 | model.add(ZeroPadding2D((1, 1)))
34 | model.add(Convolution2D(512, (3, 3), activation='relu'))
35 | model.add(ZeroPadding2D((1, 1)))
36 | model.add(Convolution2D(512, (3, 3), activation='relu'))
37 | model.add(ZeroPadding2D((1, 1)))
38 | model.add(Convolution2D(512, (3, 3), activation='relu'))
39 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
40 |
41 | model.add(ZeroPadding2D((1, 1)))
42 | model.add(Convolution2D(512, (3, 3), activation='relu'))
43 | model.add(ZeroPadding2D((1, 1)))
44 | model.add(Convolution2D(512, (3, 3), activation='relu'))
45 | model.add(ZeroPadding2D((1, 1)))
46 | model.add(Convolution2D(512, (3, 3), activation='relu'))
47 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
48 |
49 | model.add(Convolution2D(4096, (7, 7), activation='relu'))
50 | model.add(Dropout(0.5))
51 | model.add(Convolution2D(4096, (1, 1), activation='relu'))
52 | model.add(Dropout(0.5))
53 | model.add(Convolution2D(2622, (1, 1)))
54 | model.add(Flatten())
55 | model.add(Activation('softmax'))
56 |
57 | return model
58 |
59 |
60 | def preprocess_image(image_path):
61 | img = image.load_img(image_path, target_size=(224, 224))
62 | img = image.img_to_array(img)
63 | img = np.expand_dims(img, axis=0)
64 | img = preprocess_input(img)
65 | return img
66 |
67 |
68 | model = build_model()
69 | model.load_weights('vgg_face_weights.h5')
70 | vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
71 |
72 |
73 | client = pymongo.MongoClient()
74 | db = client.MMFinder
75 | images_coll = db.images
76 |
77 | image_datas = []
78 |
79 | IMAGES_PATH = "../mm_images"
80 |
81 | for image_data in images_coll.find({}):
82 | # if "vec" not in image_data:
83 | image_datas.append(image_data)
84 |
85 |
86 | for i, image_data in enumerate(image_datas):
87 | img_path = f"{IMAGES_PATH}/faces/face-{image_data['path']}"
88 | img = preprocess_image(img_path)
89 | features = vgg_face_descriptor.predict(img)
90 | print(i, img_path)
91 | vec = features[0].tolist()
92 | images_coll.update({"_id": image_data["_id"]}, {
93 | "$set": {
94 | "vec": vec
95 | }
96 | })
97 |
98 |
--------------------------------------------------------------------------------
/data_preprocess/filter_images.py:
--------------------------------------------------------------------------------
1 | import pymongo
2 | import face_recognition
3 | from PIL import Image
4 | import os
5 |
6 | client = pymongo.MongoClient()
7 | db = client.MMFinder
8 | images_coll = db.images
9 |
10 |
11 | IMAGES_PATH = "../mm_images"
12 |
13 |
14 | def get_face_and_save(path):
15 | image_path = f'{IMAGES_PATH}/{path}'
16 | image = face_recognition.load_image_file(image_path)
17 | locations = face_recognition.face_locations(image)
18 | if len(locations) == 1: # save the face of mm
19 | top, right, bottom, left = locations[0]
20 | face_image = image[top:bottom, left:right]
21 | pil_image = Image.fromarray(face_image)
22 | with open(f'{IMAGES_PATH}/faces/face-{path}', "wb") as f:
23 | pil_image.save(f)
24 | return len(locations)
25 |
26 |
27 | def check_file_type(path):
28 | allow_types = [".png", ".jpg", ".jpeg"]
29 | for t in allow_types:
30 | if path.endswith(t):
31 | return True
32 | return False
33 |
34 |
35 | for i, path in enumerate(os.listdir(IMAGES_PATH)):
36 | if not check_file_type(path):
37 | continue
38 |
39 | if images_coll.find({"path": path}).count() != 0:
40 | continue
41 |
42 | print(i, path)
43 | try:
44 | if get_face_and_save(path) != 1:
45 | continue
46 | except:
47 | continue
48 | images_coll.insert({
49 | "path": path
50 | })
51 |
--------------------------------------------------------------------------------
/demo_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nladuo/MMFinder/6650595007135c37f9bb0a1be9b8db23333545e3/demo_result.png
--------------------------------------------------------------------------------
/index_construction/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.txt
--------------------------------------------------------------------------------
/index_construction/SPTAG_rpc_search_client.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import rpyc
3 |
4 |
5 | class DataBean:
6 | def __init__(self, _id: str, vec: np.array):
7 | self._id = _id
8 | self.vec = vec
9 | if '\n' in _id:
10 | raise Exception("_id cannot contain \\n")
11 |
12 | if len(vec.shape) != 1:
13 | raise Exception("vec must be 1-d vector")
14 |
15 | if vec.dtype != np.float32:
16 | raise Exception("the dtype of vec must be np.float32")
17 |
18 |
19 | class SPTAG_RpcSearchClient:
20 |
21 | ALGO_BKT = "BKT" # SPTAG-BKT is advantageous in search accuracy in very high-dimensional data
22 | ALGO_KDT = "KDT" # SPTAG-KDT is advantageous in index building cost,
23 |
24 | DIST_L2 = "L2"
25 | DIST_Cosine = "Cosine"
26 |
27 | def __init__(self, host, port):
28 | c = rpyc.connect(host, port)
29 | c._config['sync_request_timeout'] = None
30 | self.proxy = c.root
31 |
32 | def search(self, beans: [DataBean], p_resultNum):
33 | _, vecs = self.__get_meta_and_vec_from_beans(beans)
34 | vecs_ = vecs.tolist()
35 | return self.proxy.search(vecs_, p_resultNum)
36 |
37 | def __get_meta_and_vec_from_beans(self, beans: [DataBean]):
38 | if len(beans) == 0:
39 | raise Exception("beans length cannot be zero!")
40 |
41 | if len(beans) > 1000:
42 | raise Exception("cannot add more than 1000 beans at once!")
43 |
44 | dim = beans[0].vec.shape[0]
45 | meta = ""
46 | vecs = np.zeros((len(beans), dim))
47 |
48 | for i, bean in enumerate(beans):
49 | meta += bean._id + '\n'
50 | vecs[i] = bean.vec
51 |
52 | meta = meta.encode()
53 | return meta, vecs
54 |
55 |
56 | if __name__ == "__main__":
57 | client = SPTAG_RpcSearchClient("127.0.0.1", "8888")
58 | print("Test Search")
59 | q = DataBean(_id=f"s{0}", vec=0 * np.ones((10,), dtype=np.float32))
60 | print(client.search([q], 3))
61 |
--------------------------------------------------------------------------------
/index_construction/export_SPTAG_indexbuilder_input.py:
--------------------------------------------------------------------------------
1 | import pymongo
2 | import numpy as np
3 |
4 |
5 | client = pymongo.MongoClient()
6 | db = client.MMFinder
7 | images_coll = db.images
8 |
9 | count = 0
10 |
11 | with open("mm_index_input.txt", "w") as f:
12 | for image in images_coll.find():
13 | path = image["path"]
14 | vec = "|".join([str(i) for i in image["vec"]])
15 |
16 | print(count, path)
17 | f.write(f"{path}\t{vec}\n")
18 | count += 1
19 |
--------------------------------------------------------------------------------
/index_construction/search_test.py:
--------------------------------------------------------------------------------
1 | from SPTAG_rpc_search_client import SPTAG_RpcSearchClient, DataBean
2 | import pymongo
3 | import os
4 | import random
5 | from vgg_model import get_feature_extractor, preprocess_image
6 | import numpy as np
7 |
8 | image_names = []
9 | for i, path in enumerate(os.listdir("../mm_images")):
10 | if "jpg" in path:
11 | image_names.append(path)
12 | if i > 100:
13 | break
14 |
15 | random.shuffle(image_names)
16 | test_img = image_names[0]
17 | test_face_path = f"../mm_images/faces/face-{test_img}"
18 |
19 | search_client = SPTAG_RpcSearchClient("127.0.0.1", "8888")
20 |
21 |
22 | client = pymongo.MongoClient()
23 | db = client.MMFinder
24 | images_coll = db.images
25 |
26 | vgg_feature_extractor = get_feature_extractor()
27 |
28 |
29 | def get_face_representation(path):
30 | img = preprocess_image(path)
31 | features = vgg_feature_extractor.predict(img)
32 | vec = features[0].tolist()
33 | return vec
34 |
35 |
36 | vec = np.array(get_face_representation(test_face_path), dtype=np.float32)
37 | bean = DataBean(_id="", vec=vec)
38 | os.system(f"imgcat {test_face_path}")
39 | results = search_client.search([bean], 20)
40 | for item in results[0]:
41 | k = [i for i in item.keys()][0]
42 | print(k, item[k])
43 | os.system(f"imgcat ../mm_images/faces/face-{k}")
44 |
--------------------------------------------------------------------------------
/index_construction/vgg_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from keras.layers import ZeroPadding2D, Convolution2D, MaxPooling2D, Dropout, Flatten, Activation
3 | from keras import Sequential, Model
4 | from keras.preprocessing import image
5 | from keras.applications.vgg19 import preprocess_input
6 |
7 |
8 | def get_feature_extractor():
9 | model = build_model()
10 | model.load_weights('../data_preprocess/vgg_face_weights.h5')
11 | vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
12 | return vgg_face_descriptor
13 |
14 |
15 | def build_model():
16 | model = Sequential()
17 | model.add(ZeroPadding2D((1, 1), input_shape=(224, 224, 3)))
18 | model.add(Convolution2D(64, (3, 3), activation='relu'))
19 | model.add(ZeroPadding2D((1, 1)))
20 | model.add(Convolution2D(64, (3, 3), activation='relu'))
21 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
22 |
23 | model.add(ZeroPadding2D((1, 1)))
24 | model.add(Convolution2D(128, (3, 3), activation='relu'))
25 | model.add(ZeroPadding2D((1, 1)))
26 | model.add(Convolution2D(128, (3, 3), activation='relu'))
27 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
28 |
29 | model.add(ZeroPadding2D((1, 1)))
30 | model.add(Convolution2D(256, (3, 3), activation='relu'))
31 | model.add(ZeroPadding2D((1, 1)))
32 | model.add(Convolution2D(256, (3, 3), activation='relu'))
33 | model.add(ZeroPadding2D((1, 1)))
34 | model.add(Convolution2D(256, (3, 3), activation='relu'))
35 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
36 |
37 | model.add(ZeroPadding2D((1, 1)))
38 | model.add(Convolution2D(512, (3, 3), activation='relu'))
39 | model.add(ZeroPadding2D((1, 1)))
40 | model.add(Convolution2D(512, (3, 3), activation='relu'))
41 | model.add(ZeroPadding2D((1, 1)))
42 | model.add(Convolution2D(512, (3, 3), activation='relu'))
43 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
44 |
45 | model.add(ZeroPadding2D((1, 1)))
46 | model.add(Convolution2D(512, (3, 3), activation='relu'))
47 | model.add(ZeroPadding2D((1, 1)))
48 | model.add(Convolution2D(512, (3, 3), activation='relu'))
49 | model.add(ZeroPadding2D((1, 1)))
50 | model.add(Convolution2D(512, (3, 3), activation='relu'))
51 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
52 |
53 | model.add(Convolution2D(4096, (7, 7), activation='relu'))
54 | model.add(Dropout(0.5))
55 | model.add(Convolution2D(4096, (1, 1), activation='relu'))
56 | model.add(Dropout(0.5))
57 | model.add(Convolution2D(2622, (1, 1)))
58 | model.add(Flatten())
59 | model.add(Activation('softmax'))
60 |
61 | return model
62 |
63 |
64 | def preprocess_image(image_path):
65 | img = image.load_img(image_path, target_size=(224, 224))
66 | img = image.img_to_array(img)
67 | img = np.expand_dims(img, axis=0)
68 | img = preprocess_input(img)
69 | return img
70 |
--------------------------------------------------------------------------------
/mm_images/.gitignore:
--------------------------------------------------------------------------------
1 | *.jpg
2 | *.png
3 | *.jpeg
4 |
--------------------------------------------------------------------------------
/mm_images/faces/.gitignore:
--------------------------------------------------------------------------------
1 | *.jpg
2 | *.png
3 | *.jpeg
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | flask
2 | pillow
3 | pymongo
4 | scikit-learn
5 | Tensorflow==1.15.2
6 | Keras
7 | face_recognition
--------------------------------------------------------------------------------
/search_test_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nladuo/MMFinder/6650595007135c37f9bb0a1be9b8db23333545e3/search_test_result.jpg
--------------------------------------------------------------------------------
/test_img.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nladuo/MMFinder/6650595007135c37f9bb0a1be9b8db23333545e3/test_img.jpeg
--------------------------------------------------------------------------------
/web_demo/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | node_modules
3 |
4 | # local env files
5 | .env.local
6 | .env.*.local
7 |
8 | # Log files
9 | npm-debug.log*
10 | yarn-debug.log*
11 | yarn-error.log*
12 |
13 | # Editor directories and files
14 | .idea
15 | .vscode
16 | *.suo
17 | *.ntvs*
18 | *.njsproj
19 | *.sln
20 | *.sw?
21 |
--------------------------------------------------------------------------------
/web_demo/README.md:
--------------------------------------------------------------------------------
1 | # image search web demo
2 |
3 | > MMFinder Image Search Web Demo.
4 |
5 | ## Project setup
6 | ```
7 | npm install
8 | ```
9 |
10 | ### Compiles and hot-reloads for development
11 | ```
12 | npm run serve
13 | ```
14 |
15 | ### Compiles and minifies for production
16 | ```
17 | npm run build
18 | ```
19 |
20 | ### Customize configuration
21 | See [Configuration Reference](https://cli.vuejs.org/config/).
22 |
--------------------------------------------------------------------------------
/web_demo/babel.config.js:
--------------------------------------------------------------------------------
1 | module.exports = {
2 | presets: [
3 | '@vue/cli-plugin-babel/preset'
4 | ]
5 | }
6 |
--------------------------------------------------------------------------------
/web_demo/backend/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/web_demo/backend/SPTAG_rpc_search_client.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import rpyc
3 |
4 |
5 | class DataBean:
6 | def __init__(self, _id: str, vec: np.array):
7 | self._id = _id
8 | self.vec = vec
9 | if '\n' in _id:
10 | raise Exception("_id cannot contain \\n")
11 |
12 | if len(vec.shape) != 1:
13 | raise Exception("vec must be 1-d vector")
14 |
15 | if vec.dtype != np.float32:
16 | raise Exception("the dtype of vec must be np.float32")
17 |
18 |
19 | class SPTAG_RpcSearchClient:
20 |
21 | ALGO_BKT = "BKT" # SPTAG-BKT is advantageous in search accuracy in very high-dimensional data
22 | ALGO_KDT = "KDT" # SPTAG-KDT is advantageous in index building cost,
23 |
24 | DIST_L2 = "L2"
25 | DIST_Cosine = "Cosine"
26 |
27 | def __init__(self, host, port):
28 | c = rpyc.connect(host, port)
29 | c._config['sync_request_timeout'] = None
30 | self.proxy = c.root
31 |
32 | def search(self, beans: [DataBean], p_resultNum):
33 | _, vecs = self.__get_meta_and_vec_from_beans(beans)
34 | vecs_ = vecs.tolist()
35 | return self.proxy.search(vecs_, p_resultNum)
36 |
37 | def __get_meta_and_vec_from_beans(self, beans: [DataBean]):
38 | if len(beans) == 0:
39 | raise Exception("beans length cannot be zero!")
40 |
41 | if len(beans) > 1000:
42 | raise Exception("cannot add more than 1000 beans at once!")
43 |
44 | dim = beans[0].vec.shape[0]
45 | meta = ""
46 | vecs = np.zeros((len(beans), dim))
47 |
48 | for i, bean in enumerate(beans):
49 | meta += bean._id + '\n'
50 | vecs[i] = bean.vec
51 |
52 | meta = meta.encode()
53 | return meta, vecs
54 |
55 |
56 | if __name__ == "__main__":
57 | client = SPTAG_RpcSearchClient("127.0.0.1", "8888")
58 | print("Test Search")
59 | q = DataBean(_id=f"s{0}", vec=0 * np.ones((10,), dtype=np.float32))
60 | print(client.search([q], 3))
61 |
--------------------------------------------------------------------------------
/web_demo/backend/vgg_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from keras.layers import ZeroPadding2D, Convolution2D, MaxPooling2D, Dropout, Flatten, Activation
3 | from keras import Sequential, Model
4 |
5 |
6 | def get_feature_extractor():
7 | model = build_model()
8 | model.load_weights('../data_preprocess/vgg_face_weights.h5')
9 | vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
10 | return vgg_face_descriptor
11 |
12 |
13 | def build_model():
14 | model = Sequential()
15 | model.add(ZeroPadding2D((1, 1), input_shape=(224, 224, 3)))
16 | model.add(Convolution2D(64, (3, 3), activation='relu'))
17 | model.add(ZeroPadding2D((1, 1)))
18 | model.add(Convolution2D(64, (3, 3), activation='relu'))
19 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
20 |
21 | model.add(ZeroPadding2D((1, 1)))
22 | model.add(Convolution2D(128, (3, 3), activation='relu'))
23 | model.add(ZeroPadding2D((1, 1)))
24 | model.add(Convolution2D(128, (3, 3), activation='relu'))
25 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
26 |
27 | model.add(ZeroPadding2D((1, 1)))
28 | model.add(Convolution2D(256, (3, 3), activation='relu'))
29 | model.add(ZeroPadding2D((1, 1)))
30 | model.add(Convolution2D(256, (3, 3), activation='relu'))
31 | model.add(ZeroPadding2D((1, 1)))
32 | model.add(Convolution2D(256, (3, 3), activation='relu'))
33 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
34 |
35 | model.add(ZeroPadding2D((1, 1)))
36 | model.add(Convolution2D(512, (3, 3), activation='relu'))
37 | model.add(ZeroPadding2D((1, 1)))
38 | model.add(Convolution2D(512, (3, 3), activation='relu'))
39 | model.add(ZeroPadding2D((1, 1)))
40 | model.add(Convolution2D(512, (3, 3), activation='relu'))
41 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
42 |
43 | model.add(ZeroPadding2D((1, 1)))
44 | model.add(Convolution2D(512, (3, 3), activation='relu'))
45 | model.add(ZeroPadding2D((1, 1)))
46 | model.add(Convolution2D(512, (3, 3), activation='relu'))
47 | model.add(ZeroPadding2D((1, 1)))
48 | model.add(Convolution2D(512, (3, 3), activation='relu'))
49 | model.add(MaxPooling2D((2, 2), strides=(2, 2)))
50 |
51 | model.add(Convolution2D(4096, (7, 7), activation='relu'))
52 | model.add(Dropout(0.5))
53 | model.add(Convolution2D(4096, (1, 1), activation='relu'))
54 | model.add(Dropout(0.5))
55 | model.add(Convolution2D(2622, (1, 1)))
56 | model.add(Flatten())
57 | model.add(Activation('softmax'))
58 |
59 | return model
60 |
61 |
62 |
--------------------------------------------------------------------------------
/web_demo/backend/web_service.py:
--------------------------------------------------------------------------------
1 | from .vgg_model import get_feature_extractor
2 | from .web_utils import preprocess_image
3 | import face_recognition
4 | from PIL import Image
5 | from .SPTAG_rpc_search_client import SPTAG_RpcSearchClient, DataBean
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | graph = tf.Graph()
10 | with graph.as_default():
11 | session = tf.Session()
12 | with session.as_default():
13 | vgg_feature_extractor = get_feature_extractor()
14 | vgg_feature_extractor.predict(np.zeros((1, 224, 224, 3)))
15 |
16 | IMAGES_PATH = "../mm_images"
17 | UPLOAD_DIR = './dist/static/upload_images'
18 |
19 |
20 | def get_face_and_save(filename):
21 | img_path = f"{UPLOAD_DIR}/{filename}"
22 | image = face_recognition.load_image_file(img_path)
23 | locations = face_recognition.face_locations(image)
24 | if len(locations) == 1: # save the face of mm
25 | top, right, bottom, left = locations[0]
26 | face_image = image[top:bottom, left:right]
27 | pil_image = Image.fromarray(face_image)
28 | with open(f"{UPLOAD_DIR}/face-{filename}", "wb") as f:
29 | pil_image.save(f)
30 | return len(locations)
31 |
32 |
33 | def get_face_representation(filename):
34 | face_img_path = f"{UPLOAD_DIR}/face-{filename}"
35 | img = preprocess_image(face_img_path)
36 | with graph.as_default():
37 | with session.as_default():
38 | features = vgg_feature_extractor.predict(img)
39 | vec = features[0].tolist()
40 | return vec
41 |
42 |
43 | def call_SPTAG_search(vec):
44 | vec = np.array(vec, dtype=np.float32)
45 | search_client = SPTAG_RpcSearchClient("127.0.0.1", "8888")
46 | bean = DataBean(_id="", vec=vec)
47 | results = search_client.search([bean], 30)
48 | result_images = []
49 | for item in results[0]:
50 | k = [i for i in item.keys()][0]
51 | print(k, item[k])
52 | result_images.append(k)
53 | return result_images
54 |
--------------------------------------------------------------------------------
/web_demo/backend/web_utils.py:
--------------------------------------------------------------------------------
1 | from keras.preprocessing import image
2 | from keras.applications.vgg19 import preprocess_input
3 | import numpy as np
4 | import os
5 | import uuid
6 | from PIL import Image
7 |
8 |
9 | def get_file_extension(filename):
10 | if "." not in filename:
11 | return ""
12 |
13 | return filename.rsplit('.', 1)[1].lower()
14 |
15 |
16 | def allowed_file(filename):
17 | ALLOWED_EXTENSIONS = [
18 | "png",
19 | "jpg",
20 | "jpeg",
21 | ]
22 | return get_file_extension(filename) in ALLOWED_EXTENSIONS
23 |
24 |
25 | def preprocess_image(image_path):
26 | img = image.load_img(image_path, target_size=(224, 224))
27 | img = image.img_to_array(img)
28 | img = np.expand_dims(img, axis=0)
29 | img = preprocess_input(img)
30 | return img
31 |
32 |
33 | def save_upload_file(original_name, file):
34 | UPLOAD_DIR = "./dist/static/upload_images"
35 |
36 | _, ext = os.path.splitext(original_name)
37 |
38 | encrypted_name = str(uuid.uuid4()) + ext
39 |
40 | print(encrypted_name)
41 |
42 | file.save(os.path.join(UPLOAD_DIR, encrypted_name))
43 | return encrypted_name
44 |
45 |
46 | def get_image_scale(image_name):
47 | path = f"../mm_images/{image_name}"
48 | size = Image.open(path).size
49 | return size[1] / size[0]
50 |
51 |
52 | def re_arrange_images(image_names):
53 | """
54 | 因为css是纵向排列的,但一般要横着看,所以这里把图片重新排序。
55 | (本来应该在前端写的,不过我不会写css,就代码改了)
56 | """
57 | image_scales = [get_image_scale(name) for name in image_names]
58 | columns = {}
59 | scales = []
60 | for i in range(5):
61 | columns[i] = []
62 | scales.append(0)
63 |
64 | for i, name in enumerate(image_names):
65 | which_col = scales.index(min(scales))
66 | image_scale = image_scales[i]
67 |
68 | columns[which_col].append(name)
69 | scales[which_col] += image_scale
70 |
71 | results = []
72 | for i in range(5):
73 | results += columns[i]
74 |
75 | return results
76 |
--------------------------------------------------------------------------------
/web_demo/dist/index.html:
--------------------------------------------------------------------------------
1 |