├── .gitignore ├── .python-version ├── .travis.yml ├── LICENSE ├── README.md ├── manage.py ├── project ├── __init__.py ├── server │ ├── __init__.py │ ├── auth │ │ ├── __init__.py │ │ └── views.py │ ├── config.py │ └── models.py └── tests │ ├── __init__.py │ ├── base.py │ ├── helpers.py │ ├── test__config.py │ ├── test_auth.py │ └── test_user_model.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | env 3 | venv 4 | temp 5 | tmp 6 | __pycache__ 7 | 8 | *.pyc 9 | *.sqlite 10 | *.coverage 11 | .DS_Store 12 | env.sh 13 | migrations 14 | *.idea 15 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.6.0 2 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - "3.6" 5 | - "3.5" 6 | - "3.4" 7 | - "2.7" 8 | 9 | service: 10 | - postgresql 11 | 12 | before_install: 13 | - export APP_SETTINGS="project.server.config.TestingConfig" 14 | - export SECRET_KEY="justatest" 15 | 16 | install: 17 | - pip install -r requirements.txt 18 | - pip install coveralls 19 | 20 | before_script: 21 | - psql -c 'create database flask_jwt_auth_test;' -U postgres 22 | - python manage.py db init 23 | - python manage.py db migrate 24 | - python manage.py db upgrade 25 | 26 | script: 27 | - python manage.py cov 28 | 29 | after_success: 30 | coveralls 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | Copyright (c) 2016 Michael Herman 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flask JWT Auth 2 | 3 | [![Build Status](https://travis-ci.org/realpython/flask-jwt-auth.svg?branch=master)](https://travis-ci.org/realpython/flask-jwt-auth) 4 | 5 | ## Want to learn how to build this project? 6 | 7 | Check out the [blog post](https://realpython.com/blog/python/token-based-authentication-with-flask/). 8 | 9 | ## Want to use this project? 10 | 11 | ### Basics 12 | 13 | 1. Fork/Clone 14 | 1. Activate a virtualenv 15 | 1. Install the requirements 16 | 17 | ### Set Environment Variables 18 | 19 | Update *project/server/config.py*, and then run: 20 | 21 | ```sh 22 | $ export APP_SETTINGS="project.server.config.DevelopmentConfig" 23 | ``` 24 | 25 | or 26 | 27 | ```sh 28 | $ export APP_SETTINGS="project.server.config.ProductionConfig" 29 | ``` 30 | 31 | Set a SECRET_KEY: 32 | 33 | ```sh 34 | $ export SECRET_KEY="change_me" 35 | ``` 36 | 37 | ### Create DB 38 | 39 | Create the databases in `psql`: 40 | 41 | ```sh 42 | $ psql 43 | # create database flask_jwt_auth 44 | # create database flask_jwt_auth_test 45 | # \q 46 | ``` 47 | 48 | Create the tables and run the migrations: 49 | 50 | ```sh 51 | $ python manage.py create_db 52 | $ python manage.py db init 53 | $ python manage.py db migrate 54 | ``` 55 | 56 | ### Run the Application 57 | 58 | ```sh 59 | $ python manage.py runserver 60 | ``` 61 | 62 | Access the application at the address [http://localhost:5000/](http://localhost:5000/) 63 | 64 | > Want to specify a different port? 65 | 66 | > ```sh 67 | > $ python manage.py runserver -h 0.0.0.0 -p 8080 68 | > ``` 69 | 70 | ### Testing 71 | 72 | Without coverage: 73 | 74 | ```sh 75 | $ python manage.py test 76 | ``` 77 | 78 | With coverage: 79 | 80 | ```sh 81 | $ python manage.py cov 82 | ``` 83 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | # manage.py 2 | 3 | 4 | import os 5 | import unittest 6 | import coverage 7 | 8 | from flask_script import Manager 9 | from flask_migrate import Migrate, MigrateCommand 10 | 11 | COV = coverage.coverage( 12 | branch=True, 13 | include='project/*', 14 | omit=[ 15 | 'project/tests/*', 16 | 'project/server/config.py', 17 | 'project/server/*/__init__.py' 18 | ] 19 | ) 20 | COV.start() 21 | 22 | from project.server import app, db, models 23 | 24 | migrate = Migrate(app, db) 25 | manager = Manager(app) 26 | 27 | # migrations 28 | manager.add_command('db', MigrateCommand) 29 | 30 | 31 | @manager.command 32 | def test(): 33 | """Runs the unit tests without test coverage.""" 34 | tests = unittest.TestLoader().discover('project/tests', pattern='test*.py') 35 | result = unittest.TextTestRunner(verbosity=2).run(tests) 36 | if result.wasSuccessful(): 37 | return 0 38 | return 1 39 | 40 | 41 | @manager.command 42 | def cov(): 43 | """Runs the unit tests with coverage.""" 44 | tests = unittest.TestLoader().discover('project/tests') 45 | result = unittest.TextTestRunner(verbosity=2).run(tests) 46 | if result.wasSuccessful(): 47 | COV.stop() 48 | COV.save() 49 | print('Coverage Summary:') 50 | COV.report() 51 | basedir = os.path.abspath(os.path.dirname(__file__)) 52 | covdir = os.path.join(basedir, 'tmp/coverage') 53 | COV.html_report(directory=covdir) 54 | print('HTML version: file://%s/index.html' % covdir) 55 | COV.erase() 56 | return 0 57 | return 1 58 | 59 | 60 | @manager.command 61 | def create_db(): 62 | """Creates the db tables.""" 63 | db.create_all() 64 | 65 | 66 | @manager.command 67 | def drop_db(): 68 | """Drops the db tables.""" 69 | db.drop_all() 70 | 71 | 72 | if __name__ == '__main__': 73 | manager.run() 74 | -------------------------------------------------------------------------------- /project/__init__.py: -------------------------------------------------------------------------------- 1 | # project/__init__.py 2 | -------------------------------------------------------------------------------- /project/server/__init__.py: -------------------------------------------------------------------------------- 1 | # project/server/__init__.py 2 | 3 | import os 4 | 5 | from flask import Flask 6 | from flask_bcrypt import Bcrypt 7 | from flask_sqlalchemy import SQLAlchemy 8 | from flask_cors import CORS 9 | 10 | app = Flask(__name__) 11 | CORS(app) 12 | 13 | app_settings = os.getenv( 14 | 'APP_SETTINGS', 15 | 'project.server.config.DevelopmentConfig' 16 | ) 17 | app.config.from_object(app_settings) 18 | 19 | bcrypt = Bcrypt(app) 20 | db = SQLAlchemy(app) 21 | 22 | from project.server.auth.views import auth_blueprint 23 | app.register_blueprint(auth_blueprint) 24 | -------------------------------------------------------------------------------- /project/server/auth/__init__.py: -------------------------------------------------------------------------------- 1 | # project/server/auth/__init__.py 2 | -------------------------------------------------------------------------------- /project/server/auth/views.py: -------------------------------------------------------------------------------- 1 | # project/server/auth/views.py 2 | 3 | 4 | from flask import Blueprint, request, make_response, jsonify 5 | from flask.views import MethodView 6 | 7 | from project.server import bcrypt, db 8 | from project.server.models import User, BlacklistToken 9 | 10 | auth_blueprint = Blueprint('auth', __name__) 11 | 12 | 13 | class RegisterAPI(MethodView): 14 | """ 15 | User Registration Resource 16 | """ 17 | 18 | def post(self): 19 | # get the post data 20 | post_data = request.get_json() 21 | # check if user already exists 22 | user = User.query.filter_by(email=post_data.get('email')).first() 23 | if not user: 24 | try: 25 | user = User( 26 | email=post_data.get('email'), 27 | password=post_data.get('password') 28 | ) 29 | # insert the user 30 | db.session.add(user) 31 | db.session.commit() 32 | # generate the auth token 33 | auth_token = user.encode_auth_token(user.id) 34 | responseObject = { 35 | 'status': 'success', 36 | 'message': 'Successfully registered.', 37 | 'auth_token': auth_token.decode() 38 | } 39 | return make_response(jsonify(responseObject)), 201 40 | except Exception as e: 41 | responseObject = { 42 | 'status': 'fail', 43 | 'message': 'Some error occurred. Please try again.' 44 | } 45 | return make_response(jsonify(responseObject)), 401 46 | else: 47 | responseObject = { 48 | 'status': 'fail', 49 | 'message': 'User already exists. Please Log in.', 50 | } 51 | return make_response(jsonify(responseObject)), 202 52 | 53 | 54 | class LoginAPI(MethodView): 55 | """ 56 | User Login Resource 57 | """ 58 | def post(self): 59 | # get the post data 60 | post_data = request.get_json() 61 | try: 62 | # fetch the user data 63 | user = User.query.filter_by( 64 | email=post_data.get('email') 65 | ).first() 66 | if user and bcrypt.check_password_hash( 67 | user.password, post_data.get('password') 68 | ): 69 | auth_token = user.encode_auth_token(user.id) 70 | if auth_token: 71 | responseObject = { 72 | 'status': 'success', 73 | 'message': 'Successfully logged in.', 74 | 'auth_token': auth_token.decode() 75 | } 76 | return make_response(jsonify(responseObject)), 200 77 | else: 78 | responseObject = { 79 | 'status': 'fail', 80 | 'message': 'User does not exist.' 81 | } 82 | return make_response(jsonify(responseObject)), 404 83 | except Exception as e: 84 | print(e) 85 | responseObject = { 86 | 'status': 'fail', 87 | 'message': 'Try again' 88 | } 89 | return make_response(jsonify(responseObject)), 500 90 | 91 | 92 | class UserAPI(MethodView): 93 | """ 94 | User Resource 95 | """ 96 | def get(self): 97 | # get the auth token 98 | auth_header = request.headers.get('Authorization') 99 | if auth_header: 100 | try: 101 | auth_token = auth_header.split(" ")[1] 102 | except IndexError: 103 | responseObject = { 104 | 'status': 'fail', 105 | 'message': 'Bearer token malformed.' 106 | } 107 | return make_response(jsonify(responseObject)), 401 108 | else: 109 | auth_token = '' 110 | if auth_token: 111 | resp = User.decode_auth_token(auth_token) 112 | if not isinstance(resp, str): 113 | user = User.query.filter_by(id=resp).first() 114 | responseObject = { 115 | 'status': 'success', 116 | 'data': { 117 | 'user_id': user.id, 118 | 'email': user.email, 119 | 'admin': user.admin, 120 | 'registered_on': user.registered_on 121 | } 122 | } 123 | return make_response(jsonify(responseObject)), 200 124 | responseObject = { 125 | 'status': 'fail', 126 | 'message': resp 127 | } 128 | return make_response(jsonify(responseObject)), 401 129 | else: 130 | responseObject = { 131 | 'status': 'fail', 132 | 'message': 'Provide a valid auth token.' 133 | } 134 | return make_response(jsonify(responseObject)), 401 135 | 136 | 137 | class LogoutAPI(MethodView): 138 | """ 139 | Logout Resource 140 | """ 141 | def post(self): 142 | # get auth token 143 | auth_header = request.headers.get('Authorization') 144 | if auth_header: 145 | auth_token = auth_header.split(" ")[1] 146 | else: 147 | auth_token = '' 148 | if auth_token: 149 | resp = User.decode_auth_token(auth_token) 150 | if not isinstance(resp, str): 151 | # mark the token as blacklisted 152 | blacklist_token = BlacklistToken(token=auth_token) 153 | try: 154 | # insert the token 155 | db.session.add(blacklist_token) 156 | db.session.commit() 157 | responseObject = { 158 | 'status': 'success', 159 | 'message': 'Successfully logged out.' 160 | } 161 | return make_response(jsonify(responseObject)), 200 162 | except Exception as e: 163 | responseObject = { 164 | 'status': 'fail', 165 | 'message': e 166 | } 167 | return make_response(jsonify(responseObject)), 200 168 | else: 169 | responseObject = { 170 | 'status': 'fail', 171 | 'message': resp 172 | } 173 | return make_response(jsonify(responseObject)), 401 174 | else: 175 | responseObject = { 176 | 'status': 'fail', 177 | 'message': 'Provide a valid auth token.' 178 | } 179 | return make_response(jsonify(responseObject)), 403 180 | 181 | # define the API resources 182 | registration_view = RegisterAPI.as_view('register_api') 183 | login_view = LoginAPI.as_view('login_api') 184 | user_view = UserAPI.as_view('user_api') 185 | logout_view = LogoutAPI.as_view('logout_api') 186 | 187 | # add Rules for API Endpoints 188 | auth_blueprint.add_url_rule( 189 | '/auth/register', 190 | view_func=registration_view, 191 | methods=['POST'] 192 | ) 193 | auth_blueprint.add_url_rule( 194 | '/auth/login', 195 | view_func=login_view, 196 | methods=['POST'] 197 | ) 198 | auth_blueprint.add_url_rule( 199 | '/auth/status', 200 | view_func=user_view, 201 | methods=['GET'] 202 | ) 203 | auth_blueprint.add_url_rule( 204 | '/auth/logout', 205 | view_func=logout_view, 206 | methods=['POST'] 207 | ) 208 | -------------------------------------------------------------------------------- /project/server/config.py: -------------------------------------------------------------------------------- 1 | # project/server/config.py 2 | 3 | import os 4 | basedir = os.path.abspath(os.path.dirname(__file__)) 5 | postgres_local_base = 'postgresql://postgres:@localhost/' 6 | database_name = 'flask_jwt_auth' 7 | 8 | 9 | class BaseConfig: 10 | """Base configuration.""" 11 | SECRET_KEY = os.getenv('SECRET_KEY', 'my_precious') 12 | DEBUG = False 13 | BCRYPT_LOG_ROUNDS = 13 14 | SQLALCHEMY_TRACK_MODIFICATIONS = False 15 | 16 | 17 | class DevelopmentConfig(BaseConfig): 18 | """Development configuration.""" 19 | DEBUG = True 20 | BCRYPT_LOG_ROUNDS = 4 21 | SQLALCHEMY_DATABASE_URI = postgres_local_base + database_name 22 | 23 | 24 | class TestingConfig(BaseConfig): 25 | """Testing configuration.""" 26 | DEBUG = True 27 | TESTING = True 28 | BCRYPT_LOG_ROUNDS = 4 29 | SQLALCHEMY_DATABASE_URI = postgres_local_base + database_name + '_test' 30 | PRESERVE_CONTEXT_ON_EXCEPTION = False 31 | 32 | 33 | class ProductionConfig(BaseConfig): 34 | """Production configuration.""" 35 | SECRET_KEY = 'my_precious' 36 | DEBUG = False 37 | SQLALCHEMY_DATABASE_URI = 'postgresql:///example' 38 | -------------------------------------------------------------------------------- /project/server/models.py: -------------------------------------------------------------------------------- 1 | # project/server/models.py 2 | 3 | 4 | import jwt 5 | import datetime 6 | 7 | from project.server import app, db, bcrypt 8 | 9 | 10 | class User(db.Model): 11 | """ User Model for storing user related details """ 12 | __tablename__ = "users" 13 | 14 | id = db.Column(db.Integer, primary_key=True, autoincrement=True) 15 | email = db.Column(db.String(255), unique=True, nullable=False) 16 | password = db.Column(db.String(255), nullable=False) 17 | registered_on = db.Column(db.DateTime, nullable=False) 18 | admin = db.Column(db.Boolean, nullable=False, default=False) 19 | 20 | def __init__(self, email, password, admin=False): 21 | self.email = email 22 | self.password = bcrypt.generate_password_hash( 23 | password, app.config.get('BCRYPT_LOG_ROUNDS') 24 | ).decode() 25 | self.registered_on = datetime.datetime.now() 26 | self.admin = admin 27 | 28 | def encode_auth_token(self, user_id): 29 | """ 30 | Generates the Auth Token 31 | :return: string 32 | """ 33 | try: 34 | payload = { 35 | 'exp': datetime.datetime.utcnow() + datetime.timedelta(days=0, seconds=5), 36 | 'iat': datetime.datetime.utcnow(), 37 | 'sub': user_id 38 | } 39 | return jwt.encode( 40 | payload, 41 | app.config.get('SECRET_KEY'), 42 | algorithm='HS256' 43 | ) 44 | except Exception as e: 45 | return e 46 | 47 | @staticmethod 48 | def decode_auth_token(auth_token): 49 | """ 50 | Validates the auth token 51 | :param auth_token: 52 | :return: integer|string 53 | """ 54 | try: 55 | payload = jwt.decode(auth_token, app.config.get('SECRET_KEY')) 56 | is_blacklisted_token = BlacklistToken.check_blacklist(auth_token) 57 | if is_blacklisted_token: 58 | return 'Token blacklisted. Please log in again.' 59 | else: 60 | return payload['sub'] 61 | except jwt.ExpiredSignatureError: 62 | return 'Signature expired. Please log in again.' 63 | except jwt.InvalidTokenError: 64 | return 'Invalid token. Please log in again.' 65 | 66 | 67 | class BlacklistToken(db.Model): 68 | """ 69 | Token Model for storing JWT tokens 70 | """ 71 | __tablename__ = 'blacklist_tokens' 72 | 73 | id = db.Column(db.Integer, primary_key=True, autoincrement=True) 74 | token = db.Column(db.String(500), unique=True, nullable=False) 75 | blacklisted_on = db.Column(db.DateTime, nullable=False) 76 | 77 | def __init__(self, token): 78 | self.token = token 79 | self.blacklisted_on = datetime.datetime.now() 80 | 81 | def __repr__(self): 82 | return '