├── .gitignore ├── LICENSE ├── README.md ├── app ├── __init__.py ├── main │ ├── __init__.py │ ├── errors.py │ ├── forms.py │ └── views.py ├── static │ ├── css │ │ ├── base.css │ │ ├── index.css │ │ └── result.css │ ├── img │ │ ├── close.png │ │ ├── demo.gif │ │ ├── dropfocus.gif │ │ ├── favicon.ico │ │ ├── logo.png │ │ ├── search.png │ │ └── uploading.gif │ ├── js │ │ ├── base.js │ │ ├── index.js │ │ ├── jquery-3.3.1.js │ │ └── result.js │ └── uploads │ │ └── .gitkeep └── templates │ ├── 403.html │ ├── 404.html │ ├── 500.html │ ├── base.html │ ├── csrf_error.html │ ├── index.html │ └── result.html ├── config.py ├── data └── sift │ └── .gitkeep ├── requirements.txt ├── sotu.py ├── utils.py └── vision ├── __init__.py ├── bof.py ├── he.py ├── inv.py ├── sift.py ├── ukbench.py └── wgc.py /.gitignore: -------------------------------------------------------------------------------- 1 | app/static/uploads/ 2 | data/ 3 | venv/ 4 | .vscode/ 5 | __pycache__/ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 zy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sotu 2 | 3 | 利用Flask[^1]框架实现的基于内容的图像检索(Content Based Image Retrieval, CBIR)系统 4 | 5 | ## 要求 6 | 7 | `python`版本:`3.5.2` 8 | 9 | `pip`版本: `9.0.3` 10 | 11 | 相关组件要求见`requirements.txt` 12 | 13 | ## 初始化 14 | 15 | 程序在虚拟环境下运行,首先确保安装`virtualenv`,之后创建并激活虚拟环境,接着在虚拟环境下安装所有必要的组件,并设置相关环境变量的值. 16 | 17 | 在运行应用之前,需要提取图像特征,其中初始数据集使用的是ukbench[^2]数据集的前4096幅图,提取的特征是SIFT的改进版,即RootSIFT特征. 18 | 19 | 在Linux下所有的命令为: 20 | 21 | ```sh 22 | # 复制仓库并进入相应目录 23 | $ git clone https://github.com/zysite/SoTu.git 24 | $ cd SoTu 25 | 26 | # 激活虚拟环境并安装所有必要的组件 27 | $ virtualenv venv 28 | $ . venv/bin/activate 29 | $ pip install -r requirements.txt 30 | 31 | # 设置环境变量FLASK_APP的值 32 | $ export FLASK_APP=sotu.py 33 | 34 | # 提取图像特征 35 | $ flask extract 36 | ``` 37 | 38 | 其中激活虚拟环境的命令在Windows的环境下有所不同: 39 | 40 | ```sh 41 | $ venv\Scripts\activate 42 | ``` 43 | 44 | 如果用cmd设置环境变量,需要用`set`代替上面的`export`. 45 | 46 | 如果用powershell设置,则命令为: 47 | 48 | ```sh 49 | $ $env:FLASK_APP="sotu.py" 50 | ``` 51 | 52 | 最后退出虚拟环境的命令为: 53 | 54 | ```sh 55 | $ deactivate 56 | ``` 57 | 58 | ## 运行 59 | 60 | 检索系统的实现基于特征袋模型(Bag of Feature, BoF),并在此基础上利用了汉明嵌入(Hamming Embedding, HE)方法、弱几何一致性(Weak Geometric Consistency, WGC)约束和基于RANSAC算法的几何重排. 61 | 62 | 运行web应用使用下面的命令,可以指定主机和端口: 63 | 64 | ```sh 65 | $ flask run -h localhost -p 8080 66 | ``` 67 | 68 | 检索系统支持文件上传、拖拽上传和URL上传三种图片上传方式. 69 | 70 | ![demo](app/static/img/demo.gif) 71 | 72 | ## 评估 73 | 74 | 这里使用的评估指标是mAP(mean Average Precision, mAP)指标,执行评估使用下面的命令: 75 | 76 | ```sh 77 | $ flask evaluate 78 | ``` 79 | 80 | 不同方法的评价结果如下表,其中BoF模型设置的聚类数为5000,HE的阈值**ht**为23: 81 | 82 | | methods | mAP | 83 | | :----------------: | :------: | 84 | | *BoF* | 0.713298 | 85 | | *BoF+HE* | 0.878229 | 86 | | *BoF+HE+Reranking* | 0.898573 | 87 | 88 | Jégou提到对于ukbench数据集而言,WGC方法的效果较差[^3],因此评估没有采用WGC方法. 89 | 90 | ## 参考文献 91 | 92 | * Lowe D G. Distinctive image features from scale-invariant keypoints[J]. International journal of computer vision, 2004, 60(2): 91-110. 93 | * Zisserman A. Three things everyone should know to improve object retrieval [C]. IEEE Computer Society Conference on Computer Vision and Pattern Recognition. Rhode Island, USA, 2012:2911-2918. 94 | * Sivic J, Zisserman A. Video google: A text retrieval approach to object matching in videos [C]. IEEE International Conference on Computer Vision, 2003, 2(1470): 1470-1477. 95 | * Jégou H, Douze M, Schmid C. Hamming Embedding and Weak Geometry Consistency for Large Scale Image Search[J]. Proc Eccv, 2008, 5302:304-317. 96 | * Jégou H, Douze M, Schmid C. Improving bag-of-features for large scale image search[J]. International journal of computer vision, 2010, 87(3): 316-336. 97 | * Zhao W L, Wu X, Ngo C W. On the Annotation of Web Videos by Efficient Near-Duplicate Search[J]. IEEE Transactions on Multimedia, 2010, 12(5):448-461. 98 | * Philbin J, Chum O, Isard M, et al. Object retrieval with large vocabularies and fast spatial matching [C]. IEEE Conference on Computer Vision and Pattern Recognition. Minneapolis, USA, 2007: 1-8. 99 | 100 | 101 | 102 | [^1]: http://flask.pocoo.org/docs/0.12/ 103 | [^2]: https://archive.org/download/ukbench/ukbench.zip 104 | [^3]: https://hal.inria.fr/inria-00514760/document 105 | 106 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from flask import Flask 4 | from flask_wtf.csrf import CSRFProtect 5 | 6 | from config import config 7 | from vision import BoF 8 | 9 | bof = BoF() 10 | csrf = CSRFProtect() 11 | 12 | 13 | def create_app(config_name): 14 | app = Flask(__name__) 15 | app.config["SECRET_KEY"] = "12345678" 16 | app.config.from_object(config[config_name]) 17 | config[config_name].init_app(app) 18 | 19 | bof.init_app(app) 20 | csrf.init_app(app) 21 | # 注册蓝本 22 | from .main import main 23 | app.register_blueprint(main) 24 | return app 25 | -------------------------------------------------------------------------------- /app/main/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from flask import Blueprint 4 | 5 | main = Blueprint('main', __name__) 6 | # 避免循环导入依赖 7 | from . import views, errors 8 | -------------------------------------------------------------------------------- /app/main/errors.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from flask import render_template 4 | from flask_wtf.csrf import CSRFError 5 | 6 | from . import main 7 | 8 | 9 | @main.errorhandler(CSRFError) 10 | def handle_csrf_error(e): 11 | return render_template('csrf_error.html', reason=e.description), 400 12 | 13 | 14 | @main.app_errorhandler(403) 15 | def forbidden(e): 16 | return render_template('403.html'), 403 17 | 18 | 19 | @main.app_errorhandler(404) 20 | def page_not_found(e): 21 | return render_template('404.html'), 404 22 | 23 | 24 | @main.app_errorhandler(500) 25 | def internal_server_error(e): 26 | return render_template('500.html'), 500 27 | -------------------------------------------------------------------------------- /app/main/forms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from flask_wtf import FlaskForm 4 | from flask_wtf.file import FileAllowed, FileField, FileRequired 5 | from wtforms import StringField 6 | from wtforms.validators import DataRequired, Regexp 7 | 8 | 9 | class ImgForm(FlaskForm): 10 | fileimg = FileField(validators=[ 11 | FileRequired(), 12 | FileAllowed(['png', 'jpg', 'jpeg', 'gif']) 13 | ]) 14 | 15 | 16 | class URLForm(FlaskForm): 17 | txturl = StringField(validators=[ 18 | DataRequired(), 19 | Regexp(r'(?:http\:|https\:)?\/\/.*\.(?:png|jpg|jpeg|gif)$', 20 | message="Invalid image url.") 21 | ]) 22 | -------------------------------------------------------------------------------- /app/main/views.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import posixpath 5 | 6 | from flask import (current_app, flash, redirect, render_template, request, 7 | send_from_directory, url_for) 8 | from werkzeug.utils import secure_filename 9 | 10 | from utils import download 11 | 12 | from . import main 13 | from .. import bof 14 | from .forms import ImgForm, URLForm 15 | 16 | 17 | @main.route('/', methods=['GET', 'POST']) 18 | def index(): 19 | imgform = ImgForm() 20 | urlform = URLForm() 21 | 22 | if imgform.validate_on_submit(): 23 | file = imgform.fileimg.data 24 | filename = secure_filename(file.filename) 25 | filepath = os.path.join(current_app.config['UPLOAD_DIR'], filename) 26 | if not os.path.exists(filepath): 27 | file.save(filepath) 28 | return redirect(url_for('.result', filename=filename)) 29 | elif urlform.validate_on_submit(): 30 | url = urlform.txturl.data 31 | filename = secure_filename(url.split('/')[-1]) 32 | filepath = os.path.join(current_app.config['UPLOAD_DIR'], filename) 33 | download(url, current_app.config['UPLOAD_DIR'], filename) 34 | if not os.path.exists(filepath): 35 | flash('无法取回指定URL的图片') 36 | return redirect(url_for('.index')) 37 | else: 38 | return redirect(url_for('.result', filename=filename)) 39 | return render_template('index.html') 40 | 41 | 42 | @main.route('/result', methods=['GET']) 43 | def result(): 44 | filename = request.args.get('filename') 45 | uri = os.path.join(current_app.config['UPLOAD_DIR'], filename) 46 | images = bof.match(uri, top_k=20) 47 | return render_template('result.html', filename=filename, images=images) 48 | 49 | 50 | @main.route('/images/') 51 | def download_file(uri): 52 | return send_from_directory(current_app.config['BASE_DIR'], 53 | uri, as_attachment=True) 54 | -------------------------------------------------------------------------------- /app/static/css/base.css: -------------------------------------------------------------------------------- 1 | html, body { 2 | margin: 0; 3 | padding: 0; 4 | height: 100%; 5 | border: 0; 6 | font-size: 14px; 7 | font-family: 'Microsoft Yahei', '微软雅黑', '宋体', Arial; 8 | } 9 | 10 | body h1 { 11 | margin-left: 10px; 12 | color: white; 13 | } 14 | 15 | .wrapper { 16 | position: relative; 17 | display: flex; 18 | flex-direction: column; 19 | min-width: 1280px; 20 | height: 100%; 21 | background-color: #6d6d6d; 22 | } 23 | 24 | .wrapper .header { 25 | height: 60px; 26 | } 27 | .header > a { 28 | float: left; 29 | padding: 5px; 30 | height: 50px; 31 | color: white; 32 | text-decoration: none; 33 | font-weight: bold; 34 | font-size: 35px; 35 | font-family: sans-serif; 36 | line-height: 100%; 37 | } 38 | .header > a img { 39 | vertical-align: text-bottom; 40 | } 41 | 42 | .wrapper .content { 43 | flex: 1; 44 | } 45 | .wrapper .footer { 46 | height: 40px; 47 | } -------------------------------------------------------------------------------- /app/static/css/index.css: -------------------------------------------------------------------------------- 1 | .content .container { 2 | position: absolute; 3 | top: 225px; 4 | left: 50%; 5 | z-index: 8; 6 | display: block; 7 | margin-left: -350px; 8 | width: auto; 9 | height: auto; 10 | color: black; 11 | } 12 | 13 | .container .forms { 14 | position: relative; 15 | margin: 36px 10px 12px 10px; 16 | height: 54px; 17 | } 18 | 19 | .forms #imgform, 20 | .forms #urlform { 21 | display: inline-block; 22 | } 23 | 24 | #imgform { 25 | vertical-align: top; 26 | } 27 | #imgform #upload { 28 | display: inline-block; 29 | width: 84px; 30 | height: 54px; 31 | background-color: white; 32 | text-align: center; 33 | line-height: 54px; 34 | cursor: pointer; 35 | } 36 | #imgform input[type='file'] { 37 | display: none; 38 | } 39 | 40 | #urlform input[type='text'] { 41 | margin-left: 15px; 42 | padding: 4px 40px 4px 10px; 43 | width: 468px; 44 | height: 42px; 45 | outline: none; 46 | border: 2px solid white; 47 | border-radius: 2px; 48 | background-color: transparent; 49 | color: white; 50 | vertical-align: top; 51 | } 52 | #urlform input[type='text'].warn { 53 | border: 2px solid red; 54 | } 55 | #urlform input[type=text]::-ms-clear { 56 | display: none; 57 | } 58 | #urlform ::placeholder { 59 | color: white; 60 | opacity: 1; 61 | } 62 | #urlform :-ms-input-placeholder { 63 | color: white; 64 | } 65 | #urlform ::-ms-input-placeholder { 66 | color: white; 67 | } 68 | 69 | #urlform #btnsubmit { 70 | margin: 15px 15px 15px -40px; 71 | width: 24px; 72 | height: 24px; 73 | outline: none; 74 | border: 0 none; 75 | background-color: transparent; 76 | background-image: url(../img/search.png); 77 | cursor: pointer; 78 | } 79 | 80 | .forms #btnclose { 81 | display: inline-block; 82 | margin: 18px 0 18px 25px; 83 | width: 18px; 84 | height: 18px; 85 | outline: none; 86 | border: 0 none; 87 | background-color: transparent; 88 | background-image: url(../img/close.png); 89 | vertical-align: top; 90 | cursor: pointer; 91 | } 92 | 93 | .container .dragtip { 94 | margin: 0 auto 10px auto; 95 | width: 250px; 96 | height: 45px; 97 | border-bottom: 3px solid white; 98 | color: white; 99 | text-align: center; 100 | font-size: 20px; 101 | line-height: 45px; 102 | } 103 | 104 | .container .dropzone, 105 | .container .uploadtip { 106 | position: absolute; 107 | top: 0; 108 | left: 0; 109 | width: 100%; 110 | height: 100%; 111 | text-align: center; 112 | pointer-events: none; 113 | } 114 | 115 | .dropzone { 116 | z-index: 2; 117 | background: rgba(255, 255, 255, 0.7); 118 | } 119 | .dropzone span { 120 | position: absolute; 121 | top: 50%; 122 | transform: translate(-50%, -50%); 123 | } 124 | 125 | .dropzone .dropfocus { 126 | position: absolute; 127 | display: block; 128 | width: 15px; 129 | height: 15px; 130 | background: url(../img/dropfocus.gif) no-repeat; 131 | } 132 | .dropzone .focus_top_left { 133 | top: 16%; 134 | left: 36%; 135 | background-position: top left; 136 | } 137 | .dropzone .focus_bottom_left { 138 | bottom: 16%; 139 | left: 36%; 140 | background-position: bottom left; 141 | } 142 | .dropzone .focus_top_right { 143 | top: 16%; 144 | right: 36%; 145 | background-position: top right; 146 | } 147 | .dropzone .focus_bottom_right { 148 | right: 36%; 149 | bottom: 16%; 150 | background-position: bottom right; 151 | } 152 | 153 | .uploadtip { 154 | z-index: 5; 155 | background-color: #f5f5f5; 156 | font-size: 16px; 157 | line-height: 160px; 158 | opacity: 1; 159 | } 160 | .uploadtip img { 161 | margin-right: 8px; 162 | vertical-align: middle; 163 | } 164 | -------------------------------------------------------------------------------- /app/static/css/result.css: -------------------------------------------------------------------------------- 1 | html, body, .wrapper { 2 | min-height: 100vh; 3 | height: auto; 4 | } 5 | 6 | body { 7 | display: none; 8 | } 9 | 10 | .image-wrapper { 11 | margin: 20px; 12 | min-width: 400px; 13 | } 14 | 15 | .image-target img { 16 | height: 100px; 17 | } 18 | 19 | .image-target .image-info { 20 | display: inline-block; 21 | margin-left: 20px; 22 | color: white; 23 | vertical-align: top; 24 | font-size: 16px; 25 | } 26 | 27 | .image-tip { 28 | color: white; 29 | font-size: 20px; 30 | line-height: 40px; 31 | } 32 | .image-tip span { 33 | border-bottom: 2px solid white; 34 | } 35 | 36 | .image-grid { 37 | display: flex; 38 | display: -webkit-flex; 39 | overflow: hidden; 40 | -webkit-flex-wrap: wrap; 41 | flex-wrap: wrap; 42 | max-height: 100%; 43 | } 44 | 45 | .image-grid img { 46 | flex: 1 1 auto; 47 | min-width: 100%; 48 | max-width: 100%; 49 | height: 150px; 50 | vertical-align: middle; 51 | 52 | object-fit: cover; 53 | } 54 | /* ugly but useful */ 55 | .image-grid::after { 56 | flex-grow: 1000; 57 | content: ''; 58 | } 59 | 60 | .image-grid .image-item { 61 | position: relative; 62 | flex-grow: 1; 63 | margin-top: 10px; 64 | margin-right: 10px; 65 | } 66 | 67 | .image-item .image-overlay { 68 | position: absolute; 69 | bottom: 0; 70 | width: 100%; 71 | background: rgba(0, 0, 0, 0.5); /* Black see-through */ 72 | color: white; 73 | text-align: center; 74 | font-size: small; 75 | opacity:0; 76 | transition: .5s ease; 77 | } 78 | 79 | .image-item:hover .image-overlay { 80 | opacity: 1; 81 | } -------------------------------------------------------------------------------- /app/static/img/close.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willard-yuan/SoTu/3e39739764b498d057bab1e627da3a559e5faaab/app/static/img/close.png -------------------------------------------------------------------------------- /app/static/img/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willard-yuan/SoTu/3e39739764b498d057bab1e627da3a559e5faaab/app/static/img/demo.gif -------------------------------------------------------------------------------- /app/static/img/dropfocus.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willard-yuan/SoTu/3e39739764b498d057bab1e627da3a559e5faaab/app/static/img/dropfocus.gif -------------------------------------------------------------------------------- /app/static/img/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willard-yuan/SoTu/3e39739764b498d057bab1e627da3a559e5faaab/app/static/img/favicon.ico -------------------------------------------------------------------------------- /app/static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willard-yuan/SoTu/3e39739764b498d057bab1e627da3a559e5faaab/app/static/img/logo.png -------------------------------------------------------------------------------- /app/static/img/search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willard-yuan/SoTu/3e39739764b498d057bab1e627da3a559e5faaab/app/static/img/search.png -------------------------------------------------------------------------------- /app/static/img/uploading.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willard-yuan/SoTu/3e39739764b498d057bab1e627da3a559e5faaab/app/static/img/uploading.gif -------------------------------------------------------------------------------- /app/static/js/base.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function () { 2 | $('body').children().not('.wrapper').remove(); 3 | }); -------------------------------------------------------------------------------- /app/static/js/index.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function () { 2 | $("#fileimg").change(function (e) { 3 | var file = $(this)[0].files[0]; 4 | if (validate_img(file)) 5 | submit($("#imgform")); 6 | }); 7 | 8 | $("#btnsubmit").click(function () { 9 | var url = $("#txturl").val(); 10 | if (validate_url(url)) 11 | submit($("#urlform")); 12 | }); 13 | 14 | $(document).keypress(function (e) { 15 | var key = e.which; 16 | // 回车事件触发提交按钮 17 | if (key == 13) { 18 | e.preventDefault(); 19 | $("#btnsubmit").click(); 20 | } 21 | }); 22 | 23 | $(".container").on({ 24 | dragenter: function (e) { 25 | e.stopPropagation(); 26 | e.preventDefault(); 27 | }, 28 | dragover: function (e) { 29 | e.stopPropagation(); 30 | e.preventDefault(); 31 | }, 32 | drop: function (e) { 33 | e.stopPropagation(); 34 | e.preventDefault(); 35 | $("#dropzone").hide(); 36 | $("#fileimg")[0].files = e.originalEvent.dataTransfer.files; 37 | if (validate_img($("#fileimg")[0].files[0])) 38 | submit($("#imgform")); 39 | } 40 | }); 41 | $(document).on({ 42 | dragenter: function (e) { 43 | e.stopPropagation(); 44 | e.preventDefault(); 45 | $("#dropzone").show(); 46 | }, 47 | dragover: function (e) { 48 | e.stopPropagation(); 49 | e.preventDefault(); 50 | }, 51 | dragleave: function (e) { 52 | e.stopPropagation(); 53 | e.preventDefault(); 54 | if (e.clientX <= 0 || 55 | e.clientX >= $(window).width() || 56 | e.clientY <= 0 || 57 | e.clientY >= $(window).height()) 58 | $("#dropzone").hide(); 59 | }, 60 | drop: function (e) { 61 | e.stopPropagation(); 62 | e.preventDefault(); 63 | $("#dropzone").hide(); 64 | } 65 | }); 66 | 67 | $("#btnclose").click(function () { 68 | $("#txturl").val(""); 69 | }); 70 | }); 71 | 72 | function validate_img(file) { 73 | var type = file['type']; 74 | if (type.split('/')[0] != 'image') { 75 | alert("只接受图片格式的文件"); 76 | return false; 77 | } 78 | else if (file.size >= 3 * 1024 * 1024) { 79 | alert("请上传小于3M的图片"); 80 | return false; 81 | } 82 | return true; 83 | } 84 | 85 | function validate_url(url) { 86 | var imgregex = /(http(s?):)([/|.|\w|\s|-])*\.(?:jpg|jpeg|png|gif)/g; 87 | 88 | if (!url) { 89 | $("#txturl").addClass("warn"); 90 | setTimeout(function () { 91 | $("#txturl").removeClass("warn"); 92 | }, 2e3); 93 | return false; 94 | } 95 | else if (url.length > 1000) { 96 | alert("URL长度不超过1000"); 97 | return false; 98 | } 99 | else if (!imgregex.test(url)) { 100 | alert("图片URL不合法"); 101 | return false; 102 | } 103 | return true; 104 | } 105 | 106 | function submit(form) { 107 | $("#uploadtip").show(); 108 | try { 109 | form.submit(); 110 | } 111 | catch (err) { 112 | alert(err); 113 | $("#uploadtip").hide(); 114 | } 115 | } -------------------------------------------------------------------------------- /app/static/js/result.js: -------------------------------------------------------------------------------- 1 | $(window).on('load', function () { 2 | if ($(".image-target img").length) { 3 | var size = get_size($(".image-target img")); 4 | $(".image-info").append("

图片尺寸:
" + size.width + " × " + size.height + "

"); 5 | } 6 | $(".image-grid .image-item").each(function () { 7 | var img = $(this).children("img"); 8 | var div = $(this).children("div"); 9 | if (img.length) { 10 | var size = get_size(img); 11 | div.append(size.width + " × " + size.height); 12 | } 13 | }); 14 | $("body").show(); 15 | }); 16 | 17 | function get_size(img) { 18 | return { 19 | 'width': img.get(0).naturalWidth, 20 | 'height': img.get(0).naturalHeight 21 | }; 22 | } 23 | -------------------------------------------------------------------------------- /app/static/uploads/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willard-yuan/SoTu/3e39739764b498d057bab1e627da3a559e5faaab/app/static/uploads/.gitkeep -------------------------------------------------------------------------------- /app/templates/403.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block title %}SoTu - Forbidden{% endblock %} 4 | 5 | {% block content %} 6 |

Forbidden

7 | {% endblock %} 8 | -------------------------------------------------------------------------------- /app/templates/404.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block title %}SoTu - Page Not Found{% endblock %} 4 | 5 | {% block content %} 6 |

Page Not Found

7 | {% endblock %} 8 | -------------------------------------------------------------------------------- /app/templates/500.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block title %}SoTu - Internal Server Error{% endblock %} 4 | 5 | {% block content %} 6 |

Internal Server Error

7 | {% endblock %} 8 | -------------------------------------------------------------------------------- /app/templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {% block head %} 6 | {% block title %}{% endblock %} 7 | {% block styles %} 8 | 9 | 10 | 11 | {% endblock %} 12 | {% endblock %} 13 | 14 | 15 | 16 | {% block body %} 17 |
18 |
19 | {% block header %} 20 | 21 | 22 | SOTU 23 | 24 | {% endblock %} 25 |
26 |
27 | {% block content %} 28 | {% endblock %} 29 |
30 | 34 |
35 | {% block scripts %} 36 | 37 | 38 | {% endblock %} 39 | {% with messages = get_flashed_messages() %} 40 | {% if messages %} 41 | {% for message in messages %} 42 | 43 | {% endfor %} 44 | {% endif %} 45 | {% endwith %} 46 | {% endblock %} 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /app/templates/csrf_error.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block title %}SoTu - CSRF ERROR{% endblock %} 4 | 5 | {% block content %} 6 |

{{ reason }}

7 | {% endblock %} 8 | -------------------------------------------------------------------------------- /app/templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block title %}SoTu{% endblock %} 4 | 5 | {% block styles %} 6 | {{ super() }} 7 | 8 | {% endblock %} 9 | 10 | {% block content %} 11 |
12 |
13 |
14 | 15 | 16 | 17 |
18 |
19 | 20 | 21 | 22 |
23 | 24 |
25 |
拖放图片到此处试试
26 | 33 | 37 |
38 | {% endblock %} 39 | 40 | {% block scripts %} 41 | {{ super() }} 42 | 43 | {% endblock %} 44 | -------------------------------------------------------------------------------- /app/templates/result.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block title %}SoTu{% endblock %} 4 | 5 | {% block styles %} 6 | {{ super() }} 7 | 8 | {% endblock %} 9 | 10 | {% block content %} 11 |
12 |
13 | 14 |
目标图片
15 |
16 | {% if images %} 17 |
18 | 结果图片 | 按相似度排序 19 |
20 |
21 | {% for image in images %} 22 |
23 | 24 |
25 |
26 | {% endfor %} 27 |
28 | {% else %} 29 |

没有类似图片

30 | {% endif %} 31 |
32 | {% endblock %} 33 | 34 | {% block scripts %} 35 | {{ super() }} 36 | 37 | {% endblock %} 38 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | 6 | class Config: 7 | DEBUG = False 8 | # 路径配置 9 | BASE_DIR = os.path.dirname(__file__) 10 | UPLOAD_DIR = os.path.join(BASE_DIR, 'app/static/uploads') 11 | 12 | @staticmethod 13 | def init_app(app): 14 | pass 15 | 16 | 17 | class DevelopmentConfig(Config): 18 | DEBUG = True 19 | WTF_CSRF_ENABLED = False 20 | 21 | 22 | config = { 23 | 'development': DevelopmentConfig, 24 | 'default': DevelopmentConfig 25 | } 26 | -------------------------------------------------------------------------------- /data/sift/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willard-yuan/SoTu/3e39739764b498d057bab1e627da3a559e5faaab/data/sift/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flask == 0.12.2 2 | flask-wtf == 0.14.2 3 | jinja2 == 2.10 4 | matplotlib == 2.2.2 5 | numpy == 1.14.3 6 | opencv-contrib-python == 3.4.1.15 7 | scikit-learn == 0.19.1 8 | scipy == 1.1.0 9 | werkzeug == 0.14.1 10 | -------------------------------------------------------------------------------- /sotu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from app import create_app 6 | 7 | app = create_app(os.getenv('FLASK_CONFIG') or 'default') 8 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import posixpath 5 | import tarfile 6 | from urllib.error import HTTPError, URLError 7 | from urllib.request import Request, urlretrieve 8 | 9 | 10 | def download(url, root, filename, untar=False): 11 | fpath = os.path.join(root, filename) 12 | if not os.path.exists(root): 13 | os.mkdir(root) 14 | if os.path.exists(fpath): 15 | print("Data already downloaded") 16 | else: 17 | print("Downloading %s to %s" % (url, fpath)) 18 | err_msg = "URL fetch failure on {}: {} -- {}" 19 | try: 20 | try: 21 | urlretrieve(url, fpath) 22 | except URLError as e: 23 | raise Exception(err_msg.format(url, e.errno, e.reason)) 24 | except HTTPError as e: 25 | raise Exception(err_msg.format(url, e.code, e.msg)) 26 | except (Exception, KeyboardInterrupt) as e: 27 | print(e) 28 | if os.path.exists(fpath): 29 | os.remove(fpath) 30 | if untar is True: 31 | with tarfile.open(fpath) as tar: 32 | tar.extractall(os.path.dirname(fpath)) 33 | 34 | 35 | def list_files(root, suffix): 36 | names = [] 37 | for name in os.listdir(root): 38 | fd = posixpath.join(root, name) 39 | if os.path.isfile(fd) and fd.endswith(suffix): 40 | names.append(fd) 41 | if os.path.isdir(fd): 42 | names.extend(list_files(fd, suffix)) 43 | return names 44 | -------------------------------------------------------------------------------- /vision/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .bof import BoF 4 | 5 | __all__ = ('BoF') 6 | -------------------------------------------------------------------------------- /vision/bof.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import pickle 5 | import time 6 | 7 | import click 8 | import cv2 9 | import numpy as np 10 | from sklearn.cluster import MiniBatchKMeans 11 | from werkzeug.utils import cached_property 12 | 13 | from .he import HE 14 | from .inv import InvFile 15 | from .sift import SIFT 16 | from .ukbench import UKBENCH 17 | from .wgc import WGC 18 | 19 | 20 | class BoF(object): 21 | def __init__(self): 22 | self.k = 5000 23 | self.bof_path = 'data/bof.pkl' 24 | self.inv_path = 'data/inv.pkl' 25 | self.ukbench = UKBENCH('data') 26 | self.n = len(self.ukbench) 27 | self.sift = SIFT('data') 28 | self.inv = InvFile(self.k, self.n) 29 | 30 | def init_app(self, app): 31 | 32 | @click.command('train') 33 | # todo 34 | def train(): 35 | self.train() 36 | 37 | @click.command('extract') 38 | def extract(): 39 | self.extract() 40 | 41 | @click.command('evaluate') 42 | def evaluate(): 43 | queries = [] 44 | for i in range(0, self.n, 4): 45 | start = time.time() 46 | matches = self.match(self.ukbench[i]) 47 | ap = self.ukbench.evaluate(self.ukbench[i], matches) 48 | elapse = time.time() - start 49 | print("Query %s: ap = %4f, %4fs elapsed" % 50 | (self.ukbench[i], ap, elapse)) 51 | queries.append((ap, elapse)) 52 | mAP, mT = np.mean(queries, axis=0) 53 | print("mAP of the %d images is %4f, %4fs per query" % 54 | (len(queries), mAP, mT)) 55 | 56 | app.cli.add_command(train) 57 | app.cli.add_command(extract) 58 | app.cli.add_command(evaluate) 59 | 60 | # todo 61 | def train(self): 62 | print("Get sift features of %d images" % self.n) 63 | 64 | def extract(self): 65 | print("Get sift features of %d images" % self.n) 66 | # 获取每幅图的所有关键点和对应的描述子 67 | keypoints = [] 68 | descriptors = [] 69 | numNoDes = 0 70 | badImgs = [] 71 | for i in range(self.n): 72 | print("%d (%d), %s" %((i+1), self.n, self.ukbench[i])) 73 | img = cv2.imread(self.ukbench[i], cv2.IMREAD_GRAYSCALE) 74 | if img is None: 75 | badImgs.append(self.ukbench[i]) 76 | numNoDes += 1 77 | continue 78 | status, kpt, des = self.sift.extract(img, rootsift=True) 79 | if status == 0: 80 | badImgs.append(self.ukbench[i]) 81 | numNoDes += 1 82 | continue 83 | keypoints.append(kpt) 84 | descriptors.append(des) 85 | for badImg in badImgs: 86 | self.ukbench.remove(badImg) 87 | self.n = self.n - numNoDes 88 | self.inv.n = self.n 89 | for i, (kp, des) in enumerate(zip(keypoints, descriptors)): 90 | self.sift.dump(kp, des, str(i)) 91 | # keypoints, descriptors = zip( 92 | # *[self.sift.load(str(i)) for i in range(self.n)] 93 | # ) 94 | # 垂直堆叠所有的描述子,每个128维 95 | des_all = np.vstack([des for des in descriptors]) 96 | 97 | print("Start kmeans with %d clusters" % self.k) 98 | kmeans = MiniBatchKMeans( 99 | n_clusters=self.k, batch_size=1000, 100 | random_state=0, init_size=self.k * 3 101 | ).fit(des_all) 102 | # 映射每幅图的所有描述子到距其最近的聚类并得到聚类索引 103 | labels = [kmeans.predict(des) for des in descriptors] 104 | 105 | print("Porject %d descriptors from 128d to 64d" % len(des_all)) 106 | he = HE(64, 128, self.k) 107 | projections = [he.project(des) for des in descriptors] 108 | prj_all = np.vstack([prj for prj in projections]) 109 | label_all = np.hstack([label for label in labels]) 110 | 111 | print("Calculate medians of %d visual words" % self.k) 112 | he.fit(prj_all, label_all) 113 | 114 | print("Calculate binary signatures of %d projections" % len(des_all)) 115 | signatures = [ 116 | [he.signature(p, l) for p, l in zip(prj, label)] 117 | for prj, label in zip(projections, labels) 118 | ] 119 | 120 | # 建立聚类的倒排索引 121 | self.inv.dump(keypoints, signatures, labels, self.inv_path) 122 | 123 | # 统计每幅图所有描述子所属聚类的频率向量 124 | freqs = np.array([ 125 | np.bincount(label, minlength=self.k) for label in labels 126 | ]) 127 | # 计算每幅图频率向量的模 128 | norms = np.array([np.linalg.norm(freq) for freq in freqs]) 129 | # 计算聚类频率矩阵的idf(sklearn的实现方式) 130 | idf = np.log((self.n + 1) / (np.sum((freqs > 0), axis=0) + 1)) + 1 131 | 132 | with open(self.bof_path, 'wb') as bof_pkl: 133 | pickle.dump((kmeans, he, norms, idf), bof_pkl) 134 | 135 | def match(self, uri, top_k=20, ht=23, rerank=True): 136 | kmeans, he, norms, idf = self.bof 137 | # 计算关键点和描述子 138 | status, kp, des = self.sift.extract(cv2.imread(uri, cv2.IMREAD_GRAYSCALE), rootsift=True) 139 | # 计算每个关键点对应的关于角度和尺度的几何信息 140 | geo = [(np.radians(k.angle), np.log2(k.size)) for k in kp] 141 | # 映射所有描述子到距其最近的聚类并得到该聚类的索引 142 | label = kmeans.predict(des) 143 | 144 | # 根据投影矩阵对描述子降维 145 | prj = he.project(des) 146 | # 计算所有描述子对应的Hamming编码 147 | signature = [he.signature(p, l) for p, l in zip(prj, label)] 148 | 149 | # wgc = WGC(self.n, 17, 7) 150 | scores = np.zeros(self.n) 151 | # 匹配所有所属聚类相同且对应编码的hamming距离不超过阈值的特征 152 | for (ang_q, sca_q), sig_q, lbl_q in zip(geo, signature, label): 153 | for img_id, ang_t, sca_t, sig_t in self.entries[lbl_q]: 154 | if he.distance(sig_q, sig_t) < ht: 155 | scores[img_id] += idf[lbl_q] 156 | # wgc.vote(img_id, 157 | # np.arctan2(np.sin(ang_t - ang_q), 158 | # np.cos(ang_t - ang_q)), 159 | # sca_t - sca_q) 160 | # scores *= wgc.filter() 161 | scores = scores / norms 162 | rank = np.argsort(-scores)[:top_k] 163 | 164 | if rerank: 165 | scores = np.zeros(top_k) 166 | keypoints, descriptors = zip( 167 | *[self.sift.load(str(r)) for r in rank] 168 | ) 169 | # 使用kNN算法获取匹配坐标 170 | pairs = [ 171 | [(kp[q].pt, keypoints[i][t].pt) 172 | for q, t in self.sift.match(des, descriptors[i])] 173 | for i in range(top_k) 174 | ] 175 | for i in range(top_k): 176 | mask = self.sift.filter(pairs[i]) 177 | scores[i] += np.sum(mask) 178 | rank = [r for s, r in sorted(zip(-scores, rank))] 179 | images = [self.ukbench[r] for r in rank] 180 | return images 181 | 182 | @cached_property 183 | def bof(self): 184 | with open(self.bof_path, 'rb') as bof_pkl: 185 | kmeans, he, norms, idf = pickle.load(bof_pkl) 186 | self.n = norms.shape[0] 187 | return kmeans, he, norms, idf 188 | 189 | @cached_property 190 | def entries(self): 191 | return self.inv.load(self.inv_path) 192 | -------------------------------------------------------------------------------- /vision/he.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | 6 | class HE(object): 7 | def __init__(self, db, d, k): 8 | self.k = k 9 | # 生成(d, d)的符合标准正态分布的随机矩阵 10 | self.M = np.random.randn(d, d) 11 | # QR分解得到正交矩阵Q 12 | self.Q, R = np.linalg.qr(self.M) 13 | # 获取(d, db)的投影矩阵 14 | self.P = self.Q[:db, :] 15 | # 建立(k, db)的中值矩阵 16 | self.medians = np.zeros([self.k, db]) 17 | 18 | def project(self, des): 19 | return np.dot(des, self.P.T) 20 | 21 | def fit(self, prj_all, label_all, eps=1e-7): 22 | # 统计所属聚类的频率,eps防止除数为0 23 | freqs = [eps] * self.k 24 | for prj, label in zip(prj_all, label_all): 25 | self.medians[label] += prj 26 | freqs[label] += 1 27 | self.medians = [m / f for m, f in zip(self.medians, freqs)] 28 | 29 | def signature(self, prj, label): 30 | signature = np.uint64() 31 | bins = prj > self.medians[label] 32 | # 压缩为64bits 33 | for i, b in enumerate(bins[::-1]): 34 | signature = np.bitwise_or(signature, np.uint64(2**i * b)) 35 | return signature 36 | 37 | def distance(self, sig_q, sig_t): 38 | return bin(sig_q ^ sig_t).count("1") 39 | -------------------------------------------------------------------------------- /vision/inv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | 8 | class InvFile(object): 9 | def __init__(self, k, n): 10 | self.k = k 11 | self.n = n 12 | 13 | def dump(self, keypoints, signatures, labels, path): 14 | entries = [[] for i in range(self.k)] 15 | # 添加每幅图的所有关键点的角度、尺度信息及对应的Hamming编码到倒排索引中 16 | for i in range(self.n): 17 | for k, s, l in zip(keypoints[i], signatures[i], labels[i]): 18 | entries[l].append( 19 | (i, np.radians(k.angle), np.log2(k.size), s) 20 | ) 21 | with open(path, 'wb') as inv_pkl: 22 | pickle.dump(entries, inv_pkl) 23 | 24 | def load(self, path): 25 | with open(path, 'rb') as inv_pkl: 26 | entries = pickle.load(inv_pkl) 27 | return entries 28 | -------------------------------------------------------------------------------- /vision/sift.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import pickle 5 | 6 | import numpy as np 7 | 8 | import cv2 9 | 10 | 11 | class SIFT(object): 12 | def __init__(self, root): 13 | self.path = os.path.join(root, 'sift') 14 | self.extractor = cv2.xfeatures2d.SIFT_create() 15 | 16 | def extract(self, gray, rootsift=True): 17 | # 计算图片的所有关键点和对应的描述子 18 | status = 0 19 | kp, des = self.extractor.detectAndCompute(gray, None) 20 | if len(kp) <= 0: 21 | return status, kp, des 22 | if rootsift: 23 | des = self.rootsift(des) 24 | status = 1 25 | return status, kp, des 26 | 27 | def match(self, des_q, des_t): 28 | ratio = 0.7 # 按照Lowe的测试 29 | flann = cv2.FlannBasedMatcher() 30 | # 对des_q中的每个描述子,在des_t中找到最好的两个匹配 31 | two_nn = flann.knnMatch(des_q, des_t, k=2) 32 | # 找到所有显著好于次匹配的最好匹配,得到对应的索引对 33 | matches = [(first.queryIdx, first.trainIdx) for first, second in two_nn 34 | if first.distance < ratio * second.distance] 35 | return matches 36 | 37 | def filter(self, pt_qt): 38 | if len(pt_qt) > 0: 39 | pt_q, pt_t = zip(*pt_qt) 40 | # 获取匹配坐标的变换矩阵和正常点的掩码 41 | M, mask = cv2.findHomography(np.float32(pt_q).reshape(-1, 1, 2), 42 | np.float32(pt_t).reshape(-1, 1, 2), 43 | cv2.RANSAC, 3) 44 | return mask.ravel().tolist() 45 | else: 46 | return [] 47 | 48 | def draw(self, img_q, img_t, pt_qt): 49 | import matplotlib 50 | matplotlib.use('Agg') 51 | from matplotlib import pyplot as plt 52 | from matplotlib.patches import ConnectionPatch 53 | 54 | fig, (ax_q, ax_t) = plt.subplots(1, 2, figsize=(8, 4)) 55 | for pt_q, pt_t in pt_qt: 56 | con = ConnectionPatch(pt_t, pt_q, 57 | coordsA='data', coordsB='data', 58 | axesA=ax_t, axesB=ax_q, 59 | color='g', linewidth=0.5) 60 | ax_t.add_artist(con) 61 | ax_q.plot(pt_q[0], pt_q[1], 'rx') 62 | ax_t.plot(pt_t[0], pt_t[1], 'rx') 63 | ax_q.imshow(img_q) 64 | ax_t.imshow(img_t) 65 | ax_q.axis('off') 66 | ax_t.axis('off') 67 | plt.subplots_adjust(wspace=0, hspace=0) 68 | plt.show() 69 | 70 | @classmethod 71 | def rootsift(cls, des, eps=1e-7): 72 | if des is not None: 73 | # 对所有描述子进行L1归一化并取平方根,eps防止除数为0 74 | des /= (des.sum(axis=1, keepdims=True) + eps) 75 | des = np.sqrt(des) 76 | return des 77 | 78 | def dump(self, kp, des, filename): 79 | tmp = [ 80 | (kp.pt, kp.size, kp.angle, kp.response, kp.octave, kp.class_id) 81 | for kp in kp 82 | ] 83 | with open(os.path.join(self.path, filename), 'wb') as sift_pkl: 84 | pickle.dump((tmp, des), sift_pkl) 85 | 86 | def load(self, filename): 87 | with open(os.path.join(self.path, filename), 'rb') as sift_pkl: 88 | tmp, des = pickle.load(sift_pkl) 89 | kp = [ 90 | cv2.KeyPoint(x=t[0][0], y=t[0][1], _size=t[1], _angle=t[2], 91 | _response=t[3], _octave=t[4], _class_id=t[5]) 92 | for t in tmp 93 | ] 94 | return kp, des 95 | -------------------------------------------------------------------------------- /vision/ukbench.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import posixpath 5 | import re 6 | 7 | from utils import download, list_files 8 | 9 | 10 | class UKBENCH(object): 11 | url = 'https://archive.org/download/ukbench/ukbench.zip' 12 | filename = 'ukbench.zip' 13 | ukbench_dir = 'ukbench' 14 | 15 | def __init__(self, root): 16 | self.root = root 17 | if not posixpath.exists(posixpath.join(self.root, self.ukbench_dir)): 18 | download(self.url, self.root, self.filename, untar=True) 19 | self.uris = list_files(posixpath.join(self.root, 20 | self.ukbench_dir, 21 | 'full'), 22 | ('png', 'jpg', 'jpeg', 'gif')) 23 | self.uris.sort() 24 | 25 | def __getitem__(self, index): 26 | return self.uris[index] 27 | 28 | def __len__(self): 29 | return len(self.uris) 30 | 31 | def evaluate(self, img_q, images): 32 | img_id = self.id_of(img_q) 33 | min_id = img_id - img_id % 4 34 | max_id = min_id + 4 35 | results = [self.id_of(img) for img in images] 36 | 37 | precision = [0] * len(results) 38 | recall = [0] * len(results) 39 | positives = 0 40 | ap = 0 41 | 42 | for i, result in enumerate(results): 43 | if result >= min_id and result < max_id: 44 | positives += 1 45 | precision[i] = positives / (i + 1) 46 | recall[i] = positives / 4 47 | 48 | pr_sum = precision[0] 49 | for i in range(1, len(precision)): 50 | if recall[i] > recall[i - 1]: 51 | pr_sum += precision[i] 52 | ap = pr_sum / 4 53 | return ap 54 | 55 | def id_of(self, uri): 56 | return int(re.split(r'(\d+)', os.path.basename(uri))[1]) 57 | -------------------------------------------------------------------------------- /vision/wgc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | 6 | class WGC(object): 7 | def __init__(self, n, angle_bins, scale_bins): 8 | self.n = n 9 | self.angle_bins, self.scale_bins = angle_bins, scale_bins 10 | self.angle_hists = np.zeros((self.n, self.angle_bins)) 11 | self.scale_hists = np.zeros((self.n, self.scale_bins)) 12 | 13 | def vote(self, i, diff, scale_diff): 14 | angle_index = self.quantize_angle(diff) 15 | scale_index = self.quantize_scale(diff) 16 | if angle_index >= 0 and angle_index < self.angle_bins: 17 | self.angle_hists[i][angle_index] += 1 18 | if scale_index >= 0 and scale_index < self.scale_bins: 19 | self.scale_hists[i][scale_index] += 1 20 | 21 | def filter(self): 22 | am = np.max([self.movmean(h, 3) for h in self.angle_hists], axis=1) 23 | sm = np.max([self.movmean(h, 3) for h in self.scale_hists], axis=1) 24 | return np.min(np.vstack((am, sm)), axis=0) 25 | 26 | def quantize_angle(self, diff): 27 | return int((diff + np.pi) * self.angle_bins / (2 * np.pi)) 28 | 29 | def quantize_scale(self, diff): 30 | return int((diff + 3) * self.scale_bins / 6) 31 | 32 | def movmean(self, hist, window): 33 | cumsum = np.cumsum(np.insert(hist, 0, 0)) 34 | return (cumsum[window:] - cumsum[: -window]) / float(window) 35 | --------------------------------------------------------------------------------