├── .gitignore ├── .travis.yml ├── manage.py ├── project ├── __init__.py ├── config_sample.py ├── main │ ├── __init__.py │ └── views.py ├── models.py ├── static │ ├── main.css │ ├── main.js │ └── member.js ├── templates │ ├── _base.html │ ├── errors │ │ ├── 401.html │ │ ├── 403.html │ │ ├── 404.html │ │ └── 500.html │ ├── header.html │ ├── main │ │ └── home.html │ └── user │ │ ├── login.html │ │ ├── members.html │ │ └── register.html ├── user │ ├── __init__.py │ ├── forms.py │ └── views.py └── util.py ├── readme.md ├── requirements.txt └── tests ├── __init__.py ├── base.py ├── helpers.py ├── test_config.py ├── test_main.py ├── test_models.py └── test_user.py /.gitignore: -------------------------------------------------------------------------------- 1 | env 2 | temp 3 | tmp 4 | migrations 5 | 6 | *.pyc 7 | *.sqlite 8 | *.coverage 9 | .DS_Store 10 | config.py -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | install: 5 | - pip install -r requirements.txt 6 | script: nosetests -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | # manage.py 2 | 3 | 4 | import os 5 | import unittest 6 | import coverage 7 | 8 | from flask.ext.script import Manager 9 | from flask.ext.migrate import Migrate, MigrateCommand 10 | 11 | from project import app, db 12 | from project.models import User 13 | 14 | 15 | app.config.from_object(os.environ['APP_SETTINGS']) 16 | 17 | migrate = Migrate(app, db) 18 | manager = Manager(app) 19 | 20 | # migrations 21 | manager.add_command('db', MigrateCommand) 22 | 23 | 24 | @manager.command 25 | def test(): 26 | """Runs the unit tests without coverage.""" 27 | tests = unittest.TestLoader().discover('tests') 28 | unittest.TextTestRunner(verbosity=2).run(tests) 29 | 30 | 31 | @manager.command 32 | def cov(): 33 | """Runs the unit tests with coverage.""" 34 | cov = coverage.coverage( 35 | branch=True, 36 | include='project/*', 37 | omit=['*/__init__.py'] 38 | ) 39 | cov.start() 40 | tests = unittest.TestLoader().discover('tests') 41 | unittest.TextTestRunner(verbosity=2).run(tests) 42 | cov.stop() 43 | cov.save() 44 | print 'Coverage Summary:' 45 | cov.report() 46 | basedir = os.path.abspath(os.path.dirname(__file__)) 47 | covdir = os.path.join(basedir, 'tmp/coverage') 48 | cov.html_report(directory=covdir) 49 | print('HTML version: file://%s/index.html' % covdir) 50 | cov.erase() 51 | 52 | 53 | @manager.command 54 | def create_db(): 55 | """Creates the db tables.""" 56 | db.create_all() 57 | 58 | 59 | @manager.command 60 | def drop_db(): 61 | """Drops the db tables.""" 62 | db.drop_all() 63 | 64 | 65 | @manager.command 66 | def create_data(): 67 | """Creates sample data.""" 68 | pass 69 | 70 | 71 | @manager.command 72 | def create_admin(): 73 | """Creates the admin user.""" 74 | db.session.add(User(email="ad@min.com", password="admin", admin=True)) 75 | db.session.commit() 76 | 77 | 78 | if __name__ == '__main__': 79 | manager.run() 80 | -------------------------------------------------------------------------------- /project/__init__.py: -------------------------------------------------------------------------------- 1 | # project/__init__.py 2 | 3 | 4 | ################# 5 | #### imports #### 6 | ################# 7 | 8 | import os 9 | 10 | from flask import Flask, render_template 11 | from flask.ext.login import LoginManager 12 | from flask.ext.bcrypt import Bcrypt 13 | from flask_mail import Mail 14 | from flask.ext.debugtoolbar import DebugToolbarExtension 15 | from flask_bootstrap import Bootstrap 16 | from flask.ext.sqlalchemy import SQLAlchemy 17 | 18 | 19 | ################ 20 | #### config #### 21 | ################ 22 | 23 | app = Flask(__name__) 24 | app.config.from_object(os.environ['APP_SETTINGS']) 25 | 26 | stripe_keys = { 27 | 'stripe_secret_key': app.config['STRIPE_SECRET_KEY'], 28 | 'stripe_publishable_key': app.config['STRIPE_PUBLISHABLE_KEY'] 29 | } 30 | 31 | 32 | #################### 33 | #### extensions #### 34 | #################### 35 | 36 | login_manager = LoginManager() 37 | login_manager.init_app(app) 38 | bcrypt = Bcrypt(app) 39 | mail = Mail(app) 40 | toolbar = DebugToolbarExtension(app) 41 | bootstrap = Bootstrap(app) 42 | db = SQLAlchemy(app) 43 | 44 | 45 | #################### 46 | #### blueprints #### 47 | #################### 48 | 49 | from project.main.views import main_blueprint 50 | from project.user.views import user_blueprint 51 | app.register_blueprint(main_blueprint) 52 | app.register_blueprint(user_blueprint) 53 | 54 | 55 | #################### 56 | #### flask-login #### 57 | #################### 58 | 59 | from models import User 60 | 61 | login_manager.login_view = "user.login" 62 | login_manager.login_message_category = 'danger' 63 | 64 | 65 | @login_manager.user_loader 66 | def load_user(user_id): 67 | return User.query.filter(User.id == int(user_id)).first() 68 | 69 | 70 | ######################## 71 | #### error handlers #### 72 | ######################## 73 | 74 | @app.errorhandler(403) 75 | def forbidden_page(error): 76 | return render_template("errors/403.html"), 403 77 | 78 | 79 | @app.errorhandler(404) 80 | def page_not_found(error): 81 | return render_template("errors/404.html"), 404 82 | 83 | 84 | @app.errorhandler(500) 85 | def server_error_page(error): 86 | return render_template("errors/500.html"), 500 87 | -------------------------------------------------------------------------------- /project/config_sample.py: -------------------------------------------------------------------------------- 1 | # project/config.py 2 | 3 | import os 4 | basedir = os.path.abspath(os.path.dirname(__file__)) 5 | 6 | 7 | class BaseConfig(object): 8 | """Base configuration.""" 9 | SECRET_KEY = 'my_precious' 10 | DEBUG = False 11 | BCRYPT_LOG_ROUNDS = 13 12 | WTF_CSRF_ENABLED = True 13 | DEBUG_TB_ENABLED = False 14 | DEBUG_TB_INTERCEPT_REDIRECTS = False 15 | 16 | 17 | class DevelopmentConfig(BaseConfig): 18 | """Development configuration.""" 19 | DEBUG = True 20 | BCRYPT_LOG_ROUNDS = 1 21 | WTF_CSRF_ENABLED = False 22 | SQLALCHEMY_DATABASE_URI = 'sqlite:///' + os.path.join(basedir, 'dev.sqlite') 23 | DEBUG_TB_ENABLED = True 24 | STRIPE_SECRET_KEY = 'test key' 25 | STRIPE_PUBLISHABLE_KEY = 'test key' 26 | 27 | 28 | class ProductionConfig(BaseConfig): 29 | """Production configuration.""" 30 | SECRET_KEY = 'my_precious' 31 | DEBUG = False 32 | SQLALCHEMY_DATABASE_URI = 'postgresql://localhost/example' 33 | DEBUG_TB_ENABLED = False 34 | STRIPE_SECRET_KEY = 'live key' 35 | STRIPE_PUBLISHABLE_KEY = 'live key' 36 | -------------------------------------------------------------------------------- /project/main/__init__.py: -------------------------------------------------------------------------------- 1 | # project/users/__init__.py 2 | -------------------------------------------------------------------------------- /project/main/views.py: -------------------------------------------------------------------------------- 1 | # project/main/views.py 2 | 3 | 4 | ################# 5 | #### imports #### 6 | ################# 7 | 8 | from flask import render_template, Blueprint 9 | 10 | ################ 11 | #### config #### 12 | ################ 13 | 14 | main_blueprint = Blueprint('main', __name__,) 15 | 16 | 17 | ################ 18 | #### routes #### 19 | ################ 20 | 21 | 22 | @main_blueprint.route('/') 23 | def home(): 24 | return render_template('main/home.html') 25 | -------------------------------------------------------------------------------- /project/models.py: -------------------------------------------------------------------------------- 1 | # project/models.py 2 | 3 | 4 | import datetime 5 | 6 | from project import db, bcrypt 7 | 8 | 9 | class User(db.Model): 10 | 11 | __tablename__ = "users" 12 | 13 | id = db.Column(db.Integer, primary_key=True) 14 | email = db.Column(db.String, unique=True, nullable=False) 15 | password = db.Column(db.String, nullable=False) 16 | registered_on = db.Column(db.DateTime, nullable=False) 17 | paid = db.Column(db.Boolean, nullable=False, default=False) 18 | admin = db.Column(db.Boolean, nullable=False, default=False) 19 | 20 | def __init__(self, email, password, paid=False, admin=False): 21 | self.email = email 22 | self.password = bcrypt.generate_password_hash(password) 23 | self.registered_on = datetime.datetime.now() 24 | self.paid = paid 25 | self.admin = admin 26 | 27 | def is_authenticated(self): 28 | return True 29 | 30 | def is_active(self): 31 | return True 32 | 33 | def is_anonymous(self): 34 | return False 35 | 36 | def get_id(self): 37 | return unicode(self.id) 38 | 39 | def __repr__(self): 40 | return '').val(token)); 32 | // and submit 33 | $form.get(0).submit(); 34 | } 35 | } 36 | 37 | }); -------------------------------------------------------------------------------- /project/templates/_base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Flask Paywall{% block title %}{% endblock %} 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | {% block css %}{% endblock %} 15 | 16 | 17 | 18 | 19 | {% include 'header.html' %} 20 | 21 |
22 |
23 | 24 | 25 | {% with messages = get_flashed_messages(with_categories=true) %} 26 | {% if messages %} 27 |
28 |
29 |
30 | {% for category, message in messages %} 31 |
32 | × 33 | {{message}} 34 |
35 | {% endfor %} 36 |
37 |
38 | {% endif %} 39 | {% endwith %} 40 | 41 | 42 | 43 | {% block content %}{% endblock %} 44 | 45 |
46 | 47 | 48 | {% if error %} 49 |

Error: {{ error }}

50 | {% endif %} 51 | 52 |
53 |
54 | 55 | 56 | 57 | 58 | 59 | {% block js %}{% endblock %} 60 | 61 | 62 | -------------------------------------------------------------------------------- /project/templates/errors/401.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | 3 | {% block page_title %}- Unauthorized{% endblock %} 4 | 5 | {% block content %} 6 |
7 |
8 |

401

9 |

You are not authorized to view this page. Please log in.

10 |
11 |
12 | {% endblock %} 13 | -------------------------------------------------------------------------------- /project/templates/errors/403.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | {% block content %} 3 |

403

4 |

Run along!

5 |

Return Home?

6 | {% endblock %} 7 | -------------------------------------------------------------------------------- /project/templates/errors/404.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | 3 | {% block page_title %}- Page Not Found{% endblock %} 4 | 5 | {% block content %} 6 |
7 |
8 |

404

9 |

Sorry. The requested page doesn't exist. Go home.

10 |
11 |
12 | {% endblock %} 13 | -------------------------------------------------------------------------------- /project/templates/errors/500.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | 3 | {% block page_title %}- Server Error{% endblock %} 4 | 5 | {% block content %} 6 |
7 |
8 |

500

9 |

Sorry. Something went terribly wrong. Go home.

10 |
11 |
12 | {% endblock %} 13 | -------------------------------------------------------------------------------- /project/templates/header.html: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /project/templates/main/home.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | 3 | {% block content %} 4 | 5 |
6 |

Welcome to Flask Paywall!

7 |

Setup a paywall with Flask and Stripe to offer paid access to your premium content.

8 |

Sign up today

9 |
10 | 11 | {% endblock %} -------------------------------------------------------------------------------- /project/templates/user/login.html: -------------------------------------------------------------------------------- 1 | {% extends '_base.html' %} 2 | {% import "bootstrap/wtf.html" as wtf %} 3 | 4 | {% block content %} 5 | 6 |
7 |

Please login

8 |
9 | 10 |
11 | 12 | 13 |
14 | 15 | {{ form.csrf_token }} 16 | {{ form.hidden_tag() }} 17 | {{ wtf.form_errors(form, hiddens="only") }} 18 | 19 | 20 |
21 | 22 | {{ wtf.form_field(form.email) }} 23 | {{ wtf.form_field(form.password) }} 24 | 25 | 26 |

27 |

Need to Register?

28 |
29 | 30 |
31 | 32 | 33 | {% endblock content %} 34 | -------------------------------------------------------------------------------- /project/templates/user/members.html: -------------------------------------------------------------------------------- 1 | {% extends "_base.html" %} 2 | 3 | {% block content %} 4 | 5 |

Member's page

6 | 7 |
8 | 9 |

10 | Have you paid? 11 | {% if current_user.paid %} 12 | Yes. 13 | {% else %} 14 | No. Run along. 15 | {% endif %} 16 | 17 | 18 | {% endblock %} 19 | 20 | -------------------------------------------------------------------------------- /project/templates/user/register.html: -------------------------------------------------------------------------------- 1 | {% extends '_base.html' %} 2 | {% import "bootstrap/wtf.html" as wtf %} 3 | 4 | {% block content %} 5 | 6 |

7 |

Please Register

8 |
9 | 10 |
11 | 12 |
13 | 14 | {{ form.csrf_token }} 15 | {{ form.hidden_tag() }} 16 | {{ wtf.form_errors(form, hiddens=True) }} 17 | 18 |
19 | 20 |
21 | 22 | {{ wtf.form_field(form.email) }} 23 | {{ wtf.form_field(form.password) }} 24 | {{ wtf.form_field(form.confirm) }} 25 | 26 |

27 | 28 | 29 |

30 |

Already have an account? Sign in.

31 | 32 |
33 | 34 |
35 | 36 | {{ wtf.form_field(form.card_number, **{'data-stripe': 'number'}) }} 37 | {{ wtf.form_field(form.cvc, **{'data-stripe': 'cvc'}) }} 38 | 39 | {{ wtf.form_field(form.expiration_month, **{'data-stripe': 'exp_month'}) }} 40 | 41 | 42 | {{ wtf.form_field(form.expiration_year, **{'data-stripe': 'exp_year'}) }} 43 | 44 | 45 |
46 | 47 |
48 |
49 | 50 |

51 |
52 | 53 |
54 | 55 |
56 | 57 |
58 | 59 | 60 | {% endblock content %} 61 | 62 | {% block js %} 63 | 64 | 65 | {% endblock %} 66 | -------------------------------------------------------------------------------- /project/user/__init__.py: -------------------------------------------------------------------------------- 1 | # project/users/__init__.py 2 | -------------------------------------------------------------------------------- /project/user/forms.py: -------------------------------------------------------------------------------- 1 | # project/users/forms.py 2 | 3 | 4 | from flask_wtf import Form 5 | from wtforms import TextField, PasswordField, SelectField 6 | from wtforms.validators import DataRequired, Email, Length, EqualTo 7 | 8 | from project.models import User 9 | 10 | 11 | class LoginForm(Form): 12 | email = TextField( 13 | 'Email Address', validators=[DataRequired(), Email()] 14 | ) 15 | password = PasswordField( 16 | 'Password', validators=[DataRequired()] 17 | ) 18 | 19 | 20 | class RegisterForm(Form): 21 | email = TextField( 22 | 'Email Address', 23 | validators=[DataRequired(), Email(message=None), Length(min=6, max=40)] 24 | ) 25 | password = PasswordField( 26 | 'Password', 27 | validators=[DataRequired(), Length(min=6, max=25)] 28 | ) 29 | confirm = PasswordField( 30 | 'Confirm password', 31 | validators=[ 32 | DataRequired(), 33 | EqualTo('password', message='Passwords must match.') 34 | ] 35 | ) 36 | card_number = TextField( 37 | 'Credit Card Number', 38 | validators=[DataRequired()] 39 | ) 40 | cvc = TextField( 41 | 'CVC Code', 42 | validators=[DataRequired()] 43 | ) 44 | expiration_month = SelectField( 45 | 'Expiration Month', 46 | validators=[DataRequired()] 47 | ) 48 | expiration_year = SelectField( 49 | 'Expiration Year', validators=[DataRequired()] 50 | ) 51 | 52 | def validate(self): 53 | initial_validation = super(RegisterForm, self).validate() 54 | if not initial_validation: 55 | return False 56 | user = User.query.filter_by(email=self.email.data).first() 57 | if user: 58 | self.email.errors.append("Email already registered") 59 | return False 60 | return True 61 | -------------------------------------------------------------------------------- /project/user/views.py: -------------------------------------------------------------------------------- 1 | # project/users/views.py 2 | 3 | 4 | ################# 5 | #### imports #### 6 | ################# 7 | 8 | import stripe 9 | 10 | from flask import render_template, Blueprint, url_for, \ 11 | redirect, flash, request 12 | from flask.ext.login import login_user, logout_user, \ 13 | login_required 14 | 15 | from project import db, bcrypt, stripe_keys 16 | from project.util import check_paid 17 | from project.models import User 18 | from project.user.forms import LoginForm, RegisterForm 19 | 20 | ################ 21 | #### config #### 22 | ################ 23 | 24 | stripe.api_key = stripe_keys['stripe_secret_key'] 25 | 26 | user_blueprint = Blueprint('user', __name__,) 27 | 28 | 29 | ################ 30 | #### routes #### 31 | ################ 32 | 33 | MONTHS = [ 34 | ("01", "01 - January"), 35 | ("02", "02 - February"), 36 | ("03", "03 - March"), 37 | ("04", "04 - April"), 38 | ("05", "05 - May"), 39 | ("06", "06 - June"), 40 | ("07", "07 - July"), 41 | ("08", "08 - August"), 42 | ("09", "09 - September"), 43 | ("10", "10 - October"), 44 | ("11", "11 - November"), 45 | ("12", "12 - Devember") 46 | ] 47 | 48 | 49 | @user_blueprint.route('/register', methods=['GET', 'POST']) 50 | def register(): 51 | form = RegisterForm(request.form) 52 | form.expiration_month.choices = MONTHS 53 | form.expiration_year.choices = [ 54 | (str(year), year) for year in (range(2015, 2026))] 55 | if form.validate_on_submit(): 56 | user = User( 57 | email=form.email.data, 58 | password=form.password.data 59 | ) 60 | db.session.add(user) 61 | db.session.commit() 62 | amount = 500 63 | customer = stripe.Customer.create( 64 | email=user.email, 65 | card=request.form['stripeToken'] 66 | ) 67 | try: 68 | charge = stripe.Charge.create( 69 | customer=customer.id, 70 | amount=amount, 71 | currency='usd', 72 | description='Flask Charge' 73 | ) 74 | if charge: 75 | User.query.filter_by( 76 | email=user.email).update(dict(paid=True)) 77 | db.session.commit() 78 | login_user(user) 79 | flash('Thanks for paying!', 'success') 80 | return redirect(url_for('user.members')) 81 | except stripe.CardError: 82 | flash('Oops. Something is wrong with your card info!', 'danger') 83 | return redirect(url_for('user.register')) 84 | return render_template( 85 | 'user/register.html', 86 | form=form, key=stripe_keys['stripe_publishable_key']) 87 | 88 | 89 | @user_blueprint.route('/login', methods=['GET', 'POST']) 90 | def login(): 91 | form = LoginForm(request.form) 92 | if form.validate_on_submit(): 93 | user = User.query.filter_by(email=form.email.data).first() 94 | if user and bcrypt.check_password_hash( 95 | user.password, request.form['password']): 96 | login_user(user) 97 | flash('You are logged in. Welcome!', 'success') 98 | return redirect(url_for('user.members')) 99 | else: 100 | flash('Invalid email and/or password.', 'danger') 101 | return render_template('user/login.html', form=form) 102 | return render_template('user/login.html', form=form) 103 | 104 | 105 | @user_blueprint.route('/logout') 106 | @login_required 107 | @check_paid 108 | def logout(): 109 | logout_user() 110 | flash('You are logged out. Bye!', 'success') 111 | return redirect(url_for('main.home')) 112 | 113 | 114 | @user_blueprint.route('/members') 115 | @login_required 116 | @check_paid 117 | def members(): 118 | return render_template('user/members.html') 119 | -------------------------------------------------------------------------------- /project/util.py: -------------------------------------------------------------------------------- 1 | # project/util.py 2 | 3 | 4 | from functools import wraps 5 | 6 | from flask import flash, redirect, url_for 7 | from flask.ext.login import current_user 8 | 9 | 10 | def check_paid(func): 11 | @wraps(func) 12 | def decorated_function(*args, **kwargs): 13 | if current_user.paid is False: 14 | flash("Sorry. You must pay to access this page.", 'danger') 15 | return redirect(url_for('user.register')) 16 | return func(*args, **kwargs) 17 | 18 | return decorated_function 19 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Flask Paywall 2 | 3 | Setup a paywall with Flask and Stripe to offer paid access to your premium content. 4 | 5 | ## Workflow 6 | 7 | After user register and pays (from the same form), s/he has access to premium content. 8 | 9 | ## QuickStart 10 | 11 | ### Set Environment Variables 12 | 13 | Rename *config_sample.py* to *config.py*, update the config settings, and then run: 14 | 15 | ```sh 16 | $ export APP_SETTINGS="project.config.DevelopmentConfig" 17 | ``` 18 | 19 | or 20 | 21 | ```sh 22 | $ export APP_SETTINGS="project.config.ProductionConfig" 23 | ``` 24 | 25 | ### Create DB 26 | 27 | ```sh 28 | $ python manage.py create_db 29 | $ python manage.py db init 30 | $ python manage.py db migrate 31 | $ python manage.py create_admin 32 | ``` 33 | 34 | ### Run 35 | 36 | ```sh 37 | $ python manage.py runserver 38 | ``` 39 | 40 | ### Test 41 | 42 | Without coverage: 43 | 44 | ```sh 45 | $ python manage.py test 46 | ``` 47 | 48 | With coverage: 49 | 50 | ```sh 51 | $ python manage.py cov 52 | ``` 53 | 54 | ## Todo 55 | 56 | 1. forgot password 57 | 1. change/update password 58 | 1. logging 59 | 1. admin charts 60 | 1. upgrade to python 3/update dependencies 61 | 1. add autoenv -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==0.10.1 2 | Flask-Bcrypt==0.6.0 3 | Flask-Bootstrap==3.3.0.1 4 | Flask-DebugToolbar==0.9.0 5 | Flask-Login==0.2.11 6 | Flask-Mail==0.9.1 7 | Flask-Migrate==1.2.0 8 | Flask-OAuth==0.12 9 | Flask-SQLAlchemy==2.0 10 | Flask-Script==2.0.5 11 | Flask-Testing==0.4.2 12 | Flask-WTF==0.10.2 13 | Jinja2==2.7.3 14 | Mako==1.0.0 15 | MarkupSafe==0.23 16 | SQLAlchemy==0.9.8 17 | WTForms==2.0.1 18 | Werkzeug==0.9.6 19 | alembic==0.6.7 20 | blinker==1.3 21 | coverage==4.0a1 22 | ecdsa==0.11 23 | httplib2==0.9 24 | itsdangerous==0.24 25 | oauth2==1.5.211 26 | paramiko==1.15.1 27 | psycopg2==2.5.4 28 | py-bcrypt==0.4 29 | pycrypto==2.6.1 30 | requests==2.4.3 31 | stripe==1.19.1 32 | wsgiref==0.1.2 33 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # tests/__init__.py 2 | -------------------------------------------------------------------------------- /tests/base.py: -------------------------------------------------------------------------------- 1 | # tests/base.py 2 | 3 | 4 | from flask.ext.testing import TestCase 5 | 6 | from project import app, db 7 | from project.models import User 8 | 9 | 10 | class BaseTestCase(TestCase): 11 | 12 | def create_app(self): 13 | app.config.from_object('project.config.TestingConfig') 14 | return app 15 | 16 | def setUp(self): 17 | db.create_all() 18 | user = User(email="test@testing.com", password="testing", paid=True) 19 | db.session.add(user) 20 | db.session.commit() 21 | 22 | def tearDown(self): 23 | db.session.remove() 24 | db.drop_all() 25 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | # tests/helpers.py 2 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | # tests/test_config.py 2 | 3 | 4 | import unittest 5 | 6 | from flask import current_app 7 | from flask.ext.testing import TestCase 8 | 9 | from project import app 10 | 11 | 12 | class TestDevelopmentConfig(TestCase): 13 | 14 | def create_app(self): 15 | app.config.from_object('project.config.DevelopmentConfig') 16 | return app 17 | 18 | def test_app_is_development(self): 19 | self.assertFalse(current_app.config['TESTING']) 20 | self.assertTrue(app.config['DEBUG'] is True) 21 | self.assertTrue(app.config['WTF_CSRF_ENABLED'] is False) 22 | self.assertTrue(app.config['DEBUG_TB_ENABLED'] is True) 23 | self.assertFalse(current_app is None) 24 | 25 | 26 | class TestTestingConfig(TestCase): 27 | 28 | def create_app(self): 29 | app.config.from_object('project.config.TestingConfig') 30 | return app 31 | 32 | def test_app_is_testing(self): 33 | self.assertTrue(current_app.config['TESTING']) 34 | self.assertTrue(app.config['DEBUG'] is True) 35 | self.assertTrue(app.config['BCRYPT_LOG_ROUNDS'] == 1) 36 | self.assertTrue(app.config['WTF_CSRF_ENABLED'] is False) 37 | 38 | 39 | class TestProductionConfig(TestCase): 40 | 41 | def create_app(self): 42 | app.config.from_object('project.config.ProductionConfig') 43 | return app 44 | 45 | def test_app_is_production(self): 46 | self.assertFalse(current_app.config['TESTING']) 47 | self.assertTrue(app.config['DEBUG'] is False) 48 | self.assertTrue(app.config['DEBUG_TB_ENABLED'] is False) 49 | self.assertTrue(app.config['WTF_CSRF_ENABLED'] is True) 50 | self.assertTrue(app.config['BCRYPT_LOG_ROUNDS'] == 13) 51 | 52 | 53 | if __name__ == '__main__': 54 | unittest.main() 55 | -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | # tests/test_main.py 2 | 3 | 4 | import unittest 5 | 6 | from base import BaseTestCase 7 | 8 | 9 | class TestMainBlueprint(BaseTestCase): 10 | 11 | def test_index(self): 12 | # Ensure Flask is setup. 13 | response = self.client.get('/', follow_redirects=True) 14 | self.assertEqual(response.status_code, 200) 15 | self.assertIn('Welcome to Flask Paywall!', response.data) 16 | self.assertIn('Login', response.data) 17 | self.assertIn('Sign up', response.data) 18 | 19 | def test_404(self): 20 | # Ensure 404 error is handled. 21 | response = self.client.get('/404') 22 | self.assert404(response) 23 | self.assertTemplateUsed('errors/404.html') 24 | 25 | 26 | if __name__ == '__main__': 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | # tests/test_models.py 2 | 3 | 4 | import datetime 5 | import unittest 6 | 7 | from flask.ext.login import current_user 8 | 9 | from base import BaseTestCase 10 | from project import bcrypt 11 | from project.models import User 12 | 13 | 14 | class TestUser(BaseTestCase): 15 | 16 | def test_get_by_id(self): 17 | # Ensure id is correct for the current/logged in user. 18 | with self.client: 19 | self.client.post('/login', data=dict( 20 | email='test@testing.com', password='testing' 21 | ), follow_redirects=True) 22 | self.assertTrue(current_user.id == 1) 23 | 24 | def test_registered_on_defaults_to_datetime(self): 25 | # Ensure that registered_on is a datetime. 26 | with self.client: 27 | self.client.post('/login', data=dict( 28 | email='test@testing.com', password='testing' 29 | ), follow_redirects=True) 30 | user = User.query.filter_by(email='test@testing.com').first() 31 | self.assertIsInstance(user.registered_on, datetime.datetime) 32 | 33 | def test_check_password(self): 34 | # Ensure given password is correct after unhashing 35 | user = User.query.filter_by(email='test@testing.com').first() 36 | self.assertTrue(bcrypt.check_password_hash(user.password, 'testing')) 37 | self.assertFalse(bcrypt.check_password_hash(user.password, 'foobar')) 38 | 39 | def test_validate_invalid_password(self): 40 | # Ensure user can't login when the pasword is incorrect 41 | with self.client: 42 | response = self.client.post('/login', data=dict( 43 | email='test@testing.com', password='wrong_password' 44 | ), follow_redirects=True) 45 | self.assertIn('Invalid email and/or password.', response.data) 46 | 47 | 48 | if __name__ == '__main__': 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /tests/test_user.py: -------------------------------------------------------------------------------- 1 | # tests/test_user.py 2 | 3 | 4 | import datetime 5 | import unittest 6 | import stripe 7 | 8 | from flask.ext.login import current_user 9 | 10 | from base import BaseTestCase 11 | from project import db 12 | from project.models import User 13 | from project.user.forms import LoginForm 14 | 15 | 16 | class TestUserBlueprint(BaseTestCase): 17 | 18 | def test_correct_login(self): 19 | # Ensure login behaves correctly with correct credentials. 20 | with self.client: 21 | response = self.client.post( 22 | '/login', 23 | data=dict(email="test@testing.com", password="testing"), 24 | follow_redirects=True 25 | ) 26 | self.assertIn('Welcome', response.data) 27 | self.assertIn('Logout', response.data) 28 | self.assertIn('Members', response.data) 29 | self.assertTrue(current_user.email == "test@testing.com") 30 | self.assertTrue(current_user.is_active()) 31 | self.assertEqual(response.status_code, 200) 32 | 33 | def test_logout_behaves_correctly(self): 34 | # Ensure logout behaves correctly - regarding the session. 35 | with self.client: 36 | self.client.post( 37 | '/login', 38 | data=dict(email="test@testing.com", password="testing"), 39 | follow_redirects=True 40 | ) 41 | response = self.client.get('/logout', follow_redirects=True) 42 | self.assertIn('You are logged out. Bye!\n', response.data) 43 | self.assertFalse(current_user.is_active()) 44 | 45 | def test_logout_route_requires_login(self): 46 | # Ensure logout route requires logged in user. 47 | response = self.client.get('/logout', follow_redirects=True) 48 | self.assertIn('Please log in to access this page', response.data) 49 | 50 | def test_member_route_requires_login(self): 51 | # Ensure member route requires logged in user. 52 | response = self.client.get('/members', follow_redirects=True) 53 | self.assertIn('Please log in to access this page', response.data) 54 | 55 | def test_member_route_requires_payment(self): 56 | # Ensure member route requires a paid user. 57 | user = User(email="unpaid@testing.com", password="testing", paid=False) 58 | db.session.add(user) 59 | db.session.commit() 60 | with self.client: 61 | response = self.client.post( 62 | '/login', 63 | data=dict(email="unpaid@testing.com", password="testing"), 64 | follow_redirects=True 65 | ) 66 | self.assertIn( 67 | 'Sorry. You must pay to access this page.', response.data) 68 | 69 | def test_validate_success_login_form(self): 70 | # Ensure correct data validates. 71 | form = LoginForm(email='test@testing.com', password='admin_user') 72 | self.assertTrue(form.validate()) 73 | 74 | def test_validate_invalid_email_format(self): 75 | # Ensure invalid email format throws error. 76 | form = LoginForm(email='unknown', password='example') 77 | self.assertFalse(form.validate()) 78 | 79 | def test_register_route(self): 80 | # Ensure about route behaves correctly. 81 | response = self.client.get('/register', follow_redirects=True) 82 | self.assertIn('

Please Register

\n', response.data) 83 | 84 | def test_user_registration_error(self): 85 | # Ensure registration behaves correctly. 86 | token = stripe.Token.create( 87 | card={ 88 | 'number': '4242424242424242', 89 | 'exp_month': '06', 90 | 'exp_year': str(datetime.datetime.today().year + 1), 91 | 'cvc': '123', 92 | } 93 | ) 94 | with self.client: 95 | response = self.client.post( 96 | '/register', 97 | data=dict( 98 | email="new@tester.com", 99 | password="testing", 100 | confirm="testing", 101 | card_number="4242424242424242", 102 | cvc="123", 103 | expiration_month="01", 104 | expiration_year="2015", 105 | stripeToken=token.id, 106 | ), 107 | follow_redirects=True 108 | ) 109 | user = User.query.filter_by(email='new@tester.com').first() 110 | self.assertEqual(user.email, 'new@tester.com') 111 | self.assertTrue(user.paid) 112 | self.assertIn('Thanks for paying!', response.data) 113 | self.assertTrue(current_user.email == "new@tester.com") 114 | self.assertTrue(current_user.is_active()) 115 | self.assertEqual(response.status_code, 200) 116 | 117 | 118 | if __name__ == '__main__': 119 | unittest.main() 120 | --------------------------------------------------------------------------------