├── crud └── __init__.py ├── model └── __init__.py ├── api ├── api_v2 │ └── __init__.py └── api_v1 │ ├── endpoints │ ├── __init__.py │ ├── register.py │ ├── user.py │ └── login.py │ └── api.py ├── service └── __init__.py ├── alembic ├── README ├── script.py.mako └── env.py ├── test ├── demo3.py ├── test.py ├── demo1.py └── demo2.py ├── settings ├── demo.env └── config.py ├── .gitignore ├── db ├── __init__.py ├── todoModel.py ├── init_db.py └── userModel.py ├── demo.py ├── common └── BaseResponse.py ├── requirements.txt ├── validator └── schemas.py ├── main.py ├── README.md └── utils └── jwt_service.py /crud/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /api/api_v2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /service/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /api/api_v1/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /alembic/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /test/demo3.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | print(datetime.now()) 3 | -------------------------------------------------------------------------------- /settings/demo.env: -------------------------------------------------------------------------------- 1 | # 修改文件名为".env", 并配置好你自己的参数 2 | HOST_IP="xxxx" 3 | DATABASE_PASSWORD="xxxx" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .fast 2 | .idea 3 | .vscode 4 | __pycache__ 5 | alembic/versions/* 6 | *.ini 7 | .env 8 | alembic.ini -------------------------------------------------------------------------------- /db/__init__.py: -------------------------------------------------------------------------------- 1 | from db.init_db import metadata 2 | 3 | # 这里这儿写是为了统一暴露metadata,在数据库迁移时可以找到模型 4 | from .todoModel import TodoList 5 | from .userModel import User -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Python中的设计模式因为是动态语言所以并不是那么重要,但是设计原则相对重要 3 | 1.单一职责原则:将大类拆分成小类、使用函数 4 | 2.开放关闭原则:类对扩展开放,对修改关闭,使用装饰器、函数中添加参数,技巧:找到父类中会变动的方法将其抽象成新方法 5 | 最终允许新的子类来重写它以改变类的运行行为 6 | 3.使用组合和依赖注入对代码进行扩展 7 | 4.用前端的思维设计后端,什么是变的,什么是不变的,变的东西通过数据驱动。只改动数据,代码逻辑保持不动。 8 | ''' -------------------------------------------------------------------------------- /common/BaseResponse.py: -------------------------------------------------------------------------------- 1 | # 基础返回格式 2 | 3 | 4 | def BaseResponse(code, msg, data, status=None, err=None): 5 | result = { 6 | "code": code, 7 | "status": status, 8 | "msg": msg, 9 | "error": err, 10 | "data": data 11 | } 12 | return result 13 | -------------------------------------------------------------------------------- /api/api_v1/api.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter 2 | 3 | from api.api_v1.endpoints import user, login, register 4 | 5 | api_router = APIRouter() 6 | 7 | # 分组路由 8 | api_router.include_router(login.router, tags=["login"]) 9 | api_router.include_router(user.router, prefix="/user", tags=["user"]) 10 | # api_router.include_router(utils.router, prefix="/util", tags=["utils"]) 11 | api_router.include_router(register.router, tags=["register"]) 12 | -------------------------------------------------------------------------------- /db/todoModel.py: -------------------------------------------------------------------------------- 1 | import orm 2 | from db.init_db import database, metadata 3 | 4 | 5 | class TodoList(orm.Model): 6 | 7 | __tablename__ = "todo" 8 | __database__ = database 9 | __metadata__ = metadata 10 | 11 | id = orm.Integer(primary_key=True) 12 | title = orm.String(max_length=30, allow_null=False) 13 | status = orm.Boolean(default=False) 14 | info = orm.Text() 15 | created_at = orm.DateTime() 16 | updated_at = orm.DateTime() 17 | deleted_at = orm.DateTime() 18 | -------------------------------------------------------------------------------- /settings/config.py: -------------------------------------------------------------------------------- 1 | ####################################################################################### 2 | # Param Data @ 3 | # Return @ 4 | # TODO @ 模板配置项 5 | # * 6 | # ! 7 | # ? 8 | ####################################################################################### 9 | 10 | import os 11 | from dotenv import load_dotenv 12 | from pathlib import Path # python3 only 13 | 14 | load_dotenv() 15 | load_dotenv(verbose=True) 16 | 17 | env_path = Path('.') / '.env' 18 | load_dotenv(dotenv_path=env_path) 19 | 20 | HOST_IP = os.getenv("HOST_IP") 21 | DATABASE_PASSWORD = os.getenv("DATABASE_PASSWORD") 22 | -------------------------------------------------------------------------------- /alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade(): 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade(): 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiomysql==0.0.20 2 | alembic==1.4.1 3 | autopep8==1.5 4 | bcrypt==3.1.7 5 | cffi==1.14.0 6 | Click==7.0 7 | cryptography==2.8 8 | databases==0.2.6 9 | fastapi==0.52.0 10 | h11==0.9.0 11 | httptools==0.1.1 12 | Mako==1.1.2 13 | MarkupSafe==1.1.1 14 | mysql-connector==2.2.9 15 | orm==0.1.5 16 | passlib==1.7.2 17 | pkg-resources==0.0.0 18 | pycodestyle==2.5.0 19 | pycparser==2.20 20 | pydantic==1.4 21 | PyJWT==1.7.1 22 | PyMySQL==0.9.2 23 | python-dateutil==2.8.1 24 | python-dotenv==0.12.0 25 | python-editor==1.0.4 26 | python-multipart==0.0.5 27 | six==1.14.0 28 | SQLAlchemy==1.3.13 29 | starlette==0.13.2 30 | typesystem==0.2.4 31 | uvicorn==0.11.3 32 | uvloop==0.14.0 33 | websockets==8.1 34 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | # 共享资源 4 | l = 0 5 | # 加锁 6 | lock = threading.Lock() 7 | 8 | 9 | def task_add1(): 10 | global l 11 | # 改成100w线程不安全的效果出来了 12 | for i in range(1000000): 13 | # lock.acquire() 14 | l += 1 15 | # lock.release() 16 | 17 | 18 | def task_add2(): 19 | 20 | global l 21 | for i in range(1000000): 22 | # lock.acquire() 23 | l -= 1 24 | 25 | # lock.release() 26 | 27 | 28 | def run(): 29 | # 可以给线程命名 30 | t1 = threading.Thread(name="t1", target=task_add1) 31 | t2 = threading.Thread(name="t2", target=task_add2) 32 | t1.start() 33 | t2.start() 34 | t1.join() 35 | t2.join() 36 | 37 | 38 | def main(): 39 | run() 40 | print(l) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /validator/schemas.py: -------------------------------------------------------------------------------- 1 | ####################################################################################### 2 | # Param Data @ 3 | # Return @ 4 | # TODO @ 为了避免和model混淆用schema 5 | # * 6 | # ! 7 | # ? 8 | ####################################################################################### 9 | 10 | from typing import List 11 | from pydantic import BaseModel 12 | 13 | 14 | class UserValidator(BaseModel): 15 | id: int = None 16 | username: str 17 | password: str 18 | email: str = None 19 | permission: str = None 20 | 21 | class Config: 22 | orm_mode = True 23 | 24 | 25 | class Token(BaseModel): 26 | access_token: str 27 | token_type: str 28 | 29 | 30 | class TokenData(BaseModel): 31 | username: str = None 32 | scopes: List[str] = [] 33 | 34 | # TodoList Schema 35 | 36 | 37 | class TodoBase(BaseModel): 38 | title: str 39 | info: str = None 40 | 41 | 42 | class TodoCreate(TodoBase): 43 | pass 44 | 45 | 46 | class Todo(TodoBase): 47 | id: int 48 | user_id: int 49 | 50 | class Config: 51 | # :是声明, =是默认值 52 | orm_mode = True 53 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ####################################################################################### 2 | # Param Data @ 3 | # Return @ 4 | # TODO @ 项目入口文件 5 | # * 6 | # ! 7 | # ? 8 | ####################################################################################### 9 | 10 | from fastapi import FastAPI 11 | from starlette.middleware.cors import CORSMiddleware 12 | from api.api_v1.api import api_router 13 | from db.init_db import create_connection, disconnect 14 | 15 | # 可以像flask一样自定义一些配置 16 | app = FastAPI(openapi_url="/api/v1/openapi.json") 17 | 18 | # 初始化数据库连接 19 | # init_db() 20 | 21 | app.add_event_handler("startup", create_connection) 22 | app.add_event_handler("shutdown", disconnect) 23 | 24 | # CORS配置 25 | origins = ["*"] 26 | 27 | app.add_middleware( 28 | CORSMiddleware, 29 | allow_origins=origins, 30 | allow_credentials=True, 31 | allow_methods=["*"], 32 | allow_headers=["*"], 33 | ) 34 | 35 | 36 | # *添加路由规则 37 | app.include_router(api_router, prefix="/api/v1") # 默认的前缀 38 | 39 | 40 | @app.get("/") 41 | async def index(): 42 | return {"ping": "pong"} 43 | 44 | 45 | if __name__ == '__main__': 46 | import uvicorn 47 | uvicorn.run(app, host="127.0.0.1", port=8000) 48 | -------------------------------------------------------------------------------- /test/demo1.py: -------------------------------------------------------------------------------- 1 | ####################################################################################### 2 | # Param Data @ 3 | # Return @ 4 | # TODO @ 信号量进行线程同步,信号量是操作系统内部管理的抽象数据 5 | # TODO @ 本质上信号量是内部数据, 用于标明当前的共享资源有多少并发读取 6 | # * 7 | # ! 8 | # ? 9 | ####################################################################################### 10 | 11 | # 获取会减少信号量的内部变量,当信号量是负值的时候线程会被挂起,直到有其他线程释放资源 12 | # 当线程不再需要该共享资源,必须通过 release() 释放。这样,信号量的内部变量增加,在信号量等待队列中排在最前面的线程会拿到共享资源的权限。 13 | 14 | # 信号量是原子的并没有什么问题, 如果不是或者两个操作有一个终止就会导致情况槽糕 15 | # 可以用with的语法管理 信号量,条件变量,事件和锁 16 | # 但是更常见的方式是队列 17 | import threading 18 | import logging 19 | 20 | 21 | lock = threading.Lock() 22 | rlock = threading.RLock() 23 | condition = threading.Condition() 24 | mutex = threading.Semaphore(1) 25 | threading_synchronization_list = [lock, rlock, condition, mutex] 26 | 27 | 28 | def threading_with(statement): 29 | with statement: 30 | logging.debug('%s acquired via with' % statement) 31 | 32 | 33 | def threading_not_with(statement): 34 | statement.acquire() 35 | try: 36 | logging.debug('%s acquired directly' % statement) 37 | finally: 38 | statement.release() 39 | 40 | 41 | for statement in threading_synchronization_list: 42 | t1 = threading.Thread(target=threading_with, args=(statement,)) 43 | t2 = threading.Thread(target=threading_not_with, args=(statement,)) 44 | -------------------------------------------------------------------------------- /api/api_v1/endpoints/register.py: -------------------------------------------------------------------------------- 1 | ####################################################################################### 2 | # Param Data @ 3 | # Return @ 4 | # TODO @ 注册接口 5 | # * 6 | # ! 7 | # ? 8 | ####################################################################################### 9 | 10 | from fastapi import APIRouter, HTTPException, status 11 | from validator import schemas 12 | from common.BaseResponse import BaseResponse 13 | from db.userModel import User 14 | from utils.jwt_service import create_access_token, get_password_hash 15 | from pymysql.err import IntegrityError 16 | 17 | 18 | router = APIRouter() 19 | 20 | 21 | @router.post("/register", status_code=status.HTTP_201_CREATED) 22 | async def create_user(user_schema: schemas.UserValidator): 23 | 24 | # 可以在schema里校验,但是比较费劲就在这里校验了 25 | if len(user_schema.username) <= 6 or len(user_schema.password) <= 6: 26 | raise HTTPException(422, detail="用户名和密码不能小于6位") 27 | 28 | # 对密码加密 29 | hash_password = get_password_hash(user_schema.password) 30 | try: 31 | # 在数据库中创建记录 32 | await User.objects.create(username=user_schema.username, password=hash_password) 33 | 34 | except IntegrityError: 35 | raise HTTPException( 36 | status.HTTP_422_UNPROCESSABLE_ENTITY, detail="该用户名已经有人创建") 37 | 38 | # 发放token 39 | token = create_access_token(data={"username": user_schema.username}) 40 | data = { 41 | "token": token, 42 | "token_type": "bearer" 43 | } 44 | 45 | return BaseResponse(code=201, msg="用户创建成功", data=data) 46 | -------------------------------------------------------------------------------- /db/init_db.py: -------------------------------------------------------------------------------- 1 | # from sqlalchemy import create_engine 2 | # from sqlalchemy.orm import sessionmaker 3 | import databases 4 | import sqlalchemy 5 | 6 | ####################################################################################### 7 | # Param Data @ 8 | # Return @ 9 | # TODO @ 基本的sqlalchemy配置 10 | # * 11 | # ! 12 | # ? 13 | ####################################################################################### 14 | # def init_db(): 15 | # # echo=True: 输出sqlalchemy日志 16 | # engine = create_engine( 17 | # f"mysql+pymysql://root:{MYSQL_PASSWORD}@{HOST_IP}:3306/{MYSQL_NAME}", echo=False) 18 | # # 创建DBSession类型 19 | # SessionLocal = sessionmaker(bind=engine) 20 | # return SessionLocal, engine 21 | 22 | 23 | # Dependency 24 | # def get_db(): 25 | # sessionLocal, _ = init_db() 26 | # db = sessionLocal() 27 | # try: 28 | # yield db 29 | # finally: 30 | # db.close() 31 | 32 | ####################################################################################### 33 | # Param Data @ 34 | # Return @ 35 | # TODO @ orm为基于sqlalchemy的异步数据库迁移模型,依赖aiomysql 36 | # * 37 | # ! 38 | # ? 39 | ####################################################################################### 40 | from settings.config import DATABASE_PASSWORD, HOST_IP 41 | DATABASE_URL = f"mysql+pymysql://root:{DATABASE_PASSWORD}@{HOST_IP}:3306/fastapi_plus" 42 | 43 | database = databases.Database(DATABASE_URL) 44 | 45 | metadata = sqlalchemy.MetaData() 46 | 47 | engine = sqlalchemy.create_engine(str(database.url)) 48 | metadata.create_all(engine) 49 | 50 | 51 | async def create_connection(): 52 | await database.connect() 53 | 54 | 55 | async def disconnect(): 56 | await database.disconnect() 57 | -------------------------------------------------------------------------------- /test/demo2.py: -------------------------------------------------------------------------------- 1 | ####################################################################################### 2 | # Param Data @ 3 | # Return @ 4 | # TODO @ 使用队列进行线程之间通信 5 | # * 6 | # ! 7 | # ? 8 | ####################################################################################### 9 | from threading import Thread, Event 10 | from queue import Queue 11 | import time 12 | import random 13 | 14 | # 生产者 15 | 16 | # 这里直接可以将线程类通过参数传递进来,不用单独创建了 17 | class producer(Thread): 18 | def __init__(self, queue): 19 | # 这里必须调用init初始化方法 20 | Thread.__init__(self) 21 | self.queue = queue 22 | 23 | def run(self): 24 | for i in range(10): 25 | 26 | item = random.randint(0, 256) 27 | # 向queue put数据 28 | self.queue.put(item) 29 | print('Producer notify: item N° %d appended to queue by %s' % 30 | (item, self.name)) 31 | time.sleep(1) 32 | 33 | 34 | class consumer(Thread): 35 | def __init__(self, queue): 36 | Thread.__init__(self) 37 | self.queue = queue 38 | 39 | def run(self): 40 | while True: 41 | # 消费数据 42 | item = self.queue.get() 43 | print('Consumer notify : %d popped from queue by %s' % 44 | (item, self.name)) 45 | self.queue.task_done() 46 | 47 | 48 | def main(): 49 | queue = Queue() 50 | # 一个生产者3个消费者 51 | t1 = producer(queue) 52 | t2 = consumer(queue) 53 | t3 = consumer(queue) 54 | t4 = consumer(queue) 55 | t1.start() 56 | t2.start() 57 | t3.start() 58 | t4.start() 59 | t1.join() 60 | t2.join() 61 | t3.join() 62 | t4.join() 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /api/api_v1/endpoints/user.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | 3 | from fastapi import Security, APIRouter, Depends, HTTPException 4 | from fastapi.security import OAuth2PasswordRequestForm 5 | from starlette import status 6 | 7 | from db.userModel import User 8 | from validator.schemas import Token, UserValidator 9 | from utils.jwt_service import get_current_user, create_access_token, get_password_hash, verify_password 10 | 11 | ####################################################################################### 12 | # Param Data @ 13 | # Return @ 14 | # TODO @ 获取当前用户,需要身份验证 15 | # * 16 | # ! 17 | # ? 18 | ####################################################################################### 19 | 20 | 21 | router = APIRouter() 22 | 23 | 24 | @router.post("/token", response_model=Token) 25 | # 第三方登录 26 | async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): 27 | user = await User.objects.get(username=form_data.username) 28 | 29 | if not user: 30 | raise HTTPException( 31 | status_code=status.HTTP_401_UNAUTHORIZED, 32 | detail="该用户不存在", 33 | headers={"WWW-Authenticate": "Bearer"}, 34 | ) 35 | 36 | if not verify_password(form_data.password, user.password): 37 | return HTTPException(status.HTTP_401_UNAUTHORIZED, detail="密码错误", headers={"WWW-Authenticate": "Bearer"}) 38 | 39 | access_token_expires = timedelta(minutes=15) 40 | access_token = create_access_token( 41 | data={"sub": user.username, "scopes": form_data.scopes}, expires_delta=access_token_expires 42 | ) 43 | return {"access_token": access_token, "token_type": "bearer"} 44 | 45 | 46 | @router.get("/", response_model=UserValidator, response_model_exclude=["password"]) 47 | # Security可以允许传入scopes的权限范围和Depends一样可以实现依赖注入 48 | async def user(current_user: User = Security(get_current_user, scopes=["normal"])): 49 | return current_user 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于 FastAPI 项目记录 2 | 3 | --- 4 | 5 | ### 第一个 FastAPI 程序 6 | 7 | - 创建虚拟环境 8 | 9 | ``` 10 | python3 -m venv .fastapi 11 | ``` 12 | 13 | - 激活虚拟环境 14 | 15 | ``` 16 | source .fastapi/bin/activate 17 | ``` 18 | 19 | - 安装基础依赖 20 | 21 | ``` 22 | pip install fastapi unicorn 23 | ``` 24 | 25 | - 最简单的例子 26 | 27 | ``` 28 | from fastapi import FastAPI 29 | 30 | app = FastAPI() 31 | 32 | 33 | @app.get("/") 34 | def read_root(): 35 | return {"Hello": "World"} 36 | ``` 37 | 38 | - 启动项目 39 | 40 | ``` 41 | uvicorn main:app --reload 42 | ``` 43 | 44 | - 接口文档 45 | ``` 46 | /redoc /docs 47 | ``` 48 | - 安装依赖 49 | ``` 50 | pip install -r requirements.txt 51 | ``` 52 | - 导出依赖 53 | ``` 54 | pip freeze > requirements.txt 55 | ``` 56 | --- 57 | 58 | ### 项目解耦 59 | 60 | - 思路: main.py 中的流程 61 | 62 | 1. 初始化配置(mysql, redis, mongodb 等) 63 | > 单独拎出来一个文件再分别引 redis/init.py db/init.py 等初始化连接配置.可以参考 Gin 的拆分方式(注意要打好日志, 用 try finally 起服务) 64 | 2. add_router()添加路由, 在 api/api_v1/api.py 做分组, 然后引 endpionts 中的子路由模块 65 | 3. 最后 run 起项目来 if **name**==**main** 66 | 67 | - 解耦出来的包功能说明 68 | 69 | 1. utils: 工具包 70 | 2. alambic: 数据库迁移工具自动创建的包 71 | 3. api: 项目路由, api_v1:第一版, api.py:分组路由, endpoint:真正的子路由 72 | 4. crud: 各种资源的增删改查数据库操作, 与 service(endpoints)要解耦 73 | 5. db: 数据库的连接 74 | 6. model: ORM(sqlalchemy) 75 | 7. validator: pydantic 库的校验 76 | 8. settings: 项目的配置文件, 不能暴露给用户 77 | 9. common: 公共依赖 78 | 10. service : 业务逻辑 79 | 11. test : 测试 80 | 81 | --- 82 | 83 | ### 数据库迁移 84 | 85 | - 配合 sqlalchemy 的数据库迁移工具 86 | 87 | ``` 88 | pip install alembic 89 | ``` 90 | 91 | - 项目初始化文件夹, 用来配置迁移文件和存放迁移历史记录 92 | 93 | ``` 94 | alembic init alembic 95 | ``` 96 | 97 | - 配置你的数据库连接 98 | 99 | ``` 100 | sqlalchemy.url = driver://user:pass@localhost/dbname 101 | ``` 102 | 103 | - 修改 env.py 文件 104 | 105 | ``` 106 | target_metadata = None 107 | ``` 108 | 109 | ``` 110 | import sys 111 | from os.path import abspath, dirname 112 | 113 | sys.path.append(dirname(dirname(abspath(__file__)))) 114 | # 注意这个地方是要引入模型里面的Base,不是connect里面的 115 | from sqlalchemy_demo.modules.user_module import Base 116 | target_metadata = Base.metadata 117 | ``` 118 | 119 | - 创建迁移文件 120 | 121 | ``` 122 | alembic revision --autogenerate -m "描述信息类似于git" 123 | ``` 124 | 125 | - 更新到最新的版本 126 | 127 | ``` 128 | alembic upgrade head 129 | ``` 130 | 131 | - 更多功能请查看官网 132 | 133 | --- 134 | 135 | ### 项目的设计细节 136 | 137 | - 可以参考官方文档的 sqlachemy 例子 138 | -------------------------------------------------------------------------------- /api/api_v1/endpoints/login.py: -------------------------------------------------------------------------------- 1 | ####################################################################################### 2 | # Param Data @ 3 | # Return @ 4 | # TODO @ OAuth2.0 + JWT 登录接口,Bearer token,一般放在请求头中Authorization 5 | # * 例:“eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Imxpbmlhbmh1aSJ9.hnOfZb95jFwQsYj3qlgFbUu1rKpfTE6AzgXZidEGGTk” 6 | # * 由header.payload.signature三部分组成,通过Base64编码 7 | # * header:{"alg": "HS256","typ": "JWT"}token类型和签名算法 8 | # * payload:{"sub":xxxxxx, "name":xxxxxx}jwt中预定的了一些claims 9 | # * signature:对前两部分的摘要进行签名 10 | # ! 11 | # ? 12 | ####################################################################################### 13 | """ 14 | JWT中预制了一些Claim : 15 | 1.iss(Issuer签发者) 16 | 2.sub(subject签发给的受众,在Issuer范围内是唯一的) 17 | 3.aud(Audience接收方) 18 | 4.exp(Expiration Time过期时间) 19 | 5.iat(Issued At签发时间)等等 20 | 21 | 如果放在cookie中觉得浪费带宽可以放在请求头中 22 | 23 | OAuth2.0和JWT结合使用 : 24 | { 25 | "sub":"xxxxxx", 26 | "scope":"normal", 27 | "exp":xxx, 28 | } 29 | """ 30 | 31 | from fastapi import Body, APIRouter, HTTPException, status, Response 32 | from validator.schemas import Token 33 | from db.userModel import User 34 | from utils.jwt_service import create_access_token, verify_password 35 | 36 | 37 | router = APIRouter() 38 | 39 | 40 | # 登录接口 41 | @router.post("/login", response_model=Token) 42 | async def login(response: Response, username: str = Body(..., min_length=6), password: str = Body(..., min_length=6)): 43 | # 获取用户 44 | user = await User.objects.get(username=username) 45 | if not user: 46 | raise HTTPException( 47 | status_code=status.HTTP_401_UNAUTHORIZED, 48 | detail="用不存在", 49 | headers={"Authorization": "Bearer"}, 50 | ) 51 | 52 | # 验证密码 53 | if not verify_password(password, user.password): 54 | return HTTPException(status.HTTP_401_UNAUTHORIZED, detail="密码错误") 55 | 56 | # 根据username生成JWT token 57 | token = create_access_token( 58 | data={"sub": user.username, "scopes": [user.permission]}) 59 | 60 | if token: 61 | # 这里可以不在后端set_cookie,vue也可以做,比较通用的是在前端做,因为安卓和ios是不支持cookie的 62 | # cookie的过期时间要和jwt token过期时间保持一致 63 | response.set_cookie("JWT-token", token, expires=15*60, 64 | path="/", domain="127.0.0.1", httponly=True) 65 | 66 | return {"access_token": token, "token_type": "bearer"} 67 | 68 | else: 69 | raise HTTPException( 70 | status.HTTP_500_INTERNAL_SERVER_ERROR, detail="生成token失败") 71 | -------------------------------------------------------------------------------- /alembic/env.py: -------------------------------------------------------------------------------- 1 | from logging.config import fileConfig 2 | 3 | from sqlalchemy import engine_from_config 4 | from sqlalchemy import pool 5 | 6 | from alembic import context 7 | 8 | # this is the Alembic Config object, which provides 9 | # access to the values within the .ini file in use. 10 | config = context.config 11 | 12 | # Interpret the config file for Python logging. 13 | # This line sets up loggers basically. 14 | fileConfig(config.config_file_name) 15 | 16 | # add your model's MetaData object here 17 | # for 'autogenerate' support 18 | # from myapp import mymodel 19 | # target_metadata = mymodel.Base.metadata 20 | import sys 21 | from os.path import abspath, dirname 22 | 23 | sys.path.append(dirname(dirname(abspath(__file__)))) 24 | from db import metadata 25 | target_metadata = metadata 26 | 27 | # other values from the config, defined by the needs of env.py, 28 | # can be acquired: 29 | # my_important_option = config.get_main_option("my_important_option") 30 | # ... etc. 31 | 32 | 33 | def run_migrations_offline(): 34 | """Run migrations in 'offline' mode. 35 | 36 | This configures the context with just a URL 37 | and not an Engine, though an Engine is acceptable 38 | here as well. By skipping the Engine creation 39 | we don't even need a DBAPI to be available. 40 | 41 | Calls to context.execute() here emit the given string to the 42 | script output. 43 | 44 | """ 45 | url = config.get_main_option("sqlalchemy.url") 46 | context.configure( 47 | url=url, 48 | target_metadata=target_metadata, 49 | literal_binds=True, 50 | dialect_opts={"paramstyle": "named"}, 51 | ) 52 | 53 | with context.begin_transaction(): 54 | context.run_migrations() 55 | 56 | 57 | def run_migrations_online(): 58 | """Run migrations in 'online' mode. 59 | 60 | In this scenario we need to create an Engine 61 | and associate a connection with the context. 62 | 63 | """ 64 | connectable = engine_from_config( 65 | config.get_section(config.config_ini_section), 66 | prefix="sqlalchemy.", 67 | poolclass=pool.NullPool, 68 | ) 69 | 70 | with connectable.connect() as connection: 71 | context.configure( 72 | connection=connection, target_metadata=target_metadata 73 | ) 74 | 75 | with context.begin_transaction(): 76 | context.run_migrations() 77 | 78 | 79 | if context.is_offline_mode(): 80 | run_migrations_offline() 81 | else: 82 | run_migrations_online() 83 | -------------------------------------------------------------------------------- /utils/jwt_service.py: -------------------------------------------------------------------------------- 1 | import jwt 2 | from datetime import datetime, timedelta 3 | from passlib.context import CryptContext 4 | from fastapi.security import OAuth2PasswordBearer, SecurityScopes 5 | # 注意fastapi包中的HTTPException才可以定义请求头 6 | from fastapi import Depends, status, HTTPException 7 | # from starlette.exceptions import HTTPException 8 | from pydantic import ValidationError 9 | from validator.schemas import TokenData 10 | from db.userModel import User 11 | 12 | ####################################################################################### 13 | # Param Data @ 14 | # Return @ 15 | # TODO @ jwt工具包 16 | # * 17 | # ! 18 | # ? 19 | ####################################################################################### 20 | 21 | SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" 22 | ALGORITHM = "HS256" 23 | 24 | pwd_context = CryptContext(schemes=['bcrypt']) 25 | 26 | # 这里的tokenUrl路径一定要和token路径保持一致 27 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/user/token", scopes={ 28 | "normal": "Read information about current user", "admin": "admin user"}) 29 | 30 | 31 | # 验证密码 32 | def verify_password(plain_password, hashed_password): 33 | return pwd_context.verify(plain_password, hashed_password) 34 | 35 | 36 | # hash密码 37 | def get_password_hash(password): 38 | return pwd_context.hash(password) 39 | 40 | 41 | # 创建jwt-token 42 | def create_access_token(*, data: dict, expires_delta: timedelta = None): 43 | to_encode = data.copy() 44 | if expires_delta: 45 | expire = datetime.now() + expires_delta 46 | else: 47 | expire = datetime.now() + timedelta(minutes=15) 48 | to_encode.update({"exp": expire}) 49 | encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) 50 | return encoded_jwt 51 | 52 | 53 | # 获取当前用户 54 | async def get_current_user(*, token: str = Depends(oauth2_scheme), security_scopes: SecurityScopes): 55 | if security_scopes.scopes: 56 | # 固定的写法,见官网 57 | authenticate_value = f"Bearer scopes={security_scopes.scope_str}" 58 | 59 | else: 60 | authenticate_value = f"Bearer" 61 | 62 | # 这里定义一个通用的错误 63 | error = HTTPException(status.HTTP_401_UNAUTHORIZED, detail="无权访问", 64 | headers={"WWW-Authenticate": authenticate_value}) 65 | 66 | try: 67 | # 解码token 68 | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) 69 | 70 | # 按照OAuth2.0通常定义成“sub” 71 | username = payload.get("sub") 72 | 73 | if username is None: 74 | raise error 75 | # scopes是权限,在上面已经定义 76 | token_scopes = payload.get("scopes", []) 77 | token_data = TokenData(scopes=token_scopes, username=username) 78 | 79 | except (PyJWTError, ValidationError): 80 | raise error 81 | 82 | # 解码成功获取到user 83 | user = await User.objects.limit(1).filter(username=username).all() 84 | 85 | if user is None: 86 | raise error 87 | 88 | # 判断权限 89 | for scope in security_scopes.scopes: 90 | if scope not in token_data.scopes: 91 | raise HTTPException( 92 | status_code=status.HTTP_401_UNAUTHORIZED, 93 | detail="Not enough permissions", 94 | headers={"WWW-Authenticate": authenticate_value}, 95 | ) 96 | 97 | # 返回用户 98 | return user[0] 99 | -------------------------------------------------------------------------------- /db/userModel.py: -------------------------------------------------------------------------------- 1 | # from sqlalchemy.ext.declarative import declarative_base 2 | # from sqlalchemy import Column, Integer, String, ForeignKey, Time, Boolean, Text 3 | # from sqlalchemy.orm import relationship 4 | # from db.init_db import init_db 5 | 6 | # Base = declarative_base() 7 | 8 | # _, engine = init_db() 9 | # Base.metadata.create_all(bind=engine) 10 | 11 | 12 | ####################################################################################### 13 | # Param Data @ 14 | # Return @ 15 | # TODO @ 基本的sqlalchemy orm 16 | # * 17 | # ! 18 | # ? 19 | ####################################################################################### 20 | 21 | # class User(Base): 22 | 23 | # __tablename__ = "user" 24 | 25 | # id = Column(Integer, primary_key=True, autoincrement=True) 26 | # name = Column(String(32)) 27 | # username = Column(String(32), nullable=False, unique=True) 28 | # password = Column(String(60), nullable=False) 29 | # nickname = Column(String(32)) 30 | # phoneNum = Column(Integer) 31 | # email = Column(String(30)) 32 | # create_time = Column(Time) 33 | # delete_time = Column(Time) 34 | # birth_year = Column(Integer, default=2000) 35 | # birth_month = Column(Integer, default=1) 36 | # birth_day = Column(Integer, default=1) 37 | # avatar = Column(String(256)) 38 | # sex = Column(Boolean) 39 | # location = Column(String(32)) 40 | # # 用户类型 41 | # # CHOICES 42 | # # type 43 | # token = Column(String(60), default="666") 44 | # todo_list = relationship('todoList', back_populates="users") 45 | 46 | 47 | # class todoList(Base): 48 | 49 | # __tablename__ = "todoList" 50 | 51 | # id = Column(String(20), primary_key=True) 52 | # user_id = Column(Integer, ForeignKey("user.id")) 53 | # users = relationship('User', back_populates="todo_list") 54 | # title = Column(String(30), nullable=False) 55 | # status = Column(Boolean, default=0) 56 | # info = Column(Text) 57 | # created_at = Column(Time) 58 | # updated_at = Column(Time) 59 | # deleted_at = Column(Time) 60 | 61 | ####################################################################################### 62 | # Param Data @ 63 | # Return @ 64 | # TODO @ orm为基于sqlalchemy的异步数据库迁移模型,依赖aiomysql 65 | # * 66 | # ! 67 | # ? 68 | ####################################################################################### 69 | import orm 70 | 71 | from db.init_db import database, metadata 72 | 73 | 74 | class User(orm.Model): 75 | 76 | __tablename__ = "user" 77 | __database__ = database 78 | __metadata__ = metadata 79 | 80 | id = orm.Integer(primary_key=True) 81 | username = orm.String(max_length=50, allow_null=False, 82 | allow_blank=False, index=True, unique=True) 83 | email = orm.String(max_length=50, unique=True, index=True, allow_null=True, 84 | allow_blank=True) 85 | password = orm.String(max_length=255, allow_null=False, 86 | allow_blank=False) 87 | phone = orm.String(max_length=11, min_length=11, 88 | allow_null=True, allow_blank=True) 89 | # 用户权限scopes字段 90 | permission = orm.String(max_length=50, default="normal", 91 | allow_null=True, allow_blank=True) 92 | created_at = orm.String(allow_null=True, allow_blank=True, max_length=50) 93 | updated_at = orm.String(allow_null=True, allow_blank=True, max_length=50) 94 | deleted_at = orm.String(allow_null=True, allow_blank=True, max_length=50) 95 | # 外键 96 | # todo = orm.ForeignKey(TodoList) 97 | --------------------------------------------------------------------------------