├── .gitignore ├── Pipfile ├── app_utils.py ├── crud.py ├── database.py ├── main.py ├── models.py ├── requirements.txt ├── schemas.py └── test_debug.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cuongld2/rest_api_service_fastapi/1a941713e9845669689bd8cd00036d453908c57b/.gitignore -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | fastapi = "*" 10 | uvicorn = "*" 11 | sqlalchemy = "*" 12 | mysql-connector-python = "*" 13 | bcrypt = "*" 14 | pytest = "*" 15 | pyjwt = {extras = ["crypto"], version = "*"} 16 | 17 | [requires] 18 | python_version = "3.9" 19 | -------------------------------------------------------------------------------- /app_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta, datetime 2 | import jwt 3 | secret_key = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" 4 | algorithm = "HS256" 5 | 6 | 7 | def create_access_token(*, data: dict, expires_delta: timedelta = None): 8 | to_encode = data.copy() 9 | if expires_delta: 10 | expire = datetime.utcnow() + expires_delta 11 | else: 12 | expire = datetime.utcnow() + timedelta(minutes=15) 13 | to_encode.update({"exp": expire}) 14 | encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) 15 | return encoded_jwt 16 | 17 | 18 | def decode_access_token(*, data: str): 19 | to_decode = data 20 | return jwt.decode(to_decode, secret_key, algorithms=[algorithm]) 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /crud.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.orm import Session 2 | import models, schemas 3 | import bcrypt 4 | 5 | 6 | def get_user_by_username(db: Session, username: str): 7 | return db.query(models.UserInfo).filter(models.UserInfo.username == username).first() 8 | 9 | 10 | 11 | def create_user(db: Session, user: schemas.UserCreate): 12 | hashed_password = bcrypt.hashpw(user.password.encode('utf-8'), bcrypt.gensalt()) 13 | db_user = models.UserInfo(username=user.username, password=hashed_password, fullname=user.fullname) 14 | db.add(db_user) 15 | db.commit() 16 | db.refresh(db_user) 17 | return db_user 18 | 19 | 20 | def check_username_password(db: Session, user: schemas.UserAuthenticate): 21 | db_user_info: models.UserInfo = get_user_by_username(db, username=user.username) 22 | return bcrypt.checkpw(user.password.encode('utf-8'), db_user_info.password.encode('utf-8')) 23 | 24 | 25 | def create_new_blog(db: Session, blog: schemas.BlogBase): 26 | db_blog = models.Blog(title=blog.title, content=blog.content) 27 | db.add(db_blog) 28 | db.commit() 29 | db.refresh(db_blog) 30 | return db_blog 31 | 32 | 33 | def get_all_blogs(db: Session): 34 | return db.query(models.Blog).all() 35 | 36 | 37 | def get_blog_by_id(db: Session, blog_id: int): 38 | return db.query(models.Blog).filter(models.Blog.id == blog_id).first() 39 | 40 | 41 | 42 | def delete_blog_by_id(db:Session, blog: schemas.Blog): 43 | db.delete(blog) 44 | db.commit() -------------------------------------------------------------------------------- /database.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine 2 | from sqlalchemy.ext.declarative import declarative_base 3 | from sqlalchemy.orm import sessionmaker 4 | 5 | SQLALCHEMY_DATABASE_URL = "mysql+mysqlconnector://root:cuong1990@localhost:3306/restapi" 6 | 7 | engine = create_engine( 8 | SQLALCHEMY_DATABASE_URL, 9 | ) 10 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 11 | 12 | Base = declarative_base() 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | from fastapi.security import OAuth2PasswordBearer 3 | from jwt import PyJWTError 4 | from sqlalchemy.orm import Session 5 | from fastapi import Depends, FastAPI, HTTPException 6 | from starlette import status 7 | import crud 8 | import models 9 | import schemas 10 | from app_utils import decode_access_token 11 | from crud import get_user_by_username 12 | from database import engine, SessionLocal 13 | from schemas import UserInfo, TokenData, UserCreate, Token 14 | 15 | models.Base.metadata.create_all(bind=engine) 16 | 17 | ACCESS_TOKEN_EXPIRE_MINUTES = 30 18 | 19 | app = FastAPI(debug=True) 20 | 21 | 22 | # Dependency 23 | 24 | 25 | def get_db(): 26 | db = None 27 | try: 28 | db = SessionLocal() 29 | yield db 30 | finally: 31 | db.close() 32 | 33 | 34 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="authenticate") 35 | 36 | 37 | async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): 38 | credentials_exception = HTTPException( 39 | status_code=status.HTTP_401_UNAUTHORIZED, 40 | detail="Could not validate credentials", 41 | headers={"WWW-Authenticate": "Bearer"}, 42 | ) 43 | try: 44 | payload = decode_access_token(data=token) 45 | username: str = payload.get("sub") 46 | if username is None: 47 | raise credentials_exception 48 | token_data = TokenData(username=username) 49 | except PyJWTError: 50 | raise credentials_exception 51 | user = get_user_by_username(db, username=token_data.username) 52 | if user is None: 53 | raise credentials_exception 54 | return user 55 | 56 | 57 | @app.post("/user", response_model=UserInfo) 58 | def create_user(user: UserCreate, db: Session = Depends(get_db)): 59 | db_user = crud.get_user_by_username(db, username=user.username) 60 | if db_user: 61 | raise HTTPException(status_code=400, detail="Username already registered") 62 | return crud.create_user(db=db, user=user) 63 | 64 | 65 | @app.post("/authenticate", response_model=Token) 66 | def authenticate_user(user: schemas.UserAuthenticate, db: Session = Depends(get_db)): 67 | db_user = crud.get_user_by_username(db, username=user.username) 68 | if db_user is None: 69 | raise HTTPException(status_code=400, detail="Username not existed") 70 | else: 71 | is_password_correct = crud.check_username_password(db, user) 72 | if is_password_correct is False: 73 | raise HTTPException(status_code=400, detail="Password is not correct") 74 | else: 75 | from datetime import timedelta 76 | access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) 77 | from app_utils import create_access_token 78 | access_token = create_access_token( 79 | data={"sub": user.username}, expires_delta=access_token_expires) 80 | return {"access_token": access_token, "token_type": "Bearer"} 81 | 82 | 83 | @app.post("/blog", response_model=schemas.Blog) 84 | async def create_new_blog(blog: schemas.BlogBase, current_user: UserInfo = Depends(get_current_user) 85 | , db: Session = Depends(get_db)): 86 | return crud.create_new_blog(db=db, blog=blog) 87 | 88 | 89 | @app.get("/blog") 90 | async def get_all_blogs(current_user: UserInfo = Depends(get_current_user) 91 | , db: Session = Depends(get_db)): 92 | return crud.get_all_blogs(db=db) 93 | 94 | 95 | @app.get("/blog/{blog_id}") 96 | async def get_blog_by_id(blog_id, current_user: UserInfo = Depends(get_current_user) 97 | , db: Session = Depends(get_db)): 98 | return crud.get_blog_by_id(db=db, blog_id=blog_id) 99 | 100 | @app.delete("/blog/{blog_id}",status_code=204) 101 | async def delete_blog_by_id(blog_id,current_user: UserInfo = Depends(get_current_user) 102 | , db: Session = Depends(get_db)): 103 | blog_delete = crud.get_blog_by_id(db=db,blog_id=blog_id) 104 | if blog_delete: 105 | crud.delete_blog_by_id(db=db,blog=blog_delete) 106 | 107 | 108 | if __name__ == "__main__": 109 | log_config = uvicorn.config.LOGGING_CONFIG 110 | log_config["formatters"]["access"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" 111 | log_config["formatters"]["default"]["fmt"] = "%(asctime)s - %(levelname)s - %(message)s" 112 | uvicorn.run(app, log_config=log_config) 113 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String 2 | from database import Base 3 | 4 | 5 | class UserInfo(Base): 6 | __tablename__ = "user_info" 7 | 8 | id = Column(Integer, primary_key=True, index=True) 9 | username = Column(String, unique=True) 10 | password = Column(String) 11 | fullname = Column(String, unique=True) 12 | 13 | 14 | class Blog(Base): 15 | __tablename__ = "blog" 16 | 17 | id = Column(Integer, primary_key=True, index=True) 18 | title = Column(String) 19 | content = Column(String) 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asgiref==3.4.1 2 | atomicwrites==1.4.0 3 | attrs==21.2.0 4 | bcrypt==3.2.0 5 | cffi==1.14.6 6 | click==8.0.1 7 | colorama==0.4.4 8 | cryptography==3.4.8 9 | fastapi==0.68.1 10 | greenlet==1.1.1 11 | h11==0.12.0 12 | iniconfig==1.1.1 13 | mysql-connector-python==8.0.26 14 | packaging==21.0 15 | pluggy==1.0.0 16 | protobuf==3.17.3 17 | py==1.10.0 18 | pycparser==2.20 19 | pydantic==1.8.2 20 | PyJWT==2.1.0 21 | pyparsing==2.4.7 22 | pytest==6.2.5 23 | six==1.16.0 24 | SQLAlchemy==1.4.23 25 | starlette==0.14.2 26 | toml==0.10.2 27 | typing-extensions==3.10.0.2 28 | uvicorn==0.15.0 29 | -------------------------------------------------------------------------------- /schemas.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class UserInfoBase(BaseModel): 5 | username: str 6 | 7 | 8 | class UserCreate(UserInfoBase): 9 | fullname: str 10 | password: str 11 | 12 | 13 | class UserAuthenticate(UserInfoBase): 14 | password: str 15 | 16 | 17 | class UserInfo(UserInfoBase): 18 | id: int 19 | 20 | class Config: 21 | orm_mode = True 22 | 23 | 24 | class Token(BaseModel): 25 | access_token: str 26 | token_type: str 27 | 28 | 29 | class TokenData(BaseModel): 30 | username: str = None 31 | 32 | 33 | class BlogBase(BaseModel): 34 | title: str 35 | content: str 36 | 37 | 38 | class Blog(BlogBase): 39 | id: int 40 | 41 | class Config: 42 | orm_mode = True 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /test_debug.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def test_debug(): 4 | from app_utils import create_access_token 5 | from datetime import timedelta 6 | access_token_expires = timedelta(minutes=10) 7 | access_token = create_access_token(data={"sub": "cuongld"}, expires_delta=access_token_expires) 8 | print(access_token) 9 | from app_utils import decode_access_token 10 | decoded_access_token = decode_access_token(data=access_token) 11 | print(decoded_access_token) 12 | 13 | 14 | --------------------------------------------------------------------------------