├── api ├── __init__.py ├── routers │ ├── __init__.py │ ├── followers.py │ ├── auth.py │ ├── tweet_likes.py │ ├── comment_likes.py │ ├── tweets.py │ ├── follows.py │ ├── messages.py │ ├── comments.py │ └── users.py ├── core │ ├── utilities.py │ ├── sendgrid │ │ ├── constants.py │ │ ├── schema.py │ │ ├── utils.py │ │ └── __init__.py │ ├── cors.py │ ├── security.py │ ├── websocket │ │ └── connection_manager.py │ └── config.py ├── schemas │ ├── basic_user.py │ ├── generic.py │ ├── token.py │ ├── base_class.py │ ├── counts.py │ ├── __init__.py │ ├── followers.py │ ├── tweets.py │ ├── follows.py │ ├── chat.py │ ├── messages.py │ ├── auth.py │ ├── comments.py │ ├── websockets.py │ ├── tweet_likes.py │ ├── comment_likes.py │ └── users.py ├── database.py ├── background_functions │ └── email_notifications.py ├── dependencies.py ├── models.py ├── main.py └── crud.py ├── alembic ├── README ├── script.py.mako ├── versions │ ├── 8bf9f70dbe14_create_users_tweets_relationships.py │ ├── 54983669459a_create_follows_table.py │ ├── ee2884e1c996_create_tweet_likes_table.py │ ├── 809293dcf958_create_comment_likes_table.py │ ├── e6e1e11c251b_create_tweets_table.py │ ├── 88c139446b32_create_comments_table.py │ ├── f9b8f7b3cd8a_add_messages_table.py │ └── 64d575f54bd3_create_user_table.py └── env.py ├── .dockerignore ├── .gitignore ├── docker-compose.override.yml ├── .vscode └── settings.json ├── requirements-short.txt ├── Dockerfile-prod ├── Dockerfile-dev ├── LICENSE ├── requirements.txt ├── docker-compose-dev.yml ├── docker-compose-prod.yml ├── alembic.ini ├── docker-compose.traefik.yml └── README.md /api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /api/routers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /alembic/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | venv 2 | documentation 3 | __pycache__ 4 | *.db 5 | .env 6 | pgdata -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | documentation 3 | __pycache__ 4 | **/.DS_Store 5 | *.db 6 | .env 7 | pgdata -------------------------------------------------------------------------------- /docker-compose.override.yml: -------------------------------------------------------------------------------- 1 | services: 2 | backend: 3 | ports: 4 | - 80:80 5 | 6 | networks: 7 | traefik-public: 8 | external: false 9 | -------------------------------------------------------------------------------- /api/core/utilities.py: -------------------------------------------------------------------------------- 1 | import uuid as uuidlib 2 | 3 | 4 | def generate_random_uuid(): 5 | # Generate a UUID as a confirmation key 6 | return str(uuidlib.uuid4()) 7 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "./venv/bin/python", 3 | "python.formatting.provider": "autopep8", 4 | "python.formatting.autopep8Args": ["--ignore", "E402"] 5 | } 6 | -------------------------------------------------------------------------------- /requirements-short.txt: -------------------------------------------------------------------------------- 1 | fastapi[all] 2 | psycopg2-binary 3 | sqlalchemy 4 | uvicorn[standard] 5 | python-jose[cryptography] 6 | bcrypt 7 | passlib[bcrypt] 8 | alembic 9 | sendgrid 10 | beautifulsoup4 -------------------------------------------------------------------------------- /Dockerfile-prod: -------------------------------------------------------------------------------- 1 | FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7 2 | 3 | RUN mkdir -p app/api 4 | 5 | COPY ./requirements.txt /app 6 | 7 | COPY ./api /app/api 8 | 9 | RUN pip install --upgrade pip && pip install -r requirements.txt 10 | 11 | 12 | -------------------------------------------------------------------------------- /Dockerfile-dev: -------------------------------------------------------------------------------- 1 | FROM tiangolo/uvicorn-gunicorn-fastapi:python3.7 2 | 3 | RUN mkdir -p app/api 4 | 5 | WORKDIR /app 6 | 7 | COPY ./requirements.txt /app 8 | 9 | COPY ./api /app/api 10 | 11 | RUN pip install --upgrade pip && pip install -r requirements.txt 12 | 13 | 14 | -------------------------------------------------------------------------------- /api/core/sendgrid/constants.py: -------------------------------------------------------------------------------- 1 | # Unsubscribe groups 2 | MAIN_UNSUBSCRIBE_GROUP_ID = 15703 3 | 4 | # Dynamic Template IDs 5 | REGISTRATION_CONFIRMATION_DYNAMIC_TEMPLATE_ID = "d-398f7ce806b84c878126d9d5ae4e9a8e" 6 | NEW_NOTIFICATION_DYNAMIC_TEMPLATE_ID = "d-4e340a8932ac473c91fd3a351e805fa5" 7 | -------------------------------------------------------------------------------- /api/schemas/basic_user.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from datetime import datetime, date 3 | 4 | from pydantic import BaseModel 5 | 6 | from .users import UserBase 7 | 8 | class BasicUser(UserBase): 9 | id: int 10 | 11 | class Config: 12 | orm_mode = True -------------------------------------------------------------------------------- /api/schemas/generic.py: -------------------------------------------------------------------------------- 1 | # Standard Library 2 | from datetime import datetime, date 3 | 4 | # Types 5 | from typing import List, Optional, Any 6 | from pydantic import BaseModel, validator 7 | 8 | class EmptyResponse(BaseModel): 9 | """Empty HTTP Respone (No Data) 10 | """ 11 | pass -------------------------------------------------------------------------------- /api/schemas/token.py: -------------------------------------------------------------------------------- 1 | # Standard Library 2 | from datetime import datetime, date 3 | 4 | # Types 5 | from typing import List, Optional, Any 6 | from pydantic import BaseModel, validator 7 | 8 | 9 | class Token(BaseModel): 10 | access_token: str 11 | token_type: str 12 | 13 | 14 | class TokenData(BaseModel): 15 | email: Optional[str] = None 16 | -------------------------------------------------------------------------------- /api/schemas/base_class.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from sqlalchemy.ext.declarative import as_declarative, declared_attr 4 | 5 | 6 | @as_declarative() 7 | class Base: 8 | id: Any 9 | __name__: str 10 | # Generate __tablename__ automatically 11 | @declared_attr 12 | def __tablename__(cl s) -> str: 13 | return cls.__name__.lower() -------------------------------------------------------------------------------- /api/schemas/counts.py: -------------------------------------------------------------------------------- 1 | # Standard Library 2 | from datetime import datetime, date 3 | 4 | # Types 5 | from typing import List, Optional, Any 6 | from pydantic import BaseModel, validator 7 | 8 | 9 | class CountBase(BaseModel): 10 | count: int 11 | 12 | 13 | class TweetCommentCount(CountBase): 14 | """The number of comments for a given tweet 15 | """ 16 | pass 17 | -------------------------------------------------------------------------------- /api/core/sendgrid/schema.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import os 3 | 4 | from pydantic import EmailStr 5 | 6 | 7 | class EmailSender(EmailStr, Enum): 8 | ACCOUNT = os.environ.get("SEND_GRID_ACCOUNT_VERIFICATION_FROM_EMAIL") 9 | NOTIFICATIONS = os.environ.get("SEND_GRID_NOTIFICATIONS_FROM_EMAIL") 10 | PASSWORD_RECOVERY = os.environ.get( 11 | "SEND_GRID_PASSWORD_RECOVERY_FROM_EMAIL") 12 | -------------------------------------------------------------------------------- /api/core/sendgrid/utils.py: -------------------------------------------------------------------------------- 1 | # Standard Library 2 | import os 3 | import http.client 4 | 5 | # SendGrid 6 | from sendgrid import SendGridAPIClient 7 | from sendgrid.helpers.mail import Content 8 | 9 | # Beautiful Soup 10 | from bs4 import BeautifulSoup 11 | 12 | 13 | def get_text_from_html(html): 14 | soup = BeautifulSoup(html) 15 | plain_text = soup.get_text() 16 | return Content("text/plain", plain_text) 17 | -------------------------------------------------------------------------------- /api/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pydantic import BaseModel 3 | from .websockets import * 4 | from .comments import * 5 | from .tweets import * 6 | from .users import * 7 | from .basic_user import * 8 | from .followers import * 9 | from .follows import * 10 | from .tweet_likes import * 11 | from .comment_likes import * 12 | from .messages import * 13 | from .chat import * 14 | from .auth import * 15 | from .token import * 16 | from .generic import * 17 | from .counts import * 18 | -------------------------------------------------------------------------------- /api/schemas/followers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, ForwardRef, Generic, Any 2 | from pydantic import BaseModel 3 | 4 | from datetime import datetime, date 5 | 6 | from . import BasicUser 7 | 8 | class Follower(BaseModel): 9 | user: BasicUser 10 | followsUser: BasicUser 11 | 12 | class Config: 13 | orm_mode = True 14 | 15 | 16 | class FollowersRequestBody(BaseModel): 17 | userId: int 18 | 19 | class FollowersResponse(BaseModel): 20 | userId: int 21 | email: str 22 | username: str 23 | bio: str 24 | birthdate: date 25 | -------------------------------------------------------------------------------- /alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade(): 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade(): 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /api/core/sendgrid/__init__.py: -------------------------------------------------------------------------------- 1 | # Standard Library 2 | import os 3 | from typing import Union, List 4 | from enum import Enum 5 | 6 | from pydantic import EmailStr 7 | 8 | 9 | # SendGrid API 10 | from sendgrid import SendGridAPIClient, Personalization, Asm 11 | from sendgrid.helpers.mail import Mail 12 | 13 | # Utilities 14 | from .schema import EmailSender 15 | from ... import models 16 | 17 | 18 | async def send_email(message: Mail): 19 | try: 20 | sg = SendGridAPIClient(os.environ.get("SEND_GRID_API_KEY")) 21 | response = sg.send(message) 22 | 23 | except Exception as e: 24 | print("error sending email", e) 25 | -------------------------------------------------------------------------------- /api/schemas/tweets.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, ForwardRef, Generic, Any 2 | from pydantic import BaseModel 3 | 4 | from datetime import datetime, date 5 | 6 | 7 | class TweetBase(BaseModel): 8 | content: str 9 | 10 | 11 | class TweetCreate(TweetBase): 12 | pass 13 | 14 | class TweetUpdate(BaseModel): 15 | newContent: str 16 | 17 | 18 | class Tweet(TweetBase): 19 | id: int 20 | userId: int 21 | username: str 22 | content: str 23 | created_at: datetime 24 | 25 | class Config: 26 | orm_mode = True 27 | 28 | class TweetResponse(BaseModel): 29 | tweetId: int 30 | userId: int 31 | username: str 32 | content: str 33 | createdAt: datetime -------------------------------------------------------------------------------- /api/schemas/follows.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, ForwardRef, Generic, Any 2 | from pydantic import BaseModel 3 | 4 | from datetime import datetime, date 5 | 6 | from . import BasicUser 7 | 8 | 9 | class Follow(BaseModel): 10 | user: BasicUser 11 | followsUser: BasicUser 12 | 13 | 14 | class FollowsCreateRequestBody(BaseModel): 15 | followUserId: int 16 | 17 | 18 | class FollowsDeleteRequestBody(BaseModel): 19 | followUserId: int 20 | 21 | 22 | class FollowsResponse(BaseModel): 23 | userId: int 24 | email: str 25 | username: str 26 | bio: str 27 | birthdate: date 28 | 29 | class Config: 30 | orm_mode = True 31 | 32 | 33 | class WSFollowsUpdateBody(BaseModel): 34 | userId: int 35 | followUserId: int 36 | -------------------------------------------------------------------------------- /api/schemas/chat.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from datetime import datetime, date 3 | 4 | from pydantic import BaseModel 5 | 6 | from . import Message 7 | 8 | 9 | class ChatUserOnlineRequestBody(BaseModel): 10 | userId: int 11 | 12 | 13 | class ChatUserOnlineResponseBody(BaseModel): 14 | userId: Optional[int] 15 | username: Optional[str] 16 | isOnline: bool 17 | 18 | 19 | class ChatUserTypingRequestBody(BaseModel): 20 | userId: int # other user 21 | 22 | 23 | class ChatUserTypingResponseBody(BaseModel): 24 | isTyping: bool 25 | 26 | 27 | class NewChatMessageResponseBody(Message): 28 | # Just a message 29 | pass 30 | 31 | 32 | class DeletedChatMessageResponseBody(BaseModel): 33 | messageId: int 34 | userId: int 35 | -------------------------------------------------------------------------------- /alembic/versions/8bf9f70dbe14_create_users_tweets_relationships.py: -------------------------------------------------------------------------------- 1 | """create users-tweets relationships 2 | 3 | Revision ID: 8bf9f70dbe14 4 | Revises: e6e1e11c251b 5 | Create Date: 2021-03-16 22:07:56.593450 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '8bf9f70dbe14' 14 | down_revision = 'e6e1e11c251b' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | """Since relationships are only defined in SQLAlchemy. This migration version 21 | is not needed. Skip 22 | """ 23 | pass 24 | 25 | 26 | def downgrade(): 27 | """Since relationships are only defined in SQLAlchemy. This migration version 28 | is not needed. Skip 29 | """ 30 | pass 31 | -------------------------------------------------------------------------------- /api/database.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine 2 | from sqlalchemy.ext.declarative import declarative_base 3 | from sqlalchemy.orm import sessionmaker 4 | 5 | from .core.config import get_db_connection_url 6 | 7 | import time 8 | import os 9 | 10 | engine = None 11 | SessionLocal = None 12 | Base = None 13 | 14 | retries = 5 15 | while retries > 0: 16 | try: 17 | engine = create_engine( 18 | get_db_connection_url() 19 | ) 20 | SessionLocal = sessionmaker( 21 | autocommit=False, autoflush=False, bind=engine) 22 | Base = declarative_base() 23 | print("DB Connected") 24 | break 25 | except Exception as e: 26 | print("Error connecting..." + str(e)) 27 | retries -= 1 28 | time.sleep(3) 29 | -------------------------------------------------------------------------------- /api/schemas/messages.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from datetime import datetime, date 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class Message(BaseModel): 8 | id: Optional[int] 9 | userFromId: int 10 | userFromUsername: str 11 | userToId: int 12 | userToUsername: str 13 | content: str 14 | createdAt: datetime 15 | 16 | 17 | class Conversation(BaseModel): 18 | List[Message] 19 | 20 | 21 | class MessageResponse(BaseModel): 22 | conversations: List[Conversation] 23 | 24 | 25 | class MessageCreateRequestBody(BaseModel): 26 | content: str 27 | userToId: int 28 | 29 | 30 | class MessageDeleteRequestBody(BaseModel): 31 | messageId: int 32 | 33 | 34 | class MessageUpdateRequestBody(BaseModel): 35 | messageId: int 36 | newContent: str 37 | -------------------------------------------------------------------------------- /alembic/versions/54983669459a_create_follows_table.py: -------------------------------------------------------------------------------- 1 | """create follows table 2 | 3 | Revision ID: 54983669459a 4 | Revises: 8bf9f70dbe14 5 | Create Date: 2021-03-17 14:23:19.751988 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '54983669459a' 14 | down_revision = '8bf9f70dbe14' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | op.create_table( 21 | 'follows', 22 | sa.Column('id', sa.Integer, primary_key=True, index=True), 23 | sa.Column('user_id', sa.Integer, sa.ForeignKey("users.id", ondelete="CASCADE")), 24 | sa.Column('follows_user_id', sa.Integer, sa.ForeignKey("users.id", ondelete="CASCADE")) 25 | ) 26 | 27 | 28 | def downgrade(): 29 | op.drop_table('follows') 30 | -------------------------------------------------------------------------------- /alembic/versions/ee2884e1c996_create_tweet_likes_table.py: -------------------------------------------------------------------------------- 1 | """create tweet_likes table 2 | 3 | Revision ID: ee2884e1c996 4 | Revises: 54983669459a 5 | Create Date: 2021-03-17 14:25:31.779045 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = 'ee2884e1c996' 14 | down_revision = '54983669459a' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | op.create_table( 21 | 'tweet_likes', 22 | sa.Column('id', sa.Integer, primary_key=True, index=True), 23 | sa.Column('user_id', sa.Integer, sa.ForeignKey("users.id", ondelete="CASCADE")), 24 | sa.Column('tweet_id', sa.Integer, sa.ForeignKey("tweets.id", ondelete="CASCADE")) 25 | ) 26 | 27 | 28 | def downgrade(): 29 | op.drop_table('tweet_likes') 30 | -------------------------------------------------------------------------------- /alembic/versions/809293dcf958_create_comment_likes_table.py: -------------------------------------------------------------------------------- 1 | """create comment_likes table 2 | 3 | Revision ID: 809293dcf958 4 | Revises: 88c139446b32 5 | Create Date: 2021-03-17 14:28:19.459816 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '809293dcf958' 14 | down_revision = '88c139446b32' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | op.create_table( 21 | 'comment_likes', 22 | sa.Column('id', sa.Integer, primary_key=True, index=True), 23 | sa.Column('user_id', sa.Integer, sa.ForeignKey("users.id", ondelete="CASCADE")), 24 | sa.Column('comment_id', sa.Integer, sa.ForeignKey("comments.id", ondelete="CASCADE")), 25 | ) 26 | 27 | 28 | def downgrade(): 29 | op.drop_table('comment_likes') -------------------------------------------------------------------------------- /alembic/versions/e6e1e11c251b_create_tweets_table.py: -------------------------------------------------------------------------------- 1 | """create tweets table 2 | 3 | Revision ID: e6e1e11c251b 4 | Revises: 64d575f54bd3 5 | Create Date: 2021-03-16 22:04:34.778473 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = 'e6e1e11c251b' 14 | down_revision = '64d575f54bd3' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | op.create_table( 21 | 'tweets', 22 | sa.Column('id', sa.Integer, primary_key=True, index=True), 23 | sa.Column('content', sa.String, index=True), 24 | sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), 25 | sa.Column('user_id', sa.Integer, sa.ForeignKey("users.id", ondelete="CASCADE")) 26 | ) 27 | 28 | def downgrade(): 29 | op.drop_table('tweets') 30 | -------------------------------------------------------------------------------- /api/schemas/auth.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from datetime import datetime, date 3 | 4 | from pydantic import BaseModel, AnyHttpUrl, EmailStr, HttpUrl, validator 5 | 6 | 7 | class LoginRequest(BaseModel): 8 | email: EmailStr 9 | password: str 10 | 11 | class Config: 12 | orm_mode = False 13 | 14 | 15 | class LoginResponse(BaseModel): 16 | success: Optional[bool] 17 | error: Optional[bool] 18 | message: Optional[str] = None 19 | 20 | class Config: 21 | orm_mode = False 22 | 23 | 24 | class LogoutResponse(BaseModel): 25 | success: bool 26 | error: bool 27 | message: Optional[str] = None 28 | 29 | class Config: 30 | orm_mode = False 31 | 32 | 33 | class RegisterResponse(BaseModel): 34 | success: bool 35 | error: bool 36 | message: Optional[str] = None 37 | 38 | class Config: 39 | orm_mode = False 40 | -------------------------------------------------------------------------------- /api/core/cors.py: -------------------------------------------------------------------------------- 1 | # C.O.R.S. 2 | cors_origins = [ 3 | "http://localhost", 4 | "http://localhost:8080", 5 | "http://localhost:3000", 6 | # Production Client on Vercel 7 | "https://twitter-clone.programmertutor.com", 8 | "https://www.twitter-clone.programmertutor.com", 9 | "https://twitter.dericfagnan.com", 10 | "https://www.twitter.dericfagnan.com", 11 | # Websocket Origins 12 | "ws://localhost", 13 | "wss://localhost", 14 | "ws://localhost:8080", 15 | "wss://localhost:8080", 16 | "ws://twitter-clone.programmertutor.com", 17 | "ws://www.twitter-clone.programmertutor.com", 18 | "wss://twitter-clone.programmertutor.com", 19 | "wss://www.twitter-clone.programmertutor.com", 20 | "ws://twitter.dericfagnan.com", 21 | "ws://www.twitter.dericfagnan.com", 22 | "wss://twitter.dericfagnan.com", 23 | "wss://www.twitter.dericfagnan.com", 24 | ] 25 | -------------------------------------------------------------------------------- /api/schemas/comments.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from datetime import datetime, date 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class CommentBase(BaseModel): 8 | id: Optional[int] 9 | userId: int 10 | tweetId: int 11 | content: str 12 | createdAt: datetime 13 | 14 | 15 | class Comment(CommentBase): 16 | # tweet: List[Tweet] = [] 17 | username: str 18 | 19 | class Config: 20 | orm_mode = True 21 | 22 | 23 | class CommentCreate(BaseModel): 24 | content: str 25 | tweetId: int 26 | 27 | 28 | class CommentDelete(BaseModel): 29 | commentId: int 30 | 31 | 32 | class CommentUpdate(BaseModel): 33 | commentId: int 34 | newContent: str 35 | 36 | 37 | class WSCommentCreated(BaseModel): 38 | comment: Comment 39 | 40 | 41 | class WSCommentUpdated(WSCommentCreated): 42 | pass # same as created 43 | 44 | 45 | class WSCommentDeleted(BaseModel): 46 | tweetId: int 47 | commentId: int 48 | -------------------------------------------------------------------------------- /alembic/versions/88c139446b32_create_comments_table.py: -------------------------------------------------------------------------------- 1 | """create comments table 2 | 3 | Revision ID: 88c139446b32 4 | Revises: ee2884e1c996 5 | Create Date: 2021-03-17 14:26:59.875478 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '88c139446b32' 14 | down_revision = 'ee2884e1c996' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | op.create_table( 21 | 'comments', 22 | sa.Column('id', sa.Integer, primary_key=True, index=True), 23 | sa.Column('user_id', sa.Integer, sa.ForeignKey("users.id", ondelete="CASCADE")), 24 | sa.Column('tweet_id', sa.Integer, sa.ForeignKey("tweets.id", ondelete="CASCADE")), 25 | sa.Column('content', sa.String, index=True), 26 | sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()) 27 | ) 28 | 29 | 30 | def downgrade(): 31 | op.drop_table('comments') -------------------------------------------------------------------------------- /alembic/versions/f9b8f7b3cd8a_add_messages_table.py: -------------------------------------------------------------------------------- 1 | """add messages_table 2 | 3 | Revision ID: f9b8f7b3cd8a 4 | Revises: 809293dcf958 5 | Create Date: 2021-04-27 23:55:17.717957 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = 'f9b8f7b3cd8a' 14 | down_revision = '809293dcf958' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | op.create_table( 21 | 'messages', 22 | sa.Column('id', sa.Integer, primary_key=True, index=True), 23 | sa.Column('user_from_id', sa.Integer, sa.ForeignKey( 24 | "users.id", ondelete="CASCADE"), index=True), 25 | sa.Column('user_to_id', sa.Integer, sa.ForeignKey( 26 | "users.id", ondelete="CASCADE"), index=True,), 27 | sa.Column('content', sa.String), 28 | sa.Column('is_read', sa.Boolean, default=False), 29 | sa.Column('created_at', sa.DateTime(timezone=True), 30 | server_default=sa.func.now()) 31 | ) 32 | 33 | 34 | def downgrade(): 35 | op.drop_table('messages') 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Deric Fagnan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /api/schemas/websockets.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import List, Optional, Generic, TypeVar 3 | from datetime import datetime, date 4 | 5 | from pydantic import BaseModel 6 | from pydantic.generics import GenericModel 7 | 8 | 9 | class WSMessageAction(str, Enum): 10 | ChatMessageNew = "chat.message.new" 11 | ChatMessageDeleted = "chat.message.deleted" 12 | ChatUserOnline = "chat.user.online" 13 | ChatUserTyping = "chat.user.typing" 14 | AuthRequired = "auth.required" 15 | NewFollower = "followers.followed" 16 | LostFollower = "followers.unfollowed" 17 | NewComment = "comments.new" 18 | DeletedComment = "comments.deleted" 19 | UpdatedComment = "comments.updated" 20 | UpdatedCommentLike = "comments.likes.changed" 21 | UpdatedTweetLike = "tweets.likes.changed" 22 | 23 | 24 | WSMessageBody = TypeVar("WSBody") 25 | 26 | 27 | class WSMessageError(BaseModel): 28 | code: int 29 | message: str 30 | 31 | # WSAction = 32 | 33 | 34 | class WSMessage(GenericModel, Generic[WSMessageBody]): 35 | action: WSMessageAction 36 | body: WSMessageBody 37 | status: Optional[int] 38 | error: Optional[WSMessageError] 39 | -------------------------------------------------------------------------------- /api/schemas/tweet_likes.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, ForwardRef, Generic, Any 2 | from pydantic import BaseModel 3 | 4 | from datetime import datetime, date 5 | 6 | 7 | class BasicTweet(BaseModel): 8 | id: int 9 | userId: int 10 | 11 | class Config: 12 | orm_mode = True 13 | 14 | 15 | class BasicUser(BaseModel): 16 | id: int 17 | email: str 18 | username: str 19 | bio: Optional[str] 20 | 21 | class Config: 22 | orm_mode = True 23 | 24 | 25 | class TweetLikeCreateResponseBody(BaseModel): 26 | tweetId: int 27 | userId: int 28 | username: str 29 | 30 | 31 | class TweetLikeResponseBody(BaseModel): 32 | tweetId: int 33 | userId: int 34 | username: str 35 | 36 | 37 | class TweetLike(BaseModel): 38 | tweetId: int 39 | tweet: BasicTweet 40 | userId: BasicUser 41 | 42 | class Config: 43 | orm_mode = True 44 | 45 | 46 | class TweetLikeCreateRequestBody(BaseModel): 47 | tweetId: int 48 | 49 | 50 | class TweetLikeDeleteRequestBody(BaseModel): 51 | tweetId: int 52 | 53 | 54 | class WSTweetLikeUpdated(BaseModel): 55 | isLiked: bool 56 | tweetLike: TweetLikeResponseBody 57 | -------------------------------------------------------------------------------- /api/schemas/comment_likes.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, ForwardRef, Generic, Any 2 | from pydantic import BaseModel 3 | 4 | from datetime import datetime, date 5 | 6 | 7 | class BasicComment(BaseModel): 8 | id: int 9 | userId: int 10 | 11 | class Config: 12 | orm_mode = True 13 | 14 | 15 | class BasicUser(BaseModel): 16 | id: int 17 | email: str 18 | username: str 19 | bio: Optional[str] 20 | 21 | class Config: 22 | orm_mode = True 23 | 24 | 25 | class CommentLikeCreateResponseBody(BaseModel): 26 | commentId: int 27 | userId: int 28 | username: str 29 | 30 | 31 | class CommentLikeResponseBody(BaseModel): 32 | commentLikeId: int 33 | commentId: int 34 | userId: int 35 | username: str 36 | 37 | 38 | class CommentLike(BaseModel): 39 | commentId: int 40 | comment: BasicComment 41 | userId: BasicUser 42 | 43 | class Config: 44 | orm_mode = True 45 | 46 | 47 | class CommentLikeCreateRequestBody(BaseModel): 48 | commentId: Optional[int] 49 | 50 | 51 | class CommentLikeDeleteRequestBody(BaseModel): 52 | commentId: int 53 | 54 | 55 | class WSCommentLikeUpdated(BaseModel): 56 | isLiked: bool 57 | commentLike: CommentLikeResponseBody 58 | -------------------------------------------------------------------------------- /alembic/versions/64d575f54bd3_create_user_table.py: -------------------------------------------------------------------------------- 1 | """create user table 2 | 3 | Revision ID: 64d575f54bd3 4 | Revises: 5 | Create Date: 2021-03-16 21:26:48.338701 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '64d575f54bd3' 14 | down_revision = None 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | op.create_table( 21 | 'users', 22 | sa.Column('id', sa.Integer, primary_key=True, index=True), 23 | sa.Column('email', sa.String, unique=True, index=True, nullable=False), 24 | sa.Column('username', sa.String, index=True), 25 | sa.Column('created_at', sa.DateTime(timezone=True), 26 | server_default=sa.func.now()), 27 | sa.Column('updated_at', sa.DateTime(timezone=True), 28 | server_default=sa.func.now()), 29 | sa.Column('bio', sa.String, index=True), 30 | sa.Column('birthdate', sa.Date, index=True), 31 | sa.Column('hashed_password', sa.String, nullable=False), 32 | sa.Column('confirmation_key', sa.String, nullable=True), 33 | sa.Column('account_verified', sa.Boolean, nullable=True, default=False) 34 | 35 | ) 36 | 37 | 38 | def downgrade(): 39 | op.drop_table('users') 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiofiles==0.5.0 2 | alembic==1.5.8 3 | aniso8601==7.0.0 4 | appdirs==1.4.4 5 | async-exit-stack==1.0.1 6 | async-generator==1.10 7 | autopep8==1.5.6 8 | bcrypt==3.2.0 9 | beautifulsoup4==4.9.3 10 | black==21.4b2 11 | certifi==2020.12.5 12 | cffi==1.14.5 13 | chardet==4.0.0 14 | click==7.1.2 15 | cryptography==3.4.7 16 | dnspython==2.1.0 17 | ecdsa==0.14.1 18 | email-validator==1.1.2 19 | fastapi==0.65.2 20 | graphene==2.1.8 21 | graphql-core==2.3.2 22 | graphql-relay==2.0.1 23 | greenlet==1.0.0 24 | h11==0.12.0 25 | httptools==0.1.1 26 | idna==2.10 27 | itsdangerous==1.1.0 28 | Jinja2==2.11.3 29 | Mako==1.1.4 30 | MarkupSafe==1.1.1 31 | mypy-extensions==0.4.3 32 | orjson==3.5.1 33 | passlib==1.7.4 34 | pathspec==0.8.1 35 | promise==2.3 36 | psycopg2-binary==2.8.6 37 | pyasn1==0.4.8 38 | pycodestyle==2.7.0 39 | pycparser==2.20 40 | pydantic==1.8.2 41 | python-dateutil==2.8.1 42 | python-dotenv==0.17.0 43 | python-editor==1.0.4 44 | python-http-client==3.3.2 45 | python-jose==3.2.0 46 | python-multipart==0.0.5 47 | PyYAML==5.4.1 48 | regex==2021.4.4 49 | requests==2.25.1 50 | rsa==4.7.2 51 | Rx==1.6.1 52 | sendgrid==6.7.0 53 | six==1.15.0 54 | soupsieve==2.2.1 55 | SQLAlchemy==1.4.5 56 | starkbank-ecdsa==1.1.0 57 | starlette==0.14.2 58 | toml==0.10.2 59 | typing-extensions==3.7.4.3 60 | ujson==3.2.0 61 | urllib3==1.26.5 62 | uvicorn==0.13.4 63 | uvloop==0.15.2 64 | watchgod==0.7 65 | websockets==9.1 66 | -------------------------------------------------------------------------------- /api/schemas/users.py: -------------------------------------------------------------------------------- 1 | # Standard Library 2 | from datetime import datetime, date 3 | 4 | # Types 5 | from typing import List, Optional, Any 6 | from pydantic import BaseModel, validator 7 | 8 | # SQLAlchemy 9 | from sqlalchemy.orm import Query 10 | 11 | 12 | class UserBase(BaseModel): 13 | email: str 14 | username: str 15 | bio: Optional[str] 16 | birthdate: Optional[date] 17 | 18 | 19 | class UserCreate(UserBase): 20 | password: str 21 | 22 | 23 | class User(UserBase): 24 | id: int 25 | 26 | class Config: 27 | orm_mode = True 28 | 29 | 30 | class UserWithPassword(User): 31 | hashed_password: str 32 | 33 | class Config: 34 | orm_mode = True 35 | 36 | 37 | class UserDeleteRequestBody(BaseModel): 38 | password: str 39 | 40 | 41 | class UserUpdateResponseBody(BaseModel): 42 | id: int 43 | email: str 44 | username: str 45 | bio: Optional[str] 46 | birthdate: Optional[date] 47 | 48 | class Config: 49 | orm_mode = True 50 | 51 | 52 | class UserUpdateRequestBody(BaseModel): 53 | password: str 54 | newUsername: Optional[str] 55 | newBio: Optional[str] 56 | 57 | 58 | class BasicTweet(BaseModel): 59 | id: int 60 | content: str 61 | createdAt: datetime 62 | 63 | 64 | class UserResponse(BaseModel): 65 | id: int 66 | email: str 67 | username: str 68 | bio: Optional[str] 69 | birthdate: Optional[date] 70 | # tweets: Optional[List[BasicTweet]] = [] 71 | 72 | 73 | class UserAccountConfirmationRequestBody(BaseModel): 74 | confirmationKey: str 75 | -------------------------------------------------------------------------------- /docker-compose-dev.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: "3.9" 3 | services: 4 | db: 5 | image: postgres 6 | restart: always 7 | environment: 8 | POSTGRES_DB: $POSTGRES_DB 9 | POSTGRES_USER: $POSTGRES_USER 10 | POSTGRES_PASSWORD: $POSTGRES_PASSWORD 11 | PGDATA: /var/lib/postgresql/data 12 | volumes: 13 | - ./pgdata:/var/lib/postgresql/data 14 | ports: 15 | - "5433:5432" 16 | # graphql-engine: 17 | # image: hasura/graphql-engine:v1.3.3 18 | # ports: 19 | # - "8432:8080" 20 | # depends_on: 21 | # - "db" 22 | # restart: always 23 | # environment: 24 | # HASURA_GRAPHQL_DATABASE_URL: $LOCAL_DOCKER_INTERNAL_POSTGRES_URL 25 | # ## enable the console served by server 26 | # HASURA_GRAPHQL_ENABLE_CONSOLE: "true" # set to "false" to disable console 27 | # ## enable debugging mode. It is recommended to disable this in production 28 | # HASURA_GRAPHQL_DEV_MODE: "true" 29 | # HASURA_GRAPHQL_ENABLED_LOG_TYPES: startup, http-log, webhook-log, websocket-log, query-log 30 | # ## uncomment next line to set an admin secret 31 | # HASURA_GRAPHQL_ADMIN_SECRET: $SECRET_KEY 32 | web: 33 | build: 34 | context: . 35 | dockerfile: Dockerfile-dev 36 | ports: 37 | - "8001:80" 38 | depends_on: 39 | - db 40 | volumes: 41 | - ./api:/app/api 42 | environment: 43 | MODULE_NAME: "api.main" 44 | POSTGRES_URL: $POSTGRES_URL 45 | env_file: .env 46 | entrypoint: /start-reload.sh # For dev only - adds hot-reloading 47 | volumes: 48 | db-data: null 49 | pgadmin-data: null 50 | -------------------------------------------------------------------------------- /api/routers/followers.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import APIRouter, HTTPException, Request, Depends, status 3 | 4 | # SQLAlchemy 5 | from sqlalchemy.orm import Session 6 | 7 | # Types 8 | from typing import List, Optional 9 | 10 | # Custom Modules 11 | from .. import schemas, crud 12 | from ..dependencies import get_db, get_current_user 13 | from ..core import security 14 | from ..core.config import settings 15 | 16 | # FastAPI router object 17 | router = APIRouter(prefix="/followers", tags=['followers']) 18 | 19 | 20 | @router.get("/{userId}", response_model=List[schemas.FollowersResponse]) 21 | def get_all_tweets(userId: int, db: Session = Depends(get_db)): 22 | """ 23 | The GET method for this endpoint requires a userId and will send 24 | back information about all users the follow that user. 25 | 26 | This endpoint will always return an array of objects. 27 | """ 28 | followers: List[schemas.Follower] = crud.get_all_followers(db, userId) 29 | return [ 30 | schemas.FollowersResponse( 31 | userId=follower.user.id, 32 | email=follower.user.email, 33 | username=follower.user.username, 34 | bio=follower.user.bio, 35 | birthdate=follower.user.birthdate 36 | ) for follower in followers 37 | ] 38 | 39 | 40 | @router.get("/count/{userId}", response_model=schemas.CountBase) 41 | def get_followers_count_for_user( 42 | userId: int, 43 | db: Session = Depends(get_db) 44 | ): 45 | count = crud.get_followers_for_user(db, user_id=userId) 46 | 47 | return schemas.CountBase( 48 | count=count 49 | ) 50 | -------------------------------------------------------------------------------- /api/core/security.py: -------------------------------------------------------------------------------- 1 | # Standard Library 2 | from datetime import datetime, timedelta 3 | 4 | # Types 5 | from typing import Any, Union, Optional 6 | 7 | # SQLAlchemy 8 | from sqlalchemy.orm import Session 9 | 10 | # JWT 11 | from jose import jwt, JWTError 12 | 13 | # Hashing 14 | from passlib.context import CryptContext 15 | 16 | # Custom Modules 17 | from .config import settings 18 | from ..schemas import User 19 | from .. import crud 20 | 21 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") 22 | 23 | 24 | ALGORITHM = "HS256" 25 | 26 | 27 | def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: 28 | """Create a JWT (access token) based on the provided data 29 | """ 30 | to_encode = data.copy() 31 | if expires_delta: 32 | expire = datetime.utcnow() + expires_delta 33 | else: 34 | expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) 35 | to_encode.update({"exp": expire}) 36 | encoded_jwt = jwt.encode( 37 | to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) 38 | return encoded_jwt 39 | 40 | 41 | def verify_password(plain_password: str, hashed_password: str) -> bool: 42 | """Check that hashed(plain_password) matches hashed_password. 43 | """ 44 | return pwd_context.verify(plain_password, hashed_password) 45 | 46 | 47 | def get_password_hash(password: str) -> str: 48 | """Return the hashed version of password 49 | """ 50 | return pwd_context.hash(password) 51 | 52 | 53 | def decode_token(token: str): 54 | """Return a dictionary that represents the decoded JWT. 55 | """ 56 | return jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) 57 | 58 | 59 | def authenticate_user(db: Session, email: str, password: str) -> Union[bool, User]: 60 | """Based on the provided email & password, verify that the credentials match 61 | the records contained in the database. 62 | """ 63 | user = crud.get_user_by_email_or_username(db, email) 64 | if not user: 65 | # No user with that email exists in the database 66 | return False 67 | if not verify_password(password, user.hashed_password): 68 | # The user exists but the password was incorrect 69 | return False 70 | return user 71 | -------------------------------------------------------------------------------- /alembic/env.py: -------------------------------------------------------------------------------- 1 | from logging.config import fileConfig 2 | 3 | from sqlalchemy import engine_from_config, create_engine 4 | from sqlalchemy import pool 5 | 6 | from alembic import context 7 | 8 | from api.core.config import get_db_connection_url 9 | 10 | import os 11 | 12 | # this is the Alembic Config object, which provides 13 | # access to the values within the .ini file in use. 14 | config = context.config 15 | 16 | # Interpret the config file for Python logging. 17 | # This line sets up loggers basically. 18 | fileConfig(config.config_file_name) 19 | 20 | # add your model's MetaData object here 21 | # for 'autogenerate' support 22 | # from myapp import mymodel 23 | # target_metadata = mymodel.Base.metadata 24 | target_metadata = None 25 | 26 | # other values from the config, defined by the needs of env.py, 27 | # can be acquired: 28 | # my_important_option = config.get_main_option("my_important_option") 29 | # ... etc. 30 | 31 | 32 | def run_migrations_offline(): 33 | """Run migrations in 'offline' mode. 34 | 35 | This configures the context with just a URL 36 | and not an Engine, though an Engine is acceptable 37 | here as well. By skipping the Engine creation 38 | we don't even need a DBAPI to be available. 39 | 40 | Calls to context.execute() here emit the given string to the 41 | script output. 42 | 43 | """ 44 | context.configure( 45 | url=get_db_connection_url(), 46 | target_metadata=target_metadata, 47 | literal_binds=True, 48 | dialect_opts={"paramstyle": "named"}, 49 | ) 50 | 51 | with context.begin_transaction(): 52 | context.run_migrations() 53 | 54 | 55 | def run_migrations_online(): 56 | """Run migrations in 'online' mode. 57 | 58 | In this scenario we need to create an Engine 59 | and associate a connection with the context. 60 | 61 | """ 62 | connectable = create_engine(get_db_connection_url()) 63 | 64 | with connectable.connect() as connection: 65 | context.configure( 66 | connection=connection, target_metadata=target_metadata 67 | ) 68 | 69 | with context.begin_transaction(): 70 | context.run_migrations() 71 | 72 | 73 | if context.is_offline_mode(): 74 | run_migrations_offline() 75 | else: 76 | run_migrations_online() 77 | -------------------------------------------------------------------------------- /docker-compose-prod.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | backend: 5 | build: 6 | context: . 7 | dockerfile: Dockerfile-prod 8 | restart: always 9 | environment: 10 | MODULE_NAME: "api.main" 11 | POSTGRES_URL: $POSTGRES_URL 12 | env_file: .env 13 | labels: 14 | # 15 | # Enable Traefik for this specific "backend" service 16 | # 17 | - traefik.enable=true 18 | # 19 | # Define the port inside of the Docker service to use 20 | # 21 | - traefik.http.services.app.loadbalancer.server.port=80 22 | # 23 | # Make Traefik use this domain in HTTP 24 | # 25 | - traefik.http.routers.app-http.entrypoints=http 26 | - traefik.http.routers.app-http.rule=Host(`${PRODUCTION_DOMAIN_HOST}`) 27 | # 28 | # Use the traefik-public network (declared below) 29 | # 30 | - traefik.docker.network=traefik-public 31 | # 32 | # Make Traefik use this domain in HTTPS 33 | # 34 | - traefik.http.routers.app-https.entrypoints=https 35 | - traefik.http.routers.app-https.rule=Host(`${PRODUCTION_DOMAIN_HOST}`) 36 | - traefik.http.routers.app-https.tls=true 37 | # 38 | # Websockets 39 | # 40 | - traefik.wss.routers.ws-service-wss.rule=Host(`${PRODUCTION_DOMAIN_HOST}`) 41 | - traefik.wss.routers.ws-service-wss.entrypoints=https 42 | - traefik.wss.routers.ws-service-wss.tls=true 43 | # 44 | # Use the "le" (Let's Encrypt) resolver 45 | # 46 | - traefik.http.routers.app-https.tls.certresolver=le 47 | # 48 | # https-redirect middleware to redirect HTTP to HTTPS 49 | # 50 | - traefik.http.middlewares.https-redirect.redirectscheme.scheme=https 51 | - traefik.http.middlewares.https-redirect.redirectscheme.permanent=true 52 | # 53 | # Middleware to redirect HTTP to HTTPS 54 | # 55 | - traefik.http.routers.app-http.middlewares=https-redirect 56 | # 57 | # - traefik.http.routers.app-https.middlewares=admin-auth 58 | # 59 | networks: 60 | # Use the public network created to be shared between Traefik and 61 | # any other service that needs to be publicly available with HTTPS 62 | - traefik-public 63 | 64 | networks: 65 | traefik-public: 66 | external: true 67 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | script_location = alembic 6 | 7 | # template used to generate migration files 8 | # file_template = %%(rev)s_%%(slug)s 9 | 10 | # sys.path path, will be prepended to sys.path if present. 11 | # defaults to the current working directory. 12 | prepend_sys_path = . 13 | 14 | # timezone to use when rendering the date 15 | # within the migration file as well as the filename. 16 | # string value is passed to dateutil.tz.gettz() 17 | # leave blank for localtime 18 | # timezone = 19 | 20 | # max length of characters to apply to the 21 | # "slug" field 22 | # truncate_slug_length = 40 23 | 24 | # set to 'true' to run the environment during 25 | # the 'revision' command, regardless of autogenerate 26 | # revision_environment = false 27 | 28 | # set to 'true' to allow .pyc and .pyo files without 29 | # a source .py file to be detected as revisions in the 30 | # versions/ directory 31 | # sourceless = false 32 | 33 | # version location specification; this defaults 34 | # to alembic/versions. When using multiple version 35 | # directories, initial revisions must be specified with --version-path 36 | # version_locations = %(here)s/bar %(here)s/bat alembic/versions 37 | 38 | # the output encoding used when revision files 39 | # are written from script.py.mako 40 | # output_encoding = utf-8 41 | 42 | 43 | [post_write_hooks] 44 | # post_write_hooks defines scripts or Python functions that are run 45 | # on newly generated revision scripts. See the documentation for further 46 | # detail and examples 47 | 48 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 49 | # hooks=black 50 | # black.type=console_scripts 51 | # black.entrypoint=black 52 | # black.options=-l 79 53 | 54 | # Logging configuration 55 | [loggers] 56 | keys = root,sqlalchemy,alembic 57 | 58 | [handlers] 59 | keys = console 60 | 61 | [formatters] 62 | keys = generic 63 | 64 | [logger_root] 65 | level = WARN 66 | handlers = console 67 | qualname = 68 | 69 | [logger_sqlalchemy] 70 | level = WARN 71 | handlers = 72 | qualname = sqlalchemy.engine 73 | 74 | [logger_alembic] 75 | level = INFO 76 | handlers = 77 | qualname = alembic 78 | 79 | [handler_console] 80 | class = StreamHandler 81 | args = (sys.stderr,) 82 | level = NOTSET 83 | formatter = generic 84 | 85 | [formatter_generic] 86 | format = %(levelname)-5.5s [%(name)s] %(message)s 87 | datefmt = %H:%M:%S 88 | -------------------------------------------------------------------------------- /api/core/websocket/connection_manager.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import WebSocket 3 | from fastapi.encoders import jsonable_encoder 4 | 5 | # Standard Library 6 | from typing import Union, Dict, Any 7 | import json 8 | from datetime import date, datetime 9 | 10 | from ...schemas.websockets import WSMessage, WSMessageAction 11 | 12 | 13 | class ConnectionManager: 14 | def __init__(self): 15 | # Map user id to a websocket 16 | self.active_connections: Dict[int, WebSocket] = {} 17 | 18 | async def connect(self, websocket: WebSocket, user_id: int): 19 | try: 20 | await websocket.accept() 21 | self.active_connections[user_id] = websocket 22 | except Exception as e: 23 | print("Error connecting", e) 24 | 25 | async def disconnect(self, user_id: int): 26 | try: 27 | await self.active_connections[user_id].close() 28 | self.active_connections.pop(user_id) 29 | except: 30 | print("Error disconnecting...") 31 | 32 | async def send_personal_message(self, message: Union[Dict, str], user_id: int): 33 | # TODO: I need to work out a better way to send proper json 34 | # print("user is online? ", user_id in self.active_connections) 35 | if user_id not in self.active_connections: 36 | return 37 | if type(message) is str: 38 | message = {"message": message} 39 | try: 40 | await self.active_connections[user_id].send_text(json.dumps(jsonable_encoder(message))) 41 | except: 42 | print("Error Sending Message") 43 | 44 | async def broadcast(self, message: Union[Dict, WSMessage[Any]], current_user_id: int): 45 | for user_id, connection in self.active_connections.items(): 46 | if(current_user_id != user_id): 47 | try: 48 | await connection.send_json(jsonable_encoder(message)) 49 | except: 50 | print(f"could not send to user: {user_id}") 51 | 52 | def user_is_online(self, user_id: int): 53 | return user_id in self.active_connections 54 | 55 | async def show_all_connections(self): 56 | print("\n** Listing Active WebSocket Connections **") 57 | for user_id, connection in self.active_connections.items(): 58 | print(f"\tUser ID: {user_id} | Conn: {connection}") 59 | print("********************************\n") 60 | 61 | 62 | ws_manager = ConnectionManager() 63 | -------------------------------------------------------------------------------- /api/routers/auth.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import APIRouter, HTTPException, Depends 3 | from fastapi.security import OAuth2PasswordRequestForm, OAuth2 4 | 5 | from starlette.status import HTTP_403_FORBIDDEN 6 | from starlette.responses import RedirectResponse, Response, JSONResponse 7 | from starlette.requests import Request 8 | 9 | # SQLAlchemy 10 | from sqlalchemy.orm import Session 11 | 12 | # Types 13 | from typing import List 14 | 15 | # Custom Modules 16 | from .. import schemas, crud 17 | from ..core import security 18 | from ..core.config import settings 19 | from ..dependencies import get_db, get_current_user 20 | 21 | import os 22 | 23 | router = APIRouter(tags=['auth']) 24 | 25 | 26 | @router.post("/token") 27 | async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): 28 | """User will attempt to authenticate with a email/password flow 29 | """ 30 | 31 | user = security.authenticate_user( 32 | db, form_data.username, form_data.password) 33 | if not user: 34 | # Wrong email or password provided 35 | raise HTTPException( 36 | status_code=401, detail="Incorrect username or password") 37 | 38 | if user.account_verified == False: 39 | raise HTTPException( 40 | status_code=400, detail="Email is not verified. Please check your email.") 41 | 42 | token = security.create_access_token(data={"sub": user.email}) 43 | response.set_cookie( 44 | key="Authorization", 45 | value=f'Bearer {token}', 46 | samesite="Lax" if "dev" in os.environ.get("ENV") else "None", 47 | domain="localhost" 48 | if "dev" in os.environ.get("ENV") 49 | else "dericfagnan.com", 50 | secure="dev" not in os.environ.get("ENV"), 51 | httponly=True, 52 | max_age=60 * 30, 53 | expires=60 * 30, 54 | ) 55 | 56 | return {"access_token": token, "token_type": "bearer"} 57 | 58 | 59 | @router.post("/logout") 60 | async def logout_and_expire_cookie(response: Response, current_user: schemas.User = Depends(get_current_user)): 61 | # response.delete_cookie("Authorization") 62 | response.set_cookie( 63 | key="Authorization", 64 | value=f'', 65 | samesite="Lax" if "dev" in os.environ.get("ENV") else "None", 66 | domain="localhost" 67 | if "dev" in os.environ.get("ENV") 68 | else "dericfagnan.com", 69 | secure="dev" not in os.environ.get("ENV"), 70 | httponly=True, 71 | max_age=1, 72 | expires=1, 73 | ) 74 | 75 | return {} 76 | -------------------------------------------------------------------------------- /api/core/config.py: -------------------------------------------------------------------------------- 1 | # Standard Library 2 | import os 3 | 4 | # Types 5 | from typing import Any, Dict, List, Optional, Union 6 | from pydantic import AnyHttpUrl, BaseSettings, EmailStr, HttpUrl, PostgresDsn, validator 7 | 8 | 9 | class Settings(BaseSettings): 10 | API_V1_STR: str = "/api/v1" 11 | SECRET_KEY: str = str(os.environ.get("SECRET_KEY")) 12 | # 60 minutes * 24 hours * 8 days = 8 days 13 | ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 14 | 15 | # SERVER_NAME: str 16 | # SERVER_HOST: AnyHttpUrl 17 | 18 | # BACKEND_CORS_ORIGINS is a JSON-formatted list of origins 19 | # e.g: '["http://localhost", "http://localhost:4200", "http://localhost:3000", \ 20 | # "http://localhost:8080", "http://local.dockertoolbox.tiangolo.com"]' 21 | BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] 22 | 23 | @validator("BACKEND_CORS_ORIGINS", pre=True) 24 | def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]: 25 | if isinstance(v, str) and not v.startswith("["): 26 | return [i.strip() for i in v.split(",")] 27 | elif isinstance(v, (list, str)): 28 | return v 29 | raise ValueError(v) 30 | 31 | PROJECT_NAME: str = "Twitter Clone Project" 32 | 33 | EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48 34 | EMAIL_TEMPLATES_DIR: str = "/app/app/email-templates/build" 35 | EMAILS_ENABLED: bool = False 36 | 37 | @validator("EMAILS_ENABLED", pre=True) 38 | def get_emails_enabled(cls, v: bool, values: Dict[str, Any]) -> bool: 39 | return bool( 40 | values.get("SMTP_HOST") 41 | and values.get("SMTP_PORT", "") 42 | and values.get("EMAILS_FROM_EMAIL", "") 43 | ) 44 | 45 | EMAIL_TEST_USER: EmailStr = os.environ.get( 46 | "EMAIL_TEST_USER") # type: ignore 47 | FIRST_SUPERUSER: EmailStr = os.environ.get("FIRST_SUPERUSER_EMAIL") 48 | FIRST_SUPERUSER_PASSWORD: str = os.environ.get("FIRST_SUPERUSER_PASSWORD") 49 | USERS_OPEN_REGISTRATION: bool = False 50 | 51 | class Config: 52 | case_sensitive = True 53 | 54 | 55 | settings = Settings() 56 | 57 | 58 | def get_db_connection_url(): 59 | env = os.environ.get("ENV") 60 | if not env: 61 | return "" 62 | 63 | if env == "localhost-development": 64 | return os.environ.get("LOCAL_POSTGRES_URL") 65 | 66 | elif env == "development": 67 | return os.environ.get("LOCAL_DOCKER_INTERNAL_POSTGRES_URL") 68 | 69 | elif env == "staging": 70 | return os.environ.get("STAGING_POSTGRES_URL") 71 | 72 | elif env == "production": 73 | return os.environ.get("PRODUCTION_POSTGRES_URL") 74 | -------------------------------------------------------------------------------- /api/routers/tweet_likes.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import APIRouter, HTTPException, Request, Depends, status 3 | 4 | # SQLAlchemy 5 | from sqlalchemy.orm import Session 6 | 7 | # Types 8 | from typing import List, Optional 9 | 10 | # Custom Modules 11 | from .. import schemas, crud 12 | from ..dependencies import get_db, get_current_user 13 | from ..core import security 14 | from ..core.config import settings 15 | 16 | from ..core.websocket.connection_manager import ws_manager 17 | from ..schemas.websockets import WSMessage, WSMessageAction 18 | 19 | # FastAPI router object 20 | router = APIRouter(prefix="/tweet-likes", tags=['tweet-likes']) 21 | 22 | 23 | @router.get("", response_model=List[schemas.TweetLikeResponseBody]) 24 | async def get_all_tweet_likes(tweetId: Optional[int] = None, db: Session = Depends(get_db)): 25 | """ 26 | The GET method for this endpoint will send back either all, or specific likes based on tweet. This endpoint will always return an array of objects. 27 | 28 | If you want all likes, simply make the GET request and send no data. If you want likes from a specific tweet, send the tweet Id 29 | 30 | In the example, we send the numeric id 1. The API returns all likes on tweets 1. If you want all likes on all tweets, send no data. 31 | 32 | An error will be returned if any tweetId does not exist. 33 | """ 34 | tweet_likes = [] 35 | if tweetId: 36 | tweet_likes = crud.get_all_tweet_likes_for_tweet(db, tweetId) 37 | 38 | else: 39 | tweet_likes = crud.get_all_tweet_likes(db) 40 | 41 | return [ 42 | schemas.TweetLikeResponseBody( 43 | tweetId=like.tweet_id, 44 | userId=like.user.id, 45 | username=like.user.username 46 | ) for like in tweet_likes 47 | ] 48 | 49 | 50 | @router.post("", response_model=schemas.EmptyResponse) 51 | async def like_a_tweet( 52 | tweet_body: schemas.TweetLikeCreateRequestBody, 53 | db: Session = Depends(get_db), 54 | current_user: schemas.User = Depends(get_current_user) 55 | ): 56 | # validate & create the like record 57 | tweet_like = crud.create_tweet_like_for_tweet( 58 | db=db, tweet_id=tweet_body.tweetId, user_id=current_user.id) 59 | 60 | message = WSMessage[schemas.WSTweetLikeUpdated]( 61 | action=WSMessageAction.UpdatedTweetLike, 62 | body=schemas.WSTweetLikeUpdated( 63 | isLiked=True, 64 | tweetLike=schemas.TweetLikeResponseBody( 65 | tweetId=tweet_like.tweet_id, 66 | userId=tweet_like.user.id, 67 | username=tweet_like.user.username 68 | ) 69 | ) 70 | ) 71 | 72 | await ws_manager.broadcast(message, current_user.id) 73 | 74 | return schemas.EmptyResponse() 75 | 76 | 77 | @router.delete("", response_model=schemas.EmptyResponse) 78 | async def delete_tweet_like( 79 | request_body: schemas.TweetLikeDeleteRequestBody, 80 | db: Session = Depends(get_db), 81 | current_user: schemas.User = Depends(get_current_user)): 82 | 83 | tweet_like = crud.get_tweet_like_by_tweet_id_and_user_id( 84 | db, current_user.id, request_body.tweetId) 85 | 86 | crud.delete_tweet_like( 87 | db, current_user.id, request_body.tweetId) 88 | 89 | message = WSMessage[schemas.WSTweetLikeUpdated]( 90 | action=WSMessageAction.UpdatedTweetLike, 91 | body=schemas.WSTweetLikeUpdated( 92 | isLiked=False, 93 | tweetLike=schemas.TweetLikeResponseBody( 94 | tweetId=tweet_like.tweet_id, 95 | userId=tweet_like.user.id, 96 | username=tweet_like.user.username 97 | ) 98 | ) 99 | ) 100 | 101 | await ws_manager.broadcast(message, current_user.id) 102 | 103 | return schemas.EmptyResponse() 104 | -------------------------------------------------------------------------------- /api/background_functions/email_notifications.py: -------------------------------------------------------------------------------- 1 | # Standard Library 2 | import os 3 | 4 | # Core 5 | from ..core.sendgrid import send_email, constants 6 | from ..core.sendgrid.schema import EmailSender 7 | 8 | # Models 9 | from ..models import User, Comments 10 | 11 | # SendGrid API 12 | from sendgrid import SendGridAPIClient, Personalization, Asm 13 | from sendgrid.helpers.mail import Mail 14 | 15 | # TODO : implement function for: 16 | # - password recovery email 17 | 18 | 19 | async def send_registration_confirmation_email(username: str, email: str, confirmation_key: str): 20 | # 21 | # Build Sendgrid Mail Object 22 | # 23 | message = Mail(from_email=EmailSender.ACCOUNT.value, 24 | to_emails=email 25 | ) 26 | 27 | message.template_id = constants.REGISTRATION_CONFIRMATION_DYNAMIC_TEMPLATE_ID 28 | message.asm = Asm( 29 | constants.MAIN_UNSUBSCRIBE_GROUP_ID 30 | ) 31 | message.dynamic_template_data = { 32 | "confirmation_url": f'{os.environ.get("PRODUCTION_CLIENT_HOST_URL")}/confirm-email?confirmationKey={confirmation_key}', 33 | "username": username 34 | } 35 | await send_email(message) 36 | 37 | 38 | async def send_new_message_notification_email(user: User, other_user: User): 39 | # 40 | # Build Sendgrid Mail Object 41 | # 42 | message = Mail(from_email=EmailSender.NOTIFICATIONS.value, 43 | to_emails=other_user.email 44 | ) 45 | 46 | notification_text = f"You have a new message from {user.username}! Please log in to view your messages." 47 | message.template_id = constants.NEW_NOTIFICATION_DYNAMIC_TEMPLATE_ID 48 | message.asm = Asm( 49 | constants.MAIN_UNSUBSCRIBE_GROUP_ID 50 | ) 51 | message.dynamic_template_data = { 52 | "subject": "You have a new message", 53 | "username": other_user.username, 54 | "notification_text": notification_text, 55 | "log_in_url": f'{os.environ.get("PRODUCTION_CLIENT_HOST_URL")}/tweets' 56 | } 57 | await send_email(message) 58 | 59 | 60 | async def send_new_comment_notification_email(tweet_owner: User, commenter: User, comment: Comments): 61 | # 62 | # Build Sendgrid Mail Object 63 | # 64 | message = Mail(from_email=EmailSender.NOTIFICATIONS.value, 65 | to_emails=tweet_owner.email 66 | ) 67 | 68 | notification_text = f"{commenter.username} has commented \" {comment.content} \" on your tweet!" 69 | message.template_id = constants.NEW_NOTIFICATION_DYNAMIC_TEMPLATE_ID 70 | message.asm = Asm( 71 | constants.MAIN_UNSUBSCRIBE_GROUP_ID 72 | ) 73 | message.dynamic_template_data = { 74 | "subject": f"A user commented on your tweet", 75 | "username": tweet_owner.username, 76 | "notification_text": notification_text, 77 | "log_in_url": f'{os.environ.get("PRODUCTION_CLIENT_HOST_URL")}/tweets' 78 | } 79 | await send_email(message) 80 | 81 | 82 | async def send_new_follower_notification_email(user: User, new_follower: User): 83 | # 84 | # Build Sendgrid Mail Object 85 | # 86 | message = Mail(from_email=EmailSender.NOTIFICATIONS.value, 87 | to_emails=user.email 88 | ) 89 | 90 | notification_text = f"{new_follower.username} started following you!" 91 | message.template_id = constants.NEW_NOTIFICATION_DYNAMIC_TEMPLATE_ID 92 | message.asm = Asm( 93 | constants.MAIN_UNSUBSCRIBE_GROUP_ID 94 | ) 95 | message.dynamic_template_data = { 96 | "subject": f"You have a new follower", 97 | "username": user.username, 98 | "notification_text": notification_text, 99 | "log_in_url": f'{os.environ.get("PRODUCTION_CLIENT_HOST_URL")}/followers' 100 | } 101 | await send_email(message) 102 | -------------------------------------------------------------------------------- /api/routers/comment_likes.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import APIRouter, HTTPException, Request, Depends, status 3 | 4 | # SQLAlchemy 5 | from sqlalchemy.orm import Session 6 | 7 | # Types 8 | from typing import List, Optional 9 | 10 | # Custom Modules 11 | from .. import schemas, crud, models 12 | from ..dependencies import get_db, get_current_user 13 | from ..core import security 14 | from ..core.config import settings 15 | from ..core.websocket.connection_manager import ws_manager 16 | from ..schemas.websockets import WSMessage, WSMessageAction 17 | 18 | # FastAPI router object 19 | router = APIRouter(prefix="/comment-likes", tags=['comment-likes']) 20 | 21 | 22 | @router.get("", response_model=List[schemas.CommentLikeResponseBody]) 23 | def get_all_comment_likes(commentId: Optional[int] = None, db: Session = Depends(get_db)): 24 | """ 25 | The GET method for this endpoint will send back either all, or specific likes based on comment. This endpoint will always return an array of objects. 26 | 27 | If you want all likes, simply make the GET request and send no data. If you want likes from a specific comment, send the comment Id 28 | 29 | In the example, we send the numeric id 1. The API returns all likes on comments 1. If you want all likes on all comments, send no data. 30 | 31 | An error will be returned if any commentId does not exist. 32 | """ 33 | comment_likes: List[models.CommentLikes] = [] 34 | if commentId: 35 | comment_likes = crud.get_all_comment_likes_for_comment(db, commentId) 36 | 37 | else: 38 | comment_likes = crud.get_all_comment_likes(db) 39 | 40 | return [ 41 | schemas.CommentLikeResponseBody( 42 | commentLikeId=like.id, 43 | commentId=like.comment_id, 44 | userId=like.user.id, 45 | username=like.user.username 46 | ) for like in comment_likes 47 | ] 48 | 49 | 50 | @router.post("", response_model=schemas.CommentLikeResponseBody) 51 | async def like_a_comment( 52 | comment_body: schemas.CommentLikeCreateRequestBody, 53 | db: Session = Depends(get_db), 54 | current_user: schemas.User = Depends(get_current_user) 55 | ): 56 | # validate & create the like record 57 | comment_like = crud.create_comment_like_for_comment( 58 | db=db, comment_id=comment_body.commentId, user_id=current_user.id) 59 | 60 | return_like = schemas.CommentLikeResponseBody( 61 | commentLikeId=comment_like.id, 62 | commentId=comment_like.comment_id, 63 | userId=comment_like.user.id, 64 | username=comment_like.user.username 65 | ) 66 | 67 | message = WSMessage[schemas.WSCommentLikeUpdated]( 68 | action=WSMessageAction.UpdatedCommentLike, 69 | body=schemas.WSCommentLikeUpdated( 70 | isLiked=True, 71 | commentLike=return_like 72 | ) 73 | ) 74 | await ws_manager.broadcast(message, current_user.id) 75 | 76 | return return_like 77 | 78 | 79 | @router.delete("", response_model=schemas.EmptyResponse) 80 | async def delete_comment_like( 81 | request_body: schemas.CommentLikeDeleteRequestBody, 82 | db: Session = Depends(get_db), 83 | current_user: schemas.User = Depends(get_current_user)): 84 | 85 | comment_like = crud.get_comment_like_by_comment_id_and_user_id( 86 | db, current_user.id, request_body.commentId) 87 | 88 | delete_successful = crud.delete_comment_like_by_user_and_comment_id( 89 | db, current_user.id, request_body.commentId) 90 | 91 | message = WSMessage[schemas.WSCommentLikeUpdated]( 92 | action=WSMessageAction.UpdatedCommentLike, 93 | body=schemas.WSCommentLikeUpdated( 94 | isLiked=False, 95 | commentLike=schemas.CommentLikeResponseBody( 96 | commentLikeId=comment_like.id, 97 | commentId=comment_like.comment_id, 98 | userId=comment_like.user.id, 99 | username=comment_like.user.username 100 | ) 101 | ) 102 | ) 103 | await ws_manager.broadcast(message, current_user.id) 104 | 105 | # TODO return status for delete? 106 | return schemas.EmptyResponse() 107 | -------------------------------------------------------------------------------- /api/routers/tweets.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import APIRouter, HTTPException, Request, Depends, status 3 | 4 | # SQLAlchemy 5 | from sqlalchemy.orm import Session 6 | 7 | # Types 8 | from typing import List, Optional 9 | 10 | # Custom Modules 11 | from .. import schemas, crud 12 | from ..dependencies import get_db, get_current_user 13 | from ..core import security 14 | from ..core.config import settings 15 | 16 | # FastAPI router object 17 | router = APIRouter(prefix="/tweets", tags=['tweets']) 18 | 19 | 20 | @router.get("", response_model=List[schemas.TweetResponse]) 21 | def get_all_tweets(userId: Optional[int] = None, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): 22 | """The GET method for this endpoint will send back all tweets 23 | """ 24 | if userId: 25 | user = crud.get_user_by_id(db, userId) 26 | if not user: 27 | raise HTTPException(status.HTTP_400_BAD_REQUEST, 28 | detail="Error.Bad userId. User does not exist.") 29 | tweets = crud.get_tweets_for_user(db, userId, skip=skip, limit=limit) 30 | else: 31 | tweets = crud.get_tweets(db, skip=skip, limit=limit) 32 | return [ 33 | schemas.TweetResponse( 34 | tweetId=tweet.id, 35 | content=tweet.content, 36 | createdAt=tweet.created_at, 37 | userId=tweet.user.id, 38 | username=tweet.user.username 39 | ) for tweet in tweets 40 | ] 41 | 42 | 43 | @router.get("/liked", response_model=List[schemas.TweetResponse]) 44 | def get_all_tweets_liked_by_user( 45 | skip: int = 0, 46 | limit: int = 100, 47 | db: Session = Depends(get_db), 48 | current_user: schemas.User = Depends(get_current_user)): 49 | """Return all tweets liked by the authenticated user 50 | """ 51 | 52 | tweets = crud.get_tweets_liked_by_user( 53 | db, current_user.id, skip=skip, limit=limit) 54 | 55 | return [ 56 | schemas.TweetResponse( 57 | tweetId=tweet.id, 58 | content=tweet.content, 59 | createdAt=tweet.created_at, 60 | userId=tweet.user.id, 61 | username=tweet.user.username 62 | ) for tweet in tweets 63 | ] 64 | 65 | 66 | @router.get("/one/{tweetId}", response_model=schemas.TweetResponse) 67 | def get_single_tweet_by_id( 68 | tweetId: int, 69 | db: Session = Depends(get_db) 70 | ): 71 | """Return a single tweet based on a tweetId 72 | """ 73 | 74 | tweet = crud.get_tweet_by_id(db, tweetId) 75 | 76 | return schemas.TweetResponse( 77 | tweetId=tweet.id, 78 | content=tweet.content, 79 | createdAt=tweet.created_at, 80 | userId=tweet.user.id, 81 | username=tweet.user.username 82 | ) 83 | 84 | 85 | @router.post("", response_model=schemas.TweetResponse) 86 | def create_tweet_for_user(tweet_body: schemas.TweetCreate, 87 | db: Session = Depends(get_db), 88 | current_user: schemas.User = Depends(get_current_user)): 89 | tweet = crud.create_user_tweet( 90 | db=db, tweet=tweet_body, user_id=current_user.id) 91 | return schemas.TweetResponse( 92 | tweetId=tweet.id, 93 | content=tweet.content, 94 | createdAt=tweet.created_at, 95 | userId=tweet.user.id, 96 | username=tweet.user.username 97 | ) 98 | 99 | 100 | @router.put('/{tweetId}', response_model=schemas.EmptyResponse) 101 | def update_tweet( 102 | tweetId: int, 103 | request_body: schemas.TweetUpdate, 104 | db: Session = Depends(get_db), 105 | current_user: schemas.User = Depends(get_current_user)): 106 | update_successful = crud.update_tweet( 107 | db, current_user.id, tweetId, request_body.newContent) 108 | 109 | return {} 110 | 111 | 112 | @router.delete('/{tweetId}', response_model=schemas.EmptyResponse) 113 | def delete_tweet(tweetId: int, 114 | db: Session = Depends(get_db), 115 | current_user: schemas.User = Depends(get_current_user)): 116 | delete_successful = crud.delete_tweet(db, current_user.id, tweetId) 117 | 118 | return {} 119 | -------------------------------------------------------------------------------- /api/dependencies.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import Depends, status, HTTPException, WebSocket, WebSocketDisconnect 3 | from fastapi.exceptions import WebSocketRequestValidationError 4 | from fastapi.security import OAuth2PasswordBearer, OAuth2 5 | from fastapi.security.utils import get_authorization_scheme_param 6 | from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel 7 | 8 | from starlette.responses import RedirectResponse, Response, JSONResponse 9 | from starlette.requests import Request 10 | 11 | # SQLAlchemy 12 | from sqlalchemy.orm.session import Session 13 | 14 | # JWT 15 | from jose import JWTError, jwt 16 | 17 | from typing import Optional 18 | 19 | # Custom Modules 20 | from . import crud, schemas 21 | from .database import SessionLocal 22 | from .core import security 23 | from .core.config import settings 24 | 25 | 26 | class OAuth2PasswordBearerCookie(OAuth2): 27 | def __init__( 28 | self, 29 | tokenUrl: str, 30 | scheme_name: str = None, 31 | scopes: dict = None, 32 | auto_error: bool = True, 33 | ): 34 | if not scopes: 35 | scopes = {} 36 | flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": {}}) 37 | super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error) 38 | 39 | async def __call__(self, request: Request = None, websocket: WebSocket = None) -> Optional[str]: 40 | header_authorization: str = None 41 | cookie_authorization: str = None 42 | 43 | if request and not websocket: 44 | header_authorization = request.headers.get("Authorization") 45 | cookie_authorization = request.cookies.get("Authorization") 46 | elif websocket and not request: 47 | cookie_authorization = websocket.cookies.get("Authorization") 48 | header_authorization = websocket.headers.get("Authorization") 49 | 50 | header_scheme, header_param = get_authorization_scheme_param( 51 | header_authorization 52 | ) 53 | cookie_scheme, cookie_param = get_authorization_scheme_param( 54 | cookie_authorization 55 | ) 56 | 57 | if header_scheme.lower() == "bearer": 58 | authorization = True 59 | scheme = header_scheme 60 | param = header_param 61 | 62 | elif cookie_scheme.lower() == "bearer": 63 | authorization = True 64 | scheme = cookie_scheme 65 | param = cookie_param 66 | 67 | else: 68 | authorization = False 69 | 70 | if not authorization or scheme.lower() != "bearer": 71 | if self.auto_error and request and not websocket: 72 | raise HTTPException( 73 | status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated" 74 | ) 75 | elif websocket and not request or not self.auto_error: 76 | return None 77 | return param 78 | 79 | 80 | # oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") 81 | oauth2_scheme = OAuth2PasswordBearerCookie(tokenUrl="/token") 82 | 83 | 84 | def get_db(): 85 | """Yield a SQLAlchemy database session 86 | """ 87 | db = SessionLocal() 88 | try: 89 | yield db 90 | finally: 91 | db.close() 92 | 93 | 94 | def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): 95 | """Decode the provided jwt and extract the user using the [sub] field. 96 | """ 97 | token_data: schemas.TokenData = None 98 | 99 | credentials_exception = HTTPException( 100 | status_code=status.HTTP_401_UNAUTHORIZED, 101 | detail="Could not validate credentials", 102 | headers={"WWW-Authenticate": "Bearer"}, 103 | ) 104 | if not token: 105 | return None 106 | try: 107 | payload = security.decode_token(token) 108 | email: str = payload.get("sub") 109 | if email is None: 110 | # Something wrong with the token 111 | raise credentials_exception 112 | token_data = schemas.TokenData(email=email) 113 | except JWTError: 114 | # Something wrong with the token 115 | raise credentials_exception 116 | # 117 | # Get user from database 118 | # 119 | user = crud.get_user_by_email(db, token_data.email) 120 | if user is None: 121 | raise credentials_exception 122 | return user 123 | -------------------------------------------------------------------------------- /api/routers/follows.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import APIRouter, HTTPException, Request, Depends, status, BackgroundTasks 3 | 4 | # SQLAlchemy 5 | from sqlalchemy.orm import Session 6 | 7 | # Types 8 | from typing import List, Optional 9 | 10 | # Custom Modules 11 | from .. import schemas, crud 12 | from ..dependencies import get_db, get_current_user 13 | from ..background_functions.email_notifications import send_new_follower_notification_email 14 | from ..core import security 15 | from ..core.config import settings 16 | 17 | from ..core.websocket.connection_manager import ws_manager 18 | 19 | 20 | # FastAPI router object 21 | router = APIRouter(prefix="/follows", tags=['follows']) 22 | 23 | 24 | @router.get("/{userId}", response_model=List[schemas.FollowsResponse]) 25 | def get_follows(userId: int, db: Session = Depends(get_db)): 26 | """ 27 | The GET method for this endpoint requires a userId and will send 28 | back information about all users the userId follows . 29 | 30 | Returns: 31 | This endpoint will always return an array of objects. 32 | 33 | Errors: 34 | An error will be returned if the userId does not exist. 35 | """ 36 | 37 | user = crud.get_user_by_id(db, userId) 38 | 39 | if not user: 40 | raise HTTPException(status.HTTP_404_NOT_FOUND, 41 | detail="User does not exist") 42 | 43 | follows = crud.get_all_users_following(db, userId) 44 | return [ 45 | schemas.FollowsResponse( 46 | userId=following.follows_user.id, 47 | email=following.follows_user.email, 48 | username=following.follows_user.username, 49 | bio=following.follows_user.bio, 50 | birthdate=following.follows_user.birthdate 51 | ) for following in follows 52 | ] 53 | 54 | 55 | @router.get("/count/{userId}", response_model=schemas.CountBase) 56 | def get_follows_count_for_user( 57 | userId: int, 58 | db: Session = Depends(get_db) 59 | ): 60 | count = crud.get_following_for_user(db, user_id=userId) 61 | 62 | return schemas.CountBase( 63 | count=count 64 | ) 65 | 66 | 67 | @router.post("", response_model=schemas.EmptyResponse) 68 | async def create_follow_record_for_user( 69 | request_body: schemas.FollowsCreateRequestBody, 70 | bg_tasks: BackgroundTasks, 71 | db: Session = Depends(get_db), 72 | current_user: schemas.User = Depends(get_current_user) 73 | ): 74 | """ 75 | The POST method for this endpoint will create a follow relationship between two users. 76 | 77 | current_user requests to follow a new user 78 | 79 | """ 80 | crud.create_follow_relationship( 81 | db, current_user.id, request_body.followUserId) 82 | 83 | # 84 | # Broadcast WS message so user components can update 85 | # 86 | message = schemas.WSMessage[schemas.WSFollowsUpdateBody]( 87 | action=schemas.WSMessageAction.NewFollower, 88 | body=schemas.WSFollowsUpdateBody( 89 | userId=current_user.id, 90 | followUserId=request_body.followUserId 91 | ) 92 | ) 93 | 94 | if not ws_manager.user_is_online(request_body.followUserId): 95 | # Send a notification email 96 | new_follower = crud.get_user_by_id(db, request_body.followUserId) 97 | bg_tasks.add_task(send_new_follower_notification_email, 98 | new_follower, current_user) 99 | 100 | await ws_manager.broadcast(message, current_user.id) 101 | 102 | return schemas.EmptyResponse() 103 | 104 | 105 | @router.delete('', response_model=schemas.EmptyResponse) 106 | async def delete_follow_relationship( 107 | request_body: schemas.FollowsDeleteRequestBody, 108 | db: Session = Depends(get_db), 109 | current_user: schemas.User = Depends(get_current_user) 110 | ): 111 | delete_successful = crud.delete_follow_relationship( 112 | db, current_user.id, request_body.followUserId) 113 | 114 | # 115 | # Broadcast WS message so user components can update 116 | # 117 | message = schemas.WSMessage[schemas.WSFollowsUpdateBody]( 118 | action=schemas.WSMessageAction.LostFollower, 119 | body=schemas.WSFollowsUpdateBody( 120 | userId=current_user.id, 121 | followUserId=request_body.followUserId 122 | ) 123 | ) 124 | await ws_manager.broadcast(message, current_user.id) 125 | 126 | return schemas.EmptyResponse() 127 | -------------------------------------------------------------------------------- /api/routers/messages.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import ( 3 | APIRouter, HTTPException, status, 4 | Request, Depends, BackgroundTasks, 5 | WebSocket, WebSocketDisconnect, Cookie, Query 6 | ) 7 | from fastapi.responses import HTMLResponse 8 | from fastapi.encoders import jsonable_encoder 9 | 10 | # SQLAlchemy 11 | from sqlalchemy.orm import Session 12 | 13 | # Types 14 | from typing import List, Optional 15 | 16 | # Custom Modules 17 | from .. import schemas, crud, models 18 | from ..background_functions.email_notifications import send_new_message_notification_email 19 | from ..dependencies import get_db, get_current_user 20 | from ..core import security 21 | from ..core.config import settings 22 | from ..core.utilities import generate_random_uuid 23 | from ..core.websocket.connection_manager import ws_manager 24 | 25 | # Schema 26 | from ..schemas.websockets import WSMessage, WSMessageAction 27 | from ..core.sendgrid.schema import EmailSender 28 | 29 | router = APIRouter(prefix="/messages", tags=['messages']) 30 | 31 | 32 | @router.get("/conversations") 33 | # respone_model=schemas.MessageResponse 34 | def messages( 35 | db: Session = Depends(get_db), 36 | current_user: schemas.User = Depends(get_current_user) 37 | ): 38 | 39 | return crud.get_messages_for_user(db, current_user.id) 40 | 41 | 42 | @router.get("") 43 | # respone_model=schemas.MessageResponse 44 | def messages( 45 | db: Session = Depends(get_db), 46 | current_user: schemas.User = Depends(get_current_user) 47 | ): 48 | 49 | messages = crud.get_messages_for_user(db, current_user.id) 50 | return [schemas.Message( 51 | id=message.id, 52 | userFromId=message.user_from_id, 53 | userFromUsername=message.user_from.username, 54 | userToId=message.user_to_id, 55 | userToUsername=message.user_to.username, 56 | content=message.content, 57 | createdAt=message.created_at 58 | ) for message in messages] 59 | 60 | 61 | @router.post("", response_model=schemas.Message) 62 | async def create_message( 63 | request_body: schemas.MessageCreateRequestBody, 64 | bg_tasks: BackgroundTasks, 65 | db: Session = Depends(get_db), 66 | current_user: schemas.User = Depends(get_current_user) 67 | ): 68 | newMessage = crud.create_message(db, current_user.id, request_body) 69 | return_value = schemas.Message( 70 | id=newMessage.id, 71 | userFromId=newMessage.user_from_id, 72 | userFromUsername=current_user.username, 73 | userToId=newMessage.user_to_id, 74 | userToUsername=newMessage.user_to.username, 75 | content=newMessage.content, 76 | createdAt=newMessage.created_at 77 | ) 78 | other_user = crud.get_user_by_id(db, newMessage.user_to_id) 79 | # Send a websocket message to the user who this message is sent to 80 | # If that user is not online, they will not receive the websocket message. 81 | wsMessage = WSMessage[schemas.Message]( 82 | action=WSMessageAction.ChatMessageNew, 83 | body=return_value 84 | ) 85 | if ws_manager.user_is_online(request_body.userToId): 86 | print('user is online') 87 | await ws_manager.send_personal_message(wsMessage, request_body.userToId) 88 | else: 89 | print("sending a notification email") 90 | # Send an email notification to the user. 91 | # TODO - perhaps implement some rate limiting here so it only sends 92 | # a max amount per minute or something. 93 | # this could be another dict holding last time sent...? 94 | bg_tasks.add_task(send_new_message_notification_email, 95 | current_user, other_user) 96 | 97 | return return_value 98 | 99 | 100 | @router.delete("/", response_model=schemas.EmptyResponse) 101 | async def delete_message( 102 | request_body: schemas.MessageDeleteRequestBody, 103 | db: Session = Depends(get_db), 104 | current_user: schemas.User = Depends(get_current_user) 105 | ): 106 | message: models.Messages = crud.get_message_by_id( 107 | db, request_body.messageId) 108 | result = crud.delete_message(db, current_user.id, request_body) 109 | 110 | # Send a websocket message to the user who this message is sent to 111 | # If that user is not online, they will not receive the websocket message. 112 | wsMessage = WSMessage[schemas.DeletedChatMessageResponseBody]( 113 | action=WSMessageAction.ChatMessageDeleted, 114 | body=schemas.DeletedChatMessageResponseBody( 115 | messageId=request_body.messageId, 116 | userId=message.user_from_id 117 | ) 118 | ) 119 | 120 | if ws_manager.user_is_online(message.user_from_id): 121 | await ws_manager.send_personal_message(wsMessage, message.user_to_id) 122 | 123 | return result 124 | -------------------------------------------------------------------------------- /docker-compose.traefik.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | traefik: 5 | # 6 | # Use the latest v2.3.x Traefik image available 7 | # 8 | image: traefik:v2.3 9 | ports: 10 | # 11 | # Listen on port 80, default for HTTP, necessary to redirect to HTTPS 12 | # 13 | - 80:80 14 | # 15 | # Listen on port 443, default for HTTPS 16 | # 17 | - 443:443 18 | restart: always 19 | labels: 20 | # 21 | # Enable Traefik for this service, to make it available in the public network 22 | # 23 | - traefik.enable=true 24 | # 25 | # Define the port inside of the Docker service to use 26 | # 27 | - traefik.http.services.traefik-dashboard.loadbalancer.server.port=8080 28 | # 29 | # Make Traefik use this domain in HTTP 30 | # 31 | - traefik.http.routers.traefik-dashboard-http.entrypoints=http 32 | - traefik.http.routers.traefik-dashboard-http.rule=Host(`traefik.${PRODUCTION_DOMAIN_HOST}`) 33 | # 34 | # Use the traefik-public network (declared below) 35 | # 36 | - traefik.docker.network=traefik-public 37 | # 38 | # traefik-https the actual router using HTTPS 39 | # 40 | - traefik.http.routers.traefik-dashboard-https.entrypoints=https 41 | - traefik.http.routers.traefik-dashboard-https.rule=Host(`traefik.${PRODUCTION_DOMAIN_HOST}`) 42 | - traefik.http.routers.traefik-dashboard-https.tls=true 43 | # 44 | # Use the "le" (Let's Encrypt) resolver created below 45 | # 46 | - traefik.http.routers.traefik-dashboard-https.tls.certresolver=le 47 | # 48 | # Use the special Traefik service api@internal with the web UI/Dashboard 49 | # 50 | - traefik.http.routers.traefik-dashboard-https.service=api@internal 51 | # 52 | # https-redirect middleware to redirect HTTP to HTTPS 53 | # 54 | - traefik.http.middlewares.https-redirect.redirectscheme.scheme=https 55 | - traefik.http.middlewares.https-redirect.redirectscheme.permanent=true 56 | # 57 | # traefik-http set up only to use the middleware to redirect to https 58 | # 59 | - traefik.http.routers.traefik-dashboard-http.middlewares=https-redirect 60 | # 61 | # admin-auth middleware with HTTP Basic auth 62 | # Using the environment variables USERNAME and HASHED_PASSWORD 63 | # 64 | - traefik.http.middlewares.admin-auth.basicauth.users=${USERNAME?Variable not set}:${HASHED_PASSWORD?Variable not set} 65 | # 66 | # Enable HTTP Basic auth, using the middleware created above 67 | # 68 | - traefik.http.routers.traefik-dashboard-https.middlewares=admin-auth 69 | volumes: 70 | # 71 | # Add Docker as a mounted volume, so that Traefik can read the labels of other services 72 | # 73 | - /var/run/docker.sock:/var/run/docker.sock:ro 74 | # 75 | # Mount the volume to store the certificates 76 | # 77 | - traefik-public-certificates:/certificates 78 | command: 79 | # 80 | # Enable Docker in Traefik, so that it reads labels from Docker services 81 | # 82 | - --providers.docker 83 | # 84 | # Do not expose all Docker services, only the ones explicitly exposed 85 | # 86 | - --providers.docker.exposedbydefault=false 87 | # 88 | # Create an entrypoint "http" listening on port 80 89 | # 90 | - --entrypoints.http.address=:80 91 | # 92 | # Create an entrypoint "https" listening on port 443 93 | # 94 | - --entrypoints.https.address=:443 95 | # 96 | # Create the certificate resolver "le" for Let's Encrypt, uses the environment variable EMAIL 97 | # 98 | - --certificatesresolvers.le.acme.email=deric@programmertutor.com 99 | # 100 | # Store the Let's Encrypt certificates in the mounted volume 101 | # 102 | - --certificatesresolvers.le.acme.storage=/certificates/acme.json 103 | # 104 | # Use the TLS Challenge for Let's Encrypt 105 | # 106 | - --certificatesresolvers.le.acme.tlschallenge=true 107 | # 108 | # Enable the access log, with HTTP requests 109 | # 110 | - --accesslog 111 | # 112 | # Enable the Traefik log, for configurations and errors 113 | # 114 | - --log 115 | # 116 | # Enable the Dashboard and API 117 | # 118 | - --api 119 | networks: 120 | # 121 | # Use the public network created to be shared between Traefik and 122 | # any other service that needs to be publicly available with HTTPS 123 | # 124 | - traefik-public 125 | 126 | volumes: 127 | # Create a volume to store the certificates, there is a constraint to make sure 128 | # Traefik is always deployed to the same Docker node with the same volume containing 129 | # the HTTPS certificates 130 | traefik-public-certificates: 131 | 132 | networks: 133 | # Use the previously created public network "traefik-public", shared with other 134 | # services that need to be publicly available via this Traefik 135 | traefik-public: 136 | external: true 137 | -------------------------------------------------------------------------------- /api/models.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Boolean, Column, Integer, String, DateTime, Date, ForeignKey 2 | from sqlalchemy.orm import relationship 3 | from sqlalchemy.sql import func 4 | # from sqlalchemy.dialects.postgresql import JSONB 5 | 6 | from .database import Base 7 | 8 | 9 | class User(Base): 10 | """ 11 | TODO: Need to sort relationships by date - most recent first 12 | """ 13 | __tablename__ = "users" 14 | 15 | id = Column(Integer, primary_key=True, index=True) 16 | email = Column(String, unique=True, index=True, nullable=False) 17 | username = Column(String, index=True) 18 | bio = Column(String, index=True) 19 | birthdate = Column(Date, index=True) 20 | hashed_password = Column(String, nullable=False) 21 | confirmation_key = Column(String, nullable=True) 22 | account_verified = Column(Boolean, nullable=False, default=False) 23 | created_at = Column(DateTime(timezone=True), server_default=func.now()) 24 | updated_at = Column(DateTime(timezone=True), server_default=func.now()) 25 | 26 | tweets = relationship("Tweet", back_populates="user") 27 | followers = relationship( 28 | "Follows", back_populates="user", foreign_keys="Follows.user_id") 29 | follows = relationship( 30 | "Follows", back_populates="follows_user", foreign_keys="Follows.follows_user_id") 31 | comments = relationship("Comments", back_populates="user") 32 | tweet_likes = relationship("TweetLikes", back_populates="user") 33 | comment_likes = relationship("CommentLikes", back_populates="user") 34 | 35 | inbox = relationship("Messages", back_populates="user_from", 36 | foreign_keys="Messages.user_from_id") 37 | outbox = relationship("Messages", back_populates="user_to", 38 | foreign_keys="Messages.user_to_id") 39 | 40 | def __repr__(self): 41 | return f"{self.id} | {self.username}" 42 | 43 | 44 | class Follows(Base): 45 | """ follows 46 | """ 47 | __tablename__ = "follows" 48 | 49 | id = Column(Integer, primary_key=True, index=True) 50 | user_id = Column(Integer, ForeignKey("users.id")) 51 | follows_user_id = Column(Integer, ForeignKey("users.id")) 52 | 53 | user = relationship("User", back_populates="follows", 54 | foreign_keys=[user_id]) 55 | follows_user = relationship( 56 | "User", back_populates="followers", foreign_keys=[follows_user_id]) 57 | 58 | 59 | class Tweet(Base): 60 | __tablename__ = "tweets" 61 | 62 | id = Column(Integer, primary_key=True, index=True) 63 | content = Column(String, index=True) 64 | created_at = Column(DateTime(timezone=True), server_default=func.now()) 65 | user_id = Column(Integer, ForeignKey("users.id")) 66 | 67 | user = relationship("User", back_populates="tweets", 68 | foreign_keys=[user_id]) 69 | comments = relationship("Comments", back_populates="tweet") 70 | likes = relationship("TweetLikes", back_populates="tweet") 71 | 72 | 73 | class TweetLikes(Base): 74 | __tablename__ = "tweet_likes" 75 | 76 | id = Column(Integer, primary_key=True, index=True) 77 | user_id = Column(Integer, ForeignKey("users.id")) 78 | tweet_id = Column(Integer, ForeignKey("tweets.id")) 79 | 80 | user = relationship("User", back_populates="tweet_likes", 81 | foreign_keys=[user_id]) 82 | tweet = relationship("Tweet", back_populates="likes", 83 | foreign_keys=[tweet_id]) 84 | 85 | 86 | class Comments(Base): 87 | __tablename__ = "comments" 88 | 89 | id = Column(Integer, primary_key=True, index=True) 90 | user_id = Column(Integer, ForeignKey("users.id")) 91 | tweet_id = Column(Integer, ForeignKey("tweets.id")) 92 | content = Column(String, index=True) 93 | created_at = Column(DateTime(timezone=True), server_default=func.now()) 94 | 95 | user = relationship("User", back_populates="comments", 96 | foreign_keys=[user_id]) 97 | tweet = relationship("Tweet", back_populates="comments", 98 | foreign_keys=[tweet_id]) 99 | likes = relationship("CommentLikes", back_populates="comment", 100 | foreign_keys="CommentLikes.comment_id") 101 | 102 | 103 | class CommentLikes(Base): 104 | __tablename__ = "comment_likes" 105 | 106 | id = Column(Integer, primary_key=True, index=True) 107 | user_id = Column(Integer, ForeignKey("users.id")) 108 | comment_id = Column(Integer, ForeignKey("comments.id")) 109 | 110 | user = relationship( 111 | "User", back_populates="comment_likes", foreign_keys=[user_id]) 112 | comment = relationship( 113 | "Comments", back_populates="likes", foreign_keys=[comment_id]) 114 | 115 | 116 | class Messages(Base): 117 | __tablename__ = "messages" 118 | 119 | id = Column(Integer, primary_key=True, index=True) 120 | user_from_id = Column(Integer, ForeignKey("users.id")) 121 | user_to_id = Column(Integer, ForeignKey("users.id")) 122 | content = Column(String) 123 | is_read = Column(Boolean, default=False) 124 | created_at = Column(DateTime(timezone=True), server_default=func.now()) 125 | 126 | user_from = relationship( 127 | "User", back_populates="inbox", foreign_keys=[user_from_id]) 128 | 129 | user_to = relationship( 130 | "User", back_populates="outbox", foreign_keys=[user_to_id]) 131 | -------------------------------------------------------------------------------- /api/routers/comments.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Request, Depends, BackgroundTasks 2 | from sqlalchemy.orm import Session 3 | from typing import List 4 | 5 | from .. import schemas, crud 6 | from ..core import security 7 | from ..core.config import settings 8 | from ..core.websocket.connection_manager import ws_manager 9 | from ..dependencies import get_db, get_current_user 10 | from ..background_functions.email_notifications import send_new_comment_notification_email 11 | from ..schemas.websockets import WSMessage, WSMessageAction 12 | 13 | router = APIRouter(prefix="/comments", tags=['comments']) 14 | 15 | 16 | @router.get("/user/{userId}", response_model=List[schemas.Comment]) 17 | def get_comments_for_user( 18 | userId: int, 19 | skip: int = 0, 20 | limit: int = 100, 21 | db: Session = Depends(get_db) 22 | ): 23 | comments = crud.get_comments_for_user( 24 | db, user_id=userId, skip=skip, limit=limit) 25 | return [ 26 | schemas.Comment( 27 | id=comment.id, 28 | userId=comment.user_id, 29 | tweetId=comment.tweet_id, 30 | content=comment.content, 31 | username=comment.user.username, 32 | createdAt=comment.created_at 33 | ) for comment in comments 34 | ] 35 | 36 | 37 | @router.get("/tweet/{tweetId}", response_model=List[schemas.Comment]) 38 | def get_comments_for_tweet( 39 | tweetId: int, 40 | skip: int = 0, 41 | limit: int = 0, 42 | db: Session = Depends(get_db) 43 | ): 44 | comments = crud.get_comments_for_tweet( 45 | db, tweet_id=tweetId, skip=skip, limit=limit) 46 | return [ 47 | schemas.Comment( 48 | id=comment.id, 49 | userId=comment.user_id, 50 | tweetId=comment.tweet_id, 51 | content=comment.content, 52 | username=comment.user.username, 53 | createdAt=comment.created_at 54 | ) for comment in comments 55 | ] 56 | 57 | 58 | @router.get("/count/tweet/{tweetId}", response_model=schemas.TweetCommentCount) 59 | def get_comment_count_for_tweet( 60 | tweetId: int, 61 | db: Session = Depends(get_db) 62 | ): 63 | count = crud.get_comment_count_for_tweet(db, tweet_id=tweetId) 64 | 65 | return schemas.TweetCommentCount( 66 | count=count 67 | ) 68 | 69 | 70 | @router.post("", response_model=schemas.Comment) 71 | async def create_comment_for_tweet( 72 | request_body: schemas.CommentCreate, 73 | bg_tasks: BackgroundTasks, 74 | db: Session = Depends(get_db), 75 | current_user: schemas.User = Depends(get_current_user) 76 | ): 77 | newComment = crud.create_tweet_comment(db, current_user.id, request_body) 78 | 79 | # Broadcast a WS message so users can see the new comment get updated in real-time 80 | return_comment = schemas.Comment( 81 | id=newComment.id, 82 | userId=newComment.user_id, 83 | tweetId=newComment.tweet_id, 84 | content=newComment.content, 85 | username=newComment.user.username, 86 | createdAt=newComment.created_at 87 | ) 88 | message = WSMessage[schemas.WSCommentCreated]( 89 | action=WSMessageAction.NewComment, 90 | body=schemas.WSCommentCreated( 91 | comment=return_comment 92 | ) 93 | ) 94 | 95 | if not ws_manager.user_is_online(newComment.tweet.user.id): 96 | bg_tasks.add_task(send_new_comment_notification_email, 97 | tweet_owner=newComment.tweet.user, commenter=newComment.user, comment=newComment) 98 | 99 | # Push the new comment to all online users 100 | await ws_manager.broadcast(message, current_user.id) 101 | return return_comment 102 | 103 | 104 | @router.put("", response_model=schemas.Comment) 105 | async def update_comment( 106 | request_body: schemas.CommentUpdate, 107 | db: Session = Depends(get_db), 108 | current_user: schemas.User = Depends(get_current_user) 109 | ): 110 | comment = crud.update_comment(db, current_user.id, request_body) 111 | return_comment = schemas.Comment( 112 | id=comment.id, 113 | userId=comment.user_id, 114 | tweetId=comment.tweet_id, 115 | content=comment.content, 116 | username=comment.user.username, 117 | createdAt=comment.created_at 118 | ) 119 | # 120 | # Broadcast a WS message so users can see the new comment get updated in real-time 121 | # 122 | message = WSMessage[schemas.WSCommentUpdated]( 123 | action=WSMessageAction.UpdatedComment, 124 | body=schemas.WSCommentUpdated( 125 | tweetId=comment.tweet_id, 126 | comment=return_comment 127 | ) 128 | ) 129 | await ws_manager.broadcast(message, current_user.id) 130 | 131 | return return_comment 132 | 133 | 134 | @router.delete("", response_model=schemas.EmptyResponse) 135 | async def delete_comment( 136 | request_body: schemas.CommentDelete, 137 | db: Session = Depends(get_db), 138 | current_user: schemas.User = Depends(get_current_user) 139 | ): 140 | 141 | comment = crud.get_comment_by_id(db, request_body.commentId) 142 | # Broadcast a WS message so users can see the new comment get updated in real-time 143 | message = WSMessage[schemas.WSCommentDeleted]( 144 | action=WSMessageAction.DeletedComment, 145 | body=schemas.WSCommentDeleted( 146 | tweetId=comment.tweet_id, 147 | commentId=request_body.commentId 148 | ) 149 | ) 150 | await ws_manager.broadcast(message, current_user.id) 151 | return crud.delete_comment(db, current_user.id, request_body) 152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastAPI - Twitter Clone 2 | 3 | This is a project that is meant to explore the capabilities of FastAPI by building a real-world application - in this case a simple twitter clone. 4 | 5 | The frontend NextJS Application code can be found here [https://github.com/dericf/twitter-clone-frontend](https://github.com/dericf/twitter-clone-frontend) 6 | 7 | You can see the live, hosted, auto-generated api documentation [here](https://twitter-clone.dericfagnan.com/docs) 8 | 9 | Example Screenshot: 10 | ![image](https://user-images.githubusercontent.com/14207512/123458406-43921f00-d5a2-11eb-834c-71902b118c77.png) 11 | 12 | 13 | ## Installation/Setup 14 | 15 | Make sure you have docker installed on your local machine. 16 | 17 | Generate an application secret key by running `openssl rand -hex 32` 18 | 19 | create a `.env` file and add the following environment variables 20 | 21 | ``` 22 | export LOCAL_POSTGRES_URL="postgresql://admin:secret@localhost:5433/twitterdb" 23 | export LOCAL_DOCKER_INTERNAL_POSTGRES_URL="postgresql://admin:secret@db:5432/twitterdb" 24 | export PRODUCTION_POSTGRES_URL="postgresql://admin:secret@:5432/twitterdb" 25 | export ENV=development 26 | export SECRET_KEY="ASuperSecretKey" 27 | export FIRST_SUPERUSER_EMAIL="super@user.com" 28 | export FIRST_SUPERUSER_PASSWORD="asupersecretpassword" 29 | export EMAIL_TEST_USER=john@doe.com 30 | export SEND_GRID_API_KEY= 31 | export PRODUCTION_CLIENT_HOST_URL="localhost:3000" 32 | export SEND_GRID_FROM_EMAIL="account-verification@your-domain.com" 33 | export PRODUCTION_DOMAIN_HOST="your-api-host-domain.com" 34 | ``` 35 | 36 | run `source .env` to apply the environment variables 37 | 38 | ## Running in Development 39 | 40 | All you need to run is `docker-compose -f docker-compose-dev.yml up --build` and it will build the docker image based on `Dockerfile-dev` and will automatically handle hot-reloading while developing. 41 | 42 | ### Running Database Migrations (Alembic) 43 | 44 | #### Downgrade to the initial state (blank DB) 45 | 46 | Run `alembic downgrade base` 47 | 48 | #### Run all migrations to current/highest state 49 | 50 | Run `alembic upgrade head` 51 | 52 | #### Create a new migration version 53 | 54 | Run `alembic revision -m "version tag"` 55 | 56 | ## Access Auto-Generated API Documentation 57 | 58 | navigate to `localhost:8001/docs` or `localhost:8001/redocs` 59 | 60 | ## Deploy to Production Server 61 | 62 | [Great blog post here](https://dev.to/tiangolo/deploying-fastapi-and-other-apps-with-https-powered-by-traefik-5dik) 63 | 64 | Create a new Linode Server 65 | 66 | Add CNAME record to point to that server 67 | 68 | SSH into your server `ssh root@your-domain.com` 69 | 70 | run `apt update` to update package listing 71 | 72 | run `apt upgrade` to actually install all latest packages 73 | 74 | [Docker Install Scripts](https://docs.docker.com/engine/install/ubuntu/#install-using-the-convenience-script) 75 | 76 | [Docker Compose Install](https://docs.docker.com/compose/install/) 77 | 78 | TODO: `apt install haveged` 79 | 80 | `rsync -a --exclude 'pgdata' --exclude '__pycache__' --exclude 'venv' ./* root@twitter-clone-fastapi.programmertutor.com:/root/code/twitter-clone-server-fastapi` 81 | 82 | _Note this doesn't send the .env file - must create that manually on the server once._ 83 | 84 | create the `docker-compose.traefik.yml file` 85 | create the `docker-compose.override.yml file` 86 | 87 | ### Set up credentials for traefik dashboard basic auth. 88 | 89 | create 3 new .env variables (USERNAME, PASSWORD and HASHED_PASSWORD) 90 | 91 | get the value of `HASHED_PASSWORD` with `echo $(openssl passwd -apr1 $PASSWORD)` 92 | 93 | ### Set up traefik network for the first time 94 | 95 | Create the traefik public network with `docker network create traefik-public` 96 | 97 | ### Run the main traefik container in the background 98 | 99 | Run traefik docker container with `docker-compose -f docker-compose.traefik.yml up -d` 100 | 101 | ### Finally run the fastapi server docker container 102 | 103 | Run the fastapi server with `docker-compose -f docker-compose-prod.yml up --build -d` 104 | 105 | ### Database 106 | 107 | Create a new Postgres Database on AWS RDS and update the `PRODUCTION_POSTGRES_URL` in your `.env` file 108 | 109 | --- 110 | 111 | ## Developer Docs and Examples 112 | 113 | ### Python Types & Pydantic Models 114 | 115 | [https://fastapi.tiangolo.com/python-types/](https://fastapi.tiangolo.com/python-types/) 116 | 117 | [https://mypy.readthedocs.io/en/latest/cheat_sheet_py3.html](https://mypy.readthedocs.io/en/latest/cheat_sheet_py3.html) 118 | 119 | ### FastAPI Docs 120 | 121 | [https://fastapi.tiangolo.com/](https://fastapi.tiangolo.com/) 122 | 123 | ### Starlette Docs 124 | 125 | [https://www.starlette.io/](https://www.starlette.io/) 126 | 127 | ### FastAPI + SQLALchemy Tutorial 128 | 129 | [https://fastapi.tiangolo.com/tutorial/sql-databases/](https://fastapi.tiangolo.com/tutorial/sql-databases/) 130 | 131 | ### Alembic Docs 132 | 133 | ### Docker Docs 134 | 135 | ### Docker Compose Docs 136 | 137 | ### Manual FastAPI Deployment Docs 138 | 139 | https://fastapi.tiangolo.com/deployment/manually/ 140 | 141 | ### Traefik docker image Docs 142 | 143 | ### FastAPI docker image Docs 144 | 145 | https://fastapi.tiangolo.com/deployment/docker/ 146 | 147 | ### Terraform Docs 148 | 149 | ### AWS RDS Docs 150 | 151 | ### AWS ECS Docs 152 | 153 | ### OpenAPI Spec Docs 154 | 155 | https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.1.0.md#format 156 | 157 | ### Containerized VS Code Instances for development 158 | 159 | https://code.visualstudio.com/docs/remote/containers 160 | 161 | ## References 162 | 163 | [FastAPI — How to add basic and cookie authentication (by Nils de Bruin)](https://medium.com/data-rebels/fastapi-how-to-add-basic-and-cookie-authentication-a45c85ef47d3) 164 | -------------------------------------------------------------------------------- /api/routers/users.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import APIRouter, HTTPException, status, Request, Depends, BackgroundTasks 3 | 4 | # SQLAlchemy 5 | from sqlalchemy.orm import Session 6 | 7 | # Types 8 | from typing import List, Optional 9 | 10 | # Custom Modules 11 | from .. import schemas, crud 12 | from ..background_functions.email_notifications import send_registration_confirmation_email 13 | from ..dependencies import get_db, get_current_user 14 | 15 | # Core Modules 16 | from ..core import security 17 | from ..core.config import settings 18 | from ..core.utilities import generate_random_uuid 19 | 20 | # Standard Library 21 | import os 22 | 23 | router = APIRouter(prefix="/users", tags=['users']) 24 | 25 | 26 | @router.get("", response_model=List[schemas.UserResponse]) 27 | def get_one_or_all_users(userId: Optional[int] = None, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): 28 | """Return either all users, or a single user with id == userId. Always returns a list. 29 | """ 30 | if userId: 31 | users = [crud.get_user_by_id(db, userId)] 32 | else: 33 | users = crud.get_users(db, skip=skip, limit=limit) 34 | 35 | # TODO perhaps there is a better way of returning this model. 36 | # It seems like its trying to immidate graphql 37 | return [ 38 | schemas.UserResponse( 39 | id=user.id, 40 | email=user.email, 41 | username=user.username, 42 | bio=user.bio, 43 | birthdate=user.birthdate 44 | ) for user in users] 45 | 46 | 47 | @router.get("/search/{usernameFragment}", response_model=List[schemas.UserResponse]) 48 | def get_one_or_all_users(usernameFragment: str = None, skip: int = 0, limit: int = 100, db: Session = Depends(get_db)): 49 | """Search for a user based on username. 50 | """ 51 | users = crud.search_user_by_username_fragment( 52 | db, usernameFragment, skip, limit) 53 | 54 | # TODO perhaps there is a better way of returning this model. 55 | # It seems like its trying to immidate graphql 56 | return [ 57 | schemas.UserResponse( 58 | id=user.id, 59 | email=user.email, 60 | username=user.username, 61 | bio=user.bio, 62 | birthdate=user.birthdate 63 | ) for user in users] 64 | 65 | 66 | @router.get("/me", response_model=schemas.User) 67 | def get_authenticated_user(db: Session = Depends(get_db), current_user: schemas.User = Depends(get_current_user)): 68 | """Get the currently logged in user if there is one (testing purposes only) 69 | """ 70 | return current_user 71 | 72 | 73 | @router.post("", response_model=schemas.User) 74 | async def create_user(user: schemas.UserCreate, bg_tasks: BackgroundTasks, db: Session = Depends(get_db)): 75 | """Create a new user record in the database and send a registration confirmation email 76 | """ 77 | db_user = crud.get_user_by_email(db, email=user.email) 78 | if db_user: 79 | raise HTTPException(status_code=400, detail="Email already registered") 80 | # 81 | # Generate a random uuid to email to the user 82 | # 83 | confirmation_key = generate_random_uuid() 84 | # 85 | # Create the new user 86 | # 87 | newUser: schemas.User = crud.create_user( 88 | db=db, user=user, confirmation_key=confirmation_key) 89 | bg_tasks.add_task( 90 | send_registration_confirmation_email, 91 | username=newUser.username, 92 | email=user.email, 93 | confirmation_key=confirmation_key 94 | ) 95 | return newUser 96 | 97 | 98 | @router.put('', response_model=schemas.UserUpdateResponseBody) 99 | def update_user(request_body: schemas.UserUpdateRequestBody, 100 | db: Session = Depends(get_db), 101 | current_user: schemas.UserWithPassword = Depends(get_current_user)): 102 | """Update an authenticated user's username and/or bio. 103 | """ 104 | # Check that password is correct 105 | if not security.verify_password(request_body.password, current_user.hashed_password): 106 | raise HTTPException( 107 | status_code=status.HTTP_401_UNAUTHORIZED, detail="Wrong password") 108 | 109 | # Check if they are trying to update the username 110 | if request_body.newUsername is not None: 111 | # see if that username is available 112 | db_user_with_username = crud.get_user_by_username( 113 | db, request_body.newUsername) 114 | if db_user_with_username is not None: 115 | raise HTTPException( 116 | status_code=status.HTTP_406_NOT_ACCEPTABLE, detail="Username already exists") 117 | 118 | # Update user attributes 119 | user = crud.update_user(db, current_user.id, request_body) 120 | return user 121 | 122 | 123 | @router.delete('/', response_model=schemas.EmptyResponse) 124 | def delete_user(request_body: schemas.UserDeleteRequestBody, 125 | db: Session = Depends(get_db), 126 | current_user: schemas.UserWithPassword = Depends(get_current_user)): 127 | if not security.verify_password(request_body.password, current_user.hashed_password): 128 | raise HTTPException( 129 | status_code=status.HTTP_401_UNAUTHORIZED, detail="Wrong password") 130 | delete_successful = crud.delete_user(db, current_user.id) 131 | 132 | return schemas.EmptyResponse() 133 | 134 | 135 | @router.post('/confirm-account/', response_model=schemas.EmptyResponse) 136 | async def confirm_account(request_body: schemas.UserAccountConfirmationRequestBody, db: Session = Depends(get_db)): 137 | user: schemas.User = crud.get_user_by_confirmation_key( 138 | db, request_body.confirmationKey) 139 | 140 | if not user: 141 | raise HTTPException( 142 | status_code=status.HTTP_400_BAD_REQUEST, detail="Bad key.") 143 | 144 | # correct confirmation key was passed 145 | 146 | crud.verify_account(db, user.id) 147 | 148 | return schemas.EmptyResponse() 149 | -------------------------------------------------------------------------------- /api/main.py: -------------------------------------------------------------------------------- 1 | # Standard Library 2 | import os 3 | import time 4 | import json 5 | from typing import List, Dict, Union 6 | 7 | # FastAPI 8 | from fastapi import ( 9 | Depends, 10 | FastAPI, 11 | HTTPException, 12 | BackgroundTasks, 13 | Request, 14 | WebSocket, 15 | WebSocketDisconnect, 16 | Cookie, 17 | Query, 18 | ) 19 | from fastapi.middleware.cors import CORSMiddleware 20 | from fastapi.encoders import jsonable_encoder 21 | 22 | # Routers 23 | from .routers import ( 24 | auth, 25 | users, 26 | tweets, 27 | comments, 28 | followers, 29 | follows, 30 | tweet_likes, 31 | comment_likes, 32 | messages, 33 | ) 34 | 35 | # SQLAlchemy 36 | from sqlalchemy.orm import Session 37 | 38 | # Core 39 | from .core import security 40 | from .core.config import settings 41 | from .core.cors import cors_origins 42 | from .core.websocket.connection_manager import ws_manager 43 | 44 | # Database 45 | from .database import SessionLocal, engine 46 | from . import crud, models, schemas, dependencies 47 | 48 | # Schema 49 | from .schemas.websockets import WSMessage, WSMessageAction, WSMessageError 50 | 51 | # Instantiate Main FastAPI Instance 52 | app = FastAPI( 53 | # root_path=settings.API_V1_STR, 54 | title="Twitter Clone (For Educational Purposes)", 55 | description="This API replicates some of the very basic functionality of twitter, including users, tweets, likes, comments and", 56 | version="0.0.1", 57 | ) 58 | 59 | # Add CORS middleware 60 | app.add_middleware( 61 | CORSMiddleware, 62 | allow_origins=cors_origins, 63 | allow_credentials=True, 64 | allow_methods=["*"], 65 | allow_headers=["*"], 66 | expose_headers=["set-cookie"], 67 | ) 68 | 69 | # Include All Routers 70 | app.include_router(auth.router) 71 | app.include_router(users.router) 72 | app.include_router(tweets.router) 73 | app.include_router(comments.router) 74 | app.include_router(follows.router) 75 | app.include_router(followers.router) 76 | app.include_router(tweet_likes.router) 77 | app.include_router(comment_likes.router) 78 | app.include_router(messages.router) 79 | 80 | # Needed to resolve an unknown http bug 81 | 82 | 83 | @app.middleware("http") 84 | async def modify_location_header(request: Request, call_next): 85 | """This is a very hacky fix for a glitch in the "location" response header 86 | For some reason it sends back http instead of https so I manually 87 | overwrite it here. Definitely something that should be fixed upstream in 88 | configuration but this will work temporarily until I find the correct location. 89 | """ 90 | response: Response = await call_next(request) 91 | 92 | # Check for location response header 93 | location = response.headers.get("location") 94 | if location and os.environ.get("ENV") != "development": 95 | response.headers["location"] = location.replace("http:", "https:") 96 | return response 97 | 98 | # Dummy route at the index 99 | 100 | 101 | @app.get("/") 102 | def index(request: Request): 103 | """Index route. Only used for testing purposes.""" 104 | return {"api_status": "ok"} 105 | 106 | 107 | @app.websocket("/ws/{user_id}") 108 | async def websocket_endpoint( 109 | websocket: WebSocket, 110 | user_id: int, 111 | db: Session = Depends(dependencies.get_db), 112 | current_user: schemas.User = Depends(dependencies.get_current_user) 113 | ): 114 | """ 115 | Websocket endpoint for authenticated users. 116 | 117 | Listens for incoming CONNECTIONS and uses the connection_manager class to 118 | update the dictionary of active connections. 119 | 120 | Listens for incoming MESSAGES and handles them accordingly 121 | 122 | TODO: This should be moved to its own websocket module 123 | TODO: The /user_id param should not be needed anymore since it gets it from the token 124 | """ 125 | # print("\n********************\nNew Websocket Connection Incoming: ") 126 | # print("user_id: ", user_id) 127 | # print("current user", current_user) 128 | # await ws_manager.show_all_connections() 129 | # 130 | # Attempt to connect new client 131 | # 132 | await ws_manager.connect(websocket, user_id) 133 | if not current_user: 134 | # print("No authenticated user - Alert client and close connection...") 135 | auth_failed_message = WSMessage[None]( 136 | action=WSMessageAction.AuthRequired, 137 | message="error", 138 | error=WSMessageError( 139 | message="Authentication failed", 140 | code=401 141 | ) 142 | ) 143 | await ws_manager.send_personal_message(auth_failed_message, user_id) 144 | await ws_manager.disconnect(user_id) 145 | return 146 | # 147 | # New client has connected 148 | # 149 | # user: schemas.User = crud.get_user_by_id(db, user_id) 150 | await ws_manager.broadcast({"action": "chat.user.online", "body": jsonable_encoder( 151 | schemas.ChatUserOnlineResponseBody( 152 | isOnline=True, 153 | userId=user_id, 154 | username=current_user.username 155 | ) 156 | )}, user_id) 157 | 158 | try: 159 | while True: 160 | # 161 | # Receive incoming message 162 | # ! Note: so far the client does not send any WS messages - instead relies on the http rest api 163 | data = await websocket.receive_json() 164 | 165 | user: schemas.User = crud.get_user_by_id(db, user_id) 166 | 167 | # data = json.loads(data) 168 | # print("\n*************** New Websocket Message *************") 169 | # print(data) 170 | # print(f"user: {user} ") 171 | action = data.get("action") 172 | 173 | if (action == "chat.user.online"): 174 | body = schemas.ChatUserOnlineRequestBody(**data.get("body")) 175 | # print("Action: ", action) 176 | response_body = schemas.ChatUserOnlineResponseBody( 177 | isOnline=ws_manager.user_is_online(body.userId), 178 | userId=body.userId 179 | ) 180 | # Send a message back to notify if the other user is online or not 181 | await websocket.send_json( 182 | {"action": action, "body": jsonable_encoder(response_body)}) 183 | 184 | elif (action == "chat.user.typing"): 185 | body = schemas.ChatUserTypingRequestBody(**data.get("body")) 186 | response_body = schemas.ChatUserTypingResponseBody( 187 | isTyping=True 188 | ) 189 | 190 | await ws_manager.send_personal_message( 191 | {"action": action, "body": jsonable_encoder(response_body)}, body.userId) 192 | 193 | except WebSocketDisconnect as error: 194 | # 195 | # Client has disconnected 196 | # 197 | # print("Client Disconnected !: ", error) 198 | await ws_manager.broadcast({"action": "chat.user.online", "body": jsonable_encoder( 199 | schemas.ChatUserOnlineResponseBody( 200 | isOnline=False, 201 | userId=user_id, 202 | username=current_user.username 203 | ) 204 | )}, user_id) 205 | await ws_manager.disconnect(user_id) 206 | # await ws_manager.broadcast({"action": "notification", "message": f"{user.username} disconnected"}) 207 | -------------------------------------------------------------------------------- /api/crud.py: -------------------------------------------------------------------------------- 1 | # FastAPI 2 | from fastapi import status, HTTPException 3 | 4 | # SQLAlchemy 5 | from sqlalchemy.orm import Session 6 | from sqlalchemy import or_, func, case, column 7 | 8 | # Types 9 | from typing import Optional, List, Union 10 | 11 | # Custom Modules 12 | from . import models, schemas 13 | from .database import engine 14 | from .core import security 15 | import datetime 16 | from api.core.utilities import generate_random_uuid 17 | 18 | # Standard Library 19 | import os 20 | # 21 | # TODO : split into multiple files. 22 | # TODO : add return types to all functions. 23 | # 24 | 25 | 26 | def get_user_by_id(db: Session, user_id: int) -> Union[models.User, None]: 27 | """Get a single user by id 28 | """ 29 | return db.query(models.User).filter(models.User.id == user_id).one_or_none() 30 | 31 | 32 | def get_user_by_email(db: Session, email: str) -> Union[models.User, None]: 33 | """Get a single user by email 34 | """ 35 | query = db.query(models.User).filter( 36 | func.lower(models.User.email) == func.lower(email)) 37 | # print(query.statement.compile(engine)) 38 | return query.one_or_none() 39 | 40 | 41 | def get_user_by_confirmation_key(db: Session, confirmation_key: str) -> Union[models.User, None]: 42 | """Get a single user by confirmation key sent by email 43 | """ 44 | query = db.query(models.User).filter( 45 | func.lower(models.User.confirmation_key) == func.lower(confirmation_key)) 46 | return query.one_or_none() 47 | 48 | 49 | def get_user_by_username(db: Session, username: str) -> Union[models.User, None]: 50 | """Get a single user by username 51 | """ 52 | return db.query(models.User).filter(func.lower(models.User.username) == func.lower(username)).first() 53 | 54 | 55 | def search_user_by_username_fragment(db: Session, username_fragment: str, skip: int = 0, limit: int = 100) -> List[models.User]: 56 | """Search for users by username 57 | """ 58 | return db.query(models.User).filter(models.User.username.ilike(f"%{username_fragment}%")).limit(limit).offset(skip).all() 59 | 60 | 61 | def get_user_by_email_or_username(db: Session, email: str): 62 | """Get a single user by email or by uername 63 | """ 64 | query = db.query(models.User).filter( 65 | or_(func.lower(models.User.email) == func.lower(email), func.lower(models.User.username) == func.lower(email))) 66 | # print(query.statement.compile(engine)) 67 | return query.first() 68 | 69 | 70 | def get_users(db: Session, skip: int = 0, limit: int = 100): 71 | """Get all users 72 | """ 73 | query = db.query(models.User).offset(skip).limit(limit) 74 | # print(query.statement.compile(engine)) 75 | return query.all() 76 | 77 | 78 | def create_user(db: Session, user: schemas.UserCreate, confirmation_key: str): 79 | """Add a user 80 | """ 81 | account_verified = "dev" in os.environ.get( 82 | "ENV") 83 | db_user = models.User( 84 | email=user.email.lower(), 85 | username=user.username.lower(), 86 | bio=user.bio, 87 | birthdate=user.birthdate, 88 | hashed_password=security.get_password_hash(user.password), 89 | confirmation_key=confirmation_key, 90 | account_verified=account_verified 91 | ) 92 | db.add(db_user) 93 | db.commit() 94 | db.refresh(db_user) 95 | return db_user 96 | 97 | 98 | def update_user(db: Session, user_id: int, user_update: schemas.UserUpdateRequestBody): 99 | user_db = db.query(models.User).filter( 100 | models.User.id == user_id).one_or_none() 101 | if user_update.newBio: 102 | user_db.bio = user_update.newBio 103 | if user_update.newUsername: 104 | user_db.username = user_update.newUsername 105 | db.commit() 106 | db.refresh(user_db) 107 | return user_db 108 | 109 | 110 | def verify_account(db: Session, user_id: int): 111 | user_db = db.query(models.User).filter( 112 | models.User.id == user_id).one_or_none() 113 | 114 | user_db.confirmationKey = None 115 | user_db.account_verified = True 116 | db.commit() 117 | db.refresh(user_db) 118 | return user_db 119 | 120 | 121 | def delete_user(db: Session, user_id: int): 122 | try: 123 | db.query(models.User).filter(models.User.id == user_id).delete() 124 | db.commit() 125 | 126 | except Exception as e: 127 | raise HTTPException( 128 | status_code=status.HTTP_400_BAD_REQUEST, detail="Something went wrong") 129 | 130 | 131 | ########## 132 | # TWEETS # 133 | ########## 134 | def get_tweet_by_id(db: Session, tweet_id: int): 135 | return db.query(models.Tweet).filter(models.Tweet.id == tweet_id).one_or_none() 136 | 137 | 138 | def get_tweets(db: Session, skip: int = 0, limit: int = 100): 139 | return db.query(models.Tweet).order_by(models.Tweet.created_at.desc()).offset(skip).limit(limit).all() 140 | 141 | 142 | def get_tweets_for_user(db: Session, user_id: int, skip: int = 0, limit: int = 100): 143 | # First check if user exists 144 | db_user = db.query(models.User).filter( 145 | models.User.id == user_id).one_or_none() 146 | 147 | if not db_user: 148 | raise HTTPException( 149 | status_code=status.HTTP_400_BAD_REQUEST, detail="User does not exist") 150 | 151 | # user exists - proceed to return tweets 152 | tweets = db.query(models.Tweet).filter(models.Tweet.user_id == user_id).order_by( 153 | models.Tweet.created_at.desc()).limit(limit).all() 154 | return tweets 155 | 156 | 157 | def get_tweets_liked_by_user(db: Session, user_id: int, skip: int = 0, limit: int = 100): 158 | # First check if user exists 159 | 160 | db_user = db.query(models.User).filter( 161 | models.User.id == user_id).one_or_none() 162 | 163 | if not db_user: 164 | raise HTTPException( 165 | status_code=status.HTTP_400_BAD_REQUEST, detail="User does not exist") 166 | 167 | # user exists - proceed to return tweets 168 | likes = db.query(models.TweetLikes).filter(models.TweetLikes.user_id == user_id).order_by( 169 | models.TweetLikes.id.desc()).offset(skip).limit(limit).all() 170 | 171 | # tweets = [tweet_like.tweet for tweet_like in db_user.tweet_likes] 172 | return [like.tweet for like in likes] 173 | 174 | 175 | def create_user_tweet(db: Session, tweet: schemas.TweetCreate, user_id: int): 176 | db_tweet = models.Tweet(**tweet.dict()) 177 | 178 | db_tweet.user_id = user_id 179 | 180 | db.add(db_tweet) 181 | db.commit() 182 | db.refresh(db_tweet) 183 | return db_tweet 184 | 185 | 186 | def update_tweet(db: Session, user_id: int, tweet_id: int, new_content: str): 187 | db_tweet: schemas.Tweet = db.query(models.Tweet).filter( 188 | models.Tweet.id == tweet_id).one_or_none() 189 | 190 | if not db_tweet: 191 | raise HTTPException(status.HTTP_404_NOT_FOUND, 192 | detail="tweet not found") 193 | 194 | if db_tweet.user_id != user_id: 195 | # Tweet does not belong to the user. Cannot delete. 196 | raise HTTPException(status.HTTP_401_UNAUTHORIZED, 197 | detail="user does not own that tweet") 198 | 199 | db_tweet.content = new_content 200 | db.commit() 201 | db.refresh(db_tweet) 202 | return db_tweet 203 | 204 | 205 | def delete_tweet(db: Session, user_id: int, tweet_id: int): 206 | db_tweet: schemas.Tweet = db.query(models.Tweet).filter( 207 | models.Tweet.id == tweet_id).one_or_none() 208 | if not db_tweet: 209 | raise HTTPException(status.HTTP_404_NOT_FOUND, 210 | detail="Tweet not found") 211 | if db_tweet.user_id != user_id: 212 | # Tweet does not belong to the user 213 | raise HTTPException(status.HTTP_401_UNAUTHORIZED, 214 | detail="You are not authorized to delete that tweet") 215 | try: 216 | db.delete(db_tweet) 217 | db.commit() 218 | 219 | except Exception as e: 220 | raise HTTPException(status.HTTP_400_BAD_REQUEST, 221 | detail="Something went wrong") 222 | 223 | 224 | ############ 225 | # Comments # 226 | ############ 227 | def create_tweet_comment(db: Session, user_id: int, comment: schemas.CommentCreate): 228 | # First check that tweet exists 229 | db_tweet: schemas.Tweet = db.query(models.Tweet).filter( 230 | models.Tweet.id == comment.tweetId).one_or_none() 231 | if not db_tweet: 232 | raise HTTPException( 233 | status_code=status.HTTP_400_BAD_REQUEST, detail="Tweet does not exist") 234 | 235 | # tweet exists - proceed to create comment 236 | db_comment = models.Comments( 237 | tweet_id=comment.tweetId, user_id=user_id, content=comment.content) 238 | try: 239 | db.add(db_comment) 240 | db.commit() 241 | db.refresh(db_comment) 242 | return db_comment 243 | except Exception as e: 244 | raise HTTPException( 245 | status_code=status.HTTP_400_BAD_REQUEST, detail="Tweet does not exist") 246 | 247 | 248 | def get_comment_by_id(db: Session, comment_id: int) -> models.Comments: 249 | return db.query(models.Comments).filter( 250 | models.Comments.id == comment_id).one_or_none() 251 | 252 | 253 | def get_comments_for_user(db: Session, user_id: int, skip: int = 0, limit: int = 100): 254 | # Check the user exists first 255 | db_user = db.query(models.User).filter( 256 | models.User.id == user_id).one_or_none() 257 | 258 | if not db_user: 259 | raise HTTPException( 260 | status_code=status.HTTP_400_BAD_REQUEST, detail="User does not exist") 261 | # TODO: Should be joining and setting relationship order either here 262 | # TODO: or refereably higher up in the model definitions 263 | # TODO: temporarily sort list here to have newest first 264 | # User exists - proceed to return comments 265 | # db_user.comments.sort(key=lambda comment: datetime.strptime(comment.created_at, "%d-%b-%y")) 266 | return db.query(models.Comments).filter(models.Comments.user_id == user_id).order_by(models.Comments.created_at.desc()).offset(skip).limit(limit).all() 267 | 268 | 269 | def get_comments_for_tweet(db: Session, tweet_id: int, skip: int = 0, limit: int = 100): 270 | db_tweet = db.query(models.Tweet).filter( 271 | models.Tweet.id == tweet_id).one_or_none() 272 | 273 | if not db_tweet: 274 | raise HTTPException( 275 | status_code=status.HTTP_400_BAD_REQUEST, detail="Tweet does not exist") 276 | # return db_tweet.comments 277 | return db.query(models.Comments).filter(models.Comments.tweet_id == tweet_id).order_by(models.Comments.created_at.asc()).offset(skip).all() 278 | 279 | 280 | def update_comment(db: Session, user_id: int, comment: schemas.CommentUpdate) -> models.Comments: 281 | db_comment: schemas.Comment = db.query(models.Comments).filter( 282 | models.Comments.id == comment.commentId).one_or_none() 283 | 284 | if not db_comment: 285 | raise HTTPException( 286 | status_code=status.HTTP_400_BAD_REQUEST, detail="Comment does not exist") 287 | 288 | if db_comment.user_id != user_id: 289 | raise HTTPException(status.HTTP_401_UNAUTHORIZED, 290 | detail="You are not authorized to update that comment") 291 | 292 | db_comment.content = comment.newContent 293 | db.commit() 294 | db.refresh(db_comment) 295 | return db_comment 296 | 297 | 298 | def delete_comment(db: Session, user_id: int, comment: schemas.CommentCreate): 299 | comment_db: schemas.Comment = db.query(models.Comments).filter( 300 | models.Comments.id == comment.commentId).one_or_none() 301 | 302 | if not comment_db: 303 | raise HTTPException( 304 | status_code=status.HTTP_400_BAD_REQUEST, detail="Comment does not exist") 305 | 306 | if comment_db.user_id != user_id: 307 | # comment does not belong to the user 308 | raise HTTPException(status.HTTP_401_UNAUTHORIZED, 309 | detail="You are not authorized to delete that comment") 310 | 311 | try: 312 | db.delete(comment_db) 313 | db.commit() 314 | 315 | except Exception as e: 316 | raise HTTPException( 317 | status_code=status.HTTP_400_BAD_REQUEST, detail="Something went wrong") 318 | 319 | ############# 320 | # FOLLOWS # 321 | ############# 322 | 323 | 324 | def get_all_users_following(db: Session, user_id: int): 325 | return db.query(models.Follows).filter( 326 | models.Follows.user_id == user_id).all() 327 | 328 | 329 | def create_follow_relationship(db: Session, user_id: int, follow_user_id: int): 330 | # First check if the user id is valid 331 | check_user = db.query(models.User).filter( 332 | models.User.id == follow_user_id).one_or_none() 333 | 334 | if not check_user: 335 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, 336 | detail="Bad userId. User does not exist") 337 | 338 | # User exists - proceed to check if user already follows follow_user_id 339 | 340 | existing_follow = db.query(models.Follows).filter( 341 | models.Follows.user_id == user_id, models.Follows.follows_user_id == follow_user_id).one_or_none() 342 | 343 | if existing_follow: 344 | raise HTTPException( 345 | status_code=status.HTTP_400_BAD_REQUEST, detail="Already following user.") 346 | 347 | # User exists & does not already follow - proceed to create follow relationship/link 348 | db_follows = models.Follows( 349 | user_id=user_id, follows_user_id=follow_user_id) 350 | db.add(db_follows) 351 | db.commit() 352 | db.refresh(db_follows) 353 | return db_follows 354 | 355 | 356 | def delete_follow_relationship(db: Session, user_id: int, follow_user_id: int): 357 | # First check if the user id is valid 358 | check_user = db.query(models.User).filter( 359 | models.User.id == user_id).one_or_none() 360 | 361 | if not check_user: 362 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, 363 | detail="Bad userId. User does not exist") 364 | 365 | # User exists - proceed to check if follow relationship/link exists 366 | 367 | existing_follow = db.query(models.Follows).filter( 368 | models.Follows.user_id == user_id, models.Follows.follows_user_id == follow_user_id).one_or_none() 369 | 370 | if not existing_follow: 371 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, 372 | detail="Cannot un-follow user that is not being followed") 373 | 374 | # Follow relationship exists - proceed to un-follow 375 | db.delete(existing_follow) 376 | db.commit() 377 | return 378 | 379 | ############# 380 | # FOLLOWERS # 381 | ############# 382 | 383 | 384 | def get_all_followers(db: Session, user_id: int): 385 | # Check if user_id is valid 386 | existing_user = db.query(models.User.id).filter( 387 | models.User.id == user_id).one_or_none() 388 | 389 | if not existing_user: 390 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, 391 | detail="Bad userId. User does not exist.") 392 | 393 | # User is valid - proceed to get all followers 394 | db_followers = db.query(models.Follows).filter( 395 | models.Follows.follows_user_id == user_id).all() 396 | return db_followers 397 | 398 | ############### 399 | # Tweet Likes # 400 | ############### 401 | 402 | 403 | def get_tweet_like_by_id(db: Session, tweet_like_id: int): 404 | """Get a single tweet_like object/row 405 | """ 406 | return db.query(models.TweetLikes).filter(models.TweetLikes.id == tweet_like_id).one_or_none() 407 | 408 | 409 | def get_tweet_like_by_tweet_id_and_user_id(db: Session, user_id, tweet_id: int) -> models.TweetLikes: 410 | """Get a single tweet_like object/row 411 | """ 412 | return db.query(models.TweetLikes).filter( 413 | models.TweetLikes.tweet_id == tweet_id, models.TweetLikes.user_id == user_id).one_or_none() 414 | 415 | 416 | def get_all_tweet_likes(db: Session): 417 | return db.query(models.TweetLikes).limit(2).all() 418 | 419 | 420 | def get_all_tweet_likes_for_tweet(db: Session, tweet_id: int): 421 | # First check if tweet exists 422 | existing_tweet = get_tweet_by_id(db, tweet_id) 423 | if not existing_tweet: 424 | raise HTTPException(status.HTTP_400_BAD_REQUEST, 425 | detail="Error. Tweet does not exist") 426 | 427 | # user exists - proceed to return tweets 428 | return db.query(models.TweetLikes).filter( 429 | models.TweetLikes.tweet_id == tweet_id).all() 430 | 431 | 432 | def create_tweet_like_for_tweet(db: Session, tweet_id: int, user_id: int): 433 | """Add a tweet like for tweet && user 434 | """ 435 | #! TODO: first check if user has already like this tweet 436 | db_tweet_like = models.TweetLikes( 437 | user_id=user_id, 438 | tweet_id=tweet_id 439 | ) 440 | db.add(db_tweet_like) 441 | db.commit() 442 | db.refresh(db_tweet_like) 443 | return db_tweet_like 444 | 445 | 446 | def delete_tweet_like(db: Session, user_id: int, tweet_id: int,): 447 | """Delete (Unlike) a tweet 448 | """ 449 | db_tweet_like = db.query(models.TweetLikes).filter( 450 | models.TweetLikes.tweet_id == tweet_id, models.TweetLikes.user_id == user_id).one_or_none() 451 | 452 | # if not db_tweet_like: 453 | # raise HTTPException(status.HTTP_400_BAD_REQUEST, 454 | # detail="Error. Cannot Delete. Bad ID for Like") 455 | 456 | # # First check if tweet like and user match 457 | # db_user = get_user_by_id(db, user_id) 458 | # if not db_user: 459 | # raise HTTPException(status.HTTP_400_BAD_REQUEST, 460 | # detail="Error. User does not exist.") 461 | 462 | # if db_user.id != db_tweet_like.user_id: 463 | # raise HTTPException(status.HTTP_401_UNAUTHORIZED, 464 | # detail="Unauthorized to unike this tweet.") 465 | 466 | # Data is valid - proceed to delete tweet like (un-like) 467 | db.delete(db_tweet_like) 468 | db.commit() 469 | 470 | 471 | # -------------------- 472 | 473 | ############### 474 | # Comment Likes # 475 | ############### 476 | 477 | def get_comment_like_by_id(db: Session, comment_like_id: int) -> models.CommentLikes: 478 | """Get a single comment_like object/row 479 | """ 480 | return db.query(models.CommentLikes).filter(models.CommentLikes.id == comment_like_id).one_or_none() 481 | 482 | 483 | def get_comment_like_by_comment_id_and_user_id(db: Session, user_id, comment_id: int) -> models.CommentLikes: 484 | """Get a single comment_like object/row 485 | """ 486 | return db.query(models.CommentLikes).filter( 487 | models.CommentLikes.comment_id == comment_id, models.CommentLikes.user_id == user_id).one_or_none() 488 | 489 | 490 | def get_all_comment_likes(db: Session) -> List[models.CommentLikes]: 491 | return db.query(models.CommentLikes).all() 492 | 493 | 494 | def get_all_comment_likes_for_comment(db: Session, comment_id: int) -> List[models.CommentLikes]: 495 | # First check if comment exists 496 | 497 | existing_comment = get_comment_by_id(db, comment_id) 498 | if not existing_comment: 499 | raise HTTPException(status.HTTP_400_BAD_REQUEST, 500 | detail="Error. Comment does not exist") 501 | 502 | # user exists - proceed to return comments 503 | return db.query(models.CommentLikes).filter( 504 | models.CommentLikes.comment_id == comment_id).all() 505 | 506 | 507 | def create_comment_like_for_comment(db: Session, comment_id: int, user_id: int): 508 | """Add a comment like for comment && user 509 | """ 510 | # Check if the comment exists 511 | existing_comment = get_comment_by_id(db, comment_id) 512 | if not existing_comment: 513 | raise HTTPException(status.HTTP_400_BAD_REQUEST, 514 | detail="Error. Comment does not exist") 515 | 516 | # Check if comment like already exists 517 | db_comment_like = db.query(models.CommentLikes).filter( 518 | models.CommentLikes.comment_id == comment_id, models.CommentLikes.user_id == user_id).first() 519 | 520 | if db_comment_like: 521 | raise HTTPException(status.HTTP_400_BAD_REQUEST, 522 | detail="Error. Already liked comment.") 523 | 524 | db_comment_like = models.CommentLikes( 525 | user_id=user_id, 526 | comment_id=comment_id 527 | ) 528 | 529 | db.add(db_comment_like) 530 | db.commit() 531 | db.refresh(db_comment_like) 532 | return db_comment_like 533 | 534 | 535 | def delete_comment_like_by_user_and_comment_id(db: Session, user_id: int, comment_id: int,): 536 | """Delete (Unlike) a comment 537 | """ 538 | db_comment_like = get_comment_like_by_comment_id_and_user_id( 539 | db, user_id, comment_id) 540 | 541 | if not db_comment_like: 542 | raise HTTPException(status.HTTP_400_BAD_REQUEST, 543 | detail="Error. Cannot Delete. Bad ID for Like") 544 | 545 | # Data is valid - proceed to delete comment like (un-like) 546 | db.delete(db_comment_like) 547 | db.commit() 548 | return 549 | 550 | 551 | ############### 552 | # Counts # 553 | ############### 554 | 555 | def get_comment_count_for_tweet(db: Session, tweet_id: int): 556 | return db.query(models.Comments).filter(models.Comments.tweet_id == tweet_id).count() 557 | 558 | 559 | def get_like_count_for_tweet(db: Session, tweet_id: int): 560 | return db.query(models.Comments).filter(models.Comments.tweet_id == tweet_id).count() 561 | 562 | 563 | def get_followers_for_user(db: Session, user_id: int): 564 | return db.query(models.Follows).filter(models.Follows.follows_user_id == user_id).count() 565 | 566 | 567 | def get_following_for_user(db: Session, user_id: int): 568 | return db.query(models.Follows).filter(models.Follows.user_id == user_id).count() 569 | 570 | 571 | ################# 572 | # Messages # 573 | ################# 574 | 575 | def get_message_by_id(db: Session, message_id): 576 | return db.query(models.Messages).filter_by(id=message_id).one_or_none() 577 | 578 | # !TODO: Need to actually implement the limit and skip... 579 | 580 | 581 | def get_messages_for_user(db: Session, user_id: int, skip: int = 0, limit: int = 10000): 582 | # Check the user exists first 583 | db_user = db.query(models.User).filter( 584 | models.User.id == user_id).one_or_none() 585 | 586 | if not db_user: 587 | raise HTTPException( 588 | status_code=status.HTTP_400_BAD_REQUEST, detail="User does not exist") 589 | return db.query(models.Messages).filter(or_(models.Messages.user_from_id == user_id, models.Messages.user_to_id == user_id)).order_by(models.Messages.created_at.asc()).offset(skip).limit(limit).all() 590 | 591 | # User exists - proceed to return Messages 592 | # TODO - there might be a better way to acheive this same result -> since I have to re-shape the data on the client 593 | # TODO anyway (mandatory traversal - 0(n)), I might as well just return all the messages and manually aggregate the conversations 594 | # TODO on the client... Need to think about this one a bit more. 595 | # conversation_id = case([(models.Messages.user_from_id == user_id, models.Messages.user_to_id), 596 | # (models.Messages.user_to_id == user_id, models.Messages.user_from_id)]).label("conversation_id") 597 | # return db.query(column("content"), column("id"), column("user_from_id"), column("user_to_id"), conversation_id).filter(or_(models.Messages.user_from_id == user_id, models.Messages.user_to_id == user_id)).all() 598 | 599 | 600 | def create_message(db: Session, user_id, body: schemas.MessageCreateRequestBody): 601 | db_message = models.Messages( 602 | user_from_id=user_id, 603 | content=body.content, 604 | user_to_id=body.userToId 605 | ) 606 | db.add(db_message) 607 | db.commit() 608 | db.refresh(db_message) 609 | return db_message 610 | 611 | 612 | def update_message(db: Session, user_id: int, message: schemas.MessageUpdateRequestBody): 613 | db_message: schemas.Message = db.query(models.Messages).filter( 614 | models.Messages.id == message.messageId).one_or_none() 615 | 616 | if not db_message: 617 | raise HTTPException( 618 | status_code=status.HTTP_400_BAD_REQUEST, detail="Message does not exist") 619 | 620 | if db_message.user_id != user_id: 621 | raise HTTPException(status.HTTP_401_UNAUTHORIZED, 622 | detail="You are not authorized to update that message") 623 | 624 | db_message.content = message.newContent 625 | db.commit() 626 | db.refresh(db_message) 627 | return db_message 628 | 629 | 630 | def delete_message(db: Session, user_id: int, message: schemas.MessageDeleteRequestBody): 631 | message_db: schemas.Message = db.query(models.Messages).filter( 632 | models.Messages.id == message.messageId).one_or_none() 633 | 634 | if not message_db: 635 | raise HTTPException( 636 | status_code=status.HTTP_400_BAD_REQUEST, detail="Message does not exist") 637 | 638 | if message_db.user_from_id != user_id: 639 | # message does not belong to the user 640 | raise HTTPException(status.HTTP_401_UNAUTHORIZED, 641 | detail="You are not authorized to delete that message") 642 | 643 | try: 644 | db.delete(message_db) 645 | db.commit() 646 | 647 | except Exception as e: 648 | raise HTTPException( 649 | status_code=status.HTTP_400_BAD_REQUEST, detail="Something went wrong") 650 | --------------------------------------------------------------------------------