├── .gitignore ├── .python-version ├── 2-data-ingestion ├── cdc.py ├── config.py ├── db.py ├── mq.py └── test_cdc.py ├── 3-feature-pipeline ├── config.py ├── data_flow │ ├── __init__.py │ ├── stream_input.py │ └── stream_output.py ├── data_logic │ ├── __init__.py │ ├── chunking_data_handlers.py │ ├── cleaning_data_handlers.py │ ├── dispatchers.py │ └── embedding_data_handlers.py ├── db.py ├── finetuning │ ├── __init__.py │ ├── exceptions.py │ ├── file_handler.py │ ├── generate_data.py │ └── llm_communication.py ├── llm │ ├── __init__.py │ ├── chain.py │ └── prompt_templates.py ├── main.py ├── models │ ├── __init__.py │ ├── base.py │ ├── chunk.py │ ├── clean.py │ ├── embedded_chunk.py │ └── raw.py ├── mq.py ├── rag │ ├── __init__.py │ ├── query_expanison.py │ ├── reranking.py │ ├── retriever.py │ └── self_query.py ├── retriever.py ├── scripts │ └── bytewax_entrypoint.sh └── utils │ ├── __init__.py │ ├── chunking.py │ ├── cleaning.py │ ├── embeddings.py │ └── logging.py ├── 4-finetuning ├── .env.example ├── README.md ├── build_config.yaml ├── finetuning │ ├── __init__.py │ ├── config.yaml │ ├── dataset_client.py │ ├── model.py │ ├── requirements.txt │ └── settings.py ├── media │ └── fine-tuning-workflow.png └── test_local.py ├── 5-inference ├── .env.example ├── Makefile ├── README.txt ├── config.py ├── evaluation │ ├── __init__.py │ ├── model.py │ └── rag.py ├── finetuning │ ├── __init__.py │ ├── config.yaml │ ├── dataset_client.py │ ├── model.py │ └── settings.py ├── inference_pipeline.py ├── llm │ ├── __init__.py │ ├── chain.py │ └── prompt_templates.py ├── main.py ├── monitoring │ ├── __init__.py │ └── prompt_monitoring.py ├── pyproject.toml ├── rag │ ├── __init__.py │ ├── query_expanison.py │ ├── reranking.py │ ├── retriever.py │ └── self_query.py └── utils │ ├── __init__.py │ ├── chunking.py │ ├── cleaning.py │ └── embeddings.py ├── GENERATE_INSTRUCT_DATASET.md ├── INSTALL_AND_USAGE.md ├── LICENSE ├── Makefile ├── RAG.md ├── README.md ├── TRAINING.md ├── config.py ├── crawlers ├── __init__.py ├── base.py ├── github.py ├── linkedin.py └── medium.py ├── data-ingestion ├── cdc.py ├── db.py └── mq.py ├── db ├── __init__.py ├── documents.py └── mongo.py ├── dispatcher.py ├── docker-bake.hcl ├── docker-compose-superlinked.yml ├── docker-compose.yml ├── errors.py ├── lib.py ├── main.py ├── ops ├── .gitignore ├── Pulumi.yaml ├── components │ ├── cdc.ts │ ├── config.ts │ ├── crawler.ts │ ├── docdb.ts │ ├── ecs │ │ ├── cluster.ts │ │ ├── iam.ts │ │ └── service.ts │ ├── mq.ts │ ├── nat.ts │ ├── repository.ts │ └── vpc.ts ├── index.ts ├── package-lock.json ├── package.json ├── tsconfig.json └── yarn.lock ├── poetry.lock ├── pyproject.toml ├── requirements.txt ├── settings.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # IDEs 156 | .idea/ 157 | .vscode 158 | 159 | # MacOS 160 | .DS_Store 161 | 162 | # Ruff 163 | .ruff_cache 164 | 165 | data/ 166 | dataset/ 167 | data 168 | 169 | # Data 170 | output 171 | .cache 172 | training_pipeline_output 173 | mistral_instruct_generation 174 | cache -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11.4 2 | -------------------------------------------------------------------------------- /2-data-ingestion/cdc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | from bson import json_util 5 | from mq import publish_to_rabbitmq 6 | 7 | from config import settings 8 | from db import MongoDatabaseConnector 9 | 10 | # Configure logging 11 | logging.basicConfig( 12 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 13 | ) 14 | 15 | 16 | def stream_process(): 17 | try: 18 | # Setup MongoDB connection 19 | client = MongoDatabaseConnector() 20 | db = client["scrabble"] 21 | logging.info("Connected to MongoDB.") 22 | 23 | # Watch changes in a specific collection 24 | changes = db.watch([{"$match": {"operationType": {"$in": ["insert"]}}}]) 25 | for change in changes: 26 | data_type = change["ns"]["coll"] 27 | entry_id = str(change["fullDocument"]["_id"]) # Convert ObjectId to string 28 | change["fullDocument"].pop("_id") 29 | change["fullDocument"]["type"] = data_type 30 | change["fullDocument"]["entry_id"] = entry_id 31 | 32 | # Use json_util to serialize the document 33 | data = json.dumps(change["fullDocument"], default=json_util.default) 34 | logging.info(f"Change detected and serialized: {data}") 35 | 36 | # Send data to rabbitmq 37 | publish_to_rabbitmq(queue_name=settings.RABBITMQ_QUEUE_NAME, data=data) 38 | logging.info("Data published to RabbitMQ.") 39 | 40 | except Exception as e: 41 | logging.error(f"An error occurred: {e}") 42 | 43 | 44 | if __name__ == "__main__": 45 | stream_process() 46 | -------------------------------------------------------------------------------- /2-data-ingestion/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | 3 | 4 | class Settings(BaseSettings): 5 | # MongoDB configs 6 | MONGO_DATABASE_HOST: str = ( 7 | "mongodb://mongo1:30001,mongo2:30002,mongo3:30003/?replicaSet=my-replica-set" 8 | ) 9 | MONGO_DATABASE_NAME: str = "scrabble" 10 | 11 | RABBITMQ_HOST: str = "mq" # or localhost if running outside Docker 12 | RABBITMQ_PORT: int = 5672 13 | RABBITMQ_DEFAULT_USERNAME: str = "guest" 14 | RABBITMQ_DEFAULT_PASSWORD: str = "guest" 15 | RABBITMQ_QUEUE_NAME: str = "default" 16 | 17 | 18 | settings = Settings() 19 | -------------------------------------------------------------------------------- /2-data-ingestion/db.py: -------------------------------------------------------------------------------- 1 | from pymongo import MongoClient 2 | from pymongo.errors import ConnectionFailure 3 | 4 | from config import settings 5 | 6 | 7 | class MongoDatabaseConnector: 8 | """Singleton class to connect to MongoDB database.""" 9 | 10 | _instance: MongoClient = None 11 | 12 | def __new__(cls, *args, **kwargs): 13 | if cls._instance is None: 14 | try: 15 | cls._instance = MongoClient(settings.MONGO_DATABASE_HOST) 16 | except ConnectionFailure as e: 17 | print(f"Couldn't connect to the database: {str(e)}") 18 | raise 19 | 20 | print( 21 | f"Connection to database with uri: {settings.MONGO_DATABASE_HOST} successful" 22 | ) 23 | return cls._instance 24 | 25 | def get_database(self): 26 | return self._instance[settings.MONGO_DATABASE_NAME] 27 | 28 | def close(self): 29 | if self._instance: 30 | self._instance.close() 31 | print("Connected to database has been closed.") 32 | 33 | 34 | connection = MongoDatabaseConnector() 35 | -------------------------------------------------------------------------------- /2-data-ingestion/mq.py: -------------------------------------------------------------------------------- 1 | import pika 2 | 3 | from config import settings 4 | 5 | 6 | class RabbitMQConnection: 7 | """Singleton class to manage RabbitMQ connection.""" 8 | 9 | _instance = None 10 | 11 | def __new__( 12 | cls, 13 | host: str = None, 14 | port: int = None, 15 | username: str = None, 16 | password: str = None, 17 | virtual_host: str = "/", 18 | ): 19 | if not cls._instance: 20 | cls._instance = super().__new__(cls) 21 | return cls._instance 22 | 23 | def __init__( 24 | self, 25 | host: str = None, 26 | port: int = None, 27 | username: str = None, 28 | password: str = None, 29 | virtual_host: str = "/", 30 | fail_silently: bool = False, 31 | **kwargs, 32 | ): 33 | self.host = host or settings.RABBITMQ_HOST 34 | self.port = port or settings.RABBITMQ_PORT 35 | self.username = username or settings.RABBITMQ_DEFAULT_USERNAME 36 | self.password = password or settings.RABBITMQ_DEFAULT_PASSWORD 37 | self.virtual_host = virtual_host 38 | self.fail_silently = fail_silently 39 | self._connection = None 40 | 41 | def __enter__(self): 42 | self.connect() 43 | return self 44 | 45 | def __exit__(self, exc_type, exc_val, exc_tb): 46 | self.close() 47 | 48 | def connect(self): 49 | try: 50 | credentials = pika.PlainCredentials(self.username, self.password) 51 | self._connection = pika.BlockingConnection( 52 | pika.ConnectionParameters( 53 | host=self.host, 54 | port=self.port, 55 | virtual_host=self.virtual_host, 56 | credentials=credentials, 57 | ) 58 | ) 59 | except pika.exceptions.AMQPConnectionError as e: 60 | print("Failed to connect to RabbitMQ:", e) 61 | if not self.fail_silently: 62 | raise e 63 | 64 | def is_connected(self) -> bool: 65 | return self._connection is not None and self._connection.is_open 66 | 67 | def get_channel(self): 68 | if self.is_connected(): 69 | return self._connection.channel() 70 | 71 | def close(self): 72 | if self.is_connected(): 73 | self._connection.close() 74 | self._connection = None 75 | print("Closed RabbitMQ connection") 76 | 77 | 78 | def publish_to_rabbitmq(queue_name: str, data: str): 79 | """Publish data to a RabbitMQ queue.""" 80 | try: 81 | # Create an instance of RabbitMQConnection 82 | rabbitmq_conn = RabbitMQConnection() 83 | 84 | # Establish connection 85 | with rabbitmq_conn: 86 | channel = rabbitmq_conn.get_channel() 87 | 88 | # Ensure the queue exists 89 | channel.queue_declare(queue=queue_name, durable=True) 90 | 91 | # Delivery confirmation 92 | channel.confirm_delivery() 93 | 94 | # Send data to the queue 95 | channel.basic_publish( 96 | exchange="", 97 | routing_key=queue_name, 98 | body=data, 99 | properties=pika.BasicProperties( 100 | delivery_mode=2, # make message persistent 101 | ), 102 | ) 103 | print("Sent data to RabbitMQ:", data) 104 | except pika.exceptions.UnroutableError: 105 | print("Message could not be routed") 106 | except Exception as e: 107 | print(f"Error publishing to RabbitMQ: {e}") 108 | 109 | 110 | if __name__ == "__main__": 111 | publish_to_rabbitmq("test_queue", "Hello, World!") 112 | -------------------------------------------------------------------------------- /2-data-ingestion/test_cdc.py: -------------------------------------------------------------------------------- 1 | from pymongo import MongoClient 2 | 3 | 4 | def insert_data_to_mongodb(uri, database_name, collection_name, data): 5 | """ 6 | Insert data into a MongoDB collection. 7 | 8 | :param uri: MongoDB URI 9 | :param database_name: Name of the database 10 | :param collection_name: Name of the collection 11 | :param data: Data to be inserted (dict) 12 | """ 13 | client = MongoClient(uri) 14 | db = client[database_name] 15 | collection = db[collection_name] 16 | 17 | try: 18 | result = collection.insert_one(data) 19 | print(f"Data inserted with _id: {result.inserted_id}") 20 | except Exception as e: 21 | print(f"An error occurred: {e}") 22 | finally: 23 | client.close() 24 | 25 | 26 | if __name__ == "__main__": 27 | insert_data_to_mongodb( 28 | "mongodb://localhost:30001,localhost:30002,localhost:30003/?replicaSet=my-replica-set", 29 | "scrabble", 30 | "posts", 31 | {"platform": "linkedin", "content": "Test content"} 32 | ) 33 | -------------------------------------------------------------------------------- /3-feature-pipeline/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | 3 | 4 | class Settings(BaseSettings): 5 | # CometML config 6 | COMET_API_KEY: str | None = None 7 | COMET_WORKSPACE: str | None = None 8 | COMET_PROJECT: str | None = None 9 | 10 | # Embeddings config 11 | EMBEDDING_MODEL_ID: str = "sentence-transformers/all-MiniLM-L6-v2" 12 | EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 256 13 | EMBEDDING_SIZE: int = 384 14 | EMBEDDING_MODEL_DEVICE: str = "cpu" 15 | 16 | # OpenAI 17 | OPENAI_MODEL_ID: str = "gpt-4-1106-preview" 18 | OPENAI_API_KEY: str | None = None 19 | 20 | # MQ config 21 | RABBITMQ_DEFAULT_USERNAME: str = "guest" 22 | RABBITMQ_DEFAULT_PASSWORD: str = "guest" 23 | RABBITMQ_HOST: str = "mq" # or localhost if running outside Docker 24 | RABBITMQ_PORT: int = 5672 25 | RABBITMQ_QUEUE_NAME: str = "default" 26 | 27 | # QdrantDB config 28 | QDRANT_DATABASE_HOST: str = "qdrant" # or localhost if running outside Docker 29 | QDRANT_DATABASE_PORT: int = 6333 30 | USE_QDRANT_CLOUD: bool = False # if True, fill in QDRANT_CLOUD_URL and QDRANT_APIKEY 31 | QDRANT_CLOUD_URL: str | None = None 32 | QDRANT_APIKEY: str | None = None 33 | 34 | 35 | settings = Settings() 36 | -------------------------------------------------------------------------------- /3-feature-pipeline/data_flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/3-feature-pipeline/data_flow/__init__.py -------------------------------------------------------------------------------- /3-feature-pipeline/data_flow/stream_input.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime 3 | import time 4 | from typing import Generic, Iterable, List, Optional, TypeVar 5 | 6 | from bytewax.inputs import FixedPartitionedSource, StatefulSourcePartition 7 | from config import settings 8 | from mq import RabbitMQConnection 9 | from utils.logging import get_logger 10 | 11 | logger = get_logger(__name__) 12 | 13 | DataT = TypeVar("DataT") 14 | MessageT = TypeVar("MessageT") 15 | 16 | 17 | class RabbitMQPartition(StatefulSourcePartition, Generic[DataT, MessageT]): 18 | """ 19 | Class responsible for creating a connection between bytewax and rabbitmq that facilitates the transfer of data from mq to bytewax streaming piepline. 20 | Inherits StatefulSourcePartition for snapshot functionality that enables saving the state of the queue 21 | """ 22 | 23 | def __init__(self, queue_name: str, resume_state: MessageT | None = None) -> None: 24 | self._in_flight_msg_ids = resume_state or set() 25 | self.queue_name = queue_name 26 | self.connection = RabbitMQConnection() 27 | self.connection.connect() 28 | self.channel = self.connection.get_channel() 29 | 30 | def next_batch(self, sched: Optional[datetime]) -> Iterable[DataT]: 31 | try: 32 | method_frame, header_frame, body = self.channel.basic_get( 33 | queue=self.queue_name, auto_ack=True 34 | ) 35 | except Exception: 36 | logger.error( 37 | f"Error while fetching message from queue.", queue_name=self.queue_name 38 | ) 39 | time.sleep(10) # Sleep for 10 seconds before retrying to access the queue. 40 | 41 | self.connection.connect() 42 | self.channel = self.connection.get_channel() 43 | 44 | return [] 45 | 46 | if method_frame: 47 | message_id = method_frame.delivery_tag 48 | self._in_flight_msg_ids.add(message_id) 49 | 50 | return [json.loads(body)] 51 | else: 52 | return [] 53 | 54 | def snapshot(self) -> MessageT: 55 | return self._in_flight_msg_ids 56 | 57 | def garbage_collect(self, state): 58 | closed_in_flight_msg_ids = state 59 | for msg_id in closed_in_flight_msg_ids: 60 | self.channel.basic_ack(delivery_tag=msg_id) 61 | self._in_flight_msg_ids.remove(msg_id) 62 | 63 | def close(self): 64 | self.channel.close() 65 | 66 | 67 | class RabbitMQSource(FixedPartitionedSource): 68 | def list_parts(self) -> List[str]: 69 | return ["single partition"] 70 | 71 | def build_part( 72 | self, now: datetime, for_part: str, resume_state: MessageT | None = None 73 | ) -> StatefulSourcePartition[DataT, MessageT]: 74 | return RabbitMQPartition(queue_name=settings.RABBITMQ_QUEUE_NAME) 75 | -------------------------------------------------------------------------------- /3-feature-pipeline/data_flow/stream_output.py: -------------------------------------------------------------------------------- 1 | from bytewax.outputs import DynamicSink, StatelessSinkPartition 2 | from qdrant_client.http.api_client import UnexpectedResponse 3 | from qdrant_client.models import Batch 4 | 5 | from utils.logging import get_logger 6 | from db import QdrantDatabaseConnector 7 | from models.base import VectorDBDataModel 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | class QdrantOutput(DynamicSink): 13 | """ 14 | Bytewax class that facilitates the connection to a Qdrant vector DB. 15 | Inherits DynamicSink because of the ability to create different sink sources (e.g, vector and non-vector collections) 16 | """ 17 | 18 | def __init__(self, connection: QdrantDatabaseConnector, sink_type: str): 19 | self._connection = connection 20 | self._sink_type = sink_type 21 | 22 | try: 23 | self._connection.get_collection(collection_name="cleaned_posts") 24 | except UnexpectedResponse: 25 | logger.info( 26 | "Couldn't access the collection. Creating a new one...", 27 | collection_name="cleaned_posts", 28 | ) 29 | 30 | self._connection.create_non_vector_collection( 31 | collection_name="cleaned_posts" 32 | ) 33 | 34 | try: 35 | self._connection.get_collection(collection_name="cleaned_articles") 36 | except UnexpectedResponse: 37 | logger.info( 38 | "Couldn't access the collection. Creating a new one...", 39 | collection_name="cleaned_articles", 40 | ) 41 | 42 | self._connection.create_non_vector_collection( 43 | collection_name="cleaned_articles" 44 | ) 45 | 46 | try: 47 | self._connection.get_collection(collection_name="cleaned_repositories") 48 | except UnexpectedResponse: 49 | logger.info( 50 | "Couldn't access the collection. Creating a new one...", 51 | collection_name="cleaned_repositories", 52 | ) 53 | 54 | self._connection.create_non_vector_collection( 55 | collection_name="cleaned_repositories" 56 | ) 57 | 58 | try: 59 | self._connection.get_collection(collection_name="vector_posts") 60 | except UnexpectedResponse: 61 | logger.info( 62 | "Couldn't access the collection. Creating a new one...", 63 | collection_name="vector_posts", 64 | ) 65 | 66 | self._connection.create_vector_collection(collection_name="vector_posts") 67 | 68 | try: 69 | self._connection.get_collection(collection_name="vector_articles") 70 | except UnexpectedResponse: 71 | logger.info( 72 | "Couldn't access the collection. Creating a new one...", 73 | collection_name="vector_articles", 74 | ) 75 | 76 | self._connection.create_vector_collection(collection_name="vector_articles") 77 | 78 | try: 79 | self._connection.get_collection(collection_name="vector_repositories") 80 | except UnexpectedResponse: 81 | logger.info( 82 | "Couldn't access the collection. Creating a new one...", 83 | collection_name="vector_repositories", 84 | ) 85 | 86 | self._connection.create_vector_collection( 87 | collection_name="vector_repositories" 88 | ) 89 | 90 | def build(self, worker_index: int, worker_count: int) -> StatelessSinkPartition: 91 | if self._sink_type == "clean": 92 | return QdrantCleanedDataSink(connection=self._connection) 93 | elif self._sink_type == "vector": 94 | return QdrantVectorDataSink(connection=self._connection) 95 | else: 96 | raise ValueError(f"Unsupported sink type: {self._sink_type}") 97 | 98 | 99 | class QdrantCleanedDataSink(StatelessSinkPartition): 100 | def __init__(self, connection: QdrantDatabaseConnector): 101 | self._client = connection 102 | 103 | def write_batch(self, items: list[VectorDBDataModel]) -> None: 104 | payloads = [item.to_payload() for item in items] 105 | ids, data = zip(*payloads) 106 | collection_name = get_clean_collection(data_type=data[0]["type"]) 107 | self._client.write_data( 108 | collection_name=collection_name, 109 | points=Batch(ids=ids, vectors={}, payloads=data), 110 | ) 111 | 112 | logger.info( 113 | "Successfully inserted requested cleaned point(s)", 114 | collection_name=collection_name, 115 | num=len(ids), 116 | ) 117 | 118 | 119 | class QdrantVectorDataSink(StatelessSinkPartition): 120 | def __init__(self, connection: QdrantDatabaseConnector): 121 | self._client = connection 122 | 123 | def write_batch(self, items: list[VectorDBDataModel]) -> None: 124 | payloads = [item.to_payload() for item in items] 125 | ids, vectors, meta_data = zip(*payloads) 126 | collection_name = get_vector_collection(data_type=meta_data[0]["type"]) 127 | self._client.write_data( 128 | collection_name=collection_name, 129 | points=Batch(ids=ids, vectors=vectors, payloads=meta_data), 130 | ) 131 | 132 | logger.info( 133 | "Successfully inserted requested vector point(s)", 134 | collection_name=collection_name, 135 | num=len(ids), 136 | ) 137 | 138 | 139 | def get_clean_collection(data_type: str) -> str: 140 | if data_type == "posts": 141 | return "cleaned_posts" 142 | elif data_type == "articles": 143 | return "cleaned_articles" 144 | elif data_type == "repositories": 145 | return "cleaned_repositories" 146 | else: 147 | raise ValueError(f"Unsupported data type: {data_type}") 148 | 149 | 150 | def get_vector_collection(data_type: str) -> str: 151 | if data_type == "posts": 152 | return "vector_posts" 153 | elif data_type == "articles": 154 | return "vector_articles" 155 | elif data_type == "repositories": 156 | return "vector_repositories" 157 | else: 158 | raise ValueError(f"Unsupported data type: {data_type}") 159 | -------------------------------------------------------------------------------- /3-feature-pipeline/data_logic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/3-feature-pipeline/data_logic/__init__.py -------------------------------------------------------------------------------- /3-feature-pipeline/data_logic/chunking_data_handlers.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from abc import ABC, abstractmethod 3 | 4 | from models.base import DataModel 5 | from models.chunk import ArticleChunkModel, PostChunkModel, RepositoryChunkModel 6 | from models.clean import ArticleCleanedModel, PostCleanedModel, RepositoryCleanedModel 7 | from utils.chunking import chunk_text 8 | 9 | 10 | class ChunkingDataHandler(ABC): 11 | """ 12 | Abstract class for all Chunking data handlers. 13 | All data transformations logic for the chunking step is done here 14 | """ 15 | 16 | @abstractmethod 17 | def chunk(self, data_model: DataModel) -> list[DataModel]: 18 | pass 19 | 20 | 21 | class PostChunkingHandler(ChunkingDataHandler): 22 | def chunk(self, data_model: PostCleanedModel) -> list[PostChunkModel]: 23 | data_models_list = [] 24 | 25 | text_content = data_model.cleaned_content 26 | chunks = chunk_text(text_content) 27 | 28 | for chunk in chunks: 29 | model = PostChunkModel( 30 | entry_id=data_model.entry_id, 31 | platform=data_model.platform, 32 | chunk_id=hashlib.md5(chunk.encode()).hexdigest(), 33 | chunk_content=chunk, 34 | author_id=data_model.author_id, 35 | image=data_model.image if data_model.image else None, 36 | type=data_model.type, 37 | ) 38 | data_models_list.append(model) 39 | 40 | return data_models_list 41 | 42 | 43 | class ArticleChunkingHandler(ChunkingDataHandler): 44 | def chunk(self, data_model: ArticleCleanedModel) -> list[ArticleChunkModel]: 45 | data_models_list = [] 46 | 47 | text_content = data_model.cleaned_content 48 | chunks = chunk_text(text_content) 49 | 50 | for chunk in chunks: 51 | model = ArticleChunkModel( 52 | entry_id=data_model.entry_id, 53 | platform=data_model.platform, 54 | link=data_model.link, 55 | chunk_id=hashlib.md5(chunk.encode()).hexdigest(), 56 | chunk_content=chunk, 57 | author_id=data_model.author_id, 58 | type=data_model.type, 59 | ) 60 | data_models_list.append(model) 61 | 62 | return data_models_list 63 | 64 | 65 | class RepositoryChunkingHandler(ChunkingDataHandler): 66 | def chunk(self, data_model: RepositoryCleanedModel) -> list[RepositoryChunkModel]: 67 | data_models_list = [] 68 | 69 | text_content = data_model.cleaned_content 70 | chunks = chunk_text(text_content) 71 | 72 | for chunk in chunks: 73 | model = RepositoryChunkModel( 74 | entry_id=data_model.entry_id, 75 | name=data_model.name, 76 | link=data_model.link, 77 | chunk_id=hashlib.md5(chunk.encode()).hexdigest(), 78 | chunk_content=chunk, 79 | owner_id=data_model.owner_id, 80 | type=data_model.type, 81 | ) 82 | data_models_list.append(model) 83 | 84 | return data_models_list 85 | -------------------------------------------------------------------------------- /3-feature-pipeline/data_logic/cleaning_data_handlers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from models.base import DataModel 4 | from models.clean import ArticleCleanedModel, PostCleanedModel, RepositoryCleanedModel 5 | from models.raw import ArticleRawModel, PostsRawModel, RepositoryRawModel 6 | from utils.cleaning import clean_text 7 | 8 | 9 | class CleaningDataHandler(ABC): 10 | """ 11 | Abstract class for all cleaning data handlers. 12 | All data transformations logic for the cleaning step is done here 13 | """ 14 | 15 | @abstractmethod 16 | def clean(self, data_model: DataModel) -> DataModel: 17 | pass 18 | 19 | 20 | class PostCleaningHandler(CleaningDataHandler): 21 | def clean(self, data_model: PostsRawModel) -> PostCleanedModel: 22 | return PostCleanedModel( 23 | entry_id=data_model.entry_id, 24 | platform=data_model.platform, 25 | cleaned_content=clean_text("".join(data_model.content.values())), 26 | author_id=data_model.author_id, 27 | image=data_model.image if data_model.image else None, 28 | type=data_model.type, 29 | ) 30 | 31 | 32 | class ArticleCleaningHandler(CleaningDataHandler): 33 | def clean(self, data_model: ArticleRawModel) -> ArticleCleanedModel: 34 | return ArticleCleanedModel( 35 | entry_id=data_model.entry_id, 36 | platform=data_model.platform, 37 | link=data_model.link, 38 | cleaned_content=clean_text("".join(data_model.content.values())), 39 | author_id=data_model.author_id, 40 | type=data_model.type, 41 | ) 42 | 43 | 44 | class RepositoryCleaningHandler(CleaningDataHandler): 45 | def clean(self, data_model: RepositoryRawModel) -> RepositoryCleanedModel: 46 | return RepositoryCleanedModel( 47 | entry_id=data_model.entry_id, 48 | name=data_model.name, 49 | link=data_model.link, 50 | cleaned_content=clean_text("".join(data_model.content.values())), 51 | owner_id=data_model.owner_id, 52 | type=data_model.type, 53 | ) 54 | -------------------------------------------------------------------------------- /3-feature-pipeline/data_logic/dispatchers.py: -------------------------------------------------------------------------------- 1 | from utils.logging import get_logger 2 | 3 | from data_logic.chunking_data_handlers import ( 4 | ArticleChunkingHandler, 5 | ChunkingDataHandler, 6 | PostChunkingHandler, 7 | RepositoryChunkingHandler, 8 | ) 9 | from data_logic.cleaning_data_handlers import ( 10 | ArticleCleaningHandler, 11 | CleaningDataHandler, 12 | PostCleaningHandler, 13 | RepositoryCleaningHandler, 14 | ) 15 | from data_logic.embedding_data_handlers import ( 16 | ArticleEmbeddingHandler, 17 | EmbeddingDataHandler, 18 | PostEmbeddingHandler, 19 | RepositoryEmbeddingHandler, 20 | ) 21 | from models.base import DataModel 22 | from models.raw import ArticleRawModel, PostsRawModel, RepositoryRawModel 23 | 24 | logger = get_logger(__name__) 25 | 26 | 27 | class RawDispatcher: 28 | @staticmethod 29 | def handle_mq_message(message: dict) -> DataModel: 30 | data_type = message.get("type") 31 | 32 | logger.info("Received message.", data_type=data_type) 33 | 34 | if data_type == "posts": 35 | return PostsRawModel(**message) 36 | elif data_type == "articles": 37 | return ArticleRawModel(**message) 38 | elif data_type == "repositories": 39 | return RepositoryRawModel(**message) 40 | else: 41 | raise ValueError("Unsupported data type") 42 | 43 | 44 | class CleaningHandlerFactory: 45 | @staticmethod 46 | def create_handler(data_type) -> CleaningDataHandler: 47 | if data_type == "posts": 48 | return PostCleaningHandler() 49 | elif data_type == "articles": 50 | return ArticleCleaningHandler() 51 | elif data_type == "repositories": 52 | return RepositoryCleaningHandler() 53 | else: 54 | raise ValueError("Unsupported data type") 55 | 56 | 57 | class CleaningDispatcher: 58 | cleaning_factory = CleaningHandlerFactory() 59 | 60 | @classmethod 61 | def dispatch_cleaner(cls, data_model: DataModel) -> DataModel: 62 | data_type = data_model.type 63 | handler = cls.cleaning_factory.create_handler(data_type) 64 | clean_model = handler.clean(data_model) 65 | 66 | logger.info( 67 | "Data cleaned successfully.", 68 | data_type=data_type, 69 | cleaned_content_len=len(clean_model.cleaned_content), 70 | ) 71 | 72 | return clean_model 73 | 74 | 75 | class ChunkingHandlerFactory: 76 | @staticmethod 77 | def create_handler(data_type) -> ChunkingDataHandler: 78 | if data_type == "posts": 79 | return PostChunkingHandler() 80 | elif data_type == "articles": 81 | return ArticleChunkingHandler() 82 | elif data_type == "repositories": 83 | return RepositoryChunkingHandler() 84 | else: 85 | raise ValueError("Unsupported data type") 86 | 87 | 88 | class ChunkingDispatcher: 89 | cleaning_factory = ChunkingHandlerFactory 90 | 91 | @classmethod 92 | def dispatch_chunker(cls, data_model: DataModel) -> list[DataModel]: 93 | data_type = data_model.type 94 | handler = cls.cleaning_factory.create_handler(data_type) 95 | chunk_models = handler.chunk(data_model) 96 | 97 | logger.info( 98 | "Cleaned content chunked successfully.", 99 | num=len(chunk_models), 100 | data_type=data_type, 101 | ) 102 | 103 | return chunk_models 104 | 105 | 106 | class EmbeddingHandlerFactory: 107 | @staticmethod 108 | def create_handler(data_type) -> EmbeddingDataHandler: 109 | if data_type == "posts": 110 | return PostEmbeddingHandler() 111 | elif data_type == "articles": 112 | return ArticleEmbeddingHandler() 113 | elif data_type == "repositories": 114 | return RepositoryEmbeddingHandler() 115 | else: 116 | raise ValueError("Unsupported data type") 117 | 118 | 119 | class EmbeddingDispatcher: 120 | cleaning_factory = EmbeddingHandlerFactory 121 | 122 | @classmethod 123 | def dispatch_embedder(cls, data_model: DataModel) -> DataModel: 124 | data_type = data_model.type 125 | handler = cls.cleaning_factory.create_handler(data_type) 126 | embedded_chunk_model = handler.embedd(data_model) 127 | 128 | logger.info( 129 | "Chunk embedded successfully.", 130 | data_type=data_type, 131 | embedding_len=len(embedded_chunk_model.embedded_content), 132 | ) 133 | 134 | return embedded_chunk_model 135 | -------------------------------------------------------------------------------- /3-feature-pipeline/data_logic/embedding_data_handlers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from models.base import DataModel 4 | from models.chunk import ArticleChunkModel, PostChunkModel, RepositoryChunkModel 5 | from models.embedded_chunk import ( 6 | ArticleEmbeddedChunkModel, 7 | PostEmbeddedChunkModel, 8 | RepositoryEmbeddedChunkModel, 9 | ) 10 | from utils.embeddings import embedd_text 11 | 12 | 13 | class EmbeddingDataHandler(ABC): 14 | """ 15 | Abstract class for all embedding data handlers. 16 | All data transformations logic for the embedding step is done here 17 | """ 18 | 19 | @abstractmethod 20 | def embedd(self, data_model: DataModel) -> DataModel: 21 | pass 22 | 23 | 24 | class PostEmbeddingHandler(EmbeddingDataHandler): 25 | def embedd(self, data_model: PostChunkModel) -> PostEmbeddedChunkModel: 26 | return PostEmbeddedChunkModel( 27 | entry_id=data_model.entry_id, 28 | platform=data_model.platform, 29 | chunk_id=data_model.chunk_id, 30 | chunk_content=data_model.chunk_content, 31 | embedded_content=embedd_text(data_model.chunk_content), 32 | author_id=data_model.author_id, 33 | type=data_model.type, 34 | ) 35 | 36 | 37 | class ArticleEmbeddingHandler(EmbeddingDataHandler): 38 | def embedd(self, data_model: ArticleChunkModel) -> ArticleEmbeddedChunkModel: 39 | return ArticleEmbeddedChunkModel( 40 | entry_id=data_model.entry_id, 41 | platform=data_model.platform, 42 | link=data_model.link, 43 | chunk_content=data_model.chunk_content, 44 | chunk_id=data_model.chunk_id, 45 | embedded_content=embedd_text(data_model.chunk_content), 46 | author_id=data_model.author_id, 47 | type=data_model.type, 48 | ) 49 | 50 | 51 | class RepositoryEmbeddingHandler(EmbeddingDataHandler): 52 | def embedd(self, data_model: RepositoryChunkModel) -> RepositoryEmbeddedChunkModel: 53 | return RepositoryEmbeddedChunkModel( 54 | entry_id=data_model.entry_id, 55 | name=data_model.name, 56 | link=data_model.link, 57 | chunk_id=data_model.chunk_id, 58 | chunk_content=data_model.chunk_content, 59 | embedded_content=embedd_text(data_model.chunk_content), 60 | owner_id=data_model.owner_id, 61 | type=data_model.type, 62 | ) 63 | -------------------------------------------------------------------------------- /3-feature-pipeline/db.py: -------------------------------------------------------------------------------- 1 | from qdrant_client import QdrantClient, models 2 | from qdrant_client.http.exceptions import UnexpectedResponse 3 | from qdrant_client.http.models import Batch, Distance, VectorParams 4 | 5 | from utils.logging import get_logger 6 | from config import settings 7 | 8 | logger = get_logger(__name__) 9 | 10 | 11 | class QdrantDatabaseConnector: 12 | _instance: QdrantClient | None = None 13 | 14 | def __init__(self) -> None: 15 | if self._instance is None: 16 | try: 17 | if settings.USE_QDRANT_CLOUD: 18 | self._instance = QdrantClient( 19 | url=settings.QDRANT_CLOUD_URL, 20 | api_key=settings.QDRANT_APIKEY, 21 | ) 22 | else: 23 | self._instance = QdrantClient( 24 | host=settings.QDRANT_DATABASE_HOST, 25 | port=settings.QDRANT_DATABASE_PORT, 26 | ) 27 | except UnexpectedResponse: 28 | logger.exception( 29 | "Couldn't connect to Qdrant.", 30 | host=settings.QDRANT_DATABASE_HOST, 31 | port=settings.QDRANT_DATABASE_PORT, 32 | url=settings.QDRANT_CLOUD_URL, 33 | ) 34 | 35 | raise 36 | 37 | def get_collection(self, collection_name: str): 38 | return self._instance.get_collection(collection_name=collection_name) 39 | 40 | def create_non_vector_collection(self, collection_name: str): 41 | self._instance.create_collection( 42 | collection_name=collection_name, vectors_config={} 43 | ) 44 | 45 | def create_vector_collection(self, collection_name: str): 46 | self._instance.create_collection( 47 | collection_name=collection_name, 48 | vectors_config=VectorParams( 49 | size=settings.EMBEDDING_SIZE, distance=Distance.COSINE 50 | ), 51 | ) 52 | 53 | def write_data(self, collection_name: str, points: Batch): 54 | try: 55 | self._instance.upsert(collection_name=collection_name, points=points) 56 | except Exception: 57 | logger.exception("An error occurred while inserting data.") 58 | 59 | raise 60 | 61 | def search( 62 | self, 63 | collection_name: str, 64 | query_vector: list, 65 | query_filter: models.Filter | None = None, 66 | limit: int = 3, 67 | ) -> list: 68 | return self._instance.search( 69 | collection_name=collection_name, 70 | query_vector=query_vector, 71 | query_filter=query_filter, 72 | limit=limit, 73 | ) 74 | 75 | def scroll(self, collection_name: str, limit: int): 76 | return self._instance.scroll(collection_name=collection_name, limit=limit) 77 | 78 | def close(self): 79 | if self._instance: 80 | self._instance.close() 81 | 82 | logger.info("Connected to database has been closed.") 83 | -------------------------------------------------------------------------------- /3-feature-pipeline/finetuning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/3-feature-pipeline/finetuning/__init__.py -------------------------------------------------------------------------------- /3-feature-pipeline/finetuning/exceptions.py: -------------------------------------------------------------------------------- 1 | class DatasetError(Exception): 2 | pass 3 | 4 | 5 | class FileNotFoundError(DatasetError): 6 | pass 7 | 8 | 9 | class JSONDecodeError(DatasetError): 10 | pass 11 | 12 | 13 | class APICommunicationError(DatasetError): 14 | pass 15 | -------------------------------------------------------------------------------- /3-feature-pipeline/finetuning/file_handler.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from finetuning.exceptions import JSONDecodeError 4 | 5 | 6 | class FileHandler: 7 | def read_json(self, filename: str) -> list: 8 | try: 9 | with open(filename, "r") as file: 10 | return json.load(file) 11 | except FileNotFoundError: 12 | raise FileNotFoundError(f"The file '{filename}' does not exist.") 13 | except json.JSONDecodeError: 14 | raise JSONDecodeError( 15 | f"The file '{filename}' is not properly formatted as JSON." 16 | ) 17 | 18 | def write_json(self, filename: str, data: list): 19 | with open(filename, "w") as file: 20 | json.dump(data, file, indent=4) 21 | -------------------------------------------------------------------------------- /3-feature-pipeline/finetuning/generate_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | from comet_ml import Artifact, Experiment 5 | 6 | from utils.logging import get_logger 7 | from db import QdrantDatabaseConnector 8 | from finetuning.file_handler import FileHandler 9 | from finetuning.llm_communication import GptCommunicator 10 | from config import settings 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | client = QdrantDatabaseConnector() 16 | 17 | 18 | class DataFormatter: 19 | @classmethod 20 | def get_system_prompt(cls, data_type: str) -> str: 21 | return ( 22 | f"I will give you batches of contents of {data_type}. Please generate me exactly 1 instruction for each of them. The {data_type} text " 23 | f"for which you have to generate the instructions is under Content number x lines. Please structure the answer in json format," 24 | f"ready to be loaded by json.loads(), a list of objects only with fields called instruction and content. For the content field, copy the number of the content only!." 25 | f"Please do not add any extra characters and make sure it is a list with objects in valid json format!\n" 26 | ) 27 | 28 | @classmethod 29 | def format_data(cls, data_points: list, is_example: bool, start_index: int) -> str: 30 | text = "" 31 | for index, data_point in enumerate(data_points): 32 | if not is_example: 33 | text += f"Content number {start_index + index }\n" 34 | text += str(data_point) + "\n" 35 | 36 | return text 37 | 38 | @classmethod 39 | def format_batch(cls, context_msg: str, data_points: list, start_index: int) -> str: 40 | delimiter_msg = context_msg 41 | delimiter_msg += cls.format_data(data_points, False, start_index) 42 | 43 | return delimiter_msg 44 | 45 | @classmethod 46 | def format_prompt(cls, inference_posts: list, data_type: str, start_index: int) -> str: 47 | initial_prompt = cls.get_system_prompt(data_type) 48 | initial_prompt += f"You must generate exactly a list of {len(inference_posts)} json objects, using the contents provided under CONTENTS FOR GENERATION\n" 49 | initial_prompt += cls.format_batch( 50 | "\nCONTENTS FOR GENERATION: \n", inference_posts, start_index 51 | ) 52 | 53 | return initial_prompt 54 | 55 | 56 | class DatasetGenerator: 57 | def __init__( 58 | self, 59 | file_handler: FileHandler, 60 | api_communicator: GptCommunicator, 61 | data_formatter: DataFormatter, 62 | ): 63 | self.file_handler = file_handler 64 | self.api_communicator = api_communicator 65 | self.data_formatter = data_formatter 66 | 67 | def generate_training_data(self, collection_name: str, data_type: str, batch_size: int = 1): 68 | all_contents = self.fetch_all_cleaned_content(collection_name) 69 | response = [] 70 | for i in range(0, len(all_contents), batch_size): 71 | batch = all_contents[i : i + batch_size] 72 | prompt = data_formatter.format_prompt(batch, data_type, i) 73 | response += self.api_communicator.send_prompt(prompt) 74 | for j in range(i, i + batch_size): 75 | response[j]["content"] = all_contents[j] 76 | 77 | self.push_to_comet(response, data_type, collection_name) 78 | 79 | def push_to_comet(self, data: list, data_type: str, collection_name: str): 80 | try: 81 | logger.info(f"Starting to push data to Comet: {collection_name}") 82 | 83 | # Assuming the settings module has been properly configured with the required attributes 84 | experiment = Experiment( 85 | api_key=settings.COMET_API_KEY, 86 | project_name=settings.COMET_PROJECT, 87 | workspace=settings.COMET_WORKSPACE, 88 | ) 89 | 90 | file_name = f"{collection_name}.json" 91 | logging.info(f"Writing data to file: {file_name}") 92 | 93 | with open(file_name, "w") as f: 94 | json.dump(data, f) 95 | 96 | logger.info("Data written to file successfully") 97 | 98 | artifact = Artifact(f"{data_type}-instruct-dataset") 99 | artifact.add(file_name) 100 | logger.info(f"Artifact created and file added: {file_name}") 101 | 102 | experiment.log_artifact(artifact) 103 | experiment.end() 104 | logger.info("Data pushed to Comet successfully and experiment ended") 105 | 106 | except Exception as e: 107 | logger.error(f"Failed to push data to Comet: {e}", exc_info=True) 108 | 109 | def fetch_all_cleaned_content(self, collection_name: str) -> list: 110 | all_cleaned_contents = [] 111 | 112 | scroll_response = client.scroll(collection_name=collection_name, limit=10000) 113 | points = scroll_response[0] 114 | 115 | for point in points: 116 | cleaned_content = point.payload["cleaned_content"] 117 | if cleaned_content: 118 | all_cleaned_contents.append(cleaned_content) 119 | 120 | return all_cleaned_contents 121 | 122 | 123 | if __name__ == "__main__": 124 | file_handler = FileHandler() 125 | api_communicator = GptCommunicator() 126 | data_formatter = DataFormatter() 127 | dataset_generator = DatasetGenerator(file_handler, api_communicator, data_formatter) 128 | 129 | collections = [("cleaned_articles", "articles"), ("cleaned_posts", "posts"), ("cleaned_repositories", "repositories")] 130 | for (collection_name, data_type) in collections: 131 | logger.info("Generating training data.", collection_name=collection_name, data_type=data_type) 132 | 133 | dataset_generator.generate_training_data( 134 | collection_name=collection_name, data_type=data_type, batch_size=1 135 | ) 136 | -------------------------------------------------------------------------------- /3-feature-pipeline/finetuning/llm_communication.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from openai import OpenAI 4 | 5 | from utils.logging import get_logger 6 | from config import settings 7 | 8 | MAX_LENGTH = 16384 9 | SYSTEM_PROMPT = ( 10 | "You are a technical writer handing someone's account to post about AI and MLOps." 11 | ) 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | class GptCommunicator: 17 | def __init__(self, gpt_model: str = "gpt-3.5-turbo"): 18 | self.api_key = settings.OPENAI_API_KEY 19 | self.gpt_model = gpt_model 20 | 21 | def send_prompt(self, prompt: str) -> list: 22 | try: 23 | client = OpenAI(api_key=self.api_key) 24 | logger.info("Sending batch to LLM") 25 | chat_completion = client.chat.completions.create( 26 | messages=[ 27 | {"role": "system", "content": SYSTEM_PROMPT}, 28 | {"role": "user", "content": prompt[:MAX_LENGTH]}, 29 | ], 30 | model=self.gpt_model, 31 | ) 32 | response = chat_completion.choices[0].message.content 33 | return json.loads(self.clean_response(response)) 34 | except Exception: 35 | logger.exception( 36 | f"Skipping batch! An error occurred while communicating with API." 37 | ) 38 | 39 | return [] 40 | 41 | @staticmethod 42 | def clean_response(response: str) -> str: 43 | start_index = response.find("[") 44 | end_index = response.rfind("]") 45 | return response[start_index : end_index + 1] 46 | -------------------------------------------------------------------------------- /3-feature-pipeline/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/3-feature-pipeline/llm/__init__.py -------------------------------------------------------------------------------- /3-feature-pipeline/llm/chain.py: -------------------------------------------------------------------------------- 1 | from langchain.chains.llm import LLMChain 2 | from langchain.prompts import PromptTemplate 3 | 4 | 5 | class GeneralChain: 6 | @staticmethod 7 | def get_chain(llm, template: PromptTemplate, output_key: str, verbose=True): 8 | return LLMChain( 9 | llm=llm, prompt=template, output_key=output_key, verbose=verbose 10 | ) 11 | -------------------------------------------------------------------------------- /3-feature-pipeline/llm/prompt_templates.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from langchain.prompts import PromptTemplate 4 | from pydantic import BaseModel 5 | 6 | 7 | class BasePromptTemplate(ABC, BaseModel): 8 | @abstractmethod 9 | def create_template(self) -> PromptTemplate: 10 | pass 11 | 12 | 13 | class QueryExpansionTemplate(BasePromptTemplate): 14 | prompt: str = """You are an AI language model assistant. Your task is to generate {to_expand_to_n} 15 | different versions of the given user question to retrieve relevant documents from a vector 16 | database. By generating multiple perspectives on the user question, your goal is to help 17 | the user overcome some of the limitations of the distance-based similarity search. 18 | Provide these alternative questions seperated by '{separator}'. 19 | Original question: {question}""" 20 | 21 | @property 22 | def separator(self) -> str: 23 | return "#next-question#" 24 | 25 | def create_template(self, to_expand_to_n: int) -> PromptTemplate: 26 | return PromptTemplate( 27 | template=self.prompt, 28 | input_variables=["question"], 29 | partial_variables={ 30 | "separator": self.separator, 31 | "to_expand_to_n": to_expand_to_n, 32 | }, 33 | ) 34 | 35 | 36 | class SelfQueryTemplate(BasePromptTemplate): 37 | prompt: str = """You are an AI language model assistant. Your task is to extract information from a user question. 38 | The required information that needs to be extracted is the user or author id. 39 | Your response should consists of only the extracted id (e.g. 1345256), nothing else. 40 | If you cannot find the author id, return the string "None". 41 | User question: {question}""" 42 | 43 | def create_template(self) -> PromptTemplate: 44 | return PromptTemplate(template=self.prompt, input_variables=["question"]) 45 | 46 | 47 | class RerankingTemplate(BasePromptTemplate): 48 | prompt: str = """You are an AI language model assistant. Your task is to rerank passages related to a query 49 | based on their relevance. 50 | The most relevant passages should be put at the beginning. 51 | You should only pick at max {keep_top_k} passages. 52 | The provided and reranked documents are separated by '{separator}'. 53 | 54 | The following are passages related to this query: {question}. 55 | 56 | Passages: 57 | {passages} 58 | """ 59 | 60 | def create_template(self, keep_top_k: int) -> PromptTemplate: 61 | return PromptTemplate( 62 | template=self.prompt, 63 | input_variables=["question", "passages"], 64 | partial_variables={"keep_top_k": keep_top_k, "separator": self.separator}, 65 | ) 66 | 67 | @property 68 | def separator(self) -> str: 69 | return "\n#next-document#\n" 70 | -------------------------------------------------------------------------------- /3-feature-pipeline/main.py: -------------------------------------------------------------------------------- 1 | import bytewax.operators as op 2 | from bytewax.dataflow import Dataflow 3 | 4 | from db import QdrantDatabaseConnector 5 | 6 | from data_flow.stream_input import RabbitMQSource 7 | from data_flow.stream_output import QdrantOutput 8 | from data_logic.dispatchers import ( 9 | ChunkingDispatcher, 10 | CleaningDispatcher, 11 | EmbeddingDispatcher, 12 | RawDispatcher, 13 | ) 14 | 15 | connection = QdrantDatabaseConnector() 16 | 17 | flow = Dataflow("Streaming ingestion pipeline") 18 | stream = op.input("input", flow, RabbitMQSource()) 19 | stream = op.map("raw dispatch", stream, RawDispatcher.handle_mq_message) 20 | stream = op.map("clean dispatch", stream, CleaningDispatcher.dispatch_cleaner) 21 | op.output( 22 | "cleaned data insert to qdrant", 23 | stream, 24 | QdrantOutput(connection=connection, sink_type="clean"), 25 | ) 26 | stream = op.flat_map("chunk dispatch", stream, ChunkingDispatcher.dispatch_chunker) 27 | stream = op.map( 28 | "embedded chunk dispatch", stream, EmbeddingDispatcher.dispatch_embedder 29 | ) 30 | op.output( 31 | "embedded data insert to qdrant", 32 | stream, 33 | QdrantOutput(connection=connection, sink_type="vector"), 34 | ) 35 | -------------------------------------------------------------------------------- /3-feature-pipeline/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/3-feature-pipeline/models/__init__.py -------------------------------------------------------------------------------- /3-feature-pipeline/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class DataModel(BaseModel): 7 | """ 8 | Abstract class for all data models 9 | """ 10 | 11 | entry_id: str 12 | type: str 13 | 14 | 15 | class VectorDBDataModel(ABC, DataModel): 16 | """ 17 | Abstract class for all data models that need to be saved into a vector DB (e.g. Qdrant) 18 | """ 19 | 20 | entry_id: int 21 | type: str 22 | 23 | @abstractmethod 24 | def to_payload(self) -> tuple: 25 | pass 26 | -------------------------------------------------------------------------------- /3-feature-pipeline/models/chunk.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from models.base import DataModel 4 | 5 | 6 | class PostChunkModel(DataModel): 7 | entry_id: str 8 | platform: str 9 | chunk_id: str 10 | chunk_content: str 11 | author_id: str 12 | image: Optional[str] = None 13 | type: str 14 | 15 | 16 | class ArticleChunkModel(DataModel): 17 | entry_id: str 18 | platform: str 19 | link: str 20 | chunk_id: str 21 | chunk_content: str 22 | author_id: str 23 | type: str 24 | 25 | 26 | class RepositoryChunkModel(DataModel): 27 | entry_id: str 28 | name: str 29 | link: str 30 | chunk_id: str 31 | chunk_content: str 32 | owner_id: str 33 | type: str 34 | -------------------------------------------------------------------------------- /3-feature-pipeline/models/clean.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from models.base import VectorDBDataModel 4 | 5 | 6 | class PostCleanedModel(VectorDBDataModel): 7 | entry_id: str 8 | platform: str 9 | cleaned_content: str 10 | author_id: str 11 | image: Optional[str] = None 12 | type: str 13 | 14 | def to_payload(self) -> Tuple[str, dict]: 15 | data = { 16 | "platform": self.platform, 17 | "author_id": self.author_id, 18 | "cleaned_content": self.cleaned_content, 19 | "image": self.image, 20 | "type": self.type, 21 | } 22 | 23 | return self.entry_id, data 24 | 25 | 26 | class ArticleCleanedModel(VectorDBDataModel): 27 | entry_id: str 28 | platform: str 29 | link: str 30 | cleaned_content: str 31 | author_id: str 32 | type: str 33 | 34 | def to_payload(self) -> Tuple[str, dict]: 35 | data = { 36 | "platform": self.platform, 37 | "link": self.link, 38 | "cleaned_content": self.cleaned_content, 39 | "author_id": self.author_id, 40 | "type": self.type, 41 | } 42 | 43 | return self.entry_id, data 44 | 45 | 46 | class RepositoryCleanedModel(VectorDBDataModel): 47 | entry_id: str 48 | name: str 49 | link: str 50 | cleaned_content: str 51 | owner_id: str 52 | type: str 53 | 54 | def to_payload(self) -> Tuple[str, dict]: 55 | data = { 56 | "name": self.name, 57 | "link": self.link, 58 | "cleaned_content": self.cleaned_content, 59 | "owner_id": self.owner_id, 60 | "type": self.type, 61 | } 62 | 63 | return self.entry_id, data 64 | -------------------------------------------------------------------------------- /3-feature-pipeline/models/embedded_chunk.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | from models.base import VectorDBDataModel 6 | 7 | 8 | class PostEmbeddedChunkModel(VectorDBDataModel): 9 | entry_id: str 10 | platform: str 11 | chunk_id: str 12 | chunk_content: str 13 | embedded_content: np.ndarray 14 | author_id: str 15 | type: str 16 | 17 | class Config: 18 | arbitrary_types_allowed = True 19 | 20 | def to_payload(self) -> Tuple[str, np.ndarray, dict]: 21 | data = { 22 | "id": self.entry_id, 23 | "platform": self.platform, 24 | "content": self.chunk_content, 25 | "owner_id": self.author_id, 26 | "type": self.type, 27 | } 28 | 29 | return self.chunk_id, self.embedded_content, data 30 | 31 | 32 | class ArticleEmbeddedChunkModel(VectorDBDataModel): 33 | entry_id: str 34 | platform: str 35 | link: str 36 | chunk_id: str 37 | chunk_content: str 38 | embedded_content: np.ndarray 39 | author_id: str 40 | type: str 41 | 42 | class Config: 43 | arbitrary_types_allowed = True 44 | 45 | def to_payload(self) -> Tuple[str, np.ndarray, dict]: 46 | data = { 47 | "id": self.entry_id, 48 | "platform": self.platform, 49 | "content": self.chunk_content, 50 | "link": self.link, 51 | "author_id": self.author_id, 52 | "type": self.type, 53 | } 54 | 55 | return self.chunk_id, self.embedded_content, data 56 | 57 | 58 | class RepositoryEmbeddedChunkModel(VectorDBDataModel): 59 | entry_id: str 60 | name: str 61 | link: str 62 | chunk_id: str 63 | chunk_content: str 64 | embedded_content: np.ndarray 65 | owner_id: str 66 | type: str 67 | 68 | class Config: 69 | arbitrary_types_allowed = True 70 | 71 | def to_payload(self) -> Tuple[str, np.ndarray, dict]: 72 | data = { 73 | "id": self.entry_id, 74 | "name": self.name, 75 | "content": self.chunk_content, 76 | "link": self.link, 77 | "owner_id": self.owner_id, 78 | "type": self.type, 79 | } 80 | 81 | return self.chunk_id, self.embedded_content, data 82 | -------------------------------------------------------------------------------- /3-feature-pipeline/models/raw.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from models.base import DataModel 4 | 5 | 6 | class RepositoryRawModel(DataModel): 7 | name: str 8 | link: str 9 | content: dict 10 | owner_id: str 11 | 12 | 13 | class ArticleRawModel(DataModel): 14 | platform: str 15 | link: str 16 | content: dict 17 | author_id: str 18 | 19 | 20 | class PostsRawModel(DataModel): 21 | platform: str 22 | content: dict 23 | author_id: str | None = None 24 | image: Optional[str] = None 25 | -------------------------------------------------------------------------------- /3-feature-pipeline/mq.py: -------------------------------------------------------------------------------- 1 | import pika 2 | 3 | from utils.logging import get_logger 4 | from config import settings 5 | 6 | logger = get_logger(__name__) 7 | 8 | 9 | class RabbitMQConnection: 10 | _instance = None 11 | 12 | def __new__( 13 | cls, 14 | host: str | None = None, 15 | port: int | None = None, 16 | username: str | None = None, 17 | password: str | None = None, 18 | virtual_host: str = "/", 19 | ): 20 | if not cls._instance: 21 | cls._instance = super().__new__(cls) 22 | 23 | return cls._instance 24 | 25 | def __init__( 26 | self, 27 | host: str | None = None, 28 | port: int | None = None, 29 | username: str | None = None, 30 | password: str | None = None, 31 | virtual_host: str = "/", 32 | fail_silently: bool = False, 33 | **kwargs, 34 | ): 35 | self.host = host or settings.RABBITMQ_HOST 36 | self.port = port or settings.RABBITMQ_PORT 37 | self.username = username or settings.RABBITMQ_DEFAULT_USERNAME 38 | self.password = password or settings.RABBITMQ_DEFAULT_PASSWORD 39 | self.virtual_host = virtual_host 40 | self.fail_silently = fail_silently 41 | self._connection = None 42 | 43 | def __enter__(self): 44 | self.connect() 45 | return self 46 | 47 | def __exit__(self, exc_type, exc_val, exc_tb): 48 | self.close() 49 | 50 | def connect(self): 51 | try: 52 | credentials = pika.PlainCredentials(self.username, self.password) 53 | self._connection = pika.BlockingConnection( 54 | pika.ConnectionParameters( 55 | host=self.host, 56 | port=self.port, 57 | virtual_host=self.virtual_host, 58 | credentials=credentials, 59 | ) 60 | ) 61 | except pika.exceptions.AMQPConnectionError as e: 62 | logger.exception("Failed to connect to RabbitMQ.") 63 | 64 | if not self.fail_silently: 65 | raise e 66 | 67 | def publish_message(self, data: str, queue: str): 68 | channel = self.get_channel() 69 | channel.queue_declare( 70 | queue=queue, durable=True, exclusive=False, auto_delete=False 71 | ) 72 | channel.confirm_delivery() 73 | 74 | try: 75 | channel.basic_publish( 76 | exchange="", routing_key=queue, body=data, mandatory=True 77 | ) 78 | logger.info( 79 | "Sent message successfully.", queue_type="RabbitMQ", queue_name=queue 80 | ) 81 | except pika.exceptions.UnroutableError: 82 | logger.info( 83 | "Failed to send the message.", queue_type="RabbitMQ", queue_name=queue 84 | ) 85 | 86 | def is_connected(self) -> bool: 87 | return self._connection is not None and self._connection.is_open 88 | 89 | def get_channel(self): 90 | if self.is_connected(): 91 | return self._connection.channel() 92 | 93 | def close(self): 94 | if self.is_connected(): 95 | self._connection.close() 96 | self._connection = None 97 | 98 | logger.info("Closed RabbitMQ connection.") 99 | -------------------------------------------------------------------------------- /3-feature-pipeline/rag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/3-feature-pipeline/rag/__init__.py -------------------------------------------------------------------------------- /3-feature-pipeline/rag/query_expanison.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | 3 | from llm.chain import GeneralChain 4 | from llm.prompt_templates import QueryExpansionTemplate 5 | from config import settings 6 | 7 | 8 | class QueryExpansion: 9 | @staticmethod 10 | def generate_response(query: str, to_expand_to_n: int) -> list[str]: 11 | query_expansion_template = QueryExpansionTemplate() 12 | prompt_template = query_expansion_template.create_template(to_expand_to_n) 13 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID, temperature=0) 14 | 15 | chain = GeneralChain().get_chain( 16 | llm=model, output_key="expanded_queries", template=prompt_template 17 | ) 18 | 19 | response = chain.invoke({"question": query}) 20 | result = response["expanded_queries"] 21 | 22 | queries = result.strip().split(query_expansion_template.separator) 23 | stripped_queries = [ 24 | stripped_item for item in queries if (stripped_item := item.strip()) 25 | ] 26 | 27 | return stripped_queries 28 | -------------------------------------------------------------------------------- /3-feature-pipeline/rag/reranking.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | 3 | from llm.chain import GeneralChain 4 | from llm.prompt_templates import RerankingTemplate 5 | from config import settings 6 | 7 | 8 | class Reranker: 9 | @staticmethod 10 | def generate_response( 11 | query: str, passages: list[str], keep_top_k: int 12 | ) -> list[str]: 13 | reranking_template = RerankingTemplate() 14 | prompt_template = reranking_template.create_template(keep_top_k=keep_top_k) 15 | 16 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID) 17 | chain = GeneralChain().get_chain( 18 | llm=model, output_key="rerank", template=prompt_template 19 | ) 20 | 21 | stripped_passages = [ 22 | stripped_item for item in passages if (stripped_item := item.strip()) 23 | ] 24 | passages = reranking_template.separator.join(stripped_passages) 25 | response = chain.invoke({"question": query, "passages": passages}) 26 | 27 | result = response["rerank"] 28 | reranked_passages = result.strip().split(reranking_template.separator) 29 | stripped_passages = [ 30 | stripped_item 31 | for item in reranked_passages 32 | if (stripped_item := item.strip()) 33 | ] 34 | 35 | return stripped_passages 36 | -------------------------------------------------------------------------------- /3-feature-pipeline/rag/retriever.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | 3 | from utils.logging import get_logger 4 | import utils 5 | from db import QdrantDatabaseConnector 6 | from qdrant_client import models 7 | from rag.query_expanison import QueryExpansion 8 | from rag.reranking import Reranker 9 | from rag.self_query import SelfQuery 10 | from sentence_transformers.SentenceTransformer import SentenceTransformer 11 | from config import settings 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | class VectorRetriever: 17 | """ 18 | Class for retrieving vectors from a Vector store in a RAG system using query expansion and Multitenancy search. 19 | """ 20 | 21 | def __init__(self, query: str): 22 | self._client = QdrantDatabaseConnector() 23 | self.query = query 24 | self._embedder = SentenceTransformer(settings.EMBEDDING_MODEL_ID) 25 | self._query_expander = QueryExpansion() 26 | self._metadata_extractor = SelfQuery() 27 | self._reranker = Reranker() 28 | 29 | def _search_single_query( 30 | self, generated_query: str, metadata_filter_value: str | None, k: int 31 | ): 32 | assert k > 3, "k should be greater than 3" 33 | 34 | query_vector = self._embedder.encode(generated_query).tolist() 35 | vectors = [ 36 | self._client.search( 37 | collection_name="vector_posts", 38 | query_filter=models.Filter( 39 | must=[ 40 | models.FieldCondition( 41 | key="author_id", 42 | match=models.MatchValue( 43 | value=metadata_filter_value, 44 | ), 45 | ) 46 | ] 47 | ) if metadata_filter_value else None, 48 | query_vector=query_vector, 49 | limit=k // 3, 50 | ), 51 | self._client.search( 52 | collection_name="vector_articles", 53 | query_filter=models.Filter( 54 | must=[ 55 | models.FieldCondition( 56 | key="author_id", 57 | match=models.MatchValue( 58 | value=metadata_filter_value, 59 | ), 60 | ) 61 | ] 62 | ) if metadata_filter_value else None, 63 | query_vector=query_vector, 64 | limit=k // 3, 65 | ), 66 | self._client.search( 67 | collection_name="vector_repositories", 68 | query_filter=models.Filter( 69 | must=[ 70 | models.FieldCondition( 71 | key="owner_id", 72 | match=models.MatchValue( 73 | value=metadata_filter_value, 74 | ), 75 | ) 76 | ] 77 | ) if metadata_filter_value else None, 78 | query_vector=query_vector, 79 | limit=k // 3, 80 | ), 81 | ] 82 | 83 | return utils.flatten(vectors) 84 | 85 | def retrieve_top_k(self, k: int, to_expand_to_n_queries: int) -> list: 86 | generated_queries = self._query_expander.generate_response( 87 | self.query, to_expand_to_n=to_expand_to_n_queries 88 | ) 89 | logger.info( 90 | "Successfully generated queries for search.", 91 | num_queries=len(generated_queries), 92 | ) 93 | 94 | author_id = self._metadata_extractor.generate_response(self.query) 95 | if author_id: 96 | logger.info( 97 | "Successfully extracted the author_id from the query.", 98 | author_id=author_id, 99 | ) 100 | else: 101 | logger.info("Couldn't extract the author_id from the query.") 102 | 103 | with concurrent.futures.ThreadPoolExecutor() as executor: 104 | search_tasks = [ 105 | executor.submit(self._search_single_query, query, author_id, k) 106 | for query in generated_queries 107 | ] 108 | 109 | hits = [ 110 | task.result() for task in concurrent.futures.as_completed(search_tasks) 111 | ] 112 | hits = utils.flatten(hits) 113 | 114 | logger.info("All documents retrieved successfully.", num_documents=len(hits)) 115 | 116 | return hits 117 | 118 | def rerank(self, hits: list, keep_top_k: int) -> list[str]: 119 | content_list = [hit.payload["content"] for hit in hits] 120 | rerank_hits = self._reranker.generate_response( 121 | query=self.query, passages=content_list, keep_top_k=keep_top_k 122 | ) 123 | 124 | logger.info("Documents reranked successfully.", num_documents=len(rerank_hits)) 125 | 126 | return rerank_hits 127 | 128 | def set_query(self, query: str): 129 | self.query = query 130 | -------------------------------------------------------------------------------- /3-feature-pipeline/rag/self_query.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | from llm.chain import GeneralChain 3 | from llm.prompt_templates import SelfQueryTemplate 4 | from config import settings 5 | 6 | 7 | class SelfQuery: 8 | @staticmethod 9 | def generate_response(query: str) -> str | None: 10 | prompt = SelfQueryTemplate().create_template() 11 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID, temperature=0) 12 | 13 | chain = GeneralChain().get_chain( 14 | llm=model, output_key="metadata_filter_value", template=prompt 15 | ) 16 | 17 | response = chain.invoke({"question": query}) 18 | result = response.get("metadata_filter_value", "none") 19 | 20 | if result.lower() == "none": 21 | return None 22 | 23 | return result 24 | -------------------------------------------------------------------------------- /3-feature-pipeline/retriever.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from langchain.globals import set_verbose 3 | from rag.retriever import VectorRetriever 4 | 5 | from utils.logging import get_logger 6 | 7 | set_verbose(True) 8 | 9 | logger = get_logger(__name__) 10 | 11 | if __name__ == "__main__": 12 | load_dotenv() 13 | query = """ 14 | Could you please draft a LinkedIn post discussing RAG systems? 15 | I'm particularly interested in how RAG works and how it is integrated with vector DBs and large language models (LLMs). 16 | """ 17 | retriever = VectorRetriever(query=query) 18 | hits = retriever.retrieve_top_k(k=6, to_expand_to_n_queries=5) 19 | 20 | reranked_hits = retriever.rerank(hits=hits, keep_top_k=5) 21 | for rank, hit in enumerate(reranked_hits): 22 | logger.info(f"{rank}: {hit}") 23 | -------------------------------------------------------------------------------- /3-feature-pipeline/scripts/bytewax_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$DEBUG" = true ] 4 | then 5 | python -m bytewax.run "tools.run_real_time:build_flow(debug=True)" 6 | else 7 | if [ "$BYTEWAX_PYTHON_FILE_PATH" = "" ] 8 | then 9 | echo 'BYTEWAX_PYTHON_FILE_PATH is not set. Exiting...' 10 | exit 1 11 | fi 12 | python -m bytewax.run $BYTEWAX_PYTHON_FILE_PATH 13 | fi 14 | 15 | 16 | echo 'Process ended.' 17 | 18 | if [ "$BYTEWAX_KEEP_CONTAINER_ALIVE" = true ] 19 | then 20 | echo 'Keeping container alive...'; 21 | while :; do sleep 1; done 22 | fi -------------------------------------------------------------------------------- /3-feature-pipeline/utils/__init__.py: -------------------------------------------------------------------------------- 1 | def flatten(nested_list: list) -> list: 2 | """Flatten a list of lists into a single list.""" 3 | 4 | return [item for sublist in nested_list for item in sublist] 5 | -------------------------------------------------------------------------------- /3-feature-pipeline/utils/chunking.py: -------------------------------------------------------------------------------- 1 | from langchain.text_splitter import ( 2 | RecursiveCharacterTextSplitter, 3 | SentenceTransformersTokenTextSplitter, 4 | ) 5 | 6 | from config import settings 7 | 8 | 9 | def chunk_text(text: str) -> list[str]: 10 | character_splitter = RecursiveCharacterTextSplitter( 11 | separators=["\n\n"], chunk_size=500, chunk_overlap=0 12 | ) 13 | text_split = character_splitter.split_text(text) 14 | 15 | token_splitter = SentenceTransformersTokenTextSplitter( 16 | chunk_overlap=50, 17 | tokens_per_chunk=settings.EMBEDDING_MODEL_MAX_INPUT_LENGTH, 18 | model_name=settings.EMBEDDING_MODEL_ID, 19 | ) 20 | chunks = [] 21 | 22 | for section in text_split: 23 | chunks.extend(token_splitter.split_text(section)) 24 | 25 | return chunks 26 | -------------------------------------------------------------------------------- /3-feature-pipeline/utils/cleaning.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from unstructured.cleaners.core import ( 4 | clean, 5 | clean_non_ascii_chars, 6 | replace_unicode_quotes, 7 | ) 8 | 9 | 10 | def unbold_text(text): 11 | # Mapping of bold numbers to their regular equivalents 12 | bold_numbers = { 13 | "𝟬": "0", 14 | "𝟭": "1", 15 | "𝟮": "2", 16 | "𝟯": "3", 17 | "𝟰": "4", 18 | "𝟱": "5", 19 | "𝟲": "6", 20 | "𝟳": "7", 21 | "𝟴": "8", 22 | "𝟵": "9", 23 | } 24 | 25 | # Function to convert bold characters (letters and numbers) 26 | def convert_bold_char(match): 27 | char = match.group(0) 28 | # Convert bold numbers 29 | if char in bold_numbers: 30 | return bold_numbers[char] 31 | # Convert bold uppercase letters 32 | elif "\U0001d5d4" <= char <= "\U0001d5ed": 33 | return chr(ord(char) - 0x1D5D4 + ord("A")) 34 | # Convert bold lowercase letters 35 | elif "\U0001d5ee" <= char <= "\U0001d607": 36 | return chr(ord(char) - 0x1D5EE + ord("a")) 37 | else: 38 | return char # Return the character unchanged if it's not a bold number or letter 39 | 40 | # Regex for bold characters (numbers, uppercase, and lowercase letters) 41 | bold_pattern = re.compile( 42 | r"[\U0001D5D4-\U0001D5ED\U0001D5EE-\U0001D607\U0001D7CE-\U0001D7FF]" 43 | ) 44 | text = bold_pattern.sub(convert_bold_char, text) 45 | 46 | return text 47 | 48 | 49 | def unitalic_text(text): 50 | # Function to convert italic characters (both letters) 51 | def convert_italic_char(match): 52 | char = match.group(0) 53 | # Unicode ranges for italic characters 54 | if "\U0001d608" <= char <= "\U0001d621": # Italic uppercase A-Z 55 | return chr(ord(char) - 0x1D608 + ord("A")) 56 | elif "\U0001d622" <= char <= "\U0001d63b": # Italic lowercase a-z 57 | return chr(ord(char) - 0x1D622 + ord("a")) 58 | else: 59 | return char # Return the character unchanged if it's not an italic letter 60 | 61 | # Regex for italic characters (uppercase and lowercase letters) 62 | italic_pattern = re.compile(r"[\U0001D608-\U0001D621\U0001D622-\U0001D63B]") 63 | text = italic_pattern.sub(convert_italic_char, text) 64 | 65 | return text 66 | 67 | 68 | def remove_emojis_and_symbols(text): 69 | # Extended pattern to include specific symbols like ↓ (U+2193) or ↳ (U+21B3) 70 | emoji_and_symbol_pattern = re.compile( 71 | "[" 72 | "\U0001f600-\U0001f64f" # emoticons 73 | "\U0001f300-\U0001f5ff" # symbols & pictographs 74 | "\U0001f680-\U0001f6ff" # transport & map symbols 75 | "\U0001f1e0-\U0001f1ff" # flags (iOS) 76 | "\U00002193" # downwards arrow 77 | "\U000021b3" # downwards arrow with tip rightwards 78 | "\U00002192" # rightwards arrow 79 | "]+", 80 | flags=re.UNICODE, 81 | ) 82 | 83 | return emoji_and_symbol_pattern.sub(r" ", text) 84 | 85 | 86 | def replace_urls_with_placeholder(text, placeholder="[URL]"): 87 | # Regular expression pattern for matching URLs 88 | url_pattern = r"https?://\S+|www\.\S+" 89 | 90 | return re.sub(url_pattern, placeholder, text) 91 | 92 | 93 | def remove_non_ascii(text: str) -> str: 94 | text = text.encode("ascii", "ignore").decode("ascii") 95 | return text 96 | 97 | 98 | def clean_text(text_content: str) -> str: 99 | cleaned_text = unbold_text(text_content) 100 | cleaned_text = unitalic_text(cleaned_text) 101 | cleaned_text = remove_emojis_and_symbols(cleaned_text) 102 | cleaned_text = clean(cleaned_text) 103 | cleaned_text = replace_unicode_quotes(cleaned_text) 104 | cleaned_text = clean_non_ascii_chars(cleaned_text) 105 | cleaned_text = replace_urls_with_placeholder(cleaned_text) 106 | 107 | return cleaned_text 108 | -------------------------------------------------------------------------------- /3-feature-pipeline/utils/embeddings.py: -------------------------------------------------------------------------------- 1 | from InstructorEmbedding import INSTRUCTOR 2 | from sentence_transformers.SentenceTransformer import SentenceTransformer 3 | 4 | from config import settings 5 | 6 | 7 | def embedd_text(text: str): 8 | model = SentenceTransformer(settings.EMBEDDING_MODEL_ID) 9 | return model.encode(text) 10 | 11 | 12 | def embedd_repositories(text: str): 13 | model = INSTRUCTOR("hkunlp/instructor-xl") 14 | sentence = text 15 | instruction = "Represent the structure of the repository" 16 | return model.encode([instruction, sentence]) 17 | -------------------------------------------------------------------------------- /3-feature-pipeline/utils/logging.py: -------------------------------------------------------------------------------- 1 | import structlog 2 | 3 | 4 | def get_logger(cls: str): 5 | return structlog.get_logger().bind(cls=cls) -------------------------------------------------------------------------------- /4-finetuning/.env.example: -------------------------------------------------------------------------------- 1 | HUGGINGFACE_ACCESS_TOKEN = "str" 2 | COMET_API_KEY = "str" 3 | COMET_WORKSPACE = "str" 4 | COMET_PROJECT = "scrabble" -------------------------------------------------------------------------------- /4-finetuning/build_config.yaml: -------------------------------------------------------------------------------- 1 | build_env: 2 | docker: 3 | assumed_iam_role_arn: null 4 | base_image: public.ecr.aws/qwak-us-east-1/qwak-base:0.0.13-gpu 5 | cache: true 6 | env_vars: 7 | - HUGGINGFACE_ACCESS_TOKEN= 8 | - COMET_API_KEY= 9 | - COMET_WORKSPACE= 10 | - COMET_PROJECT=llm-twin-course 11 | no_cache: false 12 | params: [] 13 | push: true 14 | python_env: 15 | dependency_file_path: finetuning/requirements.txt 16 | git_credentials: null 17 | git_credentials_secret: null 18 | poetry: null 19 | virtualenv: null 20 | remote: 21 | is_remote: true 22 | resources: 23 | cpus: null 24 | gpu_amount: null 25 | gpu_type: null 26 | instance: gpu.a10.2xl 27 | memory: null 28 | build_properties: 29 | branch: finetuning 30 | build_id: null 31 | gpu_compatible: false 32 | model_id: llm_twin 33 | model_uri: 34 | dependency_required_folders: [] 35 | git_branch: master 36 | git_credentials: null 37 | git_credentials_secret: null 38 | git_secret_ssh: null 39 | main_dir: finetuning 40 | uri: . 41 | tags: [] 42 | deploy: false 43 | deployment_instance: null 44 | post_build: null 45 | pre_build: null 46 | purchase_option: null 47 | step: 48 | tests: true 49 | validate_build_artifact: true 50 | validate_build_artifact_timeout: 120 51 | verbose: 0 52 | 53 | -------------------------------------------------------------------------------- /4-finetuning/finetuning/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import CopywriterMistralModel 2 | 3 | 4 | def load_model(): 5 | return CopywriterMistralModel() 6 | -------------------------------------------------------------------------------- /4-finetuning/finetuning/config.yaml: -------------------------------------------------------------------------------- 1 | training_arguments: 2 | output_dir: "mistral_instruct_generation" 3 | max_steps: 10 4 | per_device_train_batch_size: 1 5 | logging_steps: 10 6 | save_strategy: "epoch" 7 | evaluation_strategy: "steps" 8 | eval_steps: 2 9 | learning_rate: 0.0002 10 | fp16: true 11 | remove_unused_columns: false 12 | lr_scheduler_type: "constant" 13 | -------------------------------------------------------------------------------- /4-finetuning/finetuning/dataset_client.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | 6 | from comet_ml import Experiment 7 | from finetuning.settings import settings 8 | from sklearn.model_selection import train_test_split 9 | 10 | 11 | class DatasetClient: 12 | def __init__(self, output_dir: str = "./finetuning"): 13 | self.project = settings.COMET_PROJECT 14 | self.api_key = settings.COMET_API_KEY 15 | self.experiment = Experiment(api_key=self.api_key, project_name=self.project) 16 | self.output_dir = output_dir 17 | 18 | def get_artifact(self, artifact_name: str): 19 | try: 20 | logged_artifact = self.experiment.get_artifact(artifact_name) 21 | logged_artifact.download(self.output_dir) 22 | self.experiment.end() 23 | logging.info( 24 | f"Successfully downloaded {artifact_name} at location {self.output_dir}" 25 | ) 26 | except Exception as e: 27 | logging.error(f"Error retrieving artifact: {str(e)}") 28 | 29 | def split_data(self, artifact_name: str) -> tuple: 30 | try: 31 | training_file_path = os.path.join(self.output_dir, "train.json") 32 | validation_file_path = os.path.join(self.output_dir, "validation.json") 33 | file_name = artifact_name + ".json" 34 | with open(os.path.join(self.output_dir, file_name), "r") as file: 35 | data = json.load(file) 36 | 37 | train_data, val_data = train_test_split( 38 | data, test_size=0.2, random_state=42 39 | ) 40 | 41 | with open(training_file_path, "w") as train_file: 42 | json.dump(train_data, train_file) 43 | 44 | with open(validation_file_path, "w") as val_file: 45 | json.dump(val_data, val_file) 46 | 47 | logging.info("Data split into train.json and validation.json successfully.") 48 | return training_file_path, validation_file_path 49 | except Exception as e: 50 | logging.error(f"Error splitting data: {str(e)}") 51 | 52 | def download_dataset(self, file_name: str): 53 | self.get_artifact(file_name) 54 | return self.split_data(file_name) 55 | -------------------------------------------------------------------------------- /4-finetuning/finetuning/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.5 2 | pandas==2.2.2 3 | peft==0.11.0 4 | datasets==2.19.1 5 | transformers==4.40.2 6 | safetensors==0.4.3 7 | comet_ml==3.42.0 8 | accelerate==0.30.1 9 | bitsandbytes==0.42.0 10 | pydantic_settings==2.2.1 11 | scikit-learn==1.4.2 12 | qwak-sdk==0.5.68 13 | -------------------------------------------------------------------------------- /4-finetuning/finetuning/settings.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class AppSettings(BaseSettings): 5 | model_config = SettingsConfigDict() 6 | 7 | TOKENIZERS_PARALLELISM: str = "false" 8 | 9 | HUGGINGFACE_ACCESS_TOKEN: str = "" 10 | 11 | COMET_API_KEY: str = "" 12 | COMET_WORKSPACE: str = "" 13 | COMET_PROJECT: str = "" 14 | 15 | settings = AppSettings() -------------------------------------------------------------------------------- /4-finetuning/media/fine-tuning-workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/4-finetuning/media/fine-tuning-workflow.png -------------------------------------------------------------------------------- /4-finetuning/test_local.py: -------------------------------------------------------------------------------- 1 | from pandas import DataFrame 2 | from qwak.model.tools import run_local 3 | 4 | from finetuning import CopywriterMistralModel 5 | 6 | if __name__ == '__main__': 7 | model = CopywriterMistralModel() 8 | input_vector = DataFrame( 9 | [{ 10 | "instruction": "Write me a Linkedin post about Data Science" 11 | }] 12 | ).to_json() 13 | 14 | prediction = run_local(model, input_vector) 15 | print(prediction) 16 | -------------------------------------------------------------------------------- /5-inference/.env.example: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY = "str" 2 | HUGGINGFACE_ACCESS_TOKEN = "str" 3 | 4 | COMET_API_KEY = "str" 5 | COMET_WORKSPACE = "str" 6 | COMET_PROJECT = "llm-twin-course" 7 | 8 | QWAK_DEPLOYMENT_MODEL_ID = "str" 9 | QWAK_DEPLOYMENT_MODEL_API = "str" 10 | 11 | QDRANT_CLOUD_URL = "str" 12 | QDRANT_APIKEY = "str" 13 | -------------------------------------------------------------------------------- /5-inference/Makefile: -------------------------------------------------------------------------------- 1 | help: 2 | @grep -E '^[a-zA-Z0-9 -]+:.*#' Makefile | sort | while read -r l; do printf "\033[1;32m$$(echo $$l | cut -f 1 -d':')\033[00m:$$(echo $$l | cut -f 2- -d'#')\n"; done 3 | 4 | call-inference-pipeline: # Test the inference pipeline. 5 | poetry run python main.py 6 | 7 | -------------------------------------------------------------------------------- /5-inference/README.txt: -------------------------------------------------------------------------------- 1 | before running the inference do: 2 | - need to use `qwak configure` and set the API Key from your qwak env 3 | - when deploying the model on qwak as realtime under advanced setting you need to set the timeout filed to a bigger 4 | value (eg: 50000ms) -------------------------------------------------------------------------------- /5-inference/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class AppSettings(BaseSettings): 5 | model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") 6 | 7 | # Embeddings config 8 | EMBEDDING_MODEL_ID: str = "sentence-transformers/all-MiniLM-L6-v2" 9 | EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 256 10 | EMBEDDING_SIZE: int = 384 11 | EMBEDDING_MODEL_DEVICE: str = "cpu" 12 | 13 | # OpenAI config 14 | OPENAI_MODEL_ID: str = "gpt-4-1106-preview" 15 | OPENAI_API_KEY: str | None = None 16 | 17 | # QdrantDB config 18 | QDRANT_DATABASE_HOST: str = "mq" # Or localhost if running outside Docker 19 | QDRANT_DATABASE_PORT: int = 6333 20 | 21 | USE_QDRANT_CLOUD: bool = ( 22 | False # if True, fill in QDRANT_CLOUD_URL and QDRANT_APIKEY 23 | ) 24 | QDRANT_CLOUD_URL: str = "str" 25 | QDRANT_APIKEY: str | None = None 26 | 27 | # RAG config 28 | TOP_K: int = 5 29 | KEEP_TOP_K: int = 5 30 | EXPAND_N_QUERY: int = 5 31 | 32 | # MQ config 33 | RABBITMQ_HOST: str = "mq" 34 | RABBITMQ_PORT: int = 5672 35 | RABBITMQ_DEFAULT_USERNAME: str = "guest" 36 | RABBITMQ_DEFAULT_PASSWORD: str = "guest" 37 | 38 | # CometML config 39 | COMET_API_KEY: str 40 | COMET_WORKSPACE: str 41 | COMET_PROJECT: str = "llm-twin-course" 42 | 43 | # LLM Model config 44 | TOKENIZERS_PARALLELISM: str = "false" 45 | HUGGINGFACE_ACCESS_TOKEN: str | None = None 46 | MODEL_TYPE: str = "mistralai/Mistral-7B-Instruct-v0.1" 47 | 48 | QWAK_DEPLOYMENT_MODEL_ID: str = "copywriter_model" 49 | QWAK_DEPLOYMENT_MODEL_API: str = ( 50 | "https://models.llm-twin.qwak.ai/v1/copywriter_model/default/predict" 51 | ) 52 | 53 | 54 | settings = AppSettings() 55 | -------------------------------------------------------------------------------- /5-inference/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm import evaluate as evaluate_llm 2 | from .rag import evaluate as evaluate_rag 3 | 4 | __all__ = ["evaluate_llm", "evaluate_rag"] 5 | -------------------------------------------------------------------------------- /5-inference/evaluation/model.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | 3 | from llm.chain import GeneralChain 4 | from llm.prompt_templates import LLMEvaluationTemplate 5 | from settings import settings 6 | 7 | 8 | def evaluate(query: str, output: str) -> str: 9 | evaluation_template = LLMEvaluationTemplate() 10 | prompt_template = evaluation_template.create_template() 11 | 12 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID, api_key=settings.OPENAI_API_KEY) 13 | chain = GeneralChain.get_chain( 14 | llm=model, output_key="evaluation", template=prompt_template 15 | ) 16 | 17 | response = chain.invoke({"query": query, "output": output}) 18 | 19 | return response["evaluation"] 20 | -------------------------------------------------------------------------------- /5-inference/evaluation/rag.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | 3 | import llm as templates 4 | from llm import GeneralChain 5 | from config import settings 6 | 7 | 8 | def evaluate(query: str, context: list[str], output: str) -> str: 9 | evaluation_template = templates.RAGEvaluationTemplate() 10 | prompt_template = evaluation_template.create_template() 11 | 12 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID) 13 | chain = GeneralChain.get_chain( 14 | llm=model, output_key="rag_eval", template=prompt_template 15 | ) 16 | 17 | response = chain.invoke({"query": query, "context": context, "output": output}) 18 | 19 | return response["rag_eval"] 20 | -------------------------------------------------------------------------------- /5-inference/finetuning/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import CopywriterMistralModel 2 | 3 | 4 | def load_model(): 5 | return CopywriterMistralModel() 6 | -------------------------------------------------------------------------------- /5-inference/finetuning/config.yaml: -------------------------------------------------------------------------------- 1 | training_arguments: 2 | output_dir: "mistral_instruct_generation" 3 | max_steps: 10 4 | per_device_train_batch_size: 1 5 | logging_steps: 10 6 | save_strategy: "epoch" 7 | evaluation_strategy: "steps" 8 | eval_steps: 2 9 | learning_rate: 0.0002 10 | fp16: true 11 | remove_unused_columns: false 12 | lr_scheduler_type: "constant" 13 | -------------------------------------------------------------------------------- /5-inference/finetuning/dataset_client.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import logging 4 | from comet_ml import Experiment 5 | from sklearn.model_selection import train_test_split 6 | 7 | from settings import settings 8 | 9 | 10 | class DatasetClient: 11 | def __init__(self, output_dir: str = "./finetuning"): 12 | self.project = settings.COMET_PROJECT 13 | self.api_key = settings.COMET_API_KEY 14 | self.experiment = Experiment(api_key=self.api_key, project_name=self.project) 15 | self.output_dir = output_dir 16 | 17 | def get_artifact(self, artifact_name: str): 18 | try: 19 | logged_artifact = self.experiment.get_artifact(artifact_name) 20 | logged_artifact.download(self.output_dir) 21 | self.experiment.end() 22 | logging.info( 23 | f"Successfully downloaded {artifact_name} at location {self.output_dir}" 24 | ) 25 | except Exception as e: 26 | logging.error(f"Error retrieving artifact: {str(e)}") 27 | 28 | def split_data(self, artifact_name: str) -> tuple: 29 | try: 30 | training_file_path = os.path.join(self.output_dir, "train.json") 31 | validation_file_path = os.path.join(self.output_dir, "validation.json") 32 | file_name = artifact_name + ".json" 33 | with open(os.path.join(self.output_dir, file_name), "r") as file: 34 | data = json.load(file) 35 | 36 | train_data, val_data = train_test_split( 37 | data, test_size=0.2, random_state=42 38 | ) 39 | 40 | with open(training_file_path, "w") as train_file: 41 | json.dump(train_data, train_file) 42 | 43 | with open(validation_file_path, "w") as val_file: 44 | json.dump(val_data, val_file) 45 | 46 | logging.info("Data split into train.json and validation.json successfully.") 47 | return training_file_path, validation_file_path 48 | except Exception as e: 49 | logging.error(f"Error splitting data: {str(e)}") 50 | 51 | raise 52 | 53 | def download_dataset(self, file_name: str) -> tuple: 54 | self.get_artifact(file_name) 55 | 56 | return self.split_data(file_name) 57 | -------------------------------------------------------------------------------- /5-inference/finetuning/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import pandas as pd 5 | import qwak 6 | import torch as th 7 | import yaml 8 | from comet_ml import Experiment 9 | from datasets import DatasetDict, load_dataset 10 | from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training 11 | from qwak.model.adapters import DefaultOutputAdapter 12 | from qwak.model.base import QwakModel 13 | from qwak.model.schema import ModelSchema 14 | from qwak.model.schema_entities import InferenceOutput, RequestInput 15 | from transformers import ( 16 | AutoModelForCausalLM, 17 | AutoTokenizer, 18 | BitsAndBytesConfig, 19 | PreTrainedModel, 20 | Trainer, 21 | TrainingArguments, 22 | ) 23 | 24 | from finetuning.dataset_client import DatasetClient 25 | from settings import settings 26 | 27 | 28 | class CopywriterMistralModel(QwakModel): 29 | def __init__( 30 | self, 31 | is_saved: bool = False, 32 | model_save_dir: str = "./model", 33 | model_type: str = "mistralai/Mistral-7B-Instruct-v0.1", 34 | comet_artifact_name: str = "cleaned_posts", 35 | config_file: str = "./finetuning/config.yaml", 36 | ) -> None: 37 | self._prep_environment() 38 | 39 | self.experiment = None 40 | self.model_save_dir = model_save_dir 41 | self.model_type = model_type 42 | self.comet_dataset_artifact = comet_artifact_name 43 | self.training_args_config_file = config_file 44 | if is_saved: 45 | self.experiment = Experiment( 46 | api_key=settings.COMET_API_KEY, 47 | project_name=settings.COMET_PROJECT, 48 | workspace=settings.COMET_WORKSPACE, 49 | ) 50 | 51 | def _prep_environment(self) -> None: 52 | os.environ["TOKENIZERS_PARALLELISM"] = settings.TOKENIZERS_PARALLELISM 53 | th.cuda.empty_cache() 54 | logging.info("Emptied cuda cache. Environment prepared successfully!") 55 | 56 | def init_model(self) -> None: 57 | self.model = AutoModelForCausalLM.from_pretrained( 58 | self.model_type, 59 | token=settings.HUGGINGFACE_ACCESS_TOKEN, 60 | device_map=th.cuda.current_device(), 61 | quantization_config=self.nf4_config, 62 | use_cache=False, 63 | torchscript=True, 64 | ) 65 | self.tokenizer = AutoTokenizer.from_pretrained( 66 | self.model_type, token=settings.HUGGINGFACE_ACCESS_TOKEN 67 | ) 68 | self.tokenizer.pad_token = self.tokenizer.eos_token 69 | self.tokenizer.padding_side = "right" 70 | logging.info(f"Initialized model{self.model_type} successfully") 71 | 72 | def _init_4bit_config(self): 73 | self.nf4_config = BitsAndBytesConfig( 74 | load_in_4bit=True, 75 | bnb_4bit_quant_type="nf4", 76 | bnb_4bit_use_double_quant=True, 77 | bnb_4bit_compute_dtype=th.bfloat16, 78 | ) 79 | if self.experiment: 80 | self.experiment.log_parameters(self.nf4_config) 81 | logging.info( 82 | "Initialized config for param representation on 4bits successfully!" 83 | ) 84 | 85 | def _initialize_qlora(self, model: PreTrainedModel) -> PeftModel: 86 | self.qlora_config = LoraConfig( 87 | lora_alpha=16, lora_dropout=0.1, r=64, bias="none", task_type="CAUSAL_LM" 88 | ) 89 | 90 | if self.experiment: 91 | self.experiment.log_parameters(self.qlora_config) 92 | 93 | model = prepare_model_for_kbit_training(model) 94 | model = get_peft_model(model, self.qlora_config) 95 | logging.info("Initialized qlora config successfully!") 96 | return model 97 | 98 | def _init_trainig_args(self): 99 | with open(self.training_args_config_file, "r") as file: 100 | config = yaml.safe_load(file) 101 | self.training_arguments = TrainingArguments(**config["training_arguments"]) 102 | if self.experiment: 103 | self.experiment.log_parameters(self.training_arguments) 104 | logging.info("Initialized training arguments successfully!") 105 | 106 | def _remove_model_class_attributes(self): 107 | # remove needed in order to skip default serialization with Pickle done by Qwak 108 | del self.model 109 | del self.trainer 110 | del self.experiment 111 | 112 | def generate_prompt(self, sample: dict) -> dict: 113 | full_prompt = f"""[INST]{sample['instruction']} 114 | [/INST] {sample['content']}""" 115 | result = self.tokenize(full_prompt) 116 | return result 117 | 118 | def tokenize(self, prompt: str) -> dict: 119 | result = self.tokenizer( 120 | prompt, 121 | padding="max_length", 122 | max_length=100, 123 | truncation=True, 124 | ) 125 | result["labels"] = result["input_ids"].copy() 126 | return result 127 | 128 | def load_dataset(self) -> DatasetDict: 129 | dataset_handler = DatasetClient() 130 | train_data_file, validation_data_file = dataset_handler.download_dataset( 131 | self.comet_dataset_artifact 132 | ) 133 | data_files = {"train": train_data_file, "validation": validation_data_file} 134 | raw_datasets = load_dataset("json", data_files=data_files) 135 | train_dataset, val_dataset = self.preprocess_data_split(raw_datasets) 136 | return DatasetDict({"train": train_dataset, "validation": val_dataset}) 137 | 138 | def preprocess_data_split(self, raw_datasets: DatasetDict): 139 | train_data = raw_datasets["train"] 140 | val_data = raw_datasets["validation"] 141 | generated_train_dataset = train_data.map(self.generate_prompt) 142 | generated_train_dataset = generated_train_dataset.remove_columns( 143 | ["instruction", "content"] 144 | ) 145 | generated_val_dataset = val_data.map(self.generate_prompt) 146 | generated_val_dataset = generated_val_dataset.remove_columns( 147 | ["instruction", "content"] 148 | ) 149 | return generated_train_dataset, generated_val_dataset 150 | 151 | def build(self): 152 | self._init_4bit_config() 153 | self.init_model() 154 | if self.experiment: 155 | self.experiment.log_parameters(self.nf4_config) 156 | self.model = self._initialize_qlora(self.model) 157 | self._init_trainig_args() 158 | tokenized_datasets = self.load_dataset() 159 | self.device = th.device("cuda" if th.cuda.is_available() else "cpu") 160 | self.model = self.model.to(self.device) 161 | self.trainer = Trainer( 162 | model=self.model, 163 | args=self.training_arguments, 164 | train_dataset=tokenized_datasets["train"], 165 | eval_dataset=tokenized_datasets["validation"], 166 | tokenizer=self.tokenizer, 167 | ) 168 | 169 | logging.info("Initialized model trainer") 170 | self.trainer.train() 171 | logging.info(f"Finished training LLM: {self.model_type}") 172 | self.trainer.save_model(self.model_save_dir) 173 | logging.info(f"Finished saving model to {self.model_save_dir}") 174 | self.experiment.end() 175 | 176 | self._remove_model_class_attributes() 177 | 178 | def initialize_model(self) -> None: 179 | self.model = AutoModelForCausalLM.from_pretrained( 180 | self.model_save_dir, 181 | token=settings.HUGGINGFACE_ACCESS_TOKEN, 182 | quantization_config=self.nf4_config, 183 | ) 184 | logging.info(f"Successfully loaded model from {self.model_save_dir}") 185 | 186 | def schema(self) -> ModelSchema: 187 | return ModelSchema( 188 | inputs=[RequestInput(name="instruction", type=str)], 189 | outputs=[InferenceOutput(name="content", type=str)], 190 | ) 191 | 192 | @qwak.api(output_adapter=DefaultOutputAdapter()) 193 | def predict(self, df): 194 | input_text = list(df["instruction"].values) 195 | input_ids = self.tokenizer( 196 | input_text, return_tensors="pt", add_special_tokens=True 197 | ) 198 | input_ids = input_ids.to(self.device) 199 | 200 | generated_ids = self.model.generate( 201 | **input_ids, 202 | max_new_tokens=500, 203 | do_sample=True, 204 | pad_token_id=self.tokenizer.eos_token_id, 205 | ) 206 | 207 | decoded_output = self.tokenizer.batch_decode( 208 | generated_ids[input_ids.shape[0] :] 209 | )[0] 210 | 211 | return pd.DataFrame([{"content": decoded_output}]) 212 | -------------------------------------------------------------------------------- /5-inference/finetuning/settings.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class AppSettings(BaseSettings): 5 | model_config = SettingsConfigDict() 6 | 7 | TOKENIZERS_PARALLELISM: str = "false" 8 | HUGGINGFACE_ACCESS_TOKEN: str = "" 9 | COMET_API_KEY: str = "" 10 | COMET_WORKSPACE: str = "" 11 | COMET_PROJECT: str = "" 12 | 13 | 14 | settings = AppSettings() 15 | -------------------------------------------------------------------------------- /5-inference/inference_pipeline.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from evaluation import evaluate_llm 3 | from llm.prompt_templates import InferenceTemplate 4 | from monitoring import PromptMonitoringManager 5 | from qwak_inference import RealTimeClient 6 | from rag.retriever import VectorRetriever 7 | from config import settings 8 | 9 | 10 | class LLMTwin: 11 | def __init__(self) -> None: 12 | self.qwak_client = RealTimeClient( 13 | model_id=settings.QWAK_DEPLOYMENT_MODEL_ID, 14 | model_api=settings.QWAK_DEPLOYMENT_MODEL_API, 15 | ) 16 | self.template = InferenceTemplate() 17 | self.prompt_monitoring_manager = PromptMonitoringManager() 18 | 19 | def generate( 20 | self, 21 | query: str, 22 | enable_rag: bool = False, 23 | enable_evaluation: bool = False, 24 | enable_monitoring: bool = True, 25 | ) -> dict: 26 | prompt_template = self.template.create_template(enable_rag=enable_rag) 27 | prompt_template_variables = { 28 | "question": query, 29 | } 30 | 31 | if enable_rag is True: 32 | retriever = VectorRetriever(query=query) 33 | hits = retriever.retrieve_top_k( 34 | k=settings.TOP_K, to_expand_to_n_queries=settings.EXPAND_N_QUERY 35 | ) 36 | context = retriever.rerank(hits=hits, keep_top_k=settings.KEEP_TOP_K) 37 | prompt_template_variables["context"] = context 38 | 39 | prompt = prompt_template.format(question=query, context=context) 40 | else: 41 | prompt = prompt_template.format(question=query) 42 | 43 | input_ = pd.DataFrame([{"instruction": prompt}]).to_json() 44 | 45 | response: list[dict] = self.qwak_client.predict(input_) 46 | answer = response[0]["content"][0] 47 | 48 | if enable_evaluation is True: 49 | evaluation_result = evaluate_llm(query=query, output=answer) 50 | else: 51 | evaluation_result = None 52 | 53 | if enable_monitoring is True: 54 | if evaluation_result is not None: 55 | metadata = {"llm_evaluation_result": evaluation_result} 56 | else: 57 | metadata = None 58 | 59 | self.prompt_monitoring_manager.log( 60 | prompt=prompt, 61 | prompt_template=prompt_template.template, 62 | prompt_template_variables=prompt_template_variables, 63 | output=answer, 64 | metadata=metadata, 65 | ) 66 | self.prompt_monitoring_manager.log_chain( 67 | query=query, response=answer, eval_output=evaluation_result 68 | ) 69 | 70 | return {"answer": answer, "llm_evaluation_result": evaluation_result} 71 | -------------------------------------------------------------------------------- /5-inference/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/5-inference/llm/__init__.py -------------------------------------------------------------------------------- /5-inference/llm/chain.py: -------------------------------------------------------------------------------- 1 | from langchain.chains.llm import LLMChain 2 | from langchain.prompts import PromptTemplate 3 | 4 | 5 | class GeneralChain: 6 | @staticmethod 7 | def get_chain(llm, template: PromptTemplate, output_key: str, verbose=True): 8 | return LLMChain( 9 | llm=llm, prompt=template, output_key=output_key, verbose=verbose 10 | ) 11 | -------------------------------------------------------------------------------- /5-inference/llm/prompt_templates.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from langchain.prompts import PromptTemplate 4 | from pydantic import BaseModel 5 | 6 | 7 | class BasePromptTemplate(ABC, BaseModel): 8 | @abstractmethod 9 | def create_template(self, *args) -> PromptTemplate: 10 | pass 11 | 12 | 13 | class QueryExpansionTemplate(BasePromptTemplate): 14 | prompt: str = """You are an AI language model assistant. Your task is to generate {to_expand_to_n} 15 | different versions of the given user question to retrieve relevant documents from a vector 16 | database. By generating multiple perspectives on the user question, your goal is to help 17 | the user overcome some of the limitations of the distance-based similarity search. 18 | Provide these alternative questions seperated by '{separator}'. 19 | Original question: {question}""" 20 | 21 | @property 22 | def separator(self) -> str: 23 | return "#next-question#" 24 | 25 | def create_template(self, to_expand_to_n: int) -> PromptTemplate: 26 | return PromptTemplate( 27 | template=self.prompt, 28 | input_variables=["question"], 29 | partial_variables={ 30 | "separator": self.separator, 31 | "to_expand_to_n": to_expand_to_n, 32 | }, 33 | ) 34 | 35 | 36 | class SelfQueryTemplate(BasePromptTemplate): 37 | prompt: str = """You are an AI language model assistant. Your task is to extract information from a user question. 38 | The required information that needs to be extracted is the user or author id. 39 | Your response should consists of only the extracted id (e.g. 1345256), nothing else. 40 | User question: {question}""" 41 | 42 | def create_template(self) -> PromptTemplate: 43 | return PromptTemplate(template=self.prompt, input_variables=["question"]) 44 | 45 | 46 | class RerankingTemplate(BasePromptTemplate): 47 | prompt: str = """You are an AI language model assistant. Your task is to rerank passages related to a query 48 | based on their relevance. 49 | The most relevant passages should be put at the beginning. 50 | You should only pick at max {keep_top_k} passages. 51 | The provided and reranked documents are separated by '{separator}'. 52 | 53 | The following are passages related to this query: {question}. 54 | 55 | Passages: 56 | {passages} 57 | """ 58 | 59 | def create_template(self, keep_top_k: int) -> PromptTemplate: 60 | return PromptTemplate( 61 | template=self.prompt, 62 | input_variables=["question", "passages"], 63 | partial_variables={"keep_top_k": keep_top_k, "separator": self.separator}, 64 | ) 65 | 66 | @property 67 | def separator(self) -> str: 68 | return "\n#next-document#\n" 69 | 70 | 71 | class InferenceTemplate(BasePromptTemplate): 72 | simple_prompt: str = """You are an AI language model assistant. Your task is to generate a cohesive and concise response to the user question. 73 | Question: {question} 74 | """ 75 | 76 | rag_prompt: str = """ You are a specialist in technical content writing. Your task is to create technical content based on a user query given a specific context 77 | with additional information consisting of the user's previous writings and his knowledge. 78 | 79 | Here is a list of steps that you need to follow in order to solve this task: 80 | Step 1: You need to analyze the user provided query : {question} 81 | Step 2: You need to analyze the provided context and how the information in it relates to the user question: {context} 82 | Step 3: Generate the content keeping in mind that it needs to be as cohesive and concise as possible related to the subject presented in the query and similar to the users writing style and knowledge presented in the context. 83 | """ 84 | 85 | def create_template(self, enable_rag: bool = True) -> PromptTemplate: 86 | if enable_rag is True: 87 | return PromptTemplate( 88 | template=self.rag_prompt, input_variables=["question", "context"] 89 | ) 90 | 91 | return PromptTemplate(template=self.simple_prompt, input_variables=["question"]) 92 | 93 | 94 | class LLMEvaluationTemplate(BasePromptTemplate): 95 | prompt: str = """ 96 | You are an AI assistant and your task is to evaluate the output generated by another LLM. 97 | You need to follow these steps: 98 | Step 1: Analyze the user query: {query} 99 | Step 2: Analyze the response: {output} 100 | Step 3: Evaluate the generated response based on the following criteria and provide a score from 1 to 5 along with a brief justification for each criterion: 101 | 102 | Evaluation: 103 | Relevance - [score] 104 | [1 sentence justification why relevance = score] 105 | Coherence - [score] 106 | [1 sentence justification why coherence = score] 107 | Conciseness - [score] 108 | [1 sentence justification why conciseness = score] 109 | """ 110 | 111 | def create_template(self) -> PromptTemplate: 112 | return PromptTemplate(template=self.prompt, input_variables=["query", "output"]) 113 | 114 | 115 | class RAGEvaluationTemplate(BasePromptTemplate): 116 | prompt: str = """You are an AI assistant and your task is to evaluate the output generated by another LLM. 117 | The other LLM generates writing content based on a user query and a given context. 118 | The given context is comprised of custom data produces by a user that consists of posts, articles or code fragments. 119 | Here is a list of steps you need to follow in order to solve this task: 120 | Step 1: You need to analyze the user query : {query} 121 | Step 2: You need to analyze the given context: {contex} 122 | Step 3: You need to analyze the generated output: {output} 123 | Step 4: Generate the evaluation 124 | When doing the evaluation step you need to take the following into consideration the following: 125 | -The evaluation needs to have some sort of metrics. 126 | -The generated content needs to be evaluated based on the writing similarity form the context. 127 | -The generated content needs to be evaluated based on it's coherence and conciseness related to the given query and context. 128 | -The generated content needs to be evaluate based on how well it represents the user knowledge extracted from the context.""" 129 | 130 | def create_template(self) -> PromptTemplate: 131 | return PromptTemplate( 132 | template=self.prompt, input_variables=["query", "context", "output"] 133 | ) 134 | -------------------------------------------------------------------------------- /5-inference/main.py: -------------------------------------------------------------------------------- 1 | import core.logger_utils as logger_utils 2 | from inference_pipeline import LLMTwin 3 | 4 | logger = logger_utils.get_logger(__name__) 5 | 6 | 7 | if __name__ == "__main__": 8 | inference_endpoint = LLMTwin() 9 | 10 | query = """ 11 | Hello my author_id is 1. 12 | 13 | Could you please draft a LinkedIn post discussing Vector Databases? 14 | I'm particularly interested in how do they work. 15 | """ 16 | 17 | response = inference_endpoint.generate( 18 | query=query, 19 | enable_rag=False, 20 | enable_evaluation=True, 21 | enable_monitoring=True, 22 | ) 23 | 24 | logger.info(f"Answer: {response['answer']}") 25 | logger.info("=" * 50) 26 | logger.info(f"LLM Evaluation Result: {response['llm_evaluation_result']}") 27 | -------------------------------------------------------------------------------- /5-inference/monitoring/__init__.py: -------------------------------------------------------------------------------- 1 | from .prompt_monitoring import PromptMonitoringManager 2 | 3 | __all__ = ["PromptMonitoringManager"] 4 | -------------------------------------------------------------------------------- /5-inference/monitoring/prompt_monitoring.py: -------------------------------------------------------------------------------- 1 | import comet_llm 2 | 3 | from config import settings 4 | 5 | 6 | class PromptMonitoringManager: 7 | @classmethod 8 | def log( 9 | cls, 10 | prompt: str, 11 | output: str, 12 | prompt_template: str | None = None, 13 | prompt_template_variables: dict | None = None, 14 | metadata: dict | None = None, 15 | ) -> None: 16 | comet_llm.init() 17 | 18 | metadata = metadata or {} 19 | metadata = { 20 | "model": settings.MODEL_TYPE, 21 | **metadata, 22 | } 23 | 24 | comet_llm.log_prompt( 25 | workspace=settings.COMET_WORKSPACE, 26 | project=f"{settings.COMET_PROJECT}-monitoring", 27 | api_key=settings.COMET_API_KEY, 28 | prompt=prompt, 29 | prompt_template=prompt_template, 30 | prompt_template_variables=prompt_template_variables, 31 | output=output, 32 | metadata=metadata, 33 | ) 34 | 35 | @classmethod 36 | def log_chain(cls, query: str, response: str, eval_output: str): 37 | comet_llm.init(project=f"{settings.COMET_PROJECT}-monitoring") 38 | comet_llm.start_chain( 39 | inputs={"user_query": query}, 40 | project=f"{settings.COMET_PROJECT}-monitoring", 41 | api_key=settings.COMET_API_KEY, 42 | workspace=settings.COMET_WORKSPACE, 43 | ) 44 | with comet_llm.Span( 45 | category="twin_response", 46 | inputs={"user_query": query}, 47 | ) as span: 48 | span.set_outputs(outputs=response) 49 | 50 | with comet_llm.Span( 51 | category="gpt3.5-eval", 52 | inputs={"eval_result": eval_output}, 53 | ) as span: 54 | span.set_outputs(outputs=response) 55 | comet_llm.end_chain(outputs={"response": response, "eval_output": eval_output}) 56 | -------------------------------------------------------------------------------- /5-inference/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "inference-pipeline" 3 | description = "" 4 | version = "0.1.0" 5 | authors = [ 6 | "vlad_adu ", 7 | "Paul Iusztin ", 8 | "Alex Vesa ", 9 | ] 10 | package-mode = false 11 | readme = "README.md" 12 | 13 | 14 | [tool.poetry.dependencies] 15 | python = ">=3.10, <3.12" 16 | pydantic = "^2.6.3" 17 | pydantic-settings = "^2.1.0" 18 | bytewax = "0.18.2" 19 | pika = "^1.3.2" 20 | qdrant-client = "^1.8.0" 21 | unstructured = "^0.12.6" 22 | langchain = "^0.1.13" 23 | sentence-transformers = "^2.6.1" 24 | instructorembedding = "^1.0.1" 25 | numpy = "^1.26.4" 26 | langchain-openai = "^0.1.3" 27 | gdown = "^5.1.0" 28 | pymongo = "^4.7.1" 29 | structlog = "^24.1.0" 30 | rich = "^13.7.1" 31 | pip = "^24.0" 32 | install = "^1.3.5" 33 | comet-ml = "^3.41.0" 34 | ruff = "^0.4.3" 35 | comet-llm = "^2.2.4" 36 | qwak-sdk = "^0.5.69" 37 | pandas = "^2.2.2" 38 | datasets = "^2.19.1" 39 | peft = "^0.11.1" 40 | bitsandbytes = "^0.43.1" 41 | qwak-inference = "^0.1.17" 42 | 43 | 44 | [build-system] 45 | requires = ["poetry-core"] 46 | build-backend = "poetry.core.masonry.api" 47 | 48 | 49 | [tool.ruff] 50 | line-length = 88 51 | select = ["F401", "F403"] 52 | -------------------------------------------------------------------------------- /5-inference/rag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/5-inference/rag/__init__.py -------------------------------------------------------------------------------- /5-inference/rag/query_expanison.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | 3 | from llm_components.chain import GeneralChain 4 | from llm_components.prompt_templates import QueryExpansionTemplate 5 | from config import settings 6 | 7 | 8 | class QueryExpansion: 9 | @staticmethod 10 | def generate_response(query: str, to_expand_to_n: int) -> list[str]: 11 | query_expansion_template = QueryExpansionTemplate() 12 | prompt_template = query_expansion_template.create_template(to_expand_to_n) 13 | model = ChatOpenAI( 14 | model=settings.OPENAI_MODEL_ID, 15 | api_key=settings.OPENAI_API_KEY, 16 | temperature=0, 17 | ) 18 | 19 | chain = GeneralChain().get_chain( 20 | llm=model, output_key="expanded_queries", template=prompt_template 21 | ) 22 | 23 | response = chain.invoke({"question": query}) 24 | result = response["expanded_queries"] 25 | 26 | queries = result.strip().split(query_expansion_template.separator) 27 | stripped_queries = [ 28 | stripped_item for item in queries if (stripped_item := item.strip()) 29 | ] 30 | 31 | return stripped_queries 32 | -------------------------------------------------------------------------------- /5-inference/rag/reranking.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | from llm_components.chain import GeneralChain 3 | from llm_components.prompt_templates import RerankingTemplate 4 | 5 | from config import settings 6 | 7 | 8 | class Reranker: 9 | @staticmethod 10 | def generate_response( 11 | query: str, passages: list[str], keep_top_k: int 12 | ) -> list[str]: 13 | reranking_template = RerankingTemplate() 14 | prompt_template = reranking_template.create_template(keep_top_k=keep_top_k) 15 | 16 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID, api_key=settings.OPENAI_API_KEY) 17 | chain = GeneralChain().get_chain( 18 | llm=model, output_key="rerank", template=prompt_template 19 | ) 20 | 21 | stripped_passages = [ 22 | stripped_item for item in passages if (stripped_item := item.strip()) 23 | ] 24 | passages = reranking_template.separator.join(stripped_passages) 25 | response = chain.invoke({"question": query, "passages": passages}) 26 | 27 | result = response["rerank"] 28 | reranked_passages = result.strip().split(reranking_template.separator) 29 | stripped_passages = [ 30 | stripped_item 31 | for item in reranked_passages 32 | if (stripped_item := item.strip()) 33 | ] 34 | 35 | return stripped_passages 36 | -------------------------------------------------------------------------------- /5-inference/rag/retriever.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | 3 | import core.logger_utils as logger_utils 4 | from core.db.qdrant import QdrantDatabaseConnector 5 | from qdrant_client import models 6 | from sentence_transformers.SentenceTransformer import SentenceTransformer 7 | 8 | import utils 9 | from rag.query_expanison import QueryExpansion 10 | from rag.reranking import Reranker 11 | from rag.self_query import SelfQuery 12 | from config import settings 13 | 14 | logger = logger_utils.get_logger(__name__) 15 | 16 | 17 | class VectorRetriever: 18 | """ 19 | Class for retrieving vectors from a Vector store in a RAG system using query expansion and Multitenancy search. 20 | """ 21 | 22 | def __init__(self, query: str) -> None: 23 | self._client = QdrantDatabaseConnector() 24 | self.query = query 25 | self._embedder = SentenceTransformer(settings.EMBEDDING_MODEL_ID) 26 | self._query_expander = QueryExpansion() 27 | self._metadata_extractor = SelfQuery() 28 | self._reranker = Reranker() 29 | 30 | def _search_single_query( 31 | self, generated_query: str, metadata_filter_value: str, k: int 32 | ): 33 | assert k > 3, "k should be greater than 3" 34 | 35 | query_vector = self._embedder.encode(generated_query).tolist() 36 | 37 | vectors = [ 38 | self._client.search( 39 | collection_name="vector_posts", 40 | query_filter=models.Filter( 41 | must=[ 42 | models.FieldCondition( 43 | key="author_id", 44 | match=models.MatchValue( 45 | value=metadata_filter_value, 46 | ), 47 | ) 48 | ] 49 | ), 50 | query_vector=query_vector, 51 | limit=k // 3, 52 | ), 53 | self._client.search( 54 | collection_name="vector_articles", 55 | query_filter=models.Filter( 56 | must=[ 57 | models.FieldCondition( 58 | key="author_id", 59 | match=models.MatchValue( 60 | value=metadata_filter_value, 61 | ), 62 | ) 63 | ] 64 | ), 65 | query_vector=query_vector, 66 | limit=k // 3, 67 | ), 68 | self._client.search( 69 | collection_name="vector_repositories", 70 | query_filter=models.Filter( 71 | must=[ 72 | models.FieldCondition( 73 | key="owner_id", 74 | match=models.MatchValue( 75 | value=metadata_filter_value, 76 | ), 77 | ) 78 | ] 79 | ), 80 | query_vector=query_vector, 81 | limit=k // 3, 82 | ), 83 | ] 84 | 85 | return utils.flatten(vectors) 86 | 87 | def retrieve_top_k(self, k: int, to_expand_to_n_queries: int) -> list: 88 | generated_queries = self._query_expander.generate_response( 89 | self.query, to_expand_to_n=to_expand_to_n_queries 90 | ) 91 | logger.info( 92 | "Successfully generated queries for search.", 93 | num_queries=len(generated_queries), 94 | ) 95 | 96 | author_id = self._metadata_extractor.generate_response(self.query) 97 | logger.info( 98 | "Successfully extracted the author_id from the query.", 99 | author_id=author_id, 100 | ) 101 | 102 | with concurrent.futures.ThreadPoolExecutor() as executor: 103 | search_tasks = [ 104 | executor.submit(self._search_single_query, query, author_id, k) 105 | for query in generated_queries 106 | ] 107 | 108 | hits = [ 109 | task.result() for task in concurrent.futures.as_completed(search_tasks) 110 | ] 111 | hits = utils.flatten(hits) 112 | 113 | logger.info("All documents retrieved successfully.", num_documents=len(hits)) 114 | 115 | return hits 116 | 117 | def rerank(self, hits: list, keep_top_k: int) -> list[str]: 118 | content_list = [hit.payload["content"] for hit in hits] 119 | rerank_hits = self._reranker.generate_response( 120 | query=self.query, passages=content_list, keep_top_k=keep_top_k 121 | ) 122 | 123 | logger.info("Documents reranked successfully.", num_documents=len(rerank_hits)) 124 | 125 | return rerank_hits 126 | 127 | def set_query(self, query: str): 128 | self.query = query 129 | -------------------------------------------------------------------------------- /5-inference/rag/self_query.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | 3 | from llm_components.chain import GeneralChain 4 | from llm_components.prompt_templates import SelfQueryTemplate 5 | from config import settings 6 | 7 | 8 | class SelfQuery: 9 | @staticmethod 10 | def generate_response(query: str) -> str: 11 | prompt = SelfQueryTemplate().create_template() 12 | model = ChatOpenAI( 13 | model=settings.OPENAI_MODEL_ID, 14 | api_key=settings.OPENAI_API_KEY, 15 | temperature=0, 16 | ) 17 | 18 | chain = GeneralChain().get_chain( 19 | llm=model, output_key="metadata_filter_value", template=prompt 20 | ) 21 | 22 | response = chain.invoke({"question": query}) 23 | result = response["metadata_filter_value"] 24 | 25 | return result 26 | -------------------------------------------------------------------------------- /5-inference/utils/__init__.py: -------------------------------------------------------------------------------- 1 | def flatten(nested_list: list) -> list: 2 | """Flatten a list of lists into a single list.""" 3 | 4 | return [item for sublist in nested_list for item in sublist] 5 | -------------------------------------------------------------------------------- /5-inference/utils/chunking.py: -------------------------------------------------------------------------------- 1 | from langchain.text_splitter import ( 2 | RecursiveCharacterTextSplitter, 3 | SentenceTransformersTokenTextSplitter, 4 | ) 5 | 6 | from config import settings 7 | 8 | 9 | def chunk_text(text: str) -> list[str]: 10 | character_splitter = RecursiveCharacterTextSplitter( 11 | separators=["\n\n"], chunk_size=500, chunk_overlap=0 12 | ) 13 | text_split = character_splitter.split_text(text) 14 | 15 | token_splitter = SentenceTransformersTokenTextSplitter( 16 | chunk_overlap=50, 17 | tokens_per_chunk=settings.EMBEDDING_MODEL_MAX_INPUT_LENGTH, 18 | model_name=settings.EMBEDDING_MODEL_ID, 19 | ) 20 | chunks = [] 21 | 22 | for section in text_split: 23 | chunks.extend(token_splitter.split_text(section)) 24 | 25 | return chunks 26 | -------------------------------------------------------------------------------- /5-inference/utils/cleaning.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from unstructured.cleaners.core import ( 4 | clean, 5 | clean_non_ascii_chars, 6 | replace_unicode_quotes, 7 | ) 8 | 9 | 10 | def unbold_text(text): 11 | # Mapping of bold numbers to their regular equivalents 12 | bold_numbers = { 13 | "𝟬": "0", 14 | "𝟭": "1", 15 | "𝟮": "2", 16 | "𝟯": "3", 17 | "𝟰": "4", 18 | "𝟱": "5", 19 | "𝟲": "6", 20 | "𝟳": "7", 21 | "𝟴": "8", 22 | "𝟵": "9", 23 | } 24 | 25 | # Function to convert bold characters (letters and numbers) 26 | def convert_bold_char(match): 27 | char = match.group(0) 28 | # Convert bold numbers 29 | if char in bold_numbers: 30 | return bold_numbers[char] 31 | # Convert bold uppercase letters 32 | elif "\U0001d5d4" <= char <= "\U0001d5ed": 33 | return chr(ord(char) - 0x1D5D4 + ord("A")) 34 | # Convert bold lowercase letters 35 | elif "\U0001d5ee" <= char <= "\U0001d607": 36 | return chr(ord(char) - 0x1D5EE + ord("a")) 37 | else: 38 | return char # Return the character unchanged if it's not a bold number or letter 39 | 40 | # Regex for bold characters (numbers, uppercase, and lowercase letters) 41 | bold_pattern = re.compile( 42 | r"[\U0001D5D4-\U0001D5ED\U0001D5EE-\U0001D607\U0001D7CE-\U0001D7FF]" 43 | ) 44 | text = bold_pattern.sub(convert_bold_char, text) 45 | 46 | return text 47 | 48 | 49 | def unitalic_text(text): 50 | # Function to convert italic characters (both letters) 51 | def convert_italic_char(match): 52 | char = match.group(0) 53 | # Unicode ranges for italic characters 54 | if "\U0001d608" <= char <= "\U0001d621": # Italic uppercase A-Z 55 | return chr(ord(char) - 0x1D608 + ord("A")) 56 | elif "\U0001d622" <= char <= "\U0001d63b": # Italic lowercase a-z 57 | return chr(ord(char) - 0x1D622 + ord("a")) 58 | else: 59 | return char # Return the character unchanged if it's not an italic letter 60 | 61 | # Regex for italic characters (uppercase and lowercase letters) 62 | italic_pattern = re.compile(r"[\U0001D608-\U0001D621\U0001D622-\U0001D63B]") 63 | text = italic_pattern.sub(convert_italic_char, text) 64 | 65 | return text 66 | 67 | 68 | def remove_emojis_and_symbols(text): 69 | # Extended pattern to include specific symbols like ↓ (U+2193) or ↳ (U+21B3) 70 | emoji_and_symbol_pattern = re.compile( 71 | "[" 72 | "\U0001f600-\U0001f64f" # emoticons 73 | "\U0001f300-\U0001f5ff" # symbols & pictographs 74 | "\U0001f680-\U0001f6ff" # transport & map symbols 75 | "\U0001f1e0-\U0001f1ff" # flags (iOS) 76 | "\U00002193" # downwards arrow 77 | "\U000021b3" # downwards arrow with tip rightwards 78 | "\U00002192" # rightwards arrow 79 | "]+", 80 | flags=re.UNICODE, 81 | ) 82 | 83 | return emoji_and_symbol_pattern.sub(r" ", text) 84 | 85 | 86 | def replace_urls_with_placeholder(text, placeholder="[URL]"): 87 | # Regular expression pattern for matching URLs 88 | url_pattern = r"https?://\S+|www\.\S+" 89 | 90 | return re.sub(url_pattern, placeholder, text) 91 | 92 | 93 | def remove_non_ascii(text: str) -> str: 94 | text = text.encode("ascii", "ignore").decode("ascii") 95 | return text 96 | 97 | 98 | def clean_text(text_content: str) -> str: 99 | cleaned_text = unbold_text(text_content) 100 | cleaned_text = unitalic_text(cleaned_text) 101 | cleaned_text = remove_emojis_and_symbols(cleaned_text) 102 | cleaned_text = clean(cleaned_text) 103 | cleaned_text = replace_unicode_quotes(cleaned_text) 104 | cleaned_text = clean_non_ascii_chars(cleaned_text) 105 | cleaned_text = replace_urls_with_placeholder(cleaned_text) 106 | 107 | return cleaned_text 108 | -------------------------------------------------------------------------------- /5-inference/utils/embeddings.py: -------------------------------------------------------------------------------- 1 | from InstructorEmbedding import INSTRUCTOR 2 | from sentence_transformers.SentenceTransformer import SentenceTransformer 3 | 4 | from config import settings 5 | 6 | 7 | def embedd_text(text: str): 8 | model = SentenceTransformer(settings.EMBEDDING_MODEL_ID) 9 | return model.encode(text) 10 | 11 | 12 | def embedd_repositories(text: str): 13 | model = INSTRUCTOR("hkunlp/instructor-xl") 14 | sentence = text 15 | instruction = "Represent the structure of the repository" 16 | return model.encode([instruction, sentence]) 17 | -------------------------------------------------------------------------------- /GENERATE_INSTRUCT_DATASET.md: -------------------------------------------------------------------------------- 1 | # Generate Data for LLM finetuning task component 2 | 3 | ## Component Structure 4 | 5 | ### File Handling 6 | - `file_handler.py`: Manages file I/O operations, enabling reading and writing of JSON formatted data. 7 | 8 | ### LLM Communication 9 | - `llm_communication.py`: Handles communication with OpenAI's LLMs, sending prompts and processing responses. 10 | 11 | ### Data Generation 12 | - `generate_data.py`: Orchestrates the generation of training data by integrating file handling, LLM communication, and data formatting. 13 | 14 | 15 | ### Usage 16 | 17 | The project includes a `Makefile` for easy management of common tasks. Here are the main commands you can use: 18 | 19 | - `make help`: Displays help for each make command. 20 | - `make local-start`: Build and start mongodb, mq and qdrant. 21 | - `make local-test-github`: Insert data to mongodb 22 | - `make generate-dataset`: Generate dataset for finetuning and version it in CometML -------------------------------------------------------------------------------- /INSTALL_AND_USAGE.md: -------------------------------------------------------------------------------- 1 | # Local Install 2 | 3 | ## System dependencies 4 | 5 | Before starting to install the LLM Twin project, make sure you have installed the following dependencies on your system: 6 | 7 | - [Docker ">=v27.0.3"](https://www.docker.com/) 8 | - [GNU Make ">=3.81"](https://www.gnu.org/software/make/) 9 | 10 | The whole LLM Twin application will be run locally using Docker. 11 | 12 | ## Configure 13 | 14 | All the sensitive credentials are placed in a `.env` file that will always sit on your hardware. 15 | 16 | Go to the root of the repository, copy our `.env.example` file and fill it with your credentials: 17 | ```shell 18 | cp .env.example .env 19 | ``` 20 | 21 | ## Supported commands 22 | 23 | We will use `GNU Make` to install and run our application. 24 | 25 | To see all our supported commands, run the following: 26 | ```shell 27 | make help 28 | ``` 29 | 30 | ## Set up the infrastructure 31 | 32 | ### Spin up the infrastructure 33 | 34 | Now, the whole infrastructure can be spun up using a simple Make command: 35 | 36 | ```shell 37 | make local-start 38 | ``` 39 | 40 | Behind the scenes it will build and run all the Docker images defined in the [docker-compose.yml](https://github.com/decodingml/llm-twin-course/blob/main/docker-compose.yml) file. 41 | 42 | ## Read this before starting 🚨 43 | 44 | > [!CAUTION] 45 | > For `Mongo` to work with multiple replicas (as we use it in our Docker setup) on `macOS` or `Linux` systems, you have to add the following lines of code to `/etc/hosts`: 46 | > 47 | > ``` 48 | > 127.0.0.1 mongo1 49 | > 127.0.0.1 mongo2 50 | > 127.0.0.1 mongo3 51 | > ``` 52 | > 53 | > From what we know, on `Windows`, it `works out-of-the-box`. For more details, check out this article: https://medium.com/workleap/the-only-local-mongodb-replica-set-with-docker-compose-guide-youll-ever-need-2f0b74dd8384 54 | 55 | > [!WARNING] 56 | > For `arm` users (e.g., `M1/M2/M3 macOS devices`), go to your Docker desktop application and enable `Use Rosetta for x86_64/amd64 emulation on Apple Silicon` from the Settings. There is a checkbox you have to check. 57 | > Otherwise, your Docker containers will crash. 58 | 59 | ### Tear down the infrastructure 60 | 61 | Run the following `Make` command to tear down all your docker containers: 62 | 63 | ```shell 64 | make local-stop 65 | ``` 66 | 67 | ## Run an end-to-end flow 68 | 69 | Now that we have configured our credentials and started our infrastructure let's look at how to run an end-to-end flow of the LLM Twin application. 70 | 71 | > [!IMPORTANT] 72 | > Note that we won't go into the details of the system here. To fully understand it, check out our free article series, which explains everything step-by-step: [LLM Twin articles series](https://medium.com/decodingml/llm-twin-course/home). 73 | 74 | ### Step 1: Crawlers 75 | 76 | Trigger the crawler to collect data and add it to the MongoDB: 77 | 78 | ```shell 79 | make local-test-github 80 | # or make local-test-medium 81 | ``` 82 | 83 | After the data is added to Mongo, the CDC component will be triggered, which will populate the RabbitMQ with the event. 84 | 85 | ### Step 2: Feature engineering & Vector DB 86 | 87 | Check that the feature pipeline works and the vector DB is successfully populated. 88 | 89 | To check the `feature pipeline`, check the logs of the `llm-twin-bytewax` Docker container by running: 90 | ```shell 91 | docker logs llm-twin-bytewax 92 | ``` 93 | You should see logs reflecting the cleaning, chunking, and embedding operations (without any errors, of course). 94 | 95 | To check that the Qdrant `vector DB` is populated successfully, go to its dashboard at [localhost:6333/dashboard](localhost:6333/dashboard). There, you should see the repositories or article collections created and populated. 96 | 97 | > [!NOTE] 98 | > If using the cloud version of Qdrant, go to your Qdrant account and cluster to see the same thing as in the local dashboard. 99 | 100 | ### Step 3: RAG retrieval step 101 | 102 | Now that we have some data in our vector DB, let's test out the RAG retriever: 103 | ```shell 104 | make local-test-retriever 105 | ``` 106 | 107 | > [!IMPORTANT] 108 | > Before running this command, check [Qdrant's dashboard](localhost:6333/dashboard) to ensure that your vector DB is populated with data. 109 | 110 | > [!NOTE] 111 | > For more details on the RAG component, please refer to the [RAG](https://github.com/decodingml/llm-twin-course/blob/main/RAG.md) document. 112 | 113 | 114 | ### Step 4: Generate the instruct dataset 115 | 116 | The last step, before fine-tuning is to generate an instruct dataset and track it as an artifact in Comet ML. To do so, run: 117 | ```shell 118 | make generate-dataset 119 | ``` 120 | 121 | > [!IMPORTANT] 122 | > Now open [Comet ML](https://www.comet.com/signup/?utm_source=decoding_ml&utm_medium=partner&utm_content=github), go to your workspace, and open the `Artifacts` tab. There, you should find three artifacts as follows: 123 | > - `articles-instruct-dataset` 124 | > - `posts-instruct-dataset` 125 | > - `repositories-instruct-dataset` 126 | 127 | > [!NOTE] 128 | > For more details on generating the instruct dataset component, please refer to the [GENERATE_INSTRUCT_DATASET](https://github.com/decodingml/llm-twin-course/blob/main/GENERATE_INSTRUCT_DATASET.md) document. 129 | 130 | 131 | ### Step 5: Fine-tuning 132 | 133 | For details on setting up the training pipeline on [Qwak](https://www.qwak.com/lp/end-to-end-mlops/?utm_source=github&utm_medium=referral&utm_campaign=decodingml) and running it, please refer to the [TRAINING](https://github.com/decodingml/llm-twin-course/blob/main/TRAINING.md) document. 134 | 135 | ### Step 6: Inference 136 | 137 | After you have finetuned your model, the first step is to deploy the inference pipeline to Qwak as a REST API service: 138 | ```shell 139 | deploy-inference-pipeline 140 | ``` 141 | 142 | > [!NOTE] 143 | > You can check out the progress of the deployment on [Qwak](https://www.qwak.com/lp/end-to-end-mlops/?utm_source=github&utm_medium=referral&utm_campaign=decodingml). 144 | 145 | After the deployment is finished (it will take a while), you can call it by calling: 146 | ```shell 147 | make call-inference-pipeline 148 | ``` 149 | 150 | Ultimately, after you stop using it, make sure to delete the deployment by running: 151 | ```shell 152 | make undeploy-infernece-pipeline 153 | ``` 154 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Decoding ML 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | include .env 2 | 3 | $(eval export $(shell sed -ne 's/ *#.*$$//; /./ s/=.*$$// p' .env)) 4 | 5 | AWS_CURRENT_REGION_ID := $(shell aws configure get region) 6 | AWS_CURRENT_ACCOUNT_ID := $(shell aws sts get-caller-identity --query "Account" --output text) 7 | 8 | PYTHONPATH := $(shell pwd) 9 | 10 | .PHONY: build-all env-var 11 | 12 | RED := \033[0;31m 13 | BLUE := \033[0;34m 14 | GREEN := \033[0;32m 15 | YELLOW := \033[0;33m 16 | RESET := \033[0m 17 | 18 | env-var: 19 | @echo "Environment variable VAR is: ${RABBITMQ_HOST}" 20 | 21 | help: 22 | @grep -E '^[a-zA-Z0-9 -]+:.*#' Makefile | sort | while read -r l; do printf "\033[1;32m$$(echo $$l | cut -f 1 -d':')\033[00m:$$(echo $$l | cut -f 2- -d'#')\n"; done 23 | 24 | 25 | # ------ Infrastructure ------ 26 | 27 | push: # Build & push image to docker ECR (e.g make push IMAGE_TAG=latest) 28 | echo "Logging into AWS ECR..." 29 | aws ecr get-login-password --region $(AWS_CURRENT_REGION_ID) | docker login --username AWS --password-stdin $(AWS_CURRENT_ACCOUNT_ID).dkr.ecr.$(AWS_CURRENT_REGION_ID).amazonaws.com 30 | echo "Build & Push Docker image..." 31 | docker buildx build --platform linux/amd64 -t $(AWS_CURRENT_ACCOUNT_ID).dkr.ecr.$(AWS_CURRENT_REGION_ID).amazonaws.com/crawler:$(IMAGE_TAG) . 32 | echo "Push completed successfully." 33 | 34 | local-start: # Buil and start local infrastructure. 35 | docker compose -f docker-compose.yml up --build -d 36 | 37 | local-stop: # Stop local infrastructure. 38 | docker compose -f docker-compose.yml down --remove-orphans 39 | 40 | 41 | # ------ Crawler ------ 42 | 43 | local-test-medium: # Send test command on local to test the lambda with a Medium article 44 | curl -X POST "http://localhost:9010/2015-03-31/functions/function/invocations" \ 45 | -d '{"user": "Paul Iuztin", "link": "https://medium.com/decodingml/an-end-to-end-framework-for-production-ready-llm-systems-by-building-your-llm-twin-2cc6bb01141f"}' 46 | 47 | local-test-github: # Send test command on local to test the lambda with a Github repository 48 | curl -X POST "http://localhost:9010/2015-03-31/functions/function/invocations" \ 49 | -d '{"user": "Paul Iuztin", "link": "https://github.com/decodingml/llm-twin-course"}' 50 | 51 | cloud-test-github: # Send command to the cloud lambda with a Github repository 52 | aws lambda invoke \ 53 | --function-name crawler \ 54 | --cli-binary-format raw-in-base64-out \ 55 | --payload '{"user": "Paul Iuztin", "link": "https://github.com/decodingml/llm-twin-course"}' \ 56 | response.json 57 | 58 | # ------ RAG Feature Pipeline ------ 59 | 60 | local-feature-pipeline: # Run the RAG feature pipeline 61 | RUST_BACKTRACE=full poetry run python -m bytewax.run 3-feature-pipeline/main.py 62 | 63 | generate-dataset: # Generate dataset for finetuning and version it in Comet ML 64 | docker exec -it llm-twin-bytewax python -m finetuning.generate_data 65 | 66 | # ------ RAG ------ 67 | 68 | local-test-retriever: # Test retriever 69 | docker exec -it llm-twin-bytewax python -m retriever 70 | 71 | # ------ Qwak: Training pipeline ------ 72 | 73 | create-qwak-project: # Create Qwak project for serving the model 74 | @echo "$(YELLOW)Creating Qwak project $(RESET)" 75 | qwak models create "llm_twin" --project "llm-twin-course" 76 | 77 | local-test-training-pipeline: # Test Qwak model locally 78 | poetry run python test_local.py 79 | 80 | deploy-training-pipeline: # Deploy the model to Qwak 81 | @echo "$(YELLOW)Dumping poetry env requirements to $(RESET) $(GREEN) requirements.txt $(RESET)" 82 | poetry export -f requirements.txt --output finetuning/requirements.txt --without-hashes 83 | @echo "$(GREEN)Triggering Qwak Model Build$(RESET)" 84 | poetry run qwak models build -f build_config.yaml . 85 | 86 | 87 | # ------ Qwak: Inference pipeline ------ 88 | 89 | deploy-inference-pipeline: # Deploy the inference pipeline to Qwak. 90 | poetry run qwak models deploy realtime --model-id "llm_twin" --instance "gpu.a10.2xl" --timeout 50000 --replicas 2 --server-workers 2 91 | 92 | undeploy-infernece-pipeline: # Remove the inference pipeline deployment from Qwak. 93 | poetry run qwak models undeploy --model-id "llm_twin" 94 | 95 | call-inference-pipeline: # Call the inference pipeline. 96 | poetry run python main.py 97 | 98 | # ------ Superlinked Bonus Series ------ 99 | 100 | local-start-superlinked: # Buil and start local infrastructure used in the Superlinked series. 101 | docker compose -f docker-compose-superlinked.yml up --build -d 102 | 103 | local-stop-superlinked: # Stop local infrastructure used in the Superlinked series. 104 | docker compose -f docker-compose-superlinked.yml down --remove-orphans 105 | 106 | test-superlinked-server: 107 | poetry run python 6-bonus-superlinked-rag/local_test.py 108 | 109 | local-bytewax-superlinked: # Run bytewax pipeline powered by superlinked 110 | RUST_BACKTRACE=full poetry run python -m bytewax.run 6-bonus-superlinked-rag/main.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

LLM Twin: Building Production-Ready AI Replica

3 |

Built production-ready LLM & RAG system by building LLM Twin

4 |

From data gathering to productionizing LLMs using LLMOps good practices.

5 |
6 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class Settings(BaseSettings): 5 | model_config = SettingsConfigDict(env_file="../.env", env_file_encoding="utf-8") 6 | 7 | MONGO_DATABASE_HOST: str = ( 8 | "mongodb://mongo1:30001,mongo2:30002,mongo3:30003/?replicaSet=my-replica-set" 9 | ) 10 | MONGO_DATABASE_NAME: str = "scrabble" 11 | 12 | # Optional LinkedIn credentials for scraping your profile 13 | LINKEDIN_USERNAME: str | None = None 14 | LINKEDIN_PASSWORD: str | None = None 15 | 16 | 17 | settings = Settings() 18 | -------------------------------------------------------------------------------- /crawlers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/crawlers/__init__.py -------------------------------------------------------------------------------- /crawlers/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | from abc import ABC, abstractmethod 3 | from tempfile import mkdtemp 4 | 5 | from db.documents import BaseDocument 6 | from selenium import webdriver 7 | from selenium.webdriver.chrome.options import Options 8 | 9 | 10 | class BaseCrawler(ABC): 11 | model: type[BaseDocument] 12 | 13 | @abstractmethod 14 | def extract(self, link: str, **kwargs) -> None: ... 15 | 16 | 17 | class BaseAbstractCrawler(BaseCrawler, ABC): 18 | def __init__(self, scroll_limit: int = 5) -> None: 19 | options = webdriver.ChromeOptions() 20 | options.binary_location = "/opt/chrome/chrome" 21 | options.add_argument("--no-sandbox") 22 | options.add_argument("--headless=new") 23 | options.add_argument("--single-process") 24 | options.add_argument("--disable-dev-shm-usage") 25 | options.add_argument("--disable-gpu") 26 | options.add_argument("--log-level=3") 27 | options.add_argument("--disable-popup-blocking") 28 | options.add_argument("--disable-notifications") 29 | options.add_argument("--disable-dev-tools") 30 | options.add_argument("--ignore-certificate-errors") 31 | options.add_argument("--no-zygote") 32 | options.add_argument(f"--user-data-dir={mkdtemp()}") 33 | options.add_argument(f"--data-path={mkdtemp()}") 34 | options.add_argument(f"--disk-cache-dir={mkdtemp()}") 35 | options.add_argument("--remote-debugging-port=9222") 36 | 37 | self.set_extra_driver_options(options) 38 | 39 | self.scroll_limit = scroll_limit 40 | self.driver = webdriver.Chrome( 41 | service=webdriver.ChromeService("/opt/chromedriver"), 42 | options=options, 43 | ) 44 | 45 | def set_extra_driver_options(self, options: Options) -> None: 46 | pass 47 | 48 | def login(self) -> None: 49 | pass 50 | 51 | def scroll_page(self) -> None: 52 | """Scroll through the LinkedIn page based on the scroll limit.""" 53 | current_scroll = 0 54 | last_height = self.driver.execute_script("return document.body.scrollHeight") 55 | while True: 56 | self.driver.execute_script( 57 | "window.scrollTo(0, document.body.scrollHeight);" 58 | ) 59 | time.sleep(5) 60 | new_height = self.driver.execute_script("return document.body.scrollHeight") 61 | if new_height == last_height or ( 62 | self.scroll_limit and current_scroll >= self.scroll_limit 63 | ): 64 | break 65 | last_height = new_height 66 | current_scroll += 1 67 | -------------------------------------------------------------------------------- /crawlers/github.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import tempfile 5 | 6 | from aws_lambda_powertools import Logger 7 | 8 | from crawlers.base import BaseCrawler 9 | from db.documents import RepositoryDocument 10 | 11 | logger = Logger(service="llm-twin-course/crawler") 12 | 13 | 14 | class GithubCrawler(BaseCrawler): 15 | model = RepositoryDocument 16 | 17 | def __init__(self, ignore=(".git", ".toml", ".lock", ".png")) -> None: 18 | super().__init__() 19 | self._ignore = ignore 20 | 21 | def extract(self, link: str, **kwargs) -> None: 22 | logger.info(f"Starting scrapping GitHub repository: {link}") 23 | 24 | repo_name = link.rstrip("/").split("/")[-1] 25 | 26 | local_temp = tempfile.mkdtemp() 27 | 28 | try: 29 | os.chdir(local_temp) 30 | subprocess.run(["git", "clone", link]) 31 | 32 | repo_path = os.path.join(local_temp, os.listdir(local_temp)[0]) 33 | 34 | tree = {} 35 | for root, dirs, files in os.walk(repo_path): 36 | dir = root.replace(repo_path, "").lstrip("/") 37 | if dir.startswith(self._ignore): 38 | continue 39 | 40 | for file in files: 41 | if file.endswith(self._ignore): 42 | continue 43 | file_path = os.path.join(dir, file) 44 | with open(os.path.join(root, file), "r", errors="ignore") as f: 45 | tree[file_path] = f.read().replace(" ", "") 46 | 47 | instance = self.model( 48 | name=repo_name, link=link, content=tree, owner_id=kwargs.get("user") 49 | ) 50 | instance.save() 51 | 52 | except Exception: 53 | raise 54 | finally: 55 | shutil.rmtree(local_temp) 56 | 57 | logger.info(f"Finished scrapping GitHub repository: {link}") 58 | -------------------------------------------------------------------------------- /crawlers/linkedin.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Dict, List 3 | 4 | from aws_lambda_powertools import Logger 5 | from bs4 import BeautifulSoup 6 | from bs4.element import Tag 7 | from errors import ImproperlyConfigured 8 | from selenium.webdriver.common.by import By 9 | 10 | from db.documents import PostDocument 11 | from crawlers.base import BaseAbstractCrawler 12 | from config import settings 13 | 14 | logger = Logger(service="decodingml/crawler") 15 | 16 | 17 | class LinkedInCrawler(BaseAbstractCrawler): 18 | model = PostDocument 19 | 20 | def set_extra_driver_options(self, options) -> None: 21 | options.add_experimental_option("detach", True) 22 | 23 | def extract(self, link: str, **kwargs): 24 | logger.info(f"Starting scrapping data for profile: {link}") 25 | 26 | self.login() 27 | 28 | soup = self._get_page_content(link) 29 | 30 | data = { 31 | "Name": self._scrape_section(soup, "h1", class_="text-heading-xlarge"), 32 | "About": self._scrape_section(soup, "div", class_="display-flex ph5 pv3"), 33 | "Main Page": self._scrape_section(soup, "div", {"id": "main-content"}), 34 | "Experience": self._scrape_experience(link), 35 | "Education": self._scrape_education(link), 36 | } 37 | 38 | self.driver.get(link) 39 | time.sleep(5) 40 | button = self.driver.find_element( 41 | By.CSS_SELECTOR, 42 | ".app-aware-link.profile-creator-shared-content-view__footer-action", 43 | ) 44 | button.click() 45 | 46 | # Scrolling and scraping posts 47 | self.scroll_page() 48 | soup = BeautifulSoup(self.driver.page_source, "html.parser") 49 | post_elements = soup.find_all( 50 | "div", 51 | class_="update-components-text relative update-components-update-v2__commentary", 52 | ) 53 | buttons = soup.find_all("button", class_="update-components-image__image-link") 54 | post_images = self._extract_image_urls(buttons) 55 | 56 | posts = self._extract_posts(post_elements, post_images) 57 | logger.info(f"Found {len(posts)} posts for profile: {link}") 58 | 59 | self.driver.close() 60 | 61 | self.model.bulk_insert( 62 | [ 63 | PostDocument( 64 | platform="linkedin", content=post, author_id=kwargs.get("user") 65 | ) 66 | for post in posts 67 | ] 68 | ) 69 | 70 | logger.info(f"Finished scrapping data for profile: {link}") 71 | 72 | def _scrape_section(self, soup: BeautifulSoup, *args, **kwargs) -> str: 73 | """Scrape a specific section of the LinkedIn profile.""" 74 | # Example: Scrape the 'About' section 75 | parent_div = soup.find(*args, **kwargs) 76 | return parent_div.get_text(strip=True) if parent_div else "" 77 | 78 | def _extract_image_urls(self, buttons: List[Tag]) -> Dict[str, str]: 79 | """ 80 | Extracts image URLs from button elements. 81 | 82 | Args: 83 | buttons (List[Tag]): A list of BeautifulSoup Tag objects representing buttons. 84 | 85 | Returns: 86 | Dict[str, str]: A dictionary mapping post indexes to image URLs. 87 | """ 88 | post_images = {} 89 | for i, button in enumerate(buttons): 90 | img_tag = button.find("img") 91 | if img_tag and "src" in img_tag.attrs: 92 | post_images[f"Post_{i}"] = img_tag["src"] 93 | else: 94 | logger.warning("No image found in this button") 95 | return post_images 96 | 97 | def _get_page_content(self, url: str) -> BeautifulSoup: 98 | """Retrieve the page content of a given URL.""" 99 | self.driver.get(url) 100 | time.sleep(5) 101 | return BeautifulSoup(self.driver.page_source, "html.parser") 102 | 103 | def _extract_posts( 104 | self, post_elements: List[Tag], post_images: Dict[str, str] 105 | ) -> Dict[str, Dict[str, str]]: 106 | """ 107 | Extracts post texts and combines them with their respective images. 108 | 109 | Args: 110 | post_elements (List[Tag]): A list of BeautifulSoup Tag objects representing post elements. 111 | post_images (Dict[str, str]): A dictionary containing image URLs mapped by post index. 112 | 113 | Returns: 114 | Dict[str, Dict[str, str]]: A dictionary containing post data with text and optional image URL. 115 | """ 116 | posts_data = {} 117 | for i, post_element in enumerate(post_elements): 118 | post_text = post_element.get_text(strip=True, separator="\n") 119 | post_data = {"text": post_text} 120 | if f"Post_{i}" in post_images: 121 | post_data["image"] = post_images[f"Post_{i}"] 122 | posts_data[f"Post_{i}"] = post_data 123 | return posts_data 124 | 125 | def _scrape_experience(self, profile_url: str) -> str: 126 | """Scrapes the Experience section of the LinkedIn profile.""" 127 | self.driver.get(profile_url + "/details/experience/") 128 | time.sleep(5) 129 | soup = BeautifulSoup(self.driver.page_source, "html.parser") 130 | experience_content = soup.find("section", {"id": "experience-section"}) 131 | return experience_content.get_text(strip=True) if experience_content else "" 132 | 133 | def _scrape_education(self, profile_url: str) -> str: 134 | self.driver.get(profile_url + "/details/education/") 135 | time.sleep(5) 136 | soup = BeautifulSoup(self.driver.page_source, "html.parser") 137 | education_content = soup.find("section", {"id": "education-section"}) 138 | return education_content.get_text(strip=True) if education_content else "" 139 | 140 | def login(self): 141 | """Log in to LinkedIn.""" 142 | self.driver.get("https://www.linkedin.com/login") 143 | if not settings.LINKEDIN_USERNAME and not settings.LINKEDIN_PASSWORD: 144 | raise ImproperlyConfigured( 145 | "LinkedIn scraper requires an valid account to perform extraction" 146 | ) 147 | 148 | self.driver.find_element(By.ID, "username").send_keys( 149 | settings.LINKEDIN_USERNAME 150 | ) 151 | self.driver.find_element(By.ID, "password").send_keys( 152 | settings.LINKEDIN_PASSWORD 153 | ) 154 | self.driver.find_element( 155 | By.CSS_SELECTOR, ".login__form_action_container button" 156 | ).click() 157 | -------------------------------------------------------------------------------- /crawlers/medium.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Dict, List 3 | 4 | from aws_lambda_powertools import Logger 5 | from bs4 import BeautifulSoup 6 | from bs4.element import Tag 7 | from errors import ImproperlyConfigured 8 | from selenium.webdriver.common.by import By 9 | 10 | from db.documents import PostDocument 11 | from crawlers.base import BaseAbstractCrawler 12 | from config import settings 13 | 14 | logger = Logger(service="decodingml/crawler") 15 | 16 | 17 | class LinkedInCrawler(BaseAbstractCrawler): 18 | model = PostDocument 19 | 20 | def set_extra_driver_options(self, options) -> None: 21 | options.add_experimental_option("detach", True) 22 | 23 | def extract(self, link: str, **kwargs): 24 | logger.info(f"Starting scrapping data for profile: {link}") 25 | 26 | self.login() 27 | 28 | soup = self._get_page_content(link) 29 | 30 | data = { 31 | "Name": self._scrape_section(soup, "h1", class_="text-heading-xlarge"), 32 | "About": self._scrape_section(soup, "div", class_="display-flex ph5 pv3"), 33 | "Main Page": self._scrape_section(soup, "div", {"id": "main-content"}), 34 | "Experience": self._scrape_experience(link), 35 | "Education": self._scrape_education(link), 36 | } 37 | 38 | self.driver.get(link) 39 | time.sleep(5) 40 | button = self.driver.find_element( 41 | By.CSS_SELECTOR, 42 | ".app-aware-link.profile-creator-shared-content-view__footer-action", 43 | ) 44 | button.click() 45 | 46 | # Scrolling and scraping posts 47 | self.scroll_page() 48 | soup = BeautifulSoup(self.driver.page_source, "html.parser") 49 | post_elements = soup.find_all( 50 | "div", 51 | class_="update-components-text relative update-components-update-v2__commentary", 52 | ) 53 | buttons = soup.find_all("button", class_="update-components-image__image-link") 54 | post_images = self._extract_image_urls(buttons) 55 | 56 | posts = self._extract_posts(post_elements, post_images) 57 | logger.info(f"Found {len(posts)} posts for profile: {link}") 58 | 59 | self.driver.close() 60 | 61 | self.model.bulk_insert( 62 | [ 63 | PostDocument( 64 | platform="linkedin", content=post, author_id=kwargs.get("user") 65 | ) 66 | for post in posts 67 | ] 68 | ) 69 | 70 | logger.info(f"Finished scrapping data for profile: {link}") 71 | 72 | def _scrape_section(self, soup: BeautifulSoup, *args, **kwargs) -> str: 73 | """Scrape a specific section of the LinkedIn profile.""" 74 | # Example: Scrape the 'About' section 75 | parent_div = soup.find(*args, **kwargs) 76 | return parent_div.get_text(strip=True) if parent_div else "" 77 | 78 | def _extract_image_urls(self, buttons: List[Tag]) -> Dict[str, str]: 79 | """ 80 | Extracts image URLs from button elements. 81 | 82 | Args: 83 | buttons (List[Tag]): A list of BeautifulSoup Tag objects representing buttons. 84 | 85 | Returns: 86 | Dict[str, str]: A dictionary mapping post indexes to image URLs. 87 | """ 88 | post_images = {} 89 | for i, button in enumerate(buttons): 90 | img_tag = button.find("img") 91 | if img_tag and "src" in img_tag.attrs: 92 | post_images[f"Post_{i}"] = img_tag["src"] 93 | else: 94 | logger.warning("No image found in this button") 95 | return post_images 96 | 97 | def _get_page_content(self, url: str) -> BeautifulSoup: 98 | """Retrieve the page content of a given URL.""" 99 | self.driver.get(url) 100 | time.sleep(5) 101 | return BeautifulSoup(self.driver.page_source, "html.parser") 102 | 103 | def _extract_posts( 104 | self, post_elements: List[Tag], post_images: Dict[str, str] 105 | ) -> Dict[str, Dict[str, str]]: 106 | """ 107 | Extracts post texts and combines them with their respective images. 108 | 109 | Args: 110 | post_elements (List[Tag]): A list of BeautifulSoup Tag objects representing post elements. 111 | post_images (Dict[str, str]): A dictionary containing image URLs mapped by post index. 112 | 113 | Returns: 114 | Dict[str, Dict[str, str]]: A dictionary containing post data with text and optional image URL. 115 | """ 116 | posts_data = {} 117 | for i, post_element in enumerate(post_elements): 118 | post_text = post_element.get_text(strip=True, separator="\n") 119 | post_data = {"text": post_text} 120 | if f"Post_{i}" in post_images: 121 | post_data["image"] = post_images[f"Post_{i}"] 122 | posts_data[f"Post_{i}"] = post_data 123 | return posts_data 124 | 125 | def _scrape_experience(self, profile_url: str) -> str: 126 | """Scrapes the Experience section of the LinkedIn profile.""" 127 | self.driver.get(profile_url + "/details/experience/") 128 | time.sleep(5) 129 | soup = BeautifulSoup(self.driver.page_source, "html.parser") 130 | experience_content = soup.find("section", {"id": "experience-section"}) 131 | return experience_content.get_text(strip=True) if experience_content else "" 132 | 133 | def _scrape_education(self, profile_url: str) -> str: 134 | self.driver.get(profile_url + "/details/education/") 135 | time.sleep(5) 136 | soup = BeautifulSoup(self.driver.page_source, "html.parser") 137 | education_content = soup.find("section", {"id": "education-section"}) 138 | return education_content.get_text(strip=True) if education_content else "" 139 | 140 | def login(self): 141 | """Log in to LinkedIn.""" 142 | self.driver.get("https://www.linkedin.com/login") 143 | if not settings.LINKEDIN_USERNAME and not settings.LINKEDIN_PASSWORD: 144 | raise ImproperlyConfigured( 145 | "LinkedIn scraper requires an valid account to perform extraction" 146 | ) 147 | 148 | self.driver.find_element(By.ID, "username").send_keys( 149 | settings.LINKEDIN_USERNAME 150 | ) 151 | self.driver.find_element(By.ID, "password").send_keys( 152 | settings.LINKEDIN_PASSWORD 153 | ) 154 | self.driver.find_element( 155 | By.CSS_SELECTOR, ".login__form_action_container button" 156 | ).click() 157 | -------------------------------------------------------------------------------- /data-ingestion/cdc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | from bson import json_util 5 | from mq import publish_to_rabbitmq 6 | 7 | from config import settings 8 | from db import MongoDatabaseConnector 9 | 10 | # Configure logging 11 | logging.basicConfig( 12 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 13 | ) 14 | 15 | 16 | def stream_process(): 17 | try: 18 | # Setup MongoDB connection 19 | client = MongoDatabaseConnector() 20 | db = client["scrabble"] 21 | logging.info("Connected to MongoDB.") 22 | 23 | # Watch changes in a specific collection 24 | changes = db.watch([{"$match": {"operationType": {"$in": ["insert"]}}}]) 25 | for change in changes: 26 | data_type = change["ns"]["coll"] 27 | entry_id = str(change["fullDocument"]["_id"]) # Convert ObjectId to string 28 | change["fullDocument"].pop("_id") 29 | change["fullDocument"]["type"] = data_type 30 | change["fullDocument"]["entry_id"] = entry_id 31 | 32 | # Use json_util to serialize the document 33 | data = json.dumps(change["fullDocument"], default=json_util.default) 34 | logging.info(f"Change detected and serialized: {data}") 35 | 36 | # Send data to rabbitmq 37 | publish_to_rabbitmq(queue_name=settings.RABBITMQ_QUEUE_NAME, data=data) 38 | logging.info("Data published to RabbitMQ.") 39 | 40 | except Exception as e: 41 | logging.error(f"An error occurred: {e}") 42 | 43 | 44 | if __name__ == "__main__": 45 | stream_process() 46 | -------------------------------------------------------------------------------- /data-ingestion/db.py: -------------------------------------------------------------------------------- 1 | from pymongo import MongoClient 2 | from pymongo.errors import ConnectionFailure 3 | 4 | from config import settings 5 | 6 | 7 | class MongoDatabaseConnector: 8 | """Singleton class to connect to MongoDB database.""" 9 | 10 | _instance: MongoClient = None 11 | 12 | def __new__(cls, *args, **kwargs): 13 | if cls._instance is None: 14 | try: 15 | cls._instance = MongoClient(settings.MONGO_DATABASE_HOST) 16 | except ConnectionFailure as e: 17 | print(f"Couldn't connect to the database: {str(e)}") 18 | raise 19 | 20 | print( 21 | f"Connection to database with uri: {settings.MONGO_DATABASE_HOST} successful" 22 | ) 23 | return cls._instance 24 | 25 | def get_database(self): 26 | return self._instance[settings.MONGO_DATABASE_NAME] 27 | 28 | def close(self): 29 | if self._instance: 30 | self._instance.close() 31 | print("Connected to database has been closed.") 32 | 33 | 34 | connection = MongoDatabaseConnector() 35 | -------------------------------------------------------------------------------- /data-ingestion/mq.py: -------------------------------------------------------------------------------- 1 | import pika 2 | 3 | from config import settings 4 | 5 | 6 | class RabbitMQConnection: 7 | """Singleton class to manage RabbitMQ connection.""" 8 | 9 | _instance = None 10 | 11 | def __new__( 12 | cls, 13 | host: str = None, 14 | port: int = None, 15 | username: str = None, 16 | password: str = None, 17 | virtual_host: str = "/", 18 | ): 19 | if not cls._instance: 20 | cls._instance = super().__new__(cls) 21 | return cls._instance 22 | 23 | def __init__( 24 | self, 25 | host: str = None, 26 | port: int = None, 27 | username: str = None, 28 | password: str = None, 29 | virtual_host: str = "/", 30 | fail_silently: bool = False, 31 | **kwargs, 32 | ): 33 | self.host = host or settings.RABBITMQ_HOST 34 | self.port = port or settings.RABBITMQ_PORT 35 | self.username = username or settings.RABBITMQ_DEFAULT_USERNAME 36 | self.password = password or settings.RABBITMQ_DEFAULT_PASSWORD 37 | self.virtual_host = virtual_host 38 | self.fail_silently = fail_silently 39 | self._connection = None 40 | 41 | def __enter__(self): 42 | self.connect() 43 | return self 44 | 45 | def __exit__(self, exc_type, exc_val, exc_tb): 46 | self.close() 47 | 48 | def connect(self): 49 | try: 50 | credentials = pika.PlainCredentials(self.username, self.password) 51 | self._connection = pika.BlockingConnection( 52 | pika.ConnectionParameters( 53 | host=self.host, 54 | port=self.port, 55 | virtual_host=self.virtual_host, 56 | credentials=credentials, 57 | ) 58 | ) 59 | except pika.exceptions.AMQPConnectionError as e: 60 | print("Failed to connect to RabbitMQ:", e) 61 | if not self.fail_silently: 62 | raise e 63 | 64 | def is_connected(self) -> bool: 65 | return self._connection is not None and self._connection.is_open 66 | 67 | def get_channel(self): 68 | if self.is_connected(): 69 | return self._connection.channel() 70 | 71 | def close(self): 72 | if self.is_connected(): 73 | self._connection.close() 74 | self._connection = None 75 | print("Closed RabbitMQ connection") 76 | 77 | 78 | def publish_to_rabbitmq(queue_name: str, data: str): 79 | """Publish data to a RabbitMQ queue.""" 80 | try: 81 | # Create an instance of RabbitMQConnection 82 | rabbitmq_conn = RabbitMQConnection() 83 | 84 | # Establish connection 85 | with rabbitmq_conn: 86 | channel = rabbitmq_conn.get_channel() 87 | 88 | # Ensure the queue exists 89 | channel.queue_declare(queue=queue_name, durable=True) 90 | 91 | # Delivery confirmation 92 | channel.confirm_delivery() 93 | 94 | # Send data to the queue 95 | channel.basic_publish( 96 | exchange="", 97 | routing_key=queue_name, 98 | body=data, 99 | properties=pika.BasicProperties( 100 | delivery_mode=2, # make message persistent 101 | ), 102 | ) 103 | print("Sent data to RabbitMQ:", data) 104 | except pika.exceptions.UnroutableError: 105 | print("Message could not be routed") 106 | except Exception as e: 107 | print(f"Error publishing to RabbitMQ: {e}") 108 | 109 | 110 | if __name__ == "__main__": 111 | publish_to_rabbitmq("test_queue", "Hello, World!") 112 | -------------------------------------------------------------------------------- /db/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/J-coder118/LLM-Twin/707f9d8bb1cf402e04644bff9c5c521ce0938087/db/__init__.py -------------------------------------------------------------------------------- /db/documents.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import List, Optional 3 | 4 | from errors import ImproperlyConfigured 5 | from pydantic import UUID4, BaseModel, ConfigDict, Field 6 | from pymongo import errors 7 | from utils import get_logger 8 | 9 | from db.mongo import connection 10 | 11 | _database = connection.get_database("scrabble") 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | class BaseDocument(BaseModel): 17 | id: UUID4 = Field(default_factory=uuid.uuid4) 18 | 19 | model_config = ConfigDict(from_attributes=True, populate_by_name=True) 20 | 21 | @classmethod 22 | def from_mongo(cls, data: dict): 23 | """Convert "_id" (str object) into "id" (UUID object).""" 24 | if not data: 25 | return data 26 | 27 | id = data.pop("_id", None) 28 | return cls(**dict(data, id=id)) 29 | 30 | def to_mongo(self, **kwargs) -> dict: 31 | """Convert "id" (UUID object) into "_id" (str object).""" 32 | exclude_unset = kwargs.pop("exclude_unset", False) 33 | by_alias = kwargs.pop("by_alias", True) 34 | 35 | parsed = self.model_dump( 36 | exclude_unset=exclude_unset, by_alias=by_alias, **kwargs 37 | ) 38 | 39 | if "_id" not in parsed and "id" in parsed: 40 | parsed["_id"] = str(parsed.pop("id")) 41 | 42 | return parsed 43 | 44 | def save(self, **kwargs): 45 | collection = _database[self._get_collection_name()] 46 | 47 | try: 48 | result = collection.insert_one(self.to_mongo(**kwargs)) 49 | return result.inserted_id 50 | except errors.WriteError: 51 | logger.exception("Failed to insert document.") 52 | 53 | return None 54 | 55 | @classmethod 56 | def get_or_create(cls, **filter_options) -> Optional[str]: 57 | collection = _database[cls._get_collection_name()] 58 | try: 59 | instance = collection.find_one(filter_options) 60 | if instance: 61 | return str(cls.from_mongo(instance).id) 62 | new_instance = cls(**filter_options) 63 | new_instance = new_instance.save() 64 | return new_instance 65 | except errors.OperationFailure: 66 | logger.exception("Failed to retrieve or create document.") 67 | 68 | return None 69 | 70 | @classmethod 71 | def bulk_insert(cls, documents: List, **kwargs) -> Optional[List[str]]: 72 | collection = _database[cls._get_collection_name()] 73 | try: 74 | result = collection.insert_many( 75 | [doc.to_mongo(**kwargs) for doc in documents] 76 | ) 77 | return result.inserted_ids 78 | except errors.WriteError: 79 | logger.exception("Failed to insert documents.") 80 | 81 | return None 82 | 83 | @classmethod 84 | def _get_collection_name(cls): 85 | if not hasattr(cls, "Settings") or not hasattr(cls.Settings, "name"): 86 | raise ImproperlyConfigured( 87 | "Document should define an Settings configuration class with the name of the collection." 88 | ) 89 | 90 | return cls.Settings.name 91 | 92 | 93 | class UserDocument(BaseDocument): 94 | first_name: str 95 | last_name: str 96 | 97 | class Settings: 98 | name = "users" 99 | 100 | 101 | class RepositoryDocument(BaseDocument): 102 | name: str 103 | link: str 104 | content: dict 105 | owner_id: str = Field(alias="owner_id") 106 | 107 | class Settings: 108 | name = "repositories" 109 | 110 | 111 | class PostDocument(BaseDocument): 112 | platform: str 113 | content: dict 114 | author_id: str = Field(alias="author_id") 115 | 116 | class Settings: 117 | name = "posts" 118 | 119 | 120 | class ArticleDocument(BaseDocument): 121 | platform: str 122 | link: str 123 | content: dict 124 | author_id: str = Field(alias="author_id") 125 | 126 | class Settings: 127 | name = "articles" 128 | -------------------------------------------------------------------------------- /db/mongo.py: -------------------------------------------------------------------------------- 1 | from pymongo import MongoClient 2 | from pymongo.errors import ConnectionFailure 3 | 4 | from config import settings 5 | 6 | 7 | class MongoDatabaseConnector: 8 | """Singleton class to connect to MongoDB database.""" 9 | 10 | _instance: MongoClient = None 11 | 12 | def __new__(cls, *args, **kwargs): 13 | if cls._instance is None: 14 | try: 15 | cls._instance = MongoClient(settings.MONGO_DATABASE_HOST) 16 | except ConnectionFailure as e: 17 | print(f"Couldn't connect to the database: {str(e)}") 18 | raise 19 | 20 | print( 21 | f"Connection to database with uri: {settings.MONGO_DATABASE_HOST} successful" 22 | ) 23 | return cls._instance 24 | 25 | def get_database(self): 26 | return self._instance[settings.MONGO_DATABASE_NAME] 27 | 28 | def close(self): 29 | if self._instance: 30 | self._instance.close() 31 | print("Connected to database has been closed.") 32 | 33 | 34 | connection = MongoDatabaseConnector() 35 | -------------------------------------------------------------------------------- /dispatcher.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from crawlers.base import BaseCrawler 4 | 5 | 6 | class CrawlerDispatcher: 7 | def __init__(self) -> None: 8 | self._crawlers = {} 9 | 10 | def register(self, domain: str, crawler: type[BaseCrawler]) -> None: 11 | self._crawlers[r"https://(www\.)?{}.com/*".format(re.escape(domain))] = crawler 12 | 13 | def get_crawler(self, url: str) -> BaseCrawler: 14 | for pattern, crawler in self._crawlers.items(): 15 | if re.match(pattern, url): 16 | return crawler() 17 | else: 18 | raise ValueError("No crawler found for the provided link") 19 | -------------------------------------------------------------------------------- /docker-bake.hcl: -------------------------------------------------------------------------------- 1 | group "default" { 2 | targets = ["bytewax", "cdc"] 3 | } 4 | 5 | target "bytewax" { 6 | context = "." 7 | dockerfile = ".docker/Dockerfile.bytewax" 8 | } 9 | 10 | target "cdc" { 11 | context = "." 12 | dockerfile = ".docker/Dockerfile.cdc" 13 | } 14 | -------------------------------------------------------------------------------- /docker-compose-superlinked.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | mongo1: 5 | image: mongo:5 6 | container_name: llm-twin-mongo1 7 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30001"] 8 | volumes: 9 | - mongo-replica-1-data:/data/db 10 | ports: 11 | - "30001:30001" 12 | healthcheck: 13 | test: test $$(echo "rs.initiate({_id:'my-replica-set',members:[{_id:0,host:\"mongo1:30001\"},{_id:1,host:\"mongo2:30002\"},{_id:2,host:\"mongo3:30003\"}]}).ok || rs.status().ok" | mongo --port 30001 --quiet) -eq 1 14 | interval: 10s 15 | start_period: 30s 16 | restart: always 17 | networks: 18 | - server_default 19 | 20 | mongo2: 21 | image: mongo:5 22 | container_name: llm-twin-mongo2 23 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30002"] 24 | volumes: 25 | - mongo-replica-2-data:/data/db 26 | ports: 27 | - "30002:30002" 28 | restart: always 29 | networks: 30 | - server_default 31 | 32 | mongo3: 33 | image: mongo:5 34 | container_name: llm-twin-mongo3 35 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30003"] 36 | volumes: 37 | - mongo-replica-3-data:/data/db 38 | ports: 39 | - "30003:30003" 40 | restart: always 41 | networks: 42 | - server_default 43 | 44 | mq: 45 | image: rabbitmq:3-management-alpine 46 | container_name: llm-twin-mq 47 | ports: 48 | - "5672:5672" 49 | - "15672:15672" 50 | volumes: 51 | - ~/rabbitmq/data/:/var/lib/rabbitmq/ 52 | - ~/rabbitmq/log/:/var/log/rabbitmq 53 | healthcheck: 54 | test: ["CMD", "rabbitmqctl", "ping"] 55 | interval: 30s 56 | timeout: 10s 57 | retries: 5 58 | restart: always 59 | networks: 60 | - server_default 61 | 62 | crawler: 63 | image: "llm-twin-crawler" 64 | container_name: llm-twin-crawler 65 | platform: "linux/amd64" 66 | build: 67 | context: . 68 | dockerfile: .docker/Dockerfile.crawlers 69 | env_file: 70 | - .env 71 | ports: 72 | - "9010:8080" 73 | depends_on: 74 | - mongo1 75 | - mongo2 76 | - mongo3 77 | networks: 78 | - server_default 79 | 80 | cdc: 81 | image: "llm-twin-cdc" 82 | container_name: llm-twin-cdc 83 | build: 84 | context: . 85 | dockerfile: .docker/Dockerfile.cdc 86 | env_file: 87 | - .env 88 | depends_on: 89 | - mongo1 90 | - mongo2 91 | - mongo3 92 | - mq 93 | networks: 94 | - server_default 95 | 96 | bytewax: 97 | image: "llm-twin-bytewax-superlinked" 98 | container_name: llm-twin-bytewax-superlinked 99 | build: 100 | context: . 101 | dockerfile: .docker/Dockerfile.bytewax.superlinked 102 | environment: 103 | BYTEWAX_PYTHON_FILE_PATH: "main:flow" 104 | DEBUG: "false" 105 | BYTEWAX_KEEP_CONTAINER_ALIVE: "false" 106 | env_file: 107 | - .env 108 | depends_on: 109 | - mq 110 | restart: on-failure 111 | networks: 112 | - server_default 113 | 114 | volumes: 115 | mongo-replica-1-data: 116 | mongo-replica-2-data: 117 | mongo-replica-3-data: 118 | 119 | networks: 120 | server_default: 121 | external: true 122 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | 3 | services: 4 | mongo1: 5 | image: mongo:5 6 | container_name: llm-twin-mongo1 7 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30001"] 8 | volumes: 9 | - mongo-replica-1-data:/data/db 10 | ports: 11 | - "30001:30001" 12 | healthcheck: 13 | test: test $$(echo "rs.initiate({_id:'my-replica-set',members:[{_id:0,host:\"mongo1:30001\"},{_id:1,host:\"mongo2:30002\"},{_id:2,host:\"mongo3:30003\"}]}).ok || rs.status().ok" | mongo --port 30001 --quiet) -eq 1 14 | interval: 10s 15 | start_period: 30s 16 | restart: always 17 | 18 | mongo2: 19 | image: mongo:5 20 | container_name: llm-twin-mongo2 21 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30002"] 22 | volumes: 23 | - mongo-replica-2-data:/data/db 24 | ports: 25 | - "30002:30002" 26 | restart: always 27 | 28 | mongo3: 29 | image: mongo:5 30 | container_name: llm-twin-mongo3 31 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30003"] 32 | volumes: 33 | - mongo-replica-3-data:/data/db 34 | ports: 35 | - "30003:30003" 36 | restart: always 37 | 38 | mq: 39 | image: rabbitmq:3-management-alpine 40 | container_name: llm-twin-mq 41 | ports: 42 | - "5673:5672" 43 | - "15673:15672" 44 | volumes: 45 | - ~/rabbitmq/data/:/var/lib/rabbitmq/ 46 | - ~/rabbitmq/log/:/var/log/rabbitmq 47 | restart: always 48 | 49 | qdrant: 50 | image: qdrant/qdrant:latest 51 | container_name: llm-twin-qdrant 52 | ports: 53 | - "6333:6333" 54 | - "6334:6334" 55 | expose: 56 | - "6333" 57 | - "6334" 58 | - "6335" 59 | volumes: 60 | - qdrant-data:/qdrant_data 61 | restart: always 62 | 63 | crawler: 64 | image: "llm-twin-crawler" 65 | container_name: llm-twin-crawler 66 | platform: "linux/amd64" 67 | build: 68 | context: . 69 | dockerfile: .docker/Dockerfile.crawlers 70 | env_file: 71 | - .env 72 | ports: 73 | - "9010:8080" 74 | depends_on: 75 | - mongo1 76 | - mongo2 77 | - mongo3 78 | 79 | cdc: 80 | image: "llm-twin-cdc" 81 | container_name: llm-twin-cdc 82 | build: 83 | context: . 84 | dockerfile: .docker/Dockerfile.cdc 85 | env_file: 86 | - .env 87 | depends_on: 88 | - mongo1 89 | - mongo2 90 | - mongo3 91 | - mq 92 | 93 | bytewax: 94 | image: "llm-twin-bytewax" 95 | container_name: llm-twin-bytewax 96 | build: 97 | context: . 98 | dockerfile: .docker/Dockerfile.bytewax 99 | environment: 100 | BYTEWAX_PYTHON_FILE_PATH: "main:flow" 101 | DEBUG: "false" 102 | BYTEWAX_KEEP_CONTAINER_ALIVE: "true" 103 | env_file: 104 | - .env 105 | depends_on: 106 | - mq 107 | - qdrant 108 | restart: on-failure 109 | 110 | volumes: 111 | mongo-replica-1-data: 112 | mongo-replica-2-data: 113 | mongo-replica-3-data: 114 | qdrant-data: 115 | -------------------------------------------------------------------------------- /errors.py: -------------------------------------------------------------------------------- 1 | class ScrabbleException(Exception): 2 | pass 3 | 4 | 5 | class ImproperlyConfigured(ScrabbleException): 6 | pass 7 | -------------------------------------------------------------------------------- /lib.py: -------------------------------------------------------------------------------- 1 | from errors import ImproperlyConfigured 2 | 3 | 4 | def user_to_names(user: str | None) -> tuple[str, str]: 5 | if user is None: 6 | raise ImproperlyConfigured("User name is empty") 7 | 8 | name_tokens = user.split(" ") 9 | if len(name_tokens) == 0: 10 | raise ImproperlyConfigured("User name is empty") 11 | elif len(name_tokens) == 1: 12 | first_name, last_name = name_tokens[0], name_tokens[0] 13 | else: 14 | first_name, last_name = " ".join(name_tokens[:-1]), name_tokens[-1] 15 | 16 | return first_name, last_name 17 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from aws_lambda_powertools import Logger 4 | from aws_lambda_powertools.utilities.typing import LambdaContext 5 | 6 | import lib 7 | from crawlers import GithubCrawler, LinkedInCrawler, MediumCrawler 8 | from db.documents import UserDocument 9 | from dispatcher import CrawlerDispatcher 10 | 11 | logger = Logger(service="decodingml/crawler") 12 | 13 | _dispatcher = CrawlerDispatcher() 14 | _dispatcher.register("medium", MediumCrawler) 15 | _dispatcher.register("linkedin", LinkedInCrawler) 16 | _dispatcher.register("github", GithubCrawler) 17 | 18 | 19 | def handler(event, context: LambdaContext) -> dict[str, Any]: 20 | first_name, last_name = lib.user_to_names(event.get("user")) 21 | 22 | user = UserDocument.get_or_create(first_name=first_name, last_name=last_name) 23 | 24 | link = event.get("link") 25 | crawler = _dispatcher.get_crawler(link) 26 | 27 | try: 28 | crawler.extract(link=link, user=user) 29 | 30 | return {"statusCode": 200, "body": "Link processed successfully"} 31 | except Exception as e: 32 | return {"statusCode": 500, "body": f"An error occurred: {str(e)}"} 33 | 34 | 35 | if __name__ == "__main__": 36 | event = { 37 | "user": "Paul Iuztin", 38 | "link": "https://www.linkedin.com/in/vesaalexandru/", 39 | } 40 | handler(event, None) 41 | -------------------------------------------------------------------------------- /ops/.gitignore: -------------------------------------------------------------------------------- 1 | /bin/ 2 | /node_modules/ 3 | -------------------------------------------------------------------------------- /ops/Pulumi.yaml: -------------------------------------------------------------------------------- 1 | name: decodingml 2 | runtime: nodejs 3 | description: AWS Cloud Infrastructure for Twin LMM Course 4 | config: 5 | pulumi:tags: 6 | value: 7 | pulumi:template: "" 8 | -------------------------------------------------------------------------------- /ops/components/cdc.ts: -------------------------------------------------------------------------------- 1 | //TBD -------------------------------------------------------------------------------- /ops/components/config.ts: -------------------------------------------------------------------------------- 1 | export const SubnetCidrBlocks = { 2 | Internet: '0.0.0.0/0', 3 | VPC: '10.0.0.0/16', 4 | PublicOne: '10.0.0.0/20', 5 | PublicTwo: '10.0.16.0/20', 6 | } as const; 7 | -------------------------------------------------------------------------------- /ops/components/crawler.ts: -------------------------------------------------------------------------------- 1 | import * as pulumi from "@pulumi/pulumi"; 2 | import * as aws from "@pulumi/aws"; 3 | 4 | export interface CrawlerProps { 5 | vpcId: pulumi.Input 6 | timeout: pulumi.Input 7 | memory: pulumi.Input 8 | } 9 | 10 | export class Crawler extends pulumi.ComponentResource { 11 | public readonly arn: pulumi.Output 12 | 13 | constructor ( 14 | name: string, 15 | props: CrawlerProps, 16 | opts?: pulumi.ComponentResourceOptions, 17 | ) { 18 | super("decodingml:main:Crawler", name, {}, opts); 19 | 20 | const accountId = pulumi.output(aws.getCallerIdentity()).accountId; 21 | const region = pulumi.output(aws.getRegion()).name; 22 | 23 | const lambdaExecutionRole = new aws.iam.Role(`${name}-role`, { 24 | assumeRolePolicy: JSON.stringify({ 25 | Version: "2012-10-17", 26 | Statement: [{ 27 | Effect: "Allow", 28 | Principal: { 29 | Service: "lambda.amazonaws.com", 30 | }, 31 | Action: "sts:AssumeRole", 32 | }], 33 | }), 34 | managedPolicyArns: [ 35 | aws.iam.ManagedPolicy.AmazonS3FullAccess, 36 | aws.iam.ManagedPolicy.AmazonDocDBFullAccess, 37 | aws.iam.ManagedPolicy.AWSLambdaBasicExecutionRole, 38 | aws.iam.ManagedPolicy.AWSLambdaVPCAccessExecutionRole, 39 | aws.iam.ManagedPolicy.CloudWatchLambdaInsightsExecutionRolePolicy 40 | ] 41 | }) 42 | 43 | const sg = new aws.ec2.SecurityGroup(`${name}-security-group`, { 44 | name: `${name}-sg`, 45 | description: "Crawler Lambda Access", 46 | vpcId: props.vpcId, 47 | egress: [{ 48 | protocol: "-1", 49 | description: "Allow all outbound traffic by default", 50 | fromPort: 0, 51 | toPort: 0, 52 | cidrBlocks: ["0.0.0.0/0"], 53 | }], 54 | tags: { 55 | Name: `${name}-sg` 56 | } 57 | }) 58 | 59 | const lambdaFunction = new aws.lambda.Function(`${name}-lambda-function`, { 60 | name: `${name}`, 61 | imageUri: pulumi.interpolate`${accountId}.dkr.ecr.${region}.amazonaws.com/crawler:latest`, 62 | packageType: 'Image', 63 | description: 'Crawler Lambda Function', 64 | timeout: props.timeout, 65 | memorySize: props.memory, 66 | role: lambdaExecutionRole.arn, 67 | vpcConfig: { 68 | subnetIds: pulumi.output(aws.ec2.getSubnets({tags: {Type: 'public'}})).ids, 69 | securityGroupIds: [sg.id], 70 | } 71 | }, {dependsOn: lambdaExecutionRole}) 72 | 73 | this.arn = lambdaFunction.arn 74 | 75 | this.registerOutputs({ 76 | arn: this.arn 77 | }) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /ops/components/docdb.ts: -------------------------------------------------------------------------------- 1 | import * as pulumi from "@pulumi/pulumi"; 2 | import * as aws from "@pulumi/aws"; 3 | 4 | 5 | export interface DocumentDBClusterProps { 6 | vpcId: pulumi.Input 7 | instanceClass?: pulumi.Input 8 | multiAZ?: pulumi.Input 9 | port?: pulumi.Input 10 | 11 | backupRetentionPeriod?: pulumi.Input 12 | } 13 | 14 | export class DocumentDBCluster extends pulumi.ComponentResource { 15 | 16 | constructor ( 17 | name: string, 18 | props: DocumentDBClusterProps, 19 | opts?: pulumi.ComponentResourceOptions, 20 | ) { 21 | super("decodingml:main:DocumentDBCluster", name, {}, opts); 22 | 23 | 24 | const subnetGroup = new aws.docdb.SubnetGroup(`${name}-docdb-subnet-group`, { 25 | name: `${name}-cluster-subnet-group`, 26 | description: `VPC subnet group for the ${name}-cluster`, 27 | subnetIds: pulumi.output(aws.ec2.getSubnets({tags: {Type: 'public'}})).ids, 28 | tags: { 29 | Name: `${name}-cluster-subnet-group` 30 | } 31 | }, {parent: this}) 32 | 33 | const securityGroup = new aws.ec2.SecurityGroup(`${name}-docdb-sg`, { 34 | name: `${name}-docdb-cluster-sg`, 35 | description: "Database access", 36 | vpcId: props.vpcId, 37 | tags: { 38 | Name: `${name}-docdb-cluster-sg` 39 | }, 40 | ingress: [ 41 | { 42 | description: "Ingress from anywhere", 43 | fromPort: props.port || 27017, 44 | toPort: props.port || 27017, 45 | protocol: "-1", 46 | }, 47 | ], 48 | egress: [{ 49 | protocol: "-1", 50 | description: "Allow all outbound traffic by default", 51 | fromPort: 0, 52 | toPort: 0, 53 | cidrBlocks: ["0.0.0.0/0"], 54 | }], 55 | }, {parent: this}) 56 | 57 | const cluster = new aws.docdb.Cluster(`${name}-docdb-cluster`, { 58 | // availabilityZones: pulumi.output(aws.getAvailabilityZones({state: "available"}) if props.multiAZ else 59 | backupRetentionPeriod: props.backupRetentionPeriod || 7, 60 | clusterIdentifier: `${name}-cluster`, 61 | masterUsername: pulumi.output(aws.ssm.getParameter({ name: `/${name}/cluster/master/username` })).value, 62 | masterPassword: pulumi.output(aws.ssm.getParameter({ name: `/${name}/cluster/master/password` })).value, 63 | engineVersion: "5.0.0", 64 | port: props.port || 27017, 65 | dbSubnetGroupName: subnetGroup.name, 66 | storageEncrypted: true, 67 | skipFinalSnapshot: true, 68 | vpcSecurityGroupIds: [ securityGroup.id ], 69 | tags: { 70 | Name: `${name}-cluster` 71 | } 72 | }, {parent: this}) 73 | 74 | const primaryInstance = new aws.docdb.ClusterInstance(`${name}-docdb-primary-instance`, { 75 | clusterIdentifier: cluster.clusterIdentifier, 76 | identifier: `${name}-primary-instance`, 77 | instanceClass: props.instanceClass || "db.t3.medium", 78 | tags: { 79 | Name: `${name}-primary-instance` 80 | } 81 | }, {parent: this}) 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /ops/components/ecs/cluster.ts: -------------------------------------------------------------------------------- 1 | import * as pulumi from "@pulumi/pulumi"; 2 | import * as aws from "@pulumi/aws"; 3 | 4 | 5 | export interface ECSClusterProps { 6 | vpcId: pulumi.Input 7 | } 8 | 9 | export class ECSCluster extends pulumi.ComponentResource { 10 | name: pulumi.Output 11 | 12 | constructor ( 13 | name: string, 14 | props: ECSClusterProps, 15 | opts?: pulumi.ComponentResourceOptions, 16 | ) { 17 | super("decodingml:main:ECSCluster", name, {}, opts); 18 | 19 | const cluster = new aws.ecs.Cluster(`${name}-cluster`, { 20 | name: `${name}-cluster`, 21 | }, {parent: this}) 22 | 23 | this.name = cluster.name 24 | 25 | const securityGroup = new aws.ec2.SecurityGroup(`${name}-sg`, { 26 | name: `${name}-ecs-host-sg`, 27 | description: 'Access to the ECS hosts that run containers', 28 | vpcId: props.vpcId, 29 | ingress: [ 30 | { 31 | description: "Ingress from other containers in the same security group", 32 | fromPort: 0, 33 | toPort: 0, 34 | protocol: "-1", 35 | self: true, 36 | } 37 | ], 38 | egress: [ 39 | { 40 | cidrBlocks: ['0.0.0.0/0'], 41 | description: "Allow all outbound traffic by default", 42 | protocol: "-1", 43 | fromPort: 0, 44 | toPort: 0, 45 | }, 46 | ], 47 | tags: { 48 | Name: `${name}-ecs-host-sg` 49 | } 50 | }, {parent: this}) 51 | 52 | new aws.servicediscovery.PrivateDnsNamespace(`${name}-private-dns-namespace`, { 53 | name: `${name}.internal`, 54 | vpc: props.vpcId, 55 | }, {parent: this}) 56 | 57 | this.registerOutputs({ 58 | name: this.name 59 | }) 60 | 61 | } 62 | } -------------------------------------------------------------------------------- /ops/components/ecs/iam.ts: -------------------------------------------------------------------------------- 1 | import * as aws from "@pulumi/aws"; 2 | 3 | 4 | export const ecsRole = new aws.iam.Role("ecs-role", { 5 | name: `ecs-role`, 6 | assumeRolePolicy: aws.iam.assumeRolePolicyForPrincipal({ Service: "ecs.amazonaws.com" }), 7 | path: "/", 8 | inlinePolicies: [{ 9 | name: "ecs-service", 10 | policy: JSON.stringify({ 11 | Statement: [{ 12 | Action: [ 13 | 'ec2:AttachNetworkInterface', 14 | 'ec2:CreateNetworkInterface', 15 | 'ec2:CreateNetworkInterfacePermission', 16 | 'ec2:DeleteNetworkInterface', 17 | 'ec2:DeleteNetworkInterfacePermission', 18 | 'ec2:Describe*', 19 | 'ec2:DetachNetworkInterface', 20 | 'elasticloadbalancing:DeregisterInstancesFromLoadBalancer', 21 | 'elasticloadbalancing:DeregisterTargets', 22 | 'elasticloadbalancing:Describe*', 23 | 'elasticloadbalancing:RegisterInstancesWithLoadBalancer', 24 | 'elasticloadbalancing:RegisterTargets' 25 | ], 26 | Effect: 'Allow', 27 | Resource: '*' 28 | }], 29 | Version: '2012-10-17', 30 | } as aws.iam.PolicyDocument) 31 | }] 32 | }) 33 | 34 | 35 | export const ecsTaskExecutionRole = new aws.iam.Role("ecs-task-execution-role", { 36 | name: `ecs-task-execution-role`, 37 | assumeRolePolicy: aws.iam.assumeRolePolicyForPrincipal({ Service: "ecs-tasks.amazonaws.com" }), 38 | path: "/", 39 | inlinePolicies: [ 40 | { 41 | name: "ecs-logs", 42 | policy: JSON.stringify({ 43 | Statement: [{ 44 | Action: [ 45 | 'logs:CreateLogGroup' 46 | ], 47 | Effect: 'Allow', 48 | Resource: '*' 49 | }] 50 | } as aws.iam.PolicyDocument), 51 | }, 52 | { 53 | name: "ecs-ssm", 54 | policy: JSON.stringify({ 55 | Statement: [{ 56 | Sid: "readEnvironmentParameters", 57 | Action: [ 58 | 'ssm:GetParameters' 59 | ], 60 | Effect: 'Allow', 61 | Resource: "*" 62 | }] 63 | } as aws.iam.PolicyDocument), 64 | } 65 | ], 66 | managedPolicyArns: [ 67 | 'arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy' 68 | ] 69 | }) -------------------------------------------------------------------------------- /ops/components/ecs/service.ts: -------------------------------------------------------------------------------- 1 | import * as pulumi from "@pulumi/pulumi"; 2 | import * as aws from "@pulumi/aws"; 3 | 4 | export interface ContainerSecrets { 5 | name: pulumi.Input; 6 | parameter: pulumi.Input; 7 | } 8 | 9 | export interface ServiceProps { 10 | vpcId: pulumi.Input 11 | 12 | cluster: pulumi.Input; 13 | environment?: pulumi.Input>; 14 | secrets: ContainerSecrets[]; 15 | 16 | command?: pulumi.Input; 17 | imageTag?: pulumi.Input; 18 | containerPort: pulumi.Input; 19 | containerCpu?: pulumi.Input; 20 | containerMemory?: pulumi.Input; 21 | 22 | deploymentController?: pulumi.Input; 23 | 24 | desiredCount?: pulumi.Input; 25 | role?: pulumi.Input; 26 | } 27 | 28 | export class Service extends pulumi.ComponentResource { 29 | constructor ( 30 | name: string, 31 | props: ServiceProps, 32 | opts?: pulumi.ComponentResourceOptions, 33 | ) { 34 | 35 | super("decodingml:main:Service", name, {}, opts); 36 | 37 | const accountId = pulumi.output(aws.getCallerIdentity()).accountId; 38 | const region = pulumi.output(aws.getRegion()).name; 39 | 40 | const imageUrl = pulumi.interpolate`${accountId}.dkr.ecr.${region}.amazonaws.com/chamberlain:latest` 41 | 42 | const containerSecrets = props.secrets.map(secret => { 43 | return { 44 | name: secret.name, 45 | valueFrom: pulumi.interpolate`arn:aws:ssm:${region}:${accountId}:parameter/${secret.parameter}` 46 | } as aws.ecs.Secret; 47 | }) 48 | 49 | const logGroup = new aws.cloudwatch.LogGroup(`log-group`, { 50 | name: `/ecs/${props.cluster}/${name}`, 51 | retentionInDays: 90, 52 | tags: { 53 | Name: `${props.cluster}-${name}-cluster-log-group` 54 | } 55 | }) 56 | 57 | const taskDefinition = new aws.ecs.TaskDefinition(`${name}-ecs-task-definition`, { 58 | family: name, 59 | networkMode: 'awsvpc', 60 | requiresCompatibilities: ["FARGATE"], 61 | cpu: props.containerCpu || "512", 62 | memory: props.containerMemory || "1024", 63 | executionRoleArn: pulumi.output(aws.iam.getRole({name: `ecs-task-execution-role`})).arn, 64 | taskRoleArn: props.role, 65 | containerDefinitions: pulumi 66 | .all([imageUrl, props.logGroup, props.environment, containerSecrets]) 67 | .apply(([image,logGroup,environment, secrets]) => 68 | JSON.stringify([{ 69 | name: name, 70 | image: image, 71 | portMappings: [{ 72 | containerPort: props.containerPort, 73 | }], 74 | command: props.command, 75 | environment: environment, 76 | secrets: secrets, 77 | logConfiguration: { 78 | logDriver: "awslogs", 79 | options: { 80 | "awslogs-group": logGroup, 81 | "awslogs-create-group": "true", 82 | "awslogs-region": "eu-central-1", 83 | "awslogs-stream-prefix": name, 84 | }, 85 | }, 86 | } as aws.ecs.ContainerDefinition]) 87 | ) 88 | }, {parent: this}) 89 | 90 | const serviceDiscovery = new aws.servicediscovery.Service(`${name}-service-discovery`, { 91 | name: name, 92 | description: `Service discovery for ${name}`, 93 | dnsConfig: { 94 | routingPolicy: "MULTIVALUE", 95 | dnsRecords: [{ type: "A", ttl: 60 }], 96 | namespaceId: pulumi.output(aws.servicediscovery.getDnsNamespace({ 97 | name: `streaming.internal`, 98 | type: 'DNS_PRIVATE', 99 | })).id, 100 | }, 101 | healthCheckCustomConfig: { 102 | failureThreshold: 1 103 | }, 104 | }, {parent: this}) 105 | 106 | new aws.ecs.Service(`${name}-ecs-service`, { 107 | name: `${name}-service`, 108 | cluster: props.cluster, 109 | launchType: 'FARGATE', 110 | deploymentController: { 111 | type: props.deploymentController || "ECS", 112 | }, 113 | desiredCount: props.desiredCount || 1, 114 | taskDefinition: taskDefinition.arn, 115 | serviceRegistries: { 116 | registryArn: serviceDiscovery.arn, 117 | containerName: `${name}`, 118 | }, 119 | networkConfiguration: { 120 | assignPublicIp: false, 121 | securityGroups: pulumi.output(aws.ec2.getSecurityGroups({ 122 | tags: {Name: `ecs-host-sg`} 123 | })).ids, 124 | subnets: pulumi.output(aws.ec2.getSubnets({tags: {Type: 'private'}})).ids 125 | } 126 | }, {parent: this}) 127 | } 128 | } -------------------------------------------------------------------------------- /ops/components/mq.ts: -------------------------------------------------------------------------------- 1 | import * as pulumi from "@pulumi/pulumi"; 2 | import * as aws from "@pulumi/aws"; 3 | 4 | 5 | export interface MessageQueueBrokerProps { 6 | vpcId: pulumi.Input 7 | 8 | engineVersion?: pulumi.Input 9 | instanceType?: pulumi.Input 10 | 11 | } 12 | 13 | export class MessageQueueBroker extends pulumi.ComponentResource { 14 | 15 | constructor( 16 | name: string, 17 | props: MessageQueueBrokerProps, 18 | opts?: pulumi.ComponentResourceOptions, 19 | ) { 20 | super("decodingml:main:MessageQueueBroker", name, {}, opts); 21 | 22 | const accountId = pulumi.output(aws.getCallerIdentity()).accountId; 23 | const region = pulumi.output(aws.getRegion()).name; 24 | 25 | const securityGroup = new aws.ec2.SecurityGroup(`${name}-mq-sg`, { 26 | name: `${name}-mq-sg`, 27 | description: "Message Queue broker access", 28 | vpcId: props.vpcId, 29 | ingress: [ 30 | { 31 | description: "Ingress from AMPQS protocol", 32 | fromPort: 5671, 33 | toPort: 5671, 34 | protocol: "tcp", 35 | }, 36 | { 37 | description: "Ingress from HTTPS protocol", 38 | fromPort: 443, 39 | toPort: 443, 40 | protocol: "tcp", 41 | }, 42 | ], 43 | egress: [{ 44 | protocol: "-1", 45 | description: "Allow all outbound traffic by default", 46 | fromPort: 0, 47 | toPort: 0, 48 | cidrBlocks: ["0.0.0.0/0"], 49 | }], 50 | tags: { 51 | Name: `${name}-mq-sg` 52 | }, 53 | }, {parent: this}) 54 | 55 | const broker = new aws.mq.Broker(`${name}-mq-broker`, { 56 | brokerName: `${name}-mq-broker`, 57 | engineType: "RabbitMQ", 58 | engineVersion: props.engineVersion || "3.11.20", 59 | hostInstanceType: props.instanceType || "mq.t3.micro", 60 | securityGroups: [securityGroup.id], 61 | deploymentMode: "SINGLE_INSTANCE", 62 | logs: { 63 | general: true, 64 | }, 65 | publiclyAccessible: true, 66 | subnetIds: pulumi.output(aws.ec2.getSubnets({tags: {Type: 'public'}})).ids, 67 | users: pulumi.all([ 68 | this.getSecretValue(`arn:aws:secretsmanager:${region}:${accountId}:secret:/${name}/broker/admin`), 69 | this.getSecretValue(`arn:aws:secretsmanager:${region}:${accountId}:secret:/${name}/broker/replication-user`) 70 | ]).apply(([adminSecret, replicationUserSecret]) => [ 71 | { 72 | username: JSON.parse(adminSecret).username, 73 | password: JSON.parse(adminSecret).password, 74 | consoleAccess: true, 75 | }, 76 | { 77 | username: JSON.parse(replicationUserSecret).username, 78 | password: JSON.parse(replicationUserSecret).password, 79 | consoleAccess: true, 80 | replicationUser: true 81 | } 82 | ]), 83 | tags: { 84 | Name: `${name}-mq-sg` 85 | }, 86 | }, {parent: this}) 87 | 88 | const hostSSMParameter = new aws.ssm.Parameter(`${name}-mq-broker-host-ssm-parameter`, { 89 | name: `/${name}/broker/host`, 90 | type: aws.ssm.ParameterType.String, 91 | description: `RabbitMQ cluster host for ${name}-mq-broker`, 92 | value: broker.instances[0].endpoints[0].apply(endpoint => { 93 | return endpoint.split(":")[0]; 94 | }), 95 | }, {parent: this}) 96 | 97 | const portSSMParameter = new aws.ssm.Parameter(`${name}-mq-broker-port-ssm-parameter`, { 98 | name: `/${name}/broker/port`, 99 | type: aws.ssm.ParameterType.String, 100 | description: `RabbitMQ cluster port for ${name}-mq-broker`, 101 | value: "5671", 102 | }, {parent: this}) 103 | } 104 | 105 | private async getSecretValue(secretName: string): Promise> { 106 | return pulumi.output(aws.secretsmanager.getSecretVersion({ 107 | secretId: secretName, 108 | }, { async: true })).apply(secretVersion => { 109 | if (!secretVersion.secretString) { 110 | throw new Error("Secret version contains no string data"); 111 | } 112 | return secretVersion.secretString; 113 | }); 114 | } 115 | } -------------------------------------------------------------------------------- /ops/components/nat.ts: -------------------------------------------------------------------------------- 1 | import * as pulumi from "@pulumi/pulumi"; 2 | import * as aws from "@pulumi/aws"; 3 | import {SubnetCidrBlocks} from "./config"; 4 | 5 | export interface NatGatewayProps { 6 | env: pulumi.Input 7 | vpcId: pulumi.Input 8 | subnet: pulumi.Input 9 | 10 | instanceImageAmiId?: pulumi.Input 11 | } 12 | 13 | export class NatGateway extends pulumi.ComponentResource { 14 | public readonly id: pulumi.Output 15 | 16 | constructor( 17 | name: string, 18 | props: NatGatewayProps, 19 | opts?: pulumi.ComponentResourceOptions, 20 | ) { 21 | super("decodingml:main:NatGateway", name, {}, opts); 22 | 23 | const config = new pulumi.Config(); 24 | 25 | const sg = new aws.ec2.SecurityGroup(`${name}-security-group`, { 26 | description: "Security Group for NAT Gateway", 27 | ingress: [ 28 | { 29 | cidrBlocks: [SubnetCidrBlocks.VPC], 30 | description: "Allow all inbound traffic from network", 31 | protocol: "-1", 32 | fromPort: 0, 33 | toPort: 0, 34 | }, 35 | ], 36 | egress: [ 37 | { 38 | cidrBlocks: ['0.0.0.0/0'], 39 | description: "Allow all outbound traffic by default", 40 | protocol: "-1", 41 | fromPort: 0, 42 | toPort: 0, 43 | }, 44 | ], 45 | vpcId: props.vpcId, 46 | }, {parent: this}) 47 | 48 | const iamRole = new aws.iam.Role(`${name}-role`, { 49 | assumeRolePolicy: { 50 | Version: '2012-10-17', 51 | Statement: [ 52 | { 53 | Action: ['sts:AssumeRole'], 54 | Effect: 'Allow', 55 | Principal: { 56 | Service: 'ec2.amazonaws.com', 57 | }, 58 | }, 59 | ], 60 | }, 61 | managedPolicyArns: [ 62 | `arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore`, 63 | ], 64 | inlinePolicies: [ 65 | { 66 | name: 'for-nat', 67 | policy: JSON.stringify({ 68 | Statement: [ 69 | { 70 | Action: [ 71 | 'ec2:AttachNetworkInterface', 72 | 'ec2:ModifyNetworkInterfaceAttribute', 73 | 'ec2:AssociateAddress', 74 | 'ec2:DisassociateAddress', 75 | 'ec2:*', 76 | ], 77 | Effect: 'Allow', 78 | Resource: '*', 79 | }, 80 | ], 81 | Version: '2012-10-17', 82 | } as aws.iam.PolicyDocument), 83 | }, 84 | ], 85 | 86 | }, {parent: this}) 87 | 88 | const eni = new aws.ec2.NetworkInterface(`${name}-eni`, { 89 | subnetId: props.subnet, 90 | securityGroups: [sg.id], 91 | sourceDestCheck: false, 92 | }, {parent: this}) 93 | 94 | this.id = eni.id 95 | 96 | const instanceProfile = new aws.iam.InstanceProfile(`${name}-instance-profile`, { 97 | role: iamRole 98 | }, {parent: this}) 99 | 100 | const launchTemplate = new aws.ec2.LaunchTemplate(`${name}-launch-template`, { 101 | name: `pi-${props.env}-nat-launch-template`, 102 | imageId: config.require('natInstanceImageId'), 103 | instanceType: 't4g.nano', 104 | iamInstanceProfile: { arn: instanceProfile.arn }, 105 | vpcSecurityGroupIds: [ sg.id ], 106 | userData: eni.id.apply(id => 107 | Buffer.from( 108 | [ 109 | '#!/bin/bash', 110 | `echo "eni_id=${id}" >> /etc/fck-nat.conf`, 111 | 'service fck-nat restart', 112 | ].join('\n'), 113 | ).toString('base64'), 114 | ), 115 | tags: { 116 | Name: `pi-${props.env}-nat-launch-template` 117 | }, 118 | tagSpecifications: [{ 119 | tags: { 120 | Name: `pi-${props.env}-nat-launch-template` 121 | }, 122 | resourceType: 'instance' 123 | }] 124 | }, {dependsOn: instanceProfile, parent: this}) 125 | 126 | 127 | new aws.autoscaling.Group(`${name}-autoscaling-group`, { 128 | maxSize: 1, 129 | minSize: 1, 130 | desiredCapacity: 1, 131 | launchTemplate: { 132 | id: launchTemplate.id, 133 | version: '$Latest', 134 | }, 135 | vpcZoneIdentifiers: [ props.subnet ], 136 | tags: [{ key: 'Name', value: `pi-${props.env}-nat-instance-launch-template`, propagateAtLaunch: true }] 137 | }, {parent: this}) 138 | 139 | this.registerOutputs({ 140 | id: this.id, 141 | }) 142 | } 143 | } -------------------------------------------------------------------------------- /ops/components/repository.ts: -------------------------------------------------------------------------------- 1 | import * as pulumi from '@pulumi/pulumi' 2 | import * as aws from '@pulumi/aws' 3 | 4 | interface Props {} 5 | 6 | export class Repository extends pulumi.ComponentResource { 7 | public name: pulumi.Output 8 | public arn: pulumi.Output 9 | public url: pulumi.Output 10 | 11 | public static dockerTags = { 12 | github: 'github-crawler-latest', 13 | linkedin: 'linkedin-crawler-latest', 14 | medium: 'medium-crawler-latest', 15 | 16 | } as const 17 | 18 | private readonly tags = { 19 | module: 'ai', 20 | scope: 'ecr', 21 | } 22 | 23 | constructor( 24 | name: string, 25 | props: Props, 26 | opts?: pulumi.ComponentResourceOptions, 27 | ) { 28 | super('deocingml:ai:ecr', name, {}, opts) 29 | 30 | const ecr = new aws.ecr.Repository( 31 | `${name}-repository`, 32 | { 33 | name, 34 | tags: this.tags, 35 | imageTagMutability: 'MUTABLE', 36 | }, 37 | { parent: this }, 38 | ) 39 | 40 | new aws.ecr.LifecyclePolicy( 41 | `${name}-lifecycle-policy`, 42 | { 43 | repository: ecr.name, 44 | policy: { 45 | rules: [ 46 | { 47 | action: { type: 'expire' }, 48 | selection: { 49 | tagStatus: 'untagged', 50 | countNumber: 30, 51 | countUnit: 'days', 52 | countType: 'sinceImagePushed', 53 | }, 54 | rulePriority: 1, 55 | description: 'Delete older than 30 days images with no tag.', 56 | }, 57 | ], 58 | }, 59 | }, 60 | { parent: this }, 61 | ) 62 | 63 | this.arn = ecr.arn 64 | this.name = ecr.name 65 | this.url = ecr.repositoryUrl 66 | 67 | this.registerOutputs() 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /ops/components/vpc.ts: -------------------------------------------------------------------------------- 1 | import * as pulumi from '@pulumi/pulumi' 2 | import * as aws from '@pulumi/aws' 3 | import {SubnetCidrBlocks} from "./config"; 4 | 5 | interface VpcProps {} 6 | 7 | export class Vpc extends pulumi.ComponentResource { 8 | public readonly id: pulumi.Output 9 | 10 | constructor( 11 | name: string, 12 | props: VpcProps, 13 | opts?: pulumi.ComponentResourceOptions, 14 | ) { 15 | super("decodingml:main:Vpc", name, {}, opts); 16 | 17 | const vpc = new aws.ec2.Vpc(`${name}-vpc`, { 18 | cidrBlock: SubnetCidrBlocks.VPC, 19 | enableDnsSupport: true, 20 | enableDnsHostnames: true, 21 | tags: { 22 | Name: `${name}-vpc`, 23 | }, 24 | }, { parent: this }); 25 | 26 | this.id = vpc.id 27 | 28 | const azs = aws.getAvailabilityZones({ 29 | state: "available" 30 | }) 31 | 32 | const publicSubnetOne = new aws.ec2.Subnet(`${name}-public-subnet-one`, { 33 | vpcId: vpc.id, 34 | availabilityZone: azs.then(azs => azs.names?.[0]), 35 | cidrBlock: SubnetCidrBlocks.PublicOne, 36 | mapPublicIpOnLaunch: true, 37 | tags: { 38 | Name: `${name}-public-subnet-one`, 39 | Type: 'public', 40 | } 41 | }, {parent: this}) 42 | 43 | const publicSubnetTwo = new aws.ec2.Subnet(`${name}-public-subnet-two`, { 44 | vpcId: vpc.id, 45 | availabilityZone: azs.then(azs => azs.names?.[1]), 46 | cidrBlock: SubnetCidrBlocks.PublicTwo, 47 | mapPublicIpOnLaunch: true, 48 | tags: { 49 | Name: `${name}-public-subnet-two`, 50 | Type: 'public', 51 | } 52 | }, {parent: this}) 53 | 54 | // Setup networking resources for the public subnets. 55 | const internetGateway= new aws.ec2.InternetGateway(`${name}-internet-gateway`, { 56 | vpcId: vpc.id, 57 | tags: { 58 | Name: `${name}-internet-gateway` 59 | } 60 | }, {parent: this}) 61 | 62 | const publicRouteTable = new aws.ec2.RouteTable(`${name}-public-route-table`, { 63 | vpcId: vpc.id, 64 | tags: { 65 | Name: `${name}-public-route-table` 66 | } 67 | }, {parent: this}) 68 | 69 | new aws.ec2.Route(`${name}-public-route`, { 70 | routeTableId: publicRouteTable.id, 71 | destinationCidrBlock: "0.0.0.0/0", 72 | gatewayId: internetGateway.id 73 | }, {parent: this}) 74 | 75 | new aws.ec2.RouteTableAssociation(`${name}-public-subnet-one-rta`, { 76 | subnetId: publicSubnetOne.id, 77 | routeTableId: publicRouteTable.id 78 | }, {parent: this}) 79 | new aws.ec2.RouteTableAssociation(`${name}-public-subnet-two-rta`, { 80 | subnetId: publicSubnetTwo.id, 81 | routeTableId: publicRouteTable.id, 82 | }, {parent: this}) 83 | 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /ops/index.ts: -------------------------------------------------------------------------------- 1 | import * as pulumi from "@pulumi/pulumi"; 2 | import {Vpc} from "./components/vpc"; 3 | import {DocumentDBCluster} from "./components/docdb"; 4 | import {Crawler} from "./components/crawler"; 5 | import {ECSCluster} from "./components/ecs/cluster"; 6 | import {Service} from "./components/ecs/service"; 7 | 8 | const vpc= new Vpc("network-overlay", {}) 9 | 10 | const docdb = new DocumentDBCluster("warehouse", { 11 | vpcId: vpc.id, 12 | instanceClass: "db.t3.medium", 13 | }, {dependsOn: vpc}) 14 | 15 | const lambda = new Crawler("crawler", { 16 | vpcId: vpc.id, 17 | timeout: 900, 18 | memory: 3008 19 | }) 20 | 21 | const cluster = new ECSCluster("streaming", { 22 | vpcId: vpc.id 23 | }) 24 | 25 | const bytewaxWorker = new Service("bytewax-worker", { 26 | vpcId: vpc.id, 27 | cluster: cluster.name, 28 | containerPort: 9000, 29 | secrets: [ 30 | { 31 | name: "MONGO_DATABASE_HOST", 32 | parameter: "database/host", 33 | }, 34 | { 35 | name: "OPENAI_API_KEY", 36 | parameter: "database/host", 37 | }, 38 | { 39 | name: "QDRANT_DATABASE_HOST", 40 | parameter: "database/username", 41 | }, 42 | { 43 | name: "QDRANT_DATABASE_PORT", 44 | parameter: "database/host", 45 | }, 46 | { 47 | name: "QDRANT_APIKEY", 48 | parameter: "database/username", 49 | }, 50 | { 51 | name: "RABBITMQ_HOST", 52 | parameter: "database/host", 53 | }, 54 | { 55 | name: "RABBITMQ_PORT", 56 | parameter: "database/username", 57 | }, 58 | { 59 | name: "RABBITMQ_DEFAULT_USERNAME", 60 | parameter: "database/host", 61 | }, 62 | { 63 | name: "RABBITMQ_DEFAULT_PASSWORD", 64 | parameter: "database/username", 65 | }, 66 | ] 67 | }) 68 | 69 | 70 | export const VpcID: pulumi.Output = vpc.id 71 | -------------------------------------------------------------------------------- /ops/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "decodingml", 3 | "main": "index.ts", 4 | "devDependencies": { 5 | "@types/node": "^18" 6 | }, 7 | "dependencies": { 8 | "@pulumi/pulumi": "^3.0.0", 9 | "@pulumi/aws": "^6.0.0", 10 | "@pulumi/awsx": "^2.0.2" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /ops/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "strict": true, 4 | "outDir": "bin", 5 | "target": "es2016", 6 | "module": "commonjs", 7 | "moduleResolution": "node", 8 | "sourceMap": true, 9 | "experimentalDecorators": true, 10 | "pretty": true, 11 | "noFallthroughCasesInSwitch": true, 12 | "noImplicitReturns": true, 13 | "forceConsistentCasingInFileNames": true 14 | }, 15 | "files": [ 16 | "index.ts" 17 | ] 18 | } 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "rag-system" 3 | description = "" 4 | version = "0.1.0" 5 | authors = [ 6 | "Vlad Adumitracesei ", 7 | "Paul Iusztin ", 8 | "Alex Vesa ", 9 | "Rares Istoc " 10 | ] 11 | readme = "README.md" 12 | 13 | [tool.ruff] 14 | line-length = 88 15 | select = [ 16 | "F401", 17 | "F403", 18 | ] 19 | 20 | 21 | [tool.poetry.dependencies] 22 | python = ">=3.10, <3.12" 23 | pydantic = "^2.6.3" 24 | pydantic-settings = "^2.1.0" 25 | pika = "^1.3.2" 26 | qdrant-client = "^1.8.0" 27 | langchain = "^0.1.13" 28 | aws-lambda-powertools = "^2.38.1" 29 | selenium = "4.21.0" 30 | instructorembedding = "^1.0.1" 31 | numpy = "^1.26.4" 32 | langchain-openai = "^0.1.3" 33 | gdown = "^5.1.0" 34 | pymongo = "^4.7.1" 35 | structlog = "^24.1.0" 36 | rich = "^13.7.1" 37 | pip = "^24.0" 38 | comet-ml = "^3.41.0" 39 | ruff = "^0.4.3" 40 | pandas = "^2.0.3" 41 | datasets = "^2.19.1" 42 | transformers = "^4.40.2" 43 | safetensors = "^0.4.3" 44 | bitsandbytes = "^0.42.0" 45 | scikit-learn = "^1.4.2" 46 | unstructured = "^0.14.2" 47 | 48 | [tool.poetry.group.3-feature-pipeline.dependencies] 49 | bytewax = "0.18.2" 50 | 51 | [tool.poetry.group.ml.dependencies] 52 | qwak-inference = "^0.1.17" 53 | comet-llm = "^2.2.4" 54 | qwak-sdk = "^0.5.69" 55 | peft = "^0.11.1" 56 | sentence-transformers = "^2.2.2" 57 | accelerate = "^0.30.1" 58 | 59 | [tool.poetry.group.6-superlinked-rag.dependencies] 60 | superlinked = "7.2.1" 61 | 62 | [build-system] 63 | requires = ["poetry-core"] 64 | build-backend = "poetry.core.masonry.api" 65 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class AppSettings(BaseSettings): 5 | model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") 6 | 7 | # Embeddings config 8 | EMBEDDING_MODEL_ID: str = "sentence-transformers/all-MiniLM-L6-v2"#instruct-xl 9 | EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 256 10 | EMBEDDING_SIZE: int = 384 11 | EMBEDDING_MODEL_DEVICE: str = "cpu" 12 | 13 | OPENAI_MODEL_ID: str = "gpt-4-1106-preview" 14 | OPENAI_API_KEY: str | None = None 15 | 16 | # MongoDB configs 17 | MONGO_DATABASE_HOST: str = "mongodb://localhost:30001,localhost:30002,localhost:30003/?replicaSet=my-replica-set" 18 | MONGO_DATABASE_NAME: str = "scrabble" 19 | 20 | # QdrantDB config 21 | QDRANT_DATABASE_HOST: str = "localhost" 22 | QDRANT_DATABASE_PORT: int = 6333 23 | QDRANT_DATABASE_URL: str = "http://localhost:6333" 24 | QDRANT_CLOUD_URL: str = "str" 25 | USE_QDRANT_CLOUD: bool = False 26 | QDRANT_APIKEY: str | None = None 27 | 28 | # MQ config 29 | RABBITMQ_DEFAULT_USERNAME: str = "guest" 30 | RABBITMQ_DEFAULT_PASSWORD: str = "guest" 31 | RABBITMQ_HOST: str = "localhost" 32 | RABBITMQ_PORT: int = 5673 33 | 34 | # CometML config 35 | COMET_API_KEY: str | None = None 36 | COMET_WORKSPACE: str | None = None 37 | COMET_PROJECT: str | None = None 38 | 39 | # LinkedIn credentials 40 | LINKEDIN_USERNAME: str | None = None 41 | LINKEDIN_PASSWORD: str | None = None 42 | 43 | 44 | settings = AppSettings() 45 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import structlog 2 | 3 | 4 | def get_logger(cls: str): 5 | return structlog.get_logger().bind(cls=cls) 6 | --------------------------------------------------------------------------------