├── .DS_Store ├── .github └── workflows │ └── tests.yml ├── .gitignore ├── Dockerfile ├── LICENSE.md ├── Makefile ├── README.md ├── app ├── .DS_Store ├── __init__.py ├── config.py ├── controllers │ ├── __init__.py │ ├── account.py │ ├── auth.py │ ├── billing.py │ ├── members.py │ └── tenant.py ├── database.py ├── dependencies │ ├── __init__.py │ ├── auth.py │ └── tenant.py ├── main.py ├── models │ ├── __init__.py │ ├── account.py │ ├── base.py │ └── tenant.py ├── routers │ ├── __init__.py │ ├── accounts.py │ ├── auth.py │ ├── email_addresses.py │ ├── members.py │ └── tenants.py ├── schemas │ ├── __init__.py │ ├── account.py │ ├── auth.py │ ├── member.py │ └── tenant.py ├── settings.py ├── templates │ ├── .DS_Store │ └── email │ │ └── empty.html ├── tests │ ├── __init__.py │ ├── conftests.py │ ├── test_accounts.py │ ├── test_email_addresses.py │ ├── test_login.py │ ├── test_members.py │ └── test_register.py └── utils │ └── email.py ├── docker-compose.yml ├── prestart.sh └── requirements.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ryanb58/fastapi-saas-base/c985834dd3c5d84abd208bd7dae42a73b8fc8f1e/.DS_Store -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v1 12 | - name: Build the containers 13 | run: make build 14 | - name: Run tests 15 | run: make test 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | 4 | .vscode/ 5 | venv/ 6 | *.db 7 | .env 8 | .converage -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7 2 | 3 | ADD ./requirements.txt /app 4 | RUN pip install --no-cache-dir --upgrade pip 5 | RUN pip install -r requirements.txt 6 | 7 | ADD ./ /app 8 | WORKDIR /app/ 9 | 10 | 11 | ENV PYTHONPATH=/app 12 | 13 | EXPOSE 80 14 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2021 Taylor Brazelton 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | 9 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: cleanup 2 | cleanup: ## Cleanup 3 | # alias pyclean='find . -name "*.py[co]" -o -name __pycache__ -exec rm -rf {} +' 4 | find . -name "*.py[co]" -o -name __pycache__ -exec rm -rf {} + 5 | 6 | .PHONY: dbshell 7 | dbshell: ## Open a shell to the db running in the docker container (must run in container) 8 | pip install pgcli 9 | pgcli -h db -p 5432 -u postgres 10 | 11 | .PHONY: test 12 | test: ## Run unittests. 13 | docker-compose run --service-ports app make run-tests 14 | 15 | .PHONY: run-tests 16 | run-tests: ## Run tests (must run in container) 17 | pytest -s --cov=app --no-cov-on-fail --cov-fail-under=80 18 | 19 | .PHONY: shell 20 | shell: ## Run container. 21 | docker-compose run --service-ports app /bin/bash 22 | 23 | .PHONY: fmt 24 | fmt: ## Format files. 25 | docker-compose run app make run-fmt 26 | 27 | .PHONY: run-fmt 28 | run-fmt: ## Format files. 29 | black app/ 30 | 31 | .PHONY: build 32 | build: ## Build the docker container 33 | git config core.hooksPath .githooks 34 | docker-compose build 35 | docker-compose down 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Building a robust API base for a SAAS application. 2 | 3 | ### Getting Started: 4 | 5 | Create a `.env` file in this projects root and fill out the following: 6 | 7 | ``` 8 | FRONTEND_BASE_URL=http://localhost 9 | SECRET_KEY= 10 | 11 | STRIPE_API_SECRET= 12 | STRIPE_DEFAULT_PRODUCT_ID= 13 | STRIPE_DEFAULT_PLAN_ID= 14 | 15 | SMTP_PORT=1025 16 | SMTP_HOST=mail 17 | SMTP_USERNAME=username 18 | SMTP_PASSWORD=password 19 | ``` 20 | -------------------------------------------------------------------------------- /app/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ryanb58/fastapi-saas-base/c985834dd3c5d84abd208bd7dae42a73b8fc8f1e/app/.DS_Store -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ryanb58/fastapi-saas-base/c985834dd3c5d84abd208bd7dae42a73b8fc8f1e/app/__init__.py -------------------------------------------------------------------------------- /app/config.py: -------------------------------------------------------------------------------- 1 | from starlette.config import Config 2 | 3 | config = Config(".env") 4 | -------------------------------------------------------------------------------- /app/controllers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ryanb58/fastapi-saas-base/c985834dd3c5d84abd208bd7dae42a73b8fc8f1e/app/controllers/__init__.py -------------------------------------------------------------------------------- /app/controllers/account.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | from sqlalchemy.orm import Session 4 | from fastapi import Depends 5 | import jwt 6 | 7 | from app.schemas import account as schemas 8 | from app.models.account import Account 9 | from app.models.account import EmailAddress 10 | from app.models.account import Password 11 | from app.utils.email import send_email 12 | from app.settings import SECRET_KEY 13 | from app.settings import FRONTEND_BASE_URL 14 | from app.settings import EMAIL_TOKEN_EXPIRE_MINUTES 15 | #### Accounts 16 | 17 | # to get a string like this run: 18 | # openssl rand -hex 32 19 | ALGORITHM = "HS256" 20 | 21 | def mark_account_as_verified_and_active(db_session: Session, token:int): 22 | """Mark an account as verified and active.""" 23 | account_id = get_id_from_token(token) 24 | if not account_id: 25 | raise NotImplementedError("Invalid token.") 26 | account_obj = db_session.query(Account).get(account_id) 27 | 28 | # Mark account as active. 29 | account_obj.is_active = True 30 | 31 | # Mark email as verified. 32 | email_obj = account_obj.primary_email_address 33 | email_obj.verified = True 34 | email_obj.verified_on = datetime.now() 35 | 36 | db_session.add(account_obj) 37 | db_session.add(email_obj) 38 | db_session.commit() 39 | 40 | 41 | def get_account(db_session: Session, id: int): 42 | return db_session.query(Account).filter(Account.id == id).first() 43 | 44 | 45 | def get_account_by_email(db_session: Session, email: str): 46 | email_obj = ( 47 | db_session.query(EmailAddress).filter(EmailAddress.email == email).first() 48 | ) 49 | if not email_obj: 50 | return None 51 | return email_obj.account 52 | 53 | 54 | def get_accounts(db_session: Session, skip: int = 0, limit: int = 100): 55 | return db_session.query(Account).offset(skip).limit(limit).all() 56 | 57 | 58 | def create_account( 59 | db_session: Session, 60 | first_name: str, 61 | last_name: str, 62 | email: str, 63 | password: str, 64 | is_system_admin: bool = False, 65 | is_active: bool = False, 66 | send_registration_email: bool = True, 67 | is_verified: bool = False, 68 | ): 69 | """Create an user account.""" 70 | account_obj = Account( 71 | first_name=first_name, 72 | last_name=last_name, 73 | is_system_admin=is_system_admin, 74 | is_active=is_active, 75 | ) 76 | db_session.add(account_obj) 77 | db_session.flush() 78 | 79 | email_obj = EmailAddress( 80 | account_id=account_obj.id, email=email, primary=True, verified=is_verified 81 | ) 82 | 83 | password_obj = Password(account_id=account_obj.id, password=password) 84 | 85 | db_session.add(email_obj) 86 | db_session.add(password_obj) 87 | db_session.commit() 88 | 89 | # Send registration email. 90 | if send_registration_email: 91 | token = create_token_from_id(email_obj.id) 92 | registration_link = "{}/{}/verify?token={}".format( 93 | FRONTEND_BASE_URL, email_obj.id, token 94 | ) 95 | send_email( 96 | to_email=email, 97 | subject="Welcome!", 98 | body="""Weclome to the website. 99 |

Please use the following link to continue your registration. 100 |

{} 101 | """.format( 102 | registration_link, registration_link 103 | ), 104 | ) 105 | 106 | db_session.refresh(account_obj) 107 | 108 | return account_obj 109 | 110 | 111 | #### Email Addresses 112 | 113 | 114 | def get_email_addresses( 115 | db_session: Session, account_id: int = None, skip: int = 0, limit: int = 100 116 | ): 117 | return ( 118 | db_session.query(EmailAddress) 119 | .filter(EmailAddress.account_id == account_id) 120 | .all() 121 | or [] 122 | ) 123 | 124 | 125 | def create_email_address( 126 | db_session: Session, 127 | email: str, 128 | account_id: int, 129 | send_verification_email: bool = True, 130 | ): 131 | """Add an email_address to a users account.""" 132 | email_obj = EmailAddress( 133 | account_id=account_id, email=email, primary=False, verified=False 134 | ) 135 | 136 | db_session.add(email_obj) 137 | db_session.commit() 138 | db_session.refresh(email_obj) 139 | 140 | # Send verification email. 141 | if send_verification_email: 142 | token = create_token_from_id(email_obj.id) 143 | verification_link = "{}/{}/verify?token={}".format( 144 | FRONTEND_BASE_URL, email_obj.id, token 145 | ) 146 | send_email( 147 | to_email=email, 148 | subject="Verify your email!", 149 | body="""Email Verification. 150 |

Please use the following link to verify your email address. 151 |

{} 152 | """.format( 153 | verification_link, verification_link 154 | ), 155 | ) 156 | 157 | return email_obj 158 | 159 | 160 | def create_token_from_id(id): 161 | """ 162 | Create a token that can be used to verify a email address. 163 | 164 | Expires in 1 hour. 165 | """ 166 | to_encode = { 167 | "id": id, 168 | } 169 | expire = datetime.utcnow() + timedelta(minutes=EMAIL_TOKEN_EXPIRE_MINUTES) 170 | to_encode.update({"exp": expire}) 171 | encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) 172 | return encoded_jwt.decode("utf-8") 173 | 174 | 175 | def get_id_from_token(token): 176 | """Get an id from a signed token.""" 177 | try: 178 | token = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) 179 | except jwt.ExpiredSignatureError: 180 | # Signature has expired 181 | return False 182 | except Exception: 183 | return False 184 | 185 | return token.get("id") 186 | 187 | 188 | def mark_email_as_verified(db_session: Session, email_id: int): 189 | """Mark an email id as verified and account as active.""" 190 | email_obj = db_session.query(EmailAddress).get(email_id) 191 | 192 | # Mark email as verified. 193 | email_obj.verified = True 194 | email_obj.verified_on = datetime.now() 195 | 196 | db_session.add(email_obj) 197 | db_session.commit() 198 | 199 | 200 | def verify_email_address(db_session: Session, token: str): 201 | email_id = get_id_from_token(token) 202 | if not email_id: 203 | raise NotImplementedError("Invalid token.") 204 | 205 | mark_email_as_verified(db_session, email_id) 206 | -------------------------------------------------------------------------------- /app/controllers/auth.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | from fastapi import Depends, FastAPI, HTTPException 4 | from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm 5 | from sqlalchemy.orm import Session 6 | from starlette.status import HTTP_401_UNAUTHORIZED 7 | import jwt 8 | 9 | from app.dependencies import get_db 10 | from app.schemas.auth import Token 11 | from app.schemas.account import Account 12 | from app.controllers.account import get_account 13 | from app.controllers.account import get_account_by_email 14 | from app.models.account import Password 15 | from app.settings import SECRET_KEY 16 | from app.settings import ACCESS_TOKEN_EXPIRE_MINUTES 17 | 18 | # to get a string like this run: 19 | # openssl rand -hex 32 20 | ALGORITHM = "HS256" 21 | 22 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token") 23 | 24 | 25 | def authenticate_user(db_session: Session, username: str, plaintext_password: str): 26 | account_obj = get_account_by_email(db_session, email=username) 27 | 28 | # Must have an account. 29 | if not account_obj: 30 | return False 31 | 32 | password_obj = ( 33 | db_session.query(Password) 34 | .filter_by(account_id=account_obj.id) 35 | .order_by(Password.created_on.desc()) 36 | .first() 37 | ) 38 | if not password_obj.is_correct_password(plaintext_password): 39 | return False 40 | return account_obj 41 | 42 | 43 | def create_access_token(*, data: dict, expires_delta: timedelta = None): 44 | to_encode = data.copy() 45 | if expires_delta: 46 | expire = datetime.utcnow() + expires_delta 47 | else: 48 | expire = datetime.utcnow() + timedelta(minutes=15) 49 | to_encode.update({"exp": expire}) 50 | encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) 51 | return encoded_jwt 52 | -------------------------------------------------------------------------------- /app/controllers/billing.py: -------------------------------------------------------------------------------- 1 | import stripe 2 | 3 | from app.settings import STRIPE_API_SECRET 4 | from app.settings import STRIPE_DEFAULT_PLAN_ID 5 | 6 | stripe.api_key = STRIPE_API_SECRET 7 | 8 | def create_customer(email, full_name, tenant_obj): 9 | """Create a customer object in stripe.""" 10 | 11 | customer_resp = stripe.Customer.create( 12 | email=email, 13 | description="Customer for {}".format(email), 14 | name=full_name, 15 | metadata={"tenant_id": tenant_obj.id}, 16 | ) 17 | 18 | # TODO: Record the customer_id from stripe. 19 | return customer_resp 20 | 21 | 22 | def create_subscription(customer_stripe_id, plan_stripe_ids): 23 | """Subscribe customer to strip plan.""" 24 | plan_stripe_ids = [STRIPE_DEFAULT_PLAN_ID] 25 | 26 | return stripe.Subscription.create( 27 | customer=customer_stripe_id, 28 | items=[{"plan": plan_id} for plan_id in plan_stripe_ids], 29 | ) 30 | 31 | 32 | def cancel_subscription(subscription_stripe_id): 33 | """Cancel someones subscription.""" 34 | stripe.Subscription.delete(subscription_stripe_id) 35 | 36 | 37 | def get_product_plans(product_strip_id): 38 | pass 39 | -------------------------------------------------------------------------------- /app/controllers/members.py: -------------------------------------------------------------------------------- 1 | """ 2 | Control the members that are apart of a Tenant. 3 | """ 4 | from sqlalchemy.orm import Session 5 | from fastapi import Depends 6 | 7 | from app.schemas import tenant as schemas 8 | from app.models.tenant import Tenant, TenantAccount 9 | from app.models.account import EmailAddress 10 | from app.controllers.billing import stripe 11 | from app.controllers.account import create_account 12 | from app.controllers.account import get_account_by_email 13 | from app.utils.email import send_email 14 | from app.settings import FRONTEND_BASE_URL 15 | from app.settings import LOGIN_URL_PATH 16 | 17 | 18 | def get_members(db_session: Session, tenant_id: int, skip: int = 0, limit: int = 100): 19 | return ( 20 | db_session.query(TenantAccount) 21 | .filter(TenantAccount.tenant_id == tenant_id) 22 | .offset(skip) 23 | .limit(limit) 24 | .all() 25 | ) 26 | 27 | 28 | def get_member_by_email(db_session: Session, tenant_id: int, email: str): 29 | """Return the member if they are apart of this tenant.""" 30 | account_obj = get_account_by_email(db_session, email) 31 | # Check if email address exists. 32 | return ( 33 | db_session.query(TenantAccount) 34 | .filter( 35 | TenantAccount.tenant_id == tenant_id, TenantAccount.account == account_obj 36 | ) 37 | .first() 38 | ) 39 | 40 | 41 | def add_member( 42 | db_session: Session, tenant_id: int, email: str, do_send_email: bool = True 43 | ): 44 | """Add a new member to the tenant.""" 45 | tenant_obj = db_session.query(Tenant).get(tenant_id) 46 | # If email address: 47 | email_obj = ( 48 | db_session.query(EmailAddress).filter(EmailAddress.email == email).first() 49 | ) 50 | if email_obj and email_obj.account: 51 | # Account already exists, go ahead and add them. 52 | tenant_account_obj = TenantAccount() 53 | tenant_account_obj.tenant_id = tenant_id 54 | tenant_account_obj.account_id = email_obj.account_id 55 | db_session.add(tenant_account_obj) 56 | db_session.commit() 57 | 58 | # TODO: Send email telling them we added them. 59 | if do_send_email: 60 | # send_email() 61 | # Send the email! 62 | send_email( 63 | to_email=email, 64 | subject=f"You've been added to {tenant_obj.name}", 65 | body=( 66 | f"Weclome to {tenant_obj.name}." 67 | f"

You have been invited into the new group. Please use the link below to login." 68 | f"

Login" 69 | ), 70 | ) 71 | else: 72 | # Never been apart of this site. 73 | # Create relationship 74 | tenant_account_obj = TenantAccount() 75 | tenant_account_obj.tenant_id = tenant_id 76 | tenant_account_obj.email_address = email 77 | db_session.add(tenant_account_obj) 78 | db_session.commit() 79 | # Send registration invite. 80 | -------------------------------------------------------------------------------- /app/controllers/tenant.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.orm import Session 2 | from fastapi import Depends 3 | 4 | from app.schemas import tenant as schemas 5 | from app.models.tenant import Tenant, TenantAccount 6 | 7 | from app.controllers.billing import stripe 8 | from app.controllers.account import create_account 9 | 10 | 11 | def get_tenant(db_session: Session, id: int): 12 | return db_session.query(Tenant).filter(Tenant.id == id).first() 13 | 14 | 15 | def get_tenants(db_session: Session, skip: int = 0, limit: int = 100): 16 | return db_session.query(Tenant).offset(skip).limit(limit).all() 17 | 18 | 19 | def add_account_to_tenant(db_session: Session, account_id, tenant_id): 20 | """Create relationship between tenant and account.""" 21 | 22 | tenant_account_obj = TenantAccount(tenant_id=tenant_id, account_id=account_id) 23 | 24 | db_session.add(tenant_account_obj) 25 | db_session.commit() 26 | return tenant_account_obj 27 | 28 | 29 | def create_tenant_and_account( 30 | db_session: Session, 31 | name: str, 32 | slug: str, 33 | first_name: str, 34 | last_name: str, 35 | email: str, 36 | password: str, 37 | is_active: bool = False, 38 | is_verified: bool = False, 39 | do_send_emails: bool = True 40 | ): 41 | """Create a tenant and an account.""" 42 | 43 | tenant_obj = Tenant(name=name, slug=slug, billing_email=email) 44 | db_session.add(tenant_obj) 45 | db_session.flush() 46 | 47 | # New tenant = New Customer in stripe. 48 | customer_resp = stripe.Customer.create( 49 | email=email, 50 | description="Customer for {}<{}>".format(name, email), 51 | name=name, 52 | metadata={"tenant_id": tenant_obj.id}, 53 | ) 54 | 55 | # Record the Customer ID from stripe. 56 | tenant_obj.stripe_customer_id = customer_resp.get("id") 57 | 58 | db_session.commit() 59 | 60 | # Create account 61 | account_obj = create_account( 62 | db_session, 63 | first_name, 64 | last_name, 65 | email, 66 | password, 67 | is_active=is_active, 68 | is_verified=is_verified, 69 | send_registration_email=do_send_emails 70 | ) 71 | 72 | # Add relationship between account to tenant. 73 | add_account_to_tenant(db_session, account_obj.id, tenant_obj.id) 74 | 75 | db_session.refresh(tenant_obj) 76 | 77 | return tenant_obj 78 | 79 | 80 | def get_tenant_by_name(db_session: Session, name: str): 81 | """Get a tenant by name.""" 82 | return db_session.query(Tenant).filter(Tenant.name == name).first() 83 | -------------------------------------------------------------------------------- /app/database.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine 2 | from sqlalchemy.ext.declarative import declarative_base 3 | from sqlalchemy.orm import sessionmaker 4 | 5 | # SQLALCHEMY_DATABASE_URL = "sqlite://" 6 | SQLALCHEMY_DATABASE_URL = "postgresql://postgres:password@db/db" 7 | 8 | 9 | engine = create_engine(SQLALCHEMY_DATABASE_URL) 10 | 11 | DBSession = sessionmaker(autocommit=False, autoflush=False, bind=engine) 12 | 13 | Base = declarative_base() 14 | Base.metadata.bind = engine 15 | -------------------------------------------------------------------------------- /app/dependencies/__init__.py: -------------------------------------------------------------------------------- 1 | from app.database import DBSession 2 | 3 | # Dependency 4 | def get_db(): 5 | try: 6 | db_session = DBSession() 7 | yield db_session 8 | finally: 9 | db_session.close() 10 | -------------------------------------------------------------------------------- /app/dependencies/auth.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends, FastAPI, HTTPException 2 | 3 | from sqlalchemy.orm import Session 4 | from starlette.status import HTTP_401_UNAUTHORIZED 5 | import jwt 6 | from jwt import PyJWTError 7 | 8 | from app.dependencies import get_db 9 | from app.controllers.auth import oauth2_scheme, SECRET_KEY, ALGORITHM, get_account 10 | from app.schemas.auth import TokenData 11 | 12 | 13 | async def get_current_account( 14 | db_session: Session = Depends(get_db), token: str = Depends(oauth2_scheme) 15 | ): 16 | credentials_exception = HTTPException( 17 | status_code=HTTP_401_UNAUTHORIZED, 18 | detail="Could not validate credentials", 19 | headers={"WWW-Authenticate": "Bearer"}, 20 | ) 21 | try: 22 | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) 23 | account_id: str = payload.get("sub") 24 | if account_id is None: 25 | raise credentials_exception 26 | token_data = TokenData(account_id=account_id) 27 | except PyJWTError: 28 | raise credentials_exception 29 | account = get_account(db_session, id=token_data.account_id) 30 | if account is None: 31 | raise credentials_exception 32 | return account 33 | 34 | 35 | # async def get_current_active_account(current_user: Account = Depends(get_current_account)): 36 | # if current_user.disabled: 37 | # raise HTTPException(status_code=400, detail="Inactive user") 38 | # return current_user 39 | -------------------------------------------------------------------------------- /app/dependencies/tenant.py: -------------------------------------------------------------------------------- 1 | from fastapi import Depends, FastAPI, HTTPException 2 | 3 | from sqlalchemy.orm import Session 4 | 5 | from app.dependencies import get_db 6 | from app.models.tenant import Tenant 7 | 8 | 9 | async def get_tenant(tenant_id: int = None, db_session: Session = Depends(get_db)): 10 | """Get the tenant id from the url.""" 11 | if not tenant_id: 12 | return None 13 | return db_session.query(Tenant).get(tenant_id) 14 | -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fastapi import Depends, FastAPI, Header, HTTPException 4 | import uvicorn 5 | from sqlalchemy.orm import Session 6 | 7 | from app.dependencies import get_db 8 | from app.dependencies.auth import get_current_account 9 | from app.database import engine, Base 10 | 11 | from app.routers import accounts 12 | from app.routers import auth 13 | from app.routers import email_addresses 14 | from app.routers import members 15 | from app.routers import tenants 16 | 17 | # Create tables in database. 18 | Base.metadata.create_all(bind=engine) 19 | 20 | # Initialize FastAPI 21 | app = FastAPI( 22 | title="FastAPI Base", 23 | description="This is a base app to be used in the future for real SAAS apps and hackathons.", 24 | version="0.0.1", 25 | docs_url="/docs", 26 | redoc_url=None, 27 | ) 28 | 29 | # Startup Actions 30 | @app.on_event("startup") 31 | async def create_admin(): 32 | """If admin account doesn't exist, create it.""" 33 | from app.database import DBSession 34 | from app.controllers.account import create_account 35 | from app.controllers.account import get_account_by_email 36 | 37 | db_session = DBSession() 38 | account_data = { 39 | "email": "admin@example.com", 40 | "password": "password123", 41 | "first_name": "Admin", 42 | "last_name": "Istrator", 43 | "is_system_admin": True, 44 | "is_active": True, 45 | } 46 | account_obj = get_account_by_email(db_session, email=account_data["email"]) 47 | if account_obj: 48 | return 49 | 50 | create_account(db_session, **account_data) 51 | db_session.close() 52 | 53 | 54 | # Add routers 55 | app.include_router( 56 | auth.router, 57 | prefix="/auth", 58 | tags=["auth"], 59 | responses={404: {"description": "Not found"}}, 60 | ) 61 | app.include_router( 62 | accounts.router, 63 | prefix="/accounts", 64 | tags=["accounts"], 65 | dependencies=[Depends(get_current_account)], 66 | responses={404: {"description": "Not found"}}, 67 | ) 68 | app.include_router( 69 | email_addresses.router, 70 | prefix="/email_addresses", 71 | tags=["email_addresses"], 72 | responses={404: {"description": "Not found"}}, 73 | ) 74 | app.include_router( 75 | members.router, 76 | prefix="/members", 77 | tags=["members"], 78 | dependencies=[Depends(get_current_account)], 79 | responses={404: {"description": "Not found"}}, 80 | ) 81 | app.include_router( 82 | tenants.router, 83 | prefix="/tenants", 84 | tags=["tenants"], 85 | dependencies=[Depends(get_current_account)], 86 | responses={404: {"description": "Not found"}}, 87 | ) 88 | 89 | 90 | if __name__ == "__main__": 91 | uvicorn.run(app, host="0.0.0.0", port=8000) 92 | -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ryanb58/fastapi-saas-base/c985834dd3c5d84abd208bd7dae42a73b8fc8f1e/app/models/__init__.py -------------------------------------------------------------------------------- /app/models/account.py: -------------------------------------------------------------------------------- 1 | """Account based models.""" 2 | from sqlalchemy import ( 3 | Boolean, 4 | Column, 5 | ForeignKey, 6 | Integer, 7 | String, 8 | DateTime, 9 | LargeBinary, 10 | ) 11 | from sqlalchemy.orm import relationship, backref 12 | from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method 13 | from passlib.context import CryptContext 14 | 15 | from app.models.base import BaseModel 16 | 17 | 18 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") 19 | 20 | 21 | class Account(BaseModel): 22 | __tablename__ = "accounts" 23 | 24 | id = Column(Integer, primary_key=True, index=True) 25 | 26 | first_name = Column(String(100)) 27 | last_name = Column(String(100)) 28 | 29 | is_system_admin = Column(Boolean, default=False) 30 | is_active = Column(Boolean, default=False) 31 | 32 | def __repr__(self): 33 | return "".format(self.id, self.first_name, self.last_name) 34 | 35 | @property 36 | def email(self): 37 | return self.email_addresses.filter_by(primary=True).first().email 38 | 39 | @property 40 | def primary_email_address(self): 41 | return self.email_addresses.filter_by(primary=True).first() 42 | 43 | @property 44 | def full_name(self): 45 | return "{} {}".format(self.first_name, self.last_name) 46 | 47 | 48 | class EmailAddress(BaseModel): 49 | __tablename__ = "email_addresses" 50 | 51 | id = Column(Integer, primary_key=True) 52 | # uuid = Column( 53 | # UUID(as_uuid=True), unique=True, nullable=False, default=uuid.uuid4 54 | # ) 55 | 56 | account_id = Column( 57 | Integer, ForeignKey("accounts.id", ondelete="CASCADE"), nullable=True 58 | ) 59 | account = relationship( 60 | "Account", 61 | backref=backref("email_addresses", passive_deletes=True, lazy="dynamic"), 62 | ) 63 | 64 | email = Column(String(256), unique=True, nullable=False) 65 | primary = Column(Boolean(), nullable=True) 66 | verified = Column(Boolean(), nullable=True) 67 | verified_on = Column(DateTime, nullable=True) 68 | 69 | def __repr__(self): 70 | return "".format(self.email) 71 | 72 | 73 | class Password(BaseModel): 74 | __tablename__ = "passwords" 75 | 76 | id = Column(Integer, primary_key=True) 77 | 78 | account_id = Column( 79 | Integer, ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False 80 | ) 81 | account = relationship( 82 | "Account", backref=backref("passwords", passive_deletes=True), lazy=True 83 | ) 84 | 85 | # _password = Column(LargeBinary(256), nullable=False) 86 | _password = Column(String(256), nullable=False) 87 | 88 | def __repr__(self): 89 | return "".format(self.account_id) 90 | 91 | def validate_password(self, plaintext_password): 92 | """If invalid raise ValidationError, else return True""" 93 | 94 | if len(plaintext_password) < 8: 95 | raise Exception("Password must be at 8 or more characters long.") 96 | 97 | return True 98 | 99 | @hybrid_property 100 | def password(self): 101 | return self._password 102 | 103 | @password.setter 104 | def password(self, plain_password): 105 | self.validate_password(plain_password) 106 | self._password = pwd_context.hash(plain_password) 107 | 108 | @hybrid_method 109 | def is_correct_password(self, plain_password): 110 | return pwd_context.verify(plain_password, self.password) 111 | 112 | 113 | # class PasswordReset(BaseModel): 114 | # id = Column(Integer, primary_key=True) 115 | 116 | # account_id = Column( 117 | # Integer, ForeignKey("account.id", ondelete="CASCADE"), nullable=False 118 | # ) 119 | # account = relationship( 120 | # "Account", 121 | # backref=backref("password_resets", passive_deletes=True), 122 | # lazy=True, 123 | # ) 124 | 125 | # token = Column(String(1024), nullable=False) 126 | 127 | # def __repr__(self): 128 | # return "".format(self.account_id) 129 | 130 | # @hybrid_method 131 | # def is_valid(self): 132 | # from flask_jwt_extended import decode_token 133 | 134 | # try: 135 | # decode_token(self.token) 136 | # return True 137 | # except (jwt.DecodeError, jwt.ExpiredSignatureError): 138 | # return False 139 | -------------------------------------------------------------------------------- /app/models/base.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from sqlalchemy import Column, DateTime 4 | 5 | from app.database import Base 6 | 7 | 8 | class BaseModel(Base): 9 | """Base data model for all objects""" 10 | 11 | __abstract__ = True 12 | 13 | created_on = Column(DateTime, nullable=False, default=datetime.datetime.utcnow) 14 | updated_on = Column(DateTime, nullable=True, onupdate=datetime.datetime.utcnow) 15 | -------------------------------------------------------------------------------- /app/models/tenant.py: -------------------------------------------------------------------------------- 1 | """Tenant models.""" 2 | from sqlalchemy import ( 3 | Boolean, 4 | Column, 5 | ForeignKey, 6 | Integer, 7 | String, 8 | DateTime, 9 | LargeBinary, 10 | ) 11 | from sqlalchemy.orm import relationship, backref 12 | from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method 13 | 14 | from app.models.base import BaseModel 15 | 16 | 17 | class Tenant(BaseModel): 18 | """A customer/workspace in the system.""" 19 | 20 | __tablename__ = "tenants" 21 | 22 | id = Column(Integer, primary_key=True, index=True) 23 | name = Column(String(128)) 24 | slug = Column(String(128), unique=True) # Slug of the name for the URL. 25 | 26 | billing_email = Column(String(256), unique=True, nullable=True) 27 | stripe_customer_id = Column(String(128), unique=True, nullable=True) 28 | stripe_subscription_id = Column(String(128), unique=True, nullable=True) 29 | 30 | def __repr__(self): 31 | return "".format(self.id, self.name) 32 | 33 | 34 | class TenantAccount(BaseModel): 35 | """ 36 | Through table for M2M relationship between a Tenant and Accounts. 37 | """ 38 | 39 | __tablename__ = "tenant_account" 40 | id = Column(Integer, primary_key=True, index=True) 41 | 42 | tenant_id = Column( 43 | Integer, ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False 44 | ) 45 | tenant = relationship( 46 | "Tenant", 47 | backref=backref("accounts", passive_deletes=True, lazy="dynamic"), 48 | lazy=True, 49 | ) 50 | 51 | account_id = Column( 52 | Integer, ForeignKey("accounts.id", ondelete="CASCADE"), nullable=True 53 | ) 54 | account = relationship( 55 | "Account", 56 | backref=backref("tenants", passive_deletes=True, lazy="dynamic"), 57 | lazy=True, 58 | ) 59 | 60 | # Use this field in case the user doesn't have an account yet. 61 | email_address = Column(String, nullable=True) 62 | 63 | def __repr__(self): 64 | return "".format( 65 | self.id, self.tenant_id, self.account_id 66 | ) 67 | -------------------------------------------------------------------------------- /app/routers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ryanb58/fastapi-saas-base/c985834dd3c5d84abd208bd7dae42a73b8fc8f1e/app/routers/__init__.py -------------------------------------------------------------------------------- /app/routers/accounts.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fastapi import APIRouter, Depends, HTTPException 4 | from sqlalchemy.orm import Session 5 | 6 | from app.dependencies import get_db 7 | from app.schemas import account as schemas 8 | from app.controllers.account import get_account_by_email 9 | from app.controllers.account import get_account 10 | from app.controllers.account import get_accounts 11 | from app.controllers.account import create_account 12 | from app.dependencies.auth import get_current_account 13 | 14 | router = APIRouter() 15 | 16 | 17 | @router.post("/", response_model=schemas.Account) 18 | def create_one( 19 | account: schemas.AccountCreate, 20 | db_session: Session = Depends(get_db), 21 | current_user: schemas.Account = Depends(get_current_account), 22 | ): 23 | db_account = get_account_by_email(db_session, email=account.email) 24 | if db_account: 25 | raise HTTPException(status_code=400, detail="Email already registered") 26 | 27 | # Can not create a system admin if you yourself are not one. 28 | if not current_user.is_system_admin: 29 | account.is_system_admin = False 30 | 31 | return create_account( 32 | db_session, 33 | account.first_name, 34 | account.last_name, 35 | account.email, 36 | account.password, 37 | account.is_system_admin, 38 | ) 39 | 40 | 41 | @router.get("/", response_model=List[schemas.Account]) 42 | def read_many(skip: int = 0, limit: int = 100, db_session: Session = Depends(get_db)): 43 | accounts = get_accounts(db_session, skip=skip, limit=limit) 44 | return accounts 45 | 46 | 47 | @router.get("/me", response_model=schemas.Account) 48 | async def read_me(current_user: schemas.Account = Depends(get_current_account)): 49 | """Get logged in user details.""" 50 | return current_user 51 | 52 | 53 | @router.get("/{id}", response_model=schemas.Account) 54 | def read_one(id: int, db_session: Session = Depends(get_db)): 55 | db_account = get_account(db_session, id=id) 56 | if db_account is None: 57 | raise HTTPException(status_code=404, detail="User not found") 58 | return db_account 59 | -------------------------------------------------------------------------------- /app/routers/auth.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | from fastapi import APIRouter, Depends, HTTPException 4 | from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm 5 | from sqlalchemy.orm import Session 6 | from starlette.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED, HTTP_201_CREATED, HTTP_409_CONFLICT 7 | 8 | from app.dependencies import get_db 9 | from app.controllers.auth import ( 10 | ACCESS_TOKEN_EXPIRE_MINUTES, 11 | Token, 12 | authenticate_user, 13 | create_access_token, 14 | ) 15 | from app.dependencies.auth import get_current_account 16 | from app.schemas.tenant import TenantAccountCreate 17 | from app.controllers.tenant import create_tenant_and_account 18 | from app.controllers.account import mark_account_as_verified_and_active 19 | 20 | router = APIRouter() 21 | 22 | 23 | @router.post("/verify", status_code=HTTP_200_OK) 24 | def verify_account(token: str = "", db_session: Session = Depends(get_db)): 25 | mark_account_as_verified_and_active(db_session, token) 26 | return 27 | 28 | 29 | @router.post("/token", response_model=Token) 30 | async def login_for_access_token( 31 | form_data: OAuth2PasswordRequestForm = Depends(), 32 | db_session: Session = Depends(get_db), 33 | ): 34 | account_obj = authenticate_user(db_session, form_data.username, form_data.password) 35 | if not account_obj: 36 | raise HTTPException( 37 | status_code=HTTP_401_UNAUTHORIZED, 38 | detail="Incorrect username or password", 39 | headers={"WWW-Authenticate": "Bearer"}, 40 | ) 41 | 42 | # Account must be active. 43 | if not account_obj.is_active: 44 | raise HTTPException( 45 | status_code=HTTP_409_CONFLICT, detail="Account disabled", 46 | ) 47 | 48 | access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) 49 | access_token = create_access_token( 50 | data={"sub": account_obj.id}, expires_delta=access_token_expires 51 | ) 52 | return {"access_token": access_token, "token_type": "bearer"} 53 | 54 | 55 | @router.post("/register", status_code=HTTP_201_CREATED) 56 | def register( 57 | tenant_account: TenantAccountCreate, db_session: Session = Depends(get_db) 58 | ): 59 | tenant_obj = create_tenant_and_account( 60 | db_session, 61 | tenant_account.name, 62 | tenant_account.slug, 63 | tenant_account.first_name, 64 | tenant_account.last_name, 65 | tenant_account.email, 66 | tenant_account.password, 67 | ) 68 | if not tenant_obj: 69 | raise HTTPException( 70 | status_code=HTTP_401_UNAUTHORIZED, 71 | detail="Error creating new account/tenant.", 72 | ) 73 | return {"msg": "Please check your email."} 74 | -------------------------------------------------------------------------------- /app/routers/email_addresses.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fastapi import APIRouter, Depends, HTTPException 4 | from sqlalchemy.orm import Session 5 | from starlette.status import HTTP_201_CREATED, HTTP_200_OK 6 | 7 | from app.dependencies import get_db 8 | from app.schemas import account as schemas 9 | from app.controllers.account import ( 10 | get_email_addresses, 11 | create_email_address, 12 | get_account_by_email, 13 | verify_email_address, 14 | ) 15 | from app.dependencies.auth import get_current_account 16 | 17 | router = APIRouter() 18 | 19 | 20 | @router.post("/verify", status_code=HTTP_200_OK) 21 | def verify_email(token: str = "", db_session: Session = Depends(get_db)): 22 | verify_email_address(db_session, token) 23 | return 24 | 25 | 26 | @router.post("", response_model=schemas.EmailAddress, status_code=HTTP_201_CREATED) 27 | def create_one( 28 | email: schemas.EmailAddressCreate, 29 | db_session: Session = Depends(get_db), 30 | current_user: schemas.Account = Depends(get_current_account), 31 | ): 32 | # Can not add an email address that is already in use. 33 | db_email = get_account_by_email(db_session, email=email.email) 34 | if db_email: 35 | raise HTTPException(status_code=400, detail="Email already registered") 36 | 37 | # Set the account_id from the current logged in user. 38 | account_id = current_user.id 39 | 40 | return create_email_address(db_session, email.email, account_id) 41 | 42 | 43 | @router.get("", response_model=List[schemas.EmailAddress], status_code=HTTP_200_OK) 44 | def read_many( 45 | skip: int = 0, 46 | limit: int = 100, 47 | db_session: Session = Depends(get_db), 48 | current_user: schemas.Account = Depends(get_current_account), 49 | ): 50 | return get_email_addresses( 51 | db_session, account_id=current_user.id, skip=skip, limit=limit 52 | ) 53 | -------------------------------------------------------------------------------- /app/routers/members.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fastapi import APIRouter, Depends, HTTPException 4 | from sqlalchemy.orm import Session 5 | from starlette.status import HTTP_201_CREATED, HTTP_200_OK 6 | 7 | from app.dependencies import get_db 8 | from app.schemas import member as schemas 9 | from app.schemas.account import Account as AccountSchema 10 | from app.controllers.members import get_members 11 | from app.controllers.members import get_member_by_email 12 | from app.controllers.members import add_member 13 | from app.dependencies.auth import get_current_account 14 | 15 | router = APIRouter() 16 | 17 | 18 | @router.get("", response_model=List[schemas.TenantAccount], status_code=HTTP_200_OK) 19 | def read_many( 20 | skip: int = 0, 21 | limit: int = 100, 22 | tenant_id: int = None, 23 | db_session: Session = Depends(get_db), 24 | ): 25 | return get_members(db_session, tenant_id=tenant_id, skip=skip, limit=limit) 26 | 27 | 28 | @router.post("", response_model=schemas.MemberCreate, status_code=HTTP_201_CREATED) 29 | def add_one( 30 | member: schemas.MemberCreate, 31 | db_session: Session = Depends(get_db), 32 | current_user: AccountSchema = Depends(get_current_account), 33 | ): 34 | # Make sure email isn't already added. 35 | db_member = get_member_by_email( 36 | db_session, tenant_id=member.tenant_id, email=member.email 37 | ) 38 | if db_member: 39 | raise HTTPException(status_code=400, detail="Account has already been invited.") 40 | 41 | # Invite the user. 42 | return add_member(db_session, tenant_id=member.tenant_id, email=member.email) 43 | -------------------------------------------------------------------------------- /app/routers/tenants.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from fastapi import APIRouter, Depends, HTTPException 4 | from sqlalchemy.orm import Session 5 | 6 | from app.dependencies import get_db 7 | import app.schemas as schemas 8 | from app.controllers.tenant import get_tenant_by_name 9 | from app.controllers.tenant import get_tenant 10 | from app.controllers.tenant import get_tenants 11 | 12 | # from app.controllers.tenant import create_tenant 13 | from app.dependencies.auth import get_current_account 14 | 15 | router = APIRouter() 16 | 17 | # @router.post("/", response_model=schemas.Tenant) 18 | # def create_one(tenant: schemas.TenantCreate, db_session: Session = Depends(get_db), current_user: schemas.Account = Depends(get_current_account)): 19 | # db_tenant = get_tenant_by_name(db_session, name=tenant.name) 20 | # if db_tenant: 21 | # raise HTTPException(status_code=400, detail="Name already registered") 22 | 23 | # return create_tenant(db_session, tenant=tenant) 24 | 25 | 26 | @router.get("/", response_model=List[schemas.Tenant]) 27 | def read_many(skip: int = 0, limit: int = 100, db_session: Session = Depends(get_db)): 28 | tenants = get_tenants(db_session, skip=skip, limit=limit) 29 | return tenants 30 | 31 | 32 | @router.get("/{id}", response_model=schemas.TenantDetails) 33 | def read_one(id: int, db_session: Session = Depends(get_db)): 34 | db_tenant = get_tenant(db_session, id=id) 35 | if db_tenant is None: 36 | raise HTTPException(status_code=404, detail="Tenant not found") 37 | return db_tenant 38 | -------------------------------------------------------------------------------- /app/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | from .account import Account, AccountCreate 2 | from .auth import Token, TokenData 3 | from .tenant import Tenant, TenantCreate, TenantDetails 4 | -------------------------------------------------------------------------------- /app/schemas/account.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class AccountBase(BaseModel): 8 | email: str 9 | first_name: str 10 | last_name: str 11 | 12 | 13 | class AccountCreate(AccountBase): 14 | password: str 15 | is_system_admin: bool 16 | 17 | 18 | class Account(AccountBase): 19 | id: int 20 | is_system_admin: bool 21 | 22 | class Config: 23 | orm_mode = True 24 | 25 | 26 | class EmailAddressBase(BaseModel): 27 | email: str 28 | 29 | 30 | class EmailAddressCreate(EmailAddressBase): 31 | account_id: int = None 32 | 33 | 34 | class EmailAddress(EmailAddressBase): 35 | id: int 36 | primary: bool 37 | verified: bool 38 | verified_on: datetime = None 39 | 40 | class Config: 41 | orm_mode = True 42 | -------------------------------------------------------------------------------- /app/schemas/auth.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Token(BaseModel): 5 | access_token: str 6 | token_type: str 7 | 8 | 9 | class TokenData(BaseModel): 10 | account_id: int = None 11 | -------------------------------------------------------------------------------- /app/schemas/member.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class TenantAccount(BaseModel): 8 | id: int 9 | account_id: int 10 | tenant_id: int 11 | 12 | class Config: 13 | orm_mode = True 14 | 15 | 16 | class MemberCreate(BaseModel): 17 | tenant_id: int 18 | email: str 19 | -------------------------------------------------------------------------------- /app/schemas/tenant.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import List 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class TenantBase(BaseModel): 8 | name: str 9 | 10 | 11 | class TenantCreate(TenantBase): 12 | billing_email: str 13 | 14 | 15 | class Tenant(TenantBase): 16 | id: int 17 | slug: str 18 | 19 | class Config: 20 | orm_mode = True 21 | 22 | 23 | class TenantDetails(Tenant): 24 | billing_email: str 25 | 26 | 27 | class TenantAccountCreate(BaseModel): 28 | name: str 29 | slug: str 30 | first_name: str 31 | last_name: str 32 | email: str 33 | password: str 34 | -------------------------------------------------------------------------------- /app/settings.py: -------------------------------------------------------------------------------- 1 | from starlette.config import Config 2 | from starlette.datastructures import URL, Secret 3 | 4 | 5 | config = Config(".env") 6 | 7 | TESTING = config('TESTING', cast=bool, default=False) 8 | 9 | SECRET_KEY = config("SECRET_KEY", cast=str, default=False) 10 | 11 | FRONTEND_BASE_URL = config("FRONTEND_BASE_URL", cast=str, default="") 12 | LOGIN_URL_PATH = config("LOGIN_URL_PATH", cast=str, default="/login") 13 | 14 | EMAIL_TOKEN_EXPIRE_MINUTES = config("EMAIL_TOKEN_EXPIRE_MINUTES", cast=int, default=30) 15 | ACCESS_TOKEN_EXPIRE_MINUTES = config("ACCESS_TOKEN_EXPIRE_MINUTES", cast=int, default=30) 16 | 17 | SMTP_HOST = config("SMTP_HOST", cast=str, default=False) 18 | SMTP_PORT = config("SMTP_PORT", cast=str, default=False) 19 | SMTP_USERNAME = config("SMTP_USERNAME", cast=str, default=False) 20 | SMTP_PASSWORD = config("SMTP_PASSWORD", cast=str, default=False) 21 | 22 | STRIPE_API_SECRET = config("STRIPE_API_SECRET", cast=str, default=False) 23 | STRIPE_DEFAULT_PLAN_ID = config("STRIPE_DEFAULT_PLAN_ID", cast=str, default=False) -------------------------------------------------------------------------------- /app/templates/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ryanb58/fastapi-saas-base/c985834dd3c5d84abd208bd7dae42a73b8fc8f1e/app/templates/.DS_Store -------------------------------------------------------------------------------- /app/templates/email/empty.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Simple Transactional Email 7 | 96 | 97 | 98 | 99 | 100 | 101 | 145 | 146 | 147 |
  102 |
103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 119 | 120 | 121 | 122 |
111 | 112 | 113 | 116 | 117 |
114 |

{{ body }}

115 |
118 |
123 | 124 | 125 | 140 | 141 | 142 | 143 |
144 |
 
148 | 149 | -------------------------------------------------------------------------------- /app/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ryanb58/fastapi-saas-base/c985834dd3c5d84abd208bd7dae42a73b8fc8f1e/app/tests/__init__.py -------------------------------------------------------------------------------- /app/tests/conftests.py: -------------------------------------------------------------------------------- 1 | # From @euri10 -- https://gitter.im/tiangolo/fastapi?at=5cd915ed56271260f95275ac 2 | 3 | import asyncio 4 | from unittest import TestCase 5 | 6 | import pytest 7 | from sqlalchemy import create_engine 8 | from sqlalchemy_utils import create_database, database_exists, drop_database 9 | 10 | from starlette.config import environ 11 | from starlette.testclient import TestClient 12 | 13 | # This sets `os.environ`, but provides some additional protection. 14 | # If we placed it below the application import, it would raise an error 15 | # informing us that 'TESTING' had already been read from the environment. 16 | 17 | environ["TESTING"] = "True" 18 | environ["EMAILS_ENABLED"] = "False" 19 | 20 | from app.main import app # isort:skip 21 | from app.database import engine, Base, DBSession 22 | 23 | 24 | class TestBase(TestCase): 25 | def setUp(self): 26 | self.db_session = DBSession() 27 | self.connection = engine.connect() 28 | 29 | # # Configure Search DDL triggers. 30 | Base.metadata.drop_all(self.connection) 31 | Base.metadata.create_all(self.connection) 32 | 33 | self.client = TestClient(app) 34 | 35 | def tearDown(self): 36 | self.db_session.rollback() 37 | self.db_session.close() 38 | 39 | def create_system_admin(self, *args, **kwargs): 40 | from app.controllers.account import create_account 41 | from app.schemas.account import AccountCreate 42 | 43 | return create_account( 44 | self.db_session, 45 | first_name="Admin", 46 | last_name="Istrator", 47 | email="admin@example.com", 48 | password="password123", 49 | is_system_admin=True, 50 | is_active=True, 51 | send_registration_email=False, 52 | ) 53 | 54 | def auth_headers(self, email="admin@example.com", password="password123"): 55 | payload = {"username": email, "password": password} 56 | resp = self.client.post("/auth/token", data=payload) 57 | return {"Authorization": "Bearer " + resp.json().get("access_token")} 58 | -------------------------------------------------------------------------------- /app/tests/test_accounts.py: -------------------------------------------------------------------------------- 1 | from app.tests.conftests import TestBase 2 | 3 | 4 | class AccountsTestCase(TestBase): 5 | def test_unauth_list(self): 6 | resp = self.client.get("/accounts") 7 | assert resp.status_code == 401 8 | 9 | def test_list(self): 10 | self.create_system_admin() 11 | resp = self.client.get("/accounts", headers=self.auth_headers()) 12 | assert resp.status_code == 200 13 | -------------------------------------------------------------------------------- /app/tests/test_email_addresses.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | from app.tests.conftests import TestBase 4 | from app.models.account import Account 5 | from app.models.account import EmailAddress 6 | from app.controllers.account import create_email_address 7 | from app.controllers.account import create_token_from_id 8 | 9 | class EmailAddressTestCase(TestBase): 10 | @patch("app.controllers.account.send_email") 11 | def test_create(self, mock_send_email): 12 | self.create_system_admin() 13 | assert self.db_session.query(EmailAddress).count() == 1 14 | payload = {"email": "another@example.com"} 15 | response = self.client.post( 16 | "/email_addresses", json=payload, headers=self.auth_headers() 17 | ) 18 | assert response.status_code == 201 19 | assert self.db_session.query(EmailAddress).count() == 2 20 | # Adding a new email address should send a verification email. 21 | mock_send_email.assert_called_once() 22 | 23 | def test_get(self): 24 | self.create_system_admin() 25 | payload = {"email": "another@example.com"} 26 | response = self.client.get( 27 | "/email_addresses", json=payload, headers=self.auth_headers() 28 | ) 29 | assert response.status_code == 200 30 | assert len(response.json()) == 1 31 | 32 | def test_verify(self): 33 | admin_obj = self.create_system_admin() 34 | 35 | # Add new unverified email address to user. 36 | email_obj = create_email_address( 37 | self.db_session, 38 | "michael.scott@gmail.com", 39 | admin_obj.id, 40 | send_verification_email=False) 41 | 42 | token = create_token_from_id(email_obj.id) 43 | response = self.client.post( 44 | "/email_addresses/verify?token={}".format(token) 45 | ) 46 | assert response.status_code == 200 47 | -------------------------------------------------------------------------------- /app/tests/test_login.py: -------------------------------------------------------------------------------- 1 | from app.tests.conftests import TestBase 2 | from app.models.account import Account 3 | 4 | 5 | class LoginTestCase(TestBase): 6 | def test_success(self): 7 | self.create_system_admin() 8 | payload = {"username": "admin@example.com", "password": "password123"} 9 | response = self.client.post("/auth/token", data=payload) 10 | assert response.status_code == 200 11 | assert response.json().get("access_token", False) 12 | assert response.json().get("token_type", False) == "bearer" 13 | 14 | def test_invalid(self): 15 | self.create_system_admin() 16 | payload = {"username": "admin@example.com", "password": "password"} 17 | response = self.client.post("/auth/token", data=payload) 18 | assert response.status_code == 401 19 | -------------------------------------------------------------------------------- /app/tests/test_members.py: -------------------------------------------------------------------------------- 1 | import responses 2 | from unittest.mock import patch 3 | 4 | from app.tests.conftests import TestBase 5 | from app.models.account import Account 6 | from app.models.account import EmailAddress 7 | from app.controllers.tenant import create_tenant_and_account 8 | from app.controllers.account import create_account 9 | 10 | 11 | class MembersTestCase(TestBase): 12 | @patch("app.controllers.account.send_email") 13 | @responses.activate 14 | def test_get(self, mock_send_email): 15 | # Mock out request to stripe. 16 | responses.add( 17 | responses.POST, 18 | "https://api.stripe.com/v1/customers", 19 | json={"id": 1}, 20 | status=200, 21 | ) 22 | 23 | tenant_account = { 24 | "first_name": "Andy", 25 | "last_name": "Benard", 26 | "email": "andy.bernard@example.com", 27 | "password": "password123", 28 | "name": "Dunder Mifflin Scranton", 29 | "slug": "dunder-mifflin-scranton", 30 | "is_active": True, 31 | "is_verified": True, 32 | } 33 | tenant_obj = create_tenant_and_account(self.db_session, **tenant_account) 34 | 35 | headers = self.auth_headers(email=tenant_account.get("email")) 36 | response = self.client.get( 37 | "/members?tenant_id={}".format(tenant_obj.id), headers=headers 38 | ) 39 | 40 | assert response.status_code == 200 41 | assert len(response.json()) == 1 42 | 43 | @responses.activate 44 | def test_add_existing_account_as_member(self): 45 | # Mock out request to stripe. 46 | responses.add( 47 | responses.POST, 48 | "https://api.stripe.com/v1/customers", 49 | json={"id": 1}, 50 | status=200, 51 | ) 52 | 53 | # Create the account that will join the tenant. 54 | account_obj = create_account( 55 | self.db_session, 56 | first_name="Jim", 57 | last_name="Halpert", 58 | email="jim.halpert@example.com", 59 | password="password123", 60 | is_active=True, 61 | is_verified=True, 62 | send_registration_email=False, 63 | ) 64 | 65 | # Setup tenant and admin account. 66 | tenant_account = { 67 | "first_name": "Andy", 68 | "last_name": "Benard", 69 | "email": "andy.bernard@example.com", 70 | "password": "password123", 71 | "name": "Dunder Mifflin Scranton", 72 | "slug": "dunder-mifflin-scranton", 73 | "is_active": True, 74 | "is_verified": True, 75 | "do_send_emails": False 76 | } 77 | tenant_obj = create_tenant_and_account(self.db_session, **tenant_account) 78 | 79 | # Add member 80 | payload = { 81 | "tenant_id": tenant_obj.id, 82 | "email": "jim.halpert@example.com", 83 | } 84 | headers = self.auth_headers(email=tenant_account.get("email")) 85 | response = self.client.post("/members", json=payload, headers=headers) 86 | assert response.status_code == 201 87 | 88 | headers = self.auth_headers(email=tenant_account.get("email")) 89 | response = self.client.get( 90 | "/members?tenant_id={}".format(tenant_obj.id), headers=headers 91 | ) 92 | assert response.status_code == 200 93 | assert len(response.json()) == 2 94 | 95 | # @patch("app.controllers.account.send_email") 96 | # @responses.activate 97 | # def test_add_new_account_as_member(self, mock_send_email): 98 | # # Mock out request to stripe. 99 | # responses.add( 100 | # responses.POST, 101 | # "https://api.stripe.com/v1/customers", 102 | # json={"id": 1}, 103 | # status=200, 104 | # ) 105 | 106 | # # Setup tenant and admin account. 107 | # tenant_account = { 108 | # "first_name": "Andy", 109 | # "last_name": "Benard", 110 | # "email": "andy.bernard@example.com", 111 | # "password": "password123", 112 | # "name": "Dunder Mifflin Scranton", 113 | # "slug": "dunder-mifflin-scranton", 114 | # "is_active": True, 115 | # "is_verified": True, 116 | # } 117 | # tenant_obj = create_tenant_and_account(self.db_session, **tenant_account) 118 | 119 | # # Add member 120 | # payload = { 121 | # "tenant_id": tenant_obj.id, 122 | # "email": "jim.halpert@example.com", 123 | # } 124 | # headers = self.auth_headers(email=tenant_account.get("email")) 125 | # response = self.client.post("/members", json=payload, headers=headers) 126 | # assert response.status_code == 201 127 | 128 | # headers = self.auth_headers(email=tenant_account.get("email")) 129 | # response = self.client.get( 130 | # "/members?tenant_id={}".format(tenant_obj.id), headers=headers 131 | # ) 132 | # assert response.status_code == 200 133 | # assert len(response.json()) == 2 134 | -------------------------------------------------------------------------------- /app/tests/test_register.py: -------------------------------------------------------------------------------- 1 | import responses 2 | from unittest.mock import patch 3 | 4 | from app.tests.conftests import TestBase 5 | from app.models.account import Account, EmailAddress 6 | from app.models.tenant import Tenant 7 | from app.controllers.account import ( 8 | create_token_from_id, 9 | verify_email_address, 10 | ) 11 | 12 | 13 | class RegisterTestCase(TestBase): 14 | @patch("app.controllers.account.send_email") 15 | @responses.activate 16 | def test_success(self, mock_send_email): 17 | # Mock out request to stripe. 18 | responses.add( 19 | responses.POST, 20 | "https://api.stripe.com/v1/customers", 21 | json={"id": 1}, 22 | status=200, 23 | ) 24 | 25 | # Make sure no emails have been sent. 26 | mock_send_email.assert_not_called() 27 | 28 | payload = { 29 | "first_name": "Andy", 30 | "last_name": "Benard", 31 | "email": "andy.bernard@example.com", 32 | "password": "password123", 33 | "name": "Dunder Mifflin Scranton", 34 | "slug": "dunder-mifflin-scranton", 35 | } 36 | response = self.client.post("/auth/register", json=payload) 37 | 38 | # Validate that the registration email was sent. 39 | mock_send_email.assert_called_once() 40 | 41 | assert response.status_code == 201 42 | assert self.db_session.query(Account).count() == 1 43 | assert self.db_session.query(Tenant).count() == 1 44 | assert self.db_session.query(Tenant).first().accounts.count() == 1 45 | 46 | # Should not be able to login until after they have verified their email. 47 | payload = {"username": "andy.bernard@example.com", "password": "password123"} 48 | response = self.client.post("/auth/token", data=payload) 49 | assert response.status_code == 409 50 | 51 | # Get the signed jwt and try to validate the account. 52 | email_obj = ( 53 | self.db_session.query(EmailAddress) 54 | .filter(EmailAddress.email == "andy.bernard@example.com") 55 | .first() 56 | ) 57 | token = create_token_from_id(email_obj.id) 58 | response = self.client.post( 59 | "/auth/verify?token={}".format(token) 60 | ) 61 | assert response.status_code == 200 62 | 63 | payload = {"username": "andy.bernard@example.com", "password": "password123"} 64 | response = self.client.post("/auth/token", data=payload) 65 | assert response.status_code == 200 66 | -------------------------------------------------------------------------------- /app/utils/email.py: -------------------------------------------------------------------------------- 1 | from email.mime.multipart import MIMEMultipart 2 | from email.mime.text import MIMEText 3 | from typing import Union, Optional, List 4 | import smtplib 5 | 6 | from jinja2 import Template 7 | 8 | from app.settings import TESTING 9 | from app.settings import SMTP_HOST 10 | from app.settings import SMTP_PORT 11 | from app.settings import SMTP_USERNAME 12 | from app.settings import SMTP_PASSWORD 13 | 14 | def send_email(to_email: str, subject: str, body: str): 15 | """Send an email.""" 16 | msg = MIMEMultipart("alternative") 17 | # me == the sender's email address 18 | # you == the recipient's email address 19 | msg["Subject"] = subject 20 | msg["From"] = "noreply@example.com" 21 | msg["To"] = to_email 22 | 23 | # Get the contents of the template. 24 | with open("app/templates/email/empty.html", "r") as template: 25 | # Parse 26 | template = Template(template.read()) 27 | 28 | # inject the body. 29 | text = body 30 | html = template.render(body=body) 31 | 32 | msg.attach(MIMEText(text, "plain")) 33 | msg.attach(MIMEText(html, "html")) 34 | 35 | if TESTING: 36 | # Do not send an actual email if unittesting. 37 | return 38 | 39 | smtpObj = smtplib.SMTP(SMTP_HOST) 40 | smtpObj.send_message(msg) 41 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | 5 | mail: 6 | image: djfarrelly/maildev 7 | restart: unless-stopped 8 | ports: 9 | - 1025:25 10 | - 1080:80 11 | db: 12 | image: postgres:latest 13 | restart: unless-stopped 14 | environment: 15 | POSTGRES_PASSWORD: password # username: postgres 16 | POSTGRES_DB: db 17 | ports: 18 | - 5432:5432 19 | 20 | app: 21 | build: . 22 | command: /start-reload.sh 23 | depends_on: 24 | - db 25 | - mail 26 | ports: 27 | - 80:80 28 | volumes: 29 | - ${PWD}/:/app 30 | # - ${HOME}/.aws:/root/.aws 31 | # environment: 32 | # - AWS_ACCESS_KEY_ID 33 | # - AWS_SECRET_ACCESS_KEY 34 | -------------------------------------------------------------------------------- /prestart.sh: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles==0.4.0 2 | aniso8601==7.0.0 3 | appdirs==1.4.3 4 | async-exit-stack==1.0.1 5 | async-generator==1.10 6 | attrs==19.3.0 7 | backcall==0.1.0 8 | bcrypt==3.1.7 9 | black==19.10b0 10 | certifi==2019.11.28 11 | cffi==1.13.2 12 | chardet==3.0.4 13 | Click==7.0 14 | coverage==5.0.2 15 | decorator==4.4.1 16 | dnspython==1.16.0 17 | email-validator==1.0.5 18 | fastapi==0.45.0 19 | graphene==2.1.8 20 | graphql-core==2.2.1 21 | graphql-relay==2.0.1 22 | gunicorn==19.9.0 23 | h11==0.9.0 24 | httptools==0.0.13 25 | idna==2.8 26 | importlib-metadata==1.3.0 27 | ipdb==0.12.3 28 | ipython==7.10.2 29 | ipython-genutils==0.2.0 30 | itsdangerous==1.1.0 31 | jedi==0.15.2 32 | Jinja2==2.10.3 33 | MarkupSafe==1.1.1 34 | more-itertools==8.0.2 35 | packaging==19.2 36 | parso==0.5.2 37 | passlib==1.7.2 38 | pathspec==0.7.0 39 | pexpect==4.7.0 40 | pickleshare==0.7.5 41 | pluggy==0.13.1 42 | promise==2.3 43 | prompt-toolkit==3.0.2 44 | psycopg2==2.8.4 45 | ptyprocess==0.6.0 46 | py==1.8.0 47 | pycparser==2.19 48 | pydantic==1.2 49 | Pygments==2.5.2 50 | PyJWT==1.7.1 51 | pyparsing==2.4.5 52 | pytest==5.3.2 53 | pytest-cov==2.8.1 54 | python-multipart==0.0.5 55 | PyYAML==5.2 56 | regex==2019.12.20 57 | requests==2.22.0 58 | responses==0.10.9 59 | Rx==1.6.1 60 | six==1.13.0 61 | slugify==0.0.1 62 | SQLAlchemy==1.3.12 63 | SQLAlchemy-Utils==0.36.1 64 | starlette==0.12.9 65 | stripe==2.41.0 66 | toml==0.10.0 67 | traitlets==4.3.3 68 | typed-ast==1.4.0 69 | ujson==1.35 70 | urllib3==1.25.7 71 | uvicorn==0.10.9 72 | uvloop==0.14.0 73 | wcwidth==0.1.7 74 | websockets==8.1 75 | zipp==0.6.0 76 | --------------------------------------------------------------------------------