├── .env.local ├── .gitignore ├── .vscode └── settings.json ├── README.md ├── apps ├── app.py ├── auth │ ├── forms.py │ ├── templates │ │ └── auth │ │ │ ├── base.html │ │ │ ├── index.html │ │ │ ├── login.html │ │ │ └── signup.html │ └── views.py ├── config.py ├── crud │ ├── __init__.py │ ├── forms.py │ ├── models.py │ ├── static │ │ └── style.css │ ├── templates │ │ └── crud │ │ │ ├── base.html │ │ │ ├── create.html │ │ │ ├── edit.html │ │ │ └── index.html │ └── views.py ├── detector │ ├── __init__.py │ ├── forms.py │ ├── models.py │ ├── templates │ │ └── detector │ │ │ ├── 404.html │ │ │ ├── 500.html │ │ │ ├── base.html │ │ │ ├── index.html │ │ │ └── upload.html │ └── views.py ├── images │ └── .ignore ├── minimalapp │ ├── app.py │ ├── static │ │ └── style.css │ └── templates │ │ ├── contact.html │ │ ├── contact_complete.html │ │ ├── contact_mail.html │ │ ├── contact_mail.txt │ │ └── index.html ├── static │ └── css │ │ ├── bootstrap.min.css │ │ └── style.css └── templates │ ├── 404.html │ └── 500.html ├── flaskbook_api ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── api │ ├── __init__.py │ ├── calculation.py │ ├── config │ │ ├── __init__.py │ │ ├── base.py │ │ └── local.py │ ├── postprocess.py │ ├── preparation.py │ └── preprocess.py ├── data │ ├── original │ │ └── test.jpg │ └── output │ │ └── test.jpg ├── requirements.txt └── run.py ├── ml_api ├── .gitignore ├── LICENSE ├── README.md ├── analysis.ipynb ├── api │ ├── __init__.py │ ├── calculation.py │ ├── config │ │ ├── __init__.py │ │ ├── config.py │ │ └── json-schemas │ │ │ ├── __init__.py │ │ │ ├── check_dir_name.json │ │ │ ├── check_file_id.json │ │ │ └── check_file_schema.json │ ├── images.db │ ├── json_validate.py │ ├── models.py │ ├── preparation.py │ ├── preprocess.py │ └── run.py ├── handwriting_pics │ ├── 0.jpg │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ ├── 8.jpg │ └── 9.jpg ├── requirements-test.txt ├── requirements.txt ├── setup.cfg ├── setup.py └── test │ ├── __init__.py │ ├── conftest.py │ ├── test_calculation.py │ ├── test_preparation.py │ └── test_preprocess.py ├── requirements.txt └── tests ├── __init__.py ├── conftest.py ├── detector ├── __init__.py ├── test_views.py └── testdata │ ├── test_invalid_file.txt │ └── test_valid_image.jpg └── test_sample.py /.env.local: -------------------------------------------------------------------------------- 1 | # FLASK_APP=apps.minimalapp.app.py 2 | FLASK_APP=apps.app:create_app('local') 3 | FLASK_ENV=development 4 | 5 | # flask-mailコンフィグ設定 6 | MAIL_SERVER=smtp.gmail.com 7 | MAIL_PORT=587 8 | MAIL_USE_TLS=True 9 | MAIL_USERNAME=[Gmailのメールアドレス] 10 | MAIL_PASSWORD=[2段階認証後に発行したGmailのパスワード] 11 | MAIL_DEFAULT_SENDER=Flaskbook -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,flask,vscode 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,flask,vscode 4 | 5 | ### Flask ### 6 | instance/* 7 | !instance/.gitignore 8 | .webassets-cache 9 | .env 10 | 11 | ### Flask.Python Stack ### 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | cover/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | 149 | ### Python ### 150 | # Byte-compiled / optimized / DLL files 151 | 152 | # C extensions 153 | 154 | # Distribution / packaging 155 | 156 | # PyInstaller 157 | # Usually these files are written by a python script from a template 158 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 159 | 160 | # Installer logs 161 | 162 | # Unit test / coverage reports 163 | 164 | # Translations 165 | 166 | # Django stuff: 167 | 168 | # Flask stuff: 169 | 170 | # Scrapy stuff: 171 | 172 | # Sphinx documentation 173 | 174 | # PyBuilder 175 | 176 | # Jupyter Notebook 177 | 178 | # IPython 179 | 180 | # pyenv 181 | # For a library or package, you might want to ignore these files since the code is 182 | # intended to run in multiple environments; otherwise, check them in: 183 | # .python-version 184 | 185 | # pipenv 186 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 187 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 188 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 189 | # install all needed dependencies. 190 | 191 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 192 | 193 | # Celery stuff 194 | 195 | # SageMath parsed files 196 | 197 | # Environments 198 | 199 | # Spyder project settings 200 | 201 | # Rope project settings 202 | 203 | # mkdocs documentation 204 | 205 | # mypy 206 | 207 | # Pyre type checker 208 | 209 | # pytype static type analyzer 210 | 211 | # Cython debug symbols 212 | 213 | #!! ERROR: vscode is undefined. Use list command to see defined gitignore types !!# 214 | 215 | # End of https://www.toptal.com/developers/gitignore/api/python,flask,vscode 216 | 217 | migrations/ 218 | *.sqlite 219 | apps/detector/model.pt 220 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.linting.flake8Enabled": true, 3 | "python.formatting.provider": "black", 4 | "editor.formatOnSave": true, 5 | "editor.codeActionsOnSave": { 6 | "source.organizeImports": true 7 | }, 8 | "python.linting.mypyEnabled": true, 9 | "python.linting.flake8Args": [ 10 | "--max-line-length=88", 11 | ], 12 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python Flask による Web アプリ開発入門 2 | ## Git Clone 3 | 4 | ``` 5 | $ git clone https://github.com/ml-flaskbook/flaskbook.git 6 | ``` 7 | 8 | ## 仮想環境を作成する 9 | 10 | ### Mac/Linux 11 | 12 | ``` 13 | $ python3 -m venv venv 14 | $ source venv/bin/activate 15 | ``` 16 | 17 | ### Widows(PowerShell) 18 | 19 | スクリプトを実行するために、Windows PowerShellで次のコマンドを実行し、実行ポリシーを変更する。 20 | 21 | ``` 22 | > PowerShell Set-ExecutionPolicy RemoteSigned CurrentUser 23 | ``` 24 | 25 | ポリシーを変更したら、次のコマンドを実行する 26 | 27 | ``` 28 | > py -m venv venv 29 | > venv\Scripts\Activate.ps1 30 | ``` 31 | 32 | ## 環境変数ファイル設置 33 | 34 | ``` 35 | $ cp -p .env.local .env 36 | ``` 37 | 38 | ## パッケージインストール 39 | 40 | ``` 41 | (venv) $ pip install -r requirements.txt 42 | ``` 43 | 44 | ## DBマイグレート 45 | 46 | ``` 47 | (venv) $ flask db init 48 | (venv) $ flask db migrate 49 | (venv) $ flask db upgrade 50 | ``` 51 | 52 | ## 学習済みモデルを取得する 53 | 54 | ``` 55 | (venv) $ python 56 | >>> import torch 57 | >>> import torchvision 58 | >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) 59 | >>> torch.save(model, "model.pt") 60 | ``` 61 | 62 | `model.pt`を`apps/detector`配下へ移動する 63 | 64 | ## アプリケーション起動 65 | 66 | ``` 67 | (venv) $ flask run 68 | ``` 69 | 70 | ## テスト実行 71 | 72 | ``` 73 | $ pytest tests/detector 74 | ``` 75 | 76 | ## 第2部から読み始める場合 77 | 78 | 下記コマンドで第1部までの状態に切り替えられます。 79 | 80 | ``` 81 | $ git checkout -b part1 tags/part1 82 | ``` -------------------------------------------------------------------------------- /apps/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, render_template 2 | from flask_login import LoginManager 3 | from flask_migrate import Migrate 4 | from flask_sqlalchemy import SQLAlchemy 5 | from flask_wtf.csrf import CSRFProtect 6 | 7 | from apps.config import config 8 | 9 | db = SQLAlchemy() 10 | csrf = CSRFProtect() 11 | # LoginManagerをインスタンス化する 12 | login_manager = LoginManager() 13 | # login_view属性に未ログイン時にリダイレクトするエンドポイントを指定する 14 | login_manager.login_view = "auth.signup" 15 | # login_message属性にログイン後に表示するメッセージを指定する 16 | # ここでは何も表示しないよう空を指定する 17 | login_manager.login_message = "" 18 | 19 | 20 | # create_app関数を作成する 21 | def create_app(config_key): 22 | # Flaskインスタンス生成 23 | app = Flask(__name__) 24 | app.config.from_object(config[config_key]) 25 | 26 | # SQLAlchemyとアプリを連携する 27 | db.init_app(app) 28 | # Migrateとアプリを連携する 29 | Migrate(app, db) 30 | csrf.init_app(app) 31 | # login_managerをアプリケーションと連携する 32 | login_manager.init_app(app) 33 | 34 | # crudパッケージからviewsをimportする 35 | from apps.crud import views as crud_views 36 | 37 | # register_blueprintを使いviewsのcrudをアプリへ登録する 38 | app.register_blueprint(crud_views.crud, url_prefix="/crud") 39 | 40 | # これから作成するauthパッケージからviewsをimportする 41 | from apps.auth import views as auth_views 42 | 43 | # register_blueprintを使いviewsのauthをアプリへ登録する 44 | app.register_blueprint(auth_views.auth, url_prefix="/auth") 45 | 46 | # これから作成するdetectorパッケージからviewsをimportする 47 | from apps.detector import views as dt_views 48 | 49 | # register_blueprintを使いviewsのdtをアプリへ登録する 50 | app.register_blueprint(dt_views.dt) 51 | 52 | # カスタムエラー画面を登録する 53 | app.register_error_handler(404, page_not_found) 54 | app.register_error_handler(500, internal_server_error) 55 | 56 | return app 57 | 58 | 59 | # 登録したエンドポイント名の関数を作成し、404や500が発生した際に指定したHTMLを返す 60 | def page_not_found(e): 61 | """404 Not Found""" 62 | return render_template("404.html"), 404 63 | 64 | 65 | def internal_server_error(e): 66 | """500 Internal Server Error""" 67 | return render_template("500.html"), 500 68 | -------------------------------------------------------------------------------- /apps/auth/forms.py: -------------------------------------------------------------------------------- 1 | from flask_wtf import FlaskForm 2 | from wtforms import PasswordField, StringField, SubmitField 3 | from wtforms.validators import DataRequired, Email, Length 4 | 5 | 6 | class SignUpForm(FlaskForm): 7 | username = StringField( 8 | "ユーザー名", 9 | validators=[ 10 | DataRequired("ユーザ名は必須です。"), 11 | Length(1, 30, "30文字以内で入力してください。"), 12 | ], 13 | ) 14 | email = StringField( 15 | "メールアドレス", 16 | validators=[ 17 | DataRequired("メールアドレスは必須です。"), 18 | Email("メールアドレスの形式で入力してください。"), 19 | ], 20 | ) 21 | password = PasswordField("パスワード", validators=[DataRequired("パスワードは必須です。")]) 22 | submit = SubmitField("新規登録") 23 | 24 | 25 | class LoginForm(FlaskForm): 26 | email = StringField( 27 | "メールアドレス", 28 | validators=[ 29 | DataRequired("メールアドレスは必須です。"), 30 | Email("メールアドレスの形式で入力してください。"), 31 | ], 32 | ) 33 | password = PasswordField("パスワード", validators=[DataRequired("パスワードは必須です。")]) 34 | submit = SubmitField("ログイン") 35 | -------------------------------------------------------------------------------- /apps/auth/templates/auth/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | {% block title %}{% endblock %} 7 | 8 | 9 | 10 | {% block content %}{% endblock %} 11 | 12 | 13 | -------------------------------------------------------------------------------- /apps/auth/templates/auth/index.html: -------------------------------------------------------------------------------- 1 | {% extends "auth/base.html" %} 2 | {% block title %}認証ページ{% endblock %} 3 | {%block content %} 認証ページ表示確認 {% endblock %} -------------------------------------------------------------------------------- /apps/auth/templates/auth/login.html: -------------------------------------------------------------------------------- 1 | {% extends "detector/base.html" %} 2 | 3 | {% block title %}ログイン{% endblock %} 4 | 5 | {% block content %} 6 |
7 | 21 |
22 | {% endblock %} -------------------------------------------------------------------------------- /apps/auth/templates/auth/signup.html: -------------------------------------------------------------------------------- 1 | {% extends "detector/base.html" %} 2 | 3 | {% block title %}ユーザー新規登録{% endblock %} 4 | 5 | {% block content %} 6 |
7 | 24 |
25 | {% endblock %} -------------------------------------------------------------------------------- /apps/auth/views.py: -------------------------------------------------------------------------------- 1 | from apps.app import db 2 | from apps.auth.forms import LoginForm, SignUpForm 3 | from apps.crud.models import User 4 | from flask import Blueprint, flash, redirect, render_template, request, url_for 5 | from flask_login import login_user, logout_user 6 | 7 | # Blueprintを使ってauthを生成する 8 | auth = Blueprint("auth", __name__, template_folder="templates", static_folder="static") 9 | 10 | 11 | # indexエンドポイントを作成する 12 | @auth.route("/") 13 | def index(): 14 | return render_template("auth/index.html") 15 | 16 | 17 | @auth.route("/signup", methods=["GET", "POST"]) 18 | def signup(): 19 | # SignUpFormをインスタンス化する 20 | form = SignUpForm() 21 | 22 | if form.validate_on_submit(): 23 | user = User( 24 | username=form.username.data, 25 | email=form.email.data, 26 | password=form.password.data, 27 | ) 28 | 29 | # メールアドレス重複チェックをする 30 | if user.is_duplicate_email(): 31 | flash("指定のメールアドレスは登録済みです") 32 | return redirect(url_for("auth.signup")) 33 | 34 | # ユーザー情報を登録する 35 | db.session.add(user) 36 | db.session.commit() 37 | 38 | # ユーザー情報をセッションに格納する 39 | login_user(user) 40 | 41 | # GETパラメータにnextキーが存在し、値がない場合はユーザーの一覧ページへリダイレクトする 42 | next_ = request.args.get("next") 43 | if next_ is None or not next_.startswith("/"): 44 | next_ = url_for("detector.index") 45 | return redirect(next_) 46 | 47 | return render_template("auth/signup.html", form=form) 48 | 49 | 50 | @auth.route("/login", methods=["GET", "POST"]) 51 | def login(): 52 | form = LoginForm() 53 | 54 | if form.validate_on_submit(): 55 | # メールアドレスからユーザーを取得する 56 | user = User.query.filter_by(email=form.email.data).first() 57 | 58 | # ユーザーが存在しパスワードが一致する場合はログインを許可する 59 | if user is not None and user.verify_password(form.password.data): 60 | login_user(user) 61 | return redirect(url_for("detector.index")) 62 | 63 | # ログイン失敗メッセージを設定する 64 | flash("メールアドレスかパスワードか不正です") 65 | return render_template("auth/login.html", form=form) 66 | 67 | 68 | @auth.route("/logout") 69 | def logout(): 70 | logout_user() 71 | return redirect(url_for("auth.login")) 72 | -------------------------------------------------------------------------------- /apps/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | basedir = Path(__file__).parent.parent 4 | 5 | 6 | # BaseConfigクラスを作成する 7 | class BaseConfig: 8 | SECRET_KEY = "2AZSMss3p5QPbcY2hBsJ" 9 | WTF_CSRF_SECRET_KEY = "AuwzyszU5sugKN7KZs6f" 10 | # 画像アップロード先にapps/imagesを指定する 11 | UPLOAD_FOLDER = str(Path(basedir, "apps", "images")) 12 | # 物体検知に利用するラベル 13 | LABELS = [ 14 | "unlabeled", 15 | "person", 16 | "bicycle", 17 | "car", 18 | "motorcycle", 19 | "airplane", 20 | "bus", 21 | "train", 22 | "truck", 23 | "boat", 24 | "traffic light", 25 | "fire hydrant", 26 | "street sign", 27 | "stop sign", 28 | "parking meter", 29 | "bench", 30 | "bird", 31 | "cat", 32 | "dog", 33 | "horse", 34 | "sheep", 35 | "cow", 36 | "elephant", 37 | "bear", 38 | "zebra", 39 | "giraffe", 40 | "hat", 41 | "backpack", 42 | "umbrella", 43 | "shoe", 44 | "eye glasses", 45 | "handbag", 46 | "tie", 47 | "suitcase", 48 | "frisbee", 49 | "skis", 50 | "snowboard", 51 | "sports ball", 52 | "kite", 53 | "baseball bat", 54 | "baseball glove", 55 | "skateboard", 56 | "surfboard", 57 | "tennis racket", 58 | "bottle", 59 | "plate", 60 | "wine glass", 61 | "cup", 62 | "fork", 63 | "knife", 64 | "spoon", 65 | "bowl", 66 | "banana", 67 | "apple", 68 | "sandwich", 69 | "orange", 70 | "broccoli", 71 | "carrot", 72 | "hot dog", 73 | "pizza", 74 | "donut", 75 | "cake", 76 | "chair", 77 | "couch", 78 | "potted plant", 79 | "bed", 80 | "mirror", 81 | "dining table", 82 | "window", 83 | "desk", 84 | "toilet", 85 | "door", 86 | "tv", 87 | "laptop", 88 | "mouse", 89 | "remote", 90 | "keyboard", 91 | "cell phone", 92 | "microwave", 93 | "oven", 94 | "toaster", 95 | "sink", 96 | "refrigerator", 97 | "blender", 98 | "book", 99 | "clock", 100 | "vase", 101 | "scissors", 102 | "teddy bear", 103 | "hair drier", 104 | "toothbrush", 105 | ] 106 | 107 | 108 | # BaseConfigクラスを継承してLocalConfigクラスを作成する 109 | class LocalConfig(BaseConfig): 110 | SQLALCHEMY_DATABASE_URI = f"sqlite:///{basedir / 'local.sqlite'}" 111 | SQLALCHEMY_TRACK_MODIFICATIONS = False 112 | SQLALCHEMY_ECHO = True 113 | 114 | 115 | # BaseConfigクラスを継承してTestingConfigクラスを作成する 116 | class TestingConfig(BaseConfig): 117 | SQLALCHEMY_DATABASE_URI = f"sqlite:///{basedir / 'testing.sqlite'}" 118 | SQLALCHEMY_TRACK_MODIFICATIONS = False 119 | WTF_CSRF_ENABLED = False 120 | # 画像アップロード先にtests/detector/imagesを指定する 121 | UPLOAD_FOLDER = str(Path(basedir, "tests", "detector", "images")) 122 | 123 | 124 | # config辞書にマッピングする 125 | config = { 126 | "testing": TestingConfig, 127 | "local": LocalConfig, 128 | } 129 | -------------------------------------------------------------------------------- /apps/crud/__init__.py: -------------------------------------------------------------------------------- 1 | import apps.crud.models 2 | -------------------------------------------------------------------------------- /apps/crud/forms.py: -------------------------------------------------------------------------------- 1 | from flask_wtf import FlaskForm 2 | from wtforms import PasswordField, StringField, SubmitField 3 | from wtforms.validators import DataRequired, Email, length 4 | 5 | 6 | # ユーザー新規作成とユーザー編集フォームクラス 7 | class UserForm(FlaskForm): 8 | # ユーザーフォームのusername属性のラベルとバリデータを設定する 9 | username = StringField( 10 | "ユーザー名", 11 | validators=[ 12 | DataRequired(message="ユーザー名は必須です。"), 13 | length(max=30, message="30文字以内で入力してください。"), 14 | ], 15 | ) 16 | 17 | # ユーザーフォームemail属性のラベルとバリデータを設定する 18 | email = StringField( 19 | "メールアドレス", 20 | validators=[ 21 | DataRequired(message="メールアドレスは必須です。"), 22 | Email(message="メールアドレスの形式で入力してください。"), 23 | ], 24 | ) 25 | 26 | # ユーザーフォームpassword属性のラベルとバリデータを設定する 27 | password = PasswordField("パスワード", validators=[DataRequired(message="パスワードは必須です。")]) 28 | 29 | # ユーザーフォームsubmitの文言を設定する 30 | submit = SubmitField("新規登録") 31 | -------------------------------------------------------------------------------- /apps/crud/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from apps.app import db, login_manager 4 | from flask_login import UserMixin 5 | from werkzeug.security import check_password_hash, generate_password_hash 6 | 7 | 8 | # db.Modelを継承したUserクラスを作成する 9 | class User(db.Model, UserMixin): 10 | # テーブル名を指定する 11 | __tablename__ = "users" 12 | # カラムを定義する 13 | id = db.Column(db.Integer, primary_key=True) 14 | username = db.Column(db.String, index=True) 15 | email = db.Column(db.String, unique=True, index=True) 16 | password_hash = db.Column(db.String) 17 | created_at = db.Column(db.DateTime, default=datetime.now) 18 | updated_at = db.Column(db.DateTime, default=datetime.now, onupdate=datetime.now) 19 | 20 | # backrefを利用しrelation情報を設定する 21 | user_images = db.relationship("UserImage", backref="user") 22 | 23 | # パスワードをセットするためのプロパティ 24 | @property 25 | def password(self): 26 | raise AttributeError("読み取り不可") 27 | 28 | # パスワードをセットするためのセッター関数でハッシュ化したパスワードをセットする 29 | @password.setter 30 | def password(self, password): 31 | self.password_hash = generate_password_hash(password) 32 | 33 | # パスワードチェックをする 34 | def verify_password(self, password): 35 | return check_password_hash(self.password_hash, password) 36 | 37 | # メールアドレス重複チェックをする 38 | def is_duplicate_email(self): 39 | return User.query.filter_by(email=self.email).first() is not None 40 | 41 | 42 | # ログインしているユーザー情報を取得する関数を作成する 43 | @login_manager.user_loader 44 | def load_user(user_id): 45 | return User.query.get(user_id) 46 | -------------------------------------------------------------------------------- /apps/crud/static/style.css: -------------------------------------------------------------------------------- 1 | table { 2 | /* 隣接するボーダーを重ねて1本にして表示する */ 3 | border-collapse: collapse; 4 | } 5 | table, 6 | th, 7 | td { 8 | /* 1pxの実線にして色を調節する */ 9 | border: 1px solid #c0c0c0; 10 | } -------------------------------------------------------------------------------- /apps/crud/templates/crud/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | {% block title %}{% endblock %} 8 | 9 | 10 | 11 | 12 |
13 | {% if current_user.is_authenticated %} 14 |

15 | {{ current_user.username }} - 16 | ログアウト 17 | 18 |

19 | {% endif %} 20 |
21 | 22 | {% block content %}{% endblock %} 23 | 24 | 25 | -------------------------------------------------------------------------------- /apps/crud/templates/crud/create.html: -------------------------------------------------------------------------------- 1 | {% extends "crud/base.html" %} 2 | {% block title %}ユーザー新規作成{% endblock %} 3 | {% block content %} 4 |

ユーザー新規作成

5 |
6 | {{ form.csrf_token }} 7 |

8 | {{ form.username.label }} {{ form.username(placeholder="ユーザー名") }} 9 |

10 | {% for error in form.username.errors %} 11 | {{ error }} 12 | {% endfor %} 13 |

14 | {{ form.email.label }} {{ form.email(placeholder="メールアドレス") }} 15 |

16 | {% for error in form.email.errors %} 17 | {{ error }} 18 | {% endfor %} 19 |

20 | {{ form.password.label }} {{ form.password(placeholder="パスワード") }} 21 |

22 | {% for error in form.password.errors %} 23 | {{ error }} 24 | {% endfor %} 25 |

{{ form.submit() }}

26 |
27 | {% endblock %} -------------------------------------------------------------------------------- /apps/crud/templates/crud/edit.html: -------------------------------------------------------------------------------- 1 | {% extends 'crud/base.html' %} 2 | {% block title %}ユーザー編集{% endblock %} 3 | {% block content %} 4 |

ユーザー編集

5 |
6 | {{ form.csrf_token }} 7 |

8 | {{ form.username.label }} {{ form.username(placeholder="ユーザー名", 9 | value=user.username) }} 10 |

11 | {% for error in form.username.errors %} 12 | {{ error }} 13 | {% endfor %} 14 |

15 | {{ form.email.label }} {{ form.email(placeholder="メールアドレス", 16 | value=user.email) }} 17 |

18 | {% for error in form.email.errors %} 19 | {{ error }} 20 | {% endfor %} 21 |

22 | {{ form.password.label }} {{ form.password(placeholder="パスワード") }} 23 |

24 | {% for error in form.password.errors %} 25 | {{ error }} 26 | {% endfor %} 27 |

28 |
29 | 30 |
31 | {{ form.csrf_token }} 32 | 33 |
34 | {% endblock %} -------------------------------------------------------------------------------- /apps/crud/templates/crud/index.html: -------------------------------------------------------------------------------- 1 | {% extends "crud/base.html" %} 2 | {% block title %}ユーザーの一覧{% endblock %} 3 | {% block content %} 4 |

ユーザーの一覧

5 | ユーザー新規作成 6 | 7 | 8 | 9 | 10 | 11 | 12 | {% for user in users %} 13 | 14 | 17 | 18 | 19 | 20 | {% endfor %} 21 |
ユーザーIDユーザー名メールアドレス
15 | {{ user.id }} 16 | {{ user.username }}{{ user.email }}
22 | {% endblock %} -------------------------------------------------------------------------------- /apps/crud/views.py: -------------------------------------------------------------------------------- 1 | from apps.app import db 2 | from apps.crud.forms import UserForm 3 | from apps.crud.models import User 4 | from flask import Blueprint, redirect, render_template, url_for 5 | from flask_login import login_required 6 | 7 | # Blueprintでcrudアプリを生成する 8 | crud = Blueprint( 9 | "crud", 10 | __name__, 11 | template_folder="templates", 12 | static_folder="static", 13 | ) 14 | 15 | 16 | # indexエンドポイントを作成しindex.htmlを返す 17 | @crud.route("/") 18 | @login_required 19 | def index(): 20 | return render_template("crud/index.html") 21 | 22 | 23 | @crud.route("/sql") 24 | @login_required 25 | def sql(): 26 | db.session.query(User).all() 27 | return "コンソールログを確認してください" 28 | 29 | 30 | @crud.route("/users/new", methods=["GET", "POST"]) 31 | @login_required 32 | def create_user(): 33 | # UserFormをインスタンス化する 34 | form = UserForm() 35 | 36 | # フォームの値をバリデートする 37 | if form.validate_on_submit(): 38 | # ユーザーを作成する 39 | user = User( 40 | username=form.username.data, 41 | email=form.email.data, 42 | password=form.password.data, 43 | ) 44 | 45 | # ユーザーを追加してコミットする 46 | db.session.add(user) 47 | db.session.commit() 48 | 49 | # ユーザーの一覧画面へリダイレクトする 50 | return redirect(url_for("crud.users")) 51 | return render_template("crud/create.html", form=form) 52 | 53 | 54 | @crud.route("/users") 55 | @login_required 56 | def users(): 57 | """ユーザーの一覧を取得する""" 58 | users = User.query.all() 59 | return render_template("crud/index.html", users=users) 60 | 61 | 62 | # methodsにGETとPOSTを指定する 63 | @crud.route("/users/", methods=["GET", "POST"]) 64 | @login_required 65 | def edit_user(user_id): 66 | form = UserForm() 67 | 68 | # Userモデルを利用してユーザーを取得する 69 | user = User.query.filter_by(id=user_id).first() 70 | 71 | # formからサブミットされた場合はユーザーを更新しユーザーの一覧画面へリダイレクトする 72 | if form.validate_on_submit(): 73 | user.username = form.username.data 74 | user.email = form.email.data 75 | user.password = form.password.data 76 | db.session.add(user) 77 | db.session.commit() 78 | return redirect(url_for("crud.users")) 79 | 80 | # GETの場合はHTMLを返す 81 | return render_template("crud/edit.html", user=user, form=form) 82 | 83 | 84 | @crud.route("/users//delete", methods=["POST"]) 85 | @login_required 86 | def delete_user(user_id): 87 | user = User.query.filter_by(id=user_id).first() 88 | db.session.delete(user) 89 | db.session.commit() 90 | return redirect(url_for("crud.users")) 91 | -------------------------------------------------------------------------------- /apps/detector/__init__.py: -------------------------------------------------------------------------------- 1 | import apps.detector.models 2 | -------------------------------------------------------------------------------- /apps/detector/forms.py: -------------------------------------------------------------------------------- 1 | from flask_wtf.file import FileAllowed, FileField, FileRequired 2 | from flask_wtf.form import FlaskForm 3 | from wtforms.fields.simple import SubmitField 4 | 5 | 6 | class UploadImageForm(FlaskForm): 7 | # ファイルフィールドに必要なバリデーションを設定する 8 | image = FileField( 9 | validators=[ 10 | FileRequired("画像ファイルを指定してください。"), 11 | FileAllowed(["png", "jpg", "jpeg"], "サポートされていない画像形式です。"), 12 | ] 13 | ) 14 | submit = SubmitField("アップロード") 15 | 16 | 17 | class DetectorForm(FlaskForm): 18 | submit = SubmitField("検知") 19 | 20 | 21 | class DeleteForm(FlaskForm): 22 | submit = SubmitField("削除") 23 | -------------------------------------------------------------------------------- /apps/detector/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from apps.app import db 4 | 5 | 6 | class UserImage(db.Model): 7 | __tablename__ = "user_images" 8 | id = db.Column(db.Integer, primary_key=True) 9 | # user_idはusersテーブルのidカラムを外部キーとして設定する 10 | user_id = db.Column(db.String, db.ForeignKey("users.id")) 11 | image_path = db.Column(db.String) 12 | is_detected = db.Column(db.Boolean, default=False) 13 | created_at = db.Column(db.DateTime, default=datetime.now) 14 | updated_at = db.Column(db.DateTime, default=datetime.now, onupdate=datetime.now) 15 | 16 | 17 | class UserImageTag(db.Model): 18 | # テーブル名を指定する 19 | __tablename__ = "user_image_tags" 20 | id = db.Column(db.Integer, primary_key=True) 21 | # user_image_idはuser_imagesテーブルのidカラムの外部キーとして設定する 22 | user_image_id = db.Column(db.String, db.ForeignKey("user_images.id")) 23 | tag_name = db.Column(db.String) 24 | created_at = db.Column(db.DateTime, default=datetime.now) 25 | updated_at = db.Column(db.DateTime, default=datetime.now, onupdate=datetime.now) 26 | -------------------------------------------------------------------------------- /apps/detector/templates/detector/404.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 404 Not Found(detector) 7 | 8 | 9 | 10 |

404 Not Found(detector)

11 |

アプリケーショントップへ

12 | 13 | 14 | -------------------------------------------------------------------------------- /apps/detector/templates/detector/500.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 500 Internal Server Error(detector) 7 | 8 | 9 | 10 |

500 Internal Server Error(detector)

11 |

アプリケーショントップへ

12 | 13 | 14 | -------------------------------------------------------------------------------- /apps/detector/templates/detector/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | detector 7 | 8 | 9 | 10 | 11 | 12 | 13 | 50 | 51 | 52 |
53 | {% block content %}{% endblock %} 54 |
55 | 56 | 57 | -------------------------------------------------------------------------------- /apps/detector/templates/detector/index.html: -------------------------------------------------------------------------------- 1 | {% extends "detector/base.html" %} 2 | {% block content %} 3 | 4 | 5 | {% with messages = get_flashed_messages() %} 6 | {% if messages %} 7 |
    8 | {% for message in messages %} 9 |
  • {{ message }}
  • 10 | {% endfor %} 11 |
12 | {% endif %} 13 | {% endwith %} 14 | 15 | 16 |
17 | 画像新規登録 18 |
19 | 20 | {% for user_image in user_images %} 21 |
22 |
23 |
{{ user_image.User.username }}
24 | 25 |
26 | 27 |
28 |
29 | {{ delete_form.csrf_token }} 30 | {% if current_user.id == user_image.User.id %} 31 | {{ delete_form.submit(class="btn btn-danger") }} 32 | {% else %} 33 | {{ delete_form.submit(class="btn btn-danger", disabled="disabled") }} 34 | {% endif %} 35 |
36 |
37 |
38 |
39 | {{ detector_form.csrf_token }} 40 | {% if current_user.id == user_image.User.id and user_image.UserImage.is_detected == False %} 41 | {{detector_form.submit(class="btn btn-primary")}} 42 | {% else %} 43 | {{ detector_form.submit(class="btn btn-primary",disabled="disabled")}} 44 | {% endif %} 45 |
46 |
47 |
48 |
49 |
50 | アップロード画像 51 |
52 | 53 |
54 | {% for tag in user_image_tag_dict[user_image.UserImage.id] %} 55 | #{{tag.tag_name }} 56 | {% endfor %} 57 |
58 |
59 | {% endfor %} 60 | {% endblock %} -------------------------------------------------------------------------------- /apps/detector/templates/detector/upload.html: -------------------------------------------------------------------------------- 1 | {% extends "detector/base.html" %} 2 | {% block content %} 3 |
4 |

画像新規登録

5 |

アップロードする画像を選択してください

6 |
8 | {{ form.csrf_token }} 9 |
10 | 13 |
14 | {% for error in form.image.errors %} 15 | {{ error }} 16 | {% endfor %} 17 |
18 |
19 | 20 |
21 |
22 |
23 | {% endblock %} -------------------------------------------------------------------------------- /apps/detector/views.py: -------------------------------------------------------------------------------- 1 | import random 2 | import uuid 3 | from pathlib import Path 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torchvision 9 | from apps.app import db 10 | from apps.crud.models import User 11 | 12 | # UploadImageFormをimportする 13 | from apps.detector.forms import DeleteForm, DetectorForm, UploadImageForm 14 | from apps.detector.models import UserImage, UserImageTag 15 | from flask import ( 16 | Blueprint, 17 | current_app, 18 | flash, 19 | redirect, 20 | render_template, 21 | request, 22 | send_from_directory, 23 | url_for, 24 | ) 25 | 26 | # login_required, current_userをimportする 27 | from flask_login import current_user, login_required 28 | from PIL import Image 29 | 30 | # template_folderを指定する(staticは指定しない) 31 | dt = Blueprint("detector", __name__, template_folder="templates") 32 | 33 | 34 | # dtアプリケーションを使ってエンドポイントを作成する 35 | @dt.route("/") 36 | def index(): 37 | # UserとUserImageをJoinして画像一覧を取得する 38 | user_images = ( 39 | db.session.query(User, UserImage) 40 | .join(UserImage) 41 | .filter(User.id == UserImage.user_id) 42 | .all() 43 | ) 44 | 45 | # タグ一覧を取得する 46 | user_image_tag_dict = {} 47 | for user_image in user_images: 48 | # 画像に紐づくタグ一覧を取得する 49 | user_image_tags = ( 50 | db.session.query(UserImageTag) 51 | .filter(UserImageTag.user_image_id == user_image.UserImage.id) 52 | .all() 53 | ) 54 | user_image_tag_dict[user_image.UserImage.id] = user_image_tags 55 | 56 | # 物体検知フォームをインスタンス化する 57 | detector_form = DetectorForm() 58 | # DeleteFormをインスタンス化する 59 | delete_form = DeleteForm() 60 | 61 | return render_template( 62 | "detector/index.html", 63 | user_images=user_images, 64 | # タグ一覧をテンプレートに渡す 65 | user_image_tag_dict=user_image_tag_dict, 66 | # 物体検知フォームをテンプレートに渡す 67 | detector_form=detector_form, 68 | # 画像削除フォームをテンプレートに渡す 69 | delete_form=delete_form, 70 | ) 71 | 72 | 73 | @dt.route("/images/") 74 | def image_file(filename): 75 | return send_from_directory(current_app.config["UPLOAD_FOLDER"], filename) 76 | 77 | 78 | @dt.route("/upload", methods=["GET", "POST"]) 79 | # ログイン必須とする 80 | @login_required 81 | def upload_image(): 82 | # UploadImageFormを利用してバリデーションをする 83 | form = UploadImageForm() 84 | if form.validate_on_submit(): 85 | # アップロードされた画像ファイルを取得する 86 | file = form.image.data 87 | 88 | # ファイルのファイル名と拡張子を取得し、ファイル名をuuidに変換する 89 | ext = Path(file.filename).suffix 90 | image_uuid_file_name = str(uuid.uuid4()) + ext 91 | 92 | # 画像を保存する 93 | image_path = Path(current_app.config["UPLOAD_FOLDER"], image_uuid_file_name) 94 | file.save(image_path) 95 | 96 | # DBに保存する 97 | user_image = UserImage(user_id=current_user.id, image_path=image_uuid_file_name) 98 | db.session.add(user_image) 99 | db.session.commit() 100 | 101 | return redirect(url_for("detector.index")) 102 | return render_template("detector/upload.html", form=form) 103 | 104 | 105 | def make_color(labels): 106 | # 枠線の色をランダムに決定 107 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in labels] 108 | color = random.choice(colors) 109 | return color 110 | 111 | 112 | def make_line(result_image): 113 | # 枠線を作成 114 | line = round(0.002 * max(result_image.shape[0:2])) + 1 115 | return line 116 | 117 | 118 | def draw_lines(c1, c2, result_image, line, color): 119 | # 四角形の枠線を画像に追記 120 | cv2.rectangle(result_image, c1, c2, color, thickness=line) 121 | return cv2 122 | 123 | 124 | def draw_texts(result_image, line, c1, cv2, color, labels, label): 125 | # 検知したテキストラベルを画像に追記 126 | display_txt = f"{labels[label]}" 127 | font = max(line - 1, 1) 128 | t_size = cv2.getTextSize(display_txt, 0, fontScale=line / 3, thickness=font)[0] 129 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 130 | cv2.rectangle(result_image, c1, c2, color, -1) 131 | cv2.putText( 132 | result_image, 133 | display_txt, 134 | (c1[0], c1[1] - 2), 135 | 0, 136 | line / 3, 137 | [225, 255, 255], 138 | thickness=font, 139 | lineType=cv2.LINE_AA, 140 | ) 141 | return cv2 142 | 143 | 144 | def exec_detect(target_image_path): 145 | # ラベルの読み込み 146 | labels = current_app.config["LABELS"] 147 | 148 | # 画像の読み込み 149 | image = Image.open(target_image_path) 150 | 151 | # 画像データをテンソル型の数値データへ変換 152 | image_tensor = torchvision.transforms.functional.to_tensor(image) 153 | 154 | # 学習済みモデルの読み込み 155 | model = torch.load(Path(current_app.root_path, "detector", "model.pt")) 156 | 157 | # モデルの推論モードに切り替え 158 | model = model.eval() 159 | 160 | # 推論の実行 161 | output = model([image_tensor])[0] 162 | tags = [] 163 | result_image = np.array(image.copy()) 164 | 165 | # 学習済みモデルが検知した各物体の分だけ画像に追記 166 | for box, label, score in zip(output["boxes"], output["labels"], output["scores"]): 167 | if score > 0.5 and labels[label] not in tags: 168 | print(score) 169 | print(labels[label]) 170 | # 枠線の色の決定 171 | color = make_color(labels) 172 | # 枠線の作成 173 | line = make_line(result_image) 174 | # 検知画像の枠線とテキストラベルの枠線の位置情報 175 | c1 = (int(box[0]), int(box[1])) 176 | c2 = (int(box[2]), int(box[3])) 177 | # 画像に枠線を追記 178 | cv2 = draw_lines(c1, c2, result_image, line, color) 179 | # 画像にテキストラベルを追記 180 | cv2 = draw_texts(result_image, line, c1, cv2, color, labels, label) 181 | tags.append(labels[label]) 182 | 183 | # 検知後の画像ファイル名を生成する 184 | detected_image_file_name = str(uuid.uuid4()) + ".jpg" 185 | 186 | # 画像コピー先パスを取得する 187 | detected_image_file_path = str( 188 | Path(current_app.config["UPLOAD_FOLDER"], detected_image_file_name) 189 | ) 190 | # 変換後の画像ファイルを保存先へコピーする 191 | cv2.imwrite(detected_image_file_path, cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR)) 192 | return tags, detected_image_file_name 193 | 194 | 195 | def save_detected_image_tags(user_image, tags, detected_image_file_name): 196 | # 検知後画像の保存先パスをDBに保存する 197 | user_image.image_path = detected_image_file_name 198 | # 検知フラグをTrueにする 199 | user_image.is_detected = True 200 | db.session.add(user_image) 201 | # user_images_tagsレコードを作成する 202 | for tag in tags: 203 | user_image_tag = UserImageTag(user_image_id=user_image.id, tag_name=tag) 204 | db.session.add(user_image_tag) 205 | db.session.commit() 206 | 207 | 208 | @dt.route("/detect/", methods=["POST"]) 209 | # login_requiredデコレーターをつけてログイン必須とする 210 | @login_required 211 | def detect(image_id): 212 | # user_imagesテーブルからレコードを取得する 213 | user_image = db.session.query(UserImage).filter(UserImage.id == image_id).first() 214 | if user_image is None: 215 | flash("物体検知対象の画像が存在しません。") 216 | return redirect(url_for("detector.index")) 217 | 218 | # 物体検知対象の画像パスを取得する 219 | target_image_path = Path(current_app.config["UPLOAD_FOLDER"], user_image.image_path) 220 | # 物体検知を実行してタグと変換後の画像パスを取得する 221 | tags, detected_image_file_name = exec_detect(target_image_path) 222 | 223 | try: 224 | # データベースにタグと変換後の画像パス情報を保存する 225 | save_detected_image_tags(user_image, tags, detected_image_file_name) 226 | except Exception as e: 227 | flash("物体検知処理でエラーが発生しました。") 228 | # ロールバックする 229 | db.session.rollback() 230 | # エラーログ出力 231 | current_app.logger.error(e) 232 | return redirect(url_for("detector.index")) 233 | return redirect(url_for("detector.index")) 234 | 235 | 236 | @dt.route("/images/delete/", methods=["POST"]) 237 | @login_required 238 | def delete_image(image_id): 239 | try: 240 | # user_image_tagsテーブルからレコードを削除する 241 | db.session.query(UserImageTag).filter( 242 | UserImageTag.user_image_id == image_id 243 | ).delete() 244 | 245 | # user_imageテーブルからレコードを削除する 246 | db.session.query(UserImage).filter(UserImage.id == image_id).delete() 247 | 248 | db.session.commit() 249 | except Exception as e: 250 | flash("画像削除処理でエラーが発生しました。") 251 | # エラーログ出力 252 | current_app.logger.error(e) 253 | db.session.rollback() 254 | return redirect(url_for("detector.index")) 255 | 256 | 257 | @dt.route("/images/search", methods=["GET"]) 258 | def search(): 259 | # 画像一覧を取得する 260 | user_images = db.session.query(User, UserImage).join( 261 | UserImage, User.id == UserImage.user_id 262 | ) 263 | 264 | # GETパラメータから検索ワードを取得する 265 | search_text = request.args.get("search") 266 | 267 | user_image_tag_dict = {} 268 | filtered_user_images = [] 269 | 270 | # user_imagesをループしuser_imagesに紐づくタグ情報を検索する 271 | for user_image in user_images: 272 | # 検索ワードが空の場合はすべてのタグを取得する 273 | if not search_text: 274 | # タグ一覧を取得する 275 | user_image_tags = ( 276 | db.session.query(UserImageTag) 277 | .filter(UserImageTag.user_image_id == user_image.UserImage.id) 278 | .all() 279 | ) 280 | else: 281 | # 検索ワードで絞り込んだタグを取得する 282 | user_image_tags = ( 283 | db.session.query(UserImageTag) 284 | .filter(UserImageTag.user_image_id == user_image.UserImage.id) 285 | .filter(UserImageTag.tag_name.like("%" + search_text + "%")) 286 | .all() 287 | ) 288 | 289 | # タグが見つからなかったら画像を返さない 290 | if not user_image_tags: 291 | continue 292 | 293 | # タグがある場合はタグ情報を取得しなおす 294 | user_image_tags = ( 295 | db.session.query(UserImageTag) 296 | .filter(UserImageTag.user_image_id == user_image.UserImage.id) 297 | .all() 298 | ) 299 | 300 | # user_image_id をキーとする辞書にタグ情報をセットする 301 | user_image_tag_dict[user_image.UserImage.id] = user_image_tags 302 | 303 | # 絞り込み結果のuser_image情報を配列セットする 304 | filtered_user_images.append(user_image) 305 | 306 | delete_form = DeleteForm() 307 | detector_form = DetectorForm() 308 | 309 | return render_template( 310 | "detector/index.html", 311 | # 絞り込んだuser_images配列を渡す 312 | user_images=filtered_user_images, 313 | # 画像に紐づくタグ一覧の辞書を渡す 314 | user_image_tag_dict=user_image_tag_dict, 315 | delete_form=delete_form, 316 | detector_form=detector_form, 317 | ) 318 | 319 | 320 | @dt.errorhandler(404) 321 | def page_not_found(e): 322 | return render_template("detector/404.html"), 404 323 | -------------------------------------------------------------------------------- /apps/images/.ignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/apps/images/.ignore -------------------------------------------------------------------------------- /apps/minimalapp/app.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from email_validator import EmailNotValidError, validate_email 5 | from flask import ( 6 | Flask, 7 | current_app, 8 | flash, 9 | g, 10 | make_response, 11 | redirect, 12 | render_template, 13 | request, 14 | session, 15 | url_for, 16 | ) 17 | from flask_debugtoolbar import DebugToolbarExtension 18 | from flask_mail import Mail, Message 19 | 20 | # Flaskクラスをインスタンス化する 21 | app = Flask(__name__) 22 | 23 | # SECRET_KEYを追加する 24 | app.config["SECRET_KEY"] = "2AZSMss3p5QPbcY2hBsJ" 25 | 26 | # ログレベルを設定する 27 | app.logger.setLevel(logging.DEBUG) 28 | 29 | # リダイレクトを中断しないようにする 30 | app.config["DEBUG_TB_INTERCEPT_REDIRECTS"] = False 31 | 32 | # Mailクラスのコンフィグを追加する 33 | app.config["MAIL_SERVER"] = os.environ.get("MAIL_SERVER") 34 | app.config["MAIL_PORT"] = os.environ.get("MAIL_PORT") 35 | app.config["MAIL_USE_TLS"] = os.environ.get("MAIL_USE_TLS") 36 | app.config["MAIL_USERNAME"] = os.environ.get("MAIL_USERNAME") 37 | app.config["MAIL_PASSWORD"] = os.environ.get("MAIL_PASSWORD") 38 | app.config["MAIL_DEFAULT_SENDER"] = os.environ.get("MAIL_DEFAULT_SENDER") 39 | 40 | # DebugToolbarExtensionにアプリケーションをセットする 41 | toolbar = DebugToolbarExtension(app) 42 | 43 | # flask-mail拡張を登録する 44 | mail = Mail(app) 45 | 46 | 47 | # URLと実行する関数をマッピングする 48 | @app.route("/") 49 | def index(): 50 | return "Hello, Flaskbook!" 51 | 52 | 53 | @app.route("/hello/", methods=["GET"], endpoint="hello-endpoint") 54 | def hello(name): 55 | return f"Hello, {name}" 56 | 57 | 58 | # show_nameエンドポイントを作成する 59 | @app.route("/name/") 60 | def show_name(name): 61 | # 変数をテンプレートエンジンに渡す 62 | return render_template("index.html", name=name) 63 | 64 | 65 | # Flask2からは@app.get("/hello")、@app.post("/hello")と記述することが可能 66 | # @app.get("/hello") 67 | # @app.post("/hello") 68 | # def hello(): 69 | # return "Hello, World!" 70 | 71 | 72 | with app.test_request_context(): 73 | # / 74 | print(url_for("index")) 75 | # /hello/world 76 | print(url_for("hello-endpoint", name="world")) 77 | # /name/ichiro?page=ichiro 78 | print(url_for("show_name", name="ichiro", page="1")) 79 | 80 | 81 | # ここで呼び出すとエラーとなる 82 | # print(current_app) 83 | 84 | # アプリケーションコンテキストを取得してスタックへpushする 85 | ctx = app.app_context() 86 | ctx.push() 87 | 88 | # current_appにアクセスが可能になる 89 | print(current_app.name) 90 | # >> apps.minimalapp.app 91 | 92 | # グローバルなテンポラリ領域に値を設定する 93 | g.connection = "connection" 94 | print(g.connection) 95 | # >> connection 96 | 97 | with app.test_request_context("/users?updated=true"): 98 | # trueが出力される 99 | print(request.args.get("updated")) 100 | 101 | 102 | @app.route("/contact") 103 | def contact(): 104 | # レスポンスオブジェクトを取得する 105 | response = make_response(render_template("contact.html")) 106 | 107 | # クッキーを設定する 108 | response.set_cookie("flaskbook key", "flaskbook value") 109 | 110 | # セッションを設定する 111 | session["username"] = "ichiro" 112 | 113 | # レスポンスオブジェクトを返す 114 | return response 115 | 116 | 117 | @app.route("/contact/complete", methods=["GET", "POST"]) 118 | def contact_complete(): 119 | if request.method == "POST": 120 | # form属性を使ってフォームの値を取得する 121 | username = request.form["username"] 122 | email = request.form["email"] 123 | description = request.form["description"] 124 | 125 | # 入力チェック 126 | is_valid = True 127 | if not username: 128 | flash("ユーザ名は必須です") 129 | is_valid = False 130 | 131 | if not email: 132 | flash("メールアドレスは必須です") 133 | is_valid = False 134 | 135 | try: 136 | validate_email(email) 137 | except EmailNotValidError: 138 | flash("メールアドレスの形式で入力してください") 139 | is_valid = False 140 | 141 | if not description: 142 | flash("問い合わせ内容は必須です") 143 | is_valid = False 144 | 145 | if not is_valid: 146 | return redirect(url_for("contact")) 147 | 148 | # メールを送る 149 | send_email( 150 | email, 151 | "問い合わせありがとうございました。", 152 | "contact_mail", 153 | username=username, 154 | description=description, 155 | ) 156 | 157 | # 問い合わせ完了エンドポイントへリダイレクトする 158 | flash("問い合わせ内容はメールにて送信しました。問い合わせありがとうございます。") 159 | 160 | # contactエンドポイントへリダイレクトする 161 | return redirect(url_for("contact_complete")) 162 | return render_template("contact_complete.html") 163 | 164 | 165 | app.logger.critical("fatal error") 166 | app.logger.error("error") 167 | app.logger.warning("warning") 168 | app.logger.info("info") 169 | app.logger.debug("debug") 170 | 171 | 172 | def send_email(to, subject, template, **kwargs): 173 | """メールを送信する関数""" 174 | msg = Message(subject, recipients=[to]) 175 | msg.body = render_template(template + ".txt", **kwargs) 176 | msg.html = render_template(template + ".html", **kwargs) 177 | mail.send(msg) 178 | -------------------------------------------------------------------------------- /apps/minimalapp/static/style.css: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/apps/minimalapp/static/style.css -------------------------------------------------------------------------------- /apps/minimalapp/templates/contact.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 問い合わせフォーム 7 | 8 | 9 | 10 | 11 |

問い合わせフォーム

12 | 13 | {% with messages = get_flashed_messages() %} 14 | {% if messages %} 15 |
    16 | {% for message in messages %} 17 |
  • {{ message }}
  • 18 | {% endfor %} 19 |
20 | {% endif %} 21 | {% endwith %} 22 |
23 | 24 | 25 | 26 | 29 | 30 | 31 | 32 | 35 | 36 | 37 | 38 | 41 | 42 |
ユーザ名 27 | 28 |
メールアドレス 33 | 34 |
問い合わせ内容 39 | 40 |
43 | 44 |
45 | 46 | 47 | -------------------------------------------------------------------------------- /apps/minimalapp/templates/contact_complete.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 問い合わせ完了 7 | 8 | 9 | 10 | 11 |

問い合わせ完了

12 | 13 | {% with messages = get_flashed_messages() %} 14 | {% if messages %} 15 |
    16 | {% for message in messages %} 17 |
  • {{ message }}
  • 18 | {% endfor %} 19 |
20 | {% endif %} 21 | {% endwith %} 22 | 23 | 24 | -------------------------------------------------------------------------------- /apps/minimalapp/templates/contact_mail.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 問い合わせ完了 7 | 8 | 9 | 10 |

{{ username }} 様

11 |

問い合わせありがとうございました。問い合わせ内容はこちらになります。

12 |

問い合わせ内容

13 |

{{ description }}

14 | 15 | 16 | -------------------------------------------------------------------------------- /apps/minimalapp/templates/contact_mail.txt: -------------------------------------------------------------------------------- 1 | {{ username }} 様 2 | 3 | 問い合わせありがとうございました。問い合わせ内容はこちらになります。 4 | 5 | 問い合わせ内容 6 | 7 | {{ description }} -------------------------------------------------------------------------------- /apps/minimalapp/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Name 7 | 8 | 9 | 10 | 11 |

Name: {{ name }}

12 | 13 | 14 | -------------------------------------------------------------------------------- /apps/static/css/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | background-color: #f5f5f5; 3 | } 4 | h4 { 5 | margin-top: 20px; 6 | } 7 | input[type="search"] { 8 | background-color: #f5f5f5; 9 | } 10 | .dt-auth-main { 11 | width: 400px; 12 | margin-top: 45px; 13 | } 14 | .dt-auth-main .card { 15 | box-shadow: 0 12px 18px 2px rgba(34, 0, 51, 0.04), 0 6px 22px 4px 16 | rgba(7, 48, 114, 0.12), 17 | 0 6px 10px -4px rgba(14, 13, 26, 0.12) !important; 18 | border-radius: 16px; 19 | } 20 | .dt-auth-main .dt-auth-login { 21 | height: 300px !important; 22 | } 23 | .dt-auth-main .dt-auth-signup { 24 | height: 340px !important; 25 | } 26 | .dt-auth-main header { 27 | text-align: center; 28 | margin: 30px 0 0 0; 29 | font-size: 24px; 30 | } 31 | .dt-auth-main section { 32 | width: 300px; 33 | margin: 10px auto; 34 | } 35 | .dt-auth-flash { 36 | font-size: 14px; 37 | color: #9c1a1c; 38 | } 39 | .dt-auth-input { 40 | margin-top: 10px; 41 | } 42 | .dt-auth-btn { 43 | margin: 30px 0 0 0; 44 | } 45 | .dt-search { 46 | height: 28px !important; 47 | } 48 | .dt-image-content { 49 | margin: 20px auto; 50 | padding: 0; 51 | } 52 | .dt-image-username { 53 | padding-top: 15px; 54 | } 55 | .dt-image-register-btn { 56 | padding: 10px 47px 0 0; 57 | } 58 | .dt-image-content header { 59 | padding: 10px 10px 0 10px; 60 | } 61 | .dt-image-content section { 62 | padding: 10px 0; 63 | margin: auto; 64 | } 65 | .dt-image-content footer { 66 | padding: 10px; 67 | } 68 | .dt-image-content section img { 69 | width: 100%; 70 | } 71 | .dt-image-file { 72 | display: none; 73 | } 74 | .dt-image-submit { 75 | margin-left: 10px; 76 | } -------------------------------------------------------------------------------- /apps/templates/404.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 |   6 | 7 | 404 Not Found 8 | 9 | 10 | 11 |

404 Not Found

12 |

アプリケーショントップへ

13 | 14 | 15 | -------------------------------------------------------------------------------- /apps/templates/500.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 500 Internal Server Error 7 | 8 | 9 | 10 |

500 Internal Server Error

11 |

アプリケーショントップへ

12 | 13 | 14 | -------------------------------------------------------------------------------- /flaskbook_api/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /flaskbook_api/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ml-flaskbook 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /flaskbook_api/README.md: -------------------------------------------------------------------------------- 1 | # flaskbook_api 2 | 3 | ## Part 3 からはじめる場合 4 | 5 | ### プロジェクトのセットアップ 6 | 7 | #### Mac/Linux 8 | 9 | ``` 10 | $ python3 -m venv venv 11 | $ . venv/bin/activate 12 | (venv) $ pip install -r requirements.txt 13 | ``` 14 | 15 | #### Windows(PowerShell) 16 | 17 | ``` 18 | > py -3 -m venv venv 19 | > venv\Scripts\Activate.ps1 20 | > pip install -r requirements.txt 21 | ``` 22 | 23 | ### 実行 24 | 25 | ``` 26 | (venv) flask run 27 | ``` 28 | 29 | ### PyTorchの学習済みモデルの作成・保存 30 | 31 | ``` 32 | $ python 33 | Python 3.9.7 (v3.9.7:1016ef3790, Aug 30 2021, 16:39:15) 34 | [Clang 6.0 (clang-600.0.57)] on darwin 35 | Type "help", "copyright", "credits" or "license" for more information. 36 | >>> import torch 37 | >>> import torchvision 38 | >>> model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) 39 | >>> torch.save(model, "model.pt") 40 | ``` 41 | -------------------------------------------------------------------------------- /flaskbook_api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/flaskbook_api/__init__.py -------------------------------------------------------------------------------- /flaskbook_api/api/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint, jsonify, request 2 | 3 | from flaskbook_api.api import calculation 4 | 5 | api = Blueprint("api", __name__) 6 | 7 | 8 | @api.get("/") 9 | def index(): 10 | return jsonify({"column": "value"}), 201 11 | 12 | 13 | @api.post("/detect") 14 | def detection(): 15 | return calculation.detection(request) 16 | -------------------------------------------------------------------------------- /flaskbook_api/api/calculation.py: -------------------------------------------------------------------------------- 1 | from os import abort 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import cv2 6 | import torch 7 | from flask import current_app, jsonify 8 | 9 | from flaskbook_api.api.postprocess import draw_lines, draw_texts, make_color, make_line 10 | from flaskbook_api.api.preparation import load_image 11 | from flaskbook_api.api.preprocess import image_to_tensor 12 | 13 | basedir = Path(__file__).parent.parent 14 | 15 | 16 | def detection(request): 17 | dict_results = {} 18 | # ラベルの読み込み 19 | labels = current_app.config["LABELS"] 20 | # 画像の読み込み 21 | image, filename = load_image(request) 22 | # 画像データをテンソル型の数値データへ変更 23 | image_tensor = image_to_tensor(image) 24 | 25 | # 学習済みモデルの読み込み 26 | try: 27 | model = torch.load("model.pt") 28 | except FileNotFoundError: 29 | return jsonify("The model is not found"), 404 30 | 31 | # モデルの推論モードに切り替え 32 | model = model.eval() 33 | # 推論の実行 34 | output = model([image_tensor])[0] 35 | 36 | result_image = np.array(image.copy()) 37 | # 学習済みモデルが検知した物体の画像に枠線とラベルを追記 38 | for box, label, score in zip(output["boxes"], output["labels"], output["scores"]): 39 | # スコアが0.6以上と重複していないラベルに絞り込み 40 | if score > 0.6 and labels[label] not in dict_results: 41 | # 枠線の色の決定 42 | color = make_color(labels) 43 | # 枠線の作成 44 | line = make_line(result_image) 45 | # 検知画像の枠線とテキストラベルの枠線の位置情報 46 | c1 = (int(box[0]), int(box[1])) 47 | c2 = (int(box[2]), int(box[3])) 48 | # 画像に枠線を追記 49 | draw_lines(c1, c2, result_image, line, color) 50 | # 画像にテキストラベルを追記 51 | draw_texts(result_image, line, c1, color, labels[label]) 52 | # 検知されたラベルとスコアの辞書を作成 53 | dict_results[labels[label]] = round(100 * score.item()) 54 | # 画像保存先のディレクトリのフルパスを作成 55 | dir_image = str(basedir / "data" / "output" /filename) 56 | 57 | # 検知後の画像ファイルを保存 58 | cv2.imwrite(dir_image, cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR)) 59 | return jsonify(dict_results), 201 60 | -------------------------------------------------------------------------------- /flaskbook_api/api/config/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, local 2 | 3 | config = { 4 | "base": base.Config, 5 | "local": local.LocalConfig, 6 | } 7 | -------------------------------------------------------------------------------- /flaskbook_api/api/config/base.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | TESTING = False 3 | DEBUG = False 4 | # 検知するラベル 5 | LABELS = [ 6 | "unlabeled", 7 | "person", 8 | "bicycle", 9 | "car", 10 | "motorcycle", 11 | "airplane", 12 | "bus", 13 | "train", 14 | "truck", 15 | "boat", 16 | "traffic light", 17 | "fire hydrant", 18 | "street sign", 19 | "stop sign", 20 | "parking meter", 21 | "bench", 22 | "bird", 23 | "cat", 24 | "dog", 25 | "horse", 26 | "sheep", 27 | "cow", 28 | "elephant", 29 | "bear", 30 | "zebra", 31 | "giraffe", 32 | "hat", 33 | "backpack", 34 | "umbrella", 35 | "shoe", 36 | "eye glasses", 37 | "handbag", 38 | "tie", 39 | "suitcase", 40 | "frisbee", 41 | "skis", 42 | "snowboard", 43 | "sports ball", 44 | "kite", 45 | "baseball bat", 46 | "baseball glove", 47 | "skateboard", 48 | "surfboard", 49 | "tennis racket", 50 | "bottle", 51 | "plate", 52 | "wine glass", 53 | "cup", 54 | "fork", 55 | "knife", 56 | "spoon", 57 | "bowl", 58 | "banana", 59 | "apple", 60 | "sandwich", 61 | "orange", 62 | "broccoli", 63 | "carrot", 64 | "hot dog", 65 | "pizza", 66 | "donut", 67 | "cake", 68 | "chair", 69 | "couch", 70 | "potted plant", 71 | "bed", 72 | "mirror", 73 | "dining table", 74 | "window", 75 | "desk", 76 | "toilet", 77 | "door", 78 | "tv", 79 | "laptop", 80 | "mouse", 81 | "remote", 82 | "keyboard", 83 | "cell phone", 84 | "microwave", 85 | "oven", 86 | "toaster", 87 | "sink", 88 | "refrigerator", 89 | "blender", 90 | "book", 91 | "clock", 92 | "vase", 93 | "scissors", 94 | "teddy bear", 95 | "hair drier", 96 | "toothbrush", 97 | ] 98 | -------------------------------------------------------------------------------- /flaskbook_api/api/config/local.py: -------------------------------------------------------------------------------- 1 | from flaskbook_api.api.config.base import Config 2 | 3 | 4 | class LocalConfig(Config): 5 | TESTING = True 6 | DEBUG = True 7 | -------------------------------------------------------------------------------- /flaskbook_api/api/postprocess.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | 5 | 6 | def make_color(labels): 7 | """枠線の色をランダムに決定""" 8 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in labels] 9 | color = random.choice(colors) 10 | return color 11 | 12 | 13 | def make_line(result_image): 14 | """枠線を作成""" 15 | line = round(0.002 * max(result_image.shape[0:2])) + 1 16 | return line 17 | 18 | 19 | def draw_lines(c1, c2, result_image, line, color): 20 | """枠線を追記""" 21 | cv2.rectangle(result_image, c1, c2, color, thickness=line) 22 | 23 | 24 | def draw_texts(result_image, line, c1, color, display_txt): 25 | """検知したテキストラベルを画像に追記""" 26 | # テキストサイズの取得 27 | font = max(line - 1, 1) 28 | t_size = cv2.getTextSize(display_txt, 0, fontScale=line / 3, thickness=font)[0] 29 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 30 | 31 | # テキストボックスの追加 32 | cv2.rectangle(result_image, c1, c2, color, -1) 33 | # テキストラベル及びテキストボックスの加工 34 | cv2.putText( 35 | result_image, 36 | display_txt, 37 | (c1[0], c1[1] - 2), 38 | 0, 39 | line / 3, 40 | [225, 255, 255], 41 | thickness=font, 42 | lineType=cv2.LINE_AA, 43 | ) -------------------------------------------------------------------------------- /flaskbook_api/api/preparation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import PIL 4 | 5 | basedir = Path(__file__).parent.parent 6 | 7 | 8 | def load_image(request, reshaped_size=(256, 256)): 9 | """画像の読み込み""" 10 | filename = request.json["filename"] 11 | dir_image = str(basedir / "data" / "original" /filename) 12 | image_obj = PIL.Image.open(dir_image).convert('RGB') 13 | image = image_obj.resize(reshaped_size) 14 | return image, filename 15 | -------------------------------------------------------------------------------- /flaskbook_api/api/preprocess.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | 4 | def image_to_tensor(image): 5 | """画像データをテンソル型の数値データへ変換""" 6 | image_tensor = torchvision.transforms.functional.to_tensor(image) 7 | return image_tensor 8 | -------------------------------------------------------------------------------- /flaskbook_api/data/original/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/flaskbook_api/data/original/test.jpg -------------------------------------------------------------------------------- /flaskbook_api/data/output/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/flaskbook_api/data/output/test.jpg -------------------------------------------------------------------------------- /flaskbook_api/requirements.txt: -------------------------------------------------------------------------------- 1 | click==8.0.3 2 | Flask==2.0.2 3 | itsdangerous==2.0.1 4 | Jinja2==3.0.3 5 | MarkupSafe==2.0.1 6 | mypy-extensions==0.4.3 7 | numpy==1.21.4 8 | opencv-python==4.5.1.48 9 | pathspec==0.9.0 10 | Pillow==8.4.0 11 | platformdirs==2.4.0 12 | regex==2021.11.10 13 | tomli==1.2.2 14 | torch==1.10.0 15 | torchaudio==0.10.0 16 | torchvision==0.11.1 17 | typing-extensions==4.0.0 18 | Werkzeug==2.0.2 19 | -------------------------------------------------------------------------------- /flaskbook_api/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from flask import Flask 4 | 5 | from flaskbook_api.api import api 6 | from flaskbook_api.api.config import config 7 | 8 | config_name = os.environ.get("CONFIG", "local") 9 | 10 | app = Flask(__name__) 11 | app.config.from_object(config[config_name]) 12 | # blueprintをアプリケーションに登録 13 | app.register_blueprint(api) 14 | -------------------------------------------------------------------------------- /ml_api/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /ml_api/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ml-flaskbook 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ml_api/README.md: -------------------------------------------------------------------------------- 1 | # ml_api 2 | 3 | ## Part 4 からはじめる場合 4 | 5 | ### プロジェクトのセットアップ 6 | 7 | #### Mac/Linux 8 | 9 | ``` 10 | $ python3 -m venv venv 11 | $ . venv/bin/activate 12 | (venv) $ pip install -r requirements.txt 13 | ``` 14 | 15 | #### Windows(PowerShell) 16 | 17 | ``` 18 | > py -3 -m venv venv 19 | > venv\Scripts\Activate.ps1 20 | > pip install -r requirements.txt 21 | ``` 22 | 23 | ### DB Migrate 24 | 25 | ``` 26 | (venv) $ flask db migrate 27 | (venv) $ flask db upgrade 28 | ``` 29 | 30 | ### 実行 31 | 32 | ``` 33 | (venv) flask run 34 | ``` 35 | -------------------------------------------------------------------------------- /ml_api/analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 1.1. コードリーディング/コードドキュメンテーション" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 86, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "手書き文字の判別結果\n", 20 | "観測結果: [0 1 2 3 4 5 6 7 8 9]\n", 21 | "予測結果: [4 4 4 4 4 4 4 7 4 4]\n", 22 | "正解率: 0.2\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "import os\n", 28 | "import numpy as np\n", 29 | "from PIL import Image\n", 30 | "import sqlite3\n", 31 | "from sklearn.datasets import load_digits\n", 32 | "from sklearn.linear_model import LogisticRegression\n", 33 | "from sklearn.model_selection import train_test_split\n", 34 | "\n", 35 | "INCLUDED_EXTENTION = [\".png\", \".jpg\"]\n", 36 | "\n", 37 | "dbname = 'images.db'\n", 38 | "conn = sqlite3.connect(dbname)\n", 39 | "cur = conn.cursor()\n", 40 | "cur.execute('DROP TABLE image_info')\n", 41 | "cur.execute('CREATE TABLE image_info (id INTEGER PRIMARY KEY AUTOINCREMENT, filename STRING)')\n", 42 | "conn.commit()\n", 43 | "conn.close()\n", 44 | "\n", 45 | "conn = sqlite3.connect(dbname)\n", 46 | "cur = conn.cursor()\n", 47 | "filenames = sorted(os.listdir('handwriting_pics'))\n", 48 | "for filename in filenames:\n", 49 | " base, ext = os.path.splitext(filename)\n", 50 | " if ext not in INCLUDED_EXTENTION:\n", 51 | " continue\n", 52 | " cur.execute('INSERT INTO image_info(filename) values(?)', (filename,))\n", 53 | "conn.commit()\n", 54 | "cur.close()\n", 55 | "conn.close()\n", 56 | "\n", 57 | "conn = sqlite3.connect(dbname)\n", 58 | "cur = conn.cursor()\n", 59 | "cur.execute('SELECT * FROM image_info')\n", 60 | "pics_info = cur.fetchall()\n", 61 | "cur.close()\n", 62 | "conn.close()\n", 63 | "\n", 64 | "img_test = np.empty((0, 64))\n", 65 | "for pic_info in pics_info:\n", 66 | " filename = pic_info[1]\n", 67 | " base, ext = os.path.splitext(filename)\n", 68 | " if ext not in INCLUDED_EXTENTION:\n", 69 | " continue\n", 70 | " img = Image.open(f'handwriting_pics/{filename}').convert('L')\n", 71 | " img_data256 = 255 - np.array(img.resize((8, 8)))\n", 72 | "\n", 73 | " min_bright = img_data256.min()\n", 74 | " max_bright = img_data256.max()\n", 75 | " img_data16 = (img_data256 - min_bright) / (max_bright - min_bright) * 16\n", 76 | " img_test = np.r_[img_test, img_data16.astype(np.uint8).reshape(1, -1)]\n", 77 | "\n", 78 | "digits = load_digits()\n", 79 | "X = digits.data\n", 80 | "y = digits.target\n", 81 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)\n", 82 | "logreg = LogisticRegression(max_iter=2000)\n", 83 | "logreg_model = logreg.fit(X_train, y_train)\n", 84 | "\n", 85 | "X_true = []\n", 86 | "for filename in filenames:\n", 87 | " base, ext = os.path.splitext(filename)\n", 88 | " if ext not in INCLUDED_EXTENTION:\n", 89 | " continue\n", 90 | " X_true = X_true + [int(filename[:1])]\n", 91 | "X_true = np.array(X_true)\n", 92 | "pred_logreg = logreg_model.predict(img_test)\n", 93 | "\n", 94 | "print('手書き文字の判別結果')\n", 95 | "print('観測結果:', X_true)\n", 96 | "print('予測結果:', pred_logreg)\n", 97 | "print('正解率:', logreg_model.score(img_test, X_true))" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "metadata": {}, 103 | "source": [ 104 | "### データへのアクセスするコード" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 71, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "INCLUDED_EXTENTION = [\".png\", \".jpg\"]\n", 114 | "\n", 115 | "# 画像の入っているフォルダを指定し、中身のファイル名を取得\n", 116 | "# images.dbを新規作成。images.dbがすでに存在していれば、接続。\n", 117 | "dbname = 'images.db'\n", 118 | "# データベースへのコネクションオブジェクト作成\n", 119 | "conn = sqlite3.connect(dbname)\n", 120 | "# sqliteを操作するカーソルオブジェクトを作成\n", 121 | "cur = conn.cursor()\n", 122 | "# データベースの初期化\n", 123 | "cur.execute('DROP TABLE image_info')\n", 124 | "# image_infoというtableを作成。\n", 125 | "cur.execute('CREATE TABLE image_info (id INTEGER PRIMARY KEY AUTOINCREMENT, filename STRING)')\n", 126 | "# データベースへコミットし、変更を保存\n", 127 | "conn.commit()\n", 128 | "conn.close()\n", 129 | "\n", 130 | "# データベースに画像のファイル名を挿入\n", 131 | "conn = sqlite3.connect(dbname)\n", 132 | "cur = conn.cursor()\n", 133 | "filenames = sorted(os.listdir('handwriting_pics'))\n", 134 | "for filename in filenames:\n", 135 | " base, ext = os.path.splitext(filename)\n", 136 | " if ext not in INCLUDED_EXTENTION:\n", 137 | " continue\n", 138 | " cur.execute('INSERT INTO image_info(filename) values(?)', (filename,))\n", 139 | "conn.commit()\n", 140 | "cur.close()\n", 141 | "conn.close()\n", 142 | "\n", 143 | "# tableの中身を取得\n", 144 | "conn = sqlite3.connect(dbname)\n", 145 | "cur = conn.cursor()\n", 146 | "cur.execute('SELECT * FROM image_info')\n", 147 | "# fetchall()を使って中身を全て取得\n", 148 | "pics_info = cur.fetchall()\n", 149 | "cur.close()\n", 150 | "conn.close()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "### データの前処理をするコード" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 78, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "img_test = np.empty((0, 64))\n", 167 | "# フォルダ内の全画像をデータ化\n", 168 | "for pic_info in pics_info:\n", 169 | " filename = pic_info[1]\n", 170 | " # 画像ファイルを取得、グレースケールにしてサイズ変更\n", 171 | " base, ext = os.path.splitext(filename)\n", 172 | " if ext not in INCLUDED_EXTENTION:\n", 173 | " continue\n", 174 | " img = Image.open(f'handwriting_pics/{filename}').convert('L')\n", 175 | " img_data256 = 255 - np.array(img.resize((8, 8)))\n", 176 | "\n", 177 | " #画像データ内の最小値が0、最大値が16になるように計算\n", 178 | " min_bright = img_data256.min()\n", 179 | " max_bright = img_data256.max()\n", 180 | " img_data16 = (img_data256 - min_bright) / (max_bright - min_bright) * 16\n", 181 | " #加工した画像データの配列をまとめる\n", 182 | " img_test = np.r_[img_test, img_data16.astype(np.uint8).reshape(1, -1)]" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "### データを学習/予測/計算するコード" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 73, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "手書き文字の判別結果\n", 202 | "観測結果: [0 1 2 3 4 5 6 7 8 9]\n", 203 | "予測結果: [4 4 4 4 4 4 4 7 4 4]\n", 204 | "正解率: 0.2\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "# 教師データからの学習\n", 210 | "# sklearnのデータセットから取得、目的変数Xと説明変数yに分ける\n", 211 | "digits = load_digits()\n", 212 | "X = digits.data\n", 213 | "y = digits.target\n", 214 | "#教師データとテストデータに分ける\n", 215 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)\n", 216 | "#ロジスティック回帰のモデルの作成し、教師データを使って学習させる。\n", 217 | "logreg = LogisticRegression(max_iter=2000)\n", 218 | "logreg_model = logreg.fit(X_train, y_train)\n", 219 | "\n", 220 | "# 画像データの判別\n", 221 | "# 画像データの正解を配列にします。\n", 222 | "X_true = []\n", 223 | "for filename in filenames:\n", 224 | " base, ext = os.path.splitext(filename)\n", 225 | " if ext not in INCLUDED_EXTENTION:\n", 226 | " continue\n", 227 | " X_true = X_true + [int(filename[:1])]\n", 228 | "X_true = np.array(X_true)\n", 229 | "\n", 230 | "#ロジスティック回帰の学習済みモデルに画像データを入れ、判別します。\n", 231 | "pred_logreg = logreg_model.predict(img_test)\n", 232 | "\n", 233 | "print('手書き文字の判別結果')\n", 234 | "print('観測結果:', X_true)\n", 235 | "print('予測結果:', pred_logreg)\n", 236 | "print('正解率:', logreg_model.score(img_test, X_true))" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "# 1.2. モジュール分割 / 関数分割" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "### データへのアクセスするコード" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 88, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "import os\n", 260 | "import numpy as np\n", 261 | "from PIL import Image\n", 262 | "import sqlite3\n", 263 | "from sklearn.datasets import load_digits\n", 264 | "from sklearn.linear_model import LogisticRegression\n", 265 | "from sklearn.model_selection import train_test_split" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 89, 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "name": "stdout", 275 | "output_type": "stream", 276 | "text": [ 277 | "table is successully created\n", 278 | "image file names are successully inserted\n" 279 | ] 280 | }, 281 | { 282 | "data": { 283 | "text/plain": [ 284 | "[(1, '0.jpg'),\n", 285 | " (2, '1.jpg'),\n", 286 | " (3, '2.jpg'),\n", 287 | " (4, '3.jpg'),\n", 288 | " (5, '4.jpg'),\n", 289 | " (6, '5.jpg'),\n", 290 | " (7, '6.jpg'),\n", 291 | " (8, '7.jpg'),\n", 292 | " (9, '8.jpg'),\n", 293 | " (10, '9.jpg')]" 294 | ] 295 | }, 296 | "execution_count": 89, 297 | "metadata": {}, 298 | "output_type": "execute_result" 299 | } 300 | ], 301 | "source": [ 302 | "INCLUDED_EXTENTION = [\".png\", \".jpg\"]\n", 303 | "dbname = 'images.db'\n", 304 | "dir_name = 'handwriting_pics'\n", 305 | "\n", 306 | "def load_filenames(dir_name, included_ext=INCLUDED_EXTENTION):\n", 307 | " \"\"\"手書き文字画像が置いてあるパスからファイル名を取得し、リストを作成\"\"\"\n", 308 | " files = []\n", 309 | " filenames = sorted(os.listdir(dir_name))\n", 310 | " for filename in filenames:\n", 311 | " base, ext = os.path.splitext(filename)\n", 312 | " if ext not in included_ext:\n", 313 | " continue\n", 314 | " files.append(filename)\n", 315 | " return files\n", 316 | "\n", 317 | "def create_table(dbname):\n", 318 | " \"\"\"テーブルを作成する関数\"\"\"\n", 319 | " conn = sqlite3.connect(dbname)\n", 320 | " cur = conn.cursor()\n", 321 | " cur.execute('DROP TABLE image_info')\n", 322 | " cur.execute( 'CREATE TABLE image_info (id INTEGER PRIMARY KEY AUTOINCREMENT, filename STRING)')\n", 323 | " conn.commit()\n", 324 | " conn.close()\n", 325 | " print(\"table is successully created\")\n", 326 | "\n", 327 | "def insert_filenames(dbname, dir_name):\n", 328 | " \"\"\"手書き文字画像のファイル名をデータベースに保存\"\"\"\n", 329 | " filenames = load_filenames(dir_name)\n", 330 | " conn = sqlite3.connect(dbname)\n", 331 | " cur = conn.cursor()\n", 332 | " for filename in filenames:\n", 333 | " cur.execute('INSERT INTO image_info(filename) values(?)', (filename,))\n", 334 | " conn.commit()\n", 335 | " cur.close()\n", 336 | " conn.close()\n", 337 | " print(\"image file names are successully inserted\")\n", 338 | "\n", 339 | "def extract_filenames(dbname):\n", 340 | " \"\"\"手書き文字画像のファイル名をデータベースから取得\"\"\"\n", 341 | " conn = sqlite3.connect(dbname)\n", 342 | " cur = conn.cursor()\n", 343 | " cur.execute( 'SELECT * FROM image_info')\n", 344 | " filenames = cur.fetchall()\n", 345 | " cur.close()\n", 346 | " conn.close()\n", 347 | " return filenames\n", 348 | "\n", 349 | "create_table(dbname)\n", 350 | "insert_filenames(dbname, dir_name)\n", 351 | "extract_filenames(dbname)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "### データの前処理をするコード" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 91, 364 | "metadata": {}, 365 | "outputs": [ 366 | { 367 | "data": { 368 | "text/plain": [ 369 | "array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 370 | " 0., 0., 0., 0., 0., 0., 8., 16., 0., 0., 0., 0., 0.,\n", 371 | " 0., 16., 16., 8., 0., 0., 0., 0., 0., 8., 8., 8., 0.,\n", 372 | " 0., 0., 0., 0., 8., 16., 0., 0., 0., 0., 0., 0., 0.,\n", 373 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 374 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 375 | " 0., 0., 0., 0., 0., 0., 16., 0., 0., 0., 0., 0., 0.,\n", 376 | " 0., 8., 0., 0., 0., 0., 0., 0., 0., 8., 0., 0., 0.,\n", 377 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 378 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 379 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 3.,\n", 380 | " 0., 0., 0., 0., 0., 0., 6., 3., 0., 0., 0., 0., 0.,\n", 381 | " 3., 3., 3., 0., 0., 0., 0., 0., 0., 0., 9., 3., 0.,\n", 382 | " 0., 0., 0., 3., 16., 9., 6., 0., 0., 0., 0., 3., 3.,\n", 383 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 384 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 8., 8., 10.,\n", 385 | " 0., 0., 0., 0., 2., 8., 10., 13., 0., 0., 0., 0., 5.,\n", 386 | " 16., 13., 10., 10., 2., 0., 0., 0., 0., 0., 0., 10., 5.,\n", 387 | " 0., 0., 8., 10., 10., 10., 8., 0., 0., 0., 0., 2., 0.,\n", 388 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 389 | " [ 0., 0., 0., 4., 2., 0., 0., 0., 0., 0., 2., 16., 2.,\n", 390 | " 0., 0., 0., 0., 2., 10., 8., 2., 0., 0., 0., 0., 8.,\n", 391 | " 12., 16., 12., 8., 0., 0., 0., 0., 0., 8., 4., 2., 0.,\n", 392 | " 0., 0., 0., 0., 8., 2., 0., 0., 0., 0., 0., 0., 2.,\n", 393 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 394 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 6., 6., 0., 0.,\n", 395 | " 0., 0., 0., 0., 12., 16., 12., 9., 0., 0., 0., 0., 9.,\n", 396 | " 9., 9., 6., 0., 0., 0., 0., 3., 9., 6., 9., 9., 0.,\n", 397 | " 0., 0., 0., 3., 3., 9., 12., 0., 0., 0., 0., 9., 9.,\n", 398 | " 6., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 399 | " [ 0., 0., 5., 2., 0., 0., 0., 0., 0., 8., 8., 2., 0.,\n", 400 | " 0., 0., 0., 0., 10., 0., 0., 0., 0., 0., 0., 0., 16.,\n", 401 | " 10., 10., 8., 0., 0., 0., 0., 10., 5., 0., 8., 5., 0.,\n", 402 | " 0., 0., 0., 10., 5., 8., 8., 0., 0., 0., 0., 0., 5.,\n", 403 | " 8., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 404 | " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 12., 12., 12., 12.,\n", 405 | " 8., 0., 0., 4., 16., 4., 4., 8., 16., 0., 0., 0., 0.,\n", 406 | " 0., 0., 12., 8., 0., 0., 0., 0., 0., 0., 16., 0., 0.,\n", 407 | " 0., 0., 0., 0., 8., 12., 0., 0., 0., 0., 0., 0., 8.,\n", 408 | " 4., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 409 | " [ 0., 0., 4., 2., 0., 0., 0., 0., 0., 8., 6., 8., 4.,\n", 410 | " 0., 0., 0., 0., 8., 4., 0., 10., 2., 0., 0., 0., 0.,\n", 411 | " 10., 16., 8., 0., 0., 0., 0., 4., 8., 2., 8., 0., 0.,\n", 412 | " 0., 0., 10., 0., 0., 8., 0., 0., 0., 0., 4., 8., 8.,\n", 413 | " 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 414 | " [ 0., 2., 4., 2., 2., 0., 0., 0., 2., 8., 4., 4., 6.,\n", 415 | " 6., 2., 0., 0., 8., 2., 0., 0., 16., 4., 0., 0., 2.,\n", 416 | " 8., 6., 8., 12., 0., 0., 0., 0., 0., 0., 8., 2., 0.,\n", 417 | " 0., 0., 0., 0., 2., 6., 0., 0., 0., 0., 0., 0., 6.,\n", 418 | " 2., 0., 0., 0., 0., 0., 0., 2., 0., 0., 0., 0.]])" 419 | ] 420 | }, 421 | "execution_count": 91, 422 | "metadata": {}, 423 | "output_type": "execute_result" 424 | } 425 | ], 426 | "source": [ 427 | "# p425、アウトプットを要確認\n", 428 | "def load_filenames(dir_name, included_ext=INCLUDED_EXTENTION):\n", 429 | " \"\"\"手書き文字画像が置いてあるパスからファイル名を取得し、リストを作成する関数\"\"\"\n", 430 | " files = []\n", 431 | " filenames = sorted(os.listdir(dir_name))\n", 432 | " for filename in filenames:\n", 433 | " base, ext = os.path.splitext(filename)\n", 434 | " if ext not in included_ext:\n", 435 | " continue\n", 436 | " files.append(filename)\n", 437 | " return files\n", 438 | "\n", 439 | "def get_grayscale(dir_name):\n", 440 | " \"\"\"読み込んだ手書き文字画像の色をグレースケールに変換する関数 (グレースケールは色の濃淡の明暗を分ける技法のことです。)\"\"\"\n", 441 | " filenames = load_filenames(dir_name)\n", 442 | " for filename in filenames:\n", 443 | " img = Image.open(f'{dir_name}/{filename}').convert('L')\n", 444 | " yield img\n", 445 | "\n", 446 | "def get_shrinked_img(dir_name):\n", 447 | " \"\"\"画像サイズを8×8ピクセルのサイズに統一し、明るさも16階調のグレイスケールで白黒に変換する関数\"\"\"\n", 448 | " img_test = np.empty((0, 64))\n", 449 | " crop_size = 8\n", 450 | " for img in get_grayscale(dir_name):\n", 451 | " img_data256 = 255 - np.array(img.resize((crop_size, crop_size)))\n", 452 | " min_bright, max_bright = img_data256.min(), img_data256.max()\n", 453 | " img_data16 = (img_data256 - min_bright) / (max_bright - min_bright) * 16\n", 454 | " img_test = np.r_[img_test, img_data16.astype(np.uint8).reshape(1, -1)]\n", 455 | " return img_test\n", 456 | "\n", 457 | "img_test = get_shrinked_img(dir_name)\n", 458 | "get_shrinked_img(dir_name)" 459 | ] 460 | }, 461 | { 462 | "cell_type": "markdown", 463 | "metadata": {}, 464 | "source": [ 465 | "### データを学習/予測/計算するコード" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 82, 471 | "metadata": {}, 472 | "outputs": [ 473 | { 474 | "name": "stdout", 475 | "output_type": "stream", 476 | "text": [ 477 | "手書き文字の判別結果\n", 478 | "観測結果: [0 1 2 3 4 5 6 7 8 9]\n", 479 | "予測結果: [4 4 4 4 4 4 4 7 4 4]\n", 480 | "正解率: 0.2\n" 481 | ] 482 | }, 483 | { 484 | "data": { 485 | "text/plain": [ 486 | "'Propability calculation is successfully finished'" 487 | ] 488 | }, 489 | "execution_count": 82, 490 | "metadata": {}, 491 | "output_type": "execute_result" 492 | } 493 | ], 494 | "source": [ 495 | "import os\n", 496 | "import numpy as np\n", 497 | "from PIL import Image\n", 498 | "from sklearn.datasets import load_digits\n", 499 | "from sklearn.linear_model import LogisticRegression\n", 500 | "from sklearn.model_selection import train_test_split\n", 501 | "\n", 502 | "def load_filenames(dir_name, included_ext=INCLUDED_EXTENTION):\n", 503 | " \"\"\"手書き文字画像が置いてあるパスからファイル名を取得し、リストを作成\"\"\"\n", 504 | " files = []\n", 505 | " filenames = sorted(os.listdir(dir_name))\n", 506 | " for filename in filenames:\n", 507 | " base, ext = os.path.splitext(filename)\n", 508 | " if ext not in included_ext:\n", 509 | " continue\n", 510 | " files.append(filename)\n", 511 | " return files\n", 512 | "\n", 513 | "def create_logreg_model():\n", 514 | " \"\"\"ロジスティック回帰の学習済みモデルを生成\"\"\"\n", 515 | " digits = load_digits()\n", 516 | " X = digits.data\n", 517 | " y = digits.target\n", 518 | " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)\n", 519 | " logreg = LogisticRegression(max_iter=2000)\n", 520 | " logreg_model = logreg.fit(X_train, y_train)\n", 521 | " return logreg_model\n", 522 | "\n", 523 | "def evaluate_probs(dir_name, img_test, logreg_model):\n", 524 | " \"\"\"テストデータを利用してロジスティック回帰の学習済みモデルのアウトプットを評価\"\"\"\n", 525 | " filenames = load_filenames(dir_name)\n", 526 | " X_true = [int(filename[:1]) for filename in filenames] \n", 527 | " X_true = np.array(X_true)\n", 528 | " pred_logreg = logreg_model.predict(img_test)\n", 529 | " \n", 530 | " print('手書き文字の判別結果')\n", 531 | " print('観測結果:', X_true)\n", 532 | " print('予測結果:', pred_logreg)\n", 533 | " print('正解率:', logreg_model.score(img_test, X_true))\n", 534 | " return \"Propability calculation is successfully finished\"\n", 535 | "\n", 536 | "logreg_model = create_logreg_model()\n", 537 | "evaluate_probs(dir_name, img_test, logreg_model)" 538 | ] 539 | }, 540 | { 541 | "cell_type": "markdown", 542 | "metadata": {}, 543 | "source": [ 544 | "# ロジスティック回帰の学習済みモデルを生成" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": 11, 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "import pickle\n", 554 | "from sklearn.datasets import load_digits\n", 555 | "from sklearn.linear_model import LogisticRegression\n", 556 | "from sklearn.model_selection import train_test_split\n", 557 | "\n", 558 | "digits = load_digits()\n", 559 | "X = digits.data\n", 560 | "y = digits.target\n", 561 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)\n", 562 | "\n", 563 | "logreg = LogisticRegression(max_iter=2000)\n", 564 | "model = logreg.fit(X_train, y_train)\n", 565 | "with open('model.pickle', mode='wb') as fp:\n", 566 | " pickle.dump(model, fp)" 567 | ] 568 | } 569 | ], 570 | "metadata": { 571 | "kernelspec": { 572 | "display_name": "Python 3", 573 | "language": "python", 574 | "name": "python3" 575 | }, 576 | "language_info": { 577 | "codemirror_mode": { 578 | "name": "ipython", 579 | "version": 3 580 | }, 581 | "file_extension": ".py", 582 | "mimetype": "text/x-python", 583 | "name": "python", 584 | "nbconvert_exporter": "python", 585 | "pygments_lexer": "ipython3", 586 | "version": "3.9.7" 587 | }, 588 | "varInspector": { 589 | "cols": { 590 | "lenName": 16, 591 | "lenType": 16, 592 | "lenVar": 40 593 | }, 594 | "kernels_config": { 595 | "python": { 596 | "delete_cmd_postfix": "", 597 | "delete_cmd_prefix": "del ", 598 | "library": "var_list.py", 599 | "varRefreshCmd": "print(var_dic_list())" 600 | }, 601 | "r": { 602 | "delete_cmd_postfix": ") ", 603 | "delete_cmd_prefix": "rm(", 604 | "library": "var_list.r", 605 | "varRefreshCmd": "cat(var_dic_list()) " 606 | } 607 | }, 608 | "types_to_exclude": [ 609 | "module", 610 | "function", 611 | "builtin_function_or_method", 612 | "instance", 613 | "_Feature" 614 | ], 615 | "window_display": false 616 | } 617 | }, 618 | "nbformat": 4, 619 | "nbformat_minor": 4 620 | } 621 | -------------------------------------------------------------------------------- /ml_api/api/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from flask import Blueprint, jsonify, request 4 | 5 | from api import calculation, preparation 6 | 7 | from .json_validate import validate_json, validate_schema 8 | 9 | api = Blueprint("api", __name__, url_prefix="/v1") 10 | 11 | 12 | @api.post("/file-id") 13 | @validate_json 14 | @validate_schema("check_dir_name") 15 | def file_id(): 16 | return preparation.insert_filenames(request) 17 | 18 | 19 | @api.post("/probabilities") 20 | @validate_json 21 | @validate_schema("check_file_id") 22 | def probabilities(): 23 | return calculation.evaluate_probs(request) 24 | 25 | 26 | @api.post("/check-schema") 27 | # json schemaの有無のチェックをするデコレーター 28 | @validate_json 29 | # json schemaの定義があっているかどうかのチェックをするデコレーター 30 | @validate_schema("check_file_schema") 31 | def check_schema(): 32 | data = json.loads(request.data) 33 | print(data["file_id"]) 34 | print(data["file_name"]) 35 | d = data["file_name"] 36 | return f"Successfully get {d}" 37 | 38 | 39 | @api.errorhandler(400) 40 | @api.errorhandler(404) 41 | @api.errorhandler(500) 42 | def error_handler(error): 43 | response = jsonify( 44 | {"error_message": error.description["error_message"], "result": error.code} 45 | ) 46 | return response, error.code 47 | -------------------------------------------------------------------------------- /ml_api/api/calculation.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | from flask import jsonify 5 | 6 | from api.preparation import extract_filenames 7 | from api.preprocess import get_shrinked_img 8 | 9 | 10 | def evaluate_probs(request) -> tuple: 11 | """テストデータを利用してロジスティック回帰の学習済みモデルのアウトプットを評価""" 12 | file_id = request.json["file_id"] 13 | filenames = extract_filenames(file_id) 14 | img_test = get_shrinked_img(filenames) 15 | 16 | with open("model.pickle", mode="rb") as fp: 17 | model = pickle.load(fp) 18 | 19 | X_true = [int(filename[:1]) for filename in filenames] 20 | X_true = np.array(X_true) 21 | 22 | predicted_result = model.predict(img_test).tolist() 23 | accuracy = model.score(img_test, X_true).tolist() 24 | observed_result = X_true.tolist() 25 | 26 | return jsonify( 27 | { 28 | "results": { 29 | "file_id": file_id, 30 | "observed_result": observed_result, 31 | "predicted_result": predicted_result, 32 | "accuracy": accuracy, 33 | } 34 | }, 35 | 201, 36 | ) 37 | -------------------------------------------------------------------------------- /ml_api/api/config/__init__.py: -------------------------------------------------------------------------------- 1 | from api.config import config 2 | 3 | config = {"local": config.LocalConfig} 4 | -------------------------------------------------------------------------------- /ml_api/api/config/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | basedir = Path(__file__).parent.parent 5 | 6 | 7 | class LocalConfig: 8 | SQLALCHEMY_DATABASE_URI = "sqlite:////" + os.path.join(basedir, "images.db") 9 | SQLALCHEMY_TRACK_MODIFICATIONS = False 10 | 11 | INCLUDED_EXTENTION = [".png", ".jpg"] 12 | DIR_NAME = "handwriting_pics" 13 | -------------------------------------------------------------------------------- /ml_api/api/config/json-schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/api/config/json-schemas/__init__.py -------------------------------------------------------------------------------- /ml_api/api/config/json-schemas/check_dir_name.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-04/schema#", 3 | "type": "object", 4 | "properties": { 5 | "dir_name": { 6 | "type": "string", 7 | "maximum": 120, 8 | "minimum": 1 9 | } 10 | }, 11 | "required": [ 12 | "dir_name" 13 | ] 14 | } -------------------------------------------------------------------------------- /ml_api/api/config/json-schemas/check_file_id.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-04/schema#", 3 | "type": "object", 4 | "properties": { 5 | "file_id": { 6 | "type": "string", 7 | "maximum": 120, 8 | "minimum": 1 9 | } 10 | }, 11 | "required": [ 12 | "file_id" 13 | ] 14 | } -------------------------------------------------------------------------------- /ml_api/api/config/json-schemas/check_file_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-04/schema#", 3 | "type": "object", 4 | "properties": { 5 | "file_id": { 6 | "type": "integer" 7 | }, 8 | "file_name": { 9 | "type": "string", 10 | "maximum": 120, 11 | "minimum": 1 12 | } 13 | }, 14 | "required": [ 15 | "file_id", 16 | "file_name" 17 | ] 18 | } -------------------------------------------------------------------------------- /ml_api/api/images.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/api/images.db -------------------------------------------------------------------------------- /ml_api/api/json_validate.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | from flask import current_app, jsonify, request 4 | from jsonschema import ValidationError, validate 5 | from werkzeug.exceptions import BadRequest 6 | 7 | 8 | def validate_json(f): 9 | @wraps(f) 10 | def wrapper(*args, **kw): 11 | # リクエストのコンテンツタイプがjsonかどうかをチェックします。 12 | ctype = request.headers.get("Content-Type") 13 | method_ = request.headers.get("X-HTTP-Method-Override", request.method) 14 | if method_.lower() == request.method.lower() and "json" in ctype: 15 | try: 16 | # bodyメッセージがそもそもあるかどうかをチェックします。 17 | request.json 18 | except BadRequest as e: 19 | msg = "This is an invalid json" 20 | return jsonify({"error": msg}), 400 21 | return f(*args, **kw) 22 | 23 | return wrapper 24 | 25 | 26 | def validate_schema(schema_name): 27 | def decorator(f): 28 | @wraps(f) 29 | def wrapper(*args, **kw): 30 | try: 31 | # 先程、定義したjsonファイルの通りにjsonのbodyメッセージ送られているかどうかをチェックします。 32 | validate(request.json, current_app.config[schema_name]) 33 | except ValidationError as e: 34 | return jsonify({"error": e.message}), 400 35 | return f(*args, **kw) 36 | 37 | return wrapper 38 | 39 | return decorator 40 | -------------------------------------------------------------------------------- /ml_api/api/models.py: -------------------------------------------------------------------------------- 1 | from flask_sqlalchemy import SQLAlchemy 2 | 3 | db = SQLAlchemy() 4 | 5 | 6 | class ImageInfo(db.Model): 7 | # テーブル定義 8 | __tablename__ = "image_info" 9 | 10 | id = db.Column(db.Integer, primary_key=True, autoincrement=True) 11 | file_id = db.Column(db.String) 12 | filename = db.Column(db.String) 13 | 14 | def __repr__(self): 15 | return f"" 16 | -------------------------------------------------------------------------------- /ml_api/api/preparation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import uuid 3 | 4 | from flask import abort, current_app, jsonify 5 | from sqlalchemy.exc import SQLAlchemyError 6 | 7 | from api.models import ImageInfo, db 8 | 9 | 10 | def load_filenames(dir_name: str) -> list[str]: 11 | """ 手書き文字画像が置いてあるパスからファイル名を取得し、リストを作成""" 12 | included_ext = current_app.config["INCLUDED_EXTENTION"] 13 | dir_path = Path(__file__).resolve().parent.parent / dir_name 14 | files = Path(dir_path).iterdir() 15 | filenames = sorted( 16 | [ 17 | Path(str(file)).name 18 | for file in files 19 | if Path(str(file)).suffix in included_ext 20 | ] 21 | ) 22 | return filenames 23 | 24 | 25 | def insert_filenames(request) -> tuple: 26 | """手書き文字画像のファイル名をデータベースに保存""" 27 | dir_name = request.json["dir_name"] 28 | filenames = load_filenames(dir_name) 29 | file_id = str(uuid.uuid4()) 30 | for filename in filenames: 31 | db.session.add(ImageInfo(file_id=file_id, filename=filename)) 32 | try: 33 | db.session.commit() 34 | except SQLAlchemyError as error: 35 | db.session.rollback() 36 | abort(500, {"error_message": str(error)}) 37 | return jsonify({"file_id": file_id}), 201 38 | 39 | 40 | def extract_filenames(file_id: str) -> list[str]: 41 | """手書き文字画像のファイル名をデータベースから取得""" 42 | img_obj = db.session.query(ImageInfo).filter(ImageInfo.file_id == file_id) 43 | filenames = [img.filename for img in img_obj if img.filename] 44 | if not filenames: 45 | # p448: abortで処理を止める場合のコード 46 | # abort(404, {"error_message": "filenames are not found in database"}) 47 | 48 | # p449: abortで処理を止めず、jsonifyを実装した場合のコード 49 | return ( 50 | jsonify({"message": "filenames are not found in database", "result": 400}), 51 | 400, 52 | ) 53 | 54 | return filenames 55 | -------------------------------------------------------------------------------- /ml_api/api/preprocess.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from flask import current_app 5 | from PIL import Image 6 | 7 | 8 | def get_grayscale(filenames: list[str]): 9 | """読み込んだ手書き文字画像の色をグレースケールに変換する関数 (グレースケールは色の濃淡の明暗を分ける技法のことです。)""" 10 | dir_name = current_app.config["DIR_NAME"] 11 | dir_path = Path(__file__).resolve().parent.parent / dir_name 12 | for filename in filenames: 13 | img = Image.open(dir_path / filename).convert("L") 14 | yield img 15 | 16 | 17 | def shrink_image( 18 | img, offset=5, crop_size: int = 8, pixel_size: int = 255, max_size: int = 16 19 | ): 20 | """画像サイズを8×8ピクセルのサイズに統一し、明るさも16階調のグレイスケールで白黒に変換する関数""" 21 | img_array = np.asarray(img) 22 | h_indxis = np.where(img_array.min(axis=0) < 255) 23 | v_indxis = np.where(img_array.min(axis=1) < 255) 24 | h_min, h_max = h_indxis[0].min(), h_indxis[0].max() 25 | v_min, v_max = v_indxis[0].min(), v_indxis[0].max() 26 | width, hight = h_max - h_min, v_max - v_min 27 | 28 | if width > hight: 29 | center = (v_max + v_min) // 2 30 | left = h_min - offset 31 | upper = (center - width // 2) - 1 - offset 32 | right = h_max + offset 33 | lower = (center + width // 2) + offset 34 | else: 35 | center = (h_max + h_min + 1) // 2 36 | left = (center - hight // 2) - 1 - offset 37 | upper = v_min - offset 38 | right = (center + hight // 2) + offset 39 | lower = v_max + offset 40 | 41 | img_croped = img.crop((left, upper, right, lower)).resize((crop_size, crop_size)) 42 | img_data256 = pixel_size - np.asarray(img_croped) 43 | 44 | min_bright, max_bright = img_data256.min(), img_data256.max() 45 | img_data16 = (img_data256 - min_bright) / (max_bright - min_bright) * max_size 46 | return img_data16 47 | 48 | 49 | def get_shrinked_img(filenames: list[str]): 50 | """モデルにインプットする画像の数値データのリストを作成する関数""" 51 | img_test = np.empty((0, 64)) 52 | for img in get_grayscale(filenames): 53 | img_data16 = shrink_image(img) 54 | img_test = np.r_[img_test, img_data16.astype(np.uint8).reshape(1, -1)] 55 | return img_test 56 | -------------------------------------------------------------------------------- /ml_api/api/run.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | from flask import Flask 6 | from flask_migrate import Migrate 7 | 8 | from api import api 9 | from api.config import config 10 | from api.models import db 11 | 12 | 13 | def create_app(): 14 | config_name = os.environ.get("CONFIG", "local") 15 | 16 | app = Flask(__name__) 17 | app.config.from_object(config[config_name]) 18 | 19 | config_json_path = Path(__file__).parent / "config" / "json-schemas" 20 | for p in config_json_path.glob("*.json"): 21 | with open(p) as f: 22 | json_name = p.stem 23 | schema = json.load(f) 24 | app.config[json_name] = schema 25 | db.init_app(app) 26 | return app 27 | 28 | 29 | app = create_app() 30 | # DBマイグレーションの作成 31 | Migrate(app, db) 32 | # blueprintをアプリケーションに登録 33 | app.register_blueprint(api) 34 | -------------------------------------------------------------------------------- /ml_api/handwriting_pics/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/handwriting_pics/0.jpg -------------------------------------------------------------------------------- /ml_api/handwriting_pics/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/handwriting_pics/1.jpg -------------------------------------------------------------------------------- /ml_api/handwriting_pics/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/handwriting_pics/2.jpg -------------------------------------------------------------------------------- /ml_api/handwriting_pics/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/handwriting_pics/3.jpg -------------------------------------------------------------------------------- /ml_api/handwriting_pics/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/handwriting_pics/4.jpg -------------------------------------------------------------------------------- /ml_api/handwriting_pics/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/handwriting_pics/5.jpg -------------------------------------------------------------------------------- /ml_api/handwriting_pics/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/handwriting_pics/6.jpg -------------------------------------------------------------------------------- /ml_api/handwriting_pics/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/handwriting_pics/7.jpg -------------------------------------------------------------------------------- /ml_api/handwriting_pics/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/handwriting_pics/8.jpg -------------------------------------------------------------------------------- /ml_api/handwriting_pics/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/handwriting_pics/9.jpg -------------------------------------------------------------------------------- /ml_api/requirements-test.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.4 2 | attrs==21.2.0 3 | black==21.6b0 4 | click==8.0.1 5 | flake8==3.9.2 6 | iniconfig==1.1.1 7 | isort==5.9.1 8 | mccabe==0.6.1 9 | mypy==0.910 10 | mypy-extensions==0.4.3 11 | packaging==20.9 12 | pathspec==0.8.1 13 | pluggy==0.13.1 14 | py==1.10.0 15 | pycodestyle==2.7.0 16 | pyflakes==2.3.1 17 | pyparsing==2.4.7 18 | pytest==6.2.4 19 | regex==2021.4.4 20 | toml==0.10.2 21 | typing-extensions==3.10.0.0 22 | -------------------------------------------------------------------------------- /ml_api/requirements.txt: -------------------------------------------------------------------------------- 1 | alembic==1.6.5 2 | anyio==3.2.1 3 | appnope==0.1.2 4 | argon2-cffi==20.1.0 5 | async-generator==1.10 6 | attrs==21.2.0 7 | Babel==2.9.1 8 | backcall==0.2.0 9 | bleach==3.3.0 10 | certifi==2021.5.30 11 | cffi==1.14.5 12 | chardet==4.0.0 13 | click==8.0.1 14 | cycler==0.10.0 15 | decorator==5.0.9 16 | defusedxml==0.7.1 17 | entrypoints==0.3 18 | Flask==2.0.1 19 | Flask-Migrate==3.0.1 20 | Flask-SQLAlchemy==2.5.1 21 | greenlet==1.1.0 22 | idna==2.10 23 | ipykernel==5.5.5 24 | ipython==7.25.0 25 | ipython-genutils==0.2.0 26 | itsdangerous==2.0.1 27 | jedi==0.18.0 28 | Jinja2==3.0.1 29 | joblib==1.0.1 30 | json5==0.9.6 31 | jsonschema==3.2.0 32 | jupyter-client==6.1.12 33 | jupyter-core==4.7.1 34 | jupyter-server==1.9.0 35 | jupyterlab==3.0.16 36 | jupyterlab-pygments==0.1.2 37 | jupyterlab-server==2.6.0 38 | kiwisolver==1.3.1 39 | Mako==1.1.4 40 | MarkupSafe==2.0.1 41 | matplotlib==3.4.2 42 | matplotlib-inline==0.1.2 43 | mistune==0.8.4 44 | nbclassic==0.3.1 45 | nbclient==0.5.3 46 | nbconvert==6.1.0 47 | nbformat==5.1.3 48 | nest-asyncio==1.5.1 49 | notebook==6.4.0 50 | numpy==1.21.0 51 | packaging==20.9 52 | pandocfilters==1.4.3 53 | parso==0.8.2 54 | pexpect==4.8.0 55 | pickleshare==0.7.5 56 | Pillow==8.2.0 57 | prometheus-client==0.11.0 58 | prompt-toolkit==3.0.19 59 | ptyprocess==0.7.0 60 | pycparser==2.20 61 | Pygments==2.9.0 62 | pyparsing==2.4.7 63 | pyrsistent==0.17.3 64 | python-dateutil==2.8.1 65 | python-editor==1.0.4 66 | pytz==2021.1 67 | pyzmq==22.1.0 68 | requests==2.25.1 69 | requests-unixsocket==0.2.0 70 | scikit-learn==0.24.2 71 | scipy==1.7.0 72 | Send2Trash==1.7.1 73 | six==1.16.0 74 | sniffio==1.2.0 75 | SQLAlchemy==1.4.19 76 | terminado==0.10.1 77 | testpath==0.5.0 78 | threadpoolctl==2.1.0 79 | tornado==6.1 80 | traitlets==5.0.5 81 | urllib3==1.26.6 82 | wcwidth==0.2.5 83 | webencodings==0.5.1 84 | websocket-client==1.1.0 85 | Werkzeug==2.0.1 86 | -------------------------------------------------------------------------------- /ml_api/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = flaskbook_api 3 | version = 1.0a1 4 | description = This package is hoge 5 | 6 | [options] 7 | install_requires = 8 | flask 9 | flask_sqlalchemy 10 | Pillow 11 | numpy 12 | sklearn 13 | jsonschema 14 | 15 | [options.extras_require] 16 | testing = 17 | pytest -------------------------------------------------------------------------------- /ml_api/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /ml_api/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/ml_api/test/__init__.py -------------------------------------------------------------------------------- /ml_api/test/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from api.run import create_app, db 3 | 4 | 5 | @pytest.fixture 6 | def app(): 7 | """ 8 | テスト用のデータベースを作成 9 | """ 10 | app = create_app() 11 | app_context = app.app_context() 12 | app_context.push() 13 | db.create_all() 14 | 15 | yield app 16 | 17 | db.session.remove() 18 | db.drop_all() 19 | app_context.pop() 20 | 21 | 22 | @pytest.fixture 23 | def client(app): 24 | """ 25 | テスト用のrequestオブジェクトを作成 26 | """ 27 | return app.test_client() 28 | -------------------------------------------------------------------------------- /ml_api/test/test_calculation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from api.calculation import evaluate_probs 3 | from api.models import ImageInfo, db 4 | 5 | 6 | @pytest.fixture 7 | def app_made_preparation(app): 8 | file_id = "test_file_id" 9 | filenames = [ 10 | "0.jpg", 11 | "1.jpg", 12 | "2.jpg", 13 | "3.jpg", 14 | "4.jpg", 15 | "5.jpg", 16 | "6.jpg", 17 | "7.jpg", 18 | "8.jpg", 19 | "9.jpg", 20 | ] 21 | 22 | with app.app_context(): 23 | for filename in filenames: 24 | image_info = ImageInfo(file_id=file_id, filename=filename) 25 | db.session.add(image_info) 26 | db.session.commit() 27 | return app 28 | 29 | 30 | def test_evaluate_probs(app_made_preparation): 31 | # SET 32 | path = "/v1/evaluate_probs" 33 | payload = {"file_id": "test_file_id"} 34 | 35 | expected_file_id = "test_file_id" 36 | expected_data_len = 10 37 | 38 | # EXECUTE 39 | with app_made_preparation.test_request_context( 40 | path, method="POST", json=payload 41 | ) as req: 42 | json_response = evaluate_probs(req.request) 43 | 44 | actual_file_id = json_response.json[0]["results"]["file_id"] 45 | actual_observed_result = json_response.json[0]["results"]["observed_result"] 46 | actual_predicted_result = json_response.json[0]["results"]["predicted_result"] 47 | actual_accuracy = json_response.json[0]["results"]["accuracy"] 48 | 49 | # CHECK 50 | assert expected_file_id == actual_file_id 51 | assert expected_data_len == len(actual_observed_result) 52 | assert expected_data_len == len(actual_predicted_result) 53 | assert isinstance(actual_accuracy, float) -------------------------------------------------------------------------------- /ml_api/test/test_preparation.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import pytest 4 | from api.models import ImageInfo, db 5 | from api.preparation import extract_filenames, insert_filenames, load_filenames 6 | 7 | 8 | @pytest.fixture 9 | def app_made_preparation(app): 10 | file_id = "test_file_id" 11 | filenames = [ 12 | "test0.jpg", 13 | "test1.jpg", 14 | "test2.jpg", 15 | "test3.jpg", 16 | "test4.jpg", 17 | "test5.jpg", 18 | "test6.jpg", 19 | "test7.jpg", 20 | "test8.jpg", 21 | "test9.jpg", 22 | ] 23 | 24 | with app.app_context(): 25 | for filename in filenames: 26 | image_info = ImageInfo(file_id=file_id, filename=filename) 27 | db.session.add(image_info) 28 | db.session.commit() 29 | return app 30 | 31 | 32 | def test_load_filenames(app): 33 | # SET 34 | test_data = "handwriting_pics" 35 | expected_data = [ 36 | "0.jpg", 37 | "1.jpg", 38 | "2.jpg", 39 | "3.jpg", 40 | "4.jpg", 41 | "5.jpg", 42 | "6.jpg", 43 | "7.jpg", 44 | "8.jpg", 45 | "9.jpg", 46 | ] 47 | 48 | # EXECUTE 49 | actual_data = load_filenames(test_data) 50 | 51 | # CHECK 52 | assert len(expected_data) == len(actual_data) 53 | assert sorted(expected_data) == sorted(actual_data) 54 | 55 | 56 | def test_insert_filenames(app): 57 | # SET 58 | path = "/v1/insert_filenames" 59 | payload = {"dir_name": "handwriting_pics"} 60 | 61 | # EXECUTE 62 | with app.test_request_context(path, method="POST", json=payload) as req: 63 | json_response = insert_filenames(req.request) 64 | 65 | actual_data = json_response[0].json["file_id"] 66 | 67 | # CHECK 68 | assert isinstance(actual_data, str) 69 | assert uuid.UUID(actual_data).version == 4 70 | 71 | 72 | def test_extract_filenames(app_made_preparation): 73 | # SET 74 | test_data = "test_file_id" 75 | expected_data = [ 76 | "test0.jpg", 77 | "test1.jpg", 78 | "test2.jpg", 79 | "test3.jpg", 80 | "test4.jpg", 81 | "test5.jpg", 82 | "test6.jpg", 83 | "test7.jpg", 84 | "test8.jpg", 85 | "test9.jpg", 86 | ] 87 | 88 | # EXECUTE 89 | actual_data = extract_filenames(test_data) 90 | 91 | # CHECK 92 | assert len(expected_data) == len(actual_data) 93 | assert sorted(expected_data) == sorted(actual_data) 94 | assert all([a == b for a, b in zip(expected_data, actual_data)]) 95 | -------------------------------------------------------------------------------- /ml_api/test/test_preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from api.preprocess import get_grayscale, get_shrinked_img, shrink_image 3 | from PIL import Image 4 | 5 | 6 | def test_get_grayscale(app): 7 | # SET 8 | test_data = [ 9 | "0.jpg", 10 | "1.jpg", 11 | "2.jpg", 12 | "3.jpg", 13 | "4.jpg", 14 | "5.jpg", 15 | "6.jpg", 16 | "7.jpg", 17 | "8.jpg", 18 | "9.jpg", 19 | ] 20 | 21 | def mock_generator(): 22 | yield from () 23 | 24 | expected_generator = mock_generator() 25 | expected_len = 10 26 | 27 | # EXECUTE 28 | actual_data = get_grayscale(test_data) 29 | # CHECK 30 | assert type(expected_generator) == type(actual_data) 31 | assert expected_len == len(list(actual_data)) 32 | 33 | 34 | def test_shrink_image(app): 35 | # SET 36 | mock_objs = [Image.new("L", (64, 64)) for _ in range(10)] 37 | expected_data = np.array([np.nan for _ in range(64)]) 38 | expected_len = 8 39 | # EXECUTE 40 | for mock_obj in mock_objs: 41 | actual_data = shrink_image(mock_obj) 42 | # CHECK 43 | assert expected_len == len(actual_data) 44 | assert isinstance(actual_data, (np.ndarray, np.generic)) 45 | assert expected_data.all() == actual_data.all() 46 | 47 | 48 | def test_get_shrinked_img(app): 49 | # SET 50 | test_data = [ 51 | "0.jpg", 52 | "1.jpg", 53 | "2.jpg", 54 | "3.jpg", 55 | "4.jpg", 56 | "5.jpg", 57 | "6.jpg", 58 | "7.jpg", 59 | "8.jpg", 60 | "9.jpg", 61 | ] 62 | expected_len = 10 63 | 64 | # EXECUTE 65 | actual_data = get_shrinked_img(test_data) 66 | 67 | # CHECK 68 | assert expected_len == len(actual_data) 69 | assert isinstance(actual_data, (np.ndarray, np.generic)) 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alembic==1.7.5 2 | attrs==21.2.0 3 | black==21.12b0 4 | blinker==1.4 5 | click==8.0.3 6 | coverage==6.2 7 | dnspython==2.1.0 8 | email-validator==1.1.3 9 | flake8==4.0.1 10 | Flask==2.0.2 11 | Flask-DebugToolbar==0.11.0 12 | Flask-Login==0.5.0 13 | Flask-Mail==0.9.1 14 | Flask-Migrate==3.1.0 15 | Flask-SQLAlchemy==2.5.1 16 | Flask-WTF==1.0.0 17 | idna==3.3 18 | iniconfig==1.1.1 19 | isort==5.10.1 20 | itsdangerous==2.0.1 21 | Jinja2==3.0.3 22 | Mako==1.1.6 23 | MarkupSafe==2.0.1 24 | mccabe==0.6.1 25 | mypy==0.920 26 | mypy-extensions==0.4.3 27 | numpy==1.21.4 28 | opencv-python==4.5.4.60 29 | packaging==21.3 30 | pathspec==0.9.0 31 | Pillow==8.4.0 32 | platformdirs==2.4.0 33 | pluggy==1.0.0 34 | py==1.11.0 35 | pycodestyle==2.8.0 36 | pyflakes==2.4.0 37 | pyparsing==3.0.6 38 | pytest==6.2.5 39 | pytest-cov==3.0.0 40 | python-dotenv==0.19.2 41 | SQLAlchemy==1.4.28 42 | toml==0.10.2 43 | tomli==1.2.3 44 | torch==1.10.1 45 | torchvision==0.11.2 46 | typing_extensions==4.0.1 47 | Werkzeug==2.0.2 48 | WTForms==3.0.0 49 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import pytest 5 | from apps.app import create_app, db 6 | from apps.crud.models import User 7 | from apps.detector.models import UserImage, UserImageTag 8 | 9 | 10 | # フィクスチャ関数を作成する 11 | @pytest.fixture 12 | def fixture_app(): 13 | # セットアップ処理 14 | # テスト用のコンフィグを使うために引数にtestingを指定する 15 | app = create_app("testing") 16 | 17 | # データベースを利用するための宣言をする 18 | app.app_context().push() 19 | 20 | # テスト用データベースのテーブルを作成する 21 | with app.app_context(): 22 | db.create_all() 23 | 24 | # テスト用の画像アップロードディレクトリを作成する 25 | os.mkdir(app.config["UPLOAD_FOLDER"]) 26 | 27 | # テストを実行する 28 | yield app 29 | 30 | # クリーンナップ処理 31 | # userテーブルのレコードを削除する 32 | User.query.delete() 33 | 34 | # user_imageテーブルのレコードを削除する 35 | UserImage.query.delete() 36 | 37 | # user_image_tagsテーブルのレコードを削除する 38 | UserImageTag.query.delete() 39 | 40 | # テスト用の画像アップロードディレクトリを削除する 41 | shutil.rmtree(app.config["UPLOAD_FOLDER"]) 42 | 43 | db.session.commit() 44 | 45 | 46 | # Flaskのテストクライアントを返すフィクスチャ関数を作成する 47 | @pytest.fixture 48 | def client(fixture_app): 49 | # Flaskのテスト用クライアントを返す 50 | return fixture_app.test_client() 51 | -------------------------------------------------------------------------------- /tests/detector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/tests/detector/__init__.py -------------------------------------------------------------------------------- /tests/detector/test_views.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from apps.detector.models import UserImage 4 | from flask.helpers import get_root_path 5 | from werkzeug.datastructures import FileStorage 6 | 7 | 8 | def test_index(client): 9 | rv = client.get("/") 10 | assert "ログイン" in rv.data.decode() 11 | assert "画像新規登録" in rv.data.decode() 12 | 13 | 14 | def signup(client, username, email, password): 15 | """サインアップする""" 16 | data = dict(username=username, email=email, password=password) 17 | return client.post("/auth/signup", data=data, follow_redirects=True) 18 | 19 | 20 | def test_index_signup(client): 21 | """サインアップを実行する""" 22 | rv = signup(client, "admin", "flaskbook@example.com", "password") 23 | assert "admin" in rv.data.decode() 24 | rv = client.get("/") 25 | assert "ログアウト" in rv.data.decode() 26 | assert "画像新規登録" in rv.data.decode() 27 | 28 | 29 | def test_upload_no_auth(client): 30 | rv = client.get("/upload", follow_redirects=True) 31 | # 画像アップロード画面にはアクセスできない 32 | assert "アップロード" not in rv.data.decode() 33 | # ログイン画面へリダイレクトされる 34 | assert "メールアドレス" in rv.data.decode() 35 | assert "パスワード" in rv.data.decode() 36 | 37 | 38 | def test_upload_signup_get(client): 39 | signup(client, "admin", "flaskbook@example.com", "password") 40 | rv = client.get("/upload") 41 | assert "アップロード" in rv.data.decode() 42 | 43 | 44 | def upload_image(client, image_path): 45 | """画像をアップロードする""" 46 | image = Path(get_root_path("tests"), image_path) 47 | test_file = ( 48 | FileStorage( 49 | stream=open(image, "rb"), 50 | filename=Path(image_path).name, 51 | content_type="multipart/form-data", 52 | ), 53 | ) 54 | data = dict( 55 | image=test_file, 56 | ) 57 | return client.post("/upload", data=data, follow_redirects=True) 58 | 59 | 60 | def test_upload_signup_post_validate(client): 61 | signup(client, "admin", "flaskbook@example.com", "password") 62 | rv = upload_image(client, "detector/testdata/test_invalid_file.txt") 63 | assert "サポートされていない画像形式です。" in rv.data.decode() 64 | 65 | 66 | def test_upload_signup_post(client): 67 | signup(client, "admin", "flaskbook@example.com", "password") 68 | rv = upload_image(client, "detector/testdata/test_valid_image.jpg") 69 | user_image = UserImage.query.first() 70 | assert user_image.image_path in rv.data.decode() 71 | 72 | 73 | def test_detect_no_user_image(client): 74 | signup(client, "admin", "flaskbook@example.com", "password") 75 | upload_image(client, "detector/testdata/test_valid_image.jpg") 76 | # 存在しないIDを指定する 77 | rv = client.post("/detect/notexistid", follow_redirects=True) 78 | assert "物体検知対象の画像が存在しません。" in rv.data.decode() 79 | 80 | 81 | def test_detect(client): 82 | # サインアップする 83 | signup(client, "admin", "flaskbook@example.com", "password") 84 | 85 | # 画像をアップロードする 86 | upload_image(client, "detector/testdata/test_valid_image.jpg") 87 | user_image = UserImage.query.first() 88 | 89 | # 物体検知を実行する 90 | rv = client.post(f"/detect/{user_image.id}", follow_redirects=True) 91 | user_image = UserImage.query.first() 92 | assert user_image.image_path in rv.data.decode() 93 | assert "dog" in rv.data.decode() 94 | 95 | 96 | def test_detect_search(client): 97 | # サインアップする 98 | signup(client, "admin", "flaskbook@example.com", "password") 99 | 100 | # 画像をアップロードする 101 | upload_image(client, "detector/testdata/test_valid_image.jpg") 102 | 103 | user_image = UserImage.query.first() 104 | # 物体検知する 105 | client.post(f"/detect/{user_image.id}", follow_redirects=True) 106 | 107 | # dogワードで検索する 108 | rv = client.get("/images/search?search=dog") 109 | 110 | # dogタグの画像があることを確認する 111 | assert user_image.image_path in rv.data.decode() 112 | 113 | # dogタグがあることを確認する 114 | assert "dog" in rv.data.decode() 115 | 116 | # testワードで検索する 117 | rv = client.get("/images/search?search=test") 118 | 119 | # dogタグの画像がないことを確認する 120 | assert user_image.image_path not in rv.data.decode() 121 | 122 | # dogタグがないことを確認する 123 | assert "dog" not in rv.data.decode() 124 | 125 | 126 | def test_delete(client): 127 | signup(client, "admin", "flaskbook@example.com", "password") 128 | upload_image(client, "detector/testdata/test_valid_image.jpg") 129 | user_image = UserImage.query.first() 130 | image_path = user_image.image_path 131 | rv = client.post(f"/images/delete/{user_image.id}", follow_redirects=True) 132 | assert image_path not in rv.data.decode() 133 | 134 | 135 | def test_custom_error(client): 136 | rv = client.get("/notfound") 137 | assert "404 Not Found" in rv.data.decode() 138 | -------------------------------------------------------------------------------- /tests/detector/testdata/test_invalid_file.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/tests/detector/testdata/test_invalid_file.txt -------------------------------------------------------------------------------- /tests/detector/testdata/test_valid_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-flaskbook/flaskbook/df89aac3d31faff657049c70906ca1e36ae80502/tests/detector/testdata/test_valid_image.jpg -------------------------------------------------------------------------------- /tests/test_sample.py: -------------------------------------------------------------------------------- 1 | def test_func1(): 2 | assert 1 == 1 3 | 4 | 5 | def test_func2(): 6 | assert 2 == 2 7 | 8 | 9 | # フィクスチャの関数を引数で指定すると関数の実行結果が渡される 10 | def test_func3(app_data): 11 | assert app_data == 3 12 | --------------------------------------------------------------------------------