├── tests ├── __init__.py ├── test_config.py ├── test_forms.py ├── test_functional.py └── test_models.py ├── project ├── main │ ├── __init__.py │ └── views.py ├── user │ ├── __init__.py │ ├── forms.py │ └── views.py ├── static │ ├── main.js │ └── main.css ├── templates │ ├── main │ │ └── index.html │ ├── errors │ │ ├── 403.html │ │ ├── 404.html │ │ └── 500.html │ ├── user │ │ ├── login.html │ │ ├── profile.html │ │ └── register.html │ ├── navigation.html │ └── _base.html ├── util.py ├── models.py ├── config.py └── __init__.py ├── .gitignore ├── .travis.yml ├── requirements.txt ├── readme.md ├── LICENSE └── manage.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /project/main/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /project/user/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /project/static/main.js: -------------------------------------------------------------------------------- 1 | // custom javascript -------------------------------------------------------------------------------- /project/static/main.css: -------------------------------------------------------------------------------- 1 | /* custom styles */ 2 | 3 | body { 4 | padding-top: 70px; 5 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.sqlite 3 | .DS_Store 4 | .coverage 5 | env 6 | migrations 7 | tmp 8 | __pycache__ 9 | env.sh -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.9" 4 | - "3.8" 5 | install: 6 | - pip install -r requirements.txt 7 | script: 8 | - python manage.py test 9 | -------------------------------------------------------------------------------- /project/templates/main/index.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | {% block content %} 3 | 4 |

Welcome!

5 |
6 |

Logout 7 | 8 | {% endblock %} -------------------------------------------------------------------------------- /project/templates/errors/403.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | {% block content %} 3 |

403

4 |

Run along!

5 |

Return Home?

6 | {% endblock %} 7 | -------------------------------------------------------------------------------- /project/templates/errors/404.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | {% block content %} 3 |

404

4 |

There's nothing here!

5 |

Return Home?

6 | {% endblock %} 7 | -------------------------------------------------------------------------------- /project/templates/errors/500.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | {% block content %} 3 |

500

4 |

Something's wrong! We are on the job.

5 |

Return Home?

6 | {% endblock %} 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | coverage==5.5 2 | email-validator==1.1.2 3 | Flask==1.1.2 4 | Flask-Bcrypt==0.7.1 5 | Flask-DebugToolbar==0.11.0 6 | Flask-Login==0.5.0 7 | Flask-Mail==0.9.1 8 | Flask-Migrate==2.7.0 9 | Flask-Script==2.0.6 10 | Flask-SQLAlchemy==2.4.4 11 | Flask-Testing==0.8.1 12 | Flask-WTF==0.14.3 13 | itsdangerous==2.0.1 14 | SQLAlchemy==1.3.23 15 | wtforms==2.3.3 16 | -------------------------------------------------------------------------------- /project/main/views.py: -------------------------------------------------------------------------------- 1 | # project/main/views.py 2 | 3 | 4 | ################# 5 | #### imports #### 6 | ################# 7 | 8 | from flask import render_template, Blueprint 9 | from flask_login import login_required 10 | 11 | 12 | ################ 13 | #### config #### 14 | ################ 15 | 16 | main_blueprint = Blueprint('main', __name__,) 17 | 18 | 19 | ################ 20 | #### routes #### 21 | ################ 22 | 23 | @main_blueprint.route('/') 24 | @login_required 25 | def home(): 26 | return render_template('main/index.html') 27 | -------------------------------------------------------------------------------- /project/util.py: -------------------------------------------------------------------------------- 1 | # project/util.py 2 | 3 | 4 | from flask_testing import TestCase 5 | 6 | from project import app, db 7 | from project.models import User 8 | 9 | 10 | class BaseTestCase(TestCase): 11 | 12 | def create_app(self): 13 | app.config.from_object('project.config.TestingConfig') 14 | return app 15 | 16 | def setUp(self): 17 | db.create_all() 18 | user = User(email="ad@min.com", password="admin_user", paid=False) 19 | db.session.add(user) 20 | db.session.commit() 21 | 22 | def tearDown(self): 23 | db.session.remove() 24 | db.drop_all() 25 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | # tests/test_config.py 2 | 3 | 4 | import unittest 5 | 6 | from flask import current_app 7 | from flask_testing import TestCase 8 | 9 | from project import app 10 | 11 | 12 | class TestTestingConfig(TestCase): 13 | 14 | def create_app(self): 15 | app.config.from_object('project.config.TestingConfig') 16 | return app 17 | 18 | def test_app_is_testing(self): 19 | self.assertTrue(current_app.config['TESTING']) 20 | self.assertTrue(app.config['DEBUG'] is True) 21 | self.assertTrue(app.config['BCRYPT_LOG_ROUNDS'] == 1) 22 | self.assertTrue(app.config['WTF_CSRF_ENABLED'] is False) 23 | 24 | def test_app_exists(self): 25 | self.assertFalse(current_app is None) 26 | 27 | 28 | if __name__ == '__main__': 29 | unittest.main() 30 | -------------------------------------------------------------------------------- /project/templates/user/login.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | 3 | {% block content %} 4 | 5 |

Please login

6 |
7 |
8 | {{ form.csrf_token }} 9 |

10 | {{ form.email(placeholder="email") }} 11 | 12 | {% if form.email.errors %} 13 | {% for error in form.email.errors %} 14 | {{ error }} 15 | {% endfor %} 16 | {% endif %} 17 | 18 |

19 |

20 | {{ form.password(placeholder="password") }} 21 | 22 | {% if form.password.errors %} 23 | {% for error in form.password.errors %} 24 | {{ error }} 25 | {% endfor %} 26 | {% endif %} 27 | 28 |

29 | 30 |

31 |

Need to Register?

32 |
33 | 34 | {% endblock %} -------------------------------------------------------------------------------- /project/templates/user/profile.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | 3 | {% block content %} 4 | 5 |

Your Profile

6 |
7 | 8 | {% if current_user.is_authenticated() %} 9 |

Email: {{current_user.email}}

10 | {% endif %} 11 | 12 |

Change Password

13 |
14 |
15 | {{ form.csrf_token }} 16 |

17 | {{ form.password(placeholder="password") }} 18 | 19 | {% if form.password.errors %} 20 | {% for error in form.password.errors %} 21 | {{ error }} 22 | {% endfor %} 23 | {% endif %} 24 | 25 |

26 |

27 | {{ form.confirm(placeholder="confirm") }} 28 | 29 | {% if form.confirm.errors %} 30 | {% for error in form.confirm.errors %} 31 | {{ error }} 32 | {% endfor %} 33 | {% endif %} 34 | 35 |

36 | 37 |
38 | 39 | 40 | {% endblock %} -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Flask User Management 2 | 3 | [![Build Status](https://travis-ci.org/mjhea0/flask-basic-registration.svg?branch=master)](https://travis-ci.org/mjhea0/flask-basic-registration) 4 | 5 | Starter app for managing users - login/logout and registration. 6 | 7 | ## QuickStart 8 | 9 | ### Set Environment Variables 10 | 11 | ```sh 12 | $ export APP_SETTINGS="project.config.DevelopmentConfig" 13 | ``` 14 | 15 | or 16 | 17 | ```sh 18 | $ export APP_SETTINGS="project.config.ProductionConfig" 19 | ``` 20 | 21 | ### Update Settings in Production 22 | 23 | 1. `SECRET_KEY` 24 | 1. `SQLALCHEMY_DATABASE_URI` 25 | 26 | ### Create DB 27 | 28 | ```sh 29 | $ python manage.py create_db 30 | $ python manage.py db init 31 | $ python manage.py db migrate 32 | $ python manage.py create_admin 33 | ``` 34 | 35 | ### Run 36 | 37 | ```sh 38 | $ python manage.py runserver 39 | ``` 40 | 41 | ### Testing 42 | 43 | Without coverage: 44 | 45 | ```sh 46 | $ python manage.py test 47 | ``` 48 | 49 | With coverage: 50 | 51 | ```sh 52 | $ python manage.py cov 53 | ``` 54 | -------------------------------------------------------------------------------- /project/models.py: -------------------------------------------------------------------------------- 1 | # project/models.py 2 | 3 | 4 | import datetime 5 | 6 | from project import db, bcrypt 7 | 8 | 9 | class User(db.Model): 10 | 11 | __tablename__ = "users" 12 | 13 | id = db.Column(db.Integer, primary_key=True) 14 | email = db.Column(db.String, unique=True, nullable=False) 15 | password = db.Column(db.String, nullable=False) 16 | registered_on = db.Column(db.DateTime, nullable=False) 17 | admin = db.Column(db.Boolean, nullable=False, default=False) 18 | 19 | def __init__(self, email, password, paid=False, admin=False): 20 | self.email = email 21 | self.password = bcrypt.generate_password_hash(password) 22 | self.registered_on = datetime.datetime.now() 23 | self.admin = admin 24 | 25 | def is_authenticated(self): 26 | return True 27 | 28 | def is_active(self): 29 | return True 30 | 31 | def is_anonymous(self): 32 | return False 33 | 34 | def get_id(self): 35 | return self.id 36 | 37 | def __repr__(self): 38 | return 'Please Register 6 |
7 |
8 | {{ form.csrf_token }} 9 | {{ form.email(placeholder="email") }} 10 | 11 | {% if form.email.errors %} 12 | {% for error in form.email.errors %} 13 | {{ error }} 14 | {% endfor %} 15 | {% endif %} 16 | 17 |

18 |

19 | {{ form.password(placeholder="password") }} 20 | 21 | {% if form.password.errors %} 22 | {% for error in form.password.errors %} 23 | {{ error }} 24 | {% endfor %} 25 | {% endif %} 26 | 27 |

28 |

29 | {{ form.confirm(placeholder="confirm") }} 30 | 31 | {% if form.confirm.errors %} 32 | {% for error in form.confirm.errors %} 33 | {{ error }} 34 | {% endfor %} 35 | {% endif %} 36 | 37 |

38 | 39 |

40 |

Already have an account? Sign in.

41 |
42 | 43 | {% endblock %} -------------------------------------------------------------------------------- /project/config.py: -------------------------------------------------------------------------------- 1 | # project/config.py 2 | 3 | import os 4 | basedir = os.path.abspath(os.path.dirname(__file__)) 5 | 6 | 7 | class BaseConfig(object): 8 | """Base configuration.""" 9 | SECRET_KEY = 'my_precious' 10 | DEBUG = False 11 | BCRYPT_LOG_ROUNDS = 13 12 | WTF_CSRF_ENABLED = True 13 | DEBUG_TB_ENABLED = False 14 | DEBUG_TB_INTERCEPT_REDIRECTS = False 15 | SQLALCHEMY_TRACK_MODIFICATIONS = False 16 | 17 | 18 | class DevelopmentConfig(BaseConfig): 19 | """Development configuration.""" 20 | DEBUG = True 21 | WTF_CSRF_ENABLED = False 22 | SQLALCHEMY_DATABASE_URI = 'sqlite:///' + os.path.join(basedir, 'dev.sqlite') 23 | DEBUG_TB_ENABLED = True 24 | 25 | 26 | class TestingConfig(BaseConfig): 27 | """Testing configuration.""" 28 | TESTING = True 29 | DEBUG = True 30 | BCRYPT_LOG_ROUNDS = 1 31 | WTF_CSRF_ENABLED = False 32 | SQLALCHEMY_DATABASE_URI = 'sqlite://' 33 | 34 | 35 | class ProductionConfig(BaseConfig): 36 | """Production configuration.""" 37 | SECRET_KEY = 'my_precious' 38 | DEBUG = False 39 | SQLALCHEMY_DATABASE_URI = 'postgresql://localhost/example' 40 | DEBUG_TB_ENABLED = False 41 | STRIPE_SECRET_KEY = 'foo' 42 | STRIPE_PUBLISHABLE_KEY = 'bar' 43 | -------------------------------------------------------------------------------- /project/templates/navigation.html: -------------------------------------------------------------------------------- 1 | 2 | 34 | -------------------------------------------------------------------------------- /tests/test_forms.py: -------------------------------------------------------------------------------- 1 | # tests/test_forms.py 2 | 3 | 4 | import unittest 5 | 6 | from project.util import BaseTestCase 7 | from project.user.forms import RegisterForm, LoginForm 8 | 9 | 10 | class TestRegisterForm(BaseTestCase): 11 | 12 | def test_validate_success_register_form(self): 13 | # Ensure correct data validates. 14 | form = RegisterForm( 15 | email='new@test.test', 16 | password='example', confirm='example') 17 | self.assertTrue(form.validate()) 18 | 19 | def test_validate_invalid_password_format(self): 20 | # Ensure incorrect data does not validate. 21 | form = RegisterForm( 22 | email='new@test.test', 23 | password='example', confirm='') 24 | self.assertFalse(form.validate()) 25 | 26 | def test_validate_email_already_registered(self): 27 | # Ensure user can't register when a duplicate email is used 28 | form = RegisterForm( 29 | email='ad@min.com', 30 | password='admin_user', 31 | confirm='admin_user' 32 | ) 33 | self.assertFalse(form.validate()) 34 | 35 | 36 | class TestLoginForm(BaseTestCase): 37 | 38 | def test_validate_success_login_form(self): 39 | # Ensure correct data validates. 40 | form = LoginForm(email='ad@min.com', password='admin_user') 41 | self.assertTrue(form.validate()) 42 | 43 | def test_validate_invalid_email_format(self): 44 | # Ensure invalid email format throws error. 45 | form = LoginForm(email='unknown', password='example') 46 | self.assertFalse(form.validate()) 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /project/user/forms.py: -------------------------------------------------------------------------------- 1 | # project/user/forms.py 2 | 3 | 4 | from flask_wtf import FlaskForm 5 | from wtforms import TextField, PasswordField 6 | from wtforms.validators import DataRequired, Email, Length, EqualTo 7 | 8 | from project.models import User 9 | 10 | 11 | class LoginForm(FlaskForm): 12 | email = TextField('email', validators=[DataRequired(), Email()]) 13 | password = PasswordField('password', validators=[DataRequired()]) 14 | 15 | 16 | class RegisterForm(FlaskForm): 17 | email = TextField( 18 | 'email', 19 | validators=[DataRequired(), Email(message=None), Length(min=6, max=40)]) 20 | password = PasswordField( 21 | 'password', 22 | validators=[DataRequired(), Length(min=6, max=25)] 23 | ) 24 | confirm = PasswordField( 25 | 'Repeat password', 26 | validators=[ 27 | DataRequired(), 28 | EqualTo('password', message='Passwords must match.') 29 | ] 30 | ) 31 | 32 | def validate(self): 33 | initial_validation = super(RegisterForm, self).validate() 34 | if not initial_validation: 35 | return False 36 | user = User.query.filter_by(email=self.email.data).first() 37 | if user: 38 | self.email.errors.append("Email already registered") 39 | return False 40 | return True 41 | 42 | 43 | class ChangePasswordForm(FlaskForm): 44 | password = PasswordField( 45 | 'password', 46 | validators=[DataRequired(), Length(min=6, max=25)] 47 | ) 48 | confirm = PasswordField( 49 | 'Repeat password', 50 | validators=[ 51 | DataRequired(), 52 | EqualTo('password', message='Passwords must match.') 53 | ] 54 | ) 55 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | # manage.py 2 | 3 | 4 | import os 5 | import unittest 6 | 7 | import coverage 8 | from flask_script import Manager 9 | from flask_migrate import Migrate, MigrateCommand 10 | 11 | from project import app, db 12 | from project.models import User 13 | 14 | 15 | app.config.from_object(os.environ['APP_SETTINGS']) 16 | 17 | migrate = Migrate(app, db) 18 | manager = Manager(app) 19 | 20 | # migrations 21 | manager.add_command('db', MigrateCommand) 22 | 23 | 24 | @manager.command 25 | def test(): 26 | """Runs the unit tests without coverage.""" 27 | tests = unittest.TestLoader().discover('tests') 28 | result = unittest.TextTestRunner(verbosity=2).run(tests) 29 | if result.wasSuccessful(): 30 | return 0 31 | else: 32 | return 1 33 | 34 | 35 | @manager.command 36 | def cov(): 37 | """Runs the unit tests with coverage.""" 38 | cov = coverage.coverage(branch=True, include='project/*') 39 | cov.start() 40 | tests = unittest.TestLoader().discover('tests') 41 | unittest.TextTestRunner(verbosity=2).run(tests) 42 | cov.stop() 43 | cov.save() 44 | print('Coverage Summary:') 45 | cov.report() 46 | basedir = os.path.abspath(os.path.dirname(__file__)) 47 | covdir = os.path.join(basedir, 'tmp/coverage') 48 | cov.html_report(directory=covdir) 49 | print('HTML version: file://%s/index.html' % covdir) 50 | cov.erase() 51 | 52 | 53 | @manager.command 54 | def create_db(): 55 | """Creates the db tables.""" 56 | db.create_all() 57 | 58 | 59 | @manager.command 60 | def drop_db(): 61 | """Drops the db tables.""" 62 | db.drop_all() 63 | 64 | 65 | @manager.command 66 | def create_admin(): 67 | """Creates the admin user.""" 68 | db.session.add(User("ad@min.com", "admin")) 69 | db.session.commit() 70 | 71 | 72 | if __name__ == "__main__": 73 | manager.run() 74 | -------------------------------------------------------------------------------- /project/__init__.py: -------------------------------------------------------------------------------- 1 | # project/__init__.py 2 | 3 | 4 | ################# 5 | #### imports #### 6 | ################# 7 | 8 | import os 9 | 10 | from flask import Flask, render_template 11 | from flask_login import LoginManager 12 | from flask_bcrypt import Bcrypt 13 | from flask_mail import Mail 14 | from flask_debugtoolbar import DebugToolbarExtension 15 | from flask_sqlalchemy import SQLAlchemy 16 | 17 | 18 | ################ 19 | #### config #### 20 | ################ 21 | 22 | app = Flask(__name__) 23 | 24 | app.config.from_object(os.environ['APP_SETTINGS']) 25 | 26 | #################### 27 | #### extensions #### 28 | #################### 29 | 30 | login_manager = LoginManager() 31 | login_manager.init_app(app) 32 | bcrypt = Bcrypt(app) 33 | mail = Mail(app) 34 | toolbar = DebugToolbarExtension(app) 35 | db = SQLAlchemy(app) 36 | 37 | 38 | #################### 39 | #### blueprints #### 40 | #################### 41 | 42 | from project.main.views import main_blueprint 43 | from project.user.views import user_blueprint 44 | app.register_blueprint(main_blueprint) 45 | app.register_blueprint(user_blueprint) 46 | 47 | 48 | #################### 49 | #### flask-login #### 50 | #################### 51 | 52 | from project.models import User 53 | 54 | login_manager.login_view = "user.login" 55 | login_manager.login_message_category = "danger" 56 | 57 | 58 | @login_manager.user_loader 59 | def load_user(user_id): 60 | return User.query.filter(User.id == int(user_id)).first() 61 | 62 | 63 | ######################## 64 | #### error handlers #### 65 | ######################## 66 | 67 | @app.errorhandler(403) 68 | def forbidden_page(error): 69 | return render_template("errors/403.html"), 403 70 | 71 | 72 | @app.errorhandler(404) 73 | def page_not_found(error): 74 | return render_template("errors/404.html"), 404 75 | 76 | 77 | @app.errorhandler(500) 78 | def server_error_page(error): 79 | return render_template("errors/500.html"), 500 80 | -------------------------------------------------------------------------------- /project/templates/_base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Flask User Management 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | {% block css %}{% endblock %} 14 | 15 | 16 | 17 | {% include "navigation.html" %} 18 | 19 |
20 | 21 |
22 | 23 | 24 | {% with messages = get_flashed_messages(with_categories=true) %} 25 | {% if messages %} 26 |
27 |
28 | {% for category, message in messages %} 29 |
30 | × 31 | {{message}} 32 |
33 | {% endfor %} 34 |
35 |
36 | {% endif %} 37 | {% endwith %} 38 | 39 | 40 | {% block content %}{% endblock %} 41 | 42 |
43 | 44 | 45 | {% if error %} 46 |

Error: {{ error }}

47 | {% endif %} 48 | 49 |
50 | 51 | 52 | 53 | 54 | 55 | {% block js %}{% endblock %} 56 | 57 | 58 | -------------------------------------------------------------------------------- /tests/test_functional.py: -------------------------------------------------------------------------------- 1 | # tests/test_functional.py 2 | 3 | 4 | import unittest 5 | 6 | from flask_login import current_user 7 | 8 | from project.util import BaseTestCase 9 | 10 | 11 | class TestPublic(BaseTestCase): 12 | 13 | def test_main_route_requires_login(self): 14 | # Ensure main route requres logged in user. 15 | response = self.client.get('/', follow_redirects=True) 16 | self.assertTrue(response.status_code == 200) 17 | self.assertIn(b'Please log in to access this page', response.data) 18 | 19 | def test_logout_route_requires_login(self): 20 | # Ensure logout route requres logged in user. 21 | response = self.client.get('/logout', follow_redirects=True) 22 | self.assertIn(b'Please log in to access this page', response.data) 23 | 24 | 25 | class TestLoggingInOut(BaseTestCase): 26 | 27 | def test_correct_login(self): 28 | # Ensure login behaves correctly with correct credentials 29 | with self.client: 30 | response = self.client.post( 31 | '/login', 32 | data=dict(email="ad@min.com", password="admin_user"), 33 | follow_redirects=True 34 | ) 35 | self.assertIn(b'Welcome', response.data) 36 | self.assertTrue(current_user.email == "ad@min.com") 37 | self.assertTrue(current_user.is_active()) 38 | self.assertTrue(response.status_code == 200) 39 | 40 | def test_logout_behaves_correctly(self): 41 | # Ensure logout behaves correctly, regarding the session 42 | with self.client: 43 | self.client.post( 44 | '/login', 45 | data=dict(email="ad@min.com", password="admin_user"), 46 | follow_redirects=True 47 | ) 48 | response = self.client.get('/logout', follow_redirects=True) 49 | self.assertIn(b'You were logged out.', response.data) 50 | self.assertFalse(current_user.is_active) 51 | 52 | 53 | if __name__ == '__main__': 54 | unittest.main() 55 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | # tests/test_models.py 2 | 3 | 4 | import datetime 5 | import unittest 6 | 7 | from flask_login import current_user 8 | 9 | from project import bcrypt 10 | from project.util import BaseTestCase 11 | from project.models import User 12 | 13 | 14 | class TestUser(BaseTestCase): 15 | 16 | def test_user_registration(self): 17 | # Ensure user registration behaves correctly. 18 | with self.client: 19 | self.client.post('/register', data=dict( 20 | email='test@user.com', 21 | password='test_user', confirm='test_user' 22 | ), follow_redirects=True) 23 | user = User.query.filter_by(email='test@user.com').first() 24 | self.assertTrue(user.id) 25 | self.assertTrue(user.email == 'test@user.com') 26 | self.assertFalse(user.admin) 27 | 28 | def test_get_by_id(self): 29 | # Ensure id is correct for the current/logged in user 30 | with self.client: 31 | self.client.post('/login', data=dict( 32 | email='ad@min.com', password='admin_user' 33 | ), follow_redirects=True) 34 | self.assertTrue(current_user.id == 1) 35 | 36 | def test_registered_on_defaults_to_datetime(self): 37 | # Ensure that registered_on is a datetime 38 | with self.client: 39 | self.client.post('/login', data=dict( 40 | email='ad@min.com', password='admin_user' 41 | ), follow_redirects=True) 42 | user = User.query.filter_by(email='ad@min.com').first() 43 | self.assertIsInstance(user.registered_on, datetime.datetime) 44 | 45 | def test_check_password(self): 46 | # Ensure given password is correct after unhashing 47 | user = User.query.filter_by(email='ad@min.com').first() 48 | self.assertTrue(bcrypt.check_password_hash(user.password, 'admin_user')) 49 | self.assertFalse(bcrypt.check_password_hash(user.password, 'foobar')) 50 | 51 | def test_validate_invalid_password(self): 52 | # Ensure user can't login when the pasword is incorrect 53 | with self.client: 54 | response = self.client.post('/login', data=dict( 55 | email='ad@min.com', password='foo_bar' 56 | ), follow_redirects=True) 57 | self.assertIn(b'Invalid email and/or password.', response.data) 58 | 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /project/user/views.py: -------------------------------------------------------------------------------- 1 | # project/user/views.py 2 | 3 | 4 | ################# 5 | #### imports #### 6 | ################# 7 | 8 | from flask import render_template, Blueprint, url_for, \ 9 | redirect, flash, request 10 | from flask_login import login_user, logout_user, \ 11 | login_required, current_user 12 | 13 | from project.models import User 14 | # from project.email import send_email 15 | from project import db, bcrypt 16 | from .forms import LoginForm, RegisterForm, ChangePasswordForm 17 | 18 | 19 | ################ 20 | #### config #### 21 | ################ 22 | 23 | user_blueprint = Blueprint('user', __name__,) 24 | 25 | 26 | ################ 27 | #### routes #### 28 | ################ 29 | 30 | @user_blueprint.route('/register', methods=['GET', 'POST']) 31 | def register(): 32 | form = RegisterForm(request.form) 33 | if form.validate_on_submit(): 34 | user = User( 35 | email=form.email.data, 36 | password=form.password.data 37 | ) 38 | db.session.add(user) 39 | db.session.commit() 40 | 41 | login_user(user) 42 | flash('You registered and are now logged in. Welcome!', 'success') 43 | 44 | return redirect(url_for('main.home')) 45 | 46 | return render_template('user/register.html', form=form) 47 | 48 | 49 | @user_blueprint.route('/login', methods=['GET', 'POST']) 50 | def login(): 51 | form = LoginForm(request.form) 52 | if form.validate_on_submit(): 53 | user = User.query.filter_by(email=form.email.data).first() 54 | if user and bcrypt.check_password_hash( 55 | user.password, request.form['password']): 56 | login_user(user) 57 | flash('Welcome.', 'success') 58 | return redirect(url_for('main.home')) 59 | else: 60 | flash('Invalid email and/or password.', 'danger') 61 | return render_template('user/login.html', form=form) 62 | return render_template('user/login.html', form=form) 63 | 64 | 65 | @user_blueprint.route('/logout') 66 | @login_required 67 | def logout(): 68 | logout_user() 69 | flash('You were logged out.', 'success') 70 | return redirect(url_for('user.login')) 71 | 72 | 73 | @user_blueprint.route('/profile', methods=['GET', 'POST']) 74 | @login_required 75 | def profile(): 76 | form = ChangePasswordForm(request.form) 77 | if form.validate_on_submit(): 78 | user = User.query.filter_by(email=current_user.email).first() 79 | if user: 80 | user.password = bcrypt.generate_password_hash(form.password.data) 81 | db.session.commit() 82 | flash('Password successfully changed.', 'success') 83 | return redirect(url_for('user.profile')) 84 | else: 85 | flash('Password change was unsuccessful.', 'danger') 86 | return redirect(url_for('user.profile')) 87 | return render_template('user/profile.html', form=form) 88 | --------------------------------------------------------------------------------