├── .docker ├── .dockerignore ├── Dockerfile.data_cdc ├── Dockerfile.data_crawlers ├── Dockerfile.feature_pipeline └── Dockerfile.feature_pipeline.superlinked ├── .env.example ├── .github ├── ISSUE_TEMPLATE │ ├── clarify-concept.md │ ├── questions-about-the-course-material.md │ └── technical-troubleshooting-or-bugs.md └── workflows │ └── crawler.yml ├── .gitignore ├── .python-version ├── INSTALL_AND_USAGE.md ├── LICENSE ├── Makefile ├── README.md ├── data └── links.txt ├── docker-compose-superlinked.yml ├── docker-compose.yml ├── media ├── architecture.png ├── cover.png ├── fine-tuning-workflow.png ├── llm_engineers_handbook_cover.png ├── qdrant-example.png ├── sponsors │ ├── bytewax.png │ ├── comet.png │ ├── opik.svg │ ├── qdrant.svg │ ├── qwak.png │ └── superlinked.svg └── ui-example.png ├── poetry.lock ├── pyproject.toml └── src ├── bonus_superlinked_rag ├── README.md ├── config.py ├── data_flow │ ├── __init__.py │ ├── stream_input.py │ └── stream_output.py ├── data_logic │ ├── __init__.py │ ├── cleaning_data_handlers.py │ ├── dispatchers.py │ └── splitters.py ├── llm │ ├── __init__.py │ ├── chain.py │ └── prompt_templates.py ├── local_test.py ├── main.py ├── models │ ├── __init__.py │ ├── documents.py │ ├── raw.py │ └── utils.py ├── mq.py ├── rag │ ├── __init__.py │ ├── query_expanison.py │ ├── reranking.py │ ├── retriever.py │ └── self_query.py ├── retriever.py ├── scripts │ └── bytewax_entrypoint.sh ├── server │ ├── .python-version │ ├── LICENSE │ ├── NOTICE │ ├── README.md │ ├── compose.yaml │ ├── config │ │ ├── aws_credentials.json │ │ ├── config.yaml │ │ └── gcp_credentials.json │ ├── docs │ │ ├── api.md │ │ ├── app.md │ │ ├── bucket.md │ │ ├── dummy_app.py │ │ ├── example │ │ │ ├── amazon_app.py │ │ │ └── app.py │ │ ├── mongodb │ │ │ ├── app_with_mongodb.py │ │ │ └── mongodb.md │ │ ├── redis │ │ │ ├── app_with_redis.py │ │ │ └── redis.md │ │ ├── vector_databases.md │ │ └── vm.md │ ├── runner │ │ ├── .python-version │ │ ├── executor │ │ │ ├── Dockerfile │ │ │ ├── __init__.py │ │ │ ├── app │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── app_config.py │ │ │ │ ├── dependency_register.py │ │ │ │ ├── exception │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── exception.py │ │ │ │ │ └── exception_handler.py │ │ │ │ ├── main.py │ │ │ │ ├── middleware │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── lifespan_event.py │ │ │ │ ├── router │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── management_router.py │ │ │ │ ├── service │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── data_loader.py │ │ │ │ │ ├── file_handler_service.py │ │ │ │ │ ├── file_object_serializer.py │ │ │ │ │ ├── persistence_service.py │ │ │ │ │ └── supervisor_service.py │ │ │ │ └── util │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── fast_api_handler.py │ │ │ │ │ ├── open_api_description_util.py │ │ │ │ │ └── registry_loader.py │ │ │ ├── openapi │ │ │ │ └── static_endpoint_descriptor.json │ │ │ └── supervisord.conf │ │ ├── poetry.lock │ │ ├── poller │ │ │ ├── Dockerfile │ │ │ ├── __init__.py │ │ │ ├── app │ │ │ │ ├── __init__.py │ │ │ │ ├── app_location_parser │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── app_location_parser.py │ │ │ │ ├── config │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── poller_config.py │ │ │ │ ├── main.py │ │ │ │ ├── poller │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── poller.py │ │ │ │ └── resource_handler │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── gcs │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── gcs_resource_handler.py │ │ │ │ │ ├── local │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── local_resource_handler.py │ │ │ │ │ ├── resource_handler.py │ │ │ │ │ ├── resource_handler_factory.py │ │ │ │ │ └── s3 │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── s3_resource_handler.py │ │ │ ├── logging_config.ini │ │ │ └── poller_config.ini │ │ └── pyproject.toml │ ├── src │ │ └── app.py │ └── tools │ │ ├── deploy.py │ │ └── init-venv.sh ├── singleton.py ├── superlinked_client.py └── utils │ ├── __init__.py │ ├── cleaning.py │ └── logging.py ├── core ├── __init__.py ├── aws │ ├── __init__.py │ ├── create_execution_role.py │ └── create_sagemaker_role.py ├── config.py ├── db │ ├── __init__.py │ ├── documents.py │ ├── mongo.py │ └── qdrant.py ├── errors.py ├── lib.py ├── logger_utils.py ├── mq.py ├── opik_utils.py └── rag │ ├── __init__.py │ ├── prompt_templates.py │ ├── query_expanison.py │ ├── reranking.py │ ├── retriever.py │ └── self_query.py ├── data_cdc ├── cdc.py ├── config.py └── test_cdc.py ├── data_crawling ├── config.py ├── crawlers │ ├── __init__.py │ ├── base.py │ ├── custom_article.py │ ├── github.py │ ├── linkedin.py │ └── medium.py ├── dispatcher.py ├── main.py └── utils.py ├── feature_pipeline ├── config.py ├── data_flow │ ├── __init__.py │ ├── stream_input.py │ └── stream_output.py ├── data_logic │ ├── __init__.py │ ├── chunking_data_handlers.py │ ├── cleaning_data_handlers.py │ ├── dispatchers.py │ └── embedding_data_handlers.py ├── generate_dataset │ ├── __init__.py │ ├── chunk_documents.py │ ├── exceptions.py │ ├── file_handler.py │ ├── generate.py │ └── llm_communication.py ├── main.py ├── models │ ├── __init__.py │ ├── base.py │ ├── chunk.py │ ├── clean.py │ ├── embedded_chunk.py │ └── raw.py ├── retriever.py ├── scripts │ └── bytewax_entrypoint.sh └── utils │ ├── __init__.py │ ├── chunking.py │ ├── cleaning.py │ └── embeddings.py ├── inference_pipeline ├── aws │ ├── __init__.py │ ├── delete_sagemaker_endpoint.py │ └── deploy_sagemaker_endpoint.py ├── config.py ├── evaluation │ ├── __init__.py │ ├── evaluate.py │ ├── evaluate_monitoring.py │ ├── evaluate_rag.py │ └── style.py ├── llm_twin.py ├── main.py ├── prompt_templates.py ├── ui.py └── utils.py └── training_pipeline ├── config.py ├── download_dataset.py ├── finetune.py ├── requirements.txt └── run_on_sagemaker.py /.docker/.dockerignore: -------------------------------------------------------------------------------- 1 | bonus_superlinked_rag/server 2 | -------------------------------------------------------------------------------- /.docker/Dockerfile.data_cdc: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | FROM python:3.11-slim 3 | 4 | ENV POETRY_VERSION=1.8.3 5 | 6 | # Install system dependencies 7 | RUN apt-get update && apt-get install -y \ 8 | gcc \ 9 | python3-dev \ 10 | curl \ 11 | build-essential \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | # Install Poetry using pip and clear cache 16 | RUN pip install --no-cache-dir "poetry==$POETRY_VERSION" 17 | RUN poetry config installer.max-workers 20 18 | 19 | # Set the working directory 20 | WORKDIR /app 21 | 22 | # Copy the pyproject.toml and poetry.lock files from the root directory 23 | COPY ./pyproject.toml ./poetry.lock ./ 24 | 25 | # Install the dependencies and clear cache 26 | RUN poetry config virtualenvs.create false && \ 27 | poetry install --no-root --no-interaction --no-cache && \ 28 | rm -rf ~/.cache/pypoetry/cache/ && \ 29 | rm -rf ~/.cache/pypoetry/artifacts/ 30 | 31 | # Install dependencies 32 | RUN poetry install --no-root 33 | 34 | # Copy the data_cdc and core directories 35 | COPY ./src/data_cdc ./data_cdc 36 | COPY ./src/core ./core 37 | 38 | # Set the PYTHONPATH environment variable 39 | ENV PYTHONPATH=/app 40 | 41 | # Command to run the script 42 | CMD poetry run python /app/data_cdc/cdc.py && tail -f /dev/null -------------------------------------------------------------------------------- /.docker/Dockerfile.data_crawlers: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/lambda/python:3.11 as build 2 | 3 | # Install chrome driver and browser 4 | RUN yum install -y unzip && \ 5 | curl -Lo "/tmp/chromedriver.zip" "https://storage.googleapis.com/chrome-for-testing-public/126.0.6478.126/linux64/chromedriver-linux64.zip" && \ 6 | curl -Lo "/tmp/chrome-linux.zip" "https://storage.googleapis.com/chrome-for-testing-public/126.0.6478.126/linux64/chrome-linux64.zip" && \ 7 | unzip /tmp/chromedriver.zip -d /opt/ && \ 8 | unzip /tmp/chrome-linux.zip -d /opt/ 9 | 10 | FROM public.ecr.aws/lambda/python:3.11 11 | 12 | ENV POETRY_VERSION=1.8.3 13 | 14 | # Install the function's OS dependencies using yum 15 | RUN yum install -y \ 16 | atk \ 17 | wget \ 18 | git \ 19 | cups-libs \ 20 | gtk3 \ 21 | libXcomposite \ 22 | alsa-lib \ 23 | libXcursor \ 24 | libXdamage \ 25 | libXext \ 26 | libXi \ 27 | libXrandr \ 28 | libXScrnSaver \ 29 | libXtst \ 30 | pango \ 31 | at-spi2-atk \ 32 | libXt \ 33 | xorg-x11-server-Xvfb \ 34 | xorg-x11-xauth \ 35 | dbus-glib \ 36 | dbus-glib-devel \ 37 | nss \ 38 | mesa-libgbm \ 39 | ffmpeg \ 40 | libxext6 \ 41 | libssl-dev \ 42 | libcurl4-openssl-dev \ 43 | libpq-dev 44 | 45 | 46 | COPY --from=build /opt/chrome-linux64 /opt/chrome 47 | COPY --from=build /opt/chromedriver-linux64 /opt/ 48 | 49 | COPY ./pyproject.toml ./poetry.lock ./ 50 | 51 | # Install Poetry, export dependencies to requirements.txt, and install dependencies 52 | # in the Lambda task directory, finally cleanup manifest files. 53 | RUN python -m pip install --upgrade pip && pip install --no-cache-dir "poetry==$POETRY_VERSION" 54 | RUN poetry export --without feature_pipeline -f requirements.txt > requirements.txt && \ 55 | pip install --no-cache-dir -r requirements.txt --target "${LAMBDA_TASK_ROOT}" && \ 56 | rm requirements.txt pyproject.toml poetry.lock 57 | 58 | # Optional TLS CA only if you plan to store the extracted data into Document DB 59 | RUN wget https://truststore.pki.rds.amazonaws.com/global/global-bundle.pem -P ${LAMBDA_TASK_ROOT} 60 | ENV PYTHONPATH="${LAMBDA_TASK_ROOT}/data_crawling:${LAMBDA_TASK_ROOT}" 61 | 62 | # Copy function code 63 | COPY ./src/data_crawling ${LAMBDA_TASK_ROOT}/data_crawling 64 | COPY ./src/core ${LAMBDA_TASK_ROOT}/core 65 | 66 | # Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile) 67 | CMD ["data_crawling.main.handler"] 68 | -------------------------------------------------------------------------------- /.docker/Dockerfile.feature_pipeline: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | FROM python:3.11-slim-bullseye 3 | 4 | ENV WORKSPACE_ROOT=/usr/src/app \ 5 | PYTHONDONTWRITEBYTECODE=1 \ 6 | PYTHONUNBUFFERED=1 \ 7 | POETRY_HOME="/opt/poetry" \ 8 | POETRY_NO_INTERACTION=1 \ 9 | POETRY_VERSION=1.8.3 10 | 11 | RUN mkdir -p $WORKSPACE_ROOT 12 | 13 | # Install system dependencies 14 | RUN apt-get update -y \ 15 | && apt-get install -y --no-install-recommends build-essential \ 16 | gcc \ 17 | python3-dev \ 18 | curl \ 19 | build-essential \ 20 | && apt-get clean 21 | 22 | # Install Poetry using pip and clear cache 23 | RUN pip install --no-cache-dir "poetry==$POETRY_VERSION" 24 | RUN poetry config installer.max-workers 20 25 | 26 | RUN apt-get remove -y curl 27 | 28 | # Copy the pyproject.toml and poetry.lock files from the root directory 29 | COPY ./pyproject.toml ./poetry.lock ./ 30 | 31 | # Install the dependencies and clear cache 32 | RUN poetry config virtualenvs.create false && \ 33 | poetry install --no-root --no-interaction --no-cache && \ 34 | rm -rf ~/.cache/pypoetry/cache/ && \ 35 | rm -rf ~/.cache/pypoetry/artifacts/ 36 | 37 | # Set the working directory 38 | WORKDIR $WORKSPACE_ROOT 39 | 40 | # Copy the feature pipeline and any other necessary directories 41 | COPY ./src/feature_pipeline . 42 | COPY ./src/core ./core 43 | 44 | # Set the PYTHONPATH environment variable 45 | ENV PYTHONPATH=/usr/src/app 46 | 47 | RUN chmod +x /usr/src/app/scripts/bytewax_entrypoint.sh 48 | 49 | # Command to run the Bytewax pipeline script 50 | CMD ["/usr/src/app/scripts/bytewax_entrypoint.sh"] 51 | -------------------------------------------------------------------------------- /.docker/Dockerfile.feature_pipeline.superlinked: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | FROM python:3.11-slim-bullseye 3 | 4 | ENV WORKSPACE_ROOT=/usr/src/app \ 5 | PYTHONDONTWRITEBYTECODE=1 \ 6 | PYTHONUNBUFFERED=1 \ 7 | POETRY_HOME="/opt/poetry" \ 8 | POETRY_NO_INTERACTION=1 \ 9 | POETRY_VERSION=1.8.3 10 | 11 | RUN mkdir -p $WORKSPACE_ROOT 12 | 13 | # Install system dependencies 14 | RUN apt-get update -y \ 15 | && apt-get install -y --no-install-recommends build-essential \ 16 | gcc \ 17 | python3-dev \ 18 | curl \ 19 | build-essential \ 20 | && apt-get clean 21 | 22 | # Install Poetry using pip and clear cache 23 | RUN pip install --no-cache-dir "poetry==$POETRY_VERSION" 24 | RUN poetry config installer.max-workers 20 25 | 26 | RUN apt-get remove -y curl 27 | 28 | # Copy the pyproject.toml and poetry.lock files from the root directory 29 | COPY ./pyproject.toml ./poetry.lock ./ 30 | 31 | # Install the dependencies and clear cache 32 | RUN poetry config virtualenvs.create false && \ 33 | poetry install --no-root --no-interaction --no-cache && \ 34 | rm -rf ~/.cache/pypoetry/cache/ && \ 35 | rm -rf ~/.cache/pypoetry/artifacts/ 36 | 37 | # Set the working directory 38 | WORKDIR $WORKSPACE_ROOT 39 | 40 | # Copy the 3-feature-pipeline and any other necessary directories 41 | COPY ./src/bonus_superlinked_rag . 42 | COPY ./src/core ./core 43 | 44 | # Set the PYTHONPATH environment variable 45 | ENV PYTHONPATH=/usr/src/app 46 | 47 | RUN chmod +x /usr/src/app/scripts/bytewax_entrypoint.sh 48 | 49 | # Command to run the Bytewax pipeline script 50 | CMD ["/usr/src/app/scripts/bytewax_entrypoint.sh"] 51 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | 2 | # --- Required settings even when working locally. --- 3 | 4 | # OpenAI API config 5 | OPENAI_MODEL_ID=gpt-4o-mini 6 | OPENAI_API_KEY= 7 | 8 | # Huggingface API config 9 | HUGGINGFACE_ACCESS_TOKEN= 10 | 11 | # Comet ML (during training and inference) 12 | COMET_API_KEY= 13 | COMET_WORKSPACE= # such as your Comet username 14 | 15 | # --- Required settings ONLY when using Qdrant Cloud and AWS SageMaker --- 16 | 17 | # Qdrant cloud vector database connection config 18 | USE_QDRANT_CLOUD=false 19 | QDRANT_CLOUD_URL= 20 | QDRANT_APIKEY= 21 | 22 | # AWS authentication config 23 | AWS_ARN_ROLE= 24 | AWS_REGION=eu-central-1 25 | AWS_ACCESS_KEY= 26 | AWS_SECRET_KEY= 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/clarify-concept.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Clarify concept 3 | about: Clarify concept 4 | title: Clarify concept 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | Ask anything about the course. 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-about-the-course-material.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Questions about the course material 3 | about: Questions about the course material 4 | title: Questions about the course material 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | Ask anything about the course material. 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/technical-troubleshooting-or-bugs.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Technical troubleshooting or bugs 3 | about: Technical troubleshooting or bugs 4 | title: Technical troubleshooting or bugs 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug or technical issue** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/workflows/crawler.yml: -------------------------------------------------------------------------------- 1 | name: Crawlers 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | name: Build & Push Docker Image 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout Code 14 | uses: actions/checkout@v2 15 | - name: Configure AWS credentials 16 | uses: aws-actions/configure-aws-credentials@v1 17 | with: 18 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 19 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 20 | aws-region: ${{ secrets.AWS_REGION }} 21 | - name: Login to Amazon ECR 22 | id: login-ecr 23 | uses: aws-actions/amazon-ecr-login@v1 24 | - name: Build images & push to ECR 25 | id: build-image 26 | uses: docker/build-push-action@v4 27 | with: 28 | context: ./course/module-1 29 | file: ./course/module-1/Dockerfile 30 | tags: | 31 | ${{ steps.login-ecr.outputs.registry }}/crawler:${{ github.sha }} 32 | ${{ steps.login-ecr.outputs.registry }}/crawler:latest 33 | push: true 34 | 35 | deploy: 36 | name: Deploy Crawler 37 | runs-on: ubuntu-latest 38 | needs: build 39 | steps: 40 | - name: Configure AWS credentials 41 | uses: aws-actions/configure-aws-credentials@v1 42 | with: 43 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 44 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 45 | aws-region: ${{ secrets.AWS_REGION }} 46 | - name: Deploy Lambda Image 47 | id: deploy-lambda 48 | run: | 49 | echo "Updating lambda with new image version $ECR_REPOSITORY/crawler:$PROJECT_VERSION..." 50 | aws lambda update-function-code \ 51 | --function-name "arn:aws:lambda:$AWS_REGION:$AWS_ACCOUNT_ID:function:crawler" \ 52 | --image-uri $ECR_REPOSITORY/crawler:$PROJECT_VERSION 53 | echo "Successfully updated lambda" 54 | env: 55 | AWS_REGION: ${{ secrets.AWS_REGION }} 56 | ECR_REPOSITORY: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_REGION }}.amazonaws.com 57 | PROJECT_VERSION: ${{ github.sha }} 58 | AWS_ACCOUNT_ID: ${{ secrets.AWS_ACCOUNT_ID }} 59 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11.8 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Decoding ML by Crafted Intelligence S.R.L. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docker-compose-superlinked.yml: -------------------------------------------------------------------------------- 1 | services: 2 | mongo1: 3 | image: mongo:5 4 | container_name: llm-twin-mongo1 5 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30001"] 6 | volumes: 7 | - mongo-replica-1-data:/data/db 8 | ports: 9 | - "30001:30001" 10 | healthcheck: 11 | test: test $$(echo "rs.initiate({_id:'my-replica-set',members:[{_id:0,host:\"mongo1:30001\"},{_id:1,host:\"mongo2:30002\"},{_id:2,host:\"mongo3:30003\"}]}).ok || rs.status().ok" | mongo --port 30001 --quiet) -eq 1 12 | interval: 10s 13 | start_period: 30s 14 | restart: always 15 | networks: 16 | - server_default 17 | 18 | mongo2: 19 | image: mongo:5 20 | container_name: llm-twin-mongo2 21 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30002"] 22 | volumes: 23 | - mongo-replica-2-data:/data/db 24 | ports: 25 | - "30002:30002" 26 | restart: always 27 | networks: 28 | - server_default 29 | 30 | mongo3: 31 | image: mongo:5 32 | container_name: llm-twin-mongo3 33 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30003"] 34 | volumes: 35 | - mongo-replica-3-data:/data/db 36 | ports: 37 | - "30003:30003" 38 | restart: always 39 | networks: 40 | - server_default 41 | 42 | mq: 43 | image: rabbitmq:3-management-alpine 44 | container_name: llm-twin-mq 45 | ports: 46 | - "5672:5672" 47 | - "15672:15672" 48 | volumes: 49 | - ./rabbitmq/data/:/var/lib/rabbitmq/ 50 | - ./rabbitmq/log/:/var/log/rabbitmq 51 | healthcheck: 52 | test: ["CMD", "rabbitmqctl", "ping"] 53 | interval: 30s 54 | timeout: 10s 55 | retries: 5 56 | restart: always 57 | networks: 58 | - server_default 59 | 60 | data-crawlers: 61 | image: "llm-twin-data-crawlers" 62 | container_name: llm-twin-data-crawlers 63 | platform: "linux/amd64" 64 | build: 65 | context: . 66 | dockerfile: .docker/Dockerfile.data_crawlers 67 | env_file: 68 | - .env 69 | ports: 70 | - "9010:8080" 71 | depends_on: 72 | - mongo1 73 | - mongo2 74 | - mongo3 75 | networks: 76 | - server_default 77 | 78 | data_cdc: 79 | image: "llm-twin-data-cdc" 80 | container_name: llm-twin-data-cdc 81 | build: 82 | context: . 83 | dockerfile: .docker/Dockerfile.data_cdc 84 | env_file: 85 | - .env 86 | depends_on: 87 | - mongo1 88 | - mongo2 89 | - mongo3 90 | - mq 91 | networks: 92 | - server_default 93 | 94 | feature_pipeline: 95 | image: "llm-twin-feature-pipeline" 96 | container_name: llm-twin-feature-pipeline 97 | build: 98 | context: . 99 | dockerfile: .docker/Dockerfile.feature_pipeline 100 | environment: 101 | BYTEWAX_PYTHON_FILE_PATH: "main:flow" 102 | DEBUG: "false" 103 | BYTEWAX_KEEP_CONTAINER_ALIVE: "false" 104 | env_file: 105 | - .env 106 | depends_on: 107 | - mq 108 | restart: always 109 | networks: 110 | - server_default 111 | 112 | volumes: 113 | mongo-replica-1-data: 114 | mongo-replica-2-data: 115 | mongo-replica-3-data: 116 | 117 | networks: 118 | server_default: 119 | external: true 120 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | mongo1: 3 | image: mongo:5 4 | container_name: llm-twin-mongo1 5 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30001"] 6 | volumes: 7 | - mongo-replica-1-data:/data/db 8 | ports: 9 | - "30001:30001" 10 | healthcheck: 11 | test: test $$(echo "rs.initiate({_id:'my-replica-set',members:[{_id:0,host:\"mongo1:30001\"},{_id:1,host:\"mongo2:30002\"},{_id:2,host:\"mongo3:30003\"}]}).ok || rs.status().ok" | mongo --port 30001 --quiet) -eq 1 12 | interval: 10s 13 | start_period: 30s 14 | restart: always 15 | 16 | mongo2: 17 | image: mongo:5 18 | container_name: llm-twin-mongo2 19 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30002"] 20 | volumes: 21 | - mongo-replica-2-data:/data/db 22 | ports: 23 | - "30002:30002" 24 | restart: always 25 | 26 | mongo3: 27 | image: mongo:5 28 | container_name: llm-twin-mongo3 29 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30003"] 30 | volumes: 31 | - mongo-replica-3-data:/data/db 32 | ports: 33 | - "30003:30003" 34 | restart: always 35 | 36 | mq: 37 | image: rabbitmq:3-management-alpine 38 | container_name: llm-twin-mq 39 | ports: 40 | - "5673:5672" 41 | - "15673:15672" 42 | volumes: 43 | - ./rabbitmq/data/:/var/lib/rabbitmq/ 44 | - ./rabbitmq/log/:/var/log/rabbitmq 45 | restart: always 46 | 47 | qdrant: 48 | image: qdrant/qdrant:latest 49 | container_name: llm-twin-qdrant 50 | ports: 51 | - "6333:6333" 52 | - "6334:6334" 53 | expose: 54 | - "6333" 55 | - "6334" 56 | - "6335" 57 | volumes: 58 | - qdrant-data:/qdrant_data 59 | restart: always 60 | 61 | data-crawlers: 62 | image: "llm-twin-data-crawlers" 63 | container_name: llm-twin-data-crawlers 64 | platform: "linux/amd64" 65 | build: 66 | context: . 67 | dockerfile: .docker/Dockerfile.data_crawlers 68 | env_file: 69 | - .env 70 | ports: 71 | - "9010:8080" 72 | depends_on: 73 | - mongo1 74 | - mongo2 75 | - mongo3 76 | 77 | data_cdc: 78 | image: "llm-twin-data-cdc" 79 | container_name: llm-twin-data-cdc 80 | build: 81 | context: . 82 | dockerfile: .docker/Dockerfile.data_cdc 83 | env_file: 84 | - .env 85 | depends_on: 86 | - mongo1 87 | - mongo2 88 | - mongo3 89 | - mq 90 | 91 | feature_pipeline: 92 | image: "llm-twin-feature-pipeline" 93 | container_name: llm-twin-feature-pipeline 94 | build: 95 | context: . 96 | dockerfile: .docker/Dockerfile.feature_pipeline 97 | environment: 98 | BYTEWAX_PYTHON_FILE_PATH: "main:flow" 99 | DEBUG: "false" 100 | BYTEWAX_KEEP_CONTAINER_ALIVE: "true" 101 | env_file: 102 | - .env 103 | depends_on: 104 | - mq 105 | - qdrant 106 | restart: always 107 | 108 | volumes: 109 | mongo-replica-1-data: 110 | mongo-replica-2-data: 111 | mongo-replica-3-data: 112 | qdrant-data: 113 | -------------------------------------------------------------------------------- /media/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/architecture.png -------------------------------------------------------------------------------- /media/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/cover.png -------------------------------------------------------------------------------- /media/fine-tuning-workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/fine-tuning-workflow.png -------------------------------------------------------------------------------- /media/llm_engineers_handbook_cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/llm_engineers_handbook_cover.png -------------------------------------------------------------------------------- /media/qdrant-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/qdrant-example.png -------------------------------------------------------------------------------- /media/sponsors/bytewax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/sponsors/bytewax.png -------------------------------------------------------------------------------- /media/sponsors/comet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/sponsors/comet.png -------------------------------------------------------------------------------- /media/sponsors/qwak.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/sponsors/qwak.png -------------------------------------------------------------------------------- /media/ui-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/ui-example.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "llm-twin-course" 3 | description = "" 4 | version = "0.1.0" 5 | authors = [ 6 | "Paul Iusztin ", 7 | "Vesa Alexandru ", 8 | "Razvant Alexandru ", 9 | "Rares Istoc ", 10 | "Vlad Adumitracesei ", 11 | "Anca Muscalagiu " 12 | ] 13 | package-mode = false 14 | readme = "README.md" 15 | 16 | [tool.ruff] 17 | line-length = 88 18 | select = [ 19 | "F401", 20 | "F403", 21 | ] 22 | 23 | [tool.poetry.dependencies] 24 | python = "~3.11" 25 | pydantic = "^2.6.3" 26 | pydantic-settings = "^2.2.0" 27 | pika = "^1.3.2" 28 | qdrant-client = "^1.8.0" 29 | aws-lambda-powertools = "^2.38.1" 30 | selenium = "4.21.0" 31 | instructorembedding = "^1.0.1" 32 | numpy = "^1.26.4" 33 | gdown = "^5.1.0" 34 | pymongo = "^4.7.1" 35 | structlog = "^24.1.0" 36 | rich = "^13.7.1" 37 | comet-ml = "^3.41.0" 38 | opik = "1.0.1" 39 | ruff = "^0.4.3" 40 | pandas = "^2.0.3" 41 | datasets = "^2.19.1" 42 | scikit-learn = "^1.4.2" 43 | unstructured = "^0.14.2" 44 | litellm = "^1.50.4" 45 | langchain = "^0.2.11" 46 | langchain-openai = "^0.1.3" 47 | langchain-community = "^0.2.11" 48 | html2text = "^2024.2.26" 49 | huggingface-hub = "0.25.1" 50 | sagemaker = ">=2.232.2" 51 | sentence-transformers = "^2.2.2" 52 | gradio = "^5.5.0" 53 | 54 | [tool.poetry.group.feature_pipeline.dependencies] 55 | bytewax = "0.18.2" 56 | 57 | [tool.poetry.group.superlinked_rag.dependencies] 58 | superlinked = "^7.2.1" 59 | 60 | [build-system] 61 | requires = ["poetry-core"] 62 | build-backend = "poetry.core.masonry.api" 63 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/README.md: -------------------------------------------------------------------------------- 1 | # Dependencies 2 | 3 | - [Docker](https://www.docker.com/) 4 | - [Poetry](https://python-poetry.org/) 5 | - [PyEnv](https://github.com/pyenv/pyenv) 6 | - [Python](https://www.python.org/) 7 | - [GNU Make](https://www.gnu.org/software/make/) 8 | 9 | # Install 10 | 11 | ## 1. Start the Superlinked server 12 | 13 | Make sure you have `pyenv` installed. For example, on MacOS you can install it as: 14 | ```shell 15 | brew update 16 | brew install pyenv 17 | ``` 18 | 19 | Now, let's start the Superlinked server by running the following commands: 20 | ```shell 21 | # Create a virtual environment and install all necessary dependencies to deploy the server. 22 | cd 6-bonus-superlinked-rag/server 23 | ./tools/init-venv.sh 24 | cd runner 25 | source "$(poetry env info --path)/bin/activate" 26 | cd .. 27 | 28 | # Make sure you have your docker engine running (e.g. start the docker desktop app). 29 | ./tools/deploy.py up 30 | ``` 31 | 32 | > [!NOTE] 33 | > After the server started, you can check out it works and also it's API at http://localhost:8080/docs/ 34 | 35 | You can test that the Superlinked server started successfully by running the following command from the `root directory` of the `llm-twin-course`: 36 | ``` 37 | make test-superlinked-server 38 | ``` 39 | You should see that some mock data has been sent to the Superlinked server and it was queried successfully. 40 | 41 | ## 2. Start the rest of the infrastructure 42 | 43 | After your Superlinked server successfully started (from the root of the repository), run the following to start all necessary components to run locally the LLM twin project powered by Superlinked: 44 | ```shell 45 | make local-start-superlinked 46 | ``` 47 | 48 | > [!IMPORTANT] 49 | > Before starting, ensure you have your `.env` file filled with everything required to run the system. 50 | > 51 | > For more details on setting up the local infrastructure, you can check out the course's main [INSTALL_AND_USAGE](https://github.com/decodingml/llm-twin-course/blob/main/INSTALL_AND_USAGE.md) document. 52 | 53 | To stop the local infrastructure, run: 54 | ```shell 55 | make local-stop-superlinked 56 | ``` 57 | 58 | > [!NOTE] 59 | > After running the ingestion pipeline, you can visualize what's inside the Redis vector DB at http://localhost:8001/redis-stack/browser 60 | 61 | 62 | # Usage 63 | 64 | To trigger the ingestion, run: 65 | ```shell 66 | make local-test-github 67 | # #OR 68 | make local-test-medium 69 | ``` 70 | You can use other GitHub or Medium links to populate the vector DB with more data. 71 | 72 | To call the retrieval module and query the Superlinked server & vector DB, run: 73 | ```shell 74 | make local-test-retriever-superlinked 75 | ``` 76 | 77 | > [!IMPORTANT] 78 | > You can check out the main [INSTALL_AND_USAGE](https://github.com/decodingml/llm-twin-course/blob/main/INSTALL_AND_USAGE.md) document of the course for more details on an end-to-end flow. 79 | 80 | 81 | # Next steps 82 | 83 | If you enjoyed our [Superlinked](https://github.com/superlinked/superlinked?utm_source=community&utm_medium=github&utm_campaign=oscourse) bonus series, we recommend checking out their site for more examples. As Superlinked is not just a RAG tool but a general vector compute engine, you can build other awesome stuff with it, such as recommender systems. 84 | 85 | → 🔗 More on [Superlinked](https://github.com/superlinked/superlinked?utm_source=community&utm_medium=github&utm_campaign=oscourse) ← 86 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | 3 | 4 | class Settings(BaseSettings): 5 | # Embeddings config 6 | EMBEDDING_MODEL_ID: str = "BAAI/bge-small-en-v1.5" 7 | 8 | # MQ config 9 | RABBITMQ_DEFAULT_USERNAME: str = "guest" 10 | RABBITMQ_DEFAULT_PASSWORD: str = "guest" 11 | RABBITMQ_HOST: str = "mq" # or localhost if running outside Docker 12 | RABBITMQ_PORT: int = 5672 13 | RABBITMQ_QUEUE_NAME: str = "default" 14 | 15 | # Superlinked 16 | SUPERLINKED_SERVER_URL: str = ( 17 | "http://executor:8080" # # or http://localhost:8080 if running outside Docker 18 | ) 19 | 20 | # OpenAI 21 | OPENAI_MODEL_ID: str = "gpt-4o-mini" 22 | OPENAI_API_KEY: str | None = None 23 | 24 | 25 | settings = Settings() 26 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/data_flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/data_flow/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/data_flow/stream_input.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from datetime import datetime 4 | from typing import Generic, Iterable, List, Optional, TypeVar 5 | 6 | from bytewax.inputs import FixedPartitionedSource, StatefulSourcePartition 7 | from config import settings 8 | from mq import RabbitMQConnection 9 | from utils.logging import get_logger 10 | 11 | logger = get_logger(__name__) 12 | 13 | DataT = TypeVar("DataT") 14 | MessageT = TypeVar("MessageT") 15 | 16 | 17 | class RabbitMQSource(FixedPartitionedSource, Generic[DataT, MessageT]): 18 | def list_parts(self) -> List[str]: 19 | return ["single partition"] 20 | 21 | def build_part( 22 | self, now: datetime, for_part: str, resume_state: MessageT | None = None 23 | ) -> StatefulSourcePartition[DataT, MessageT]: 24 | return RabbitMQPartition(queue_name=settings.RABBITMQ_QUEUE_NAME) 25 | 26 | 27 | class RabbitMQPartition(StatefulSourcePartition, Generic[DataT, MessageT]): 28 | """ 29 | Class responsible for creating a connection between bytewax and rabbitmq that facilitates the transfer of data from mq to bytewax streaming piepline. 30 | Inherits StatefulSourcePartition for snapshot functionality that enables saving the state of the queue 31 | """ 32 | 33 | def __init__(self, queue_name: str, resume_state: MessageT | None = None) -> None: 34 | self._in_flight_msg_ids = resume_state or set() 35 | self.queue_name = queue_name 36 | self.connection = RabbitMQConnection() 37 | 38 | try: 39 | self.connection.connect() 40 | self.channel = self.connection.get_channel() 41 | except Exception: 42 | logger.warning( 43 | f"Error while trying to connect to the queue and get the current channel {self.queue_name}", 44 | ) 45 | 46 | def next_batch(self, sched: Optional[datetime]) -> Iterable[DataT]: 47 | try: 48 | method_frame, header_frame, body = self.channel.basic_get( 49 | queue=self.queue_name, auto_ack=True 50 | ) 51 | except Exception: 52 | logger.warning( 53 | f"Error while fetching message from queue: {self.queue_name}", 54 | ) 55 | time.sleep(10) # Sleep for 10 seconds before retrying to access the queue. 56 | 57 | self.connection.connect() 58 | self.channel = self.connection.get_channel() 59 | 60 | return [] 61 | 62 | if method_frame: 63 | message_id = method_frame.delivery_tag 64 | self._in_flight_msg_ids.add(message_id) 65 | 66 | return [json.loads(body)] 67 | else: 68 | return [] 69 | 70 | def snapshot(self) -> MessageT: 71 | return self._in_flight_msg_ids 72 | 73 | def garbage_collect(self, state): 74 | closed_in_flight_msg_ids = state 75 | for msg_id in closed_in_flight_msg_ids: 76 | self.channel.basic_ack(delivery_tag=msg_id) 77 | self._in_flight_msg_ids.remove(msg_id) 78 | 79 | def close(self): 80 | self.channel.close() 81 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/data_flow/stream_output.py: -------------------------------------------------------------------------------- 1 | from bytewax.outputs import DynamicSink, StatelessSinkPartition 2 | from models.documents import Document 3 | from superlinked_client import SuperlinkedClient 4 | from tqdm import tqdm 5 | from utils.logging import get_logger 6 | 7 | logger = get_logger(__name__) 8 | 9 | 10 | class SuperlinkedOutputSink(DynamicSink): 11 | def __init__(self, client: SuperlinkedClient) -> None: 12 | self._client = client 13 | 14 | def build(self, worker_index: int, worker_count: int) -> StatelessSinkPartition: 15 | return SuperlinkedSinkPartition(client=self._client) 16 | 17 | 18 | class SuperlinkedSinkPartition(StatelessSinkPartition): 19 | def __init__(self, client: SuperlinkedClient): 20 | self._client = client 21 | 22 | def write_batch(self, items: list[Document]) -> None: 23 | for item in tqdm(items, desc="Sending items to Superlinked..."): 24 | match item.type: 25 | case "repositories": 26 | self._client.ingest_repository(item) 27 | case "posts": 28 | self._client.ingest_post(item) 29 | case "articles": 30 | self._client.ingest_article(item) 31 | case _: 32 | logger.error(f"Unknown item type: {item.type}") 33 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/data_logic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/data_logic/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/data_logic/cleaning_data_handlers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from models.documents import ArticleDocument, Document, PostDocument, RepositoryDocument 4 | from models.raw import ArticleRawModel, PostsRawModel, RawModel, RepositoryRawModel 5 | from utils.cleaning import clean_text 6 | 7 | from .splitters import split_text 8 | 9 | 10 | class CleaningDataHandler(ABC): 11 | """ 12 | Abstract class for all cleaning data handlers. 13 | All data transformations logic for the cleaning step is done here 14 | """ 15 | 16 | @abstractmethod 17 | def clean(self, data_model: RawModel) -> list[Document]: 18 | pass 19 | 20 | 21 | class PostCleaningHandler(CleaningDataHandler): 22 | def clean(self, data_model: PostsRawModel) -> list[PostDocument]: 23 | documents = [] 24 | cleaned_text = clean_text("".join(data_model.content.values())) 25 | for post_subsection in split_text(cleaned_text): 26 | documents.append( 27 | PostDocument( 28 | id=data_model.id, 29 | platform=data_model.platform, 30 | content=post_subsection, 31 | author_id=data_model.author_id, 32 | type=data_model.type, 33 | ) 34 | ) 35 | 36 | return documents 37 | 38 | 39 | class ArticleCleaningHandler(CleaningDataHandler): 40 | def clean(self, data_model: ArticleRawModel) -> list[ArticleDocument]: 41 | documents = [] 42 | cleaned_text = clean_text("".join(data_model.content.values())) 43 | for article_subsection in split_text(cleaned_text): 44 | documents.append( 45 | ArticleDocument( 46 | id=data_model.id, 47 | platform=data_model.platform, 48 | link=data_model.link, 49 | content=article_subsection, 50 | author_id=data_model.author_id, 51 | type=data_model.type, 52 | ) 53 | ) 54 | 55 | return documents 56 | 57 | 58 | class RepositoryCleaningHandler(CleaningDataHandler): 59 | def clean(self, data_model: RepositoryRawModel) -> list[RepositoryDocument]: 60 | documents = [] 61 | for file_name, file_content in data_model.content.items(): 62 | cleaned_file_content = clean_text(file_content) 63 | for file_subsection in split_text(cleaned_file_content): 64 | documents.append( 65 | RepositoryDocument( 66 | id=data_model.id, 67 | platform=data_model.platform, 68 | name=f"{data_model.name}:{file_name}", 69 | link=data_model.link, 70 | content=file_subsection, 71 | author_id=data_model.owner_id, 72 | type=data_model.type, 73 | ) 74 | ) 75 | 76 | return documents 77 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/data_logic/dispatchers.py: -------------------------------------------------------------------------------- 1 | from data_logic.cleaning_data_handlers import ( 2 | ArticleCleaningHandler, 3 | CleaningDataHandler, 4 | PostCleaningHandler, 5 | RepositoryCleaningHandler, 6 | ) 7 | from models.documents import Document 8 | from models.raw import ArticleRawModel, PostsRawModel, RawModel, RepositoryRawModel 9 | from utils.logging import get_logger 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | class RawDispatcher: 15 | @staticmethod 16 | def handle_mq_message(message: dict) -> RawModel: 17 | data_type = message.get("type") 18 | 19 | logger.info("Received raw message.", data_type=data_type) 20 | 21 | if data_type == "posts": 22 | return PostsRawModel(**message) 23 | elif data_type == "articles": 24 | return ArticleRawModel(**message) 25 | elif data_type == "repositories": 26 | return RepositoryRawModel(**message) 27 | else: 28 | raise ValueError(f"Unsupported data type: {data_type}") 29 | 30 | 31 | class CleaningHandlerFactory: 32 | @staticmethod 33 | def create_handler(data_type: str) -> CleaningDataHandler: 34 | if data_type == "posts": 35 | return PostCleaningHandler() 36 | elif data_type == "articles": 37 | return ArticleCleaningHandler() 38 | elif data_type == "repositories": 39 | return RepositoryCleaningHandler() 40 | else: 41 | raise ValueError("Unsupported data type") 42 | 43 | 44 | class CleaningDispatcher: 45 | cleaning_factory = CleaningHandlerFactory() 46 | 47 | @classmethod 48 | def dispatch_cleaner(cls, data_model: RawModel) -> list[Document]: 49 | logger.info("Cleaning data.", data_type=data_model.type) 50 | 51 | data_type = data_model.type 52 | handler = cls.cleaning_factory.create_handler(data_type) 53 | cleaned_models = handler.clean(data_model) 54 | 55 | logger.info( 56 | "Data cleaned successfully.", 57 | data_type=data_type, 58 | len_cleaned_documents=len(cleaned_models), 59 | len_content=sum([len(doc.content) for doc in cleaned_models]), 60 | ) 61 | 62 | return cleaned_models 63 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/data_logic/splitters.py: -------------------------------------------------------------------------------- 1 | from langchain_text_splitters import RecursiveCharacterTextSplitter 2 | 3 | 4 | def split_text(text: str) -> list[str]: 5 | character_splitter = RecursiveCharacterTextSplitter( 6 | separators=["\n\n"], chunk_size=2000, chunk_overlap=0 7 | ) 8 | chunks = character_splitter.split_text(text) 9 | 10 | return chunks 11 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/llm/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/llm/chain.py: -------------------------------------------------------------------------------- 1 | from langchain.chains.llm import LLMChain 2 | from langchain.prompts import PromptTemplate 3 | 4 | 5 | class GeneralChain: 6 | @staticmethod 7 | def get_chain(llm, template: PromptTemplate, output_key: str, verbose=True): 8 | return LLMChain( 9 | llm=llm, prompt=template, output_key=output_key, verbose=verbose 10 | ) 11 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/llm/prompt_templates.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from langchain.prompts import PromptTemplate 4 | from pydantic import BaseModel 5 | 6 | 7 | class BasePromptTemplate(ABC, BaseModel): 8 | @abstractmethod 9 | def create_template(self) -> PromptTemplate: 10 | pass 11 | 12 | 13 | class QueryExpansionTemplate(BasePromptTemplate): 14 | prompt: str = """You are an AI language model assistant. Your task is to generate {to_expand_to_n} 15 | different versions of the given user question to retrieve relevant documents from a vector 16 | database. By generating multiple perspectives on the user question, your goal is to help 17 | the user overcome some of the limitations of the distance-based similarity search. 18 | Provide these alternative questions separated by '{separator}'. 19 | Original question: {question}""" 20 | 21 | @property 22 | def separator(self) -> str: 23 | return "#next-question#" 24 | 25 | def create_template(self, to_expand_to_n: int) -> PromptTemplate: 26 | return PromptTemplate( 27 | template=self.prompt, 28 | input_variables=["question"], 29 | partial_variables={ 30 | "separator": self.separator, 31 | "to_expand_to_n": to_expand_to_n, 32 | }, 33 | ) 34 | 35 | 36 | class SelfQueryTemplate(BasePromptTemplate): 37 | prompt: str = """You are an AI language model assistant. Your task is to extract information from a user question. 38 | The required information that needs to be extracted is the user or author id. 39 | Your response should consists of only the extracted id (e.g. 1345256), nothing else. 40 | If you cannot find the author id, return the string "None". 41 | User question: {question}""" 42 | 43 | def create_template(self) -> PromptTemplate: 44 | return PromptTemplate(template=self.prompt, input_variables=["question"]) 45 | 46 | 47 | class RerankingTemplate(BasePromptTemplate): 48 | prompt: str = """You are an AI language model assistant. Your task is to rerank passages related to a query 49 | based on their relevance. 50 | The most relevant passages should be put at the beginning. If no passes are relevant, keep the original order. Even if you find duplicates, return them all. 51 | You should only pick at max {keep_top_k} passages. When no passages are relevant, pick at max top {keep_top_k} passages. 52 | The provided and reranked documents are separated by '{separator}'. 53 | 54 | The following are passages related to this query: {question}. 55 | 56 | Passages: 57 | {passages} 58 | """ 59 | 60 | def create_template(self, keep_top_k: int) -> PromptTemplate: 61 | return PromptTemplate( 62 | template=self.prompt, 63 | input_variables=["question", "passages"], 64 | partial_variables={"keep_top_k": keep_top_k, "separator": self.separator}, 65 | ) 66 | 67 | @property 68 | def separator(self) -> str: 69 | return "\n#next-document#\n" 70 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/main.py: -------------------------------------------------------------------------------- 1 | import bytewax.operators as op 2 | from bytewax.dataflow import Dataflow 3 | 4 | from data_flow.stream_input import RabbitMQSource 5 | from data_flow.stream_output import SuperlinkedOutputSink 6 | from data_logic.dispatchers import ( 7 | CleaningDispatcher, 8 | RawDispatcher, 9 | ) 10 | from superlinked_client import SuperlinkedClient 11 | 12 | 13 | flow = Dataflow("Streaming RAG feature pipeline") 14 | stream = op.input("input", flow, RabbitMQSource()) 15 | stream = op.map("raw", stream, RawDispatcher.handle_mq_message) 16 | stream = op.map("clean", stream, CleaningDispatcher.dispatch_cleaner) 17 | stream = op.flatten("flatten_final_output", stream) 18 | op.output( 19 | "superlinked_output", 20 | stream, 21 | SuperlinkedOutputSink(client=SuperlinkedClient()), 22 | ) 23 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import documents, raw, utils 2 | 3 | __all__ = ["documents", "raw", "utils"] 4 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/models/documents.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import Annotated 2 | from pydantic import BaseModel, BeforeValidator 3 | 4 | 5 | class PostDocument(BaseModel): 6 | id: str 7 | platform: Annotated[str, BeforeValidator(str.lower)] 8 | content: str 9 | author_id: str 10 | type: str 11 | 12 | 13 | class ArticleDocument(BaseModel): 14 | id: str 15 | platform: Annotated[str, BeforeValidator(str.lower)] 16 | link: str 17 | content: str 18 | author_id: str 19 | type: str 20 | 21 | 22 | class RepositoryDocument(BaseModel): 23 | id: str 24 | platform: Annotated[str, BeforeValidator(str.lower)] 25 | name: str 26 | link: str 27 | content: str 28 | author_id: str 29 | type: str 30 | 31 | 32 | Document = PostDocument | ArticleDocument | RepositoryDocument 33 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/models/raw.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class RepositoryRawModel(BaseModel): 7 | id: str = Field(alias="entry_id") 8 | type: str 9 | platform: str = "github" 10 | name: str 11 | link: str 12 | content: dict 13 | owner_id: str 14 | 15 | 16 | class ArticleRawModel(BaseModel): 17 | id: str = Field(alias="entry_id") 18 | type: str 19 | platform: str 20 | link: str 21 | content: dict 22 | author_id: str 23 | 24 | 25 | class PostsRawModel(BaseModel): 26 | id: str = Field(alias="entry_id") 27 | type: str 28 | platform: str 29 | content: dict 30 | author_id: str | None = None 31 | image: Optional[str] = None 32 | 33 | 34 | RawModel = RepositoryRawModel | ArticleRawModel | PostsRawModel 35 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/models/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional 2 | 3 | import pandas as pd 4 | from pydantic import BaseModel 5 | 6 | from models.documents import Document 7 | 8 | 9 | def pydantic_models_to_dataframe( 10 | models: List[BaseModel], index_column: Optional[str] = "id" 11 | ) -> pd.DataFrame: 12 | """ 13 | Converts a list of Pydantic models to a Pandas DataFrame. 14 | 15 | Args: 16 | models (List[BaseModel]): List of Pydantic models. 17 | 18 | Returns: 19 | pd.DataFrame: DataFrame containing the data from the Pydantic models. 20 | """ 21 | 22 | if not models: 23 | return pd.DataFrame() 24 | 25 | # Convert each model to a dictionary and create a list of dictionaries 26 | data = [model.model_dump() for model in models] 27 | 28 | # Create a DataFrame from the list of dictionaries 29 | df = pd.DataFrame(data) 30 | 31 | if index_column in df.columns: 32 | df["index"] = df[index_column] 33 | else: 34 | raise RuntimeError(f"Index column '{index_column}' not found in DataFrame.") 35 | 36 | return df 37 | 38 | 39 | def group_by_type(documents: list[Document]) -> Dict[str, list[Document]]: 40 | return _group_by(documents, selector=lambda doc: doc.type) 41 | 42 | 43 | def _group_by(documents: list[Document], selector: Callable) -> Dict[Any, list]: 44 | grouped = {} 45 | for doc in documents: 46 | key = selector(doc) 47 | 48 | if key not in grouped: 49 | grouped[key] = [] 50 | grouped[key].append(doc) 51 | 52 | return grouped 53 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/mq.py: -------------------------------------------------------------------------------- 1 | import pika 2 | 3 | from utils.logging import get_logger 4 | from config import settings 5 | 6 | logger = get_logger(__name__) 7 | 8 | 9 | class RabbitMQConnection: 10 | _instance = None 11 | 12 | def __new__( 13 | cls, 14 | host: str | None = None, 15 | port: int | None = None, 16 | username: str | None = None, 17 | password: str | None = None, 18 | virtual_host: str = "/", 19 | ): 20 | if not cls._instance: 21 | cls._instance = super().__new__(cls) 22 | 23 | return cls._instance 24 | 25 | def __init__( 26 | self, 27 | host: str | None = None, 28 | port: int | None = None, 29 | username: str | None = None, 30 | password: str | None = None, 31 | virtual_host: str = "/", 32 | fail_silently: bool = False, 33 | **kwargs, 34 | ): 35 | self.host = host or settings.RABBITMQ_HOST 36 | self.port = port or settings.RABBITMQ_PORT 37 | self.username = username or settings.RABBITMQ_DEFAULT_USERNAME 38 | self.password = password or settings.RABBITMQ_DEFAULT_PASSWORD 39 | self.virtual_host = virtual_host 40 | self.fail_silently = fail_silently 41 | self._connection = None 42 | 43 | def __enter__(self): 44 | self.connect() 45 | return self 46 | 47 | def __exit__(self, exc_type, exc_val, exc_tb): 48 | self.close() 49 | 50 | def connect(self) -> None: 51 | logger.info("Trying to connect to RabbitMQ.", host=self.host, port=self.port) 52 | 53 | try: 54 | credentials = pika.PlainCredentials(self.username, self.password) 55 | self._connection = pika.BlockingConnection( 56 | pika.ConnectionParameters( 57 | host=self.host, 58 | port=self.port, 59 | virtual_host=self.virtual_host, 60 | credentials=credentials, 61 | ) 62 | ) 63 | except pika.exceptions.AMQPConnectionError as e: 64 | logger.warning("Failed to connect to RabbitMQ.") 65 | 66 | if not self.fail_silently: 67 | raise e 68 | 69 | def publish_message(self, data: str, queue: str): 70 | channel = self.get_channel() 71 | channel.queue_declare( 72 | queue=queue, durable=True, exclusive=False, auto_delete=False 73 | ) 74 | channel.confirm_delivery() 75 | 76 | try: 77 | channel.basic_publish( 78 | exchange="", routing_key=queue, body=data, mandatory=True 79 | ) 80 | logger.info( 81 | "Sent message successfully.", queue_type="RabbitMQ", queue_name=queue 82 | ) 83 | except pika.exceptions.UnroutableError: 84 | logger.info( 85 | "Failed to send the message.", queue_type="RabbitMQ", queue_name=queue 86 | ) 87 | 88 | def is_connected(self) -> bool: 89 | return self._connection is not None and self._connection.is_open 90 | 91 | def get_channel(self): 92 | if self.is_connected(): 93 | return self._connection.channel() 94 | 95 | def close(self): 96 | if self.is_connected(): 97 | self._connection.close() 98 | self._connection = None 99 | 100 | logger.info("Closed RabbitMQ connection.") 101 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/rag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/rag/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/rag/query_expanison.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | 3 | from llm.chain import GeneralChain 4 | from llm.prompt_templates import QueryExpansionTemplate 5 | from config import settings 6 | 7 | 8 | class QueryExpansion: 9 | @staticmethod 10 | def generate_response(query: str, to_expand_to_n: int) -> list[str]: 11 | query_expansion_template = QueryExpansionTemplate() 12 | prompt_template = query_expansion_template.create_template(to_expand_to_n) 13 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID, temperature=0) 14 | 15 | chain = GeneralChain().get_chain( 16 | llm=model, output_key="expanded_queries", template=prompt_template 17 | ) 18 | 19 | response = chain.invoke({"question": query}) 20 | result = response["expanded_queries"] 21 | 22 | queries = result.strip().split(query_expansion_template.separator) 23 | stripped_queries = [ 24 | stripped_item for item in queries if (stripped_item := item.strip()) 25 | ] 26 | 27 | return stripped_queries 28 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/rag/reranking.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | 3 | from llm.chain import GeneralChain 4 | from llm.prompt_templates import RerankingTemplate 5 | from config import settings 6 | 7 | 8 | class Reranker: 9 | @staticmethod 10 | def generate_response( 11 | query: str, passages: list[str], keep_top_k: int 12 | ) -> list[str]: 13 | reranking_template = RerankingTemplate() 14 | prompt_template = reranking_template.create_template(keep_top_k=keep_top_k) 15 | 16 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID) 17 | chain = GeneralChain().get_chain( 18 | llm=model, output_key="rerank", template=prompt_template 19 | ) 20 | 21 | stripped_passages = [ 22 | stripped_item for item in passages if (stripped_item := item.strip()) 23 | ] 24 | passages = reranking_template.separator.join(stripped_passages) 25 | response = chain.invoke({"question": query, "passages": passages}) 26 | 27 | result = response["rerank"] 28 | reranked_passages = result.strip().split(reranking_template.separator) 29 | stripped_passages = [ 30 | stripped_item 31 | for item in reranked_passages 32 | if (stripped_item := item.strip()) 33 | ] 34 | 35 | return stripped_passages 36 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/rag/self_query.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | from llm.chain import GeneralChain 3 | from llm.prompt_templates import SelfQueryTemplate 4 | from config import settings 5 | 6 | 7 | class SelfQuery: 8 | @staticmethod 9 | def generate_response(query: str) -> str | None: 10 | prompt = SelfQueryTemplate().create_template() 11 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID, temperature=0) 12 | 13 | chain = GeneralChain().get_chain( 14 | llm=model, output_key="metadata_filter_value", template=prompt 15 | ) 16 | 17 | response = chain.invoke({"question": query}) 18 | result = response.get("metadata_filter_value", "none") 19 | 20 | if result.lower() == "none": 21 | return None 22 | 23 | return result 24 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/retriever.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | from langchain.globals import set_verbose 3 | from rag.retriever import VectorRetriever 4 | 5 | from utils.logging import get_logger 6 | 7 | set_verbose(False) 8 | 9 | logger = get_logger(__name__) 10 | 11 | if __name__ == "__main__": 12 | load_dotenv() 13 | query = """ 14 | I am author_4. 15 | 16 | Could you please draft a LinkedIn post discussing RAG systems? 17 | I'm particularly interested in how RAG works and how it is integrated with vector DBs and large language models (LLMs). 18 | """ 19 | retriever = VectorRetriever(query=query) 20 | hits = retriever.retrieve_top_k(k=6, to_expand_to_n_queries=3) 21 | 22 | reranked_hits = retriever.rerank(documents=hits, keep_top_k=3) 23 | 24 | logger.info("Reranked hits:") 25 | for rank, hit in enumerate(reranked_hits): 26 | logger.info(f"{rank}: {hit[:100]}...") 27 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/scripts/bytewax_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$DEBUG" = true ] 4 | then 5 | python -m bytewax.run "tools.run_real_time:build_flow(debug=True)" 6 | else 7 | if [ "$BYTEWAX_PYTHON_FILE_PATH" = "" ] 8 | then 9 | echo 'BYTEWAX_PYTHON_FILE_PATH is not set. Exiting...' 10 | exit 1 11 | fi 12 | python -m bytewax.run $BYTEWAX_PYTHON_FILE_PATH 13 | fi 14 | 15 | 16 | echo 'Process ended.' 17 | 18 | if [ "$BYTEWAX_KEEP_CONTAINER_ALIVE" = true ] 19 | then 20 | echo 'Keeping container alive...'; 21 | while :; do sleep 1; done 22 | fi -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | poller: 5 | build: 6 | context: runner 7 | dockerfile: poller/Dockerfile 8 | mem_limit: 256m 9 | cpus: '0.5' 10 | volumes: 11 | - ./cache:/code/data 12 | - ./config:/code/config 13 | - ./src:/src 14 | 15 | executor: 16 | depends_on: 17 | - poller 18 | - redis 19 | build: 20 | context: runner 21 | dockerfile: executor/Dockerfile 22 | volumes: 23 | - ./cache:/code/data 24 | environment: 25 | - APP_MODULE_PATH=data.src.app 26 | ports: 27 | - 8080:8080 28 | env_file: 29 | - runner/executor/.env 30 | 31 | redis: 32 | image: redis/redis-stack:latest 33 | ports: 34 | - "6379:6379" 35 | - "8001:8001" 36 | volumes: 37 | - redis-data:/data 38 | 39 | volumes: 40 | redis-data: 41 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/config/aws_credentials.json: -------------------------------------------------------------------------------- 1 | { 2 | "accessKeyId": , 3 | "region": , 4 | "secretAccessKey": 5 | } -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/config/config.yaml: -------------------------------------------------------------------------------- 1 | app_location: local 2 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/config/gcp_credentials.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "service_account", 3 | "project_id": "PROJECT_ID", 4 | "private_key_id": "KEY_ID", 5 | "private_key": "-----BEGIN PRIVATE KEY-----\nPRIVATE_KEY\n-----END PRIVATE KEY-----\n", 6 | "client_email": "SERVICE_ACCOUNT_EMAIL", 7 | "client_id": "CLIENT_ID", 8 | "auth_uri": "https://accounts.google.com/o/oauth2/auth", 9 | "token_uri": "https://accounts.google.com/o/oauth2/token", 10 | "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", 11 | "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/SERVICE_ACCOUNT_EMAIL" 12 | } -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/docs/bucket.md: -------------------------------------------------------------------------------- 1 | ## What are Storage Buckets? 2 | 3 | Storage buckets are fundamental entities used to store data in cloud storage services. They are essentially containers for data where you can upload, download, and manage files. Two of the most popular cloud storage services are Amazon S3 (Simple Storage Service) and Google Cloud Storage (GCS). 4 | 5 | ### Amazon S3 6 | 7 | Amazon S3 buckets are used to store objects, which consist of data and its descriptive metadata. They are globally unique, and defined at the Universal Resource Locator (URL) level. 8 | 9 | ### Google Cloud Storage (GCS) 10 | 11 | Google Cloud Storage buckets are similar to Amazon S3 buckets. They are used to store data objects in Google Cloud. Each bucket is associated with a specific project, and you can choose to make the bucket data public or private. 12 | 13 | ## Purpose in Our System 14 | 15 | In our system, we use these storage buckets to store user code. When you write and save your code, it gets stored in either an S3 or GCS bucket. This allows us to securely store your code, retrieve it when needed, and even share it among multiple services or instances if required. 16 | 17 | ## How to Create Buckets 18 | 19 | ### Creating an Amazon S3 Bucket using AWS CLI 20 | 21 | To create an Amazon S3 bucket, you can use the AWS CLI (Command Line Interface). Here's how: 22 | 23 | 1. First, install the AWS CLI on your machine. You can find the installation instructions [here](https://aws.amazon.com/cli/). 24 | 2. Configure the AWS CLI with your credentials: 25 | 26 | ```bash 27 | aws configure 28 | ``` 29 | 30 | 1. Create a new S3 bucket: 31 | 32 | ```bash 33 | aws s3api create-bucket --bucket my-bucket-name --region us-west-2 34 | ``` 35 | 36 | Replace `my-bucket-name` with your desired bucket name and `us-west-2` with the AWS region you want to create your bucket in. For more information on S3 bucket creation, refer to the [AWS documentation](https://docs.aws.amazon.com/AmazonS3/latest/userguide/create-bucket-overview.html). 37 | 38 | ### Creating a Google Cloud Storage Bucket using `gsutil` 39 | 40 | To create a Google Cloud Storage bucket, you can use the `gsutil` tool. Here's how: 41 | 42 | 1. First, install the Google Cloud SDK, which includes the `gsutil` tool. You can find the installation instructions [here](https://cloud.google.com/sdk/docs/install). 43 | 2. Authenticate your account: 44 | 45 | ```bash 46 | gcloud auth login 47 | ``` 48 | 49 | 1. Create a new GCS bucket: 50 | 51 | ```bash 52 | gsutil mb -p project_id -c storage_class -l location gs://my-bucket-name 53 | ``` 54 | 55 | Replace `project_id` with the ID of your project, `storage_class` with the desired storage class for the bucket, `location` with the desired bucket location and `my-bucket-name` with your desired bucket name. For more information on GCS bucket creation, refer to the [Google Cloud Documentation](https://cloud.google.com/storage/docs/creating-buckets). 56 | 57 | Remember, the bucket names must be unique across all existing bucket names in Amazon S3 or Google Cloud Storage. 58 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/docs/dummy_app.py: -------------------------------------------------------------------------------- 1 | from superlinked.framework.common.schema.id_schema_object import IdField 2 | from superlinked.framework.common.schema.schema import schema 3 | from superlinked.framework.common.schema.schema_object import String 4 | from superlinked.framework.dsl.executor.rest.rest_configuration import RestQuery 5 | from superlinked.framework.dsl.executor.rest.rest_descriptor import RestDescriptor 6 | from superlinked.framework.dsl.executor.rest.rest_executor import RestExecutor 7 | from superlinked.framework.dsl.index.index import Index 8 | from superlinked.framework.dsl.query.param import Param 9 | from superlinked.framework.dsl.query.query import Query 10 | from superlinked.framework.dsl.registry.superlinked_registry import SuperlinkedRegistry 11 | from superlinked.framework.dsl.source.rest_source import RestSource 12 | from superlinked.framework.dsl.space.text_similarity_space import TextSimilaritySpace 13 | from superlinked.framework.dsl.storage.in_memory_vector_database import InMemoryVectorDatabase 14 | 15 | 16 | @schema 17 | class YourSchema: 18 | id: IdField 19 | attribute: String 20 | 21 | 22 | your_schema = YourSchema() 23 | 24 | text_space = TextSimilaritySpace(text=your_schema.attribute, model="model-name") 25 | 26 | index = Index(text_space) 27 | 28 | query = ( 29 | Query(index) 30 | .find(your_schema) 31 | .similar( 32 | text_space.text, 33 | Param("query_text"), 34 | ) 35 | ) 36 | 37 | your_source: RestSource = RestSource(your_schema) 38 | 39 | executor = RestExecutor( 40 | sources=[your_source], 41 | indices=[index], 42 | queries=[RestQuery(RestDescriptor("query"), query)], 43 | vector_database=InMemoryVectorDatabase(), 44 | ) 45 | 46 | SuperlinkedRegistry.register(executor) 47 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/docs/example/app.py: -------------------------------------------------------------------------------- 1 | from superlinked.framework.common.schema.id_schema_object import IdField 2 | from superlinked.framework.common.schema.schema import schema 3 | from superlinked.framework.common.schema.schema_object import String 4 | from superlinked.framework.dsl.executor.rest.rest_configuration import ( 5 | RestQuery, 6 | ) 7 | from superlinked.framework.dsl.executor.rest.rest_descriptor import RestDescriptor 8 | from superlinked.framework.dsl.executor.rest.rest_executor import RestExecutor 9 | from superlinked.framework.dsl.index.index import Index 10 | from superlinked.framework.dsl.query.param import Param 11 | from superlinked.framework.dsl.query.query import Query 12 | from superlinked.framework.dsl.registry.superlinked_registry import SuperlinkedRegistry 13 | from superlinked.framework.dsl.source.rest_source import RestSource 14 | from superlinked.framework.dsl.space.text_similarity_space import TextSimilaritySpace 15 | from superlinked.framework.dsl.storage.in_memory_vector_database import InMemoryVectorDatabase 16 | 17 | 18 | @schema 19 | class CarSchema: 20 | id: IdField 21 | make: String 22 | model: String 23 | 24 | 25 | car_schema = CarSchema() 26 | 27 | car_make_text_space = TextSimilaritySpace(text=car_schema.make, model="all-MiniLM-L6-v2") 28 | car_model_text_space = TextSimilaritySpace(text=car_schema.model, model="all-MiniLM-L6-v2") 29 | 30 | index = Index([car_make_text_space, car_model_text_space]) 31 | 32 | query = ( 33 | Query(index) 34 | .find(car_schema) 35 | .similar(car_make_text_space.text, Param("make")) 36 | .similar(car_model_text_space.text, Param("model")) 37 | .limit(Param("limit")) 38 | ) 39 | 40 | car_source: RestSource = RestSource(car_schema) 41 | 42 | executor = RestExecutor( 43 | sources=[car_source], 44 | indices=[index], 45 | queries=[RestQuery(RestDescriptor("query"), query)], 46 | vector_database=InMemoryVectorDatabase(), 47 | ) 48 | 49 | SuperlinkedRegistry.register(executor) 50 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/docs/mongodb/app_with_mongodb.py: -------------------------------------------------------------------------------- 1 | from superlinked.framework.common.schema.id_schema_object import IdField 2 | from superlinked.framework.common.schema.schema import schema 3 | from superlinked.framework.common.schema.schema_object import String 4 | from superlinked.framework.dsl.executor.rest.rest_configuration import ( 5 | RestQuery, 6 | ) 7 | from superlinked.framework.dsl.executor.rest.rest_descriptor import RestDescriptor 8 | from superlinked.framework.dsl.executor.rest.rest_executor import RestExecutor 9 | from superlinked.framework.dsl.index.index import Index 10 | from superlinked.framework.dsl.query.param import Param 11 | from superlinked.framework.dsl.query.query import Query 12 | from superlinked.framework.dsl.registry.superlinked_registry import SuperlinkedRegistry 13 | from superlinked.framework.dsl.source.rest_source import RestSource 14 | from superlinked.framework.dsl.space.text_similarity_space import TextSimilaritySpace 15 | from superlinked.framework.dsl.storage.mongo_vector_database import MongoVectorDatabase 16 | 17 | 18 | @schema 19 | class CarSchema: 20 | id: IdField 21 | make: String 22 | model: String 23 | 24 | 25 | car_schema = CarSchema() 26 | 27 | car_make_text_space = TextSimilaritySpace(text=car_schema.make, model="all-MiniLM-L6-v2") 28 | car_model_text_space = TextSimilaritySpace(text=car_schema.model, model="all-MiniLM-L6-v2") 29 | 30 | index = Index([car_make_text_space, car_model_text_space]) 31 | 32 | query = ( 33 | Query(index) 34 | .find(car_schema) 35 | .similar(car_make_text_space.text, Param("make")) 36 | .similar(car_model_text_space.text, Param("model")) 37 | .limit(Param("limit")) 38 | ) 39 | 40 | car_source: RestSource = RestSource(car_schema) 41 | 42 | mongo_vector_database = MongoVectorDatabase( 43 | ":@", 44 | "", 45 | "", 46 | "", 47 | "", 48 | "", 49 | ) 50 | 51 | executor = RestExecutor( 52 | sources=[car_source], 53 | indices=[index], 54 | queries=[RestQuery(RestDescriptor("query"), query)], 55 | vector_database=mongo_vector_database, 56 | ) 57 | 58 | SuperlinkedRegistry.register(executor) 59 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/docs/redis/app_with_redis.py: -------------------------------------------------------------------------------- 1 | from superlinked.framework.common.schema.id_schema_object import IdField 2 | from superlinked.framework.common.schema.schema import schema 3 | from superlinked.framework.common.schema.schema_object import String 4 | from superlinked.framework.dsl.executor.rest.rest_configuration import ( 5 | RestQuery, 6 | ) 7 | from superlinked.framework.dsl.executor.rest.rest_descriptor import RestDescriptor 8 | from superlinked.framework.dsl.executor.rest.rest_executor import RestExecutor 9 | from superlinked.framework.dsl.index.index import Index 10 | from superlinked.framework.dsl.query.param import Param 11 | from superlinked.framework.dsl.query.query import Query 12 | from superlinked.framework.dsl.registry.superlinked_registry import SuperlinkedRegistry 13 | from superlinked.framework.dsl.source.rest_source import RestSource 14 | from superlinked.framework.dsl.space.text_similarity_space import TextSimilaritySpace 15 | from superlinked.framework.dsl.storage.redis_vector_database import RedisVectorDatabase 16 | 17 | 18 | @schema 19 | class CarSchema: 20 | id: IdField 21 | make: String 22 | model: String 23 | 24 | 25 | car_schema = CarSchema() 26 | 27 | car_make_text_space = TextSimilaritySpace(text=car_schema.make, model="all-MiniLM-L6-v2") 28 | car_model_text_space = TextSimilaritySpace(text=car_schema.model, model="all-MiniLM-L6-v2") 29 | 30 | index = Index([car_make_text_space, car_model_text_space]) 31 | 32 | query = ( 33 | Query(index) 34 | .find(car_schema) 35 | .similar(car_make_text_space.text, Param("make")) 36 | .similar(car_model_text_space.text, Param("model")) 37 | .limit(Param("limit")) 38 | ) 39 | 40 | car_source: RestSource = RestSource(car_schema) 41 | 42 | redis_vector_database = RedisVectorDatabase("", 12345, username="default", password="") 43 | 44 | executor = RestExecutor( 45 | sources=[car_source], 46 | indices=[index], 47 | queries=[RestQuery(RestDescriptor("query"), query)], 48 | vector_database=redis_vector_database, 49 | ) 50 | 51 | SuperlinkedRegistry.register(executor) 52 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/docs/redis/redis.md: -------------------------------------------------------------------------------- 1 | # Redis 2 | 3 | This document provides clear steps on how to use and integrate Redis with Superlinked. 4 | 5 | ## Configuring your existing managed Redis 6 | 7 | To use Superlinked with Redis, you will need several Redis modules. The simplest approach is to use the official Redis Stack, which includes all the necessary modules. Installation instructions for the Redis Stack can be found in the [Redis official documentation](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/). Alternatively, you can start a managed instance provided by Redis (a free-tier is available). For detailed steps on initiating a managed instance, refer to the [Start a Managed Redis Instance](#start-a-managed-redis-instance) section below. 8 | 9 | Once your Redis instance is up and running, ensure it is accessible from the server that will use it. Additionally, configure the necessary authentication settings as described below. 10 | 11 | ## Modifications in app.py 12 | 13 | To integrate Redis, you need to add the `RedisVectorDatabase` class and include it in the executor. Here’s how you can do it: 14 | 15 | To configure your Redis, the following code will help you: 16 | ```python 17 | from superlinked.framework.dsl.storage.redis_vector_database import RedisVectorDatabase 18 | 19 | vector_database = RedisVectorDatabase( 20 | "", # (Mandatory) This is your redis URL without any port or extra fields 21 | 12315, # (Mandatory) This is the port and it should be an integer 22 | # These params must be in a form of kwarg params. Here you can specify anything that the official python client 23 | # enables. The params can be found here: https://redis.readthedocs.io/en/stable/connections.html. Below you can see a very basic user-pass authentication as an example. 24 | username="test", 25 | password="password" 26 | ) 27 | ``` 28 | 29 | Once you have configured the vector database just simply set it as your vector database. 30 | ```python 31 | ... 32 | executor = RestExecutor( 33 | sources=[source], 34 | indices=[index], 35 | queries=[RestQuery(RestDescriptor("query"), query)], 36 | vector_database=vector_database, # Or any variable that you assigned your `RedisVectorDatabase` 37 | ) 38 | ... 39 | ``` 40 | 41 | ## Start a Managed Redis Instance 42 | 43 | To initiate a managed Redis instance, navigate to [Redis Labs](https://app.redislabs.com/), sign in, and click the "New Database" button. On the ensuing page, locate the `Type` selector, which offers two options: `Redis Stack` and `Memcached`. By default, `Redis Stack` is pre-selected, which is the correct choice. If it is not selected, ensure to choose `Redis Stack`. For basic usage, no further configuration is necessary. Redis already generated a user which called `default` and a password that you can see below it. However, if you intend to use the instance for persistent data storage beyond sandbox purposes, consider configuring High Availability (HA), data persistence, and other relevant settings. 44 | 45 | ## Example app with Redis 46 | 47 | You can find an example that utilizes Redis [here](app_with_redis.py) 48 | 49 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/docs/vector_databases.md: -------------------------------------------------------------------------------- 1 | # Vector databases 2 | 3 | This document will list and point to the detailed documentation of the supported vector databases: 4 | 5 | - [Redis](redis/redis.md) 6 | - [MongoDB](mongodb/mongodb.md) 7 | 8 | Missing your favorite one? [Let us know in github discussions](https://github.com/superlinked/superlinked/discussions/41) 9 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/docs/vm.md: -------------------------------------------------------------------------------- 1 | ## What are Virtual Machines? 2 | 3 | Virtual Machines (VMs) are emulations of computer systems. They can run applications and services on a virtual platform, as if they were on a physical machine, while sharing the underlying hardware resources. 4 | 5 | ### VM size recommendations 6 | 7 | The minimum VM size is `t3.small` in AWS or `e2-small` in GCP. Both configurations offer 2 vCPUs and 2GB of RAM. Note that if you are utilizing in-memory storage, the 2GB memory capacity can only accommodate a few tens of thousands of records. For larger datasets, opt for a machine with more RAM. The data loader functionality is CPU-intensive, so if you plan to use it, consider provisioning additional vCPUs. 8 | 9 | The system comprises two components: poller and executor. The poller is a basic service, constrained in the Docker Compose file to use only 256MB of memory and 0.5 vCPU. The executor, which runs Superlinked along with a few minor additions, will consume the majority of the resources. It has no set limits in the Docker Compose file, so any resources not utilized by the poller will be allocated to the executor. 10 | 11 | ### Amazon EC2 12 | 13 | Amazon Elastic Compute Cloud (EC2) is a part of Amazon's cloud-computing platform, Amazon Web Services (AWS). EC2 allows users to rent virtual computers on which to run their own computer applications. 14 | 15 | ### Google Cloud Compute Engine 16 | 17 | Google Cloud Compute Engine delivers virtual machines running in Google's innovative data centers and worldwide fiber network. Compute Engine's tooling and workflow support enable scaling from single instances to global, load-balanced cloud computing. 18 | 19 | ## Creating an Amazon EC2 Instance using AWS CLI 20 | 21 | To create an Amazon EC2 instance, you can use the AWS CLI. Here's how: 22 | 23 | 1. Install and configure the AWS CLI as described in the previous section. 24 | 2. Launch an EC2 instance: 25 | 26 | ```bash 27 | aws ec2 run-instances --image-id ami-0abcdef1234567890 --count 1 --instance-type t3.small --key-name MyKeyPair --security-group-ids sg-903004f8 --subnet-id subnet-6e7f829e 28 | ``` 29 | 30 | Replace the `image-id`, `key-name`, `security-group-ids`, and `subnet-id` with your own values. The `image-id` is the ID of the AMI (Amazon Machine Image), `key-name` is the name of the key pair for the instance, `security-group-ids` is the ID of the security group, and `subnet-id` is the ID of the subnet. 31 | 32 | For more information on creating key pairs, refer to the [AWS documentation](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-key-pairs.html#having-ec2-create-your-key-pair). For a list of available regions, refer to the [AWS Regional Services List](https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services/). 33 | 34 | ## Creating a Google Cloud Compute Engine VM using gcloud 35 | 36 | To create a Google Cloud Compute Engine VM, you can use the `gcloud` command-line tool. Here's how: 37 | 38 | 1. Install and authenticate the Google Cloud SDK as described in the previous section. 39 | 2. Create a Compute Engine instance: 40 | 41 | ```bash 42 | gcloud compute instances create my-vm --machine-type=e2-small --image-project=debian-cloud --image-family=debian-9 --boot-disk-size=50GB 43 | ``` 44 | 45 | Replace `my-vm` with your desired instance name. The `machine-type` is the machine type of the VM, `image-project` is the project ID of the image, `image-family` is the family of the image, and `boot-disk-size` is the size of the boot disk. Please use a `boot-disk-size` value over 20 GB. 46 | 47 | Remember, you need to have the necessary permissions and quotas to create VM instances in both Amazon EC2 and Google Cloud Compute Engine. 48 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/Dockerfile: -------------------------------------------------------------------------------- 1 | # --- Python base --- 2 | 3 | FROM python:3.11.9 AS python-base 4 | 5 | ENV PYTHONUNBUFFERED=1 \ 6 | PYTHONDONTWRITEBYTECODE=1 \ 7 | \ 8 | PIP_NO_CACHE_DIR=off \ 9 | PIP_DISABLE_PIP_VERSION_CHECK=on \ 10 | PIP_DEFAULT_TIMEOUT=100 \ 11 | \ 12 | POETRY_VERSION=1.8.2 \ 13 | POETRY_HOME="/opt/poetry" \ 14 | POETRY_VIRTUALENVS_IN_PROJECT=true \ 15 | POETRY_NO_INTERACTION=1 \ 16 | \ 17 | PYSETUP_PATH="/opt/pysetup" \ 18 | VENV_PATH="/opt/pysetup/.venv" 19 | 20 | ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH" 21 | # --- Builder base --- 22 | 23 | FROM python-base AS builder-base 24 | RUN apt-get update \ 25 | && apt-get install --no-install-recommends -y \ 26 | curl \ 27 | build-essential \ 28 | && rm -rf /var/lib/apt/lists/* 29 | 30 | SHELL ["/bin/bash", "-o", "pipefail", "-c"] 31 | RUN curl -sSL https://install.python-poetry.org | python - 32 | 33 | # --- Executor base --- 34 | 35 | FROM builder-base as executor-base 36 | 37 | WORKDIR $PYSETUP_PATH 38 | 39 | COPY poetry.lock pyproject.toml $PYSETUP_PATH/ 40 | RUN poetry install --no-root --only executor 41 | 42 | # --- Executor --- 43 | 44 | FROM python:3.11.5-slim AS executor 45 | 46 | COPY --from=executor-base /opt/pysetup/.venv/ /opt/pysetup/.venv/ 47 | 48 | WORKDIR /code 49 | 50 | RUN apt-get update && apt-get install --no-install-recommends -y supervisor && apt-get clean && rm -rf /var/lib/apt/lists/* 51 | 52 | COPY executor/supervisord.conf /etc/supervisor/conf.d/supervisord.conf 53 | COPY executor/.env /code/executor/ 54 | COPY executor/app /code/executor/app 55 | 56 | EXPOSE 8080 57 | ENTRYPOINT ["/usr/bin/supervisord", "-n"] 58 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/configuration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/configuration/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/configuration/app_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pydantic_settings import BaseSettings, SettingsConfigDict 16 | 17 | 18 | class AppConfig(BaseSettings): 19 | SERVER_URL: str 20 | APP_MODULE_PATH: str 21 | LOG_LEVEL: str 22 | PERSISTENCE_FOLDER_PATH: str 23 | DISABLE_RECENCY_SPACE: bool 24 | 25 | model_config = SettingsConfigDict(env_file="executor/.env", extra="ignore") 26 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/dependency_register.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from xmlrpc.client import ServerProxy 16 | 17 | import inject 18 | 19 | from executor.app.configuration.app_config import AppConfig 20 | from executor.app.service.data_loader import DataLoader 21 | from executor.app.service.file_handler_service import FileHandlerService 22 | from executor.app.service.file_object_serializer import FileObjectSerializer 23 | from executor.app.service.persistence_service import PersistenceService 24 | from executor.app.service.supervisor_service import SupervisorService 25 | 26 | 27 | def register_dependencies() -> None: 28 | inject.configure(_configure) 29 | 30 | 31 | def _configure(binder: inject.Binder) -> None: 32 | app_config = AppConfig() 33 | file_handler_service = FileHandlerService(app_config) 34 | serializer = FileObjectSerializer(file_handler_service) 35 | server_proxy = ServerProxy(app_config.SERVER_URL) 36 | binder.bind(DataLoader, DataLoader(app_config)) 37 | binder.bind(PersistenceService, PersistenceService(serializer)) 38 | binder.bind(SupervisorService, SupervisorService(server_proxy)) 39 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/exception/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/exception/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/exception/exception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class UnsupportedProtocolException(Exception): 17 | pass 18 | 19 | 20 | class FilesNotFoundException(Exception): 21 | pass 22 | 23 | 24 | class DataLoaderNotFoundException(Exception): 25 | pass 26 | 27 | 28 | class DataLoaderAlreadyRunningException(Exception): 29 | pass 30 | 31 | 32 | class DataLoaderTaskNotFoundException(Exception): 33 | pass 34 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/exception/exception_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | from fastapi import Request, status 18 | from fastapi.responses import JSONResponse 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | async def handle_bad_request(_: Request, exception: Exception) -> JSONResponse: 24 | logger.exception("Bad request") 25 | return JSONResponse( 26 | status_code=status.HTTP_400_BAD_REQUEST, 27 | content={ 28 | "exception": str(exception.__class__.__name__), 29 | "detail": str(exception), 30 | }, 31 | ) 32 | 33 | 34 | async def handle_generic_exception(_: Request, exception: Exception) -> JSONResponse: 35 | logger.exception("Unexpected exception happened") 36 | return JSONResponse( 37 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 38 | content={ 39 | "exception": str(exception.__class__.__name__), 40 | "detail": str(exception), 41 | }, 42 | ) 43 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from json import JSONDecodeError 17 | 18 | import uvicorn 19 | from fastapi import FastAPI 20 | from fastapi_restful.timing import add_timing_middleware 21 | from superlinked.framework.common.parser.exception import MissingIdException 22 | from superlinked.framework.online.dag.exception import ValueNotProvidedException 23 | 24 | from executor.app.configuration.app_config import AppConfig 25 | from executor.app.dependency_register import register_dependencies 26 | from executor.app.exception.exception_handler import ( 27 | handle_bad_request, 28 | handle_generic_exception, 29 | ) 30 | from executor.app.middleware.lifespan_event import lifespan 31 | from executor.app.router.management_router import router as management_router 32 | 33 | app_config = AppConfig() 34 | 35 | logging.basicConfig(level=app_config.LOG_LEVEL) 36 | logger = logging.getLogger(__name__) 37 | logging.getLogger("sentence_transformers").setLevel(logging.WARNING) 38 | 39 | app = FastAPI(lifespan=lifespan) 40 | 41 | app.add_exception_handler(ValueNotProvidedException, handle_bad_request) 42 | app.add_exception_handler(MissingIdException, handle_bad_request) 43 | app.add_exception_handler(JSONDecodeError, handle_bad_request) 44 | app.add_exception_handler(Exception, handle_generic_exception) 45 | 46 | app.include_router(management_router) 47 | 48 | add_timing_middleware(app, record=logger.info) 49 | 50 | register_dependencies() 51 | 52 | if __name__ == "__main__": 53 | uvicorn.run(app, host="0.0.0.0", port=8080) # noqa: S104 hardcoded-bind-all-interfaces 54 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/middleware/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/middleware/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/router/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/router/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/service/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/service/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/service/file_handler_service.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | from enum import Enum 4 | 5 | from executor.app.configuration.app_config import AppConfig 6 | 7 | 8 | class HashType(Enum): 9 | MD5 = hashlib.md5 10 | 11 | 12 | class FileHandlerService: 13 | def __init__(self, app_config: AppConfig, hash_type: HashType | None = None) -> None: 14 | self.__hash_type = hash_type or HashType.MD5 15 | self.app_config = app_config 16 | 17 | def generate_filename(self, field_id: str, app_id: str) -> str: 18 | filename = self.__hash_type.value(f"{app_id}_{field_id}".encode()).hexdigest() 19 | return f"{self.app_config.PERSISTENCE_FOLDER_PATH}/{filename}.json" 20 | 21 | def ensure_folder(self) -> None: 22 | os.makedirs(self.app_config.PERSISTENCE_FOLDER_PATH, exist_ok=True) 23 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/service/file_object_serializer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import logging 17 | import os 18 | 19 | from superlinked.framework.storage.in_memory.object_serializer import ObjectSerializer 20 | 21 | from executor.app.service.file_handler_service import FileHandlerService 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | EMPTY_JSON_OBJECT_SIZE = 4 # 4 characters: "{}" 27 | 28 | 29 | class FileObjectSerializer(ObjectSerializer): 30 | def __init__(self, file_handler_service: FileHandlerService) -> None: 31 | super().__init__() 32 | self.__file_handler_service = file_handler_service 33 | 34 | def write(self, field_identifier: str, serialized_object: str, app_identifier: str) -> None: 35 | self.__file_handler_service.ensure_folder() 36 | 37 | logger.info("Persisting database with field id: %s and app id: %s", field_identifier, app_identifier) 38 | file_with_path = self.__file_handler_service.generate_filename(field_identifier, app_identifier) 39 | with open(file_with_path, "w", encoding="utf-8") as file: 40 | logger.debug("Writing field: %s and app: %s file to: %s", field_identifier, app_identifier, file_with_path) 41 | json.dump(serialized_object, file) 42 | 43 | def read(self, field_identifier: str, app_identifier: str) -> str: 44 | logger.info("Restoring database using field id: %s and app id: %s", field_identifier, app_identifier) 45 | file_with_path = self.__file_handler_service.generate_filename(field_identifier, app_identifier) 46 | 47 | result = "{}" 48 | try: 49 | if os.path.isfile(file_with_path): 50 | with open(file_with_path, encoding="utf-8") as file: 51 | if os.stat(file_with_path).st_size >= EMPTY_JSON_OBJECT_SIZE: 52 | result = json.load(file) 53 | except json.JSONDecodeError: 54 | logger.exception("File is present but contains invalid data. File: %s", file_with_path) 55 | except Exception: # pylint: disable=broad-exception-caught 56 | logger.exception("An error occurred during the file read operation. File: %s", file_with_path) 57 | return result 58 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/service/persistence_service.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | 17 | from superlinked.framework.dsl.executor.rest.rest_executor import RestApp 18 | 19 | from executor.app.service.file_object_serializer import FileObjectSerializer 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class PersistenceService: 25 | def __init__(self, serializer: FileObjectSerializer) -> None: 26 | self._applications: list[RestApp] = [] 27 | self._serializer = serializer 28 | 29 | def register(self, rest_app: RestApp) -> None: 30 | if rest_app in self._applications: 31 | logger.warning("Application already exists: %s", rest_app) 32 | return 33 | logger.info("Rest app registered: %s", rest_app) 34 | self._applications.append(rest_app) 35 | 36 | def persist(self) -> None: 37 | for app in self._applications: 38 | app.online_app.persist(self._serializer) 39 | 40 | def restore(self) -> None: 41 | for app in self._applications: 42 | app.online_app.restore(self._serializer) 43 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/service/supervisor_service.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from xmlrpc.client import ServerProxy 16 | 17 | 18 | class SupervisorService: 19 | def __init__(self, server_proxy: ServerProxy) -> None: 20 | self.server = server_proxy 21 | 22 | def restart(self) -> str: 23 | """Restart the API via supervisor XML-RPC and return the result.""" 24 | return str(self.server.supervisor.restart()) 25 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/util/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/util/fast_api_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any 16 | 17 | from fastapi import Request, Response, status 18 | from fastapi.responses import JSONResponse 19 | from pydantic import Field 20 | from superlinked.framework.common.util.immutable_model import ImmutableBaseModel 21 | from superlinked.framework.dsl.executor.rest.rest_handler import RestHandler 22 | 23 | 24 | class QueryResponse(ImmutableBaseModel): 25 | schema_: str = Field(..., alias="schema") 26 | results: list[dict[str, Any]] 27 | 28 | 29 | class FastApiHandler: 30 | def __init__(self, rest_handler: RestHandler) -> None: 31 | self.__rest_handler = rest_handler 32 | 33 | async def ingest(self, request: Request) -> Response: 34 | payload = await request.json() 35 | self.__rest_handler._ingest_handler(payload, request.url.path) # noqa: SLF001 private-member-access 36 | return Response(status_code=status.HTTP_202_ACCEPTED) 37 | 38 | async def query(self, request: Request) -> Response: 39 | payload = await request.json() 40 | result = self.__rest_handler._query_handler(payload, request.url.path) # noqa: SLF001 private-member-access 41 | query_response = QueryResponse( 42 | schema=result.schema._schema_name, # noqa: SLF001 private-member-access 43 | results=[ 44 | { 45 | "entity": { 46 | "id": entry.entity.header.object_id, 47 | "origin": ( 48 | { 49 | "id": entry.entity.header.object_id, 50 | "schema": entry.entity.header.schema_id, 51 | } 52 | if entry.entity.header.origin_id 53 | else {} 54 | ), 55 | }, 56 | "obj": entry.stored_object, 57 | } 58 | for entry in result.entries 59 | ], 60 | ) 61 | return JSONResponse( 62 | content=query_response.model_dump(by_alias=True), 63 | status_code=status.HTTP_200_OK, 64 | ) 65 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/util/open_api_description_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from typing import Any 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class OpenApiDescriptionUtil: 10 | @staticmethod 11 | def get_open_api_description_by_key(key: str, file_path: str | None = None) -> dict[str, Any]: 12 | if file_path is None: 13 | file_path = os.path.join(os.getcwd(), "executor/openapi/static_endpoint_descriptor.json") 14 | with open(file_path, encoding="utf-8") as file: 15 | data = json.load(file) 16 | open_api_description = data.get(key) 17 | if open_api_description is None: 18 | logger.warning("No OpenAPI description found for key: %s", key) 19 | return open_api_description 20 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/app/util/registry_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from importlib import import_module 17 | 18 | from superlinked.framework.dsl.registry.superlinked_registry import SuperlinkedRegistry 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class RegistryLoader: 24 | @staticmethod 25 | def get_registry(app_module_path: str) -> SuperlinkedRegistry | None: 26 | try: 27 | return import_module(app_module_path).SuperlinkedRegistry 28 | except ImportError: 29 | logger.exception("Module not found at: %s", app_module_path) 30 | except AttributeError: 31 | logger.exception("SuperlinkedRegistry not found in module: %s", app_module_path) 32 | except Exception: # pylint: disable=broad-exception-caught 33 | logger.exception("An unexpected error occurred while loading the module at: %s", app_module_path) 34 | return None 35 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/executor/supervisord.conf: -------------------------------------------------------------------------------- 1 | [supervisord] 2 | 3 | [inet_http_server] 4 | port=127.0.0.1:9001 5 | 6 | [rpcinterface:supervisor] 7 | supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface 8 | 9 | [program:uvicorn_host] 10 | command=/opt/pysetup/.venv/bin/python -m uvicorn executor.app.main:app --host 0.0.0.0 --port 8080 11 | numprocs=1 12 | process_name=uvicorn_host 13 | stdout_logfile=/dev/stdout 14 | stdout_logfile_maxbytes=0 15 | stderr_logfile=/dev/stderr 16 | stderr_logfile_maxbytes=0 17 | autorestart=true 18 | 19 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/Dockerfile: -------------------------------------------------------------------------------- 1 | # --- Python base --- 2 | 3 | FROM python:3.11.9 AS python-base 4 | 5 | ENV PYTHONUNBUFFERED=1 \ 6 | PYTHONDONTWRITEBYTECODE=1 \ 7 | \ 8 | PIP_NO_CACHE_DIR=off \ 9 | PIP_DISABLE_PIP_VERSION_CHECK=on \ 10 | PIP_DEFAULT_TIMEOUT=100 \ 11 | \ 12 | POETRY_VERSION=1.8.2 \ 13 | POETRY_HOME="/opt/poetry" \ 14 | POETRY_VIRTUALENVS_IN_PROJECT=true \ 15 | POETRY_NO_INTERACTION=1 \ 16 | \ 17 | PYSETUP_PATH="/opt/pysetup" \ 18 | VENV_PATH="/opt/pysetup/.venv" 19 | 20 | ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH" 21 | 22 | # --- Builder base --- 23 | 24 | FROM python-base AS builder-base 25 | RUN apt-get update \ 26 | && apt-get install --no-install-recommends -y \ 27 | curl \ 28 | build-essential \ 29 | && apt-get clean && rm -rf /var/lib/apt/lists/* 30 | 31 | SHELL ["/bin/bash", "-o", "pipefail", "-c"] 32 | RUN curl -sSL https://install.python-poetry.org | python - 33 | 34 | WORKDIR $PYSETUP_PATH 35 | 36 | COPY poetry.lock pyproject.toml $PYSETUP_PATH/ 37 | RUN poetry install --no-root --only poller 38 | 39 | # --- Poller --- 40 | 41 | FROM python:3.11.5-alpine AS poller 42 | 43 | COPY --from=builder-base /opt/pysetup/.venv /opt/pysetup/.venv 44 | 45 | WORKDIR /code 46 | 47 | #COPY config /code/config 48 | COPY poller/app /code/poller/app 49 | COPY poller/logging_config.ini poller/poller_config.ini /code/poller/ 50 | 51 | ENTRYPOINT ["/opt/pysetup/.venv/bin/python", "-m", "poller.app.main", "--config_path", "/code/config/config.yaml"] 52 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/app_location_parser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/app_location_parser/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/config/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/config/poller_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import configparser 16 | import logging 17 | import logging.config 18 | 19 | 20 | class PollerConfig: 21 | def __init__(self) -> None: 22 | poller_dir = "poller" 23 | poller_config_path = f"{poller_dir}/poller_config.ini" 24 | logging_config_path = f"{poller_dir}/logging_config.ini" 25 | 26 | config = configparser.ConfigParser() 27 | config.read(poller_config_path) 28 | 29 | self.poll_interval_seconds = config.getint("POLLER", "POLL_INTERVAL_SECONDS") 30 | self.executor_port = config.getint("POLLER", "EXECUTOR_PORT") 31 | self.executor_url = config.get("POLLER", "EXECUTOR_URL") 32 | self.aws_credentials = config.get("POLLER", "AWS_CREDENTIALS") 33 | self.gcp_credentials = config.get("POLLER", "GCP_CREDENTIALS") 34 | self.download_location = config.get("POLLER", "DOWNLOAD_LOCATION") 35 | self.logging_config = logging_config_path 36 | 37 | def setup_logger(self, name: str) -> logging.Logger: 38 | logging.config.fileConfig(self.logging_config) 39 | return logging.getLogger(name) 40 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Module for polling storage services and downloading updated files.""" 16 | 17 | import argparse 18 | 19 | from poller.app.poller.poller import Poller 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser(description="Start the poller with a given config file.") 23 | parser.add_argument( 24 | "--config_path", 25 | help="Path to the configuration file.", 26 | default="../config/config.yaml", 27 | ) 28 | args = parser.parse_args() 29 | 30 | poller = Poller(args.config_path) 31 | poller.run() 32 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/poller/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/poller/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/poller/poller.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import time 16 | from threading import Thread 17 | 18 | import yaml 19 | 20 | from poller.app.app_location_parser.app_location_parser import ( 21 | AppLocation, 22 | AppLocationParser, 23 | ) 24 | from poller.app.config.poller_config import PollerConfig 25 | from poller.app.resource_handler.resource_handler_factory import ResourceHandlerFactory 26 | 27 | 28 | class Poller(Thread): 29 | """ 30 | The Poller class is a thread that polls files from different types of storage: local, S3, 31 | and Google Cloud Storage at a regular interval. 32 | """ 33 | 34 | def __init__(self, config_path: str) -> None: 35 | Thread.__init__(self) 36 | self.poller_config = PollerConfig() 37 | self.app_location_config_path = config_path 38 | self.app_location_config = self.parse_app_location_config() 39 | 40 | def parse_app_location_config(self) -> AppLocation: 41 | """ 42 | Parse the configuration from the YAML file. 43 | """ 44 | with open(self.app_location_config_path, encoding="utf-8") as file: 45 | config_yaml = yaml.safe_load(file) 46 | app_location = config_yaml["app_location"] 47 | return AppLocationParser().parse(app_location) 48 | 49 | def run(self) -> None: 50 | """ 51 | Start the polling process. 52 | """ 53 | resource_handler = ResourceHandlerFactory.get_resource_handler(self.app_location_config) 54 | while True: 55 | if resource_handler.check_api_health(): 56 | resource_handler.poll() 57 | time.sleep(self.poller_config.poll_interval_seconds) 58 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/gcs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/gcs/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/gcs/gcs_resource_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from google.auth.exceptions import DefaultCredentialsError 16 | from google.cloud import storage # type: ignore[attr-defined] 17 | from google.cloud.exceptions import GoogleCloudError 18 | from google.cloud.storage.client import Client as GCSClient 19 | 20 | from poller.app.app_location_parser.app_location_parser import AppLocation 21 | from poller.app.resource_handler.resource_handler import ResourceHandler 22 | 23 | 24 | class GCSResourceHandler(ResourceHandler): 25 | def __init__(self, app_location: AppLocation, client: GCSClient | None = None) -> None: 26 | super().__init__(app_location) 27 | self.client = client or self.initialize_gcs_client() 28 | 29 | def initialize_gcs_client(self) -> GCSClient: 30 | """ 31 | Initialize the GCS client, with fallback to credentials file if necessary. 32 | """ 33 | try: 34 | # First, try to create a GCS client without explicit credentials 35 | client = storage.Client() 36 | client.get_bucket(self.app_location.bucket) # Test access 37 | except (GoogleCloudError, DefaultCredentialsError): 38 | # If the first method fails, try to use credentials from a JSON file 39 | try: 40 | return storage.Client.from_service_account_json( 41 | json_credentials_path=self.poller_config.gcp_credentials, 42 | ) 43 | except FileNotFoundError as e: 44 | msg = "Could not find GCP credentials file and no service account available." 45 | raise FileNotFoundError(msg) from e 46 | else: 47 | return client 48 | 49 | def poll(self) -> None: 50 | """ 51 | Poll files from a Google Cloud Storage bucket and download new or modified files. 52 | """ 53 | bucket = self.client.get_bucket(self.get_bucket()) 54 | blobs = bucket.list_blobs(prefix=self.app_location.path) 55 | for blob in blobs: 56 | self.check_and_download(blob.updated, blob.name) 57 | 58 | def download_file(self, bucket_name: str | None, object_name: str, download_path: str) -> None: 59 | """ 60 | Download a file from GCS to the specified path. 61 | """ 62 | bucket = self.client.get_bucket(bucket_name) 63 | blob = bucket.blob(object_name) 64 | blob.download_to_filename(download_path) 65 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/local/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/local/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/local/local_resource_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import shutil 17 | from datetime import datetime, timezone 18 | 19 | from poller.app.resource_handler.resource_handler import ResourceHandler 20 | 21 | 22 | class LocalResourceHandler(ResourceHandler): 23 | def get_bucket(self) -> str: 24 | return "local" 25 | 26 | def download_file(self, _: str | None, object_name: str, download_path: str) -> None: 27 | """ 28 | 'Download' a file from local storage to the specified path. 29 | In this case, it's just copying the file. 30 | """ 31 | self.logger.info("Copy file from %s to %s", object_name, download_path) 32 | shutil.copy2(object_name, download_path) 33 | 34 | def poll(self) -> None: 35 | """ 36 | Poll files from a local directory and notify about new or modified files. 37 | """ 38 | self.logger.info("Polling files from: %s", self.app_location.path) 39 | if not self._path_exists(): 40 | self.logger.error("Path does not exist: %s", self.app_location.path) 41 | return 42 | self._process_path() 43 | 44 | def _path_exists(self) -> bool: 45 | return os.path.exists(self.app_location.path) 46 | 47 | def _process_path(self) -> None: 48 | if os.path.isfile(self.app_location.path): 49 | self._process_file(self.app_location.path) 50 | else: 51 | self._process_directory() 52 | 53 | def _process_directory(self) -> None: 54 | for root, _, files in os.walk(self.app_location.path): 55 | for file in files: 56 | file_path = os.path.join(root, file) 57 | self._process_file(file_path) 58 | 59 | def _process_file(self, file_path: str) -> None: 60 | self.logger.info("Found file: %s", file_path) 61 | file_time = datetime.fromtimestamp(os.path.getmtime(file_path), tz=timezone.utc) 62 | try: 63 | self.check_and_download(file_time, self.app_location.path) 64 | except (FileNotFoundError, PermissionError): 65 | self.logger.exception("Failed to download and notify new version") 66 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/resource_handler_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from poller.app.app_location_parser.app_location_parser import AppLocation, StorageType 16 | from poller.app.resource_handler.gcs.gcs_resource_handler import GCSResourceHandler 17 | from poller.app.resource_handler.local.local_resource_handler import ( 18 | LocalResourceHandler, 19 | ) 20 | from poller.app.resource_handler.resource_handler import ResourceHandler 21 | from poller.app.resource_handler.s3.s3_resource_handler import S3ResourceHandler 22 | 23 | 24 | class ResourceHandlerFactory: 25 | @staticmethod 26 | def get_resource_handler(app_location: AppLocation) -> ResourceHandler: 27 | match app_location.type_: 28 | case StorageType.S3: 29 | return S3ResourceHandler(app_location) 30 | case StorageType.GCS: 31 | return GCSResourceHandler(app_location) 32 | case StorageType.LOCAL: 33 | return LocalResourceHandler(app_location) 34 | case _: 35 | msg = f"Invalid resource type in config: {app_location.type_}" 36 | raise ValueError(msg) 37 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/s3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/s3/__init__.py -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/s3/s3_resource_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Superlinked, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | 17 | import boto3 18 | from botocore.client import Config 19 | from botocore.exceptions import ClientError 20 | from mypy_boto3_s3.client import S3Client 21 | 22 | from poller.app.app_location_parser.app_location_parser import AppLocation 23 | from poller.app.resource_handler.resource_handler import ResourceHandler 24 | 25 | 26 | class S3ResourceHandler(ResourceHandler): 27 | def __init__( 28 | self, 29 | app_location: AppLocation, 30 | client: S3Client | None = None, 31 | ) -> None: # client=None for easier testability 32 | super().__init__(app_location) 33 | self.client = client or self.initialize_s3_client() 34 | self.resource = boto3.resource("s3") 35 | 36 | def initialize_s3_client(self) -> S3Client: 37 | """ 38 | Initialize the S3 client, with fallback to credentials file if necessary. 39 | """ 40 | try: 41 | # First, try to create an S3 resource without explicit credentials 42 | client = boto3.client("s3", config=Config(signature_version="s3v4")) 43 | if self.app_location.bucket is not None: 44 | client.head_bucket(Bucket=self.app_location.bucket) # Test access 45 | except ClientError: 46 | # If the first method fails, try to use credentials from a JSON file 47 | try: 48 | with open(self.poller_config.aws_credentials, encoding="utf-8") as aws_cred_file: 49 | aws_credentials = json.load(aws_cred_file) 50 | return boto3.client( 51 | "s3", 52 | aws_access_key_id=aws_credentials["aws_access_key_id"], 53 | aws_secret_access_key=aws_credentials["aws_secret_access_key"], 54 | region_name=aws_credentials["region"], 55 | ) 56 | except FileNotFoundError as e: 57 | msg = "Could not find AWS credentials file and no IAM role available." 58 | raise FileNotFoundError(msg) from e 59 | else: 60 | return client 61 | 62 | def poll(self) -> None: 63 | """ 64 | Poll files from an S3 bucket and download new or modified files. 65 | """ 66 | bucket = self.resource.Bucket(self.get_bucket()) 67 | for obj in bucket.objects.filter(Prefix=self.app_location.path): 68 | self.check_and_download(obj.last_modified, obj.key) 69 | 70 | def download_file(self, _: str | None, object_name: str, download_path: str) -> None: 71 | """ 72 | Download a file from S3 to the specified path. 73 | """ 74 | bucket = self.get_bucket() 75 | self.client.download_file(Bucket=bucket, Key=object_name, Filename=download_path) 76 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/logging_config.ini: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root,sampleLogger 3 | 4 | [handlers] 5 | keys=consoleHandler 6 | 7 | [formatters] 8 | keys=sampleFormatter 9 | 10 | [logger_root] 11 | level=INFO 12 | handlers=consoleHandler 13 | 14 | [logger_sampleLogger] 15 | level=INFO 16 | handlers=consoleHandler 17 | qualname=sampleLogger 18 | propagate=0 19 | 20 | [handler_consoleHandler] 21 | class=StreamHandler 22 | level=INFO 23 | formatter=sampleFormatter 24 | args=(sys.__stdout__,) 25 | 26 | [formatter_sampleFormatter] 27 | format=%(asctime)s - %(levelname)s - %(message)s 28 | datefmt= 29 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/runner/poller/poller_config.ini: -------------------------------------------------------------------------------- 1 | [POLLER] 2 | POLL_INTERVAL_SECONDS=10 3 | EXECUTOR_PORT=8080 4 | EXECUTOR_URL=http://executor 5 | DOWNLOAD_LOCATION=/code/data/src 6 | AWS_CREDENTIALS=/code/aws_credentials.json 7 | GCP_CREDENTIALS=/code/gcp_credentials.json 8 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/server/tools/init-venv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # please install pyenv, poetry on your system 3 | # on Mac 4 | # brew install pyenv 5 | # brew install poetry 6 | 7 | set -euxo pipefail 8 | 9 | script_path=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P) 10 | cd "${script_path}/../runner/" 11 | 12 | eval "$(pyenv init --path)" 13 | 14 | python_version=$(cat ../.python-version) 15 | 16 | if ! pyenv versions | grep -q "${python_version}"; then 17 | echo y | pyenv install "${python_version}" 18 | fi 19 | 20 | poetry env use $(pyenv local ${python_version} && pyenv which python) 21 | 22 | poetry cache list | awk '{print $1}' | xargs -I {} poetry cache clear {} --all 23 | 24 | poetry install --no-root --with poller,executor,dev 25 | 26 | source "$(poetry env info --path)/bin/activate" 27 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/singleton.py: -------------------------------------------------------------------------------- 1 | from threading import Lock 2 | from typing import ClassVar 3 | 4 | 5 | class SingletonMeta(type): 6 | """ 7 | This is a thread-safe implementation of Singleton. 8 | """ 9 | 10 | _instances: ClassVar = {} 11 | 12 | _lock: Lock = Lock() 13 | 14 | """ 15 | We now have a lock object that will be used to synchronize threads during 16 | first access to the Singleton. 17 | """ 18 | 19 | def __call__(cls, *args, **kwargs): 20 | """ 21 | Possible changes to the value of the `__init__` argument do not affect 22 | the returned instance. 23 | """ 24 | # Now, imagine that the program has just been launched. Since there's no 25 | # Singleton instance yet, multiple threads can simultaneously pass the 26 | # previous conditional and reach this point almost at the same time. The 27 | # first of them will acquire lock and will proceed further, while the 28 | # rest will wait here. 29 | with cls._lock: 30 | # The first thread to acquire the lock, reaches this conditional, 31 | # goes inside and creates the Singleton instance. Once it leaves the 32 | # lock block, a thread that might have been waiting for the lock 33 | # release may then enter this section. But since the Singleton field 34 | # is already initialized, the thread won't create a new object. 35 | if cls not in cls._instances: 36 | instance = super().__call__(*args, **kwargs) 37 | cls._instances[cls] = instance 38 | 39 | return cls._instances[cls] 40 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/utils/__init__.py: -------------------------------------------------------------------------------- 1 | def flatten(nested_list: list) -> list: 2 | """Flatten a list of lists into a single list.""" 3 | 4 | return [item for sublist in nested_list for item in sublist] 5 | -------------------------------------------------------------------------------- /src/bonus_superlinked_rag/utils/logging.py: -------------------------------------------------------------------------------- 1 | import structlog 2 | 3 | 4 | def get_logger(cls: str): 5 | return structlog.get_logger().bind(cls=cls) -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | from . import db, logger_utils, opik_utils 2 | from .logger_utils import get_logger 3 | 4 | logger = get_logger(__file__) 5 | 6 | try: 7 | from .opik_utils import configure_opik 8 | 9 | configure_opik() 10 | except: 11 | logger.warning("Could not configure Opik.") 12 | 13 | __all__ = ["get_logger", "logger_utils", "opik_utils", "db"] 14 | -------------------------------------------------------------------------------- /src/core/aws/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/core/aws/create_execution_role.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | from core.logger_utils import get_logger 5 | 6 | logger = get_logger(__file__) 7 | 8 | try: 9 | import boto3 10 | except ModuleNotFoundError: 11 | logger.warning( 12 | "Couldn't load AWS or SageMaker imports. Run 'poetry install --with aws' to support AWS." 13 | ) 14 | 15 | from core.config import settings 16 | 17 | 18 | def create_sagemaker_execution_role(role_name: str): 19 | assert settings.AWS_REGION, "AWS_REGION is not set." 20 | assert settings.AWS_ACCESS_KEY, "AWS_ACCESS_KEY is not set." 21 | assert settings.AWS_SECRET_KEY, "AWS_SECRET_KEY is not set." 22 | 23 | # Create IAM client 24 | iam = boto3.client( 25 | "iam", 26 | region_name=settings.AWS_REGION, 27 | aws_access_key_id=settings.AWS_ACCESS_KEY, 28 | aws_secret_access_key=settings.AWS_SECRET_KEY, 29 | ) 30 | 31 | # Define the trust relationship policy 32 | trust_relationship = { 33 | "Version": "2012-10-17", 34 | "Statement": [ 35 | { 36 | "Effect": "Allow", 37 | "Principal": {"Service": "sagemaker.amazonaws.com"}, 38 | "Action": "sts:AssumeRole", 39 | } 40 | ], 41 | } 42 | 43 | try: 44 | # Create the IAM role 45 | role = iam.create_role( 46 | RoleName=role_name, 47 | AssumeRolePolicyDocument=json.dumps(trust_relationship), 48 | Description="Execution role for SageMaker", 49 | ) 50 | 51 | # Attach necessary policies 52 | policies = [ 53 | "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess", 54 | "arn:aws:iam::aws:policy/AmazonS3FullAccess", 55 | "arn:aws:iam::aws:policy/CloudWatchLogsFullAccess", 56 | "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryFullAccess", 57 | ] 58 | 59 | for policy in policies: 60 | iam.attach_role_policy(RoleName=role_name, PolicyArn=policy) 61 | 62 | logger.info(f"Role '{role_name}' created successfully.") 63 | logger.info(f"Role ARN: {role['Role']['Arn']}") 64 | 65 | return role["Role"]["Arn"] 66 | 67 | except iam.exceptions.EntityAlreadyExistsException: 68 | logger.warning(f"Role '{role_name}' already exists. Fetching its ARN...") 69 | role = iam.get_role(RoleName=role_name) 70 | 71 | return role["Role"]["Arn"] 72 | 73 | 74 | if __name__ == "__main__": 75 | role_arn = create_sagemaker_execution_role("SageMakerExecutionRoleLLM") 76 | logger.info(role_arn) 77 | 78 | # Save the role ARN to a file 79 | with Path("sagemaker_execution_role.json").open("w") as f: 80 | json.dump({"RoleArn": role_arn}, f) 81 | 82 | logger.info("Role ARN saved to 'sagemaker_execution_role.json'") 83 | -------------------------------------------------------------------------------- /src/core/aws/create_sagemaker_role.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | from logger_utils import get_logger 5 | 6 | logger = get_logger(__file__) 7 | 8 | try: 9 | import boto3 10 | except ModuleNotFoundError: 11 | logger.warning( 12 | "Couldn't load AWS or SageMaker imports. Run 'poetry install --with aws' to support AWS." 13 | ) 14 | 15 | from config import settings 16 | 17 | 18 | def create_sagemaker_user(username: str): 19 | assert settings.AWS_REGION, "AWS_REGION is not set." 20 | assert settings.AWS_ACCESS_KEY, "AWS_ACCESS_KEY is not set." 21 | assert settings.AWS_SECRET_KEY, "AWS_SECRET_KEY is not set." 22 | 23 | # Create IAM client 24 | iam = boto3.client( 25 | "iam", 26 | region_name=settings.AWS_REGION, 27 | aws_access_key_id=settings.AWS_ACCESS_KEY, 28 | aws_secret_access_key=settings.AWS_SECRET_KEY, 29 | ) 30 | 31 | # Create user 32 | iam.create_user(UserName=username) 33 | 34 | # Attach necessary policies 35 | policies = [ 36 | "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess", 37 | "arn:aws:iam::aws:policy/AWSCloudFormationFullAccess", 38 | "arn:aws:iam::aws:policy/IAMFullAccess", 39 | "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryFullAccess", 40 | "arn:aws:iam::aws:policy/AmazonS3FullAccess", 41 | ] 42 | 43 | for policy in policies: 44 | iam.attach_user_policy(UserName=username, PolicyArn=policy) 45 | 46 | # Create access key 47 | response = iam.create_access_key(UserName=username) 48 | access_key = response["AccessKey"] 49 | 50 | logger.info(f"User '{username}' successfully created.") 51 | logger.info("Access Key ID and Secret Access Key successfully created.") 52 | 53 | return { 54 | "AccessKeyId": access_key["AccessKeyId"], 55 | "SecretAccessKey": access_key["SecretAccessKey"], 56 | } 57 | 58 | 59 | if __name__ == "__main__": 60 | new_user = create_sagemaker_user("sagemaker-deployer") 61 | 62 | with Path("sagemaker_user_credentials.json").open("w") as f: 63 | json.dump(new_user, f) 64 | 65 | logger.info("Credentials saved to 'sagemaker_user_credentials.json'") 66 | -------------------------------------------------------------------------------- /src/core/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from pydantic_settings import BaseSettings, SettingsConfigDict 4 | 5 | ROOT_DIR = str(Path(__file__).parent.parent.parent) 6 | 7 | 8 | class AppSettings(BaseSettings): 9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8") 10 | 11 | # MongoDB configs 12 | MONGO_DATABASE_HOST: str = ( 13 | "mongodb://mongo1:30001,mongo2:30002,mongo3:30003/?replicaSet=my-replica-set" 14 | ) 15 | MONGO_DATABASE_NAME: str = "twin" 16 | 17 | # MQ config 18 | RABBITMQ_DEFAULT_USERNAME: str = "guest" 19 | RABBITMQ_DEFAULT_PASSWORD: str = "guest" 20 | RABBITMQ_HOST: str = "mq" 21 | RABBITMQ_PORT: int = 5673 22 | 23 | # QdrantDB config 24 | QDRANT_CLOUD_URL: str = "str" 25 | QDRANT_DATABASE_HOST: str = "qdrant" 26 | QDRANT_DATABASE_PORT: int = 6333 27 | USE_QDRANT_CLOUD: bool = False 28 | QDRANT_APIKEY: str | None = None 29 | 30 | # OpenAI config 31 | OPENAI_MODEL_ID: str = "gpt-4o-mini" 32 | OPENAI_API_KEY: str | None = None 33 | 34 | # CometML config 35 | COMET_API_KEY: str | None = None 36 | COMET_WORKSPACE: str | None = None 37 | COMET_PROJECT: str = "llm-twin" 38 | 39 | # AWS Authentication 40 | AWS_REGION: str = "eu-central-1" 41 | AWS_ACCESS_KEY: str | None = None 42 | AWS_SECRET_KEY: str | None = None 43 | AWS_ARN_ROLE: str | None = None 44 | 45 | # LLM Model config 46 | HUGGINGFACE_ACCESS_TOKEN: str | None = None 47 | MODEL_ID: str = "pauliusztin/LLMTwin-Llama-3.1-8B" 48 | DEPLOYMENT_ENDPOINT_NAME: str = "twin" 49 | 50 | MAX_INPUT_TOKENS: int = 1536 # Max length of input text. 51 | MAX_TOTAL_TOKENS: int = 2048 # Max length of the generation (including input text). 52 | MAX_BATCH_TOTAL_TOKENS: int = 2048 # Limits the number of tokens that can be processed in parallel during the generation. 53 | 54 | # Embeddings config 55 | EMBEDDING_MODEL_ID: str = "BAAI/bge-small-en-v1.5" 56 | EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 512 57 | EMBEDDING_SIZE: int = 384 58 | EMBEDDING_MODEL_DEVICE: str = "cpu" 59 | 60 | def patch_localhost(self) -> None: 61 | self.MONGO_DATABASE_HOST = "mongodb://localhost:30001,localhost:30002,localhost:30003/?replicaSet=my-replica-set" 62 | self.QDRANT_DATABASE_HOST = "localhost" 63 | self.RABBITMQ_HOST = "localhost" 64 | 65 | 66 | settings = AppSettings() 67 | -------------------------------------------------------------------------------- /src/core/db/__init__.py: -------------------------------------------------------------------------------- 1 | from . import documents, mongo, qdrant 2 | 3 | __all__ = ["documents", "mongo", "qdrant"] 4 | -------------------------------------------------------------------------------- /src/core/db/mongo.py: -------------------------------------------------------------------------------- 1 | from pymongo import MongoClient 2 | from pymongo.errors import ConnectionFailure 3 | 4 | from core.config import settings 5 | from core.logger_utils import get_logger 6 | 7 | logger = get_logger(__file__) 8 | 9 | 10 | class MongoDatabaseConnector: 11 | """Singleton class to connect to MongoDB database.""" 12 | 13 | _instance: MongoClient | None = None 14 | 15 | def __new__(cls, *args, **kwargs): 16 | if cls._instance is None: 17 | try: 18 | cls._instance = MongoClient(settings.MONGO_DATABASE_HOST) 19 | logger.info( 20 | f"Connection to database with uri: {settings.MONGO_DATABASE_HOST} successful" 21 | ) 22 | except ConnectionFailure: 23 | logger.error(f"Couldn't connect to the database.") 24 | 25 | raise 26 | 27 | return cls._instance 28 | 29 | def get_database(self): 30 | assert self._instance, "Database connection not initialized" 31 | 32 | return self._instance[settings.MONGO_DATABASE_NAME] 33 | 34 | def close(self): 35 | if self._instance: 36 | self._instance.close() 37 | logger.info("Connected to database has been closed.") 38 | 39 | 40 | connection = MongoDatabaseConnector() 41 | -------------------------------------------------------------------------------- /src/core/db/qdrant.py: -------------------------------------------------------------------------------- 1 | from qdrant_client import QdrantClient, models 2 | from qdrant_client.http.models import Batch, Distance, VectorParams 3 | 4 | import core.logger_utils as logger_utils 5 | from core.config import settings 6 | 7 | logger = logger_utils.get_logger(__name__) 8 | 9 | 10 | class QdrantDatabaseConnector: 11 | _instance: QdrantClient | None = None 12 | 13 | def __init__(self) -> None: 14 | if self._instance is None: 15 | if settings.USE_QDRANT_CLOUD: 16 | self._instance = QdrantClient( 17 | url=settings.QDRANT_CLOUD_URL, 18 | api_key=settings.QDRANT_APIKEY, 19 | ) 20 | else: 21 | self._instance = QdrantClient( 22 | host=settings.QDRANT_DATABASE_HOST, 23 | port=settings.QDRANT_DATABASE_PORT, 24 | ) 25 | 26 | def get_collection(self, collection_name: str): 27 | return self._instance.get_collection(collection_name=collection_name) 28 | 29 | def create_non_vector_collection(self, collection_name: str): 30 | self._instance.create_collection( 31 | collection_name=collection_name, vectors_config={} 32 | ) 33 | 34 | def create_vector_collection(self, collection_name: str): 35 | self._instance.create_collection( 36 | collection_name=collection_name, 37 | vectors_config=VectorParams( 38 | size=settings.EMBEDDING_SIZE, distance=Distance.COSINE 39 | ), 40 | ) 41 | 42 | def write_data(self, collection_name: str, points: Batch): 43 | try: 44 | self._instance.upsert(collection_name=collection_name, points=points) 45 | except Exception: 46 | logger.exception("An error occurred while inserting data.") 47 | 48 | raise 49 | 50 | def search( 51 | self, 52 | collection_name: str, 53 | query_vector: list, 54 | query_filter: models.Filter | None = None, 55 | limit: int = 3, 56 | ) -> list: 57 | return self._instance.search( 58 | collection_name=collection_name, 59 | query_vector=query_vector, 60 | query_filter=query_filter, 61 | limit=limit, 62 | ) 63 | 64 | def scroll(self, collection_name: str, limit: int): 65 | return self._instance.scroll(collection_name=collection_name, limit=limit) 66 | 67 | def close(self): 68 | if self._instance: 69 | self._instance.close() 70 | 71 | logger.info("Connected to database has been closed.") 72 | -------------------------------------------------------------------------------- /src/core/errors.py: -------------------------------------------------------------------------------- 1 | class TwinBaseException(Exception): 2 | pass 3 | 4 | 5 | class ImproperlyConfigured(TwinBaseException): 6 | pass 7 | -------------------------------------------------------------------------------- /src/core/lib.py: -------------------------------------------------------------------------------- 1 | from core.errors import ImproperlyConfigured 2 | 3 | 4 | def split_user_full_name(user: str | None) -> tuple[str, str]: 5 | if user is None: 6 | raise ImproperlyConfigured("User name is empty") 7 | 8 | name_tokens = user.split(" ") 9 | if len(name_tokens) == 0: 10 | raise ImproperlyConfigured("User name is empty") 11 | elif len(name_tokens) == 1: 12 | first_name, last_name = name_tokens[0], name_tokens[0] 13 | else: 14 | first_name, last_name = " ".join(name_tokens[:-1]), name_tokens[-1] 15 | 16 | return first_name, last_name 17 | 18 | 19 | def flatten(nested_list: list) -> list: 20 | """Flatten a list of lists into a single list.""" 21 | 22 | return [item for sublist in nested_list for item in sublist] 23 | -------------------------------------------------------------------------------- /src/core/logger_utils.py: -------------------------------------------------------------------------------- 1 | import structlog 2 | 3 | 4 | def get_logger(cls: str): 5 | return structlog.get_logger().bind(cls=cls) 6 | -------------------------------------------------------------------------------- /src/core/rag/__init__.py: -------------------------------------------------------------------------------- 1 | from .query_expanison import QueryExpansion -------------------------------------------------------------------------------- /src/core/rag/prompt_templates.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from langchain.prompts import PromptTemplate 4 | from pydantic import BaseModel 5 | 6 | 7 | class BasePromptTemplate(ABC, BaseModel): 8 | @abstractmethod 9 | def create_template(self, *args) -> PromptTemplate: 10 | pass 11 | 12 | 13 | class QueryExpansionTemplate(BasePromptTemplate): 14 | prompt: str = """You are an AI language model assistant. Your task is to generate {to_expand_to_n} 15 | different versions of the given user question to retrieve relevant documents from a vector 16 | database. By generating multiple perspectives on the user question, your goal is to help 17 | the user overcome some of the limitations of the distance-based similarity search. 18 | Provide these alternative questions separated by '{separator}'. 19 | Original question: {question}""" 20 | 21 | @property 22 | def separator(self) -> str: 23 | return "#next-question#" 24 | 25 | def create_template(self, to_expand_to_n: int) -> PromptTemplate: 26 | return PromptTemplate( 27 | template=self.prompt, 28 | input_variables=["question"], 29 | partial_variables={ 30 | "separator": self.separator, 31 | "to_expand_to_n": to_expand_to_n, 32 | }, 33 | ) 34 | 35 | 36 | class SelfQueryTemplate(BasePromptTemplate): 37 | prompt: str = """You are an AI language model assistant. Your task is to extract information from a user question. 38 | The required information that needs to be extracted is the user name or user id. 39 | Your response should consists of only the extracted user name (e.g., John Doe) or id (e.g. 1345256), nothing else. 40 | If the user question does not contain any user name or id, you should return the following token: none. 41 | 42 | For example: 43 | QUESTION 1: 44 | My name is Paul Iusztin and I want a post about... 45 | RESPONSE 1: 46 | Paul Iusztin 47 | 48 | QUESTION 2: 49 | I want to write a post about... 50 | RESPONSE 2: 51 | none 52 | 53 | QUESTION 3: 54 | My user id is 1345256 and I want to write a post about... 55 | RESPONSE 3: 56 | 1345256 57 | 58 | User question: {question}""" 59 | 60 | def create_template(self) -> PromptTemplate: 61 | return PromptTemplate(template=self.prompt, input_variables=["question"]) 62 | 63 | 64 | class RerankingTemplate(BasePromptTemplate): 65 | prompt: str = """You are an AI language model assistant. Your task is to rerank passages related to a query 66 | based on their relevance. 67 | The most relevant passages should be put at the beginning. 68 | You should only pick at max {keep_top_k} passages. 69 | The provided and reranked documents are separated by '{separator}'. 70 | 71 | The following are passages related to this query: {question}. 72 | 73 | Passages: 74 | {passages} 75 | """ 76 | 77 | def create_template(self, keep_top_k: int) -> PromptTemplate: 78 | return PromptTemplate( 79 | template=self.prompt, 80 | input_variables=["question", "passages"], 81 | partial_variables={"keep_top_k": keep_top_k, "separator": self.separator}, 82 | ) 83 | 84 | @property 85 | def separator(self) -> str: 86 | return "\n#next-document#\n" 87 | -------------------------------------------------------------------------------- /src/core/rag/query_expanison.py: -------------------------------------------------------------------------------- 1 | import opik 2 | from config import settings 3 | from langchain_openai import ChatOpenAI 4 | from opik.integrations.langchain import OpikTracer 5 | 6 | from core.rag.prompt_templates import QueryExpansionTemplate 7 | 8 | 9 | class QueryExpansion: 10 | opik_tracer = OpikTracer(tags=["QueryExpansion"]) 11 | 12 | @staticmethod 13 | @opik.track(name="QueryExpansion.generate_response") 14 | def generate_response(query: str, to_expand_to_n: int) -> list[str]: 15 | query_expansion_template = QueryExpansionTemplate() 16 | prompt = query_expansion_template.create_template(to_expand_to_n) 17 | model = ChatOpenAI( 18 | model=settings.OPENAI_MODEL_ID, 19 | api_key=settings.OPENAI_API_KEY, 20 | temperature=0, 21 | ) 22 | chain = prompt | model 23 | chain = chain.with_config({"callbacks": [QueryExpansion.opik_tracer]}) 24 | 25 | response = chain.invoke({"question": query}) 26 | response = response.content 27 | 28 | queries = response.strip().split(query_expansion_template.separator) 29 | stripped_queries = [ 30 | stripped_item for item in queries if (stripped_item := item.strip(" \\n")) 31 | ] 32 | 33 | return stripped_queries 34 | -------------------------------------------------------------------------------- /src/core/rag/reranking.py: -------------------------------------------------------------------------------- 1 | from config import settings 2 | from langchain_openai import ChatOpenAI 3 | 4 | from core.rag.prompt_templates import RerankingTemplate 5 | 6 | 7 | class Reranker: 8 | @staticmethod 9 | def generate_response( 10 | query: str, passages: list[str], keep_top_k: int 11 | ) -> list[str]: 12 | reranking_template = RerankingTemplate() 13 | prompt = reranking_template.create_template(keep_top_k=keep_top_k) 14 | model = ChatOpenAI( 15 | model=settings.OPENAI_MODEL_ID, api_key=settings.OPENAI_API_KEY 16 | ) 17 | chain = prompt | model 18 | 19 | stripped_passages = [ 20 | stripped_item for item in passages if (stripped_item := item.strip()) 21 | ] 22 | passages = reranking_template.separator.join(stripped_passages) 23 | response = chain.invoke({"question": query, "passages": passages}) 24 | response = response.content 25 | 26 | reranked_passages = response.strip().split(reranking_template.separator) 27 | stripped_passages = [ 28 | stripped_item 29 | for item in reranked_passages 30 | if (stripped_item := item.strip()) 31 | ] 32 | 33 | return stripped_passages 34 | -------------------------------------------------------------------------------- /src/core/rag/self_query.py: -------------------------------------------------------------------------------- 1 | import opik 2 | from config import settings 3 | from langchain_openai import ChatOpenAI 4 | from opik.integrations.langchain import OpikTracer 5 | 6 | import core.logger_utils as logger_utils 7 | from core import lib 8 | from core.db.documents import UserDocument 9 | from core.rag.prompt_templates import SelfQueryTemplate 10 | 11 | logger = logger_utils.get_logger(__name__) 12 | 13 | 14 | class SelfQuery: 15 | opik_tracer = OpikTracer(tags=["SelfQuery"]) 16 | 17 | @staticmethod 18 | @opik.track(name="SelQuery.generate_response") 19 | def generate_response(query: str) -> str | None: 20 | prompt = SelfQueryTemplate().create_template() 21 | model = ChatOpenAI( 22 | model=settings.OPENAI_MODEL_ID, 23 | api_key=settings.OPENAI_API_KEY, 24 | temperature=0, 25 | ) 26 | chain = prompt | model 27 | chain = chain.with_config({"callbacks": [SelfQuery.opik_tracer]}) 28 | 29 | response = chain.invoke({"question": query}) 30 | response = response.content 31 | user_full_name = response.strip("\n ") 32 | 33 | if user_full_name == "none": 34 | return None 35 | 36 | logger.info( 37 | f"Successfully extracted the user full name from the query.", 38 | user_full_name=user_full_name, 39 | ) 40 | first_name, last_name = lib.split_user_full_name(user_full_name) 41 | logger.info( 42 | f"Successfully extracted the user first and last name from the query.", 43 | first_name=first_name, 44 | last_name=last_name, 45 | ) 46 | user_id = UserDocument.get_or_create(first_name=first_name, last_name=last_name) 47 | 48 | return user_id 49 | -------------------------------------------------------------------------------- /src/data_cdc/cdc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | from bson import json_util 5 | from config import settings 6 | from core.db.mongo import MongoDatabaseConnector 7 | from core.logger_utils import get_logger 8 | from core.mq import publish_to_rabbitmq 9 | 10 | logger = get_logger(__file__) 11 | 12 | 13 | def stream_process(): 14 | try: 15 | client = MongoDatabaseConnector() 16 | db = client["twin"] 17 | logging.info("Connected to MongoDB.") 18 | 19 | # Watch changes in a specific collection 20 | changes = db.watch([{"$match": {"operationType": {"$in": ["insert"]}}}]) 21 | for change in changes: 22 | data_type = change["ns"]["coll"] 23 | entry_id = str(change["fullDocument"]["_id"]) # Convert ObjectId to string 24 | 25 | change["fullDocument"].pop("_id") 26 | change["fullDocument"]["type"] = data_type 27 | change["fullDocument"]["entry_id"] = entry_id 28 | 29 | if data_type not in ["articles", "posts", "repositories"]: 30 | logging.info(f"Unsupported data type: '{data_type}'") 31 | continue 32 | 33 | # Use json_util to serialize the document 34 | data = json.dumps(change["fullDocument"], default=json_util.default) 35 | logger.info( 36 | f"Change detected and serialized for a data sample of type {data_type}." 37 | ) 38 | 39 | # Send data to rabbitmq 40 | publish_to_rabbitmq(queue_name=settings.RABBITMQ_QUEUE_NAME, data=data) 41 | logger.info(f"Data of type '{data_type}' published to RabbitMQ.") 42 | 43 | except Exception as e: 44 | logger.error(f"An error occurred: {e}") 45 | 46 | 47 | if __name__ == "__main__": 48 | stream_process() 49 | -------------------------------------------------------------------------------- /src/data_cdc/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from pydantic_settings import BaseSettings, SettingsConfigDict 4 | 5 | ROOT_DIR = str(Path(__file__).parent.parent.parent) 6 | 7 | 8 | class Settings(BaseSettings): 9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8") 10 | 11 | MONGO_DATABASE_HOST: str = ( 12 | "mongodb://mongo1:30001,mongo2:30002,mongo3:30003/?replicaSet=my-replica-set" 13 | ) 14 | MONGO_DATABASE_NAME: str = "twin" 15 | 16 | RABBITMQ_HOST: str = "mq" # or localhost if running outside Docker 17 | RABBITMQ_PORT: int = 5672 18 | RABBITMQ_DEFAULT_USERNAME: str = "guest" 19 | RABBITMQ_DEFAULT_PASSWORD: str = "guest" 20 | RABBITMQ_QUEUE_NAME: str = "default" 21 | 22 | 23 | settings = Settings() 24 | -------------------------------------------------------------------------------- /src/data_cdc/test_cdc.py: -------------------------------------------------------------------------------- 1 | from pymongo import MongoClient 2 | 3 | 4 | def insert_data_to_mongodb(uri, database_name, collection_name, data): 5 | """ 6 | Insert data into a MongoDB collection. 7 | 8 | :param uri: MongoDB URI 9 | :param database_name: Name of the database 10 | :param collection_name: Name of the collection 11 | :param data: Data to be inserted (dict) 12 | """ 13 | client = MongoClient(uri) 14 | db = client[database_name] 15 | collection = db[collection_name] 16 | 17 | try: 18 | result = collection.insert_one(data) 19 | print(f"Data inserted with _id: {result.inserted_id}") 20 | except Exception as e: 21 | print(f"An error occurred: {e}") 22 | finally: 23 | client.close() 24 | 25 | 26 | if __name__ == "__main__": 27 | insert_data_to_mongodb( 28 | "mongodb://localhost:30001,localhost:30002,localhost:30003/?replicaSet=my-replica-set", 29 | "twin", 30 | "posts", 31 | {"platform": "linkedin", "content": "Test content"} 32 | ) 33 | -------------------------------------------------------------------------------- /src/data_crawling/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from pydantic_settings import BaseSettings, SettingsConfigDict 4 | 5 | ROOT_DIR = str(Path(__file__).parent.parent.parent) 6 | 7 | 8 | class Settings(BaseSettings): 9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8") 10 | 11 | MONGO_DATABASE_HOST: str = ( 12 | "mongodb://mongo1:30001,mongo2:30002,mongo3:30003/?replicaSet=my-replica-set" 13 | ) 14 | MONGO_DATABASE_NAME: str = "twin" 15 | 16 | # Optional LinkedIn credentials for scraping your profile 17 | LINKEDIN_USERNAME: str | None = None 18 | LINKEDIN_PASSWORD: str | None = None 19 | 20 | 21 | settings = Settings() 22 | -------------------------------------------------------------------------------- /src/data_crawling/crawlers/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom_article import CustomArticleCrawler 2 | from .github import GithubCrawler 3 | from .linkedin import LinkedInCrawler 4 | from .medium import MediumCrawler 5 | 6 | __all__ = ["CustomArticleCrawler", "GithubCrawler", "LinkedInCrawler", "MediumCrawler"] 7 | -------------------------------------------------------------------------------- /src/data_crawling/crawlers/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | from abc import ABC, abstractmethod 3 | from tempfile import mkdtemp 4 | 5 | from core.db.documents import BaseDocument 6 | from selenium import webdriver 7 | from selenium.webdriver.chrome.options import Options 8 | 9 | 10 | class BaseCrawler(ABC): 11 | model: type[BaseDocument] 12 | 13 | @abstractmethod 14 | def extract(self, link: str, **kwargs) -> None: ... 15 | 16 | 17 | class BaseAbstractCrawler(BaseCrawler, ABC): 18 | def __init__(self, scroll_limit: int = 5) -> None: 19 | options = webdriver.ChromeOptions() 20 | 21 | options.add_argument("--no-sandbox") 22 | options.add_argument("--headless=new") 23 | options.add_argument("--disable-dev-shm-usage") 24 | options.add_argument("--log-level=3") 25 | options.add_argument("--disable-popup-blocking") 26 | options.add_argument("--disable-notifications") 27 | options.add_argument("--disable-extensions") 28 | options.add_argument("--disable-background-networking") 29 | options.add_argument("--ignore-certificate-errors") 30 | options.add_argument(f"--user-data-dir={mkdtemp()}") 31 | options.add_argument(f"--data-path={mkdtemp()}") 32 | options.add_argument(f"--disk-cache-dir={mkdtemp()}") 33 | options.add_argument("--remote-debugging-port=9226") 34 | 35 | self.set_extra_driver_options(options) 36 | 37 | self.scroll_limit = scroll_limit 38 | self.driver = webdriver.Chrome( 39 | options=options, 40 | ) 41 | 42 | def set_extra_driver_options(self, options: Options) -> None: 43 | pass 44 | 45 | def login(self) -> None: 46 | pass 47 | 48 | def scroll_page(self) -> None: 49 | """Scroll through the LinkedIn page based on the scroll limit.""" 50 | current_scroll = 0 51 | last_height = self.driver.execute_script("return document.body.scrollHeight") 52 | while True: 53 | self.driver.execute_script( 54 | "window.scrollTo(0, document.body.scrollHeight);" 55 | ) 56 | time.sleep(5) 57 | new_height = self.driver.execute_script("return document.body.scrollHeight") 58 | if new_height == last_height or ( 59 | self.scroll_limit and current_scroll >= self.scroll_limit 60 | ): 61 | break 62 | last_height = new_height 63 | current_scroll += 1 64 | -------------------------------------------------------------------------------- /src/data_crawling/crawlers/custom_article.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlparse 2 | 3 | from aws_lambda_powertools import Logger 4 | from core.db.documents import ArticleDocument 5 | from langchain_community.document_loaders import AsyncHtmlLoader 6 | from langchain_community.document_transformers.html2text import Html2TextTransformer 7 | 8 | from .base import BaseCrawler 9 | 10 | logger = Logger(service="llm-twin-course/crawler") 11 | 12 | 13 | class CustomArticleCrawler(BaseCrawler): 14 | model = ArticleDocument 15 | 16 | def __init__(self) -> None: 17 | super().__init__() 18 | 19 | def extract(self, link: str, **kwargs) -> None: 20 | old_model = self.model.find(link=link) 21 | if old_model is not None: 22 | logger.info(f"Article already exists in the database: {link}") 23 | 24 | return 25 | 26 | logger.info(f"Starting scrapping article: {link}") 27 | 28 | loader = AsyncHtmlLoader([link]) 29 | docs = loader.load() 30 | 31 | html2text = Html2TextTransformer() 32 | docs_transformed = html2text.transform_documents(docs) 33 | doc_transformed = docs_transformed[0] 34 | 35 | content = { 36 | "Title": doc_transformed.metadata.get("title"), 37 | "Subtitle": doc_transformed.metadata.get("description"), 38 | "Content": doc_transformed.page_content, 39 | "language": doc_transformed.metadata.get("language"), 40 | } 41 | 42 | parsed_url = urlparse(link) 43 | platform = parsed_url.netloc 44 | 45 | instance = self.model( 46 | content=content, 47 | link=link, 48 | platform=platform, 49 | author_id=kwargs.get("user"), 50 | ) 51 | instance.save() 52 | 53 | logger.info(f"Finished scrapping custom article: {link}") 54 | -------------------------------------------------------------------------------- /src/data_crawling/crawlers/github.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import tempfile 5 | 6 | from aws_lambda_powertools import Logger 7 | from core.db.documents import RepositoryDocument 8 | 9 | from crawlers.base import BaseCrawler 10 | 11 | logger = Logger(service="llm-twin-course/crawler") 12 | 13 | 14 | class GithubCrawler(BaseCrawler): 15 | model = RepositoryDocument 16 | 17 | def __init__(self, ignore=(".git", ".toml", ".lock", ".png")) -> None: 18 | super().__init__() 19 | self._ignore = ignore 20 | 21 | def extract(self, link: str, **kwargs) -> None: 22 | logger.info(f"Starting scrapping GitHub repository: {link}") 23 | 24 | repo_name = link.rstrip("/").split("/")[-1] 25 | 26 | local_temp = tempfile.mkdtemp() 27 | 28 | try: 29 | os.chdir(local_temp) 30 | subprocess.run(["git", "clone", link]) 31 | 32 | repo_path = os.path.join(local_temp, os.listdir(local_temp)[0]) 33 | 34 | tree = {} 35 | for root, dirs, files in os.walk(repo_path): 36 | dir = root.replace(repo_path, "").lstrip("/") 37 | if dir.startswith(self._ignore): 38 | continue 39 | 40 | for file in files: 41 | if file.endswith(self._ignore): 42 | continue 43 | file_path = os.path.join(dir, file) 44 | with open(os.path.join(root, file), "r", errors="ignore") as f: 45 | tree[file_path] = f.read().replace(" ", "") 46 | 47 | instance = self.model( 48 | name=repo_name, link=link, content=tree, owner_id=kwargs.get("user") 49 | ) 50 | instance.save() 51 | 52 | except Exception: 53 | raise 54 | finally: 55 | shutil.rmtree(local_temp) 56 | 57 | logger.info(f"Finished scrapping GitHub repository: {link}") 58 | -------------------------------------------------------------------------------- /src/data_crawling/crawlers/medium.py: -------------------------------------------------------------------------------- 1 | from aws_lambda_powertools import Logger 2 | from bs4 import BeautifulSoup 3 | from core.db.documents import ArticleDocument 4 | from selenium.webdriver.common.by import By 5 | 6 | from crawlers.base import BaseAbstractCrawler 7 | 8 | logger = Logger(service="llm-twin-course/crawler") 9 | 10 | 11 | class MediumCrawler(BaseAbstractCrawler): 12 | model = ArticleDocument 13 | 14 | def set_extra_driver_options(self, options) -> None: 15 | options.add_argument(r"--profile-directory=Profile 2") 16 | 17 | def extract(self, link: str, **kwargs) -> None: 18 | logger.info(f"Starting scrapping Medium article: {link}") 19 | 20 | self.driver.get(link) 21 | self.scroll_page() 22 | 23 | soup = BeautifulSoup(self.driver.page_source, "html.parser") 24 | title = soup.find_all("h1", class_="pw-post-title") 25 | subtitle = soup.find_all("h2", class_="pw-subtitle-paragraph") 26 | 27 | data = { 28 | "Title": title[0].string if title else None, 29 | "Subtitle": subtitle[0].string if subtitle else None, 30 | "Content": soup.get_text(), 31 | } 32 | 33 | logger.info(f"Successfully scraped and saved article: {link}") 34 | self.driver.close() 35 | instance = self.model( 36 | platform="medium", content=data, link=link, author_id=kwargs.get("user") 37 | ) 38 | instance.save() 39 | 40 | def login(self): 41 | """Log in to Medium with Google""" 42 | self.driver.get("https://medium.com/m/signin") 43 | self.driver.find_element(By.TAG_NAME, "a").click() 44 | -------------------------------------------------------------------------------- /src/data_crawling/dispatcher.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from aws_lambda_powertools import Logger 4 | from crawlers.base import BaseCrawler 5 | from crawlers.custom_article import CustomArticleCrawler 6 | 7 | logger = Logger(service="llm-twin-course/crawler") 8 | 9 | 10 | class CrawlerDispatcher: 11 | def __init__(self) -> None: 12 | self._crawlers = {} 13 | 14 | def register(self, domain: str, crawler: type[BaseCrawler]) -> None: 15 | self._crawlers[r"https://(www\.)?{}.com/*".format(re.escape(domain))] = crawler 16 | 17 | def get_crawler(self, url: str) -> BaseCrawler: 18 | for pattern, crawler in self._crawlers.items(): 19 | if re.match(pattern, url): 20 | return crawler() 21 | else: 22 | logger.warning( 23 | f"No crawler found for {url}. Defaulting to CustomArticleCrawler." 24 | ) 25 | 26 | return CustomArticleCrawler() 27 | -------------------------------------------------------------------------------- /src/data_crawling/main.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from aws_lambda_powertools import Logger 4 | from aws_lambda_powertools.utilities.typing import LambdaContext 5 | from core import lib 6 | from core.db.documents import UserDocument 7 | from crawlers import CustomArticleCrawler, GithubCrawler, LinkedInCrawler 8 | from dispatcher import CrawlerDispatcher 9 | 10 | logger = Logger(service="llm-twin-course/crawler") 11 | 12 | _dispatcher = CrawlerDispatcher() 13 | _dispatcher.register("medium", CustomArticleCrawler) 14 | _dispatcher.register("linkedin", LinkedInCrawler) 15 | _dispatcher.register("github", GithubCrawler) 16 | 17 | 18 | def handler(event, context: LambdaContext | None = None) -> dict[str, Any]: 19 | first_name, last_name = lib.split_user_full_name(event.get("user")) 20 | 21 | user_id = UserDocument.get_or_create(first_name=first_name, last_name=last_name) 22 | 23 | link = event.get("link") 24 | crawler = _dispatcher.get_crawler(link) 25 | 26 | try: 27 | crawler.extract(link=link, user=user_id) 28 | 29 | return {"statusCode": 200, "body": "Link processed successfully"} 30 | except Exception as e: 31 | return {"statusCode": 500, "body": f"An error occurred: {str(e)}"} 32 | 33 | 34 | if __name__ == "__main__": 35 | event = { 36 | "user": "Paul Iuztin", 37 | "link": "https://www.linkedin.com/in/vesaalexandru/", 38 | } 39 | handler(event, None) 40 | -------------------------------------------------------------------------------- /src/data_crawling/utils.py: -------------------------------------------------------------------------------- 1 | import structlog 2 | 3 | 4 | def get_logger(cls: str): 5 | return structlog.get_logger().bind(cls=cls) 6 | -------------------------------------------------------------------------------- /src/feature_pipeline/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from pydantic_settings import BaseSettings, SettingsConfigDict 4 | 5 | ROOT_DIR = str(Path(__file__).parent.parent.parent) 6 | 7 | 8 | class Settings(BaseSettings): 9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8") 10 | 11 | # CometML config 12 | COMET_API_KEY: str | None = None 13 | COMET_WORKSPACE: str | None = None 14 | COMET_PROJECT: str = "llm-twin" 15 | 16 | # Embeddings config 17 | EMBEDDING_MODEL_ID: str = "BAAI/bge-small-en-v1.5" 18 | EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 512 19 | EMBEDDING_SIZE: int = 384 20 | EMBEDDING_MODEL_DEVICE: str = "cpu" 21 | 22 | # OpenAI 23 | OPENAI_MODEL_ID: str = "gpt-4o-mini" 24 | OPENAI_API_KEY: str | None = None 25 | 26 | # MQ config 27 | RABBITMQ_DEFAULT_USERNAME: str = "guest" 28 | RABBITMQ_DEFAULT_PASSWORD: str = "guest" 29 | RABBITMQ_HOST: str = "mq" # or localhost if running outside Docker 30 | RABBITMQ_PORT: int = 5672 31 | RABBITMQ_QUEUE_NAME: str = "default" 32 | 33 | # QdrantDB config 34 | QDRANT_DATABASE_HOST: str = "qdrant" # or localhost if running outside Docker 35 | QDRANT_DATABASE_PORT: int = 6333 36 | USE_QDRANT_CLOUD: bool = ( 37 | False # if True, fill in QDRANT_CLOUD_URL and QDRANT_APIKEY 38 | ) 39 | QDRANT_CLOUD_URL: str | None = None 40 | QDRANT_APIKEY: str | None = None 41 | 42 | 43 | settings = Settings() 44 | -------------------------------------------------------------------------------- /src/feature_pipeline/data_flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/feature_pipeline/data_flow/__init__.py -------------------------------------------------------------------------------- /src/feature_pipeline/data_flow/stream_input.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from datetime import datetime 4 | from typing import Generic, Iterable, List, Optional, TypeVar 5 | 6 | from bytewax.inputs import FixedPartitionedSource, StatefulSourcePartition 7 | from config import settings 8 | from core import get_logger 9 | from core.mq import RabbitMQConnection 10 | 11 | logger = get_logger(__name__) 12 | 13 | DataT = TypeVar("DataT") 14 | MessageT = TypeVar("MessageT") 15 | 16 | 17 | class RabbitMQPartition(StatefulSourcePartition, Generic[DataT, MessageT]): 18 | """ 19 | Class responsible for creating a connection between bytewax and rabbitmq that facilitates the transfer of data from mq to bytewax streaming piepline. 20 | Inherits StatefulSourcePartition for snapshot functionality that enables saving the state of the queue 21 | """ 22 | 23 | def __init__(self, queue_name: str, resume_state: MessageT | None = None) -> None: 24 | self._in_flight_msg_ids = resume_state or set() 25 | self.queue_name = queue_name 26 | self.connection = RabbitMQConnection() 27 | self.connection.connect() 28 | self.channel = self.connection.get_channel() 29 | 30 | def next_batch(self, sched: Optional[datetime]) -> Iterable[DataT]: 31 | try: 32 | method_frame, header_frame, body = self.channel.basic_get( 33 | queue=self.queue_name, auto_ack=True 34 | ) 35 | except Exception: 36 | logger.error( 37 | f"Error while fetching message from queue.", queue_name=self.queue_name 38 | ) 39 | time.sleep(10) # Sleep for 10 seconds before retrying to access the queue. 40 | 41 | self.connection.connect() 42 | self.channel = self.connection.get_channel() 43 | 44 | return [] 45 | 46 | if method_frame: 47 | message_id = method_frame.delivery_tag 48 | self._in_flight_msg_ids.add(message_id) 49 | 50 | return [json.loads(body)] 51 | else: 52 | return [] 53 | 54 | def snapshot(self) -> MessageT: 55 | return self._in_flight_msg_ids 56 | 57 | def garbage_collect(self, state): 58 | closed_in_flight_msg_ids = state 59 | for msg_id in closed_in_flight_msg_ids: 60 | self.channel.basic_ack(delivery_tag=msg_id) 61 | self._in_flight_msg_ids.remove(msg_id) 62 | 63 | def close(self): 64 | self.channel.close() 65 | 66 | 67 | class RabbitMQSource(FixedPartitionedSource): 68 | def list_parts(self) -> List[str]: 69 | return ["single partition"] 70 | 71 | def build_part( 72 | self, now: datetime, for_part: str, resume_state: MessageT | None = None 73 | ) -> StatefulSourcePartition[DataT, MessageT]: 74 | return RabbitMQPartition(queue_name=settings.RABBITMQ_QUEUE_NAME) 75 | -------------------------------------------------------------------------------- /src/feature_pipeline/data_logic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/feature_pipeline/data_logic/__init__.py -------------------------------------------------------------------------------- /src/feature_pipeline/data_logic/chunking_data_handlers.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from abc import ABC, abstractmethod 3 | 4 | from models.base import DataModel 5 | from models.chunk import ArticleChunkModel, PostChunkModel, RepositoryChunkModel 6 | from models.clean import ArticleCleanedModel, PostCleanedModel, RepositoryCleanedModel 7 | from utils.chunking import chunk_text 8 | 9 | 10 | class ChunkingDataHandler(ABC): 11 | """ 12 | Abstract class for all Chunking data handlers. 13 | All data transformations logic for the chunking step is done here 14 | """ 15 | 16 | @abstractmethod 17 | def chunk(self, data_model: DataModel) -> list[DataModel]: 18 | pass 19 | 20 | 21 | class PostChunkingHandler(ChunkingDataHandler): 22 | def chunk(self, data_model: PostCleanedModel) -> list[PostChunkModel]: 23 | data_models_list = [] 24 | 25 | text_content = data_model.cleaned_content 26 | chunks = chunk_text(text_content) 27 | 28 | for chunk in chunks: 29 | model = PostChunkModel( 30 | entry_id=data_model.entry_id, 31 | platform=data_model.platform, 32 | chunk_id=hashlib.md5(chunk.encode()).hexdigest(), 33 | chunk_content=chunk, 34 | author_id=data_model.author_id, 35 | image=data_model.image if data_model.image else None, 36 | type=data_model.type, 37 | ) 38 | data_models_list.append(model) 39 | 40 | return data_models_list 41 | 42 | 43 | class ArticleChunkingHandler(ChunkingDataHandler): 44 | def chunk(self, data_model: ArticleCleanedModel) -> list[ArticleChunkModel]: 45 | data_models_list = [] 46 | 47 | text_content = data_model.cleaned_content 48 | chunks = chunk_text(text_content) 49 | 50 | for chunk in chunks: 51 | model = ArticleChunkModel( 52 | entry_id=data_model.entry_id, 53 | platform=data_model.platform, 54 | link=data_model.link, 55 | chunk_id=hashlib.md5(chunk.encode()).hexdigest(), 56 | chunk_content=chunk, 57 | author_id=data_model.author_id, 58 | type=data_model.type, 59 | ) 60 | data_models_list.append(model) 61 | 62 | return data_models_list 63 | 64 | 65 | class RepositoryChunkingHandler(ChunkingDataHandler): 66 | def chunk(self, data_model: RepositoryCleanedModel) -> list[RepositoryChunkModel]: 67 | data_models_list = [] 68 | 69 | text_content = data_model.cleaned_content 70 | chunks = chunk_text(text_content) 71 | 72 | for chunk in chunks: 73 | model = RepositoryChunkModel( 74 | entry_id=data_model.entry_id, 75 | name=data_model.name, 76 | link=data_model.link, 77 | chunk_id=hashlib.md5(chunk.encode()).hexdigest(), 78 | chunk_content=chunk, 79 | owner_id=data_model.owner_id, 80 | type=data_model.type, 81 | ) 82 | data_models_list.append(model) 83 | 84 | return data_models_list 85 | -------------------------------------------------------------------------------- /src/feature_pipeline/data_logic/cleaning_data_handlers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from models.base import DataModel 4 | from models.clean import ArticleCleanedModel, PostCleanedModel, RepositoryCleanedModel 5 | from models.raw import ArticleRawModel, PostsRawModel, RepositoryRawModel 6 | from utils.cleaning import clean_text 7 | 8 | 9 | class CleaningDataHandler(ABC): 10 | """ 11 | Abstract class for all cleaning data handlers. 12 | All data transformations logic for the cleaning step is done here 13 | """ 14 | 15 | @abstractmethod 16 | def clean(self, data_model: DataModel) -> DataModel: 17 | pass 18 | 19 | 20 | class PostCleaningHandler(CleaningDataHandler): 21 | def clean(self, data_model: PostsRawModel) -> PostCleanedModel: 22 | joined_text = ( 23 | "".join(data_model.content.values()) if data_model and data_model.content else None 24 | ) 25 | 26 | return PostCleanedModel( 27 | entry_id=data_model.entry_id, 28 | platform=data_model.platform, 29 | cleaned_content=clean_text(joined_text), 30 | author_id=data_model.author_id, 31 | image=data_model.image if data_model.image else None, 32 | type=data_model.type, 33 | ) 34 | 35 | 36 | class ArticleCleaningHandler(CleaningDataHandler): 37 | def clean(self, data_model: ArticleRawModel) -> ArticleCleanedModel: 38 | joined_text = ( 39 | "".join(data_model.content.values()) if data_model and data_model.content else None 40 | ) 41 | 42 | return ArticleCleanedModel( 43 | entry_id=data_model.entry_id, 44 | platform=data_model.platform, 45 | link=data_model.link, 46 | cleaned_content=clean_text(joined_text), 47 | author_id=data_model.author_id, 48 | type=data_model.type, 49 | ) 50 | 51 | 52 | class RepositoryCleaningHandler(CleaningDataHandler): 53 | def clean(self, data_model: RepositoryRawModel) -> RepositoryCleanedModel: 54 | joined_text = ( 55 | "".join(data_model.content.values()) if data_model and data_model.content else None 56 | ) 57 | 58 | return RepositoryCleanedModel( 59 | entry_id=data_model.entry_id, 60 | name=data_model.name, 61 | link=data_model.link, 62 | cleaned_content=clean_text(joined_text), 63 | owner_id=data_model.owner_id, 64 | type=data_model.type, 65 | ) 66 | -------------------------------------------------------------------------------- /src/feature_pipeline/data_logic/embedding_data_handlers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from models.base import DataModel 4 | from models.chunk import ArticleChunkModel, PostChunkModel, RepositoryChunkModel 5 | from models.embedded_chunk import ( 6 | ArticleEmbeddedChunkModel, 7 | PostEmbeddedChunkModel, 8 | RepositoryEmbeddedChunkModel, 9 | ) 10 | from utils.embeddings import embedd_text 11 | 12 | 13 | class EmbeddingDataHandler(ABC): 14 | """ 15 | Abstract class for all embedding data handlers. 16 | All data transformations logic for the embedding step is done here 17 | """ 18 | 19 | @abstractmethod 20 | def embedd(self, data_model: DataModel) -> DataModel: 21 | pass 22 | 23 | 24 | class PostEmbeddingHandler(EmbeddingDataHandler): 25 | def embedd(self, data_model: PostChunkModel) -> PostEmbeddedChunkModel: 26 | return PostEmbeddedChunkModel( 27 | entry_id=data_model.entry_id, 28 | platform=data_model.platform, 29 | chunk_id=data_model.chunk_id, 30 | chunk_content=data_model.chunk_content, 31 | embedded_content=embedd_text(data_model.chunk_content), 32 | author_id=data_model.author_id, 33 | type=data_model.type, 34 | ) 35 | 36 | 37 | class ArticleEmbeddingHandler(EmbeddingDataHandler): 38 | def embedd(self, data_model: ArticleChunkModel) -> ArticleEmbeddedChunkModel: 39 | return ArticleEmbeddedChunkModel( 40 | entry_id=data_model.entry_id, 41 | platform=data_model.platform, 42 | link=data_model.link, 43 | chunk_content=data_model.chunk_content, 44 | chunk_id=data_model.chunk_id, 45 | embedded_content=embedd_text(data_model.chunk_content), 46 | author_id=data_model.author_id, 47 | type=data_model.type, 48 | ) 49 | 50 | 51 | class RepositoryEmbeddingHandler(EmbeddingDataHandler): 52 | def embedd(self, data_model: RepositoryChunkModel) -> RepositoryEmbeddedChunkModel: 53 | return RepositoryEmbeddedChunkModel( 54 | entry_id=data_model.entry_id, 55 | name=data_model.name, 56 | link=data_model.link, 57 | chunk_id=data_model.chunk_id, 58 | chunk_content=data_model.chunk_content, 59 | embedded_content=embedd_text(data_model.chunk_content), 60 | owner_id=data_model.owner_id, 61 | type=data_model.type, 62 | ) 63 | -------------------------------------------------------------------------------- /src/feature_pipeline/generate_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/feature_pipeline/generate_dataset/__init__.py -------------------------------------------------------------------------------- /src/feature_pipeline/generate_dataset/chunk_documents.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def chunk_documents(documents: list[str], min_length: int = 1000, max_length: int = 2000): 5 | chunked_documents = [] 6 | for document in documents: 7 | chunks = extract_substrings(document, min_length=min_length, max_length=max_length) 8 | chunked_documents.extend(chunks) 9 | 10 | return chunked_documents 11 | 12 | def extract_substrings( 13 | text: str, min_length: int = 1000, max_length: int = 2000 14 | ) -> list[str]: 15 | sentences = re.split(r"(?= min_length: 28 | extracts.append(current_chunk.strip()) 29 | current_chunk = sentence + " " 30 | 31 | if len(current_chunk) >= min_length: 32 | extracts.append(current_chunk.strip()) 33 | 34 | return extracts 35 | -------------------------------------------------------------------------------- /src/feature_pipeline/generate_dataset/exceptions.py: -------------------------------------------------------------------------------- 1 | class DatasetError(Exception): 2 | pass 3 | 4 | 5 | class FileNotFoundError(DatasetError): 6 | pass 7 | 8 | 9 | class JSONDecodeError(DatasetError): 10 | pass 11 | 12 | 13 | class APICommunicationError(DatasetError): 14 | pass 15 | -------------------------------------------------------------------------------- /src/feature_pipeline/generate_dataset/file_handler.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from generate_dataset.exceptions import JSONDecodeError 4 | 5 | 6 | class FileHandler: 7 | def read_json(self, filename: str) -> list: 8 | try: 9 | with open(filename, "r") as file: 10 | return json.load(file) 11 | except FileNotFoundError: 12 | raise FileNotFoundError(f"The file '{filename}' does not exist.") 13 | except json.JSONDecodeError: 14 | raise JSONDecodeError( 15 | f"The file '{filename}' is not properly formatted as JSON." 16 | ) 17 | 18 | def write_json(self, filename: str, data: list): 19 | with open(filename, "w") as file: 20 | json.dump(data, file, indent=4) 21 | -------------------------------------------------------------------------------- /src/feature_pipeline/generate_dataset/llm_communication.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from config import settings 4 | from core import get_logger 5 | from openai import OpenAI 6 | 7 | MAX_LENGTH = 16384 8 | SYSTEM_PROMPT = ( 9 | "You are a technical writer handing someone's account to post about AI and MLOps." 10 | ) 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | class GptCommunicator: 16 | def __init__(self, gpt_model: str = settings.OPENAI_MODEL_ID): 17 | self.api_key = settings.OPENAI_API_KEY 18 | self.gpt_model = gpt_model 19 | 20 | def send_prompt(self, prompt: str) -> list: 21 | try: 22 | client = OpenAI(api_key=self.api_key) 23 | logger.info(f"Sending batch to GPT = '{settings.OPENAI_MODEL_ID}'.") 24 | 25 | chat_completion = client.chat.completions.create( 26 | messages=[ 27 | {"role": "system", "content": SYSTEM_PROMPT}, 28 | {"role": "user", "content": prompt[:MAX_LENGTH]}, 29 | ], 30 | model=self.gpt_model, 31 | ) 32 | response = chat_completion.choices[0].message.content 33 | return json.loads(self.clean_response(response)) 34 | except Exception: 35 | logger.exception( 36 | f"Skipping batch! An error occurred while communicating with API." 37 | ) 38 | 39 | return [] 40 | 41 | @staticmethod 42 | def clean_response(response: str) -> str: 43 | start_index = response.find("[") 44 | end_index = response.rfind("]") 45 | return response[start_index : end_index + 1] 46 | -------------------------------------------------------------------------------- /src/feature_pipeline/main.py: -------------------------------------------------------------------------------- 1 | import bytewax.operators as op 2 | from bytewax.dataflow import Dataflow 3 | from core.db.qdrant import QdrantDatabaseConnector 4 | from data_flow.stream_input import RabbitMQSource 5 | from data_flow.stream_output import QdrantOutput 6 | from data_logic.dispatchers import ( 7 | ChunkingDispatcher, 8 | CleaningDispatcher, 9 | EmbeddingDispatcher, 10 | RawDispatcher, 11 | ) 12 | 13 | connection = QdrantDatabaseConnector() 14 | 15 | flow = Dataflow("Streaming ingestion pipeline") 16 | stream = op.input("input", flow, RabbitMQSource()) 17 | stream = op.map("raw dispatch", stream, RawDispatcher.handle_mq_message) 18 | stream = op.map("clean dispatch", stream, CleaningDispatcher.dispatch_cleaner) 19 | op.output( 20 | "cleaned data insert to qdrant", 21 | stream, 22 | QdrantOutput(connection=connection, sink_type="clean"), 23 | ) 24 | stream = op.flat_map("chunk dispatch", stream, ChunkingDispatcher.dispatch_chunker) 25 | stream = op.map( 26 | "embedded chunk dispatch", stream, EmbeddingDispatcher.dispatch_embedder 27 | ) 28 | op.output( 29 | "embedded data insert to qdrant", 30 | stream, 31 | QdrantOutput(connection=connection, sink_type="vector"), 32 | ) 33 | -------------------------------------------------------------------------------- /src/feature_pipeline/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/feature_pipeline/models/__init__.py -------------------------------------------------------------------------------- /src/feature_pipeline/models/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class DataModel(BaseModel): 7 | """ 8 | Abstract class for all data models 9 | """ 10 | 11 | entry_id: str 12 | type: str 13 | 14 | 15 | class VectorDBDataModel(ABC, DataModel): 16 | """ 17 | Abstract class for all data models that need to be saved into a vector DB (e.g. Qdrant) 18 | """ 19 | 20 | entry_id: int 21 | type: str 22 | 23 | @abstractmethod 24 | def to_payload(self) -> tuple: 25 | pass 26 | -------------------------------------------------------------------------------- /src/feature_pipeline/models/chunk.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from models.base import DataModel 4 | 5 | 6 | class PostChunkModel(DataModel): 7 | entry_id: str 8 | platform: str 9 | chunk_id: str 10 | chunk_content: str 11 | author_id: str 12 | image: Optional[str] = None 13 | type: str 14 | 15 | 16 | class ArticleChunkModel(DataModel): 17 | entry_id: str 18 | platform: str 19 | link: str 20 | chunk_id: str 21 | chunk_content: str 22 | author_id: str 23 | type: str 24 | 25 | 26 | class RepositoryChunkModel(DataModel): 27 | entry_id: str 28 | name: str 29 | link: str 30 | chunk_id: str 31 | chunk_content: str 32 | owner_id: str 33 | type: str 34 | -------------------------------------------------------------------------------- /src/feature_pipeline/models/clean.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from models.base import VectorDBDataModel 4 | 5 | 6 | class PostCleanedModel(VectorDBDataModel): 7 | entry_id: str 8 | platform: str 9 | cleaned_content: str 10 | author_id: str 11 | image: Optional[str] = None 12 | type: str 13 | 14 | def to_payload(self) -> Tuple[str, dict]: 15 | data = { 16 | "platform": self.platform, 17 | "author_id": self.author_id, 18 | "cleaned_content": self.cleaned_content, 19 | "image": self.image, 20 | "type": self.type, 21 | } 22 | 23 | return self.entry_id, data 24 | 25 | 26 | class ArticleCleanedModel(VectorDBDataModel): 27 | entry_id: str 28 | platform: str 29 | link: str 30 | cleaned_content: str 31 | author_id: str 32 | type: str 33 | 34 | def to_payload(self) -> Tuple[str, dict]: 35 | data = { 36 | "platform": self.platform, 37 | "link": self.link, 38 | "cleaned_content": self.cleaned_content, 39 | "author_id": self.author_id, 40 | "type": self.type, 41 | } 42 | 43 | return self.entry_id, data 44 | 45 | 46 | class RepositoryCleanedModel(VectorDBDataModel): 47 | entry_id: str 48 | name: str 49 | link: str 50 | cleaned_content: str 51 | owner_id: str 52 | type: str 53 | 54 | def to_payload(self) -> Tuple[str, dict]: 55 | data = { 56 | "name": self.name, 57 | "link": self.link, 58 | "cleaned_content": self.cleaned_content, 59 | "owner_id": self.owner_id, 60 | "type": self.type, 61 | } 62 | 63 | return self.entry_id, data 64 | -------------------------------------------------------------------------------- /src/feature_pipeline/models/embedded_chunk.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | from models.base import VectorDBDataModel 6 | 7 | 8 | class PostEmbeddedChunkModel(VectorDBDataModel): 9 | entry_id: str 10 | platform: str 11 | chunk_id: str 12 | chunk_content: str 13 | embedded_content: np.ndarray 14 | author_id: str 15 | type: str 16 | 17 | class Config: 18 | arbitrary_types_allowed = True 19 | 20 | def to_payload(self) -> Tuple[str, np.ndarray, dict]: 21 | data = { 22 | "id": self.entry_id, 23 | "platform": self.platform, 24 | "content": self.chunk_content, 25 | "owner_id": self.author_id, 26 | "type": self.type, 27 | } 28 | 29 | return self.chunk_id, self.embedded_content, data 30 | 31 | 32 | class ArticleEmbeddedChunkModel(VectorDBDataModel): 33 | entry_id: str 34 | platform: str 35 | link: str 36 | chunk_id: str 37 | chunk_content: str 38 | embedded_content: np.ndarray 39 | author_id: str 40 | type: str 41 | 42 | class Config: 43 | arbitrary_types_allowed = True 44 | 45 | def to_payload(self) -> Tuple[str, np.ndarray, dict]: 46 | data = { 47 | "id": self.entry_id, 48 | "platform": self.platform, 49 | "content": self.chunk_content, 50 | "link": self.link, 51 | "author_id": self.author_id, 52 | "type": self.type, 53 | } 54 | 55 | return self.chunk_id, self.embedded_content, data 56 | 57 | 58 | class RepositoryEmbeddedChunkModel(VectorDBDataModel): 59 | entry_id: str 60 | name: str 61 | link: str 62 | chunk_id: str 63 | chunk_content: str 64 | embedded_content: np.ndarray 65 | owner_id: str 66 | type: str 67 | 68 | class Config: 69 | arbitrary_types_allowed = True 70 | 71 | def to_payload(self) -> Tuple[str, np.ndarray, dict]: 72 | data = { 73 | "id": self.entry_id, 74 | "name": self.name, 75 | "content": self.chunk_content, 76 | "link": self.link, 77 | "owner_id": self.owner_id, 78 | "type": self.type, 79 | } 80 | 81 | return self.chunk_id, self.embedded_content, data 82 | -------------------------------------------------------------------------------- /src/feature_pipeline/models/raw.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from models.base import DataModel 4 | 5 | 6 | class RepositoryRawModel(DataModel): 7 | name: str 8 | link: str 9 | content: dict 10 | owner_id: str 11 | 12 | 13 | class ArticleRawModel(DataModel): 14 | platform: str 15 | link: str 16 | content: dict 17 | author_id: str 18 | 19 | 20 | class PostsRawModel(DataModel): 21 | platform: str 22 | content: dict 23 | author_id: str | None = None 24 | image: Optional[str] = None 25 | -------------------------------------------------------------------------------- /src/feature_pipeline/retriever.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | # To mimic using multiple Python modules, such as 'core' and 'feature_pipeline', 5 | # we will add the './src' directory to the PYTHONPATH. This is not intended for 6 | # production use cases but for development and educational purposes. 7 | ROOT_DIR = str(Path(__file__).parent.parent) 8 | sys.path.append(ROOT_DIR) 9 | 10 | 11 | from core import get_logger 12 | from core.config import settings 13 | from core.rag.retriever import VectorRetriever 14 | 15 | logger = get_logger(__name__) 16 | 17 | settings.patch_localhost() 18 | logger.warning( 19 | "Patched settings to work with 'localhost' URLs. \ 20 | Remove the 'settings.patch_localhost()' call from above when deploying or running inside Docker." 21 | ) 22 | 23 | if __name__ == "__main__": 24 | query = """ 25 | Hello I am Paul Iusztin. 26 | 27 | Could you draft an article paragraph discussing RAG? 28 | I'm particularly interested in how to design a RAG system. 29 | """ 30 | 31 | retriever = VectorRetriever(query=query) 32 | hits = retriever.retrieve_top_k(k=6, to_expand_to_n_queries=5) 33 | reranked_hits = retriever.rerank(hits=hits, keep_top_k=5) 34 | 35 | logger.info("====== RETRIEVED DOCUMENTS ======") 36 | for rank, hit in enumerate(reranked_hits): 37 | logger.info(f"Rank = {rank} : {hit}") 38 | -------------------------------------------------------------------------------- /src/feature_pipeline/scripts/bytewax_entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$DEBUG" = true ] 4 | then 5 | python -m bytewax.run "tools.run_real_time:build_flow(debug=True)" 6 | else 7 | if [ "$BYTEWAX_PYTHON_FILE_PATH" = "" ] 8 | then 9 | echo 'BYTEWAX_PYTHON_FILE_PATH is not set. Exiting...' 10 | exit 1 11 | fi 12 | python -m bytewax.run $BYTEWAX_PYTHON_FILE_PATH 13 | fi 14 | 15 | 16 | echo 'Process ended.' 17 | 18 | if [ "$BYTEWAX_KEEP_CONTAINER_ALIVE" = true ] 19 | then 20 | echo 'Keeping container alive...'; 21 | while :; do sleep 1; done 22 | fi -------------------------------------------------------------------------------- /src/feature_pipeline/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/feature_pipeline/utils/chunking.py: -------------------------------------------------------------------------------- 1 | from langchain.text_splitter import ( 2 | RecursiveCharacterTextSplitter, 3 | SentenceTransformersTokenTextSplitter, 4 | ) 5 | 6 | from config import settings 7 | 8 | 9 | def chunk_text(text: str) -> list[str]: 10 | character_splitter = RecursiveCharacterTextSplitter( 11 | separators=["\n\n"], chunk_size=500, chunk_overlap=0 12 | ) 13 | text_split = character_splitter.split_text(text) 14 | 15 | token_splitter = SentenceTransformersTokenTextSplitter( 16 | chunk_overlap=50, 17 | tokens_per_chunk=settings.EMBEDDING_MODEL_MAX_INPUT_LENGTH, 18 | model_name=settings.EMBEDDING_MODEL_ID, 19 | ) 20 | chunks = [] 21 | 22 | for section in text_split: 23 | chunks.extend(token_splitter.split_text(section)) 24 | 25 | return chunks 26 | -------------------------------------------------------------------------------- /src/feature_pipeline/utils/embeddings.py: -------------------------------------------------------------------------------- 1 | from InstructorEmbedding import INSTRUCTOR 2 | from sentence_transformers.SentenceTransformer import SentenceTransformer 3 | 4 | from config import settings 5 | 6 | 7 | def embedd_text(text: str): 8 | model = SentenceTransformer(settings.EMBEDDING_MODEL_ID) 9 | return model.encode(text) 10 | 11 | 12 | def embedd_repositories(text: str): 13 | model = INSTRUCTOR("hkunlp/instructor-xl") 14 | sentence = text 15 | instruction = "Represent the structure of the repository" 16 | return model.encode([instruction, sentence]) 17 | -------------------------------------------------------------------------------- /src/inference_pipeline/aws/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/inference_pipeline/aws/__init__.py -------------------------------------------------------------------------------- /src/inference_pipeline/aws/delete_sagemaker_endpoint.py: -------------------------------------------------------------------------------- 1 | from core import get_logger 2 | 3 | logger = get_logger(__file__) 4 | 5 | try: 6 | import boto3 7 | from botocore.exceptions import ClientError 8 | except ModuleNotFoundError: 9 | logger.warning( 10 | "Couldn't load AWS or SageMaker imports. Run 'poetry install --with aws' to support AWS." 11 | ) 12 | 13 | 14 | from config import settings 15 | 16 | 17 | def delete_endpoint_and_config(endpoint_name) -> None: 18 | """ 19 | Deletes an AWS SageMaker endpoint and its associated configuration. 20 | Args: 21 | endpoint_name (str): The name of the SageMaker endpoint to delete. 22 | Returns: 23 | None 24 | """ 25 | 26 | try: 27 | sagemaker_client = boto3.client( 28 | "sagemaker", 29 | region_name=settings.AWS_REGION, 30 | aws_access_key_id=settings.AWS_ACCESS_KEY, 31 | aws_secret_access_key=settings.AWS_SECRET_KEY, 32 | ) 33 | except Exception: 34 | logger.exception("Error creating SageMaker client") 35 | 36 | return 37 | 38 | # Get the endpoint configuration name 39 | try: 40 | response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name) 41 | config_name = response["EndpointConfigName"] 42 | except ClientError: 43 | logger.error("Error getting endpoint configuration and modelname.") 44 | 45 | return 46 | 47 | # Delete the endpoint 48 | try: 49 | sagemaker_client.delete_endpoint(EndpointName=endpoint_name) 50 | logger.info(f"Endpoint '{endpoint_name}' deletion initiated.") 51 | except ClientError: 52 | logger.error("Error deleting endpoint") 53 | 54 | try: 55 | response = sagemaker_client.describe_endpoint_config( 56 | EndpointConfigName=endpoint_name 57 | ) 58 | model_name = response["ProductionVariants"][0]["ModelName"] 59 | except ClientError: 60 | logger.error("Error getting model name.") 61 | 62 | # Delete the endpoint configuration 63 | try: 64 | sagemaker_client.delete_endpoint_config(EndpointConfigName=config_name) 65 | logger.info(f"Endpoint configuration '{config_name}' deleted.") 66 | except ClientError: 67 | logger.error("Error deleting endpoint configuration.") 68 | 69 | # Delete models 70 | try: 71 | sagemaker_client.delete_model(ModelName=model_name) 72 | logger.info(f"Model '{model_name}' deleted.") 73 | except ClientError: 74 | logger.error("Error deleting model.") 75 | 76 | 77 | if __name__ == "__main__": 78 | endpoint_name = settings.DEPLOYMENT_ENDPOINT_NAME 79 | logger.info(f"Attempting to delete endpoint: {endpoint_name}") 80 | delete_endpoint_and_config(endpoint_name=endpoint_name) 81 | -------------------------------------------------------------------------------- /src/inference_pipeline/aws/deploy_sagemaker_endpoint.py: -------------------------------------------------------------------------------- 1 | from config import settings 2 | from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri 3 | 4 | 5 | def main() -> None: 6 | assert settings.HUGGINGFACE_ACCESS_TOKEN, "HUGGINGFACE_ACCESS_TOKEN is required." 7 | 8 | env_vars = { 9 | "HF_MODEL_ID": settings.MODEL_ID, 10 | "SM_NUM_GPUS": "1", # Number of GPU used per replica. 11 | "HUGGING_FACE_HUB_TOKEN": settings.HUGGINGFACE_ACCESS_TOKEN, 12 | "MAX_INPUT_TOKENS": str( 13 | settings.MAX_INPUT_TOKENS 14 | ), # Max length of input tokens. 15 | "MAX_TOTAL_TOKENS": str( 16 | settings.MAX_TOTAL_TOKENS 17 | ), # Max length of the generation (including input text). 18 | "MAX_BATCH_TOTAL_TOKENS": str( 19 | settings.MAX_BATCH_TOTAL_TOKENS 20 | ), # Limits the number of tokens that can be processed in parallel during the generation. 21 | "MESSAGES_API_ENABLED": "true", # Enable/disable the messages API, following OpenAI's standard. 22 | "HF_MODEL_QUANTIZE": "bitsandbytes", 23 | } 24 | 25 | image_uri = get_huggingface_llm_image_uri("huggingface", version="2.2.0") 26 | 27 | model = HuggingFaceModel( 28 | env=env_vars, role=settings.AWS_ARN_ROLE, image_uri=image_uri 29 | ) 30 | 31 | model.deploy( 32 | initial_instance_count=1, 33 | instance_type="ml.g5.2xlarge", 34 | container_startup_health_check_timeout=900, 35 | endpoint_name=settings.DEPLOYMENT_ENDPOINT_NAME, 36 | ) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /src/inference_pipeline/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from pydantic_settings import BaseSettings, SettingsConfigDict 4 | 5 | ROOT_DIR = str(Path(__file__).parent.parent.parent) 6 | 7 | 8 | class Settings(BaseSettings): 9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8") 10 | 11 | # Embeddings config 12 | EMBEDDING_MODEL_ID: str = "BAAI/bge-small-en-v1.5" 13 | EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 512 14 | EMBEDDING_SIZE: int = 384 15 | EMBEDDING_MODEL_DEVICE: str = "cpu" 16 | 17 | # OpenAI config 18 | OPENAI_MODEL_ID: str = "gpt-4o-mini" 19 | OPENAI_API_KEY: str | None = None 20 | 21 | # QdrantDB config 22 | QDRANT_DATABASE_HOST: str = "localhost" # Or 'qdrant' if running inside Docker 23 | QDRANT_DATABASE_PORT: int = 6333 24 | 25 | USE_QDRANT_CLOUD: bool = ( 26 | False # if True, fill in QDRANT_CLOUD_URL and QDRANT_APIKEY 27 | ) 28 | QDRANT_CLOUD_URL: str = "str" 29 | QDRANT_APIKEY: str | None = None 30 | 31 | # RAG config 32 | TOP_K: int = 5 33 | KEEP_TOP_K: int = 5 34 | EXPAND_N_QUERY: int = 5 35 | 36 | # CometML config 37 | COMET_API_KEY: str 38 | COMET_WORKSPACE: str 39 | COMET_PROJECT: str = "llm-twin" 40 | 41 | # LLM Model config 42 | HUGGINGFACE_ACCESS_TOKEN: str | None = None 43 | MODEL_ID: str = "pauliusztin/LLMTwin-Llama-3.1-8B" # Change this with your Hugging Face model ID to test out your fine-tuned LLM 44 | DEPLOYMENT_ENDPOINT_NAME: str = "twin" 45 | 46 | MAX_INPUT_TOKENS: int = 1536 # Max length of input text. 47 | MAX_TOTAL_TOKENS: int = 2048 # Max length of the generation (including input text). 48 | MAX_BATCH_TOTAL_TOKENS: int = 2048 # Limits the number of tokens that can be processed in parallel during the generation. 49 | 50 | # AWS Authentication 51 | AWS_REGION: str = "eu-central-1" 52 | AWS_ACCESS_KEY: str | None = None 53 | AWS_SECRET_KEY: str | None = None 54 | AWS_ARN_ROLE: str | None = None 55 | 56 | 57 | settings = Settings() 58 | -------------------------------------------------------------------------------- /src/inference_pipeline/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | # To mimic using multiple Python modules, such as 'core' and 'feature_pipeline', 5 | # we will add the './src' directory to the PYTHONPATH. This is not intended for 6 | # production use cases but for development and educational purposes. 7 | ROOT_DIR = str(Path(__file__).parent.parent.parent) 8 | sys.path.append(ROOT_DIR) 9 | 10 | from core import logger_utils 11 | 12 | logger = logger_utils.get_logger(__name__) 13 | logger.info( 14 | f"Added the following directory to PYTHONPATH to simulate multiple modules: {ROOT_DIR}" 15 | ) 16 | -------------------------------------------------------------------------------- /src/inference_pipeline/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from config import settings 4 | from core.logger_utils import get_logger 5 | from core.opik_utils import create_dataset_from_artifacts 6 | from llm_twin import LLMTwin 7 | from opik.evaluation import evaluate 8 | from opik.evaluation.metrics import Hallucination, LevenshteinRatio, Moderation 9 | 10 | from .style import Style 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | def evaluation_task(x: dict) -> dict: 16 | inference_pipeline = LLMTwin(mock=False) 17 | result = inference_pipeline.generate( 18 | query=x["instruction"], 19 | enable_rag=False, 20 | ) 21 | answer = result["answer"] 22 | 23 | return { 24 | "input": x["instruction"], 25 | "output": answer, 26 | "expected_output": x["content"], 27 | "reference": x["content"], 28 | } 29 | 30 | 31 | def main() -> None: 32 | parser = argparse.ArgumentParser(description="Evaluate monitoring script.") 33 | parser.add_argument( 34 | "--dataset_name", 35 | type=str, 36 | default="LLMTwinMonitoringDataset", 37 | help="Name of the dataset to evaluate", 38 | ) 39 | 40 | args = parser.parse_args() 41 | 42 | dataset_name = args.dataset_name 43 | 44 | logger.info(f"Evaluating Opik dataset: '{dataset_name}'") 45 | 46 | dataset = create_dataset_from_artifacts( 47 | dataset_name="LLMTwinArtifactTestDataset", 48 | artifact_names=[ 49 | "articles-instruct-dataset", 50 | "repositories-instruct-dataset", 51 | ], 52 | ) 53 | if dataset is None: 54 | logger.error("Dataset can't be created. Exiting.") 55 | exit(1) 56 | 57 | experiment_config = { 58 | "model_id": settings.MODEL_ID, 59 | } 60 | scoring_metrics = [ 61 | LevenshteinRatio(), 62 | Hallucination(), 63 | Moderation(), 64 | Style(), 65 | ] 66 | evaluate( 67 | dataset=dataset, 68 | task=evaluation_task, 69 | scoring_metrics=scoring_metrics, 70 | experiment_config=experiment_config, 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /src/inference_pipeline/evaluation/evaluate_monitoring.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import opik 4 | from config import settings 5 | from core.logger_utils import get_logger 6 | from opik.evaluation import evaluate 7 | from opik.evaluation.metrics import AnswerRelevance, Hallucination, Moderation 8 | 9 | from .style import Style 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | def evaluation_task(x: dict) -> dict: 15 | return { 16 | "input": x["input"]["query"], 17 | "context": x["expected_output"]["context"], 18 | "output": x["expected_output"]["answer"], 19 | } 20 | 21 | 22 | def main() -> None: 23 | parser = argparse.ArgumentParser(description="Evaluate monitoring script.") 24 | parser.add_argument( 25 | "--dataset_name", 26 | type=str, 27 | default="LLMTwinMonitoringDataset", 28 | help="Name of the dataset to evaluate", 29 | ) 30 | 31 | args = parser.parse_args() 32 | 33 | dataset_name = args.dataset_name 34 | 35 | logger.info(f"Evaluating Opik dataset: '{dataset_name}'") 36 | 37 | client = opik.Opik() 38 | try: 39 | dataset = client.get_dataset(dataset_name) 40 | except Exception: 41 | logger.error(f"Monitoring dataset '{dataset_name}' not found in Opik. Exiting.") 42 | exit(1) 43 | 44 | experiment_config = { 45 | "model_id": settings.MODEL_ID, 46 | } 47 | 48 | scoring_metrics = [Hallucination(), Moderation(), AnswerRelevance(), Style()] 49 | evaluate( 50 | dataset=dataset, 51 | task=evaluation_task, 52 | scoring_metrics=scoring_metrics, 53 | experiment_config=experiment_config, 54 | ) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /src/inference_pipeline/evaluation/evaluate_rag.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from core.config import settings 4 | from core.logger_utils import get_logger 5 | from core.opik_utils import create_dataset_from_artifacts 6 | from llm_twin import LLMTwin 7 | from opik.evaluation import evaluate 8 | from opik.evaluation.metrics import ( 9 | ContextPrecision, 10 | ContextRecall, 11 | Hallucination, 12 | ) 13 | 14 | settings.patch_localhost() 15 | 16 | logger = get_logger(__name__) 17 | logger.warning( 18 | "Patched settings to work with 'localhost' URLs. \ 19 | Remove the 'settings.patch_localhost()' call from above when deploying or running inside Docker." 20 | ) 21 | 22 | 23 | def evaluation_task(x: dict) -> dict: 24 | inference_pipeline = LLMTwin(mock=False) 25 | result = inference_pipeline.generate( 26 | query=x["instruction"], 27 | enable_rag=True, 28 | ) 29 | answer = result["answer"] 30 | context = result["context"] 31 | 32 | return { 33 | "input": x["instruction"], 34 | "output": answer, 35 | "context": context, 36 | "expected_output": x["content"], 37 | "reference": x["content"], 38 | } 39 | 40 | 41 | def main() -> None: 42 | parser = argparse.ArgumentParser(description="Evaluate monitoring script.") 43 | parser.add_argument( 44 | "--dataset_name", 45 | type=str, 46 | default="LLMTwinMonitoringDataset", 47 | help="Name of the dataset to evaluate", 48 | ) 49 | 50 | args = parser.parse_args() 51 | 52 | dataset_name = args.dataset_name 53 | 54 | logger.info(f"Evaluating Opik dataset: '{dataset_name}'") 55 | 56 | dataset = create_dataset_from_artifacts( 57 | dataset_name="LLMTwinArtifactTestDataset", 58 | artifact_names=[ 59 | "articles-instruct-dataset", 60 | "posts-instruct-dataset", 61 | "repositories-instruct-dataset", 62 | ], 63 | ) 64 | if dataset is None: 65 | logger.error("Dataset can't be created. Exiting.") 66 | exit(1) 67 | 68 | experiment_config = { 69 | "model_id": settings.MODEL_ID, 70 | "embedding_model_id": settings.EMBEDDING_MODEL_ID, 71 | } 72 | scoring_metrics = [ 73 | Hallucination(), 74 | ContextRecall(), 75 | ContextPrecision(), 76 | ] 77 | evaluate( 78 | dataset=dataset, 79 | task=evaluation_task, 80 | scoring_metrics=scoring_metrics, 81 | experiment_config=experiment_config, 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /src/inference_pipeline/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | # To mimic using multiple Python modules, such as 'core' and 'feature_pipeline', 5 | # we will add the './src' directory to the PYTHONPATH. This is not intended for 6 | # production use cases but for development and educational purposes. 7 | ROOT_DIR = str(Path(__file__).parent.parent) 8 | sys.path.append(ROOT_DIR) 9 | 10 | from core import logger_utils 11 | from core.config import settings 12 | from llm_twin import LLMTwin 13 | 14 | settings.patch_localhost() 15 | 16 | logger = logger_utils.get_logger(__name__) 17 | logger.info( 18 | f"Added the following directory to PYTHONPATH to simulate multiple modules: {ROOT_DIR}" 19 | ) 20 | logger.warning( 21 | "Patched settings to work with 'localhost' URLs. \ 22 | Remove the 'settings.patch_localhost()' call from above when deploying or running inside Docker." 23 | ) 24 | 25 | 26 | if __name__ == "__main__": 27 | inference_endpoint = LLMTwin(mock=False) 28 | 29 | query = """ 30 | Hello I am Paul Iusztin. 31 | 32 | Could you draft an article paragraph discussing RAG? 33 | I'm particularly interested in how to design a RAG system. 34 | """ 35 | 36 | response = inference_endpoint.generate( 37 | query=query, enable_rag=True, sample_for_evaluation=True 38 | ) 39 | 40 | logger.info("=" * 50) 41 | logger.info(f"Query: {query}") 42 | logger.info("=" * 50) 43 | logger.info(f"Answer: {response['answer']}") 44 | logger.info("=" * 50) 45 | -------------------------------------------------------------------------------- /src/inference_pipeline/prompt_templates.py: -------------------------------------------------------------------------------- 1 | from core.rag.prompt_templates import BasePromptTemplate 2 | from langchain.prompts import PromptTemplate 3 | 4 | 5 | class InferenceTemplate(BasePromptTemplate): 6 | simple_system_prompt: str = """ 7 | You are an AI language model assistant. Your task is to generate a cohesive and concise response based on the user's instruction by using a similar writing style and voice. 8 | """ 9 | simple_prompt_template: str = """ 10 | ### Instruction: 11 | {question} 12 | """ 13 | 14 | rag_system_prompt: str = """ You are a specialist in technical content writing. Your task is to create technical content based on the user's instruction given a specific context 15 | with additional information consisting of the user's previous writings and his knowledge. 16 | 17 | Here is a list of steps that you need to follow in order to solve this task: 18 | 19 | Step 1: You need to analyze the user's instruction. 20 | Step 2: You need to analyze the provided context and how the information in it relates to the user instruction. 21 | Step 3: Generate the content keeping in mind that it needs to be as cohesive and concise as possible based on the query. You will use the users writing style and voice inferred from the user instruction and context. 22 | First try to answer based on the context. If the context is irrelevant answer with "I cannot answer your question, as I don't have enough context." 23 | """ 24 | rag_prompt_template: str = """ 25 | ### Instruction: 26 | {question} 27 | 28 | ### Context: 29 | {context} 30 | """ 31 | 32 | def create_template(self, enable_rag: bool = True) -> tuple[str, PromptTemplate]: 33 | if enable_rag is True: 34 | return self.rag_system_prompt, PromptTemplate( 35 | template=self.rag_prompt_template, 36 | input_variables=["question", "context"], 37 | ) 38 | 39 | return self.simple_system_prompt, PromptTemplate( 40 | template=self.simple_prompt_template, input_variables=["question"] 41 | ) 42 | -------------------------------------------------------------------------------- /src/inference_pipeline/ui.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | # To mimic using multiple Python modules, such as 'core' and 'feature_pipeline', 5 | # we will add the './src' directory to the PYTHONPATH. This is not intended for 6 | # production use cases but for development and educational purposes. 7 | ROOT_DIR = str(Path(__file__).parent.parent) 8 | sys.path.append(ROOT_DIR) 9 | 10 | from core.config import settings 11 | from llm_twin import LLMTwin 12 | 13 | settings.patch_localhost() 14 | 15 | 16 | import gradio as gr 17 | from inference_pipeline.llm_twin import LLMTwin 18 | 19 | llm_twin = LLMTwin(mock=False) 20 | 21 | 22 | def predict(message: str, history: list[list[str]], author: str) -> str: 23 | """ 24 | Generates a response using the LLM Twin, simulating a conversation with your digital twin. 25 | 26 | Args: 27 | message (str): The user's input message or question. 28 | history (List[List[str]]): Previous conversation history between user and twin. 29 | about_me (str): Personal context about the user to help personalize responses. 30 | 31 | Returns: 32 | str: The LLM Twin's generated response. 33 | """ 34 | 35 | query = f"I am {author}. Write about: {message}" 36 | response = llm_twin.generate( 37 | query=query, enable_rag=True, sample_for_evaluation=False 38 | ) 39 | 40 | return response["answer"] 41 | 42 | 43 | demo = gr.ChatInterface( 44 | predict, 45 | textbox=gr.Textbox( 46 | placeholder="Chat with your LLM Twin", 47 | label="Message", 48 | container=False, 49 | scale=7, 50 | ), 51 | additional_inputs=[ 52 | gr.Textbox( 53 | "Paul Iusztin", 54 | label="Who are you?", 55 | ) 56 | ], 57 | title="Your LLM Twin", 58 | description=""" 59 | Chat with your personalized LLM Twin! This AI assistant will help you write content incorporating your style and voice. 60 | """, 61 | theme="soft", 62 | examples=[ 63 | [ 64 | "Draft a post about RAG systems.", 65 | "Paul Iusztin", 66 | ], 67 | [ 68 | "Draft an article paragraph about vector databases.", 69 | "Paul Iusztin", 70 | ], 71 | [ 72 | "Draft a post about LLM chatbots.", 73 | "Paul Iusztin", 74 | ], 75 | ], 76 | cache_examples=False, 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True) 82 | -------------------------------------------------------------------------------- /src/inference_pipeline/utils.py: -------------------------------------------------------------------------------- 1 | from config import settings 2 | from transformers import AutoTokenizer 3 | 4 | 5 | def compute_num_tokens(text: str) -> int: 6 | tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_ID) 7 | 8 | return len(tokenizer.encode(text, add_special_tokens=False)) 9 | 10 | 11 | def truncate_text_to_max_tokens(text: str, max_tokens: int) -> tuple[str, int]: 12 | """Truncates text to not exceed max_tokens while trying to preserve complete sentences. 13 | 14 | Args: 15 | text: The text to truncate 16 | max_tokens: Maximum number of tokens allowed 17 | 18 | Returns: 19 | Truncated text that fits within max_tokens and the number of tokens in the truncated text. 20 | """ 21 | 22 | current_tokens = compute_num_tokens(text) 23 | 24 | if current_tokens <= max_tokens: 25 | return text, current_tokens 26 | 27 | tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_ID) 28 | tokens = tokenizer.encode(text, add_special_tokens=False) 29 | 30 | # Take first max_tokens tokens and decode 31 | truncated_tokens = tokens[:max_tokens] 32 | truncated_text = tokenizer.decode(truncated_tokens) 33 | 34 | # Try to end at last complete sentence 35 | last_period = truncated_text.rfind(".") 36 | if last_period > 0: 37 | truncated_text = truncated_text[: last_period + 1] 38 | 39 | truncated_tokens = compute_num_tokens(truncated_text) 40 | 41 | return truncated_text, truncated_tokens 42 | -------------------------------------------------------------------------------- /src/training_pipeline/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from pydantic_settings import BaseSettings, SettingsConfigDict 4 | 5 | ROOT_DIR = str(Path(__file__).parent.parent.parent.parent) 6 | 7 | 8 | class Settings(BaseSettings): 9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8") 10 | 11 | # Hugging Face config 12 | HUGGINGFACE_BASE_MODEL_ID: str = "meta-llama/Llama-3.1-8B" 13 | HUGGINGFACE_ACCESS_TOKEN: str | None = None 14 | 15 | # Comet config 16 | COMET_API_KEY: str | None = None 17 | COMET_WORKSPACE: str | None = None 18 | COMET_PROJECT: str = "llm-twin" 19 | 20 | DATASET_ID: str = "articles-instruct-dataset" # Comet artifact containing your fine-tuning dataset (available after generating the instruct dataset). 21 | 22 | # AWS config 23 | AWS_REGION: str = "eu-central-1" 24 | AWS_ACCESS_KEY: str | None = None 25 | AWS_SECRET_KEY: str | None = None 26 | AWS_ARN_ROLE: str | None = None 27 | 28 | 29 | settings = Settings() 30 | -------------------------------------------------------------------------------- /src/training_pipeline/download_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | from comet_ml import Artifact, Experiment 5 | from comet_ml.artifacts import ArtifactAsset 6 | from config import settings 7 | from core import get_logger 8 | from datasets import Dataset # noqa: E402 9 | 10 | logger = get_logger(__file__) 11 | 12 | 13 | class DatasetClient: 14 | def __init__( 15 | self, 16 | output_dir: Path = Path("./finetuning_dataset"), 17 | ) -> None: 18 | self.output_dir = output_dir 19 | self.output_dir.mkdir(parents=True, exist_ok=True) 20 | 21 | def download_dataset(self, dataset_id: str, split: str = "train") -> Dataset: 22 | assert split in ["train", "test"], "Split must be either 'train' or 'test'" 23 | 24 | if "/" in dataset_id: 25 | tokens = dataset_id.split("/") 26 | assert ( 27 | len(tokens) == 2 28 | ), f"Wrong format for the {dataset_id}. It should have a maximum one '/' character following the next template: 'comet_ml_workspace/comet_ml_artiface_name'" 29 | workspace, artifact_name = tokens 30 | 31 | experiment = Experiment(workspace=workspace) 32 | else: 33 | artifact_name = dataset_id 34 | 35 | experiment = Experiment() 36 | 37 | artifact = self._download_artifact(artifact_name, experiment) 38 | asset = self._artifact_to_asset(artifact, split) 39 | dataset = self._load_data(asset) 40 | 41 | experiment.end() 42 | 43 | return dataset 44 | 45 | def _download_artifact(self, artifact_name: str, experiment) -> Artifact: 46 | try: 47 | logged_artifact = experiment.get_artifact(artifact_name) 48 | artifact = logged_artifact.download(self.output_dir) 49 | except Exception as e: 50 | print(f"Error retrieving artifact: {str(e)}") 51 | 52 | raise 53 | 54 | print(f"Successfully downloaded '{artifact_name}' at location '{self.output_dir}'") 55 | 56 | return artifact 57 | 58 | def _artifact_to_asset(self, artifact: Artifact, split: str) -> ArtifactAsset: 59 | if len(artifact.assets) == 0: 60 | raise RuntimeError("Artifact has no assets") 61 | elif len(artifact.assets) != 2: 62 | raise RuntimeError( 63 | f"Artifact has more {len(artifact.assets)} assets, which is invalid. It should have only 2." 64 | ) 65 | 66 | print(f"Picking split = '{split}'") 67 | asset = [asset for asset in artifact.assets if split in asset.logical_path][0] 68 | 69 | return asset 70 | 71 | def _load_data(self, asset: ArtifactAsset) -> Dataset: 72 | data_file_path = asset.local_path_or_data 73 | with open(data_file_path, "r") as file: 74 | data = json.load(file) 75 | 76 | dataset_dict = {k: [str(d[k]) for d in data] for k in data[0].keys()} 77 | dataset = Dataset.from_dict(dataset_dict) 78 | 79 | print( 80 | f"Successfully loaded dataset from artifact, num_samples = {len(dataset)}", 81 | ) 82 | 83 | return dataset 84 | 85 | 86 | if __name__ == "__main__": 87 | dataset_client = DatasetClient() 88 | dataset_client.download_dataset(dataset_id=settings.DATASET_ID) 89 | 90 | logger.info(f"Data available at '{dataset_client.output_dir}'.") 91 | -------------------------------------------------------------------------------- /src/training_pipeline/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.33.0 2 | torch==2.4.0 3 | transformers==4.43.3 4 | datasets==2.20.0 5 | peft==0.12.0 6 | trl==0.9.6 7 | bitsandbytes==0.43.3 8 | comet-ml==3.44.3 9 | flash-attn==2.3.6 10 | unsloth==2024.9.post2 11 | --------------------------------------------------------------------------------