├── src ├── _init__.py ├── notification │ ├── __init__.py │ ├── notification_service.py │ ├── consumer.py │ ├── celery_worker.py │ ├── models.py │ ├── triger_router.py │ ├── websocket_manager.py │ └── messaging_bq.py └── app │ ├── main.py │ └── database.py ├── .gitignore ├── migrations ├── README ├── script.py.mako ├── versions │ └── 19f18cceb2de_websocket_table.py └── env.py ├── docker-compose.yml └── alembic.ini /src/_init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/notification/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/notification/notification_service.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | .vscode 3 | sql_app.db 4 | -------------------------------------------------------------------------------- /migrations/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /src/notification/consumer.py: -------------------------------------------------------------------------------- 1 | from src.notification.messaging_bq import mq 2 | 3 | 4 | def main(): 5 | mq.consume_messages() 6 | 7 | 8 | if __name__ == "__main__": 9 | main() 10 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.9' 2 | services: 3 | rabbitmq: 4 | image: rabbitmq:3-management-alpine 5 | container_name: 'rabbitmq' 6 | hostname: "0.0.0.0" 7 | ports: 8 | - 5672:5672 9 | - 4369:4369 10 | - 15672:15672 11 | tty: true 12 | stdin_open: true -------------------------------------------------------------------------------- /src/app/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from src.notification.triger_router import trigger, manager 3 | 4 | 5 | app = FastAPI() 6 | 7 | app.include_router(trigger) 8 | 9 | 10 | @app.get("/", status_code=200) 11 | def root(): 12 | return { 13 | "message": "notification proof of concept", 14 | } 15 | -------------------------------------------------------------------------------- /src/notification/celery_worker.py: -------------------------------------------------------------------------------- 1 | from celery import Celery 2 | from src.notification.websocket_manager import manager 3 | from asgiref.sync import async_to_sync 4 | from celery.utils.log import get_task_logger 5 | 6 | notifier = Celery("notifier", broker="amqp://guest:guest@127.0.0.1:5672//") 7 | 8 | # Create logger-enable to display messages on task logger. 9 | celery_log = get_task_logger(__name__) 10 | 11 | 12 | @notifier.task 13 | def notification(message: dict): 14 | message_status = async_to_sync(manager.send_personal_message)(message) 15 | return message_status 16 | -------------------------------------------------------------------------------- /migrations/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade() -> None: 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade() -> None: 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /src/app/database.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import create_engine 2 | from sqlalchemy.orm import declarative_base 3 | from sqlalchemy.orm import sessionmaker 4 | 5 | SQLALCHEMY_DATABASE_URL = "sqlite:///./sql_app.db" 6 | 7 | engine = create_engine( 8 | SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} 9 | ) 10 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 11 | 12 | 13 | Base = declarative_base() 14 | 15 | 16 | from contextlib import contextmanager 17 | 18 | 19 | @contextmanager 20 | def session_scope(): 21 | """Provide a transactional scope around a series of operations.""" 22 | session = SessionLocal() 23 | 24 | try: 25 | yield session 26 | except: 27 | session.rollback() 28 | raise 29 | finally: 30 | session.close() 31 | -------------------------------------------------------------------------------- /migrations/versions/19f18cceb2de_websocket_table.py: -------------------------------------------------------------------------------- 1 | """websocket table 2 | 3 | Revision ID: 19f18cceb2de 4 | Revises: 5 | Create Date: 2023-05-07 18:31:41.663564 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = "19f18cceb2de" 14 | down_revision = None 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | # ### commands auto generated by Alembic - please adjust! ### 21 | op.create_table( 22 | "websocket_connections", 23 | sa.Column("id", sa.Integer(), nullable=False), 24 | sa.Column("wsock", sa.JSON(), nullable=False), 25 | sa.Column("user_id", sa.Integer(), nullable=False), 26 | sa.PrimaryKeyConstraint("id"), 27 | sa.UniqueConstraint("user_id"), 28 | ) 29 | # ### end Alembic commands ### 30 | 31 | 32 | def downgrade() -> None: 33 | # ### commands auto generated by Alembic - please adjust! ### 34 | op.drop_table("websocket_connections") 35 | # ### end Alembic commands ### 36 | -------------------------------------------------------------------------------- /src/notification/models.py: -------------------------------------------------------------------------------- 1 | from src.app.database import Base 2 | 3 | from sqlalchemy import Column, Integer, JSON 4 | from sqlalchemy.orm import Session 5 | 6 | 7 | class NoticationConnections(Base): 8 | __tablename__ = "websocket_connections" 9 | id = Column(Integer, primary_key=True, nullable=False) 10 | wsock = Column(JSON, nullable=False) 11 | user_id = Column(Integer, nullable=False, unique=True) 12 | 13 | 14 | class Repo: 15 | def __init__(self, db: Session): 16 | self.db = db 17 | 18 | @property 19 | def base(self): 20 | return self.db.query(NoticationConnections) 21 | 22 | def get_connections(self): 23 | return self.base.all() 24 | 25 | def get_wsock(self, user_id: int): 26 | return self.base.filter(NoticationConnections.user_id == user_id).first() 27 | 28 | def create(self, user_id: int, wsock: dict) -> NoticationConnections: 29 | wbsock = NoticationConnections(user_id=user_id, wsock=wsock) 30 | self.db.add(wbsock) 31 | self.db.commit() 32 | 33 | def delete(self, user_id: int): 34 | wbsock = self.get_wsock(user_id) 35 | if wbsock: 36 | self.db.delete(wbsock) 37 | self.db.commit() 38 | -------------------------------------------------------------------------------- /migrations/env.py: -------------------------------------------------------------------------------- 1 | from logging.config import fileConfig 2 | 3 | from sqlalchemy import engine_from_config 4 | from sqlalchemy import pool 5 | 6 | from alembic import context 7 | from src.app.database import SQLALCHEMY_DATABASE_URL 8 | from src.notification.models import Base 9 | 10 | # this is the Alembic Config object, which provides 11 | # access to the values within the .ini file in use. 12 | config = context.config 13 | config.set_main_option("sqlalchemy.url", SQLALCHEMY_DATABASE_URL) 14 | # Interpret the config file for Python logging. 15 | # This line sets up loggers basically. 16 | if config.config_file_name is not None: 17 | fileConfig(config.config_file_name) 18 | 19 | # add your model's MetaData object here 20 | # for 'autogenerate' support 21 | # from myapp import mymodel 22 | # target_metadata = mymodel.Base.metadata 23 | target_metadata = Base.metadata 24 | 25 | # other values from the config, defined by the needs of env.py, 26 | # can be acquired: 27 | # my_important_option = config.get_main_option("my_important_option") 28 | # ... etc. 29 | 30 | 31 | def run_migrations_offline() -> None: 32 | """Run migrations in 'offline' mode. 33 | 34 | This configures the context with just a URL 35 | and not an Engine, though an Engine is acceptable 36 | here as well. By skipping the Engine creation 37 | we don't even need a DBAPI to be available. 38 | 39 | Calls to context.execute() here emit the given string to the 40 | script output. 41 | 42 | """ 43 | url = config.get_main_option("sqlalchemy.url") 44 | context.configure( 45 | url=url, 46 | target_metadata=target_metadata, 47 | literal_binds=True, 48 | dialect_opts={"paramstyle": "named"}, 49 | ) 50 | 51 | with context.begin_transaction(): 52 | context.run_migrations() 53 | 54 | 55 | def run_migrations_online() -> None: 56 | """Run migrations in 'online' mode. 57 | 58 | In this scenario we need to create an Engine 59 | and associate a connection with the context. 60 | 61 | """ 62 | connectable = engine_from_config( 63 | config.get_section(config.config_ini_section, {}), 64 | prefix="sqlalchemy.", 65 | poolclass=pool.NullPool, 66 | ) 67 | 68 | with connectable.connect() as connection: 69 | context.configure(connection=connection, target_metadata=target_metadata) 70 | 71 | with context.begin_transaction(): 72 | context.run_migrations() 73 | 74 | 75 | if context.is_offline_mode(): 76 | run_migrations_offline() 77 | else: 78 | run_migrations_online() 79 | -------------------------------------------------------------------------------- /src/notification/triger_router.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends 2 | from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK 3 | from src.notification.websocket_manager import manager 4 | from src.notification.messaging_bq import mq 5 | from pydantic import BaseModel 6 | import time 7 | import asyncio 8 | 9 | trigger = APIRouter(prefix="/api/trigger", tags=["trigger_in"]) 10 | 11 | 12 | class Demonstrate(BaseModel): 13 | body: dict 14 | user_id: int 15 | 16 | 17 | State = [] 18 | 19 | 20 | def get_current_user(): 21 | return {"id": 1} 22 | 23 | 24 | @trigger.post("/push-in/", status_code=200) 25 | async def registration(demo: Demonstrate, user: dict = Depends(get_current_user)): 26 | demo = demo.dict() 27 | 28 | print(manager.active_connections) 29 | 30 | if manager.get_ws(demo["user_id"]): 31 | ws_alive = await manager.pong(manager.get_ws(demo["user_id"])) 32 | if ws_alive: 33 | await manager.send_personal_message(demo) 34 | else: 35 | mq.publish_notification(demo) 36 | else: 37 | mq.publish_notification(demo) 38 | 39 | State.append(demo) 40 | 41 | # print(State) 42 | return {"message": "This has been published"} 43 | 44 | 45 | {"message": {"body": {"testing": "God abeg"}, "user_id": 1}, "delivery_tag": 2} 46 | 47 | 48 | @trigger.websocket("/notifier/ws/") 49 | async def notification_socket( 50 | websocket: WebSocket, user: dict = Depends(get_current_user) 51 | ): 52 | await manager.connect(websocket, user["id"]) 53 | 54 | try: 55 | if manager.get_ws(user["id"]): 56 | user_meesage = mq.get_user_messages(user["id"]) 57 | 58 | if user_meesage != None: 59 | for message in user_meesage: 60 | if message != None: 61 | message_status = await manager.personal_notification(message) 62 | print(message_status) 63 | # delete the message from the queue if successfully sent via WebSocket 64 | if message_status: 65 | mq.channel.basic_ack(delivery_tag=message["delivery_tag"]) 66 | 67 | hang = True 68 | while hang: 69 | try: 70 | await asyncio.sleep(1) 71 | await manager.ping(websocket) 72 | except asyncio.exceptions.CancelledError: 73 | break 74 | 75 | except (WebSocketDisconnect, ConnectionClosedError, ConnectionClosedOK): 76 | manager.disconnect(user["id"]) 77 | -------------------------------------------------------------------------------- /src/notification/websocket_manager.py: -------------------------------------------------------------------------------- 1 | from fastapi import WebSocket 2 | from typing import Dict, List 3 | import asyncio 4 | 5 | CONNECTIONS = [] 6 | 7 | 8 | class ConnectionManager: 9 | # INITIALIZE THE LIST AND CONNECTION 10 | def __init__(self): 11 | self.active_connections: List[Dict[int, WebSocket]] = CONNECTIONS 12 | 13 | # CONNECT TO WEBSOCKET AND APPEND TO THE LIST 14 | async def connect(self, websocket: WebSocket, connection_id: int): 15 | await websocket.accept() 16 | self.active_connections.append({connection_id: websocket}) 17 | 18 | # PURGE WEBSOCKET LIST STORE 19 | def disconnect(self, user_id: int): 20 | for web_dict in self.active_connections: 21 | if web_dict.get(user_id): 22 | self.active_connections.remove(web_dict) 23 | 24 | # SEND MESSAGE AFTER WEBSOCKET IS ALIVE 25 | async def send_personal_message(self, message: dict): 26 | user_ws = self.get_ws(message["user_id"]) 27 | if user_ws != None: 28 | websocket: WebSocket = user_ws 29 | await websocket.send_json(message) 30 | return True 31 | return False 32 | 33 | # Keep the WebSocket alive. 34 | async def ping(self, websocket: WebSocket): 35 | await websocket.send_text("Nil") 36 | 37 | # Trigger to recieve message from the WebSocket 38 | async def reply(self, websocket: WebSocket): 39 | await websocket.send_text("Reply Pong") 40 | 41 | # Listening to the Websocket and sending message. 42 | async def pong(self, websocket: WebSocket): 43 | await self.reply(websocket) 44 | try: 45 | pong = await asyncio.wait_for(websocket.receive_text(), timeout=5) 46 | if pong == "pong": 47 | return True 48 | else: 49 | return False 50 | 51 | except asyncio.exceptions.TimeoutError as e: 52 | return False 53 | 54 | # Fetch WebSocket and return them if it exists. 55 | def get_ws(self, user_id: int): 56 | for web_dict in self.active_connections: 57 | if web_dict.get(user_id): 58 | return web_dict[user_id] 59 | 60 | # Send a retry message to a user WebSocket 61 | async def personal_notification(self, message: dict): 62 | connection_check = self.get_ws(message["message"]["user_id"]) 63 | if connection_check: 64 | connection_check: WebSocket 65 | await connection_check.send_json(message) 66 | await asyncio.sleep(2) 67 | return True 68 | else: 69 | del self.active_connections[message["message"]["user_id"]] 70 | return False 71 | 72 | 73 | manager = ConnectionManager() 74 | -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | script_location = migrations 6 | 7 | # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s 8 | # Uncomment the line below if you want the files to be prepended with date and time 9 | # see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file 10 | # for all available tokens 11 | # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s 12 | 13 | # sys.path path, will be prepended to sys.path if present. 14 | # defaults to the current working directory. 15 | prepend_sys_path = . 16 | 17 | # timezone to use when rendering the date within the migration file 18 | # as well as the filename. 19 | # If specified, requires the python-dateutil library that can be 20 | # installed by adding `alembic[tz]` to the pip requirements 21 | # string value is passed to dateutil.tz.gettz() 22 | # leave blank for localtime 23 | # timezone = 24 | 25 | # max length of characters to apply to the 26 | # "slug" field 27 | # truncate_slug_length = 40 28 | 29 | # set to 'true' to run the environment during 30 | # the 'revision' command, regardless of autogenerate 31 | # revision_environment = false 32 | 33 | # set to 'true' to allow .pyc and .pyo files without 34 | # a source .py file to be detected as revisions in the 35 | # versions/ directory 36 | # sourceless = false 37 | 38 | # version location specification; This defaults 39 | # to migrations/versions. When using multiple version 40 | # directories, initial revisions must be specified with --version-path. 41 | # The path separator used here should be the separator specified by "version_path_separator" below. 42 | # version_locations = %(here)s/bar:%(here)s/bat:migrations/versions 43 | 44 | # version path separator; As mentioned above, this is the character used to split 45 | # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. 46 | # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. 47 | # Valid values for version_path_separator are: 48 | # 49 | # version_path_separator = : 50 | # version_path_separator = ; 51 | # version_path_separator = space 52 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects. 53 | 54 | # set to 'true' to search source files recursively 55 | # in each "version_locations" directory 56 | # new in Alembic version 1.10 57 | # recursive_version_locations = false 58 | 59 | # the output encoding used when revision files 60 | # are written from script.py.mako 61 | # output_encoding = utf-8 62 | 63 | ; sqlalchemy.url = driver://user:pass@localhost/dbname 64 | 65 | 66 | [post_write_hooks] 67 | # post_write_hooks defines scripts or Python functions that are run 68 | # on newly generated revision scripts. See the documentation for further 69 | # detail and examples 70 | 71 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 72 | # hooks = black 73 | # black.type = console_scripts 74 | # black.entrypoint = black 75 | # black.options = -l 79 REVISION_SCRIPT_FILENAME 76 | 77 | # Logging configuration 78 | [loggers] 79 | keys = root,sqlalchemy,alembic 80 | 81 | [handlers] 82 | keys = console 83 | 84 | [formatters] 85 | keys = generic 86 | 87 | [logger_root] 88 | level = WARN 89 | handlers = console 90 | qualname = 91 | 92 | [logger_sqlalchemy] 93 | level = WARN 94 | handlers = 95 | qualname = sqlalchemy.engine 96 | 97 | [logger_alembic] 98 | level = INFO 99 | handlers = 100 | qualname = alembic 101 | 102 | [handler_console] 103 | class = StreamHandler 104 | args = (sys.stderr,) 105 | level = NOTSET 106 | formatter = generic 107 | 108 | [formatter_generic] 109 | format = %(levelname)-5.5s [%(name)s] %(message)s 110 | datefmt = %H:%M:%S 111 | -------------------------------------------------------------------------------- /src/notification/messaging_bq.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pika 3 | from src.notification.websocket_manager import manager 4 | from asgiref.sync import async_to_sync 5 | 6 | 7 | class MessageQueue: 8 | def __init__(self) -> None: 9 | # Initializing the Message Queue 10 | self.connection = pika.BlockingConnection( 11 | pika.ConnectionParameters(host="localhost", port=5672, heartbeat=600) 12 | ) 13 | self.channel = self.connection.channel() 14 | self.wsok_manager = manager 15 | self.exchange = "demo_notification" 16 | self.queue = "test_case_queue" 17 | self.channel.exchange_declare(exchange=self.exchange, exchange_type="fanout") 18 | self.channel.queue_declare(queue=self.queue) 19 | self.channel.queue_bind( 20 | exchange=self.exchange, queue=self.queue, routing_key="notfy-x" 21 | ) 22 | self.channel.basic_qos(prefetch_count=1) 23 | 24 | # PUBLISHING MESSAGES TO THE QUEUE 25 | def publish_notification(self, message: dict): 26 | # publishing to the queue 27 | self.channel.basic_publish( 28 | exchange=self.exchange, routing_key="notfy-x", body=json.dumps(message) 29 | ) 30 | return True 31 | 32 | # CONSUMER (BUT LARGELY INACTIVE) 33 | def consume_messages(self): 34 | try: 35 | print("messages are now consumed") 36 | 37 | async def callback_func(ch, method, properties, body): 38 | # parse the message from the queue 39 | 40 | if self.wsok_manager.active_connections: 41 | message_status = await self.wsok_manager.send_personal_message( 42 | json.loads(body) 43 | ) 44 | if message_status: 45 | ch.basic_ack(delivery_tag=method.delivery_tag) 46 | 47 | else: 48 | ch.basic_nack(delivery_tag=method.delivery_tag, requeue=False) 49 | 50 | # async_tosyn 51 | self.channel.basic_consume( 52 | queue=self.queue, 53 | on_message_callback=async_to_sync(callback_func), 54 | auto_ack=False, 55 | ) 56 | self.channel.start_consuming() 57 | 58 | except KeyboardInterrupt: 59 | print("Consumer closed") 60 | 61 | # FETCH ALL THE MESSAGES 62 | def fetch_all_messages(self): 63 | messages = [] 64 | method_frame, header_frame, body = self.channel.basic_get(queue=self.queue) 65 | while method_frame: 66 | method_frame, header_frame, body = self.channel.basic_get(queue=self.queue) 67 | 68 | if body: 69 | messages.append((method_frame, json.loads(body))) 70 | 71 | return messages 72 | 73 | # USERS MESSAGE FILTER 74 | def get_user_messages(self, user_id: int): 75 | messages = self.fetch_all_messages() 76 | if messages: 77 | user_messages = [ 78 | {"message": message, "delivery_tag": method_frame.delivery_tag} 79 | for method_frame, message in messages 80 | if message["user_id"] == user_id 81 | ] 82 | if user_messages: 83 | return user_messages 84 | else: 85 | return None 86 | return None 87 | 88 | # RESEND MESSAGES VIA WEBSOCKET PER PERSON 89 | async def retry_unsent_messages(self, user_id: int): 90 | user_meesage = self.get_user_messages(user_id) 91 | if user_meesage != None: 92 | for message in user_meesage: 93 | if message != None: 94 | message_status = await self.wsok_manager.personal_notification( 95 | message 96 | ) 97 | if message_status: 98 | self.channel.basic_ack(delivery_tag=message["delivery_tag"]) 99 | return True 100 | return True 101 | 102 | # CLOSING CONNECTIONS 103 | def __del__(self): 104 | # try catch exceptions 105 | self.connection.close() 106 | 107 | 108 | mq = MessageQueue() 109 | --------------------------------------------------------------------------------