├── .gitignore ├── LICENSE ├── README.md ├── data_preprocess ├── .gitignore ├── feature_extraction.py └── filter_images.py ├── demo_result.png ├── index_construction ├── .gitignore ├── SPTAG_rpc_search_client.py ├── export_SPTAG_indexbuilder_input.py ├── search_test.py └── vgg_model.py ├── mm_images ├── .gitignore └── faces │ └── .gitignore ├── requirements.txt ├── search_test_result.jpg ├── test_img.jpeg └── web_demo ├── .gitignore ├── README.md ├── babel.config.js ├── backend ├── .gitignore ├── SPTAG_rpc_search_client.py ├── vgg_model.py ├── web_service.py └── web_utils.py ├── dist ├── index.html └── static │ ├── css │ └── app.af6337e4.css │ ├── fonts │ ├── element-icons.535877f5.woff │ └── element-icons.732389de.ttf │ ├── js │ ├── app.3d295f3c.js │ ├── app.3d295f3c.js.map │ ├── chunk-vendors.c354b2dd.js │ └── chunk-vendors.c354b2dd.js.map │ └── upload_images │ └── .gitignore ├── main.py ├── package-lock.json ├── package.json ├── public ├── index.html └── static │ └── upload_images │ └── .gitignore ├── src ├── App.vue ├── actions │ └── api.js ├── components │ └── SearchView.vue ├── element-variables.scss ├── main.js └── plugins │ └── element.js └── vue.config.js /.gitignore: -------------------------------------------------------------------------------- 1 | build-desktop-client-*/ 2 | .vscode/ 3 | .idea/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Kalen Blue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MMFinder 2 | 一个美女图搜索应用的demo。 3 | 4 | > 新增了ElasticSearch版本,配置简单,详细说明请查看:[elasticsearch分支](https://github.com/nladuo/MMFinder/tree/elasticsearch)。 5 | 6 | ## 环境 7 | python3.6以上 + mongodb + SPTAG 8 | 9 | 10 | ## 安装依赖 11 | ### 安装Dlib 12 | #### Mac or Linux 13 | 需要先安装好Cmake再通过pip安装即可。 14 | ``` 15 | pip3 install dlib 16 | ``` 17 | 18 | #### 对于Windows 19 | 对于python3.6可以通过whl快速安装。 20 | ``` 21 | pip install https://pypi.python.org/packages/da/06/bd3e241c4eb0a662914b3b4875fc52dd176a9db0d4a2c915ac2ad8800e9e/dlib-19.7.0-cp36-cp36m-win_amd64.whl#md5=b7330a5b2d46420343fbed5df69e6a3f 22 | ``` 23 | 其他版本可以参考网上教程。 24 | 25 | ### 安装Python依赖 26 | ``` 27 | cd MMFinder 28 | pip3 install -r requirements.txt 29 | ``` 30 | 31 | ## 数据准备 32 | ### 1. 准备数据 33 | 爬取MM图片数据。 34 | > 如果没有美女图片,可以用我的数据,Google Drive下载链接:[https://drive.google.com/file/d/1shZ3gx9nHPHUgylsZIrvWliwCh9TucAo/view?usp=sharing](https://drive.google.com/file/d/1shZ3gx9nHPHUgylsZIrvWliwCh9TucAo/view?usp=sharing)。解压密码:nladuo。 35 | 36 | ### 2. 过滤图片 37 | 只选出带一个脸的美女图,然后放到mongo里面 38 | ```bash 39 | cd data_prprocess 40 | python3 filter_images.py 41 | ``` 42 | 43 | ## 特征工程 44 | 通过VGG-net对人脸图片特征提取,转换成dense-vector。 45 | ### 1. 下载VGG预训练模型 46 | Google Drive:https://drive.google.com/file/d/1CPSeum3HpopfomUEK1gybeuIVoeJT_Eo/view?usp=sharing] 47 |
48 | 百度云链接:https://pan.baidu.com/s/1Dk40tW2lx1ezTda9IyIO9g 密码:0vc7 49 | ### 2. 使用VGG提取特征并构建数据集 50 | ```bash 51 | cd data_prprocess 52 | python3 feature_extraction.py 53 | ``` 54 | 55 | ## 建立索引 56 | ### 1. 安装SPTAG,并启动Rpc服务 57 | 见:[Docker下SPTAG的安装与测试](https://www.jianshu.com/p/fcedf00eac32) 58 | 59 | 60 | ### 2. 对图片建立索引 61 | ``` 62 | cd index_construction 63 | python3 export_SPTAG_indexbuilder_input.py 64 | ``` 65 | 拷贝mm_index_input.txt到docker容器中 66 | ```bash 67 | docker cp mm_index_input.txt 25042d741f07:/app/Release/ 68 | ``` 69 | 70 | 进入SPTAG的docker容器中,建立索引 71 | ```bash 72 | docker attach 25042d741f07 73 | ./indexbuilder -d 2622 -v Float -i ./mm_index_input.txt -o data/mm_index -a BKT -t 2 74 | ``` 75 | 76 | 启动SPTAG搜索服务 77 | ```bash 78 | python3 SPTAG_rpc_search_service.py 79 | ``` 80 | 81 | ### 3. 搜索测试 82 | 对于mac用户,可以先安装``imgcat``,然后运行``index_construction/search_test.py``. 83 | 84 | 效果如下: 85 | ![](search_test_result.jpg) 86 | 87 | ## 运行demo 88 | ### 运行演示网站 89 | ``` 90 | cd web_demo 91 | python3 main.py 92 | ``` 93 | 94 | ### 测试效果 95 | 打开[http://localhost:3889](http://localhost:3889) 96 | 97 | 上传一张图片测试,效果如下: 98 | ![demo_result](demo_result.png) 99 | ## Reference 100 | - https://sefiks.com/2018/08/06/deep-face-recognition-with-keras/ 101 | 102 | ## LICENSE 103 | MIT 104 | -------------------------------------------------------------------------------- /data_preprocess/.gitignore: -------------------------------------------------------------------------------- 1 | vgg_face_weights.h5 2 | data.pickle -------------------------------------------------------------------------------- /data_preprocess/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from keras.layers import ZeroPadding2D, Convolution2D, MaxPooling2D, Dropout, Flatten, Activation 5 | from keras import Sequential, Model 6 | from keras.preprocessing import image 7 | from keras.applications.vgg19 import preprocess_input 8 | import pymongo 9 | 10 | 11 | def build_model(): 12 | model = Sequential() 13 | model.add(ZeroPadding2D((1, 1), input_shape=(224, 224, 3))) 14 | model.add(Convolution2D(64, (3, 3), activation='relu')) 15 | model.add(ZeroPadding2D((1, 1))) 16 | model.add(Convolution2D(64, (3, 3), activation='relu')) 17 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 18 | 19 | model.add(ZeroPadding2D((1, 1))) 20 | model.add(Convolution2D(128, (3, 3), activation='relu')) 21 | model.add(ZeroPadding2D((1, 1))) 22 | model.add(Convolution2D(128, (3, 3), activation='relu')) 23 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 24 | 25 | model.add(ZeroPadding2D((1, 1))) 26 | model.add(Convolution2D(256, (3, 3), activation='relu')) 27 | model.add(ZeroPadding2D((1, 1))) 28 | model.add(Convolution2D(256, (3, 3), activation='relu')) 29 | model.add(ZeroPadding2D((1, 1))) 30 | model.add(Convolution2D(256, (3, 3), activation='relu')) 31 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 32 | 33 | model.add(ZeroPadding2D((1, 1))) 34 | model.add(Convolution2D(512, (3, 3), activation='relu')) 35 | model.add(ZeroPadding2D((1, 1))) 36 | model.add(Convolution2D(512, (3, 3), activation='relu')) 37 | model.add(ZeroPadding2D((1, 1))) 38 | model.add(Convolution2D(512, (3, 3), activation='relu')) 39 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 40 | 41 | model.add(ZeroPadding2D((1, 1))) 42 | model.add(Convolution2D(512, (3, 3), activation='relu')) 43 | model.add(ZeroPadding2D((1, 1))) 44 | model.add(Convolution2D(512, (3, 3), activation='relu')) 45 | model.add(ZeroPadding2D((1, 1))) 46 | model.add(Convolution2D(512, (3, 3), activation='relu')) 47 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 48 | 49 | model.add(Convolution2D(4096, (7, 7), activation='relu')) 50 | model.add(Dropout(0.5)) 51 | model.add(Convolution2D(4096, (1, 1), activation='relu')) 52 | model.add(Dropout(0.5)) 53 | model.add(Convolution2D(2622, (1, 1))) 54 | model.add(Flatten()) 55 | model.add(Activation('softmax')) 56 | 57 | return model 58 | 59 | 60 | def preprocess_image(image_path): 61 | img = image.load_img(image_path, target_size=(224, 224)) 62 | img = image.img_to_array(img) 63 | img = np.expand_dims(img, axis=0) 64 | img = preprocess_input(img) 65 | return img 66 | 67 | 68 | model = build_model() 69 | model.load_weights('vgg_face_weights.h5') 70 | vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output) 71 | 72 | 73 | client = pymongo.MongoClient() 74 | db = client.MMFinder 75 | images_coll = db.images 76 | 77 | image_datas = [] 78 | 79 | IMAGES_PATH = "../mm_images" 80 | 81 | for image_data in images_coll.find({}): 82 | # if "vec" not in image_data: 83 | image_datas.append(image_data) 84 | 85 | 86 | for i, image_data in enumerate(image_datas): 87 | img_path = f"{IMAGES_PATH}/faces/face-{image_data['path']}" 88 | img = preprocess_image(img_path) 89 | features = vgg_face_descriptor.predict(img) 90 | print(i, img_path) 91 | vec = features[0].tolist() 92 | images_coll.update({"_id": image_data["_id"]}, { 93 | "$set": { 94 | "vec": vec 95 | } 96 | }) 97 | 98 | -------------------------------------------------------------------------------- /data_preprocess/filter_images.py: -------------------------------------------------------------------------------- 1 | import pymongo 2 | import face_recognition 3 | from PIL import Image 4 | import os 5 | 6 | client = pymongo.MongoClient() 7 | db = client.MMFinder 8 | images_coll = db.images 9 | 10 | 11 | IMAGES_PATH = "../mm_images" 12 | 13 | 14 | def get_face_and_save(path): 15 | image_path = f'{IMAGES_PATH}/{path}' 16 | image = face_recognition.load_image_file(image_path) 17 | locations = face_recognition.face_locations(image) 18 | if len(locations) == 1: # save the face of mm 19 | top, right, bottom, left = locations[0] 20 | face_image = image[top:bottom, left:right] 21 | pil_image = Image.fromarray(face_image) 22 | with open(f'{IMAGES_PATH}/faces/face-{path}', "wb") as f: 23 | pil_image.save(f) 24 | return len(locations) 25 | 26 | 27 | def check_file_type(path): 28 | allow_types = [".png", ".jpg", ".jpeg"] 29 | for t in allow_types: 30 | if path.endswith(t): 31 | return True 32 | return False 33 | 34 | 35 | for i, path in enumerate(os.listdir(IMAGES_PATH)): 36 | if not check_file_type(path): 37 | continue 38 | 39 | if images_coll.find({"path": path}).count() != 0: 40 | continue 41 | 42 | print(i, path) 43 | try: 44 | if get_face_and_save(path) != 1: 45 | continue 46 | except: 47 | continue 48 | images_coll.insert({ 49 | "path": path 50 | }) 51 | -------------------------------------------------------------------------------- /demo_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nladuo/MMFinder/6650595007135c37f9bb0a1be9b8db23333545e3/demo_result.png -------------------------------------------------------------------------------- /index_construction/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.txt -------------------------------------------------------------------------------- /index_construction/SPTAG_rpc_search_client.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import rpyc 3 | 4 | 5 | class DataBean: 6 | def __init__(self, _id: str, vec: np.array): 7 | self._id = _id 8 | self.vec = vec 9 | if '\n' in _id: 10 | raise Exception("_id cannot contain \\n") 11 | 12 | if len(vec.shape) != 1: 13 | raise Exception("vec must be 1-d vector") 14 | 15 | if vec.dtype != np.float32: 16 | raise Exception("the dtype of vec must be np.float32") 17 | 18 | 19 | class SPTAG_RpcSearchClient: 20 | 21 | ALGO_BKT = "BKT" # SPTAG-BKT is advantageous in search accuracy in very high-dimensional data 22 | ALGO_KDT = "KDT" # SPTAG-KDT is advantageous in index building cost, 23 | 24 | DIST_L2 = "L2" 25 | DIST_Cosine = "Cosine" 26 | 27 | def __init__(self, host, port): 28 | c = rpyc.connect(host, port) 29 | c._config['sync_request_timeout'] = None 30 | self.proxy = c.root 31 | 32 | def search(self, beans: [DataBean], p_resultNum): 33 | _, vecs = self.__get_meta_and_vec_from_beans(beans) 34 | vecs_ = vecs.tolist() 35 | return self.proxy.search(vecs_, p_resultNum) 36 | 37 | def __get_meta_and_vec_from_beans(self, beans: [DataBean]): 38 | if len(beans) == 0: 39 | raise Exception("beans length cannot be zero!") 40 | 41 | if len(beans) > 1000: 42 | raise Exception("cannot add more than 1000 beans at once!") 43 | 44 | dim = beans[0].vec.shape[0] 45 | meta = "" 46 | vecs = np.zeros((len(beans), dim)) 47 | 48 | for i, bean in enumerate(beans): 49 | meta += bean._id + '\n' 50 | vecs[i] = bean.vec 51 | 52 | meta = meta.encode() 53 | return meta, vecs 54 | 55 | 56 | if __name__ == "__main__": 57 | client = SPTAG_RpcSearchClient("127.0.0.1", "8888") 58 | print("Test Search") 59 | q = DataBean(_id=f"s{0}", vec=0 * np.ones((10,), dtype=np.float32)) 60 | print(client.search([q], 3)) 61 | -------------------------------------------------------------------------------- /index_construction/export_SPTAG_indexbuilder_input.py: -------------------------------------------------------------------------------- 1 | import pymongo 2 | import numpy as np 3 | 4 | 5 | client = pymongo.MongoClient() 6 | db = client.MMFinder 7 | images_coll = db.images 8 | 9 | count = 0 10 | 11 | with open("mm_index_input.txt", "w") as f: 12 | for image in images_coll.find(): 13 | path = image["path"] 14 | vec = "|".join([str(i) for i in image["vec"]]) 15 | 16 | print(count, path) 17 | f.write(f"{path}\t{vec}\n") 18 | count += 1 19 | -------------------------------------------------------------------------------- /index_construction/search_test.py: -------------------------------------------------------------------------------- 1 | from SPTAG_rpc_search_client import SPTAG_RpcSearchClient, DataBean 2 | import pymongo 3 | import os 4 | import random 5 | from vgg_model import get_feature_extractor, preprocess_image 6 | import numpy as np 7 | 8 | image_names = [] 9 | for i, path in enumerate(os.listdir("../mm_images")): 10 | if "jpg" in path: 11 | image_names.append(path) 12 | if i > 100: 13 | break 14 | 15 | random.shuffle(image_names) 16 | test_img = image_names[0] 17 | test_face_path = f"../mm_images/faces/face-{test_img}" 18 | 19 | search_client = SPTAG_RpcSearchClient("127.0.0.1", "8888") 20 | 21 | 22 | client = pymongo.MongoClient() 23 | db = client.MMFinder 24 | images_coll = db.images 25 | 26 | vgg_feature_extractor = get_feature_extractor() 27 | 28 | 29 | def get_face_representation(path): 30 | img = preprocess_image(path) 31 | features = vgg_feature_extractor.predict(img) 32 | vec = features[0].tolist() 33 | return vec 34 | 35 | 36 | vec = np.array(get_face_representation(test_face_path), dtype=np.float32) 37 | bean = DataBean(_id="", vec=vec) 38 | os.system(f"imgcat {test_face_path}") 39 | results = search_client.search([bean], 20) 40 | for item in results[0]: 41 | k = [i for i in item.keys()][0] 42 | print(k, item[k]) 43 | os.system(f"imgcat ../mm_images/faces/face-{k}") 44 | -------------------------------------------------------------------------------- /index_construction/vgg_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.layers import ZeroPadding2D, Convolution2D, MaxPooling2D, Dropout, Flatten, Activation 3 | from keras import Sequential, Model 4 | from keras.preprocessing import image 5 | from keras.applications.vgg19 import preprocess_input 6 | 7 | 8 | def get_feature_extractor(): 9 | model = build_model() 10 | model.load_weights('../data_preprocess/vgg_face_weights.h5') 11 | vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output) 12 | return vgg_face_descriptor 13 | 14 | 15 | def build_model(): 16 | model = Sequential() 17 | model.add(ZeroPadding2D((1, 1), input_shape=(224, 224, 3))) 18 | model.add(Convolution2D(64, (3, 3), activation='relu')) 19 | model.add(ZeroPadding2D((1, 1))) 20 | model.add(Convolution2D(64, (3, 3), activation='relu')) 21 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 22 | 23 | model.add(ZeroPadding2D((1, 1))) 24 | model.add(Convolution2D(128, (3, 3), activation='relu')) 25 | model.add(ZeroPadding2D((1, 1))) 26 | model.add(Convolution2D(128, (3, 3), activation='relu')) 27 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 28 | 29 | model.add(ZeroPadding2D((1, 1))) 30 | model.add(Convolution2D(256, (3, 3), activation='relu')) 31 | model.add(ZeroPadding2D((1, 1))) 32 | model.add(Convolution2D(256, (3, 3), activation='relu')) 33 | model.add(ZeroPadding2D((1, 1))) 34 | model.add(Convolution2D(256, (3, 3), activation='relu')) 35 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 36 | 37 | model.add(ZeroPadding2D((1, 1))) 38 | model.add(Convolution2D(512, (3, 3), activation='relu')) 39 | model.add(ZeroPadding2D((1, 1))) 40 | model.add(Convolution2D(512, (3, 3), activation='relu')) 41 | model.add(ZeroPadding2D((1, 1))) 42 | model.add(Convolution2D(512, (3, 3), activation='relu')) 43 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 44 | 45 | model.add(ZeroPadding2D((1, 1))) 46 | model.add(Convolution2D(512, (3, 3), activation='relu')) 47 | model.add(ZeroPadding2D((1, 1))) 48 | model.add(Convolution2D(512, (3, 3), activation='relu')) 49 | model.add(ZeroPadding2D((1, 1))) 50 | model.add(Convolution2D(512, (3, 3), activation='relu')) 51 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 52 | 53 | model.add(Convolution2D(4096, (7, 7), activation='relu')) 54 | model.add(Dropout(0.5)) 55 | model.add(Convolution2D(4096, (1, 1), activation='relu')) 56 | model.add(Dropout(0.5)) 57 | model.add(Convolution2D(2622, (1, 1))) 58 | model.add(Flatten()) 59 | model.add(Activation('softmax')) 60 | 61 | return model 62 | 63 | 64 | def preprocess_image(image_path): 65 | img = image.load_img(image_path, target_size=(224, 224)) 66 | img = image.img_to_array(img) 67 | img = np.expand_dims(img, axis=0) 68 | img = preprocess_input(img) 69 | return img 70 | -------------------------------------------------------------------------------- /mm_images/.gitignore: -------------------------------------------------------------------------------- 1 | *.jpg 2 | *.png 3 | *.jpeg 4 | -------------------------------------------------------------------------------- /mm_images/faces/.gitignore: -------------------------------------------------------------------------------- 1 | *.jpg 2 | *.png 3 | *.jpeg 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | pillow 3 | pymongo 4 | scikit-learn 5 | Tensorflow==1.15.2 6 | Keras 7 | face_recognition -------------------------------------------------------------------------------- /search_test_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nladuo/MMFinder/6650595007135c37f9bb0a1be9b8db23333545e3/search_test_result.jpg -------------------------------------------------------------------------------- /test_img.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nladuo/MMFinder/6650595007135c37f9bb0a1be9b8db23333545e3/test_img.jpeg -------------------------------------------------------------------------------- /web_demo/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | node_modules 3 | 4 | # local env files 5 | .env.local 6 | .env.*.local 7 | 8 | # Log files 9 | npm-debug.log* 10 | yarn-debug.log* 11 | yarn-error.log* 12 | 13 | # Editor directories and files 14 | .idea 15 | .vscode 16 | *.suo 17 | *.ntvs* 18 | *.njsproj 19 | *.sln 20 | *.sw? 21 | -------------------------------------------------------------------------------- /web_demo/README.md: -------------------------------------------------------------------------------- 1 | # image search web demo 2 | 3 | > MMFinder Image Search Web Demo. 4 | 5 | ## Project setup 6 | ``` 7 | npm install 8 | ``` 9 | 10 | ### Compiles and hot-reloads for development 11 | ``` 12 | npm run serve 13 | ``` 14 | 15 | ### Compiles and minifies for production 16 | ``` 17 | npm run build 18 | ``` 19 | 20 | ### Customize configuration 21 | See [Configuration Reference](https://cli.vuejs.org/config/). 22 | -------------------------------------------------------------------------------- /web_demo/babel.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | presets: [ 3 | '@vue/cli-plugin-babel/preset' 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /web_demo/backend/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /web_demo/backend/SPTAG_rpc_search_client.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import rpyc 3 | 4 | 5 | class DataBean: 6 | def __init__(self, _id: str, vec: np.array): 7 | self._id = _id 8 | self.vec = vec 9 | if '\n' in _id: 10 | raise Exception("_id cannot contain \\n") 11 | 12 | if len(vec.shape) != 1: 13 | raise Exception("vec must be 1-d vector") 14 | 15 | if vec.dtype != np.float32: 16 | raise Exception("the dtype of vec must be np.float32") 17 | 18 | 19 | class SPTAG_RpcSearchClient: 20 | 21 | ALGO_BKT = "BKT" # SPTAG-BKT is advantageous in search accuracy in very high-dimensional data 22 | ALGO_KDT = "KDT" # SPTAG-KDT is advantageous in index building cost, 23 | 24 | DIST_L2 = "L2" 25 | DIST_Cosine = "Cosine" 26 | 27 | def __init__(self, host, port): 28 | c = rpyc.connect(host, port) 29 | c._config['sync_request_timeout'] = None 30 | self.proxy = c.root 31 | 32 | def search(self, beans: [DataBean], p_resultNum): 33 | _, vecs = self.__get_meta_and_vec_from_beans(beans) 34 | vecs_ = vecs.tolist() 35 | return self.proxy.search(vecs_, p_resultNum) 36 | 37 | def __get_meta_and_vec_from_beans(self, beans: [DataBean]): 38 | if len(beans) == 0: 39 | raise Exception("beans length cannot be zero!") 40 | 41 | if len(beans) > 1000: 42 | raise Exception("cannot add more than 1000 beans at once!") 43 | 44 | dim = beans[0].vec.shape[0] 45 | meta = "" 46 | vecs = np.zeros((len(beans), dim)) 47 | 48 | for i, bean in enumerate(beans): 49 | meta += bean._id + '\n' 50 | vecs[i] = bean.vec 51 | 52 | meta = meta.encode() 53 | return meta, vecs 54 | 55 | 56 | if __name__ == "__main__": 57 | client = SPTAG_RpcSearchClient("127.0.0.1", "8888") 58 | print("Test Search") 59 | q = DataBean(_id=f"s{0}", vec=0 * np.ones((10,), dtype=np.float32)) 60 | print(client.search([q], 3)) 61 | -------------------------------------------------------------------------------- /web_demo/backend/vgg_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.layers import ZeroPadding2D, Convolution2D, MaxPooling2D, Dropout, Flatten, Activation 3 | from keras import Sequential, Model 4 | 5 | 6 | def get_feature_extractor(): 7 | model = build_model() 8 | model.load_weights('../data_preprocess/vgg_face_weights.h5') 9 | vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output) 10 | return vgg_face_descriptor 11 | 12 | 13 | def build_model(): 14 | model = Sequential() 15 | model.add(ZeroPadding2D((1, 1), input_shape=(224, 224, 3))) 16 | model.add(Convolution2D(64, (3, 3), activation='relu')) 17 | model.add(ZeroPadding2D((1, 1))) 18 | model.add(Convolution2D(64, (3, 3), activation='relu')) 19 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 20 | 21 | model.add(ZeroPadding2D((1, 1))) 22 | model.add(Convolution2D(128, (3, 3), activation='relu')) 23 | model.add(ZeroPadding2D((1, 1))) 24 | model.add(Convolution2D(128, (3, 3), activation='relu')) 25 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 26 | 27 | model.add(ZeroPadding2D((1, 1))) 28 | model.add(Convolution2D(256, (3, 3), activation='relu')) 29 | model.add(ZeroPadding2D((1, 1))) 30 | model.add(Convolution2D(256, (3, 3), activation='relu')) 31 | model.add(ZeroPadding2D((1, 1))) 32 | model.add(Convolution2D(256, (3, 3), activation='relu')) 33 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 34 | 35 | model.add(ZeroPadding2D((1, 1))) 36 | model.add(Convolution2D(512, (3, 3), activation='relu')) 37 | model.add(ZeroPadding2D((1, 1))) 38 | model.add(Convolution2D(512, (3, 3), activation='relu')) 39 | model.add(ZeroPadding2D((1, 1))) 40 | model.add(Convolution2D(512, (3, 3), activation='relu')) 41 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 42 | 43 | model.add(ZeroPadding2D((1, 1))) 44 | model.add(Convolution2D(512, (3, 3), activation='relu')) 45 | model.add(ZeroPadding2D((1, 1))) 46 | model.add(Convolution2D(512, (3, 3), activation='relu')) 47 | model.add(ZeroPadding2D((1, 1))) 48 | model.add(Convolution2D(512, (3, 3), activation='relu')) 49 | model.add(MaxPooling2D((2, 2), strides=(2, 2))) 50 | 51 | model.add(Convolution2D(4096, (7, 7), activation='relu')) 52 | model.add(Dropout(0.5)) 53 | model.add(Convolution2D(4096, (1, 1), activation='relu')) 54 | model.add(Dropout(0.5)) 55 | model.add(Convolution2D(2622, (1, 1))) 56 | model.add(Flatten()) 57 | model.add(Activation('softmax')) 58 | 59 | return model 60 | 61 | 62 | -------------------------------------------------------------------------------- /web_demo/backend/web_service.py: -------------------------------------------------------------------------------- 1 | from .vgg_model import get_feature_extractor 2 | from .web_utils import preprocess_image 3 | import face_recognition 4 | from PIL import Image 5 | from .SPTAG_rpc_search_client import SPTAG_RpcSearchClient, DataBean 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | graph = tf.Graph() 10 | with graph.as_default(): 11 | session = tf.Session() 12 | with session.as_default(): 13 | vgg_feature_extractor = get_feature_extractor() 14 | vgg_feature_extractor.predict(np.zeros((1, 224, 224, 3))) 15 | 16 | IMAGES_PATH = "../mm_images" 17 | UPLOAD_DIR = './dist/static/upload_images' 18 | 19 | 20 | def get_face_and_save(filename): 21 | img_path = f"{UPLOAD_DIR}/{filename}" 22 | image = face_recognition.load_image_file(img_path) 23 | locations = face_recognition.face_locations(image) 24 | if len(locations) == 1: # save the face of mm 25 | top, right, bottom, left = locations[0] 26 | face_image = image[top:bottom, left:right] 27 | pil_image = Image.fromarray(face_image) 28 | with open(f"{UPLOAD_DIR}/face-{filename}", "wb") as f: 29 | pil_image.save(f) 30 | return len(locations) 31 | 32 | 33 | def get_face_representation(filename): 34 | face_img_path = f"{UPLOAD_DIR}/face-{filename}" 35 | img = preprocess_image(face_img_path) 36 | with graph.as_default(): 37 | with session.as_default(): 38 | features = vgg_feature_extractor.predict(img) 39 | vec = features[0].tolist() 40 | return vec 41 | 42 | 43 | def call_SPTAG_search(vec): 44 | vec = np.array(vec, dtype=np.float32) 45 | search_client = SPTAG_RpcSearchClient("127.0.0.1", "8888") 46 | bean = DataBean(_id="", vec=vec) 47 | results = search_client.search([bean], 30) 48 | result_images = [] 49 | for item in results[0]: 50 | k = [i for i in item.keys()][0] 51 | print(k, item[k]) 52 | result_images.append(k) 53 | return result_images 54 | -------------------------------------------------------------------------------- /web_demo/backend/web_utils.py: -------------------------------------------------------------------------------- 1 | from keras.preprocessing import image 2 | from keras.applications.vgg19 import preprocess_input 3 | import numpy as np 4 | import os 5 | import uuid 6 | from PIL import Image 7 | 8 | 9 | def get_file_extension(filename): 10 | if "." not in filename: 11 | return "" 12 | 13 | return filename.rsplit('.', 1)[1].lower() 14 | 15 | 16 | def allowed_file(filename): 17 | ALLOWED_EXTENSIONS = [ 18 | "png", 19 | "jpg", 20 | "jpeg", 21 | ] 22 | return get_file_extension(filename) in ALLOWED_EXTENSIONS 23 | 24 | 25 | def preprocess_image(image_path): 26 | img = image.load_img(image_path, target_size=(224, 224)) 27 | img = image.img_to_array(img) 28 | img = np.expand_dims(img, axis=0) 29 | img = preprocess_input(img) 30 | return img 31 | 32 | 33 | def save_upload_file(original_name, file): 34 | UPLOAD_DIR = "./dist/static/upload_images" 35 | 36 | _, ext = os.path.splitext(original_name) 37 | 38 | encrypted_name = str(uuid.uuid4()) + ext 39 | 40 | print(encrypted_name) 41 | 42 | file.save(os.path.join(UPLOAD_DIR, encrypted_name)) 43 | return encrypted_name 44 | 45 | 46 | def get_image_scale(image_name): 47 | path = f"../mm_images/{image_name}" 48 | size = Image.open(path).size 49 | return size[1] / size[0] 50 | 51 | 52 | def re_arrange_images(image_names): 53 | """ 54 | 因为css是纵向排列的,但一般要横着看,所以这里把图片重新排序。 55 | (本来应该在前端写的,不过我不会写css,就代码改了) 56 | """ 57 | image_scales = [get_image_scale(name) for name in image_names] 58 | columns = {} 59 | scales = [] 60 | for i in range(5): 61 | columns[i] = [] 62 | scales.append(0) 63 | 64 | for i, name in enumerate(image_names): 65 | which_col = scales.index(min(scales)) 66 | image_scale = image_scales[i] 67 | 68 | columns[which_col].append(name) 69 | scales[which_col] += image_scale 70 | 71 | results = [] 72 | for i in range(5): 73 | results += columns[i] 74 | 75 | return results 76 | -------------------------------------------------------------------------------- /web_demo/dist/index.html: -------------------------------------------------------------------------------- 1 | MMFinder Web Demo
-------------------------------------------------------------------------------- /web_demo/dist/static/fonts/element-icons.535877f5.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nladuo/MMFinder/6650595007135c37f9bb0a1be9b8db23333545e3/web_demo/dist/static/fonts/element-icons.535877f5.woff -------------------------------------------------------------------------------- /web_demo/dist/static/fonts/element-icons.732389de.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nladuo/MMFinder/6650595007135c37f9bb0a1be9b8db23333545e3/web_demo/dist/static/fonts/element-icons.732389de.ttf -------------------------------------------------------------------------------- /web_demo/dist/static/js/app.3d295f3c.js: -------------------------------------------------------------------------------- 1 | (function(e){function t(t){for(var r,c,i=t[0],o=t[1],l=t[2],p=0,f=[];p\n
\n \n \n
\n
\n
\n
\n \n
\n
\n
\n \n \n \n \n \n
\n
\n
\n
\n \n
\n
\n 搜索\n
\n
\n
\n
\n\n \n\n
\n\n\n\n\n\n\n","import mod from \"-!../../node_modules/cache-loader/dist/cjs.js??ref--12-0!../../node_modules/thread-loader/dist/cjs.js!../../node_modules/babel-loader/lib/index.js!../../node_modules/cache-loader/dist/cjs.js??ref--0-0!../../node_modules/vue-loader/lib/index.js??vue-loader-options!./SearchView.vue?vue&type=script&lang=js&\"; export default mod; export * from \"-!../../node_modules/cache-loader/dist/cjs.js??ref--12-0!../../node_modules/thread-loader/dist/cjs.js!../../node_modules/babel-loader/lib/index.js!../../node_modules/cache-loader/dist/cjs.js??ref--0-0!../../node_modules/vue-loader/lib/index.js??vue-loader-options!./SearchView.vue?vue&type=script&lang=js&\"","import { render, staticRenderFns } from \"./SearchView.vue?vue&type=template&id=32e9f8c0&scoped=true&\"\nimport script from \"./SearchView.vue?vue&type=script&lang=js&\"\nexport * from \"./SearchView.vue?vue&type=script&lang=js&\"\nimport style0 from \"./SearchView.vue?vue&type=style&index=0&id=32e9f8c0&scoped=true&lang=css&\"\nimport style1 from \"./SearchView.vue?vue&type=style&index=1&lang=css&\"\n\n\n/* normalize component */\nimport normalizer from \"!../../node_modules/vue-loader/lib/runtime/componentNormalizer.js\"\nvar component = normalizer(\n script,\n render,\n staticRenderFns,\n false,\n null,\n \"32e9f8c0\",\n null\n \n)\n\nexport default component.exports","\n\n\n\n\n\n","import mod from \"-!../node_modules/cache-loader/dist/cjs.js??ref--12-0!../node_modules/thread-loader/dist/cjs.js!../node_modules/babel-loader/lib/index.js!../node_modules/cache-loader/dist/cjs.js??ref--0-0!../node_modules/vue-loader/lib/index.js??vue-loader-options!./App.vue?vue&type=script&lang=js&\"; export default mod; export * from \"-!../node_modules/cache-loader/dist/cjs.js??ref--12-0!../node_modules/thread-loader/dist/cjs.js!../node_modules/babel-loader/lib/index.js!../node_modules/cache-loader/dist/cjs.js??ref--0-0!../node_modules/vue-loader/lib/index.js??vue-loader-options!./App.vue?vue&type=script&lang=js&\"","import { render, staticRenderFns } from \"./App.vue?vue&type=template&id=733f8b47&\"\nimport script from \"./App.vue?vue&type=script&lang=js&\"\nexport * from \"./App.vue?vue&type=script&lang=js&\"\nimport style0 from \"./App.vue?vue&type=style&index=0&lang=css&\"\n\n\n/* normalize component */\nimport normalizer from \"!../node_modules/vue-loader/lib/runtime/componentNormalizer.js\"\nvar component = normalizer(\n script,\n render,\n staticRenderFns,\n false,\n null,\n null,\n null\n \n)\n\nexport default component.exports","import Vue from 'vue'\nimport Element from 'element-ui'\nimport '../element-variables.scss'\n\nVue.use(Element)\n","import Vue from 'vue'\nimport App from './App.vue'\nimport './plugins/element.js'\n\nVue.config.productionTip = false\n\nnew Vue({\n render: h => h(App),\n}).$mount('#app')\n"],"sourceRoot":""} -------------------------------------------------------------------------------- /web_demo/dist/static/upload_images/.gitignore: -------------------------------------------------------------------------------- 1 | *.jpg 2 | *.png 3 | *.jpeg -------------------------------------------------------------------------------- /web_demo/main.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, send_from_directory, request 2 | import json 3 | from backend.web_utils import allowed_file, save_upload_file, re_arrange_images 4 | from backend.web_service import get_face_and_save, get_face_representation, call_SPTAG_search, UPLOAD_DIR 5 | import os 6 | 7 | 8 | app = Flask(__name__, static_folder='dist') 9 | 10 | 11 | @app.route('/') 12 | def index(): 13 | return app.send_static_file('index.html') 14 | 15 | 16 | @app.route('/static/') 17 | def serve_static(path): 18 | return send_from_directory('./dist/static', path) 19 | 20 | 21 | @app.route('/api/get_images/') 22 | def serve_images(path): 23 | return send_from_directory('../mm_images', path) 24 | 25 | 26 | @app.route('/api/get_upload_images/') 27 | def serve_upload_images(path): 28 | return send_from_directory('./dist/static/upload_images', path) 29 | 30 | 31 | @app.route('/api/upload_image', methods=["POST"]) 32 | def api_upload_image(): 33 | if request.method == 'POST': 34 | if 'file' not in request.files: 35 | return json.dumps({'success': False, 'msg': '请求参数错误'}) 36 | file = request.files['file'] 37 | if file.filename == '': 38 | return json.dumps({'success': False, 'msg': '没选择文件'}) 39 | else: 40 | if file and allowed_file(file.filename): 41 | origin_file_name = file.filename 42 | # 保存文件 43 | filename = save_upload_file(origin_file_name, file) 44 | face_count = get_face_and_save(filename) 45 | if face_count == 0: 46 | os.remove(f"{UPLOAD_DIR}/{filename}") 47 | return json.dumps({'success': False, 'msg': '未检测出图片中的人脸'}) 48 | elif face_count > 1: 49 | os.remove(f"{UPLOAD_DIR}/{filename}") 50 | return json.dumps({'success': False, 'msg': '检测出图片不止一张人脸(必须保证上传图片只有一张人脸)'}) 51 | 52 | return json.dumps({'success': True, 'filename': filename, 'msg': '成功'}) 53 | else: 54 | return json.dumps({'success': False, 'msg': '文件类型错误,请上传png,jpg,jpeg格式的图片'}) 55 | 56 | 57 | @app.route('/api/search') 58 | def api_search_image(): 59 | filename = request.args.get("filename") 60 | face_img_path = f"{UPLOAD_DIR}/face-{filename}" 61 | if not os.path.exists(face_img_path): 62 | return json.dumps({ 63 | "success": False, 64 | "msg": "请求图片不存在,参数错误" 65 | }) 66 | 67 | vec = get_face_representation(filename) 68 | image_names = call_SPTAG_search(vec) 69 | 70 | return json.dumps({ 71 | "success": True, 72 | "data": re_arrange_images(image_names) 73 | }) 74 | 75 | 76 | if __name__ == '__main__': 77 | app.run(port=3889, debug=True) 78 | 79 | -------------------------------------------------------------------------------- /web_demo/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "web_demo", 3 | "version": "0.1.0", 4 | "private": true, 5 | "scripts": { 6 | "serve": "vue-cli-service serve", 7 | "build": "vue-cli-service build" 8 | }, 9 | "dependencies": { 10 | "core-js": "^3.3.2", 11 | "element-ui": "^2.4.5", 12 | "jquery": "^3.4.1", 13 | "vue": "^2.6.10" 14 | }, 15 | "devDependencies": { 16 | "@vue/cli-plugin-babel": "^4.0.0", 17 | "@vue/cli-service": "^4.0.0", 18 | "node-sass": "^4.9.2", 19 | "sass-loader": "^7.0.3", 20 | "vue-cli-plugin-element": "^1.0.1", 21 | "vue-template-compiler": "^2.6.10" 22 | }, 23 | "postcss": { 24 | "plugins": { 25 | "autoprefixer": {} 26 | } 27 | }, 28 | "browserslist": [ 29 | "> 1%", 30 | "last 2 versions" 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /web_demo/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | MMFinder Web Demo 8 | 9 | 10 | 13 |
14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /web_demo/public/static/upload_images/.gitignore: -------------------------------------------------------------------------------- 1 | *.jpg 2 | *.png 3 | *.jpeg -------------------------------------------------------------------------------- /web_demo/src/App.vue: -------------------------------------------------------------------------------- 1 | 16 | 17 | 27 | 28 | 36 | 37 | -------------------------------------------------------------------------------- /web_demo/src/actions/api.js: -------------------------------------------------------------------------------- 1 | import $ from "jquery"; 2 | 3 | export default { 4 | get(url, data, _emit) { 5 | url = '/api' + url; 6 | 7 | $.ajax({ 8 | type : "GET", 9 | url : url, 10 | dataType: 'json', 11 | data : data, 12 | async: true, 13 | success(data) { 14 | _emit(data); 15 | }, 16 | error() { 17 | _emit(null); 18 | } 19 | }); 20 | }, 21 | 22 | post(url, data, _emit) { 23 | url = '/api' + url; 24 | 25 | $.ajax({ 26 | type : "POST", 27 | url : url, 28 | dataType: 'json', 29 | data : data, 30 | async: true, 31 | success(data) { 32 | _emit(data); 33 | }, 34 | error() { 35 | _emit(null); 36 | } 37 | }); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /web_demo/src/components/SearchView.vue: -------------------------------------------------------------------------------- 1 | 42 | 43 | 103 | 104 | 110 | 160 | -------------------------------------------------------------------------------- /web_demo/src/element-variables.scss: -------------------------------------------------------------------------------- 1 | /* 2 | Write your variables here. All available variables can be 3 | found in element-ui/packages/theme-chalk/src/common/var.scss. 4 | For example, to overwrite the theme color: 5 | */ 6 | $--color-primary: teal; 7 | 8 | /* icon font path, required */ 9 | $--font-path: '~element-ui/lib/theme-chalk/fonts'; 10 | 11 | @import "~element-ui/packages/theme-chalk/src/index"; 12 | -------------------------------------------------------------------------------- /web_demo/src/main.js: -------------------------------------------------------------------------------- 1 | import Vue from 'vue' 2 | import App from './App.vue' 3 | import './plugins/element.js' 4 | 5 | Vue.config.productionTip = false 6 | 7 | new Vue({ 8 | render: h => h(App), 9 | }).$mount('#app') 10 | -------------------------------------------------------------------------------- /web_demo/src/plugins/element.js: -------------------------------------------------------------------------------- 1 | import Vue from 'vue' 2 | import Element from 'element-ui' 3 | import '../element-variables.scss' 4 | 5 | Vue.use(Element) 6 | -------------------------------------------------------------------------------- /web_demo/vue.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | assetsDir: "static", 3 | devServer: { 4 | proxy: { 5 | '/api': { 6 | target: 'http://127.0.0.1:3889', // 接口的域名 7 | changeOrigin: true, 8 | secure: false 9 | } 10 | } 11 | } 12 | } --------------------------------------------------------------------------------