├── .gitignore
├── .travis.yml
├── manage.py
├── project
├── __init__.py
├── config_sample.py
├── main
│ ├── __init__.py
│ └── views.py
├── models.py
├── static
│ ├── main.css
│ ├── main.js
│ └── member.js
├── templates
│ ├── _base.html
│ ├── errors
│ │ ├── 401.html
│ │ ├── 403.html
│ │ ├── 404.html
│ │ └── 500.html
│ ├── header.html
│ ├── main
│ │ └── home.html
│ └── user
│ │ ├── login.html
│ │ ├── members.html
│ │ └── register.html
├── user
│ ├── __init__.py
│ ├── forms.py
│ └── views.py
└── util.py
├── readme.md
├── requirements.txt
└── tests
├── __init__.py
├── base.py
├── helpers.py
├── test_config.py
├── test_main.py
├── test_models.py
└── test_user.py
/.gitignore:
--------------------------------------------------------------------------------
1 | env
2 | temp
3 | tmp
4 | migrations
5 |
6 | *.pyc
7 | *.sqlite
8 | *.coverage
9 | .DS_Store
10 | config.py
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "2.7"
4 | install:
5 | - pip install -r requirements.txt
6 | script: nosetests
--------------------------------------------------------------------------------
/manage.py:
--------------------------------------------------------------------------------
1 | # manage.py
2 |
3 |
4 | import os
5 | import unittest
6 | import coverage
7 |
8 | from flask.ext.script import Manager
9 | from flask.ext.migrate import Migrate, MigrateCommand
10 |
11 | from project import app, db
12 | from project.models import User
13 |
14 |
15 | app.config.from_object(os.environ['APP_SETTINGS'])
16 |
17 | migrate = Migrate(app, db)
18 | manager = Manager(app)
19 |
20 | # migrations
21 | manager.add_command('db', MigrateCommand)
22 |
23 |
24 | @manager.command
25 | def test():
26 | """Runs the unit tests without coverage."""
27 | tests = unittest.TestLoader().discover('tests')
28 | unittest.TextTestRunner(verbosity=2).run(tests)
29 |
30 |
31 | @manager.command
32 | def cov():
33 | """Runs the unit tests with coverage."""
34 | cov = coverage.coverage(
35 | branch=True,
36 | include='project/*',
37 | omit=['*/__init__.py']
38 | )
39 | cov.start()
40 | tests = unittest.TestLoader().discover('tests')
41 | unittest.TextTestRunner(verbosity=2).run(tests)
42 | cov.stop()
43 | cov.save()
44 | print 'Coverage Summary:'
45 | cov.report()
46 | basedir = os.path.abspath(os.path.dirname(__file__))
47 | covdir = os.path.join(basedir, 'tmp/coverage')
48 | cov.html_report(directory=covdir)
49 | print('HTML version: file://%s/index.html' % covdir)
50 | cov.erase()
51 |
52 |
53 | @manager.command
54 | def create_db():
55 | """Creates the db tables."""
56 | db.create_all()
57 |
58 |
59 | @manager.command
60 | def drop_db():
61 | """Drops the db tables."""
62 | db.drop_all()
63 |
64 |
65 | @manager.command
66 | def create_data():
67 | """Creates sample data."""
68 | pass
69 |
70 |
71 | @manager.command
72 | def create_admin():
73 | """Creates the admin user."""
74 | db.session.add(User(email="ad@min.com", password="admin", admin=True))
75 | db.session.commit()
76 |
77 |
78 | if __name__ == '__main__':
79 | manager.run()
80 |
--------------------------------------------------------------------------------
/project/__init__.py:
--------------------------------------------------------------------------------
1 | # project/__init__.py
2 |
3 |
4 | #################
5 | #### imports ####
6 | #################
7 |
8 | import os
9 |
10 | from flask import Flask, render_template
11 | from flask.ext.login import LoginManager
12 | from flask.ext.bcrypt import Bcrypt
13 | from flask_mail import Mail
14 | from flask.ext.debugtoolbar import DebugToolbarExtension
15 | from flask_bootstrap import Bootstrap
16 | from flask.ext.sqlalchemy import SQLAlchemy
17 |
18 |
19 | ################
20 | #### config ####
21 | ################
22 |
23 | app = Flask(__name__)
24 | app.config.from_object(os.environ['APP_SETTINGS'])
25 |
26 | stripe_keys = {
27 | 'stripe_secret_key': app.config['STRIPE_SECRET_KEY'],
28 | 'stripe_publishable_key': app.config['STRIPE_PUBLISHABLE_KEY']
29 | }
30 |
31 |
32 | ####################
33 | #### extensions ####
34 | ####################
35 |
36 | login_manager = LoginManager()
37 | login_manager.init_app(app)
38 | bcrypt = Bcrypt(app)
39 | mail = Mail(app)
40 | toolbar = DebugToolbarExtension(app)
41 | bootstrap = Bootstrap(app)
42 | db = SQLAlchemy(app)
43 |
44 |
45 | ####################
46 | #### blueprints ####
47 | ####################
48 |
49 | from project.main.views import main_blueprint
50 | from project.user.views import user_blueprint
51 | app.register_blueprint(main_blueprint)
52 | app.register_blueprint(user_blueprint)
53 |
54 |
55 | ####################
56 | #### flask-login ####
57 | ####################
58 |
59 | from models import User
60 |
61 | login_manager.login_view = "user.login"
62 | login_manager.login_message_category = 'danger'
63 |
64 |
65 | @login_manager.user_loader
66 | def load_user(user_id):
67 | return User.query.filter(User.id == int(user_id)).first()
68 |
69 |
70 | ########################
71 | #### error handlers ####
72 | ########################
73 |
74 | @app.errorhandler(403)
75 | def forbidden_page(error):
76 | return render_template("errors/403.html"), 403
77 |
78 |
79 | @app.errorhandler(404)
80 | def page_not_found(error):
81 | return render_template("errors/404.html"), 404
82 |
83 |
84 | @app.errorhandler(500)
85 | def server_error_page(error):
86 | return render_template("errors/500.html"), 500
87 |
--------------------------------------------------------------------------------
/project/config_sample.py:
--------------------------------------------------------------------------------
1 | # project/config.py
2 |
3 | import os
4 | basedir = os.path.abspath(os.path.dirname(__file__))
5 |
6 |
7 | class BaseConfig(object):
8 | """Base configuration."""
9 | SECRET_KEY = 'my_precious'
10 | DEBUG = False
11 | BCRYPT_LOG_ROUNDS = 13
12 | WTF_CSRF_ENABLED = True
13 | DEBUG_TB_ENABLED = False
14 | DEBUG_TB_INTERCEPT_REDIRECTS = False
15 |
16 |
17 | class DevelopmentConfig(BaseConfig):
18 | """Development configuration."""
19 | DEBUG = True
20 | BCRYPT_LOG_ROUNDS = 1
21 | WTF_CSRF_ENABLED = False
22 | SQLALCHEMY_DATABASE_URI = 'sqlite:///' + os.path.join(basedir, 'dev.sqlite')
23 | DEBUG_TB_ENABLED = True
24 | STRIPE_SECRET_KEY = 'test key'
25 | STRIPE_PUBLISHABLE_KEY = 'test key'
26 |
27 |
28 | class ProductionConfig(BaseConfig):
29 | """Production configuration."""
30 | SECRET_KEY = 'my_precious'
31 | DEBUG = False
32 | SQLALCHEMY_DATABASE_URI = 'postgresql://localhost/example'
33 | DEBUG_TB_ENABLED = False
34 | STRIPE_SECRET_KEY = 'live key'
35 | STRIPE_PUBLISHABLE_KEY = 'live key'
36 |
--------------------------------------------------------------------------------
/project/main/__init__.py:
--------------------------------------------------------------------------------
1 | # project/users/__init__.py
2 |
--------------------------------------------------------------------------------
/project/main/views.py:
--------------------------------------------------------------------------------
1 | # project/main/views.py
2 |
3 |
4 | #################
5 | #### imports ####
6 | #################
7 |
8 | from flask import render_template, Blueprint
9 |
10 | ################
11 | #### config ####
12 | ################
13 |
14 | main_blueprint = Blueprint('main', __name__,)
15 |
16 |
17 | ################
18 | #### routes ####
19 | ################
20 |
21 |
22 | @main_blueprint.route('/')
23 | def home():
24 | return render_template('main/home.html')
25 |
--------------------------------------------------------------------------------
/project/models.py:
--------------------------------------------------------------------------------
1 | # project/models.py
2 |
3 |
4 | import datetime
5 |
6 | from project import db, bcrypt
7 |
8 |
9 | class User(db.Model):
10 |
11 | __tablename__ = "users"
12 |
13 | id = db.Column(db.Integer, primary_key=True)
14 | email = db.Column(db.String, unique=True, nullable=False)
15 | password = db.Column(db.String, nullable=False)
16 | registered_on = db.Column(db.DateTime, nullable=False)
17 | paid = db.Column(db.Boolean, nullable=False, default=False)
18 | admin = db.Column(db.Boolean, nullable=False, default=False)
19 |
20 | def __init__(self, email, password, paid=False, admin=False):
21 | self.email = email
22 | self.password = bcrypt.generate_password_hash(password)
23 | self.registered_on = datetime.datetime.now()
24 | self.paid = paid
25 | self.admin = admin
26 |
27 | def is_authenticated(self):
28 | return True
29 |
30 | def is_active(self):
31 | return True
32 |
33 | def is_anonymous(self):
34 | return False
35 |
36 | def get_id(self):
37 | return unicode(self.id)
38 |
39 | def __repr__(self):
40 | return '').val(token));
32 | // and submit
33 | $form.get(0).submit();
34 | }
35 | }
36 |
37 | });
--------------------------------------------------------------------------------
/project/templates/_base.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Flask Paywall{% block title %}{% endblock %}
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | {% block css %}{% endblock %}
15 |
16 |
17 |
18 |
19 | {% include 'header.html' %}
20 |
21 |
22 |
23 |
24 |
25 | {% with messages = get_flashed_messages(with_categories=true) %}
26 | {% if messages %}
27 |
28 |
29 |
30 | {% for category, message in messages %}
31 |
32 |
×
33 | {{message}}
34 |
35 | {% endfor %}
36 |
37 |
38 | {% endif %}
39 | {% endwith %}
40 |
41 |
42 |
43 | {% block content %}{% endblock %}
44 |
45 |
46 |
47 |
48 | {% if error %}
49 |
Error: {{ error }}
50 | {% endif %}
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 | {% block js %}{% endblock %}
60 |
61 |
62 |
--------------------------------------------------------------------------------
/project/templates/errors/401.html:
--------------------------------------------------------------------------------
1 | {% extends "_base.html" %}
2 |
3 | {% block page_title %}- Unauthorized{% endblock %}
4 |
5 | {% block content %}
6 |
7 |
8 |
401
9 |
You are not authorized to view this page. Please log in.
10 |
11 |
12 | {% endblock %}
13 |
--------------------------------------------------------------------------------
/project/templates/errors/403.html:
--------------------------------------------------------------------------------
1 | {% extends "_base.html" %}
2 | {% block content %}
3 | 403
4 | Run along!
5 | Return Home?
6 | {% endblock %}
7 |
--------------------------------------------------------------------------------
/project/templates/errors/404.html:
--------------------------------------------------------------------------------
1 | {% extends "_base.html" %}
2 |
3 | {% block page_title %}- Page Not Found{% endblock %}
4 |
5 | {% block content %}
6 |
7 |
8 |
404
9 |
Sorry. The requested page doesn't exist. Go home.
10 |
11 |
12 | {% endblock %}
13 |
--------------------------------------------------------------------------------
/project/templates/errors/500.html:
--------------------------------------------------------------------------------
1 | {% extends "_base.html" %}
2 |
3 | {% block page_title %}- Server Error{% endblock %}
4 |
5 | {% block content %}
6 |
7 |
8 |
500
9 |
Sorry. Something went terribly wrong. Go home.
10 |
11 |
12 | {% endblock %}
13 |
--------------------------------------------------------------------------------
/project/templates/header.html:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/project/templates/main/home.html:
--------------------------------------------------------------------------------
1 | {% extends "_base.html" %}
2 |
3 | {% block content %}
4 |
5 |
6 |
Welcome to Flask Paywall!
7 |
Setup a paywall with Flask and Stripe to offer paid access to your premium content.
8 |
Sign up today
9 |
10 |
11 | {% endblock %}
--------------------------------------------------------------------------------
/project/templates/user/login.html:
--------------------------------------------------------------------------------
1 | {% extends '_base.html' %}
2 | {% import "bootstrap/wtf.html" as wtf %}
3 |
4 | {% block content %}
5 |
6 |
9 |
10 |
11 |
12 |
13 |
31 |
32 |
33 | {% endblock content %}
34 |
--------------------------------------------------------------------------------
/project/templates/user/members.html:
--------------------------------------------------------------------------------
1 | {% extends "_base.html" %}
2 |
3 | {% block content %}
4 |
5 | Member's page
6 |
7 |
8 |
9 |
10 | Have you paid?
11 | {% if current_user.paid %}
12 | Yes.
13 | {% else %}
14 | No. Run along.
15 | {% endif %}
16 |
17 |
18 | {% endblock %}
19 |
20 |
--------------------------------------------------------------------------------
/project/templates/user/register.html:
--------------------------------------------------------------------------------
1 | {% extends '_base.html' %}
2 | {% import "bootstrap/wtf.html" as wtf %}
3 |
4 | {% block content %}
5 |
6 |
9 |
10 |
11 |
12 |
58 |
59 |
60 | {% endblock content %}
61 |
62 | {% block js %}
63 |
64 |
65 | {% endblock %}
66 |
--------------------------------------------------------------------------------
/project/user/__init__.py:
--------------------------------------------------------------------------------
1 | # project/users/__init__.py
2 |
--------------------------------------------------------------------------------
/project/user/forms.py:
--------------------------------------------------------------------------------
1 | # project/users/forms.py
2 |
3 |
4 | from flask_wtf import Form
5 | from wtforms import TextField, PasswordField, SelectField
6 | from wtforms.validators import DataRequired, Email, Length, EqualTo
7 |
8 | from project.models import User
9 |
10 |
11 | class LoginForm(Form):
12 | email = TextField(
13 | 'Email Address', validators=[DataRequired(), Email()]
14 | )
15 | password = PasswordField(
16 | 'Password', validators=[DataRequired()]
17 | )
18 |
19 |
20 | class RegisterForm(Form):
21 | email = TextField(
22 | 'Email Address',
23 | validators=[DataRequired(), Email(message=None), Length(min=6, max=40)]
24 | )
25 | password = PasswordField(
26 | 'Password',
27 | validators=[DataRequired(), Length(min=6, max=25)]
28 | )
29 | confirm = PasswordField(
30 | 'Confirm password',
31 | validators=[
32 | DataRequired(),
33 | EqualTo('password', message='Passwords must match.')
34 | ]
35 | )
36 | card_number = TextField(
37 | 'Credit Card Number',
38 | validators=[DataRequired()]
39 | )
40 | cvc = TextField(
41 | 'CVC Code',
42 | validators=[DataRequired()]
43 | )
44 | expiration_month = SelectField(
45 | 'Expiration Month',
46 | validators=[DataRequired()]
47 | )
48 | expiration_year = SelectField(
49 | 'Expiration Year', validators=[DataRequired()]
50 | )
51 |
52 | def validate(self):
53 | initial_validation = super(RegisterForm, self).validate()
54 | if not initial_validation:
55 | return False
56 | user = User.query.filter_by(email=self.email.data).first()
57 | if user:
58 | self.email.errors.append("Email already registered")
59 | return False
60 | return True
61 |
--------------------------------------------------------------------------------
/project/user/views.py:
--------------------------------------------------------------------------------
1 | # project/users/views.py
2 |
3 |
4 | #################
5 | #### imports ####
6 | #################
7 |
8 | import stripe
9 |
10 | from flask import render_template, Blueprint, url_for, \
11 | redirect, flash, request
12 | from flask.ext.login import login_user, logout_user, \
13 | login_required
14 |
15 | from project import db, bcrypt, stripe_keys
16 | from project.util import check_paid
17 | from project.models import User
18 | from project.user.forms import LoginForm, RegisterForm
19 |
20 | ################
21 | #### config ####
22 | ################
23 |
24 | stripe.api_key = stripe_keys['stripe_secret_key']
25 |
26 | user_blueprint = Blueprint('user', __name__,)
27 |
28 |
29 | ################
30 | #### routes ####
31 | ################
32 |
33 | MONTHS = [
34 | ("01", "01 - January"),
35 | ("02", "02 - February"),
36 | ("03", "03 - March"),
37 | ("04", "04 - April"),
38 | ("05", "05 - May"),
39 | ("06", "06 - June"),
40 | ("07", "07 - July"),
41 | ("08", "08 - August"),
42 | ("09", "09 - September"),
43 | ("10", "10 - October"),
44 | ("11", "11 - November"),
45 | ("12", "12 - Devember")
46 | ]
47 |
48 |
49 | @user_blueprint.route('/register', methods=['GET', 'POST'])
50 | def register():
51 | form = RegisterForm(request.form)
52 | form.expiration_month.choices = MONTHS
53 | form.expiration_year.choices = [
54 | (str(year), year) for year in (range(2015, 2026))]
55 | if form.validate_on_submit():
56 | user = User(
57 | email=form.email.data,
58 | password=form.password.data
59 | )
60 | db.session.add(user)
61 | db.session.commit()
62 | amount = 500
63 | customer = stripe.Customer.create(
64 | email=user.email,
65 | card=request.form['stripeToken']
66 | )
67 | try:
68 | charge = stripe.Charge.create(
69 | customer=customer.id,
70 | amount=amount,
71 | currency='usd',
72 | description='Flask Charge'
73 | )
74 | if charge:
75 | User.query.filter_by(
76 | email=user.email).update(dict(paid=True))
77 | db.session.commit()
78 | login_user(user)
79 | flash('Thanks for paying!', 'success')
80 | return redirect(url_for('user.members'))
81 | except stripe.CardError:
82 | flash('Oops. Something is wrong with your card info!', 'danger')
83 | return redirect(url_for('user.register'))
84 | return render_template(
85 | 'user/register.html',
86 | form=form, key=stripe_keys['stripe_publishable_key'])
87 |
88 |
89 | @user_blueprint.route('/login', methods=['GET', 'POST'])
90 | def login():
91 | form = LoginForm(request.form)
92 | if form.validate_on_submit():
93 | user = User.query.filter_by(email=form.email.data).first()
94 | if user and bcrypt.check_password_hash(
95 | user.password, request.form['password']):
96 | login_user(user)
97 | flash('You are logged in. Welcome!', 'success')
98 | return redirect(url_for('user.members'))
99 | else:
100 | flash('Invalid email and/or password.', 'danger')
101 | return render_template('user/login.html', form=form)
102 | return render_template('user/login.html', form=form)
103 |
104 |
105 | @user_blueprint.route('/logout')
106 | @login_required
107 | @check_paid
108 | def logout():
109 | logout_user()
110 | flash('You are logged out. Bye!', 'success')
111 | return redirect(url_for('main.home'))
112 |
113 |
114 | @user_blueprint.route('/members')
115 | @login_required
116 | @check_paid
117 | def members():
118 | return render_template('user/members.html')
119 |
--------------------------------------------------------------------------------
/project/util.py:
--------------------------------------------------------------------------------
1 | # project/util.py
2 |
3 |
4 | from functools import wraps
5 |
6 | from flask import flash, redirect, url_for
7 | from flask.ext.login import current_user
8 |
9 |
10 | def check_paid(func):
11 | @wraps(func)
12 | def decorated_function(*args, **kwargs):
13 | if current_user.paid is False:
14 | flash("Sorry. You must pay to access this page.", 'danger')
15 | return redirect(url_for('user.register'))
16 | return func(*args, **kwargs)
17 |
18 | return decorated_function
19 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Flask Paywall
2 |
3 | Setup a paywall with Flask and Stripe to offer paid access to your premium content.
4 |
5 | ## Workflow
6 |
7 | After user register and pays (from the same form), s/he has access to premium content.
8 |
9 | ## QuickStart
10 |
11 | ### Set Environment Variables
12 |
13 | Rename *config_sample.py* to *config.py*, update the config settings, and then run:
14 |
15 | ```sh
16 | $ export APP_SETTINGS="project.config.DevelopmentConfig"
17 | ```
18 |
19 | or
20 |
21 | ```sh
22 | $ export APP_SETTINGS="project.config.ProductionConfig"
23 | ```
24 |
25 | ### Create DB
26 |
27 | ```sh
28 | $ python manage.py create_db
29 | $ python manage.py db init
30 | $ python manage.py db migrate
31 | $ python manage.py create_admin
32 | ```
33 |
34 | ### Run
35 |
36 | ```sh
37 | $ python manage.py runserver
38 | ```
39 |
40 | ### Test
41 |
42 | Without coverage:
43 |
44 | ```sh
45 | $ python manage.py test
46 | ```
47 |
48 | With coverage:
49 |
50 | ```sh
51 | $ python manage.py cov
52 | ```
53 |
54 | ## Todo
55 |
56 | 1. forgot password
57 | 1. change/update password
58 | 1. logging
59 | 1. admin charts
60 | 1. upgrade to python 3/update dependencies
61 | 1. add autoenv
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Flask==0.10.1
2 | Flask-Bcrypt==0.6.0
3 | Flask-Bootstrap==3.3.0.1
4 | Flask-DebugToolbar==0.9.0
5 | Flask-Login==0.2.11
6 | Flask-Mail==0.9.1
7 | Flask-Migrate==1.2.0
8 | Flask-OAuth==0.12
9 | Flask-SQLAlchemy==2.0
10 | Flask-Script==2.0.5
11 | Flask-Testing==0.4.2
12 | Flask-WTF==0.10.2
13 | Jinja2==2.7.3
14 | Mako==1.0.0
15 | MarkupSafe==0.23
16 | SQLAlchemy==0.9.8
17 | WTForms==2.0.1
18 | Werkzeug==0.9.6
19 | alembic==0.6.7
20 | blinker==1.3
21 | coverage==4.0a1
22 | ecdsa==0.11
23 | httplib2==0.9
24 | itsdangerous==0.24
25 | oauth2==1.5.211
26 | paramiko==1.15.1
27 | psycopg2==2.5.4
28 | py-bcrypt==0.4
29 | pycrypto==2.6.1
30 | requests==2.4.3
31 | stripe==1.19.1
32 | wsgiref==0.1.2
33 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # tests/__init__.py
2 |
--------------------------------------------------------------------------------
/tests/base.py:
--------------------------------------------------------------------------------
1 | # tests/base.py
2 |
3 |
4 | from flask.ext.testing import TestCase
5 |
6 | from project import app, db
7 | from project.models import User
8 |
9 |
10 | class BaseTestCase(TestCase):
11 |
12 | def create_app(self):
13 | app.config.from_object('project.config.TestingConfig')
14 | return app
15 |
16 | def setUp(self):
17 | db.create_all()
18 | user = User(email="test@testing.com", password="testing", paid=True)
19 | db.session.add(user)
20 | db.session.commit()
21 |
22 | def tearDown(self):
23 | db.session.remove()
24 | db.drop_all()
25 |
--------------------------------------------------------------------------------
/tests/helpers.py:
--------------------------------------------------------------------------------
1 | # tests/helpers.py
2 |
--------------------------------------------------------------------------------
/tests/test_config.py:
--------------------------------------------------------------------------------
1 | # tests/test_config.py
2 |
3 |
4 | import unittest
5 |
6 | from flask import current_app
7 | from flask.ext.testing import TestCase
8 |
9 | from project import app
10 |
11 |
12 | class TestDevelopmentConfig(TestCase):
13 |
14 | def create_app(self):
15 | app.config.from_object('project.config.DevelopmentConfig')
16 | return app
17 |
18 | def test_app_is_development(self):
19 | self.assertFalse(current_app.config['TESTING'])
20 | self.assertTrue(app.config['DEBUG'] is True)
21 | self.assertTrue(app.config['WTF_CSRF_ENABLED'] is False)
22 | self.assertTrue(app.config['DEBUG_TB_ENABLED'] is True)
23 | self.assertFalse(current_app is None)
24 |
25 |
26 | class TestTestingConfig(TestCase):
27 |
28 | def create_app(self):
29 | app.config.from_object('project.config.TestingConfig')
30 | return app
31 |
32 | def test_app_is_testing(self):
33 | self.assertTrue(current_app.config['TESTING'])
34 | self.assertTrue(app.config['DEBUG'] is True)
35 | self.assertTrue(app.config['BCRYPT_LOG_ROUNDS'] == 1)
36 | self.assertTrue(app.config['WTF_CSRF_ENABLED'] is False)
37 |
38 |
39 | class TestProductionConfig(TestCase):
40 |
41 | def create_app(self):
42 | app.config.from_object('project.config.ProductionConfig')
43 | return app
44 |
45 | def test_app_is_production(self):
46 | self.assertFalse(current_app.config['TESTING'])
47 | self.assertTrue(app.config['DEBUG'] is False)
48 | self.assertTrue(app.config['DEBUG_TB_ENABLED'] is False)
49 | self.assertTrue(app.config['WTF_CSRF_ENABLED'] is True)
50 | self.assertTrue(app.config['BCRYPT_LOG_ROUNDS'] == 13)
51 |
52 |
53 | if __name__ == '__main__':
54 | unittest.main()
55 |
--------------------------------------------------------------------------------
/tests/test_main.py:
--------------------------------------------------------------------------------
1 | # tests/test_main.py
2 |
3 |
4 | import unittest
5 |
6 | from base import BaseTestCase
7 |
8 |
9 | class TestMainBlueprint(BaseTestCase):
10 |
11 | def test_index(self):
12 | # Ensure Flask is setup.
13 | response = self.client.get('/', follow_redirects=True)
14 | self.assertEqual(response.status_code, 200)
15 | self.assertIn('Welcome to Flask Paywall!', response.data)
16 | self.assertIn('Login', response.data)
17 | self.assertIn('Sign up', response.data)
18 |
19 | def test_404(self):
20 | # Ensure 404 error is handled.
21 | response = self.client.get('/404')
22 | self.assert404(response)
23 | self.assertTemplateUsed('errors/404.html')
24 |
25 |
26 | if __name__ == '__main__':
27 | unittest.main()
28 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | # tests/test_models.py
2 |
3 |
4 | import datetime
5 | import unittest
6 |
7 | from flask.ext.login import current_user
8 |
9 | from base import BaseTestCase
10 | from project import bcrypt
11 | from project.models import User
12 |
13 |
14 | class TestUser(BaseTestCase):
15 |
16 | def test_get_by_id(self):
17 | # Ensure id is correct for the current/logged in user.
18 | with self.client:
19 | self.client.post('/login', data=dict(
20 | email='test@testing.com', password='testing'
21 | ), follow_redirects=True)
22 | self.assertTrue(current_user.id == 1)
23 |
24 | def test_registered_on_defaults_to_datetime(self):
25 | # Ensure that registered_on is a datetime.
26 | with self.client:
27 | self.client.post('/login', data=dict(
28 | email='test@testing.com', password='testing'
29 | ), follow_redirects=True)
30 | user = User.query.filter_by(email='test@testing.com').first()
31 | self.assertIsInstance(user.registered_on, datetime.datetime)
32 |
33 | def test_check_password(self):
34 | # Ensure given password is correct after unhashing
35 | user = User.query.filter_by(email='test@testing.com').first()
36 | self.assertTrue(bcrypt.check_password_hash(user.password, 'testing'))
37 | self.assertFalse(bcrypt.check_password_hash(user.password, 'foobar'))
38 |
39 | def test_validate_invalid_password(self):
40 | # Ensure user can't login when the pasword is incorrect
41 | with self.client:
42 | response = self.client.post('/login', data=dict(
43 | email='test@testing.com', password='wrong_password'
44 | ), follow_redirects=True)
45 | self.assertIn('Invalid email and/or password.', response.data)
46 |
47 |
48 | if __name__ == '__main__':
49 | unittest.main()
50 |
--------------------------------------------------------------------------------
/tests/test_user.py:
--------------------------------------------------------------------------------
1 | # tests/test_user.py
2 |
3 |
4 | import datetime
5 | import unittest
6 | import stripe
7 |
8 | from flask.ext.login import current_user
9 |
10 | from base import BaseTestCase
11 | from project import db
12 | from project.models import User
13 | from project.user.forms import LoginForm
14 |
15 |
16 | class TestUserBlueprint(BaseTestCase):
17 |
18 | def test_correct_login(self):
19 | # Ensure login behaves correctly with correct credentials.
20 | with self.client:
21 | response = self.client.post(
22 | '/login',
23 | data=dict(email="test@testing.com", password="testing"),
24 | follow_redirects=True
25 | )
26 | self.assertIn('Welcome', response.data)
27 | self.assertIn('Logout', response.data)
28 | self.assertIn('Members', response.data)
29 | self.assertTrue(current_user.email == "test@testing.com")
30 | self.assertTrue(current_user.is_active())
31 | self.assertEqual(response.status_code, 200)
32 |
33 | def test_logout_behaves_correctly(self):
34 | # Ensure logout behaves correctly - regarding the session.
35 | with self.client:
36 | self.client.post(
37 | '/login',
38 | data=dict(email="test@testing.com", password="testing"),
39 | follow_redirects=True
40 | )
41 | response = self.client.get('/logout', follow_redirects=True)
42 | self.assertIn('You are logged out. Bye!\n', response.data)
43 | self.assertFalse(current_user.is_active())
44 |
45 | def test_logout_route_requires_login(self):
46 | # Ensure logout route requires logged in user.
47 | response = self.client.get('/logout', follow_redirects=True)
48 | self.assertIn('Please log in to access this page', response.data)
49 |
50 | def test_member_route_requires_login(self):
51 | # Ensure member route requires logged in user.
52 | response = self.client.get('/members', follow_redirects=True)
53 | self.assertIn('Please log in to access this page', response.data)
54 |
55 | def test_member_route_requires_payment(self):
56 | # Ensure member route requires a paid user.
57 | user = User(email="unpaid@testing.com", password="testing", paid=False)
58 | db.session.add(user)
59 | db.session.commit()
60 | with self.client:
61 | response = self.client.post(
62 | '/login',
63 | data=dict(email="unpaid@testing.com", password="testing"),
64 | follow_redirects=True
65 | )
66 | self.assertIn(
67 | 'Sorry. You must pay to access this page.', response.data)
68 |
69 | def test_validate_success_login_form(self):
70 | # Ensure correct data validates.
71 | form = LoginForm(email='test@testing.com', password='admin_user')
72 | self.assertTrue(form.validate())
73 |
74 | def test_validate_invalid_email_format(self):
75 | # Ensure invalid email format throws error.
76 | form = LoginForm(email='unknown', password='example')
77 | self.assertFalse(form.validate())
78 |
79 | def test_register_route(self):
80 | # Ensure about route behaves correctly.
81 | response = self.client.get('/register', follow_redirects=True)
82 | self.assertIn('Please Register
\n', response.data)
83 |
84 | def test_user_registration_error(self):
85 | # Ensure registration behaves correctly.
86 | token = stripe.Token.create(
87 | card={
88 | 'number': '4242424242424242',
89 | 'exp_month': '06',
90 | 'exp_year': str(datetime.datetime.today().year + 1),
91 | 'cvc': '123',
92 | }
93 | )
94 | with self.client:
95 | response = self.client.post(
96 | '/register',
97 | data=dict(
98 | email="new@tester.com",
99 | password="testing",
100 | confirm="testing",
101 | card_number="4242424242424242",
102 | cvc="123",
103 | expiration_month="01",
104 | expiration_year="2015",
105 | stripeToken=token.id,
106 | ),
107 | follow_redirects=True
108 | )
109 | user = User.query.filter_by(email='new@tester.com').first()
110 | self.assertEqual(user.email, 'new@tester.com')
111 | self.assertTrue(user.paid)
112 | self.assertIn('Thanks for paying!', response.data)
113 | self.assertTrue(current_user.email == "new@tester.com")
114 | self.assertTrue(current_user.is_active())
115 | self.assertEqual(response.status_code, 200)
116 |
117 |
118 | if __name__ == '__main__':
119 | unittest.main()
120 |
--------------------------------------------------------------------------------