├── .docker
├── .dockerignore
├── Dockerfile.data_cdc
├── Dockerfile.data_crawlers
├── Dockerfile.feature_pipeline
└── Dockerfile.feature_pipeline.superlinked
├── .env.example
├── .github
├── ISSUE_TEMPLATE
│ ├── clarify-concept.md
│ ├── questions-about-the-course-material.md
│ └── technical-troubleshooting-or-bugs.md
└── workflows
│ └── crawler.yml
├── .gitignore
├── .python-version
├── INSTALL_AND_USAGE.md
├── LICENSE
├── Makefile
├── README.md
├── data
└── links.txt
├── docker-compose-superlinked.yml
├── docker-compose.yml
├── media
├── architecture.png
├── cover.png
├── fine-tuning-workflow.png
├── llm_engineers_handbook_cover.png
├── qdrant-example.png
├── sponsors
│ ├── bytewax.png
│ ├── comet.png
│ ├── opik.svg
│ ├── qdrant.svg
│ ├── qwak.png
│ └── superlinked.svg
└── ui-example.png
├── poetry.lock
├── pyproject.toml
└── src
├── bonus_superlinked_rag
├── README.md
├── config.py
├── data_flow
│ ├── __init__.py
│ ├── stream_input.py
│ └── stream_output.py
├── data_logic
│ ├── __init__.py
│ ├── cleaning_data_handlers.py
│ ├── dispatchers.py
│ └── splitters.py
├── llm
│ ├── __init__.py
│ ├── chain.py
│ └── prompt_templates.py
├── local_test.py
├── main.py
├── models
│ ├── __init__.py
│ ├── documents.py
│ ├── raw.py
│ └── utils.py
├── mq.py
├── rag
│ ├── __init__.py
│ ├── query_expanison.py
│ ├── reranking.py
│ ├── retriever.py
│ └── self_query.py
├── retriever.py
├── scripts
│ └── bytewax_entrypoint.sh
├── server
│ ├── .python-version
│ ├── LICENSE
│ ├── NOTICE
│ ├── README.md
│ ├── compose.yaml
│ ├── config
│ │ ├── aws_credentials.json
│ │ ├── config.yaml
│ │ └── gcp_credentials.json
│ ├── docs
│ │ ├── api.md
│ │ ├── app.md
│ │ ├── bucket.md
│ │ ├── dummy_app.py
│ │ ├── example
│ │ │ ├── amazon_app.py
│ │ │ └── app.py
│ │ ├── mongodb
│ │ │ ├── app_with_mongodb.py
│ │ │ └── mongodb.md
│ │ ├── redis
│ │ │ ├── app_with_redis.py
│ │ │ └── redis.md
│ │ ├── vector_databases.md
│ │ └── vm.md
│ ├── runner
│ │ ├── .python-version
│ │ ├── executor
│ │ │ ├── Dockerfile
│ │ │ ├── __init__.py
│ │ │ ├── app
│ │ │ │ ├── __init__.py
│ │ │ │ ├── configuration
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── app_config.py
│ │ │ │ ├── dependency_register.py
│ │ │ │ ├── exception
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── exception.py
│ │ │ │ │ └── exception_handler.py
│ │ │ │ ├── main.py
│ │ │ │ ├── middleware
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── lifespan_event.py
│ │ │ │ ├── router
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── management_router.py
│ │ │ │ ├── service
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── data_loader.py
│ │ │ │ │ ├── file_handler_service.py
│ │ │ │ │ ├── file_object_serializer.py
│ │ │ │ │ ├── persistence_service.py
│ │ │ │ │ └── supervisor_service.py
│ │ │ │ └── util
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── fast_api_handler.py
│ │ │ │ │ ├── open_api_description_util.py
│ │ │ │ │ └── registry_loader.py
│ │ │ ├── openapi
│ │ │ │ └── static_endpoint_descriptor.json
│ │ │ └── supervisord.conf
│ │ ├── poetry.lock
│ │ ├── poller
│ │ │ ├── Dockerfile
│ │ │ ├── __init__.py
│ │ │ ├── app
│ │ │ │ ├── __init__.py
│ │ │ │ ├── app_location_parser
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── app_location_parser.py
│ │ │ │ ├── config
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── poller_config.py
│ │ │ │ ├── main.py
│ │ │ │ ├── poller
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── poller.py
│ │ │ │ └── resource_handler
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── gcs
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── gcs_resource_handler.py
│ │ │ │ │ ├── local
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── local_resource_handler.py
│ │ │ │ │ ├── resource_handler.py
│ │ │ │ │ ├── resource_handler_factory.py
│ │ │ │ │ └── s3
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── s3_resource_handler.py
│ │ │ ├── logging_config.ini
│ │ │ └── poller_config.ini
│ │ └── pyproject.toml
│ ├── src
│ │ └── app.py
│ └── tools
│ │ ├── deploy.py
│ │ └── init-venv.sh
├── singleton.py
├── superlinked_client.py
└── utils
│ ├── __init__.py
│ ├── cleaning.py
│ └── logging.py
├── core
├── __init__.py
├── aws
│ ├── __init__.py
│ ├── create_execution_role.py
│ └── create_sagemaker_role.py
├── config.py
├── db
│ ├── __init__.py
│ ├── documents.py
│ ├── mongo.py
│ └── qdrant.py
├── errors.py
├── lib.py
├── logger_utils.py
├── mq.py
├── opik_utils.py
└── rag
│ ├── __init__.py
│ ├── prompt_templates.py
│ ├── query_expanison.py
│ ├── reranking.py
│ ├── retriever.py
│ └── self_query.py
├── data_cdc
├── cdc.py
├── config.py
└── test_cdc.py
├── data_crawling
├── config.py
├── crawlers
│ ├── __init__.py
│ ├── base.py
│ ├── custom_article.py
│ ├── github.py
│ ├── linkedin.py
│ └── medium.py
├── dispatcher.py
├── main.py
└── utils.py
├── feature_pipeline
├── config.py
├── data_flow
│ ├── __init__.py
│ ├── stream_input.py
│ └── stream_output.py
├── data_logic
│ ├── __init__.py
│ ├── chunking_data_handlers.py
│ ├── cleaning_data_handlers.py
│ ├── dispatchers.py
│ └── embedding_data_handlers.py
├── generate_dataset
│ ├── __init__.py
│ ├── chunk_documents.py
│ ├── exceptions.py
│ ├── file_handler.py
│ ├── generate.py
│ └── llm_communication.py
├── main.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── chunk.py
│ ├── clean.py
│ ├── embedded_chunk.py
│ └── raw.py
├── retriever.py
├── scripts
│ └── bytewax_entrypoint.sh
└── utils
│ ├── __init__.py
│ ├── chunking.py
│ ├── cleaning.py
│ └── embeddings.py
├── inference_pipeline
├── aws
│ ├── __init__.py
│ ├── delete_sagemaker_endpoint.py
│ └── deploy_sagemaker_endpoint.py
├── config.py
├── evaluation
│ ├── __init__.py
│ ├── evaluate.py
│ ├── evaluate_monitoring.py
│ ├── evaluate_rag.py
│ └── style.py
├── llm_twin.py
├── main.py
├── prompt_templates.py
├── ui.py
└── utils.py
└── training_pipeline
├── config.py
├── download_dataset.py
├── finetune.py
├── requirements.txt
└── run_on_sagemaker.py
/.docker/.dockerignore:
--------------------------------------------------------------------------------
1 | bonus_superlinked_rag/server
2 |
--------------------------------------------------------------------------------
/.docker/Dockerfile.data_cdc:
--------------------------------------------------------------------------------
1 | # Use an official Python runtime as a parent image
2 | FROM python:3.11-slim
3 |
4 | ENV POETRY_VERSION=1.8.3
5 |
6 | # Install system dependencies
7 | RUN apt-get update && apt-get install -y \
8 | gcc \
9 | python3-dev \
10 | curl \
11 | build-essential \
12 | && apt-get clean \
13 | && rm -rf /var/lib/apt/lists/*
14 |
15 | # Install Poetry using pip and clear cache
16 | RUN pip install --no-cache-dir "poetry==$POETRY_VERSION"
17 | RUN poetry config installer.max-workers 20
18 |
19 | # Set the working directory
20 | WORKDIR /app
21 |
22 | # Copy the pyproject.toml and poetry.lock files from the root directory
23 | COPY ./pyproject.toml ./poetry.lock ./
24 |
25 | # Install the dependencies and clear cache
26 | RUN poetry config virtualenvs.create false && \
27 | poetry install --no-root --no-interaction --no-cache && \
28 | rm -rf ~/.cache/pypoetry/cache/ && \
29 | rm -rf ~/.cache/pypoetry/artifacts/
30 |
31 | # Install dependencies
32 | RUN poetry install --no-root
33 |
34 | # Copy the data_cdc and core directories
35 | COPY ./src/data_cdc ./data_cdc
36 | COPY ./src/core ./core
37 |
38 | # Set the PYTHONPATH environment variable
39 | ENV PYTHONPATH=/app
40 |
41 | # Command to run the script
42 | CMD poetry run python /app/data_cdc/cdc.py && tail -f /dev/null
--------------------------------------------------------------------------------
/.docker/Dockerfile.data_crawlers:
--------------------------------------------------------------------------------
1 | FROM public.ecr.aws/lambda/python:3.11 as build
2 |
3 | # Install chrome driver and browser
4 | RUN yum install -y unzip && \
5 | curl -Lo "/tmp/chromedriver.zip" "https://storage.googleapis.com/chrome-for-testing-public/126.0.6478.126/linux64/chromedriver-linux64.zip" && \
6 | curl -Lo "/tmp/chrome-linux.zip" "https://storage.googleapis.com/chrome-for-testing-public/126.0.6478.126/linux64/chrome-linux64.zip" && \
7 | unzip /tmp/chromedriver.zip -d /opt/ && \
8 | unzip /tmp/chrome-linux.zip -d /opt/
9 |
10 | FROM public.ecr.aws/lambda/python:3.11
11 |
12 | ENV POETRY_VERSION=1.8.3
13 |
14 | # Install the function's OS dependencies using yum
15 | RUN yum install -y \
16 | atk \
17 | wget \
18 | git \
19 | cups-libs \
20 | gtk3 \
21 | libXcomposite \
22 | alsa-lib \
23 | libXcursor \
24 | libXdamage \
25 | libXext \
26 | libXi \
27 | libXrandr \
28 | libXScrnSaver \
29 | libXtst \
30 | pango \
31 | at-spi2-atk \
32 | libXt \
33 | xorg-x11-server-Xvfb \
34 | xorg-x11-xauth \
35 | dbus-glib \
36 | dbus-glib-devel \
37 | nss \
38 | mesa-libgbm \
39 | ffmpeg \
40 | libxext6 \
41 | libssl-dev \
42 | libcurl4-openssl-dev \
43 | libpq-dev
44 |
45 |
46 | COPY --from=build /opt/chrome-linux64 /opt/chrome
47 | COPY --from=build /opt/chromedriver-linux64 /opt/
48 |
49 | COPY ./pyproject.toml ./poetry.lock ./
50 |
51 | # Install Poetry, export dependencies to requirements.txt, and install dependencies
52 | # in the Lambda task directory, finally cleanup manifest files.
53 | RUN python -m pip install --upgrade pip && pip install --no-cache-dir "poetry==$POETRY_VERSION"
54 | RUN poetry export --without feature_pipeline -f requirements.txt > requirements.txt && \
55 | pip install --no-cache-dir -r requirements.txt --target "${LAMBDA_TASK_ROOT}" && \
56 | rm requirements.txt pyproject.toml poetry.lock
57 |
58 | # Optional TLS CA only if you plan to store the extracted data into Document DB
59 | RUN wget https://truststore.pki.rds.amazonaws.com/global/global-bundle.pem -P ${LAMBDA_TASK_ROOT}
60 | ENV PYTHONPATH="${LAMBDA_TASK_ROOT}/data_crawling:${LAMBDA_TASK_ROOT}"
61 |
62 | # Copy function code
63 | COPY ./src/data_crawling ${LAMBDA_TASK_ROOT}/data_crawling
64 | COPY ./src/core ${LAMBDA_TASK_ROOT}/core
65 |
66 | # Set the CMD to your handler (could also be done as a parameter override outside of the Dockerfile)
67 | CMD ["data_crawling.main.handler"]
68 |
--------------------------------------------------------------------------------
/.docker/Dockerfile.feature_pipeline:
--------------------------------------------------------------------------------
1 | # Use an official Python runtime as a parent image
2 | FROM python:3.11-slim-bullseye
3 |
4 | ENV WORKSPACE_ROOT=/usr/src/app \
5 | PYTHONDONTWRITEBYTECODE=1 \
6 | PYTHONUNBUFFERED=1 \
7 | POETRY_HOME="/opt/poetry" \
8 | POETRY_NO_INTERACTION=1 \
9 | POETRY_VERSION=1.8.3
10 |
11 | RUN mkdir -p $WORKSPACE_ROOT
12 |
13 | # Install system dependencies
14 | RUN apt-get update -y \
15 | && apt-get install -y --no-install-recommends build-essential \
16 | gcc \
17 | python3-dev \
18 | curl \
19 | build-essential \
20 | && apt-get clean
21 |
22 | # Install Poetry using pip and clear cache
23 | RUN pip install --no-cache-dir "poetry==$POETRY_VERSION"
24 | RUN poetry config installer.max-workers 20
25 |
26 | RUN apt-get remove -y curl
27 |
28 | # Copy the pyproject.toml and poetry.lock files from the root directory
29 | COPY ./pyproject.toml ./poetry.lock ./
30 |
31 | # Install the dependencies and clear cache
32 | RUN poetry config virtualenvs.create false && \
33 | poetry install --no-root --no-interaction --no-cache && \
34 | rm -rf ~/.cache/pypoetry/cache/ && \
35 | rm -rf ~/.cache/pypoetry/artifacts/
36 |
37 | # Set the working directory
38 | WORKDIR $WORKSPACE_ROOT
39 |
40 | # Copy the feature pipeline and any other necessary directories
41 | COPY ./src/feature_pipeline .
42 | COPY ./src/core ./core
43 |
44 | # Set the PYTHONPATH environment variable
45 | ENV PYTHONPATH=/usr/src/app
46 |
47 | RUN chmod +x /usr/src/app/scripts/bytewax_entrypoint.sh
48 |
49 | # Command to run the Bytewax pipeline script
50 | CMD ["/usr/src/app/scripts/bytewax_entrypoint.sh"]
51 |
--------------------------------------------------------------------------------
/.docker/Dockerfile.feature_pipeline.superlinked:
--------------------------------------------------------------------------------
1 | # Use an official Python runtime as a parent image
2 | FROM python:3.11-slim-bullseye
3 |
4 | ENV WORKSPACE_ROOT=/usr/src/app \
5 | PYTHONDONTWRITEBYTECODE=1 \
6 | PYTHONUNBUFFERED=1 \
7 | POETRY_HOME="/opt/poetry" \
8 | POETRY_NO_INTERACTION=1 \
9 | POETRY_VERSION=1.8.3
10 |
11 | RUN mkdir -p $WORKSPACE_ROOT
12 |
13 | # Install system dependencies
14 | RUN apt-get update -y \
15 | && apt-get install -y --no-install-recommends build-essential \
16 | gcc \
17 | python3-dev \
18 | curl \
19 | build-essential \
20 | && apt-get clean
21 |
22 | # Install Poetry using pip and clear cache
23 | RUN pip install --no-cache-dir "poetry==$POETRY_VERSION"
24 | RUN poetry config installer.max-workers 20
25 |
26 | RUN apt-get remove -y curl
27 |
28 | # Copy the pyproject.toml and poetry.lock files from the root directory
29 | COPY ./pyproject.toml ./poetry.lock ./
30 |
31 | # Install the dependencies and clear cache
32 | RUN poetry config virtualenvs.create false && \
33 | poetry install --no-root --no-interaction --no-cache && \
34 | rm -rf ~/.cache/pypoetry/cache/ && \
35 | rm -rf ~/.cache/pypoetry/artifacts/
36 |
37 | # Set the working directory
38 | WORKDIR $WORKSPACE_ROOT
39 |
40 | # Copy the 3-feature-pipeline and any other necessary directories
41 | COPY ./src/bonus_superlinked_rag .
42 | COPY ./src/core ./core
43 |
44 | # Set the PYTHONPATH environment variable
45 | ENV PYTHONPATH=/usr/src/app
46 |
47 | RUN chmod +x /usr/src/app/scripts/bytewax_entrypoint.sh
48 |
49 | # Command to run the Bytewax pipeline script
50 | CMD ["/usr/src/app/scripts/bytewax_entrypoint.sh"]
51 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 |
2 | # --- Required settings even when working locally. ---
3 |
4 | # OpenAI API config
5 | OPENAI_MODEL_ID=gpt-4o-mini
6 | OPENAI_API_KEY=
7 |
8 | # Huggingface API config
9 | HUGGINGFACE_ACCESS_TOKEN=
10 |
11 | # Comet ML (during training and inference)
12 | COMET_API_KEY=
13 | COMET_WORKSPACE= # such as your Comet username
14 |
15 | # --- Required settings ONLY when using Qdrant Cloud and AWS SageMaker ---
16 |
17 | # Qdrant cloud vector database connection config
18 | USE_QDRANT_CLOUD=false
19 | QDRANT_CLOUD_URL=
20 | QDRANT_APIKEY=
21 |
22 | # AWS authentication config
23 | AWS_ARN_ROLE=
24 | AWS_REGION=eu-central-1
25 | AWS_ACCESS_KEY=
26 | AWS_SECRET_KEY=
27 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/clarify-concept.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Clarify concept
3 | about: Clarify concept
4 | title: Clarify concept
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | Ask anything about the course.
11 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/questions-about-the-course-material.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Questions about the course material
3 | about: Questions about the course material
4 | title: Questions about the course material
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | Ask anything about the course material.
11 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/technical-troubleshooting-or-bugs.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Technical troubleshooting or bugs
3 | about: Technical troubleshooting or bugs
4 | title: Technical troubleshooting or bugs
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug or technical issue**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. iOS]
28 | - Browser [e.g. chrome, safari]
29 | - Version [e.g. 22]
30 |
31 | **Smartphone (please complete the following information):**
32 | - Device: [e.g. iPhone6]
33 | - OS: [e.g. iOS8.1]
34 | - Browser [e.g. stock browser, safari]
35 | - Version [e.g. 22]
36 |
37 | **Additional context**
38 | Add any other context about the problem here.
39 |
--------------------------------------------------------------------------------
/.github/workflows/crawler.yml:
--------------------------------------------------------------------------------
1 | name: Crawlers
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | build:
10 | name: Build & Push Docker Image
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout Code
14 | uses: actions/checkout@v2
15 | - name: Configure AWS credentials
16 | uses: aws-actions/configure-aws-credentials@v1
17 | with:
18 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
19 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
20 | aws-region: ${{ secrets.AWS_REGION }}
21 | - name: Login to Amazon ECR
22 | id: login-ecr
23 | uses: aws-actions/amazon-ecr-login@v1
24 | - name: Build images & push to ECR
25 | id: build-image
26 | uses: docker/build-push-action@v4
27 | with:
28 | context: ./course/module-1
29 | file: ./course/module-1/Dockerfile
30 | tags: |
31 | ${{ steps.login-ecr.outputs.registry }}/crawler:${{ github.sha }}
32 | ${{ steps.login-ecr.outputs.registry }}/crawler:latest
33 | push: true
34 |
35 | deploy:
36 | name: Deploy Crawler
37 | runs-on: ubuntu-latest
38 | needs: build
39 | steps:
40 | - name: Configure AWS credentials
41 | uses: aws-actions/configure-aws-credentials@v1
42 | with:
43 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
44 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
45 | aws-region: ${{ secrets.AWS_REGION }}
46 | - name: Deploy Lambda Image
47 | id: deploy-lambda
48 | run: |
49 | echo "Updating lambda with new image version $ECR_REPOSITORY/crawler:$PROJECT_VERSION..."
50 | aws lambda update-function-code \
51 | --function-name "arn:aws:lambda:$AWS_REGION:$AWS_ACCOUNT_ID:function:crawler" \
52 | --image-uri $ECR_REPOSITORY/crawler:$PROJECT_VERSION
53 | echo "Successfully updated lambda"
54 | env:
55 | AWS_REGION: ${{ secrets.AWS_REGION }}
56 | ECR_REPOSITORY: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_REGION }}.amazonaws.com
57 | PROJECT_VERSION: ${{ github.sha }}
58 | AWS_ACCOUNT_ID: ${{ secrets.AWS_ACCOUNT_ID }}
59 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.11.8
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Decoding ML by Crafted Intelligence S.R.L.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docker-compose-superlinked.yml:
--------------------------------------------------------------------------------
1 | services:
2 | mongo1:
3 | image: mongo:5
4 | container_name: llm-twin-mongo1
5 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30001"]
6 | volumes:
7 | - mongo-replica-1-data:/data/db
8 | ports:
9 | - "30001:30001"
10 | healthcheck:
11 | test: test $$(echo "rs.initiate({_id:'my-replica-set',members:[{_id:0,host:\"mongo1:30001\"},{_id:1,host:\"mongo2:30002\"},{_id:2,host:\"mongo3:30003\"}]}).ok || rs.status().ok" | mongo --port 30001 --quiet) -eq 1
12 | interval: 10s
13 | start_period: 30s
14 | restart: always
15 | networks:
16 | - server_default
17 |
18 | mongo2:
19 | image: mongo:5
20 | container_name: llm-twin-mongo2
21 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30002"]
22 | volumes:
23 | - mongo-replica-2-data:/data/db
24 | ports:
25 | - "30002:30002"
26 | restart: always
27 | networks:
28 | - server_default
29 |
30 | mongo3:
31 | image: mongo:5
32 | container_name: llm-twin-mongo3
33 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30003"]
34 | volumes:
35 | - mongo-replica-3-data:/data/db
36 | ports:
37 | - "30003:30003"
38 | restart: always
39 | networks:
40 | - server_default
41 |
42 | mq:
43 | image: rabbitmq:3-management-alpine
44 | container_name: llm-twin-mq
45 | ports:
46 | - "5672:5672"
47 | - "15672:15672"
48 | volumes:
49 | - ./rabbitmq/data/:/var/lib/rabbitmq/
50 | - ./rabbitmq/log/:/var/log/rabbitmq
51 | healthcheck:
52 | test: ["CMD", "rabbitmqctl", "ping"]
53 | interval: 30s
54 | timeout: 10s
55 | retries: 5
56 | restart: always
57 | networks:
58 | - server_default
59 |
60 | data-crawlers:
61 | image: "llm-twin-data-crawlers"
62 | container_name: llm-twin-data-crawlers
63 | platform: "linux/amd64"
64 | build:
65 | context: .
66 | dockerfile: .docker/Dockerfile.data_crawlers
67 | env_file:
68 | - .env
69 | ports:
70 | - "9010:8080"
71 | depends_on:
72 | - mongo1
73 | - mongo2
74 | - mongo3
75 | networks:
76 | - server_default
77 |
78 | data_cdc:
79 | image: "llm-twin-data-cdc"
80 | container_name: llm-twin-data-cdc
81 | build:
82 | context: .
83 | dockerfile: .docker/Dockerfile.data_cdc
84 | env_file:
85 | - .env
86 | depends_on:
87 | - mongo1
88 | - mongo2
89 | - mongo3
90 | - mq
91 | networks:
92 | - server_default
93 |
94 | feature_pipeline:
95 | image: "llm-twin-feature-pipeline"
96 | container_name: llm-twin-feature-pipeline
97 | build:
98 | context: .
99 | dockerfile: .docker/Dockerfile.feature_pipeline
100 | environment:
101 | BYTEWAX_PYTHON_FILE_PATH: "main:flow"
102 | DEBUG: "false"
103 | BYTEWAX_KEEP_CONTAINER_ALIVE: "false"
104 | env_file:
105 | - .env
106 | depends_on:
107 | - mq
108 | restart: always
109 | networks:
110 | - server_default
111 |
112 | volumes:
113 | mongo-replica-1-data:
114 | mongo-replica-2-data:
115 | mongo-replica-3-data:
116 |
117 | networks:
118 | server_default:
119 | external: true
120 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | mongo1:
3 | image: mongo:5
4 | container_name: llm-twin-mongo1
5 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30001"]
6 | volumes:
7 | - mongo-replica-1-data:/data/db
8 | ports:
9 | - "30001:30001"
10 | healthcheck:
11 | test: test $$(echo "rs.initiate({_id:'my-replica-set',members:[{_id:0,host:\"mongo1:30001\"},{_id:1,host:\"mongo2:30002\"},{_id:2,host:\"mongo3:30003\"}]}).ok || rs.status().ok" | mongo --port 30001 --quiet) -eq 1
12 | interval: 10s
13 | start_period: 30s
14 | restart: always
15 |
16 | mongo2:
17 | image: mongo:5
18 | container_name: llm-twin-mongo2
19 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30002"]
20 | volumes:
21 | - mongo-replica-2-data:/data/db
22 | ports:
23 | - "30002:30002"
24 | restart: always
25 |
26 | mongo3:
27 | image: mongo:5
28 | container_name: llm-twin-mongo3
29 | command: ["--replSet", "my-replica-set", "--bind_ip_all", "--port", "30003"]
30 | volumes:
31 | - mongo-replica-3-data:/data/db
32 | ports:
33 | - "30003:30003"
34 | restart: always
35 |
36 | mq:
37 | image: rabbitmq:3-management-alpine
38 | container_name: llm-twin-mq
39 | ports:
40 | - "5673:5672"
41 | - "15673:15672"
42 | volumes:
43 | - ./rabbitmq/data/:/var/lib/rabbitmq/
44 | - ./rabbitmq/log/:/var/log/rabbitmq
45 | restart: always
46 |
47 | qdrant:
48 | image: qdrant/qdrant:latest
49 | container_name: llm-twin-qdrant
50 | ports:
51 | - "6333:6333"
52 | - "6334:6334"
53 | expose:
54 | - "6333"
55 | - "6334"
56 | - "6335"
57 | volumes:
58 | - qdrant-data:/qdrant_data
59 | restart: always
60 |
61 | data-crawlers:
62 | image: "llm-twin-data-crawlers"
63 | container_name: llm-twin-data-crawlers
64 | platform: "linux/amd64"
65 | build:
66 | context: .
67 | dockerfile: .docker/Dockerfile.data_crawlers
68 | env_file:
69 | - .env
70 | ports:
71 | - "9010:8080"
72 | depends_on:
73 | - mongo1
74 | - mongo2
75 | - mongo3
76 |
77 | data_cdc:
78 | image: "llm-twin-data-cdc"
79 | container_name: llm-twin-data-cdc
80 | build:
81 | context: .
82 | dockerfile: .docker/Dockerfile.data_cdc
83 | env_file:
84 | - .env
85 | depends_on:
86 | - mongo1
87 | - mongo2
88 | - mongo3
89 | - mq
90 |
91 | feature_pipeline:
92 | image: "llm-twin-feature-pipeline"
93 | container_name: llm-twin-feature-pipeline
94 | build:
95 | context: .
96 | dockerfile: .docker/Dockerfile.feature_pipeline
97 | environment:
98 | BYTEWAX_PYTHON_FILE_PATH: "main:flow"
99 | DEBUG: "false"
100 | BYTEWAX_KEEP_CONTAINER_ALIVE: "true"
101 | env_file:
102 | - .env
103 | depends_on:
104 | - mq
105 | - qdrant
106 | restart: always
107 |
108 | volumes:
109 | mongo-replica-1-data:
110 | mongo-replica-2-data:
111 | mongo-replica-3-data:
112 | qdrant-data:
113 |
--------------------------------------------------------------------------------
/media/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/architecture.png
--------------------------------------------------------------------------------
/media/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/cover.png
--------------------------------------------------------------------------------
/media/fine-tuning-workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/fine-tuning-workflow.png
--------------------------------------------------------------------------------
/media/llm_engineers_handbook_cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/llm_engineers_handbook_cover.png
--------------------------------------------------------------------------------
/media/qdrant-example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/qdrant-example.png
--------------------------------------------------------------------------------
/media/sponsors/bytewax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/sponsors/bytewax.png
--------------------------------------------------------------------------------
/media/sponsors/comet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/sponsors/comet.png
--------------------------------------------------------------------------------
/media/sponsors/qwak.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/sponsors/qwak.png
--------------------------------------------------------------------------------
/media/ui-example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/media/ui-example.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "llm-twin-course"
3 | description = ""
4 | version = "0.1.0"
5 | authors = [
6 | "Paul Iusztin
",
7 | "Vesa Alexandru ",
8 | "Razvant Alexandru ",
9 | "Rares Istoc ",
10 | "Vlad Adumitracesei ",
11 | "Anca Muscalagiu "
12 | ]
13 | package-mode = false
14 | readme = "README.md"
15 |
16 | [tool.ruff]
17 | line-length = 88
18 | select = [
19 | "F401",
20 | "F403",
21 | ]
22 |
23 | [tool.poetry.dependencies]
24 | python = "~3.11"
25 | pydantic = "^2.6.3"
26 | pydantic-settings = "^2.2.0"
27 | pika = "^1.3.2"
28 | qdrant-client = "^1.8.0"
29 | aws-lambda-powertools = "^2.38.1"
30 | selenium = "4.21.0"
31 | instructorembedding = "^1.0.1"
32 | numpy = "^1.26.4"
33 | gdown = "^5.1.0"
34 | pymongo = "^4.7.1"
35 | structlog = "^24.1.0"
36 | rich = "^13.7.1"
37 | comet-ml = "^3.41.0"
38 | opik = "1.0.1"
39 | ruff = "^0.4.3"
40 | pandas = "^2.0.3"
41 | datasets = "^2.19.1"
42 | scikit-learn = "^1.4.2"
43 | unstructured = "^0.14.2"
44 | litellm = "^1.50.4"
45 | langchain = "^0.2.11"
46 | langchain-openai = "^0.1.3"
47 | langchain-community = "^0.2.11"
48 | html2text = "^2024.2.26"
49 | huggingface-hub = "0.25.1"
50 | sagemaker = ">=2.232.2"
51 | sentence-transformers = "^2.2.2"
52 | gradio = "^5.5.0"
53 |
54 | [tool.poetry.group.feature_pipeline.dependencies]
55 | bytewax = "0.18.2"
56 |
57 | [tool.poetry.group.superlinked_rag.dependencies]
58 | superlinked = "^7.2.1"
59 |
60 | [build-system]
61 | requires = ["poetry-core"]
62 | build-backend = "poetry.core.masonry.api"
63 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/README.md:
--------------------------------------------------------------------------------
1 | # Dependencies
2 |
3 | - [Docker](https://www.docker.com/)
4 | - [Poetry](https://python-poetry.org/)
5 | - [PyEnv](https://github.com/pyenv/pyenv)
6 | - [Python](https://www.python.org/)
7 | - [GNU Make](https://www.gnu.org/software/make/)
8 |
9 | # Install
10 |
11 | ## 1. Start the Superlinked server
12 |
13 | Make sure you have `pyenv` installed. For example, on MacOS you can install it as:
14 | ```shell
15 | brew update
16 | brew install pyenv
17 | ```
18 |
19 | Now, let's start the Superlinked server by running the following commands:
20 | ```shell
21 | # Create a virtual environment and install all necessary dependencies to deploy the server.
22 | cd 6-bonus-superlinked-rag/server
23 | ./tools/init-venv.sh
24 | cd runner
25 | source "$(poetry env info --path)/bin/activate"
26 | cd ..
27 |
28 | # Make sure you have your docker engine running (e.g. start the docker desktop app).
29 | ./tools/deploy.py up
30 | ```
31 |
32 | > [!NOTE]
33 | > After the server started, you can check out it works and also it's API at http://localhost:8080/docs/
34 |
35 | You can test that the Superlinked server started successfully by running the following command from the `root directory` of the `llm-twin-course`:
36 | ```
37 | make test-superlinked-server
38 | ```
39 | You should see that some mock data has been sent to the Superlinked server and it was queried successfully.
40 |
41 | ## 2. Start the rest of the infrastructure
42 |
43 | After your Superlinked server successfully started (from the root of the repository), run the following to start all necessary components to run locally the LLM twin project powered by Superlinked:
44 | ```shell
45 | make local-start-superlinked
46 | ```
47 |
48 | > [!IMPORTANT]
49 | > Before starting, ensure you have your `.env` file filled with everything required to run the system.
50 | >
51 | > For more details on setting up the local infrastructure, you can check out the course's main [INSTALL_AND_USAGE](https://github.com/decodingml/llm-twin-course/blob/main/INSTALL_AND_USAGE.md) document.
52 |
53 | To stop the local infrastructure, run:
54 | ```shell
55 | make local-stop-superlinked
56 | ```
57 |
58 | > [!NOTE]
59 | > After running the ingestion pipeline, you can visualize what's inside the Redis vector DB at http://localhost:8001/redis-stack/browser
60 |
61 |
62 | # Usage
63 |
64 | To trigger the ingestion, run:
65 | ```shell
66 | make local-test-github
67 | # #OR
68 | make local-test-medium
69 | ```
70 | You can use other GitHub or Medium links to populate the vector DB with more data.
71 |
72 | To call the retrieval module and query the Superlinked server & vector DB, run:
73 | ```shell
74 | make local-test-retriever-superlinked
75 | ```
76 |
77 | > [!IMPORTANT]
78 | > You can check out the main [INSTALL_AND_USAGE](https://github.com/decodingml/llm-twin-course/blob/main/INSTALL_AND_USAGE.md) document of the course for more details on an end-to-end flow.
79 |
80 |
81 | # Next steps
82 |
83 | If you enjoyed our [Superlinked](https://github.com/superlinked/superlinked?utm_source=community&utm_medium=github&utm_campaign=oscourse) bonus series, we recommend checking out their site for more examples. As Superlinked is not just a RAG tool but a general vector compute engine, you can build other awesome stuff with it, such as recommender systems.
84 |
85 | → 🔗 More on [Superlinked](https://github.com/superlinked/superlinked?utm_source=community&utm_medium=github&utm_campaign=oscourse) ←
86 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/config.py:
--------------------------------------------------------------------------------
1 | from pydantic_settings import BaseSettings
2 |
3 |
4 | class Settings(BaseSettings):
5 | # Embeddings config
6 | EMBEDDING_MODEL_ID: str = "BAAI/bge-small-en-v1.5"
7 |
8 | # MQ config
9 | RABBITMQ_DEFAULT_USERNAME: str = "guest"
10 | RABBITMQ_DEFAULT_PASSWORD: str = "guest"
11 | RABBITMQ_HOST: str = "mq" # or localhost if running outside Docker
12 | RABBITMQ_PORT: int = 5672
13 | RABBITMQ_QUEUE_NAME: str = "default"
14 |
15 | # Superlinked
16 | SUPERLINKED_SERVER_URL: str = (
17 | "http://executor:8080" # # or http://localhost:8080 if running outside Docker
18 | )
19 |
20 | # OpenAI
21 | OPENAI_MODEL_ID: str = "gpt-4o-mini"
22 | OPENAI_API_KEY: str | None = None
23 |
24 |
25 | settings = Settings()
26 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/data_flow/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/data_flow/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/data_flow/stream_input.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | from datetime import datetime
4 | from typing import Generic, Iterable, List, Optional, TypeVar
5 |
6 | from bytewax.inputs import FixedPartitionedSource, StatefulSourcePartition
7 | from config import settings
8 | from mq import RabbitMQConnection
9 | from utils.logging import get_logger
10 |
11 | logger = get_logger(__name__)
12 |
13 | DataT = TypeVar("DataT")
14 | MessageT = TypeVar("MessageT")
15 |
16 |
17 | class RabbitMQSource(FixedPartitionedSource, Generic[DataT, MessageT]):
18 | def list_parts(self) -> List[str]:
19 | return ["single partition"]
20 |
21 | def build_part(
22 | self, now: datetime, for_part: str, resume_state: MessageT | None = None
23 | ) -> StatefulSourcePartition[DataT, MessageT]:
24 | return RabbitMQPartition(queue_name=settings.RABBITMQ_QUEUE_NAME)
25 |
26 |
27 | class RabbitMQPartition(StatefulSourcePartition, Generic[DataT, MessageT]):
28 | """
29 | Class responsible for creating a connection between bytewax and rabbitmq that facilitates the transfer of data from mq to bytewax streaming piepline.
30 | Inherits StatefulSourcePartition for snapshot functionality that enables saving the state of the queue
31 | """
32 |
33 | def __init__(self, queue_name: str, resume_state: MessageT | None = None) -> None:
34 | self._in_flight_msg_ids = resume_state or set()
35 | self.queue_name = queue_name
36 | self.connection = RabbitMQConnection()
37 |
38 | try:
39 | self.connection.connect()
40 | self.channel = self.connection.get_channel()
41 | except Exception:
42 | logger.warning(
43 | f"Error while trying to connect to the queue and get the current channel {self.queue_name}",
44 | )
45 |
46 | def next_batch(self, sched: Optional[datetime]) -> Iterable[DataT]:
47 | try:
48 | method_frame, header_frame, body = self.channel.basic_get(
49 | queue=self.queue_name, auto_ack=True
50 | )
51 | except Exception:
52 | logger.warning(
53 | f"Error while fetching message from queue: {self.queue_name}",
54 | )
55 | time.sleep(10) # Sleep for 10 seconds before retrying to access the queue.
56 |
57 | self.connection.connect()
58 | self.channel = self.connection.get_channel()
59 |
60 | return []
61 |
62 | if method_frame:
63 | message_id = method_frame.delivery_tag
64 | self._in_flight_msg_ids.add(message_id)
65 |
66 | return [json.loads(body)]
67 | else:
68 | return []
69 |
70 | def snapshot(self) -> MessageT:
71 | return self._in_flight_msg_ids
72 |
73 | def garbage_collect(self, state):
74 | closed_in_flight_msg_ids = state
75 | for msg_id in closed_in_flight_msg_ids:
76 | self.channel.basic_ack(delivery_tag=msg_id)
77 | self._in_flight_msg_ids.remove(msg_id)
78 |
79 | def close(self):
80 | self.channel.close()
81 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/data_flow/stream_output.py:
--------------------------------------------------------------------------------
1 | from bytewax.outputs import DynamicSink, StatelessSinkPartition
2 | from models.documents import Document
3 | from superlinked_client import SuperlinkedClient
4 | from tqdm import tqdm
5 | from utils.logging import get_logger
6 |
7 | logger = get_logger(__name__)
8 |
9 |
10 | class SuperlinkedOutputSink(DynamicSink):
11 | def __init__(self, client: SuperlinkedClient) -> None:
12 | self._client = client
13 |
14 | def build(self, worker_index: int, worker_count: int) -> StatelessSinkPartition:
15 | return SuperlinkedSinkPartition(client=self._client)
16 |
17 |
18 | class SuperlinkedSinkPartition(StatelessSinkPartition):
19 | def __init__(self, client: SuperlinkedClient):
20 | self._client = client
21 |
22 | def write_batch(self, items: list[Document]) -> None:
23 | for item in tqdm(items, desc="Sending items to Superlinked..."):
24 | match item.type:
25 | case "repositories":
26 | self._client.ingest_repository(item)
27 | case "posts":
28 | self._client.ingest_post(item)
29 | case "articles":
30 | self._client.ingest_article(item)
31 | case _:
32 | logger.error(f"Unknown item type: {item.type}")
33 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/data_logic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/data_logic/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/data_logic/cleaning_data_handlers.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from models.documents import ArticleDocument, Document, PostDocument, RepositoryDocument
4 | from models.raw import ArticleRawModel, PostsRawModel, RawModel, RepositoryRawModel
5 | from utils.cleaning import clean_text
6 |
7 | from .splitters import split_text
8 |
9 |
10 | class CleaningDataHandler(ABC):
11 | """
12 | Abstract class for all cleaning data handlers.
13 | All data transformations logic for the cleaning step is done here
14 | """
15 |
16 | @abstractmethod
17 | def clean(self, data_model: RawModel) -> list[Document]:
18 | pass
19 |
20 |
21 | class PostCleaningHandler(CleaningDataHandler):
22 | def clean(self, data_model: PostsRawModel) -> list[PostDocument]:
23 | documents = []
24 | cleaned_text = clean_text("".join(data_model.content.values()))
25 | for post_subsection in split_text(cleaned_text):
26 | documents.append(
27 | PostDocument(
28 | id=data_model.id,
29 | platform=data_model.platform,
30 | content=post_subsection,
31 | author_id=data_model.author_id,
32 | type=data_model.type,
33 | )
34 | )
35 |
36 | return documents
37 |
38 |
39 | class ArticleCleaningHandler(CleaningDataHandler):
40 | def clean(self, data_model: ArticleRawModel) -> list[ArticleDocument]:
41 | documents = []
42 | cleaned_text = clean_text("".join(data_model.content.values()))
43 | for article_subsection in split_text(cleaned_text):
44 | documents.append(
45 | ArticleDocument(
46 | id=data_model.id,
47 | platform=data_model.platform,
48 | link=data_model.link,
49 | content=article_subsection,
50 | author_id=data_model.author_id,
51 | type=data_model.type,
52 | )
53 | )
54 |
55 | return documents
56 |
57 |
58 | class RepositoryCleaningHandler(CleaningDataHandler):
59 | def clean(self, data_model: RepositoryRawModel) -> list[RepositoryDocument]:
60 | documents = []
61 | for file_name, file_content in data_model.content.items():
62 | cleaned_file_content = clean_text(file_content)
63 | for file_subsection in split_text(cleaned_file_content):
64 | documents.append(
65 | RepositoryDocument(
66 | id=data_model.id,
67 | platform=data_model.platform,
68 | name=f"{data_model.name}:{file_name}",
69 | link=data_model.link,
70 | content=file_subsection,
71 | author_id=data_model.owner_id,
72 | type=data_model.type,
73 | )
74 | )
75 |
76 | return documents
77 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/data_logic/dispatchers.py:
--------------------------------------------------------------------------------
1 | from data_logic.cleaning_data_handlers import (
2 | ArticleCleaningHandler,
3 | CleaningDataHandler,
4 | PostCleaningHandler,
5 | RepositoryCleaningHandler,
6 | )
7 | from models.documents import Document
8 | from models.raw import ArticleRawModel, PostsRawModel, RawModel, RepositoryRawModel
9 | from utils.logging import get_logger
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | class RawDispatcher:
15 | @staticmethod
16 | def handle_mq_message(message: dict) -> RawModel:
17 | data_type = message.get("type")
18 |
19 | logger.info("Received raw message.", data_type=data_type)
20 |
21 | if data_type == "posts":
22 | return PostsRawModel(**message)
23 | elif data_type == "articles":
24 | return ArticleRawModel(**message)
25 | elif data_type == "repositories":
26 | return RepositoryRawModel(**message)
27 | else:
28 | raise ValueError(f"Unsupported data type: {data_type}")
29 |
30 |
31 | class CleaningHandlerFactory:
32 | @staticmethod
33 | def create_handler(data_type: str) -> CleaningDataHandler:
34 | if data_type == "posts":
35 | return PostCleaningHandler()
36 | elif data_type == "articles":
37 | return ArticleCleaningHandler()
38 | elif data_type == "repositories":
39 | return RepositoryCleaningHandler()
40 | else:
41 | raise ValueError("Unsupported data type")
42 |
43 |
44 | class CleaningDispatcher:
45 | cleaning_factory = CleaningHandlerFactory()
46 |
47 | @classmethod
48 | def dispatch_cleaner(cls, data_model: RawModel) -> list[Document]:
49 | logger.info("Cleaning data.", data_type=data_model.type)
50 |
51 | data_type = data_model.type
52 | handler = cls.cleaning_factory.create_handler(data_type)
53 | cleaned_models = handler.clean(data_model)
54 |
55 | logger.info(
56 | "Data cleaned successfully.",
57 | data_type=data_type,
58 | len_cleaned_documents=len(cleaned_models),
59 | len_content=sum([len(doc.content) for doc in cleaned_models]),
60 | )
61 |
62 | return cleaned_models
63 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/data_logic/splitters.py:
--------------------------------------------------------------------------------
1 | from langchain_text_splitters import RecursiveCharacterTextSplitter
2 |
3 |
4 | def split_text(text: str) -> list[str]:
5 | character_splitter = RecursiveCharacterTextSplitter(
6 | separators=["\n\n"], chunk_size=2000, chunk_overlap=0
7 | )
8 | chunks = character_splitter.split_text(text)
9 |
10 | return chunks
11 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/llm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/llm/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/llm/chain.py:
--------------------------------------------------------------------------------
1 | from langchain.chains.llm import LLMChain
2 | from langchain.prompts import PromptTemplate
3 |
4 |
5 | class GeneralChain:
6 | @staticmethod
7 | def get_chain(llm, template: PromptTemplate, output_key: str, verbose=True):
8 | return LLMChain(
9 | llm=llm, prompt=template, output_key=output_key, verbose=verbose
10 | )
11 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/llm/prompt_templates.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from langchain.prompts import PromptTemplate
4 | from pydantic import BaseModel
5 |
6 |
7 | class BasePromptTemplate(ABC, BaseModel):
8 | @abstractmethod
9 | def create_template(self) -> PromptTemplate:
10 | pass
11 |
12 |
13 | class QueryExpansionTemplate(BasePromptTemplate):
14 | prompt: str = """You are an AI language model assistant. Your task is to generate {to_expand_to_n}
15 | different versions of the given user question to retrieve relevant documents from a vector
16 | database. By generating multiple perspectives on the user question, your goal is to help
17 | the user overcome some of the limitations of the distance-based similarity search.
18 | Provide these alternative questions separated by '{separator}'.
19 | Original question: {question}"""
20 |
21 | @property
22 | def separator(self) -> str:
23 | return "#next-question#"
24 |
25 | def create_template(self, to_expand_to_n: int) -> PromptTemplate:
26 | return PromptTemplate(
27 | template=self.prompt,
28 | input_variables=["question"],
29 | partial_variables={
30 | "separator": self.separator,
31 | "to_expand_to_n": to_expand_to_n,
32 | },
33 | )
34 |
35 |
36 | class SelfQueryTemplate(BasePromptTemplate):
37 | prompt: str = """You are an AI language model assistant. Your task is to extract information from a user question.
38 | The required information that needs to be extracted is the user or author id.
39 | Your response should consists of only the extracted id (e.g. 1345256), nothing else.
40 | If you cannot find the author id, return the string "None".
41 | User question: {question}"""
42 |
43 | def create_template(self) -> PromptTemplate:
44 | return PromptTemplate(template=self.prompt, input_variables=["question"])
45 |
46 |
47 | class RerankingTemplate(BasePromptTemplate):
48 | prompt: str = """You are an AI language model assistant. Your task is to rerank passages related to a query
49 | based on their relevance.
50 | The most relevant passages should be put at the beginning. If no passes are relevant, keep the original order. Even if you find duplicates, return them all.
51 | You should only pick at max {keep_top_k} passages. When no passages are relevant, pick at max top {keep_top_k} passages.
52 | The provided and reranked documents are separated by '{separator}'.
53 |
54 | The following are passages related to this query: {question}.
55 |
56 | Passages:
57 | {passages}
58 | """
59 |
60 | def create_template(self, keep_top_k: int) -> PromptTemplate:
61 | return PromptTemplate(
62 | template=self.prompt,
63 | input_variables=["question", "passages"],
64 | partial_variables={"keep_top_k": keep_top_k, "separator": self.separator},
65 | )
66 |
67 | @property
68 | def separator(self) -> str:
69 | return "\n#next-document#\n"
70 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/main.py:
--------------------------------------------------------------------------------
1 | import bytewax.operators as op
2 | from bytewax.dataflow import Dataflow
3 |
4 | from data_flow.stream_input import RabbitMQSource
5 | from data_flow.stream_output import SuperlinkedOutputSink
6 | from data_logic.dispatchers import (
7 | CleaningDispatcher,
8 | RawDispatcher,
9 | )
10 | from superlinked_client import SuperlinkedClient
11 |
12 |
13 | flow = Dataflow("Streaming RAG feature pipeline")
14 | stream = op.input("input", flow, RabbitMQSource())
15 | stream = op.map("raw", stream, RawDispatcher.handle_mq_message)
16 | stream = op.map("clean", stream, CleaningDispatcher.dispatch_cleaner)
17 | stream = op.flatten("flatten_final_output", stream)
18 | op.output(
19 | "superlinked_output",
20 | stream,
21 | SuperlinkedOutputSink(client=SuperlinkedClient()),
22 | )
23 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import documents, raw, utils
2 |
3 | __all__ = ["documents", "raw", "utils"]
4 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/models/documents.py:
--------------------------------------------------------------------------------
1 | from typing_extensions import Annotated
2 | from pydantic import BaseModel, BeforeValidator
3 |
4 |
5 | class PostDocument(BaseModel):
6 | id: str
7 | platform: Annotated[str, BeforeValidator(str.lower)]
8 | content: str
9 | author_id: str
10 | type: str
11 |
12 |
13 | class ArticleDocument(BaseModel):
14 | id: str
15 | platform: Annotated[str, BeforeValidator(str.lower)]
16 | link: str
17 | content: str
18 | author_id: str
19 | type: str
20 |
21 |
22 | class RepositoryDocument(BaseModel):
23 | id: str
24 | platform: Annotated[str, BeforeValidator(str.lower)]
25 | name: str
26 | link: str
27 | content: str
28 | author_id: str
29 | type: str
30 |
31 |
32 | Document = PostDocument | ArticleDocument | RepositoryDocument
33 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/models/raw.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from pydantic import BaseModel, Field
4 |
5 |
6 | class RepositoryRawModel(BaseModel):
7 | id: str = Field(alias="entry_id")
8 | type: str
9 | platform: str = "github"
10 | name: str
11 | link: str
12 | content: dict
13 | owner_id: str
14 |
15 |
16 | class ArticleRawModel(BaseModel):
17 | id: str = Field(alias="entry_id")
18 | type: str
19 | platform: str
20 | link: str
21 | content: dict
22 | author_id: str
23 |
24 |
25 | class PostsRawModel(BaseModel):
26 | id: str = Field(alias="entry_id")
27 | type: str
28 | platform: str
29 | content: dict
30 | author_id: str | None = None
31 | image: Optional[str] = None
32 |
33 |
34 | RawModel = RepositoryRawModel | ArticleRawModel | PostsRawModel
35 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/models/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, List, Optional
2 |
3 | import pandas as pd
4 | from pydantic import BaseModel
5 |
6 | from models.documents import Document
7 |
8 |
9 | def pydantic_models_to_dataframe(
10 | models: List[BaseModel], index_column: Optional[str] = "id"
11 | ) -> pd.DataFrame:
12 | """
13 | Converts a list of Pydantic models to a Pandas DataFrame.
14 |
15 | Args:
16 | models (List[BaseModel]): List of Pydantic models.
17 |
18 | Returns:
19 | pd.DataFrame: DataFrame containing the data from the Pydantic models.
20 | """
21 |
22 | if not models:
23 | return pd.DataFrame()
24 |
25 | # Convert each model to a dictionary and create a list of dictionaries
26 | data = [model.model_dump() for model in models]
27 |
28 | # Create a DataFrame from the list of dictionaries
29 | df = pd.DataFrame(data)
30 |
31 | if index_column in df.columns:
32 | df["index"] = df[index_column]
33 | else:
34 | raise RuntimeError(f"Index column '{index_column}' not found in DataFrame.")
35 |
36 | return df
37 |
38 |
39 | def group_by_type(documents: list[Document]) -> Dict[str, list[Document]]:
40 | return _group_by(documents, selector=lambda doc: doc.type)
41 |
42 |
43 | def _group_by(documents: list[Document], selector: Callable) -> Dict[Any, list]:
44 | grouped = {}
45 | for doc in documents:
46 | key = selector(doc)
47 |
48 | if key not in grouped:
49 | grouped[key] = []
50 | grouped[key].append(doc)
51 |
52 | return grouped
53 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/mq.py:
--------------------------------------------------------------------------------
1 | import pika
2 |
3 | from utils.logging import get_logger
4 | from config import settings
5 |
6 | logger = get_logger(__name__)
7 |
8 |
9 | class RabbitMQConnection:
10 | _instance = None
11 |
12 | def __new__(
13 | cls,
14 | host: str | None = None,
15 | port: int | None = None,
16 | username: str | None = None,
17 | password: str | None = None,
18 | virtual_host: str = "/",
19 | ):
20 | if not cls._instance:
21 | cls._instance = super().__new__(cls)
22 |
23 | return cls._instance
24 |
25 | def __init__(
26 | self,
27 | host: str | None = None,
28 | port: int | None = None,
29 | username: str | None = None,
30 | password: str | None = None,
31 | virtual_host: str = "/",
32 | fail_silently: bool = False,
33 | **kwargs,
34 | ):
35 | self.host = host or settings.RABBITMQ_HOST
36 | self.port = port or settings.RABBITMQ_PORT
37 | self.username = username or settings.RABBITMQ_DEFAULT_USERNAME
38 | self.password = password or settings.RABBITMQ_DEFAULT_PASSWORD
39 | self.virtual_host = virtual_host
40 | self.fail_silently = fail_silently
41 | self._connection = None
42 |
43 | def __enter__(self):
44 | self.connect()
45 | return self
46 |
47 | def __exit__(self, exc_type, exc_val, exc_tb):
48 | self.close()
49 |
50 | def connect(self) -> None:
51 | logger.info("Trying to connect to RabbitMQ.", host=self.host, port=self.port)
52 |
53 | try:
54 | credentials = pika.PlainCredentials(self.username, self.password)
55 | self._connection = pika.BlockingConnection(
56 | pika.ConnectionParameters(
57 | host=self.host,
58 | port=self.port,
59 | virtual_host=self.virtual_host,
60 | credentials=credentials,
61 | )
62 | )
63 | except pika.exceptions.AMQPConnectionError as e:
64 | logger.warning("Failed to connect to RabbitMQ.")
65 |
66 | if not self.fail_silently:
67 | raise e
68 |
69 | def publish_message(self, data: str, queue: str):
70 | channel = self.get_channel()
71 | channel.queue_declare(
72 | queue=queue, durable=True, exclusive=False, auto_delete=False
73 | )
74 | channel.confirm_delivery()
75 |
76 | try:
77 | channel.basic_publish(
78 | exchange="", routing_key=queue, body=data, mandatory=True
79 | )
80 | logger.info(
81 | "Sent message successfully.", queue_type="RabbitMQ", queue_name=queue
82 | )
83 | except pika.exceptions.UnroutableError:
84 | logger.info(
85 | "Failed to send the message.", queue_type="RabbitMQ", queue_name=queue
86 | )
87 |
88 | def is_connected(self) -> bool:
89 | return self._connection is not None and self._connection.is_open
90 |
91 | def get_channel(self):
92 | if self.is_connected():
93 | return self._connection.channel()
94 |
95 | def close(self):
96 | if self.is_connected():
97 | self._connection.close()
98 | self._connection = None
99 |
100 | logger.info("Closed RabbitMQ connection.")
101 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/rag/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/rag/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/rag/query_expanison.py:
--------------------------------------------------------------------------------
1 | from langchain_openai import ChatOpenAI
2 |
3 | from llm.chain import GeneralChain
4 | from llm.prompt_templates import QueryExpansionTemplate
5 | from config import settings
6 |
7 |
8 | class QueryExpansion:
9 | @staticmethod
10 | def generate_response(query: str, to_expand_to_n: int) -> list[str]:
11 | query_expansion_template = QueryExpansionTemplate()
12 | prompt_template = query_expansion_template.create_template(to_expand_to_n)
13 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID, temperature=0)
14 |
15 | chain = GeneralChain().get_chain(
16 | llm=model, output_key="expanded_queries", template=prompt_template
17 | )
18 |
19 | response = chain.invoke({"question": query})
20 | result = response["expanded_queries"]
21 |
22 | queries = result.strip().split(query_expansion_template.separator)
23 | stripped_queries = [
24 | stripped_item for item in queries if (stripped_item := item.strip())
25 | ]
26 |
27 | return stripped_queries
28 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/rag/reranking.py:
--------------------------------------------------------------------------------
1 | from langchain_openai import ChatOpenAI
2 |
3 | from llm.chain import GeneralChain
4 | from llm.prompt_templates import RerankingTemplate
5 | from config import settings
6 |
7 |
8 | class Reranker:
9 | @staticmethod
10 | def generate_response(
11 | query: str, passages: list[str], keep_top_k: int
12 | ) -> list[str]:
13 | reranking_template = RerankingTemplate()
14 | prompt_template = reranking_template.create_template(keep_top_k=keep_top_k)
15 |
16 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID)
17 | chain = GeneralChain().get_chain(
18 | llm=model, output_key="rerank", template=prompt_template
19 | )
20 |
21 | stripped_passages = [
22 | stripped_item for item in passages if (stripped_item := item.strip())
23 | ]
24 | passages = reranking_template.separator.join(stripped_passages)
25 | response = chain.invoke({"question": query, "passages": passages})
26 |
27 | result = response["rerank"]
28 | reranked_passages = result.strip().split(reranking_template.separator)
29 | stripped_passages = [
30 | stripped_item
31 | for item in reranked_passages
32 | if (stripped_item := item.strip())
33 | ]
34 |
35 | return stripped_passages
36 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/rag/self_query.py:
--------------------------------------------------------------------------------
1 | from langchain_openai import ChatOpenAI
2 | from llm.chain import GeneralChain
3 | from llm.prompt_templates import SelfQueryTemplate
4 | from config import settings
5 |
6 |
7 | class SelfQuery:
8 | @staticmethod
9 | def generate_response(query: str) -> str | None:
10 | prompt = SelfQueryTemplate().create_template()
11 | model = ChatOpenAI(model=settings.OPENAI_MODEL_ID, temperature=0)
12 |
13 | chain = GeneralChain().get_chain(
14 | llm=model, output_key="metadata_filter_value", template=prompt
15 | )
16 |
17 | response = chain.invoke({"question": query})
18 | result = response.get("metadata_filter_value", "none")
19 |
20 | if result.lower() == "none":
21 | return None
22 |
23 | return result
24 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/retriever.py:
--------------------------------------------------------------------------------
1 | from dotenv import load_dotenv
2 | from langchain.globals import set_verbose
3 | from rag.retriever import VectorRetriever
4 |
5 | from utils.logging import get_logger
6 |
7 | set_verbose(False)
8 |
9 | logger = get_logger(__name__)
10 |
11 | if __name__ == "__main__":
12 | load_dotenv()
13 | query = """
14 | I am author_4.
15 |
16 | Could you please draft a LinkedIn post discussing RAG systems?
17 | I'm particularly interested in how RAG works and how it is integrated with vector DBs and large language models (LLMs).
18 | """
19 | retriever = VectorRetriever(query=query)
20 | hits = retriever.retrieve_top_k(k=6, to_expand_to_n_queries=3)
21 |
22 | reranked_hits = retriever.rerank(documents=hits, keep_top_k=3)
23 |
24 | logger.info("Reranked hits:")
25 | for rank, hit in enumerate(reranked_hits):
26 | logger.info(f"{rank}: {hit[:100]}...")
27 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/scripts/bytewax_entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if [ "$DEBUG" = true ]
4 | then
5 | python -m bytewax.run "tools.run_real_time:build_flow(debug=True)"
6 | else
7 | if [ "$BYTEWAX_PYTHON_FILE_PATH" = "" ]
8 | then
9 | echo 'BYTEWAX_PYTHON_FILE_PATH is not set. Exiting...'
10 | exit 1
11 | fi
12 | python -m bytewax.run $BYTEWAX_PYTHON_FILE_PATH
13 | fi
14 |
15 |
16 | echo 'Process ended.'
17 |
18 | if [ "$BYTEWAX_KEEP_CONTAINER_ALIVE" = true ]
19 | then
20 | echo 'Keeping container alive...';
21 | while :; do sleep 1; done
22 | fi
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/.python-version:
--------------------------------------------------------------------------------
1 | 3.11
2 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/compose.yaml:
--------------------------------------------------------------------------------
1 | version: "3"
2 |
3 | services:
4 | poller:
5 | build:
6 | context: runner
7 | dockerfile: poller/Dockerfile
8 | mem_limit: 256m
9 | cpus: '0.5'
10 | volumes:
11 | - ./cache:/code/data
12 | - ./config:/code/config
13 | - ./src:/src
14 |
15 | executor:
16 | depends_on:
17 | - poller
18 | - redis
19 | build:
20 | context: runner
21 | dockerfile: executor/Dockerfile
22 | volumes:
23 | - ./cache:/code/data
24 | environment:
25 | - APP_MODULE_PATH=data.src.app
26 | ports:
27 | - 8080:8080
28 | env_file:
29 | - runner/executor/.env
30 |
31 | redis:
32 | image: redis/redis-stack:latest
33 | ports:
34 | - "6379:6379"
35 | - "8001:8001"
36 | volumes:
37 | - redis-data:/data
38 |
39 | volumes:
40 | redis-data:
41 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/config/aws_credentials.json:
--------------------------------------------------------------------------------
1 | {
2 | "accessKeyId": ,
3 | "region": ,
4 | "secretAccessKey":
5 | }
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/config/config.yaml:
--------------------------------------------------------------------------------
1 | app_location: local
2 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/config/gcp_credentials.json:
--------------------------------------------------------------------------------
1 | {
2 | "type": "service_account",
3 | "project_id": "PROJECT_ID",
4 | "private_key_id": "KEY_ID",
5 | "private_key": "-----BEGIN PRIVATE KEY-----\nPRIVATE_KEY\n-----END PRIVATE KEY-----\n",
6 | "client_email": "SERVICE_ACCOUNT_EMAIL",
7 | "client_id": "CLIENT_ID",
8 | "auth_uri": "https://accounts.google.com/o/oauth2/auth",
9 | "token_uri": "https://accounts.google.com/o/oauth2/token",
10 | "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
11 | "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/SERVICE_ACCOUNT_EMAIL"
12 | }
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/docs/bucket.md:
--------------------------------------------------------------------------------
1 | ## What are Storage Buckets?
2 |
3 | Storage buckets are fundamental entities used to store data in cloud storage services. They are essentially containers for data where you can upload, download, and manage files. Two of the most popular cloud storage services are Amazon S3 (Simple Storage Service) and Google Cloud Storage (GCS).
4 |
5 | ### Amazon S3
6 |
7 | Amazon S3 buckets are used to store objects, which consist of data and its descriptive metadata. They are globally unique, and defined at the Universal Resource Locator (URL) level.
8 |
9 | ### Google Cloud Storage (GCS)
10 |
11 | Google Cloud Storage buckets are similar to Amazon S3 buckets. They are used to store data objects in Google Cloud. Each bucket is associated with a specific project, and you can choose to make the bucket data public or private.
12 |
13 | ## Purpose in Our System
14 |
15 | In our system, we use these storage buckets to store user code. When you write and save your code, it gets stored in either an S3 or GCS bucket. This allows us to securely store your code, retrieve it when needed, and even share it among multiple services or instances if required.
16 |
17 | ## How to Create Buckets
18 |
19 | ### Creating an Amazon S3 Bucket using AWS CLI
20 |
21 | To create an Amazon S3 bucket, you can use the AWS CLI (Command Line Interface). Here's how:
22 |
23 | 1. First, install the AWS CLI on your machine. You can find the installation instructions [here](https://aws.amazon.com/cli/).
24 | 2. Configure the AWS CLI with your credentials:
25 |
26 | ```bash
27 | aws configure
28 | ```
29 |
30 | 1. Create a new S3 bucket:
31 |
32 | ```bash
33 | aws s3api create-bucket --bucket my-bucket-name --region us-west-2
34 | ```
35 |
36 | Replace `my-bucket-name` with your desired bucket name and `us-west-2` with the AWS region you want to create your bucket in. For more information on S3 bucket creation, refer to the [AWS documentation](https://docs.aws.amazon.com/AmazonS3/latest/userguide/create-bucket-overview.html).
37 |
38 | ### Creating a Google Cloud Storage Bucket using `gsutil`
39 |
40 | To create a Google Cloud Storage bucket, you can use the `gsutil` tool. Here's how:
41 |
42 | 1. First, install the Google Cloud SDK, which includes the `gsutil` tool. You can find the installation instructions [here](https://cloud.google.com/sdk/docs/install).
43 | 2. Authenticate your account:
44 |
45 | ```bash
46 | gcloud auth login
47 | ```
48 |
49 | 1. Create a new GCS bucket:
50 |
51 | ```bash
52 | gsutil mb -p project_id -c storage_class -l location gs://my-bucket-name
53 | ```
54 |
55 | Replace `project_id` with the ID of your project, `storage_class` with the desired storage class for the bucket, `location` with the desired bucket location and `my-bucket-name` with your desired bucket name. For more information on GCS bucket creation, refer to the [Google Cloud Documentation](https://cloud.google.com/storage/docs/creating-buckets).
56 |
57 | Remember, the bucket names must be unique across all existing bucket names in Amazon S3 or Google Cloud Storage.
58 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/docs/dummy_app.py:
--------------------------------------------------------------------------------
1 | from superlinked.framework.common.schema.id_schema_object import IdField
2 | from superlinked.framework.common.schema.schema import schema
3 | from superlinked.framework.common.schema.schema_object import String
4 | from superlinked.framework.dsl.executor.rest.rest_configuration import RestQuery
5 | from superlinked.framework.dsl.executor.rest.rest_descriptor import RestDescriptor
6 | from superlinked.framework.dsl.executor.rest.rest_executor import RestExecutor
7 | from superlinked.framework.dsl.index.index import Index
8 | from superlinked.framework.dsl.query.param import Param
9 | from superlinked.framework.dsl.query.query import Query
10 | from superlinked.framework.dsl.registry.superlinked_registry import SuperlinkedRegistry
11 | from superlinked.framework.dsl.source.rest_source import RestSource
12 | from superlinked.framework.dsl.space.text_similarity_space import TextSimilaritySpace
13 | from superlinked.framework.dsl.storage.in_memory_vector_database import InMemoryVectorDatabase
14 |
15 |
16 | @schema
17 | class YourSchema:
18 | id: IdField
19 | attribute: String
20 |
21 |
22 | your_schema = YourSchema()
23 |
24 | text_space = TextSimilaritySpace(text=your_schema.attribute, model="model-name")
25 |
26 | index = Index(text_space)
27 |
28 | query = (
29 | Query(index)
30 | .find(your_schema)
31 | .similar(
32 | text_space.text,
33 | Param("query_text"),
34 | )
35 | )
36 |
37 | your_source: RestSource = RestSource(your_schema)
38 |
39 | executor = RestExecutor(
40 | sources=[your_source],
41 | indices=[index],
42 | queries=[RestQuery(RestDescriptor("query"), query)],
43 | vector_database=InMemoryVectorDatabase(),
44 | )
45 |
46 | SuperlinkedRegistry.register(executor)
47 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/docs/example/app.py:
--------------------------------------------------------------------------------
1 | from superlinked.framework.common.schema.id_schema_object import IdField
2 | from superlinked.framework.common.schema.schema import schema
3 | from superlinked.framework.common.schema.schema_object import String
4 | from superlinked.framework.dsl.executor.rest.rest_configuration import (
5 | RestQuery,
6 | )
7 | from superlinked.framework.dsl.executor.rest.rest_descriptor import RestDescriptor
8 | from superlinked.framework.dsl.executor.rest.rest_executor import RestExecutor
9 | from superlinked.framework.dsl.index.index import Index
10 | from superlinked.framework.dsl.query.param import Param
11 | from superlinked.framework.dsl.query.query import Query
12 | from superlinked.framework.dsl.registry.superlinked_registry import SuperlinkedRegistry
13 | from superlinked.framework.dsl.source.rest_source import RestSource
14 | from superlinked.framework.dsl.space.text_similarity_space import TextSimilaritySpace
15 | from superlinked.framework.dsl.storage.in_memory_vector_database import InMemoryVectorDatabase
16 |
17 |
18 | @schema
19 | class CarSchema:
20 | id: IdField
21 | make: String
22 | model: String
23 |
24 |
25 | car_schema = CarSchema()
26 |
27 | car_make_text_space = TextSimilaritySpace(text=car_schema.make, model="all-MiniLM-L6-v2")
28 | car_model_text_space = TextSimilaritySpace(text=car_schema.model, model="all-MiniLM-L6-v2")
29 |
30 | index = Index([car_make_text_space, car_model_text_space])
31 |
32 | query = (
33 | Query(index)
34 | .find(car_schema)
35 | .similar(car_make_text_space.text, Param("make"))
36 | .similar(car_model_text_space.text, Param("model"))
37 | .limit(Param("limit"))
38 | )
39 |
40 | car_source: RestSource = RestSource(car_schema)
41 |
42 | executor = RestExecutor(
43 | sources=[car_source],
44 | indices=[index],
45 | queries=[RestQuery(RestDescriptor("query"), query)],
46 | vector_database=InMemoryVectorDatabase(),
47 | )
48 |
49 | SuperlinkedRegistry.register(executor)
50 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/docs/mongodb/app_with_mongodb.py:
--------------------------------------------------------------------------------
1 | from superlinked.framework.common.schema.id_schema_object import IdField
2 | from superlinked.framework.common.schema.schema import schema
3 | from superlinked.framework.common.schema.schema_object import String
4 | from superlinked.framework.dsl.executor.rest.rest_configuration import (
5 | RestQuery,
6 | )
7 | from superlinked.framework.dsl.executor.rest.rest_descriptor import RestDescriptor
8 | from superlinked.framework.dsl.executor.rest.rest_executor import RestExecutor
9 | from superlinked.framework.dsl.index.index import Index
10 | from superlinked.framework.dsl.query.param import Param
11 | from superlinked.framework.dsl.query.query import Query
12 | from superlinked.framework.dsl.registry.superlinked_registry import SuperlinkedRegistry
13 | from superlinked.framework.dsl.source.rest_source import RestSource
14 | from superlinked.framework.dsl.space.text_similarity_space import TextSimilaritySpace
15 | from superlinked.framework.dsl.storage.mongo_vector_database import MongoVectorDatabase
16 |
17 |
18 | @schema
19 | class CarSchema:
20 | id: IdField
21 | make: String
22 | model: String
23 |
24 |
25 | car_schema = CarSchema()
26 |
27 | car_make_text_space = TextSimilaritySpace(text=car_schema.make, model="all-MiniLM-L6-v2")
28 | car_model_text_space = TextSimilaritySpace(text=car_schema.model, model="all-MiniLM-L6-v2")
29 |
30 | index = Index([car_make_text_space, car_model_text_space])
31 |
32 | query = (
33 | Query(index)
34 | .find(car_schema)
35 | .similar(car_make_text_space.text, Param("make"))
36 | .similar(car_model_text_space.text, Param("model"))
37 | .limit(Param("limit"))
38 | )
39 |
40 | car_source: RestSource = RestSource(car_schema)
41 |
42 | mongo_vector_database = MongoVectorDatabase(
43 | ":@",
44 | "",
45 | "",
46 | "",
47 | "",
48 | "",
49 | )
50 |
51 | executor = RestExecutor(
52 | sources=[car_source],
53 | indices=[index],
54 | queries=[RestQuery(RestDescriptor("query"), query)],
55 | vector_database=mongo_vector_database,
56 | )
57 |
58 | SuperlinkedRegistry.register(executor)
59 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/docs/redis/app_with_redis.py:
--------------------------------------------------------------------------------
1 | from superlinked.framework.common.schema.id_schema_object import IdField
2 | from superlinked.framework.common.schema.schema import schema
3 | from superlinked.framework.common.schema.schema_object import String
4 | from superlinked.framework.dsl.executor.rest.rest_configuration import (
5 | RestQuery,
6 | )
7 | from superlinked.framework.dsl.executor.rest.rest_descriptor import RestDescriptor
8 | from superlinked.framework.dsl.executor.rest.rest_executor import RestExecutor
9 | from superlinked.framework.dsl.index.index import Index
10 | from superlinked.framework.dsl.query.param import Param
11 | from superlinked.framework.dsl.query.query import Query
12 | from superlinked.framework.dsl.registry.superlinked_registry import SuperlinkedRegistry
13 | from superlinked.framework.dsl.source.rest_source import RestSource
14 | from superlinked.framework.dsl.space.text_similarity_space import TextSimilaritySpace
15 | from superlinked.framework.dsl.storage.redis_vector_database import RedisVectorDatabase
16 |
17 |
18 | @schema
19 | class CarSchema:
20 | id: IdField
21 | make: String
22 | model: String
23 |
24 |
25 | car_schema = CarSchema()
26 |
27 | car_make_text_space = TextSimilaritySpace(text=car_schema.make, model="all-MiniLM-L6-v2")
28 | car_model_text_space = TextSimilaritySpace(text=car_schema.model, model="all-MiniLM-L6-v2")
29 |
30 | index = Index([car_make_text_space, car_model_text_space])
31 |
32 | query = (
33 | Query(index)
34 | .find(car_schema)
35 | .similar(car_make_text_space.text, Param("make"))
36 | .similar(car_model_text_space.text, Param("model"))
37 | .limit(Param("limit"))
38 | )
39 |
40 | car_source: RestSource = RestSource(car_schema)
41 |
42 | redis_vector_database = RedisVectorDatabase("", 12345, username="default", password="")
43 |
44 | executor = RestExecutor(
45 | sources=[car_source],
46 | indices=[index],
47 | queries=[RestQuery(RestDescriptor("query"), query)],
48 | vector_database=redis_vector_database,
49 | )
50 |
51 | SuperlinkedRegistry.register(executor)
52 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/docs/redis/redis.md:
--------------------------------------------------------------------------------
1 | # Redis
2 |
3 | This document provides clear steps on how to use and integrate Redis with Superlinked.
4 |
5 | ## Configuring your existing managed Redis
6 |
7 | To use Superlinked with Redis, you will need several Redis modules. The simplest approach is to use the official Redis Stack, which includes all the necessary modules. Installation instructions for the Redis Stack can be found in the [Redis official documentation](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/). Alternatively, you can start a managed instance provided by Redis (a free-tier is available). For detailed steps on initiating a managed instance, refer to the [Start a Managed Redis Instance](#start-a-managed-redis-instance) section below.
8 |
9 | Once your Redis instance is up and running, ensure it is accessible from the server that will use it. Additionally, configure the necessary authentication settings as described below.
10 |
11 | ## Modifications in app.py
12 |
13 | To integrate Redis, you need to add the `RedisVectorDatabase` class and include it in the executor. Here’s how you can do it:
14 |
15 | To configure your Redis, the following code will help you:
16 | ```python
17 | from superlinked.framework.dsl.storage.redis_vector_database import RedisVectorDatabase
18 |
19 | vector_database = RedisVectorDatabase(
20 | "", # (Mandatory) This is your redis URL without any port or extra fields
21 | 12315, # (Mandatory) This is the port and it should be an integer
22 | # These params must be in a form of kwarg params. Here you can specify anything that the official python client
23 | # enables. The params can be found here: https://redis.readthedocs.io/en/stable/connections.html. Below you can see a very basic user-pass authentication as an example.
24 | username="test",
25 | password="password"
26 | )
27 | ```
28 |
29 | Once you have configured the vector database just simply set it as your vector database.
30 | ```python
31 | ...
32 | executor = RestExecutor(
33 | sources=[source],
34 | indices=[index],
35 | queries=[RestQuery(RestDescriptor("query"), query)],
36 | vector_database=vector_database, # Or any variable that you assigned your `RedisVectorDatabase`
37 | )
38 | ...
39 | ```
40 |
41 | ## Start a Managed Redis Instance
42 |
43 | To initiate a managed Redis instance, navigate to [Redis Labs](https://app.redislabs.com/), sign in, and click the "New Database" button. On the ensuing page, locate the `Type` selector, which offers two options: `Redis Stack` and `Memcached`. By default, `Redis Stack` is pre-selected, which is the correct choice. If it is not selected, ensure to choose `Redis Stack`. For basic usage, no further configuration is necessary. Redis already generated a user which called `default` and a password that you can see below it. However, if you intend to use the instance for persistent data storage beyond sandbox purposes, consider configuring High Availability (HA), data persistence, and other relevant settings.
44 |
45 | ## Example app with Redis
46 |
47 | You can find an example that utilizes Redis [here](app_with_redis.py)
48 |
49 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/docs/vector_databases.md:
--------------------------------------------------------------------------------
1 | # Vector databases
2 |
3 | This document will list and point to the detailed documentation of the supported vector databases:
4 |
5 | - [Redis](redis/redis.md)
6 | - [MongoDB](mongodb/mongodb.md)
7 |
8 | Missing your favorite one? [Let us know in github discussions](https://github.com/superlinked/superlinked/discussions/41)
9 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/docs/vm.md:
--------------------------------------------------------------------------------
1 | ## What are Virtual Machines?
2 |
3 | Virtual Machines (VMs) are emulations of computer systems. They can run applications and services on a virtual platform, as if they were on a physical machine, while sharing the underlying hardware resources.
4 |
5 | ### VM size recommendations
6 |
7 | The minimum VM size is `t3.small` in AWS or `e2-small` in GCP. Both configurations offer 2 vCPUs and 2GB of RAM. Note that if you are utilizing in-memory storage, the 2GB memory capacity can only accommodate a few tens of thousands of records. For larger datasets, opt for a machine with more RAM. The data loader functionality is CPU-intensive, so if you plan to use it, consider provisioning additional vCPUs.
8 |
9 | The system comprises two components: poller and executor. The poller is a basic service, constrained in the Docker Compose file to use only 256MB of memory and 0.5 vCPU. The executor, which runs Superlinked along with a few minor additions, will consume the majority of the resources. It has no set limits in the Docker Compose file, so any resources not utilized by the poller will be allocated to the executor.
10 |
11 | ### Amazon EC2
12 |
13 | Amazon Elastic Compute Cloud (EC2) is a part of Amazon's cloud-computing platform, Amazon Web Services (AWS). EC2 allows users to rent virtual computers on which to run their own computer applications.
14 |
15 | ### Google Cloud Compute Engine
16 |
17 | Google Cloud Compute Engine delivers virtual machines running in Google's innovative data centers and worldwide fiber network. Compute Engine's tooling and workflow support enable scaling from single instances to global, load-balanced cloud computing.
18 |
19 | ## Creating an Amazon EC2 Instance using AWS CLI
20 |
21 | To create an Amazon EC2 instance, you can use the AWS CLI. Here's how:
22 |
23 | 1. Install and configure the AWS CLI as described in the previous section.
24 | 2. Launch an EC2 instance:
25 |
26 | ```bash
27 | aws ec2 run-instances --image-id ami-0abcdef1234567890 --count 1 --instance-type t3.small --key-name MyKeyPair --security-group-ids sg-903004f8 --subnet-id subnet-6e7f829e
28 | ```
29 |
30 | Replace the `image-id`, `key-name`, `security-group-ids`, and `subnet-id` with your own values. The `image-id` is the ID of the AMI (Amazon Machine Image), `key-name` is the name of the key pair for the instance, `security-group-ids` is the ID of the security group, and `subnet-id` is the ID of the subnet.
31 |
32 | For more information on creating key pairs, refer to the [AWS documentation](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-key-pairs.html#having-ec2-create-your-key-pair). For a list of available regions, refer to the [AWS Regional Services List](https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services/).
33 |
34 | ## Creating a Google Cloud Compute Engine VM using gcloud
35 |
36 | To create a Google Cloud Compute Engine VM, you can use the `gcloud` command-line tool. Here's how:
37 |
38 | 1. Install and authenticate the Google Cloud SDK as described in the previous section.
39 | 2. Create a Compute Engine instance:
40 |
41 | ```bash
42 | gcloud compute instances create my-vm --machine-type=e2-small --image-project=debian-cloud --image-family=debian-9 --boot-disk-size=50GB
43 | ```
44 |
45 | Replace `my-vm` with your desired instance name. The `machine-type` is the machine type of the VM, `image-project` is the project ID of the image, `image-family` is the family of the image, and `boot-disk-size` is the size of the boot disk. Please use a `boot-disk-size` value over 20 GB.
46 |
47 | Remember, you need to have the necessary permissions and quotas to create VM instances in both Amazon EC2 and Google Cloud Compute Engine.
48 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/.python-version:
--------------------------------------------------------------------------------
1 | 3.11
2 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/Dockerfile:
--------------------------------------------------------------------------------
1 | # --- Python base ---
2 |
3 | FROM python:3.11.9 AS python-base
4 |
5 | ENV PYTHONUNBUFFERED=1 \
6 | PYTHONDONTWRITEBYTECODE=1 \
7 | \
8 | PIP_NO_CACHE_DIR=off \
9 | PIP_DISABLE_PIP_VERSION_CHECK=on \
10 | PIP_DEFAULT_TIMEOUT=100 \
11 | \
12 | POETRY_VERSION=1.8.2 \
13 | POETRY_HOME="/opt/poetry" \
14 | POETRY_VIRTUALENVS_IN_PROJECT=true \
15 | POETRY_NO_INTERACTION=1 \
16 | \
17 | PYSETUP_PATH="/opt/pysetup" \
18 | VENV_PATH="/opt/pysetup/.venv"
19 |
20 | ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH"
21 | # --- Builder base ---
22 |
23 | FROM python-base AS builder-base
24 | RUN apt-get update \
25 | && apt-get install --no-install-recommends -y \
26 | curl \
27 | build-essential \
28 | && rm -rf /var/lib/apt/lists/*
29 |
30 | SHELL ["/bin/bash", "-o", "pipefail", "-c"]
31 | RUN curl -sSL https://install.python-poetry.org | python -
32 |
33 | # --- Executor base ---
34 |
35 | FROM builder-base as executor-base
36 |
37 | WORKDIR $PYSETUP_PATH
38 |
39 | COPY poetry.lock pyproject.toml $PYSETUP_PATH/
40 | RUN poetry install --no-root --only executor
41 |
42 | # --- Executor ---
43 |
44 | FROM python:3.11.5-slim AS executor
45 |
46 | COPY --from=executor-base /opt/pysetup/.venv/ /opt/pysetup/.venv/
47 |
48 | WORKDIR /code
49 |
50 | RUN apt-get update && apt-get install --no-install-recommends -y supervisor && apt-get clean && rm -rf /var/lib/apt/lists/*
51 |
52 | COPY executor/supervisord.conf /etc/supervisor/conf.d/supervisord.conf
53 | COPY executor/.env /code/executor/
54 | COPY executor/app /code/executor/app
55 |
56 | EXPOSE 8080
57 | ENTRYPOINT ["/usr/bin/supervisord", "-n"]
58 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/configuration/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/configuration/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/configuration/app_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pydantic_settings import BaseSettings, SettingsConfigDict
16 |
17 |
18 | class AppConfig(BaseSettings):
19 | SERVER_URL: str
20 | APP_MODULE_PATH: str
21 | LOG_LEVEL: str
22 | PERSISTENCE_FOLDER_PATH: str
23 | DISABLE_RECENCY_SPACE: bool
24 |
25 | model_config = SettingsConfigDict(env_file="executor/.env", extra="ignore")
26 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/dependency_register.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from xmlrpc.client import ServerProxy
16 |
17 | import inject
18 |
19 | from executor.app.configuration.app_config import AppConfig
20 | from executor.app.service.data_loader import DataLoader
21 | from executor.app.service.file_handler_service import FileHandlerService
22 | from executor.app.service.file_object_serializer import FileObjectSerializer
23 | from executor.app.service.persistence_service import PersistenceService
24 | from executor.app.service.supervisor_service import SupervisorService
25 |
26 |
27 | def register_dependencies() -> None:
28 | inject.configure(_configure)
29 |
30 |
31 | def _configure(binder: inject.Binder) -> None:
32 | app_config = AppConfig()
33 | file_handler_service = FileHandlerService(app_config)
34 | serializer = FileObjectSerializer(file_handler_service)
35 | server_proxy = ServerProxy(app_config.SERVER_URL)
36 | binder.bind(DataLoader, DataLoader(app_config))
37 | binder.bind(PersistenceService, PersistenceService(serializer))
38 | binder.bind(SupervisorService, SupervisorService(server_proxy))
39 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/exception/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/exception/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/exception/exception.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | class UnsupportedProtocolException(Exception):
17 | pass
18 |
19 |
20 | class FilesNotFoundException(Exception):
21 | pass
22 |
23 |
24 | class DataLoaderNotFoundException(Exception):
25 | pass
26 |
27 |
28 | class DataLoaderAlreadyRunningException(Exception):
29 | pass
30 |
31 |
32 | class DataLoaderTaskNotFoundException(Exception):
33 | pass
34 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/exception/exception_handler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 |
17 | from fastapi import Request, status
18 | from fastapi.responses import JSONResponse
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | async def handle_bad_request(_: Request, exception: Exception) -> JSONResponse:
24 | logger.exception("Bad request")
25 | return JSONResponse(
26 | status_code=status.HTTP_400_BAD_REQUEST,
27 | content={
28 | "exception": str(exception.__class__.__name__),
29 | "detail": str(exception),
30 | },
31 | )
32 |
33 |
34 | async def handle_generic_exception(_: Request, exception: Exception) -> JSONResponse:
35 | logger.exception("Unexpected exception happened")
36 | return JSONResponse(
37 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
38 | content={
39 | "exception": str(exception.__class__.__name__),
40 | "detail": str(exception),
41 | },
42 | )
43 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | from json import JSONDecodeError
17 |
18 | import uvicorn
19 | from fastapi import FastAPI
20 | from fastapi_restful.timing import add_timing_middleware
21 | from superlinked.framework.common.parser.exception import MissingIdException
22 | from superlinked.framework.online.dag.exception import ValueNotProvidedException
23 |
24 | from executor.app.configuration.app_config import AppConfig
25 | from executor.app.dependency_register import register_dependencies
26 | from executor.app.exception.exception_handler import (
27 | handle_bad_request,
28 | handle_generic_exception,
29 | )
30 | from executor.app.middleware.lifespan_event import lifespan
31 | from executor.app.router.management_router import router as management_router
32 |
33 | app_config = AppConfig()
34 |
35 | logging.basicConfig(level=app_config.LOG_LEVEL)
36 | logger = logging.getLogger(__name__)
37 | logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
38 |
39 | app = FastAPI(lifespan=lifespan)
40 |
41 | app.add_exception_handler(ValueNotProvidedException, handle_bad_request)
42 | app.add_exception_handler(MissingIdException, handle_bad_request)
43 | app.add_exception_handler(JSONDecodeError, handle_bad_request)
44 | app.add_exception_handler(Exception, handle_generic_exception)
45 |
46 | app.include_router(management_router)
47 |
48 | add_timing_middleware(app, record=logger.info)
49 |
50 | register_dependencies()
51 |
52 | if __name__ == "__main__":
53 | uvicorn.run(app, host="0.0.0.0", port=8080) # noqa: S104 hardcoded-bind-all-interfaces
54 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/middleware/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/middleware/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/router/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/router/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/service/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/service/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/service/file_handler_service.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | from enum import Enum
4 |
5 | from executor.app.configuration.app_config import AppConfig
6 |
7 |
8 | class HashType(Enum):
9 | MD5 = hashlib.md5
10 |
11 |
12 | class FileHandlerService:
13 | def __init__(self, app_config: AppConfig, hash_type: HashType | None = None) -> None:
14 | self.__hash_type = hash_type or HashType.MD5
15 | self.app_config = app_config
16 |
17 | def generate_filename(self, field_id: str, app_id: str) -> str:
18 | filename = self.__hash_type.value(f"{app_id}_{field_id}".encode()).hexdigest()
19 | return f"{self.app_config.PERSISTENCE_FOLDER_PATH}/{filename}.json"
20 |
21 | def ensure_folder(self) -> None:
22 | os.makedirs(self.app_config.PERSISTENCE_FOLDER_PATH, exist_ok=True)
23 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/service/file_object_serializer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 | import logging
17 | import os
18 |
19 | from superlinked.framework.storage.in_memory.object_serializer import ObjectSerializer
20 |
21 | from executor.app.service.file_handler_service import FileHandlerService
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | EMPTY_JSON_OBJECT_SIZE = 4 # 4 characters: "{}"
27 |
28 |
29 | class FileObjectSerializer(ObjectSerializer):
30 | def __init__(self, file_handler_service: FileHandlerService) -> None:
31 | super().__init__()
32 | self.__file_handler_service = file_handler_service
33 |
34 | def write(self, field_identifier: str, serialized_object: str, app_identifier: str) -> None:
35 | self.__file_handler_service.ensure_folder()
36 |
37 | logger.info("Persisting database with field id: %s and app id: %s", field_identifier, app_identifier)
38 | file_with_path = self.__file_handler_service.generate_filename(field_identifier, app_identifier)
39 | with open(file_with_path, "w", encoding="utf-8") as file:
40 | logger.debug("Writing field: %s and app: %s file to: %s", field_identifier, app_identifier, file_with_path)
41 | json.dump(serialized_object, file)
42 |
43 | def read(self, field_identifier: str, app_identifier: str) -> str:
44 | logger.info("Restoring database using field id: %s and app id: %s", field_identifier, app_identifier)
45 | file_with_path = self.__file_handler_service.generate_filename(field_identifier, app_identifier)
46 |
47 | result = "{}"
48 | try:
49 | if os.path.isfile(file_with_path):
50 | with open(file_with_path, encoding="utf-8") as file:
51 | if os.stat(file_with_path).st_size >= EMPTY_JSON_OBJECT_SIZE:
52 | result = json.load(file)
53 | except json.JSONDecodeError:
54 | logger.exception("File is present but contains invalid data. File: %s", file_with_path)
55 | except Exception: # pylint: disable=broad-exception-caught
56 | logger.exception("An error occurred during the file read operation. File: %s", file_with_path)
57 | return result
58 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/service/persistence_service.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 |
17 | from superlinked.framework.dsl.executor.rest.rest_executor import RestApp
18 |
19 | from executor.app.service.file_object_serializer import FileObjectSerializer
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class PersistenceService:
25 | def __init__(self, serializer: FileObjectSerializer) -> None:
26 | self._applications: list[RestApp] = []
27 | self._serializer = serializer
28 |
29 | def register(self, rest_app: RestApp) -> None:
30 | if rest_app in self._applications:
31 | logger.warning("Application already exists: %s", rest_app)
32 | return
33 | logger.info("Rest app registered: %s", rest_app)
34 | self._applications.append(rest_app)
35 |
36 | def persist(self) -> None:
37 | for app in self._applications:
38 | app.online_app.persist(self._serializer)
39 |
40 | def restore(self) -> None:
41 | for app in self._applications:
42 | app.online_app.restore(self._serializer)
43 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/service/supervisor_service.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from xmlrpc.client import ServerProxy
16 |
17 |
18 | class SupervisorService:
19 | def __init__(self, server_proxy: ServerProxy) -> None:
20 | self.server = server_proxy
21 |
22 | def restart(self) -> str:
23 | """Restart the API via supervisor XML-RPC and return the result."""
24 | return str(self.server.supervisor.restart())
25 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/executor/app/util/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/util/fast_api_handler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any
16 |
17 | from fastapi import Request, Response, status
18 | from fastapi.responses import JSONResponse
19 | from pydantic import Field
20 | from superlinked.framework.common.util.immutable_model import ImmutableBaseModel
21 | from superlinked.framework.dsl.executor.rest.rest_handler import RestHandler
22 |
23 |
24 | class QueryResponse(ImmutableBaseModel):
25 | schema_: str = Field(..., alias="schema")
26 | results: list[dict[str, Any]]
27 |
28 |
29 | class FastApiHandler:
30 | def __init__(self, rest_handler: RestHandler) -> None:
31 | self.__rest_handler = rest_handler
32 |
33 | async def ingest(self, request: Request) -> Response:
34 | payload = await request.json()
35 | self.__rest_handler._ingest_handler(payload, request.url.path) # noqa: SLF001 private-member-access
36 | return Response(status_code=status.HTTP_202_ACCEPTED)
37 |
38 | async def query(self, request: Request) -> Response:
39 | payload = await request.json()
40 | result = self.__rest_handler._query_handler(payload, request.url.path) # noqa: SLF001 private-member-access
41 | query_response = QueryResponse(
42 | schema=result.schema._schema_name, # noqa: SLF001 private-member-access
43 | results=[
44 | {
45 | "entity": {
46 | "id": entry.entity.header.object_id,
47 | "origin": (
48 | {
49 | "id": entry.entity.header.object_id,
50 | "schema": entry.entity.header.schema_id,
51 | }
52 | if entry.entity.header.origin_id
53 | else {}
54 | ),
55 | },
56 | "obj": entry.stored_object,
57 | }
58 | for entry in result.entries
59 | ],
60 | )
61 | return JSONResponse(
62 | content=query_response.model_dump(by_alias=True),
63 | status_code=status.HTTP_200_OK,
64 | )
65 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/util/open_api_description_util.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from typing import Any
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class OpenApiDescriptionUtil:
10 | @staticmethod
11 | def get_open_api_description_by_key(key: str, file_path: str | None = None) -> dict[str, Any]:
12 | if file_path is None:
13 | file_path = os.path.join(os.getcwd(), "executor/openapi/static_endpoint_descriptor.json")
14 | with open(file_path, encoding="utf-8") as file:
15 | data = json.load(file)
16 | open_api_description = data.get(key)
17 | if open_api_description is None:
18 | logger.warning("No OpenAPI description found for key: %s", key)
19 | return open_api_description
20 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/app/util/registry_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | from importlib import import_module
17 |
18 | from superlinked.framework.dsl.registry.superlinked_registry import SuperlinkedRegistry
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class RegistryLoader:
24 | @staticmethod
25 | def get_registry(app_module_path: str) -> SuperlinkedRegistry | None:
26 | try:
27 | return import_module(app_module_path).SuperlinkedRegistry
28 | except ImportError:
29 | logger.exception("Module not found at: %s", app_module_path)
30 | except AttributeError:
31 | logger.exception("SuperlinkedRegistry not found in module: %s", app_module_path)
32 | except Exception: # pylint: disable=broad-exception-caught
33 | logger.exception("An unexpected error occurred while loading the module at: %s", app_module_path)
34 | return None
35 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/executor/supervisord.conf:
--------------------------------------------------------------------------------
1 | [supervisord]
2 |
3 | [inet_http_server]
4 | port=127.0.0.1:9001
5 |
6 | [rpcinterface:supervisor]
7 | supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface
8 |
9 | [program:uvicorn_host]
10 | command=/opt/pysetup/.venv/bin/python -m uvicorn executor.app.main:app --host 0.0.0.0 --port 8080
11 | numprocs=1
12 | process_name=uvicorn_host
13 | stdout_logfile=/dev/stdout
14 | stdout_logfile_maxbytes=0
15 | stderr_logfile=/dev/stderr
16 | stderr_logfile_maxbytes=0
17 | autorestart=true
18 |
19 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/Dockerfile:
--------------------------------------------------------------------------------
1 | # --- Python base ---
2 |
3 | FROM python:3.11.9 AS python-base
4 |
5 | ENV PYTHONUNBUFFERED=1 \
6 | PYTHONDONTWRITEBYTECODE=1 \
7 | \
8 | PIP_NO_CACHE_DIR=off \
9 | PIP_DISABLE_PIP_VERSION_CHECK=on \
10 | PIP_DEFAULT_TIMEOUT=100 \
11 | \
12 | POETRY_VERSION=1.8.2 \
13 | POETRY_HOME="/opt/poetry" \
14 | POETRY_VIRTUALENVS_IN_PROJECT=true \
15 | POETRY_NO_INTERACTION=1 \
16 | \
17 | PYSETUP_PATH="/opt/pysetup" \
18 | VENV_PATH="/opt/pysetup/.venv"
19 |
20 | ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH"
21 |
22 | # --- Builder base ---
23 |
24 | FROM python-base AS builder-base
25 | RUN apt-get update \
26 | && apt-get install --no-install-recommends -y \
27 | curl \
28 | build-essential \
29 | && apt-get clean && rm -rf /var/lib/apt/lists/*
30 |
31 | SHELL ["/bin/bash", "-o", "pipefail", "-c"]
32 | RUN curl -sSL https://install.python-poetry.org | python -
33 |
34 | WORKDIR $PYSETUP_PATH
35 |
36 | COPY poetry.lock pyproject.toml $PYSETUP_PATH/
37 | RUN poetry install --no-root --only poller
38 |
39 | # --- Poller ---
40 |
41 | FROM python:3.11.5-alpine AS poller
42 |
43 | COPY --from=builder-base /opt/pysetup/.venv /opt/pysetup/.venv
44 |
45 | WORKDIR /code
46 |
47 | #COPY config /code/config
48 | COPY poller/app /code/poller/app
49 | COPY poller/logging_config.ini poller/poller_config.ini /code/poller/
50 |
51 | ENTRYPOINT ["/opt/pysetup/.venv/bin/python", "-m", "poller.app.main", "--config_path", "/code/config/config.yaml"]
52 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/app_location_parser/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/app_location_parser/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/config/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/config/poller_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import configparser
16 | import logging
17 | import logging.config
18 |
19 |
20 | class PollerConfig:
21 | def __init__(self) -> None:
22 | poller_dir = "poller"
23 | poller_config_path = f"{poller_dir}/poller_config.ini"
24 | logging_config_path = f"{poller_dir}/logging_config.ini"
25 |
26 | config = configparser.ConfigParser()
27 | config.read(poller_config_path)
28 |
29 | self.poll_interval_seconds = config.getint("POLLER", "POLL_INTERVAL_SECONDS")
30 | self.executor_port = config.getint("POLLER", "EXECUTOR_PORT")
31 | self.executor_url = config.get("POLLER", "EXECUTOR_URL")
32 | self.aws_credentials = config.get("POLLER", "AWS_CREDENTIALS")
33 | self.gcp_credentials = config.get("POLLER", "GCP_CREDENTIALS")
34 | self.download_location = config.get("POLLER", "DOWNLOAD_LOCATION")
35 | self.logging_config = logging_config_path
36 |
37 | def setup_logger(self, name: str) -> logging.Logger:
38 | logging.config.fileConfig(self.logging_config)
39 | return logging.getLogger(name)
40 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Module for polling storage services and downloading updated files."""
16 |
17 | import argparse
18 |
19 | from poller.app.poller.poller import Poller
20 |
21 | if __name__ == "__main__":
22 | parser = argparse.ArgumentParser(description="Start the poller with a given config file.")
23 | parser.add_argument(
24 | "--config_path",
25 | help="Path to the configuration file.",
26 | default="../config/config.yaml",
27 | )
28 | args = parser.parse_args()
29 |
30 | poller = Poller(args.config_path)
31 | poller.run()
32 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/poller/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/poller/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/poller/poller.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import time
16 | from threading import Thread
17 |
18 | import yaml
19 |
20 | from poller.app.app_location_parser.app_location_parser import (
21 | AppLocation,
22 | AppLocationParser,
23 | )
24 | from poller.app.config.poller_config import PollerConfig
25 | from poller.app.resource_handler.resource_handler_factory import ResourceHandlerFactory
26 |
27 |
28 | class Poller(Thread):
29 | """
30 | The Poller class is a thread that polls files from different types of storage: local, S3,
31 | and Google Cloud Storage at a regular interval.
32 | """
33 |
34 | def __init__(self, config_path: str) -> None:
35 | Thread.__init__(self)
36 | self.poller_config = PollerConfig()
37 | self.app_location_config_path = config_path
38 | self.app_location_config = self.parse_app_location_config()
39 |
40 | def parse_app_location_config(self) -> AppLocation:
41 | """
42 | Parse the configuration from the YAML file.
43 | """
44 | with open(self.app_location_config_path, encoding="utf-8") as file:
45 | config_yaml = yaml.safe_load(file)
46 | app_location = config_yaml["app_location"]
47 | return AppLocationParser().parse(app_location)
48 |
49 | def run(self) -> None:
50 | """
51 | Start the polling process.
52 | """
53 | resource_handler = ResourceHandlerFactory.get_resource_handler(self.app_location_config)
54 | while True:
55 | if resource_handler.check_api_health():
56 | resource_handler.poll()
57 | time.sleep(self.poller_config.poll_interval_seconds)
58 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/gcs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/gcs/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/gcs/gcs_resource_handler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from google.auth.exceptions import DefaultCredentialsError
16 | from google.cloud import storage # type: ignore[attr-defined]
17 | from google.cloud.exceptions import GoogleCloudError
18 | from google.cloud.storage.client import Client as GCSClient
19 |
20 | from poller.app.app_location_parser.app_location_parser import AppLocation
21 | from poller.app.resource_handler.resource_handler import ResourceHandler
22 |
23 |
24 | class GCSResourceHandler(ResourceHandler):
25 | def __init__(self, app_location: AppLocation, client: GCSClient | None = None) -> None:
26 | super().__init__(app_location)
27 | self.client = client or self.initialize_gcs_client()
28 |
29 | def initialize_gcs_client(self) -> GCSClient:
30 | """
31 | Initialize the GCS client, with fallback to credentials file if necessary.
32 | """
33 | try:
34 | # First, try to create a GCS client without explicit credentials
35 | client = storage.Client()
36 | client.get_bucket(self.app_location.bucket) # Test access
37 | except (GoogleCloudError, DefaultCredentialsError):
38 | # If the first method fails, try to use credentials from a JSON file
39 | try:
40 | return storage.Client.from_service_account_json(
41 | json_credentials_path=self.poller_config.gcp_credentials,
42 | )
43 | except FileNotFoundError as e:
44 | msg = "Could not find GCP credentials file and no service account available."
45 | raise FileNotFoundError(msg) from e
46 | else:
47 | return client
48 |
49 | def poll(self) -> None:
50 | """
51 | Poll files from a Google Cloud Storage bucket and download new or modified files.
52 | """
53 | bucket = self.client.get_bucket(self.get_bucket())
54 | blobs = bucket.list_blobs(prefix=self.app_location.path)
55 | for blob in blobs:
56 | self.check_and_download(blob.updated, blob.name)
57 |
58 | def download_file(self, bucket_name: str | None, object_name: str, download_path: str) -> None:
59 | """
60 | Download a file from GCS to the specified path.
61 | """
62 | bucket = self.client.get_bucket(bucket_name)
63 | blob = bucket.blob(object_name)
64 | blob.download_to_filename(download_path)
65 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/local/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/local/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/local/local_resource_handler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import shutil
17 | from datetime import datetime, timezone
18 |
19 | from poller.app.resource_handler.resource_handler import ResourceHandler
20 |
21 |
22 | class LocalResourceHandler(ResourceHandler):
23 | def get_bucket(self) -> str:
24 | return "local"
25 |
26 | def download_file(self, _: str | None, object_name: str, download_path: str) -> None:
27 | """
28 | 'Download' a file from local storage to the specified path.
29 | In this case, it's just copying the file.
30 | """
31 | self.logger.info("Copy file from %s to %s", object_name, download_path)
32 | shutil.copy2(object_name, download_path)
33 |
34 | def poll(self) -> None:
35 | """
36 | Poll files from a local directory and notify about new or modified files.
37 | """
38 | self.logger.info("Polling files from: %s", self.app_location.path)
39 | if not self._path_exists():
40 | self.logger.error("Path does not exist: %s", self.app_location.path)
41 | return
42 | self._process_path()
43 |
44 | def _path_exists(self) -> bool:
45 | return os.path.exists(self.app_location.path)
46 |
47 | def _process_path(self) -> None:
48 | if os.path.isfile(self.app_location.path):
49 | self._process_file(self.app_location.path)
50 | else:
51 | self._process_directory()
52 |
53 | def _process_directory(self) -> None:
54 | for root, _, files in os.walk(self.app_location.path):
55 | for file in files:
56 | file_path = os.path.join(root, file)
57 | self._process_file(file_path)
58 |
59 | def _process_file(self, file_path: str) -> None:
60 | self.logger.info("Found file: %s", file_path)
61 | file_time = datetime.fromtimestamp(os.path.getmtime(file_path), tz=timezone.utc)
62 | try:
63 | self.check_and_download(file_time, self.app_location.path)
64 | except (FileNotFoundError, PermissionError):
65 | self.logger.exception("Failed to download and notify new version")
66 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/resource_handler_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from poller.app.app_location_parser.app_location_parser import AppLocation, StorageType
16 | from poller.app.resource_handler.gcs.gcs_resource_handler import GCSResourceHandler
17 | from poller.app.resource_handler.local.local_resource_handler import (
18 | LocalResourceHandler,
19 | )
20 | from poller.app.resource_handler.resource_handler import ResourceHandler
21 | from poller.app.resource_handler.s3.s3_resource_handler import S3ResourceHandler
22 |
23 |
24 | class ResourceHandlerFactory:
25 | @staticmethod
26 | def get_resource_handler(app_location: AppLocation) -> ResourceHandler:
27 | match app_location.type_:
28 | case StorageType.S3:
29 | return S3ResourceHandler(app_location)
30 | case StorageType.GCS:
31 | return GCSResourceHandler(app_location)
32 | case StorageType.LOCAL:
33 | return LocalResourceHandler(app_location)
34 | case _:
35 | msg = f"Invalid resource type in config: {app_location.type_}"
36 | raise ValueError(msg)
37 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/s3/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/s3/__init__.py
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/app/resource_handler/s3/s3_resource_handler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Superlinked, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import json
16 |
17 | import boto3
18 | from botocore.client import Config
19 | from botocore.exceptions import ClientError
20 | from mypy_boto3_s3.client import S3Client
21 |
22 | from poller.app.app_location_parser.app_location_parser import AppLocation
23 | from poller.app.resource_handler.resource_handler import ResourceHandler
24 |
25 |
26 | class S3ResourceHandler(ResourceHandler):
27 | def __init__(
28 | self,
29 | app_location: AppLocation,
30 | client: S3Client | None = None,
31 | ) -> None: # client=None for easier testability
32 | super().__init__(app_location)
33 | self.client = client or self.initialize_s3_client()
34 | self.resource = boto3.resource("s3")
35 |
36 | def initialize_s3_client(self) -> S3Client:
37 | """
38 | Initialize the S3 client, with fallback to credentials file if necessary.
39 | """
40 | try:
41 | # First, try to create an S3 resource without explicit credentials
42 | client = boto3.client("s3", config=Config(signature_version="s3v4"))
43 | if self.app_location.bucket is not None:
44 | client.head_bucket(Bucket=self.app_location.bucket) # Test access
45 | except ClientError:
46 | # If the first method fails, try to use credentials from a JSON file
47 | try:
48 | with open(self.poller_config.aws_credentials, encoding="utf-8") as aws_cred_file:
49 | aws_credentials = json.load(aws_cred_file)
50 | return boto3.client(
51 | "s3",
52 | aws_access_key_id=aws_credentials["aws_access_key_id"],
53 | aws_secret_access_key=aws_credentials["aws_secret_access_key"],
54 | region_name=aws_credentials["region"],
55 | )
56 | except FileNotFoundError as e:
57 | msg = "Could not find AWS credentials file and no IAM role available."
58 | raise FileNotFoundError(msg) from e
59 | else:
60 | return client
61 |
62 | def poll(self) -> None:
63 | """
64 | Poll files from an S3 bucket and download new or modified files.
65 | """
66 | bucket = self.resource.Bucket(self.get_bucket())
67 | for obj in bucket.objects.filter(Prefix=self.app_location.path):
68 | self.check_and_download(obj.last_modified, obj.key)
69 |
70 | def download_file(self, _: str | None, object_name: str, download_path: str) -> None:
71 | """
72 | Download a file from S3 to the specified path.
73 | """
74 | bucket = self.get_bucket()
75 | self.client.download_file(Bucket=bucket, Key=object_name, Filename=download_path)
76 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/logging_config.ini:
--------------------------------------------------------------------------------
1 | [loggers]
2 | keys=root,sampleLogger
3 |
4 | [handlers]
5 | keys=consoleHandler
6 |
7 | [formatters]
8 | keys=sampleFormatter
9 |
10 | [logger_root]
11 | level=INFO
12 | handlers=consoleHandler
13 |
14 | [logger_sampleLogger]
15 | level=INFO
16 | handlers=consoleHandler
17 | qualname=sampleLogger
18 | propagate=0
19 |
20 | [handler_consoleHandler]
21 | class=StreamHandler
22 | level=INFO
23 | formatter=sampleFormatter
24 | args=(sys.__stdout__,)
25 |
26 | [formatter_sampleFormatter]
27 | format=%(asctime)s - %(levelname)s - %(message)s
28 | datefmt=
29 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/runner/poller/poller_config.ini:
--------------------------------------------------------------------------------
1 | [POLLER]
2 | POLL_INTERVAL_SECONDS=10
3 | EXECUTOR_PORT=8080
4 | EXECUTOR_URL=http://executor
5 | DOWNLOAD_LOCATION=/code/data/src
6 | AWS_CREDENTIALS=/code/aws_credentials.json
7 | GCP_CREDENTIALS=/code/gcp_credentials.json
8 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/server/tools/init-venv.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # please install pyenv, poetry on your system
3 | # on Mac
4 | # brew install pyenv
5 | # brew install poetry
6 |
7 | set -euxo pipefail
8 |
9 | script_path=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)
10 | cd "${script_path}/../runner/"
11 |
12 | eval "$(pyenv init --path)"
13 |
14 | python_version=$(cat ../.python-version)
15 |
16 | if ! pyenv versions | grep -q "${python_version}"; then
17 | echo y | pyenv install "${python_version}"
18 | fi
19 |
20 | poetry env use $(pyenv local ${python_version} && pyenv which python)
21 |
22 | poetry cache list | awk '{print $1}' | xargs -I {} poetry cache clear {} --all
23 |
24 | poetry install --no-root --with poller,executor,dev
25 |
26 | source "$(poetry env info --path)/bin/activate"
27 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/singleton.py:
--------------------------------------------------------------------------------
1 | from threading import Lock
2 | from typing import ClassVar
3 |
4 |
5 | class SingletonMeta(type):
6 | """
7 | This is a thread-safe implementation of Singleton.
8 | """
9 |
10 | _instances: ClassVar = {}
11 |
12 | _lock: Lock = Lock()
13 |
14 | """
15 | We now have a lock object that will be used to synchronize threads during
16 | first access to the Singleton.
17 | """
18 |
19 | def __call__(cls, *args, **kwargs):
20 | """
21 | Possible changes to the value of the `__init__` argument do not affect
22 | the returned instance.
23 | """
24 | # Now, imagine that the program has just been launched. Since there's no
25 | # Singleton instance yet, multiple threads can simultaneously pass the
26 | # previous conditional and reach this point almost at the same time. The
27 | # first of them will acquire lock and will proceed further, while the
28 | # rest will wait here.
29 | with cls._lock:
30 | # The first thread to acquire the lock, reaches this conditional,
31 | # goes inside and creates the Singleton instance. Once it leaves the
32 | # lock block, a thread that might have been waiting for the lock
33 | # release may then enter this section. But since the Singleton field
34 | # is already initialized, the thread won't create a new object.
35 | if cls not in cls._instances:
36 | instance = super().__call__(*args, **kwargs)
37 | cls._instances[cls] = instance
38 |
39 | return cls._instances[cls]
40 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/utils/__init__.py:
--------------------------------------------------------------------------------
1 | def flatten(nested_list: list) -> list:
2 | """Flatten a list of lists into a single list."""
3 |
4 | return [item for sublist in nested_list for item in sublist]
5 |
--------------------------------------------------------------------------------
/src/bonus_superlinked_rag/utils/logging.py:
--------------------------------------------------------------------------------
1 | import structlog
2 |
3 |
4 | def get_logger(cls: str):
5 | return structlog.get_logger().bind(cls=cls)
--------------------------------------------------------------------------------
/src/core/__init__.py:
--------------------------------------------------------------------------------
1 | from . import db, logger_utils, opik_utils
2 | from .logger_utils import get_logger
3 |
4 | logger = get_logger(__file__)
5 |
6 | try:
7 | from .opik_utils import configure_opik
8 |
9 | configure_opik()
10 | except:
11 | logger.warning("Could not configure Opik.")
12 |
13 | __all__ = ["get_logger", "logger_utils", "opik_utils", "db"]
14 |
--------------------------------------------------------------------------------
/src/core/aws/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/core/aws/create_execution_role.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | from core.logger_utils import get_logger
5 |
6 | logger = get_logger(__file__)
7 |
8 | try:
9 | import boto3
10 | except ModuleNotFoundError:
11 | logger.warning(
12 | "Couldn't load AWS or SageMaker imports. Run 'poetry install --with aws' to support AWS."
13 | )
14 |
15 | from core.config import settings
16 |
17 |
18 | def create_sagemaker_execution_role(role_name: str):
19 | assert settings.AWS_REGION, "AWS_REGION is not set."
20 | assert settings.AWS_ACCESS_KEY, "AWS_ACCESS_KEY is not set."
21 | assert settings.AWS_SECRET_KEY, "AWS_SECRET_KEY is not set."
22 |
23 | # Create IAM client
24 | iam = boto3.client(
25 | "iam",
26 | region_name=settings.AWS_REGION,
27 | aws_access_key_id=settings.AWS_ACCESS_KEY,
28 | aws_secret_access_key=settings.AWS_SECRET_KEY,
29 | )
30 |
31 | # Define the trust relationship policy
32 | trust_relationship = {
33 | "Version": "2012-10-17",
34 | "Statement": [
35 | {
36 | "Effect": "Allow",
37 | "Principal": {"Service": "sagemaker.amazonaws.com"},
38 | "Action": "sts:AssumeRole",
39 | }
40 | ],
41 | }
42 |
43 | try:
44 | # Create the IAM role
45 | role = iam.create_role(
46 | RoleName=role_name,
47 | AssumeRolePolicyDocument=json.dumps(trust_relationship),
48 | Description="Execution role for SageMaker",
49 | )
50 |
51 | # Attach necessary policies
52 | policies = [
53 | "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
54 | "arn:aws:iam::aws:policy/AmazonS3FullAccess",
55 | "arn:aws:iam::aws:policy/CloudWatchLogsFullAccess",
56 | "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryFullAccess",
57 | ]
58 |
59 | for policy in policies:
60 | iam.attach_role_policy(RoleName=role_name, PolicyArn=policy)
61 |
62 | logger.info(f"Role '{role_name}' created successfully.")
63 | logger.info(f"Role ARN: {role['Role']['Arn']}")
64 |
65 | return role["Role"]["Arn"]
66 |
67 | except iam.exceptions.EntityAlreadyExistsException:
68 | logger.warning(f"Role '{role_name}' already exists. Fetching its ARN...")
69 | role = iam.get_role(RoleName=role_name)
70 |
71 | return role["Role"]["Arn"]
72 |
73 |
74 | if __name__ == "__main__":
75 | role_arn = create_sagemaker_execution_role("SageMakerExecutionRoleLLM")
76 | logger.info(role_arn)
77 |
78 | # Save the role ARN to a file
79 | with Path("sagemaker_execution_role.json").open("w") as f:
80 | json.dump({"RoleArn": role_arn}, f)
81 |
82 | logger.info("Role ARN saved to 'sagemaker_execution_role.json'")
83 |
--------------------------------------------------------------------------------
/src/core/aws/create_sagemaker_role.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | from logger_utils import get_logger
5 |
6 | logger = get_logger(__file__)
7 |
8 | try:
9 | import boto3
10 | except ModuleNotFoundError:
11 | logger.warning(
12 | "Couldn't load AWS or SageMaker imports. Run 'poetry install --with aws' to support AWS."
13 | )
14 |
15 | from config import settings
16 |
17 |
18 | def create_sagemaker_user(username: str):
19 | assert settings.AWS_REGION, "AWS_REGION is not set."
20 | assert settings.AWS_ACCESS_KEY, "AWS_ACCESS_KEY is not set."
21 | assert settings.AWS_SECRET_KEY, "AWS_SECRET_KEY is not set."
22 |
23 | # Create IAM client
24 | iam = boto3.client(
25 | "iam",
26 | region_name=settings.AWS_REGION,
27 | aws_access_key_id=settings.AWS_ACCESS_KEY,
28 | aws_secret_access_key=settings.AWS_SECRET_KEY,
29 | )
30 |
31 | # Create user
32 | iam.create_user(UserName=username)
33 |
34 | # Attach necessary policies
35 | policies = [
36 | "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
37 | "arn:aws:iam::aws:policy/AWSCloudFormationFullAccess",
38 | "arn:aws:iam::aws:policy/IAMFullAccess",
39 | "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryFullAccess",
40 | "arn:aws:iam::aws:policy/AmazonS3FullAccess",
41 | ]
42 |
43 | for policy in policies:
44 | iam.attach_user_policy(UserName=username, PolicyArn=policy)
45 |
46 | # Create access key
47 | response = iam.create_access_key(UserName=username)
48 | access_key = response["AccessKey"]
49 |
50 | logger.info(f"User '{username}' successfully created.")
51 | logger.info("Access Key ID and Secret Access Key successfully created.")
52 |
53 | return {
54 | "AccessKeyId": access_key["AccessKeyId"],
55 | "SecretAccessKey": access_key["SecretAccessKey"],
56 | }
57 |
58 |
59 | if __name__ == "__main__":
60 | new_user = create_sagemaker_user("sagemaker-deployer")
61 |
62 | with Path("sagemaker_user_credentials.json").open("w") as f:
63 | json.dump(new_user, f)
64 |
65 | logger.info("Credentials saved to 'sagemaker_user_credentials.json'")
66 |
--------------------------------------------------------------------------------
/src/core/config.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from pydantic_settings import BaseSettings, SettingsConfigDict
4 |
5 | ROOT_DIR = str(Path(__file__).parent.parent.parent)
6 |
7 |
8 | class AppSettings(BaseSettings):
9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8")
10 |
11 | # MongoDB configs
12 | MONGO_DATABASE_HOST: str = (
13 | "mongodb://mongo1:30001,mongo2:30002,mongo3:30003/?replicaSet=my-replica-set"
14 | )
15 | MONGO_DATABASE_NAME: str = "twin"
16 |
17 | # MQ config
18 | RABBITMQ_DEFAULT_USERNAME: str = "guest"
19 | RABBITMQ_DEFAULT_PASSWORD: str = "guest"
20 | RABBITMQ_HOST: str = "mq"
21 | RABBITMQ_PORT: int = 5673
22 |
23 | # QdrantDB config
24 | QDRANT_CLOUD_URL: str = "str"
25 | QDRANT_DATABASE_HOST: str = "qdrant"
26 | QDRANT_DATABASE_PORT: int = 6333
27 | USE_QDRANT_CLOUD: bool = False
28 | QDRANT_APIKEY: str | None = None
29 |
30 | # OpenAI config
31 | OPENAI_MODEL_ID: str = "gpt-4o-mini"
32 | OPENAI_API_KEY: str | None = None
33 |
34 | # CometML config
35 | COMET_API_KEY: str | None = None
36 | COMET_WORKSPACE: str | None = None
37 | COMET_PROJECT: str = "llm-twin"
38 |
39 | # AWS Authentication
40 | AWS_REGION: str = "eu-central-1"
41 | AWS_ACCESS_KEY: str | None = None
42 | AWS_SECRET_KEY: str | None = None
43 | AWS_ARN_ROLE: str | None = None
44 |
45 | # LLM Model config
46 | HUGGINGFACE_ACCESS_TOKEN: str | None = None
47 | MODEL_ID: str = "pauliusztin/LLMTwin-Llama-3.1-8B"
48 | DEPLOYMENT_ENDPOINT_NAME: str = "twin"
49 |
50 | MAX_INPUT_TOKENS: int = 1536 # Max length of input text.
51 | MAX_TOTAL_TOKENS: int = 2048 # Max length of the generation (including input text).
52 | MAX_BATCH_TOTAL_TOKENS: int = 2048 # Limits the number of tokens that can be processed in parallel during the generation.
53 |
54 | # Embeddings config
55 | EMBEDDING_MODEL_ID: str = "BAAI/bge-small-en-v1.5"
56 | EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 512
57 | EMBEDDING_SIZE: int = 384
58 | EMBEDDING_MODEL_DEVICE: str = "cpu"
59 |
60 | def patch_localhost(self) -> None:
61 | self.MONGO_DATABASE_HOST = "mongodb://localhost:30001,localhost:30002,localhost:30003/?replicaSet=my-replica-set"
62 | self.QDRANT_DATABASE_HOST = "localhost"
63 | self.RABBITMQ_HOST = "localhost"
64 |
65 |
66 | settings = AppSettings()
67 |
--------------------------------------------------------------------------------
/src/core/db/__init__.py:
--------------------------------------------------------------------------------
1 | from . import documents, mongo, qdrant
2 |
3 | __all__ = ["documents", "mongo", "qdrant"]
4 |
--------------------------------------------------------------------------------
/src/core/db/mongo.py:
--------------------------------------------------------------------------------
1 | from pymongo import MongoClient
2 | from pymongo.errors import ConnectionFailure
3 |
4 | from core.config import settings
5 | from core.logger_utils import get_logger
6 |
7 | logger = get_logger(__file__)
8 |
9 |
10 | class MongoDatabaseConnector:
11 | """Singleton class to connect to MongoDB database."""
12 |
13 | _instance: MongoClient | None = None
14 |
15 | def __new__(cls, *args, **kwargs):
16 | if cls._instance is None:
17 | try:
18 | cls._instance = MongoClient(settings.MONGO_DATABASE_HOST)
19 | logger.info(
20 | f"Connection to database with uri: {settings.MONGO_DATABASE_HOST} successful"
21 | )
22 | except ConnectionFailure:
23 | logger.error(f"Couldn't connect to the database.")
24 |
25 | raise
26 |
27 | return cls._instance
28 |
29 | def get_database(self):
30 | assert self._instance, "Database connection not initialized"
31 |
32 | return self._instance[settings.MONGO_DATABASE_NAME]
33 |
34 | def close(self):
35 | if self._instance:
36 | self._instance.close()
37 | logger.info("Connected to database has been closed.")
38 |
39 |
40 | connection = MongoDatabaseConnector()
41 |
--------------------------------------------------------------------------------
/src/core/db/qdrant.py:
--------------------------------------------------------------------------------
1 | from qdrant_client import QdrantClient, models
2 | from qdrant_client.http.models import Batch, Distance, VectorParams
3 |
4 | import core.logger_utils as logger_utils
5 | from core.config import settings
6 |
7 | logger = logger_utils.get_logger(__name__)
8 |
9 |
10 | class QdrantDatabaseConnector:
11 | _instance: QdrantClient | None = None
12 |
13 | def __init__(self) -> None:
14 | if self._instance is None:
15 | if settings.USE_QDRANT_CLOUD:
16 | self._instance = QdrantClient(
17 | url=settings.QDRANT_CLOUD_URL,
18 | api_key=settings.QDRANT_APIKEY,
19 | )
20 | else:
21 | self._instance = QdrantClient(
22 | host=settings.QDRANT_DATABASE_HOST,
23 | port=settings.QDRANT_DATABASE_PORT,
24 | )
25 |
26 | def get_collection(self, collection_name: str):
27 | return self._instance.get_collection(collection_name=collection_name)
28 |
29 | def create_non_vector_collection(self, collection_name: str):
30 | self._instance.create_collection(
31 | collection_name=collection_name, vectors_config={}
32 | )
33 |
34 | def create_vector_collection(self, collection_name: str):
35 | self._instance.create_collection(
36 | collection_name=collection_name,
37 | vectors_config=VectorParams(
38 | size=settings.EMBEDDING_SIZE, distance=Distance.COSINE
39 | ),
40 | )
41 |
42 | def write_data(self, collection_name: str, points: Batch):
43 | try:
44 | self._instance.upsert(collection_name=collection_name, points=points)
45 | except Exception:
46 | logger.exception("An error occurred while inserting data.")
47 |
48 | raise
49 |
50 | def search(
51 | self,
52 | collection_name: str,
53 | query_vector: list,
54 | query_filter: models.Filter | None = None,
55 | limit: int = 3,
56 | ) -> list:
57 | return self._instance.search(
58 | collection_name=collection_name,
59 | query_vector=query_vector,
60 | query_filter=query_filter,
61 | limit=limit,
62 | )
63 |
64 | def scroll(self, collection_name: str, limit: int):
65 | return self._instance.scroll(collection_name=collection_name, limit=limit)
66 |
67 | def close(self):
68 | if self._instance:
69 | self._instance.close()
70 |
71 | logger.info("Connected to database has been closed.")
72 |
--------------------------------------------------------------------------------
/src/core/errors.py:
--------------------------------------------------------------------------------
1 | class TwinBaseException(Exception):
2 | pass
3 |
4 |
5 | class ImproperlyConfigured(TwinBaseException):
6 | pass
7 |
--------------------------------------------------------------------------------
/src/core/lib.py:
--------------------------------------------------------------------------------
1 | from core.errors import ImproperlyConfigured
2 |
3 |
4 | def split_user_full_name(user: str | None) -> tuple[str, str]:
5 | if user is None:
6 | raise ImproperlyConfigured("User name is empty")
7 |
8 | name_tokens = user.split(" ")
9 | if len(name_tokens) == 0:
10 | raise ImproperlyConfigured("User name is empty")
11 | elif len(name_tokens) == 1:
12 | first_name, last_name = name_tokens[0], name_tokens[0]
13 | else:
14 | first_name, last_name = " ".join(name_tokens[:-1]), name_tokens[-1]
15 |
16 | return first_name, last_name
17 |
18 |
19 | def flatten(nested_list: list) -> list:
20 | """Flatten a list of lists into a single list."""
21 |
22 | return [item for sublist in nested_list for item in sublist]
23 |
--------------------------------------------------------------------------------
/src/core/logger_utils.py:
--------------------------------------------------------------------------------
1 | import structlog
2 |
3 |
4 | def get_logger(cls: str):
5 | return structlog.get_logger().bind(cls=cls)
6 |
--------------------------------------------------------------------------------
/src/core/rag/__init__.py:
--------------------------------------------------------------------------------
1 | from .query_expanison import QueryExpansion
--------------------------------------------------------------------------------
/src/core/rag/prompt_templates.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from langchain.prompts import PromptTemplate
4 | from pydantic import BaseModel
5 |
6 |
7 | class BasePromptTemplate(ABC, BaseModel):
8 | @abstractmethod
9 | def create_template(self, *args) -> PromptTemplate:
10 | pass
11 |
12 |
13 | class QueryExpansionTemplate(BasePromptTemplate):
14 | prompt: str = """You are an AI language model assistant. Your task is to generate {to_expand_to_n}
15 | different versions of the given user question to retrieve relevant documents from a vector
16 | database. By generating multiple perspectives on the user question, your goal is to help
17 | the user overcome some of the limitations of the distance-based similarity search.
18 | Provide these alternative questions separated by '{separator}'.
19 | Original question: {question}"""
20 |
21 | @property
22 | def separator(self) -> str:
23 | return "#next-question#"
24 |
25 | def create_template(self, to_expand_to_n: int) -> PromptTemplate:
26 | return PromptTemplate(
27 | template=self.prompt,
28 | input_variables=["question"],
29 | partial_variables={
30 | "separator": self.separator,
31 | "to_expand_to_n": to_expand_to_n,
32 | },
33 | )
34 |
35 |
36 | class SelfQueryTemplate(BasePromptTemplate):
37 | prompt: str = """You are an AI language model assistant. Your task is to extract information from a user question.
38 | The required information that needs to be extracted is the user name or user id.
39 | Your response should consists of only the extracted user name (e.g., John Doe) or id (e.g. 1345256), nothing else.
40 | If the user question does not contain any user name or id, you should return the following token: none.
41 |
42 | For example:
43 | QUESTION 1:
44 | My name is Paul Iusztin and I want a post about...
45 | RESPONSE 1:
46 | Paul Iusztin
47 |
48 | QUESTION 2:
49 | I want to write a post about...
50 | RESPONSE 2:
51 | none
52 |
53 | QUESTION 3:
54 | My user id is 1345256 and I want to write a post about...
55 | RESPONSE 3:
56 | 1345256
57 |
58 | User question: {question}"""
59 |
60 | def create_template(self) -> PromptTemplate:
61 | return PromptTemplate(template=self.prompt, input_variables=["question"])
62 |
63 |
64 | class RerankingTemplate(BasePromptTemplate):
65 | prompt: str = """You are an AI language model assistant. Your task is to rerank passages related to a query
66 | based on their relevance.
67 | The most relevant passages should be put at the beginning.
68 | You should only pick at max {keep_top_k} passages.
69 | The provided and reranked documents are separated by '{separator}'.
70 |
71 | The following are passages related to this query: {question}.
72 |
73 | Passages:
74 | {passages}
75 | """
76 |
77 | def create_template(self, keep_top_k: int) -> PromptTemplate:
78 | return PromptTemplate(
79 | template=self.prompt,
80 | input_variables=["question", "passages"],
81 | partial_variables={"keep_top_k": keep_top_k, "separator": self.separator},
82 | )
83 |
84 | @property
85 | def separator(self) -> str:
86 | return "\n#next-document#\n"
87 |
--------------------------------------------------------------------------------
/src/core/rag/query_expanison.py:
--------------------------------------------------------------------------------
1 | import opik
2 | from config import settings
3 | from langchain_openai import ChatOpenAI
4 | from opik.integrations.langchain import OpikTracer
5 |
6 | from core.rag.prompt_templates import QueryExpansionTemplate
7 |
8 |
9 | class QueryExpansion:
10 | opik_tracer = OpikTracer(tags=["QueryExpansion"])
11 |
12 | @staticmethod
13 | @opik.track(name="QueryExpansion.generate_response")
14 | def generate_response(query: str, to_expand_to_n: int) -> list[str]:
15 | query_expansion_template = QueryExpansionTemplate()
16 | prompt = query_expansion_template.create_template(to_expand_to_n)
17 | model = ChatOpenAI(
18 | model=settings.OPENAI_MODEL_ID,
19 | api_key=settings.OPENAI_API_KEY,
20 | temperature=0,
21 | )
22 | chain = prompt | model
23 | chain = chain.with_config({"callbacks": [QueryExpansion.opik_tracer]})
24 |
25 | response = chain.invoke({"question": query})
26 | response = response.content
27 |
28 | queries = response.strip().split(query_expansion_template.separator)
29 | stripped_queries = [
30 | stripped_item for item in queries if (stripped_item := item.strip(" \\n"))
31 | ]
32 |
33 | return stripped_queries
34 |
--------------------------------------------------------------------------------
/src/core/rag/reranking.py:
--------------------------------------------------------------------------------
1 | from config import settings
2 | from langchain_openai import ChatOpenAI
3 |
4 | from core.rag.prompt_templates import RerankingTemplate
5 |
6 |
7 | class Reranker:
8 | @staticmethod
9 | def generate_response(
10 | query: str, passages: list[str], keep_top_k: int
11 | ) -> list[str]:
12 | reranking_template = RerankingTemplate()
13 | prompt = reranking_template.create_template(keep_top_k=keep_top_k)
14 | model = ChatOpenAI(
15 | model=settings.OPENAI_MODEL_ID, api_key=settings.OPENAI_API_KEY
16 | )
17 | chain = prompt | model
18 |
19 | stripped_passages = [
20 | stripped_item for item in passages if (stripped_item := item.strip())
21 | ]
22 | passages = reranking_template.separator.join(stripped_passages)
23 | response = chain.invoke({"question": query, "passages": passages})
24 | response = response.content
25 |
26 | reranked_passages = response.strip().split(reranking_template.separator)
27 | stripped_passages = [
28 | stripped_item
29 | for item in reranked_passages
30 | if (stripped_item := item.strip())
31 | ]
32 |
33 | return stripped_passages
34 |
--------------------------------------------------------------------------------
/src/core/rag/self_query.py:
--------------------------------------------------------------------------------
1 | import opik
2 | from config import settings
3 | from langchain_openai import ChatOpenAI
4 | from opik.integrations.langchain import OpikTracer
5 |
6 | import core.logger_utils as logger_utils
7 | from core import lib
8 | from core.db.documents import UserDocument
9 | from core.rag.prompt_templates import SelfQueryTemplate
10 |
11 | logger = logger_utils.get_logger(__name__)
12 |
13 |
14 | class SelfQuery:
15 | opik_tracer = OpikTracer(tags=["SelfQuery"])
16 |
17 | @staticmethod
18 | @opik.track(name="SelQuery.generate_response")
19 | def generate_response(query: str) -> str | None:
20 | prompt = SelfQueryTemplate().create_template()
21 | model = ChatOpenAI(
22 | model=settings.OPENAI_MODEL_ID,
23 | api_key=settings.OPENAI_API_KEY,
24 | temperature=0,
25 | )
26 | chain = prompt | model
27 | chain = chain.with_config({"callbacks": [SelfQuery.opik_tracer]})
28 |
29 | response = chain.invoke({"question": query})
30 | response = response.content
31 | user_full_name = response.strip("\n ")
32 |
33 | if user_full_name == "none":
34 | return None
35 |
36 | logger.info(
37 | f"Successfully extracted the user full name from the query.",
38 | user_full_name=user_full_name,
39 | )
40 | first_name, last_name = lib.split_user_full_name(user_full_name)
41 | logger.info(
42 | f"Successfully extracted the user first and last name from the query.",
43 | first_name=first_name,
44 | last_name=last_name,
45 | )
46 | user_id = UserDocument.get_or_create(first_name=first_name, last_name=last_name)
47 |
48 | return user_id
49 |
--------------------------------------------------------------------------------
/src/data_cdc/cdc.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 |
4 | from bson import json_util
5 | from config import settings
6 | from core.db.mongo import MongoDatabaseConnector
7 | from core.logger_utils import get_logger
8 | from core.mq import publish_to_rabbitmq
9 |
10 | logger = get_logger(__file__)
11 |
12 |
13 | def stream_process():
14 | try:
15 | client = MongoDatabaseConnector()
16 | db = client["twin"]
17 | logging.info("Connected to MongoDB.")
18 |
19 | # Watch changes in a specific collection
20 | changes = db.watch([{"$match": {"operationType": {"$in": ["insert"]}}}])
21 | for change in changes:
22 | data_type = change["ns"]["coll"]
23 | entry_id = str(change["fullDocument"]["_id"]) # Convert ObjectId to string
24 |
25 | change["fullDocument"].pop("_id")
26 | change["fullDocument"]["type"] = data_type
27 | change["fullDocument"]["entry_id"] = entry_id
28 |
29 | if data_type not in ["articles", "posts", "repositories"]:
30 | logging.info(f"Unsupported data type: '{data_type}'")
31 | continue
32 |
33 | # Use json_util to serialize the document
34 | data = json.dumps(change["fullDocument"], default=json_util.default)
35 | logger.info(
36 | f"Change detected and serialized for a data sample of type {data_type}."
37 | )
38 |
39 | # Send data to rabbitmq
40 | publish_to_rabbitmq(queue_name=settings.RABBITMQ_QUEUE_NAME, data=data)
41 | logger.info(f"Data of type '{data_type}' published to RabbitMQ.")
42 |
43 | except Exception as e:
44 | logger.error(f"An error occurred: {e}")
45 |
46 |
47 | if __name__ == "__main__":
48 | stream_process()
49 |
--------------------------------------------------------------------------------
/src/data_cdc/config.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from pydantic_settings import BaseSettings, SettingsConfigDict
4 |
5 | ROOT_DIR = str(Path(__file__).parent.parent.parent)
6 |
7 |
8 | class Settings(BaseSettings):
9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8")
10 |
11 | MONGO_DATABASE_HOST: str = (
12 | "mongodb://mongo1:30001,mongo2:30002,mongo3:30003/?replicaSet=my-replica-set"
13 | )
14 | MONGO_DATABASE_NAME: str = "twin"
15 |
16 | RABBITMQ_HOST: str = "mq" # or localhost if running outside Docker
17 | RABBITMQ_PORT: int = 5672
18 | RABBITMQ_DEFAULT_USERNAME: str = "guest"
19 | RABBITMQ_DEFAULT_PASSWORD: str = "guest"
20 | RABBITMQ_QUEUE_NAME: str = "default"
21 |
22 |
23 | settings = Settings()
24 |
--------------------------------------------------------------------------------
/src/data_cdc/test_cdc.py:
--------------------------------------------------------------------------------
1 | from pymongo import MongoClient
2 |
3 |
4 | def insert_data_to_mongodb(uri, database_name, collection_name, data):
5 | """
6 | Insert data into a MongoDB collection.
7 |
8 | :param uri: MongoDB URI
9 | :param database_name: Name of the database
10 | :param collection_name: Name of the collection
11 | :param data: Data to be inserted (dict)
12 | """
13 | client = MongoClient(uri)
14 | db = client[database_name]
15 | collection = db[collection_name]
16 |
17 | try:
18 | result = collection.insert_one(data)
19 | print(f"Data inserted with _id: {result.inserted_id}")
20 | except Exception as e:
21 | print(f"An error occurred: {e}")
22 | finally:
23 | client.close()
24 |
25 |
26 | if __name__ == "__main__":
27 | insert_data_to_mongodb(
28 | "mongodb://localhost:30001,localhost:30002,localhost:30003/?replicaSet=my-replica-set",
29 | "twin",
30 | "posts",
31 | {"platform": "linkedin", "content": "Test content"}
32 | )
33 |
--------------------------------------------------------------------------------
/src/data_crawling/config.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from pydantic_settings import BaseSettings, SettingsConfigDict
4 |
5 | ROOT_DIR = str(Path(__file__).parent.parent.parent)
6 |
7 |
8 | class Settings(BaseSettings):
9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8")
10 |
11 | MONGO_DATABASE_HOST: str = (
12 | "mongodb://mongo1:30001,mongo2:30002,mongo3:30003/?replicaSet=my-replica-set"
13 | )
14 | MONGO_DATABASE_NAME: str = "twin"
15 |
16 | # Optional LinkedIn credentials for scraping your profile
17 | LINKEDIN_USERNAME: str | None = None
18 | LINKEDIN_PASSWORD: str | None = None
19 |
20 |
21 | settings = Settings()
22 |
--------------------------------------------------------------------------------
/src/data_crawling/crawlers/__init__.py:
--------------------------------------------------------------------------------
1 | from .custom_article import CustomArticleCrawler
2 | from .github import GithubCrawler
3 | from .linkedin import LinkedInCrawler
4 | from .medium import MediumCrawler
5 |
6 | __all__ = ["CustomArticleCrawler", "GithubCrawler", "LinkedInCrawler", "MediumCrawler"]
7 |
--------------------------------------------------------------------------------
/src/data_crawling/crawlers/base.py:
--------------------------------------------------------------------------------
1 | import time
2 | from abc import ABC, abstractmethod
3 | from tempfile import mkdtemp
4 |
5 | from core.db.documents import BaseDocument
6 | from selenium import webdriver
7 | from selenium.webdriver.chrome.options import Options
8 |
9 |
10 | class BaseCrawler(ABC):
11 | model: type[BaseDocument]
12 |
13 | @abstractmethod
14 | def extract(self, link: str, **kwargs) -> None: ...
15 |
16 |
17 | class BaseAbstractCrawler(BaseCrawler, ABC):
18 | def __init__(self, scroll_limit: int = 5) -> None:
19 | options = webdriver.ChromeOptions()
20 |
21 | options.add_argument("--no-sandbox")
22 | options.add_argument("--headless=new")
23 | options.add_argument("--disable-dev-shm-usage")
24 | options.add_argument("--log-level=3")
25 | options.add_argument("--disable-popup-blocking")
26 | options.add_argument("--disable-notifications")
27 | options.add_argument("--disable-extensions")
28 | options.add_argument("--disable-background-networking")
29 | options.add_argument("--ignore-certificate-errors")
30 | options.add_argument(f"--user-data-dir={mkdtemp()}")
31 | options.add_argument(f"--data-path={mkdtemp()}")
32 | options.add_argument(f"--disk-cache-dir={mkdtemp()}")
33 | options.add_argument("--remote-debugging-port=9226")
34 |
35 | self.set_extra_driver_options(options)
36 |
37 | self.scroll_limit = scroll_limit
38 | self.driver = webdriver.Chrome(
39 | options=options,
40 | )
41 |
42 | def set_extra_driver_options(self, options: Options) -> None:
43 | pass
44 |
45 | def login(self) -> None:
46 | pass
47 |
48 | def scroll_page(self) -> None:
49 | """Scroll through the LinkedIn page based on the scroll limit."""
50 | current_scroll = 0
51 | last_height = self.driver.execute_script("return document.body.scrollHeight")
52 | while True:
53 | self.driver.execute_script(
54 | "window.scrollTo(0, document.body.scrollHeight);"
55 | )
56 | time.sleep(5)
57 | new_height = self.driver.execute_script("return document.body.scrollHeight")
58 | if new_height == last_height or (
59 | self.scroll_limit and current_scroll >= self.scroll_limit
60 | ):
61 | break
62 | last_height = new_height
63 | current_scroll += 1
64 |
--------------------------------------------------------------------------------
/src/data_crawling/crawlers/custom_article.py:
--------------------------------------------------------------------------------
1 | from urllib.parse import urlparse
2 |
3 | from aws_lambda_powertools import Logger
4 | from core.db.documents import ArticleDocument
5 | from langchain_community.document_loaders import AsyncHtmlLoader
6 | from langchain_community.document_transformers.html2text import Html2TextTransformer
7 |
8 | from .base import BaseCrawler
9 |
10 | logger = Logger(service="llm-twin-course/crawler")
11 |
12 |
13 | class CustomArticleCrawler(BaseCrawler):
14 | model = ArticleDocument
15 |
16 | def __init__(self) -> None:
17 | super().__init__()
18 |
19 | def extract(self, link: str, **kwargs) -> None:
20 | old_model = self.model.find(link=link)
21 | if old_model is not None:
22 | logger.info(f"Article already exists in the database: {link}")
23 |
24 | return
25 |
26 | logger.info(f"Starting scrapping article: {link}")
27 |
28 | loader = AsyncHtmlLoader([link])
29 | docs = loader.load()
30 |
31 | html2text = Html2TextTransformer()
32 | docs_transformed = html2text.transform_documents(docs)
33 | doc_transformed = docs_transformed[0]
34 |
35 | content = {
36 | "Title": doc_transformed.metadata.get("title"),
37 | "Subtitle": doc_transformed.metadata.get("description"),
38 | "Content": doc_transformed.page_content,
39 | "language": doc_transformed.metadata.get("language"),
40 | }
41 |
42 | parsed_url = urlparse(link)
43 | platform = parsed_url.netloc
44 |
45 | instance = self.model(
46 | content=content,
47 | link=link,
48 | platform=platform,
49 | author_id=kwargs.get("user"),
50 | )
51 | instance.save()
52 |
53 | logger.info(f"Finished scrapping custom article: {link}")
54 |
--------------------------------------------------------------------------------
/src/data_crawling/crawlers/github.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 | import tempfile
5 |
6 | from aws_lambda_powertools import Logger
7 | from core.db.documents import RepositoryDocument
8 |
9 | from crawlers.base import BaseCrawler
10 |
11 | logger = Logger(service="llm-twin-course/crawler")
12 |
13 |
14 | class GithubCrawler(BaseCrawler):
15 | model = RepositoryDocument
16 |
17 | def __init__(self, ignore=(".git", ".toml", ".lock", ".png")) -> None:
18 | super().__init__()
19 | self._ignore = ignore
20 |
21 | def extract(self, link: str, **kwargs) -> None:
22 | logger.info(f"Starting scrapping GitHub repository: {link}")
23 |
24 | repo_name = link.rstrip("/").split("/")[-1]
25 |
26 | local_temp = tempfile.mkdtemp()
27 |
28 | try:
29 | os.chdir(local_temp)
30 | subprocess.run(["git", "clone", link])
31 |
32 | repo_path = os.path.join(local_temp, os.listdir(local_temp)[0])
33 |
34 | tree = {}
35 | for root, dirs, files in os.walk(repo_path):
36 | dir = root.replace(repo_path, "").lstrip("/")
37 | if dir.startswith(self._ignore):
38 | continue
39 |
40 | for file in files:
41 | if file.endswith(self._ignore):
42 | continue
43 | file_path = os.path.join(dir, file)
44 | with open(os.path.join(root, file), "r", errors="ignore") as f:
45 | tree[file_path] = f.read().replace(" ", "")
46 |
47 | instance = self.model(
48 | name=repo_name, link=link, content=tree, owner_id=kwargs.get("user")
49 | )
50 | instance.save()
51 |
52 | except Exception:
53 | raise
54 | finally:
55 | shutil.rmtree(local_temp)
56 |
57 | logger.info(f"Finished scrapping GitHub repository: {link}")
58 |
--------------------------------------------------------------------------------
/src/data_crawling/crawlers/medium.py:
--------------------------------------------------------------------------------
1 | from aws_lambda_powertools import Logger
2 | from bs4 import BeautifulSoup
3 | from core.db.documents import ArticleDocument
4 | from selenium.webdriver.common.by import By
5 |
6 | from crawlers.base import BaseAbstractCrawler
7 |
8 | logger = Logger(service="llm-twin-course/crawler")
9 |
10 |
11 | class MediumCrawler(BaseAbstractCrawler):
12 | model = ArticleDocument
13 |
14 | def set_extra_driver_options(self, options) -> None:
15 | options.add_argument(r"--profile-directory=Profile 2")
16 |
17 | def extract(self, link: str, **kwargs) -> None:
18 | logger.info(f"Starting scrapping Medium article: {link}")
19 |
20 | self.driver.get(link)
21 | self.scroll_page()
22 |
23 | soup = BeautifulSoup(self.driver.page_source, "html.parser")
24 | title = soup.find_all("h1", class_="pw-post-title")
25 | subtitle = soup.find_all("h2", class_="pw-subtitle-paragraph")
26 |
27 | data = {
28 | "Title": title[0].string if title else None,
29 | "Subtitle": subtitle[0].string if subtitle else None,
30 | "Content": soup.get_text(),
31 | }
32 |
33 | logger.info(f"Successfully scraped and saved article: {link}")
34 | self.driver.close()
35 | instance = self.model(
36 | platform="medium", content=data, link=link, author_id=kwargs.get("user")
37 | )
38 | instance.save()
39 |
40 | def login(self):
41 | """Log in to Medium with Google"""
42 | self.driver.get("https://medium.com/m/signin")
43 | self.driver.find_element(By.TAG_NAME, "a").click()
44 |
--------------------------------------------------------------------------------
/src/data_crawling/dispatcher.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from aws_lambda_powertools import Logger
4 | from crawlers.base import BaseCrawler
5 | from crawlers.custom_article import CustomArticleCrawler
6 |
7 | logger = Logger(service="llm-twin-course/crawler")
8 |
9 |
10 | class CrawlerDispatcher:
11 | def __init__(self) -> None:
12 | self._crawlers = {}
13 |
14 | def register(self, domain: str, crawler: type[BaseCrawler]) -> None:
15 | self._crawlers[r"https://(www\.)?{}.com/*".format(re.escape(domain))] = crawler
16 |
17 | def get_crawler(self, url: str) -> BaseCrawler:
18 | for pattern, crawler in self._crawlers.items():
19 | if re.match(pattern, url):
20 | return crawler()
21 | else:
22 | logger.warning(
23 | f"No crawler found for {url}. Defaulting to CustomArticleCrawler."
24 | )
25 |
26 | return CustomArticleCrawler()
27 |
--------------------------------------------------------------------------------
/src/data_crawling/main.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from aws_lambda_powertools import Logger
4 | from aws_lambda_powertools.utilities.typing import LambdaContext
5 | from core import lib
6 | from core.db.documents import UserDocument
7 | from crawlers import CustomArticleCrawler, GithubCrawler, LinkedInCrawler
8 | from dispatcher import CrawlerDispatcher
9 |
10 | logger = Logger(service="llm-twin-course/crawler")
11 |
12 | _dispatcher = CrawlerDispatcher()
13 | _dispatcher.register("medium", CustomArticleCrawler)
14 | _dispatcher.register("linkedin", LinkedInCrawler)
15 | _dispatcher.register("github", GithubCrawler)
16 |
17 |
18 | def handler(event, context: LambdaContext | None = None) -> dict[str, Any]:
19 | first_name, last_name = lib.split_user_full_name(event.get("user"))
20 |
21 | user_id = UserDocument.get_or_create(first_name=first_name, last_name=last_name)
22 |
23 | link = event.get("link")
24 | crawler = _dispatcher.get_crawler(link)
25 |
26 | try:
27 | crawler.extract(link=link, user=user_id)
28 |
29 | return {"statusCode": 200, "body": "Link processed successfully"}
30 | except Exception as e:
31 | return {"statusCode": 500, "body": f"An error occurred: {str(e)}"}
32 |
33 |
34 | if __name__ == "__main__":
35 | event = {
36 | "user": "Paul Iuztin",
37 | "link": "https://www.linkedin.com/in/vesaalexandru/",
38 | }
39 | handler(event, None)
40 |
--------------------------------------------------------------------------------
/src/data_crawling/utils.py:
--------------------------------------------------------------------------------
1 | import structlog
2 |
3 |
4 | def get_logger(cls: str):
5 | return structlog.get_logger().bind(cls=cls)
6 |
--------------------------------------------------------------------------------
/src/feature_pipeline/config.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from pydantic_settings import BaseSettings, SettingsConfigDict
4 |
5 | ROOT_DIR = str(Path(__file__).parent.parent.parent)
6 |
7 |
8 | class Settings(BaseSettings):
9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8")
10 |
11 | # CometML config
12 | COMET_API_KEY: str | None = None
13 | COMET_WORKSPACE: str | None = None
14 | COMET_PROJECT: str = "llm-twin"
15 |
16 | # Embeddings config
17 | EMBEDDING_MODEL_ID: str = "BAAI/bge-small-en-v1.5"
18 | EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 512
19 | EMBEDDING_SIZE: int = 384
20 | EMBEDDING_MODEL_DEVICE: str = "cpu"
21 |
22 | # OpenAI
23 | OPENAI_MODEL_ID: str = "gpt-4o-mini"
24 | OPENAI_API_KEY: str | None = None
25 |
26 | # MQ config
27 | RABBITMQ_DEFAULT_USERNAME: str = "guest"
28 | RABBITMQ_DEFAULT_PASSWORD: str = "guest"
29 | RABBITMQ_HOST: str = "mq" # or localhost if running outside Docker
30 | RABBITMQ_PORT: int = 5672
31 | RABBITMQ_QUEUE_NAME: str = "default"
32 |
33 | # QdrantDB config
34 | QDRANT_DATABASE_HOST: str = "qdrant" # or localhost if running outside Docker
35 | QDRANT_DATABASE_PORT: int = 6333
36 | USE_QDRANT_CLOUD: bool = (
37 | False # if True, fill in QDRANT_CLOUD_URL and QDRANT_APIKEY
38 | )
39 | QDRANT_CLOUD_URL: str | None = None
40 | QDRANT_APIKEY: str | None = None
41 |
42 |
43 | settings = Settings()
44 |
--------------------------------------------------------------------------------
/src/feature_pipeline/data_flow/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/feature_pipeline/data_flow/__init__.py
--------------------------------------------------------------------------------
/src/feature_pipeline/data_flow/stream_input.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | from datetime import datetime
4 | from typing import Generic, Iterable, List, Optional, TypeVar
5 |
6 | from bytewax.inputs import FixedPartitionedSource, StatefulSourcePartition
7 | from config import settings
8 | from core import get_logger
9 | from core.mq import RabbitMQConnection
10 |
11 | logger = get_logger(__name__)
12 |
13 | DataT = TypeVar("DataT")
14 | MessageT = TypeVar("MessageT")
15 |
16 |
17 | class RabbitMQPartition(StatefulSourcePartition, Generic[DataT, MessageT]):
18 | """
19 | Class responsible for creating a connection between bytewax and rabbitmq that facilitates the transfer of data from mq to bytewax streaming piepline.
20 | Inherits StatefulSourcePartition for snapshot functionality that enables saving the state of the queue
21 | """
22 |
23 | def __init__(self, queue_name: str, resume_state: MessageT | None = None) -> None:
24 | self._in_flight_msg_ids = resume_state or set()
25 | self.queue_name = queue_name
26 | self.connection = RabbitMQConnection()
27 | self.connection.connect()
28 | self.channel = self.connection.get_channel()
29 |
30 | def next_batch(self, sched: Optional[datetime]) -> Iterable[DataT]:
31 | try:
32 | method_frame, header_frame, body = self.channel.basic_get(
33 | queue=self.queue_name, auto_ack=True
34 | )
35 | except Exception:
36 | logger.error(
37 | f"Error while fetching message from queue.", queue_name=self.queue_name
38 | )
39 | time.sleep(10) # Sleep for 10 seconds before retrying to access the queue.
40 |
41 | self.connection.connect()
42 | self.channel = self.connection.get_channel()
43 |
44 | return []
45 |
46 | if method_frame:
47 | message_id = method_frame.delivery_tag
48 | self._in_flight_msg_ids.add(message_id)
49 |
50 | return [json.loads(body)]
51 | else:
52 | return []
53 |
54 | def snapshot(self) -> MessageT:
55 | return self._in_flight_msg_ids
56 |
57 | def garbage_collect(self, state):
58 | closed_in_flight_msg_ids = state
59 | for msg_id in closed_in_flight_msg_ids:
60 | self.channel.basic_ack(delivery_tag=msg_id)
61 | self._in_flight_msg_ids.remove(msg_id)
62 |
63 | def close(self):
64 | self.channel.close()
65 |
66 |
67 | class RabbitMQSource(FixedPartitionedSource):
68 | def list_parts(self) -> List[str]:
69 | return ["single partition"]
70 |
71 | def build_part(
72 | self, now: datetime, for_part: str, resume_state: MessageT | None = None
73 | ) -> StatefulSourcePartition[DataT, MessageT]:
74 | return RabbitMQPartition(queue_name=settings.RABBITMQ_QUEUE_NAME)
75 |
--------------------------------------------------------------------------------
/src/feature_pipeline/data_logic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/feature_pipeline/data_logic/__init__.py
--------------------------------------------------------------------------------
/src/feature_pipeline/data_logic/chunking_data_handlers.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | from abc import ABC, abstractmethod
3 |
4 | from models.base import DataModel
5 | from models.chunk import ArticleChunkModel, PostChunkModel, RepositoryChunkModel
6 | from models.clean import ArticleCleanedModel, PostCleanedModel, RepositoryCleanedModel
7 | from utils.chunking import chunk_text
8 |
9 |
10 | class ChunkingDataHandler(ABC):
11 | """
12 | Abstract class for all Chunking data handlers.
13 | All data transformations logic for the chunking step is done here
14 | """
15 |
16 | @abstractmethod
17 | def chunk(self, data_model: DataModel) -> list[DataModel]:
18 | pass
19 |
20 |
21 | class PostChunkingHandler(ChunkingDataHandler):
22 | def chunk(self, data_model: PostCleanedModel) -> list[PostChunkModel]:
23 | data_models_list = []
24 |
25 | text_content = data_model.cleaned_content
26 | chunks = chunk_text(text_content)
27 |
28 | for chunk in chunks:
29 | model = PostChunkModel(
30 | entry_id=data_model.entry_id,
31 | platform=data_model.platform,
32 | chunk_id=hashlib.md5(chunk.encode()).hexdigest(),
33 | chunk_content=chunk,
34 | author_id=data_model.author_id,
35 | image=data_model.image if data_model.image else None,
36 | type=data_model.type,
37 | )
38 | data_models_list.append(model)
39 |
40 | return data_models_list
41 |
42 |
43 | class ArticleChunkingHandler(ChunkingDataHandler):
44 | def chunk(self, data_model: ArticleCleanedModel) -> list[ArticleChunkModel]:
45 | data_models_list = []
46 |
47 | text_content = data_model.cleaned_content
48 | chunks = chunk_text(text_content)
49 |
50 | for chunk in chunks:
51 | model = ArticleChunkModel(
52 | entry_id=data_model.entry_id,
53 | platform=data_model.platform,
54 | link=data_model.link,
55 | chunk_id=hashlib.md5(chunk.encode()).hexdigest(),
56 | chunk_content=chunk,
57 | author_id=data_model.author_id,
58 | type=data_model.type,
59 | )
60 | data_models_list.append(model)
61 |
62 | return data_models_list
63 |
64 |
65 | class RepositoryChunkingHandler(ChunkingDataHandler):
66 | def chunk(self, data_model: RepositoryCleanedModel) -> list[RepositoryChunkModel]:
67 | data_models_list = []
68 |
69 | text_content = data_model.cleaned_content
70 | chunks = chunk_text(text_content)
71 |
72 | for chunk in chunks:
73 | model = RepositoryChunkModel(
74 | entry_id=data_model.entry_id,
75 | name=data_model.name,
76 | link=data_model.link,
77 | chunk_id=hashlib.md5(chunk.encode()).hexdigest(),
78 | chunk_content=chunk,
79 | owner_id=data_model.owner_id,
80 | type=data_model.type,
81 | )
82 | data_models_list.append(model)
83 |
84 | return data_models_list
85 |
--------------------------------------------------------------------------------
/src/feature_pipeline/data_logic/cleaning_data_handlers.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from models.base import DataModel
4 | from models.clean import ArticleCleanedModel, PostCleanedModel, RepositoryCleanedModel
5 | from models.raw import ArticleRawModel, PostsRawModel, RepositoryRawModel
6 | from utils.cleaning import clean_text
7 |
8 |
9 | class CleaningDataHandler(ABC):
10 | """
11 | Abstract class for all cleaning data handlers.
12 | All data transformations logic for the cleaning step is done here
13 | """
14 |
15 | @abstractmethod
16 | def clean(self, data_model: DataModel) -> DataModel:
17 | pass
18 |
19 |
20 | class PostCleaningHandler(CleaningDataHandler):
21 | def clean(self, data_model: PostsRawModel) -> PostCleanedModel:
22 | joined_text = (
23 | "".join(data_model.content.values()) if data_model and data_model.content else None
24 | )
25 |
26 | return PostCleanedModel(
27 | entry_id=data_model.entry_id,
28 | platform=data_model.platform,
29 | cleaned_content=clean_text(joined_text),
30 | author_id=data_model.author_id,
31 | image=data_model.image if data_model.image else None,
32 | type=data_model.type,
33 | )
34 |
35 |
36 | class ArticleCleaningHandler(CleaningDataHandler):
37 | def clean(self, data_model: ArticleRawModel) -> ArticleCleanedModel:
38 | joined_text = (
39 | "".join(data_model.content.values()) if data_model and data_model.content else None
40 | )
41 |
42 | return ArticleCleanedModel(
43 | entry_id=data_model.entry_id,
44 | platform=data_model.platform,
45 | link=data_model.link,
46 | cleaned_content=clean_text(joined_text),
47 | author_id=data_model.author_id,
48 | type=data_model.type,
49 | )
50 |
51 |
52 | class RepositoryCleaningHandler(CleaningDataHandler):
53 | def clean(self, data_model: RepositoryRawModel) -> RepositoryCleanedModel:
54 | joined_text = (
55 | "".join(data_model.content.values()) if data_model and data_model.content else None
56 | )
57 |
58 | return RepositoryCleanedModel(
59 | entry_id=data_model.entry_id,
60 | name=data_model.name,
61 | link=data_model.link,
62 | cleaned_content=clean_text(joined_text),
63 | owner_id=data_model.owner_id,
64 | type=data_model.type,
65 | )
66 |
--------------------------------------------------------------------------------
/src/feature_pipeline/data_logic/embedding_data_handlers.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from models.base import DataModel
4 | from models.chunk import ArticleChunkModel, PostChunkModel, RepositoryChunkModel
5 | from models.embedded_chunk import (
6 | ArticleEmbeddedChunkModel,
7 | PostEmbeddedChunkModel,
8 | RepositoryEmbeddedChunkModel,
9 | )
10 | from utils.embeddings import embedd_text
11 |
12 |
13 | class EmbeddingDataHandler(ABC):
14 | """
15 | Abstract class for all embedding data handlers.
16 | All data transformations logic for the embedding step is done here
17 | """
18 |
19 | @abstractmethod
20 | def embedd(self, data_model: DataModel) -> DataModel:
21 | pass
22 |
23 |
24 | class PostEmbeddingHandler(EmbeddingDataHandler):
25 | def embedd(self, data_model: PostChunkModel) -> PostEmbeddedChunkModel:
26 | return PostEmbeddedChunkModel(
27 | entry_id=data_model.entry_id,
28 | platform=data_model.platform,
29 | chunk_id=data_model.chunk_id,
30 | chunk_content=data_model.chunk_content,
31 | embedded_content=embedd_text(data_model.chunk_content),
32 | author_id=data_model.author_id,
33 | type=data_model.type,
34 | )
35 |
36 |
37 | class ArticleEmbeddingHandler(EmbeddingDataHandler):
38 | def embedd(self, data_model: ArticleChunkModel) -> ArticleEmbeddedChunkModel:
39 | return ArticleEmbeddedChunkModel(
40 | entry_id=data_model.entry_id,
41 | platform=data_model.platform,
42 | link=data_model.link,
43 | chunk_content=data_model.chunk_content,
44 | chunk_id=data_model.chunk_id,
45 | embedded_content=embedd_text(data_model.chunk_content),
46 | author_id=data_model.author_id,
47 | type=data_model.type,
48 | )
49 |
50 |
51 | class RepositoryEmbeddingHandler(EmbeddingDataHandler):
52 | def embedd(self, data_model: RepositoryChunkModel) -> RepositoryEmbeddedChunkModel:
53 | return RepositoryEmbeddedChunkModel(
54 | entry_id=data_model.entry_id,
55 | name=data_model.name,
56 | link=data_model.link,
57 | chunk_id=data_model.chunk_id,
58 | chunk_content=data_model.chunk_content,
59 | embedded_content=embedd_text(data_model.chunk_content),
60 | owner_id=data_model.owner_id,
61 | type=data_model.type,
62 | )
63 |
--------------------------------------------------------------------------------
/src/feature_pipeline/generate_dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/feature_pipeline/generate_dataset/__init__.py
--------------------------------------------------------------------------------
/src/feature_pipeline/generate_dataset/chunk_documents.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def chunk_documents(documents: list[str], min_length: int = 1000, max_length: int = 2000):
5 | chunked_documents = []
6 | for document in documents:
7 | chunks = extract_substrings(document, min_length=min_length, max_length=max_length)
8 | chunked_documents.extend(chunks)
9 |
10 | return chunked_documents
11 |
12 | def extract_substrings(
13 | text: str, min_length: int = 1000, max_length: int = 2000
14 | ) -> list[str]:
15 | sentences = re.split(r"(?= min_length:
28 | extracts.append(current_chunk.strip())
29 | current_chunk = sentence + " "
30 |
31 | if len(current_chunk) >= min_length:
32 | extracts.append(current_chunk.strip())
33 |
34 | return extracts
35 |
--------------------------------------------------------------------------------
/src/feature_pipeline/generate_dataset/exceptions.py:
--------------------------------------------------------------------------------
1 | class DatasetError(Exception):
2 | pass
3 |
4 |
5 | class FileNotFoundError(DatasetError):
6 | pass
7 |
8 |
9 | class JSONDecodeError(DatasetError):
10 | pass
11 |
12 |
13 | class APICommunicationError(DatasetError):
14 | pass
15 |
--------------------------------------------------------------------------------
/src/feature_pipeline/generate_dataset/file_handler.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from generate_dataset.exceptions import JSONDecodeError
4 |
5 |
6 | class FileHandler:
7 | def read_json(self, filename: str) -> list:
8 | try:
9 | with open(filename, "r") as file:
10 | return json.load(file)
11 | except FileNotFoundError:
12 | raise FileNotFoundError(f"The file '{filename}' does not exist.")
13 | except json.JSONDecodeError:
14 | raise JSONDecodeError(
15 | f"The file '{filename}' is not properly formatted as JSON."
16 | )
17 |
18 | def write_json(self, filename: str, data: list):
19 | with open(filename, "w") as file:
20 | json.dump(data, file, indent=4)
21 |
--------------------------------------------------------------------------------
/src/feature_pipeline/generate_dataset/llm_communication.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from config import settings
4 | from core import get_logger
5 | from openai import OpenAI
6 |
7 | MAX_LENGTH = 16384
8 | SYSTEM_PROMPT = (
9 | "You are a technical writer handing someone's account to post about AI and MLOps."
10 | )
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | class GptCommunicator:
16 | def __init__(self, gpt_model: str = settings.OPENAI_MODEL_ID):
17 | self.api_key = settings.OPENAI_API_KEY
18 | self.gpt_model = gpt_model
19 |
20 | def send_prompt(self, prompt: str) -> list:
21 | try:
22 | client = OpenAI(api_key=self.api_key)
23 | logger.info(f"Sending batch to GPT = '{settings.OPENAI_MODEL_ID}'.")
24 |
25 | chat_completion = client.chat.completions.create(
26 | messages=[
27 | {"role": "system", "content": SYSTEM_PROMPT},
28 | {"role": "user", "content": prompt[:MAX_LENGTH]},
29 | ],
30 | model=self.gpt_model,
31 | )
32 | response = chat_completion.choices[0].message.content
33 | return json.loads(self.clean_response(response))
34 | except Exception:
35 | logger.exception(
36 | f"Skipping batch! An error occurred while communicating with API."
37 | )
38 |
39 | return []
40 |
41 | @staticmethod
42 | def clean_response(response: str) -> str:
43 | start_index = response.find("[")
44 | end_index = response.rfind("]")
45 | return response[start_index : end_index + 1]
46 |
--------------------------------------------------------------------------------
/src/feature_pipeline/main.py:
--------------------------------------------------------------------------------
1 | import bytewax.operators as op
2 | from bytewax.dataflow import Dataflow
3 | from core.db.qdrant import QdrantDatabaseConnector
4 | from data_flow.stream_input import RabbitMQSource
5 | from data_flow.stream_output import QdrantOutput
6 | from data_logic.dispatchers import (
7 | ChunkingDispatcher,
8 | CleaningDispatcher,
9 | EmbeddingDispatcher,
10 | RawDispatcher,
11 | )
12 |
13 | connection = QdrantDatabaseConnector()
14 |
15 | flow = Dataflow("Streaming ingestion pipeline")
16 | stream = op.input("input", flow, RabbitMQSource())
17 | stream = op.map("raw dispatch", stream, RawDispatcher.handle_mq_message)
18 | stream = op.map("clean dispatch", stream, CleaningDispatcher.dispatch_cleaner)
19 | op.output(
20 | "cleaned data insert to qdrant",
21 | stream,
22 | QdrantOutput(connection=connection, sink_type="clean"),
23 | )
24 | stream = op.flat_map("chunk dispatch", stream, ChunkingDispatcher.dispatch_chunker)
25 | stream = op.map(
26 | "embedded chunk dispatch", stream, EmbeddingDispatcher.dispatch_embedder
27 | )
28 | op.output(
29 | "embedded data insert to qdrant",
30 | stream,
31 | QdrantOutput(connection=connection, sink_type="vector"),
32 | )
33 |
--------------------------------------------------------------------------------
/src/feature_pipeline/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/feature_pipeline/models/__init__.py
--------------------------------------------------------------------------------
/src/feature_pipeline/models/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from pydantic import BaseModel
4 |
5 |
6 | class DataModel(BaseModel):
7 | """
8 | Abstract class for all data models
9 | """
10 |
11 | entry_id: str
12 | type: str
13 |
14 |
15 | class VectorDBDataModel(ABC, DataModel):
16 | """
17 | Abstract class for all data models that need to be saved into a vector DB (e.g. Qdrant)
18 | """
19 |
20 | entry_id: int
21 | type: str
22 |
23 | @abstractmethod
24 | def to_payload(self) -> tuple:
25 | pass
26 |
--------------------------------------------------------------------------------
/src/feature_pipeline/models/chunk.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from models.base import DataModel
4 |
5 |
6 | class PostChunkModel(DataModel):
7 | entry_id: str
8 | platform: str
9 | chunk_id: str
10 | chunk_content: str
11 | author_id: str
12 | image: Optional[str] = None
13 | type: str
14 |
15 |
16 | class ArticleChunkModel(DataModel):
17 | entry_id: str
18 | platform: str
19 | link: str
20 | chunk_id: str
21 | chunk_content: str
22 | author_id: str
23 | type: str
24 |
25 |
26 | class RepositoryChunkModel(DataModel):
27 | entry_id: str
28 | name: str
29 | link: str
30 | chunk_id: str
31 | chunk_content: str
32 | owner_id: str
33 | type: str
34 |
--------------------------------------------------------------------------------
/src/feature_pipeline/models/clean.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | from models.base import VectorDBDataModel
4 |
5 |
6 | class PostCleanedModel(VectorDBDataModel):
7 | entry_id: str
8 | platform: str
9 | cleaned_content: str
10 | author_id: str
11 | image: Optional[str] = None
12 | type: str
13 |
14 | def to_payload(self) -> Tuple[str, dict]:
15 | data = {
16 | "platform": self.platform,
17 | "author_id": self.author_id,
18 | "cleaned_content": self.cleaned_content,
19 | "image": self.image,
20 | "type": self.type,
21 | }
22 |
23 | return self.entry_id, data
24 |
25 |
26 | class ArticleCleanedModel(VectorDBDataModel):
27 | entry_id: str
28 | platform: str
29 | link: str
30 | cleaned_content: str
31 | author_id: str
32 | type: str
33 |
34 | def to_payload(self) -> Tuple[str, dict]:
35 | data = {
36 | "platform": self.platform,
37 | "link": self.link,
38 | "cleaned_content": self.cleaned_content,
39 | "author_id": self.author_id,
40 | "type": self.type,
41 | }
42 |
43 | return self.entry_id, data
44 |
45 |
46 | class RepositoryCleanedModel(VectorDBDataModel):
47 | entry_id: str
48 | name: str
49 | link: str
50 | cleaned_content: str
51 | owner_id: str
52 | type: str
53 |
54 | def to_payload(self) -> Tuple[str, dict]:
55 | data = {
56 | "name": self.name,
57 | "link": self.link,
58 | "cleaned_content": self.cleaned_content,
59 | "owner_id": self.owner_id,
60 | "type": self.type,
61 | }
62 |
63 | return self.entry_id, data
64 |
--------------------------------------------------------------------------------
/src/feature_pipeline/models/embedded_chunk.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 |
5 | from models.base import VectorDBDataModel
6 |
7 |
8 | class PostEmbeddedChunkModel(VectorDBDataModel):
9 | entry_id: str
10 | platform: str
11 | chunk_id: str
12 | chunk_content: str
13 | embedded_content: np.ndarray
14 | author_id: str
15 | type: str
16 |
17 | class Config:
18 | arbitrary_types_allowed = True
19 |
20 | def to_payload(self) -> Tuple[str, np.ndarray, dict]:
21 | data = {
22 | "id": self.entry_id,
23 | "platform": self.platform,
24 | "content": self.chunk_content,
25 | "owner_id": self.author_id,
26 | "type": self.type,
27 | }
28 |
29 | return self.chunk_id, self.embedded_content, data
30 |
31 |
32 | class ArticleEmbeddedChunkModel(VectorDBDataModel):
33 | entry_id: str
34 | platform: str
35 | link: str
36 | chunk_id: str
37 | chunk_content: str
38 | embedded_content: np.ndarray
39 | author_id: str
40 | type: str
41 |
42 | class Config:
43 | arbitrary_types_allowed = True
44 |
45 | def to_payload(self) -> Tuple[str, np.ndarray, dict]:
46 | data = {
47 | "id": self.entry_id,
48 | "platform": self.platform,
49 | "content": self.chunk_content,
50 | "link": self.link,
51 | "author_id": self.author_id,
52 | "type": self.type,
53 | }
54 |
55 | return self.chunk_id, self.embedded_content, data
56 |
57 |
58 | class RepositoryEmbeddedChunkModel(VectorDBDataModel):
59 | entry_id: str
60 | name: str
61 | link: str
62 | chunk_id: str
63 | chunk_content: str
64 | embedded_content: np.ndarray
65 | owner_id: str
66 | type: str
67 |
68 | class Config:
69 | arbitrary_types_allowed = True
70 |
71 | def to_payload(self) -> Tuple[str, np.ndarray, dict]:
72 | data = {
73 | "id": self.entry_id,
74 | "name": self.name,
75 | "content": self.chunk_content,
76 | "link": self.link,
77 | "owner_id": self.owner_id,
78 | "type": self.type,
79 | }
80 |
81 | return self.chunk_id, self.embedded_content, data
82 |
--------------------------------------------------------------------------------
/src/feature_pipeline/models/raw.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from models.base import DataModel
4 |
5 |
6 | class RepositoryRawModel(DataModel):
7 | name: str
8 | link: str
9 | content: dict
10 | owner_id: str
11 |
12 |
13 | class ArticleRawModel(DataModel):
14 | platform: str
15 | link: str
16 | content: dict
17 | author_id: str
18 |
19 |
20 | class PostsRawModel(DataModel):
21 | platform: str
22 | content: dict
23 | author_id: str | None = None
24 | image: Optional[str] = None
25 |
--------------------------------------------------------------------------------
/src/feature_pipeline/retriever.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | # To mimic using multiple Python modules, such as 'core' and 'feature_pipeline',
5 | # we will add the './src' directory to the PYTHONPATH. This is not intended for
6 | # production use cases but for development and educational purposes.
7 | ROOT_DIR = str(Path(__file__).parent.parent)
8 | sys.path.append(ROOT_DIR)
9 |
10 |
11 | from core import get_logger
12 | from core.config import settings
13 | from core.rag.retriever import VectorRetriever
14 |
15 | logger = get_logger(__name__)
16 |
17 | settings.patch_localhost()
18 | logger.warning(
19 | "Patched settings to work with 'localhost' URLs. \
20 | Remove the 'settings.patch_localhost()' call from above when deploying or running inside Docker."
21 | )
22 |
23 | if __name__ == "__main__":
24 | query = """
25 | Hello I am Paul Iusztin.
26 |
27 | Could you draft an article paragraph discussing RAG?
28 | I'm particularly interested in how to design a RAG system.
29 | """
30 |
31 | retriever = VectorRetriever(query=query)
32 | hits = retriever.retrieve_top_k(k=6, to_expand_to_n_queries=5)
33 | reranked_hits = retriever.rerank(hits=hits, keep_top_k=5)
34 |
35 | logger.info("====== RETRIEVED DOCUMENTS ======")
36 | for rank, hit in enumerate(reranked_hits):
37 | logger.info(f"Rank = {rank} : {hit}")
38 |
--------------------------------------------------------------------------------
/src/feature_pipeline/scripts/bytewax_entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if [ "$DEBUG" = true ]
4 | then
5 | python -m bytewax.run "tools.run_real_time:build_flow(debug=True)"
6 | else
7 | if [ "$BYTEWAX_PYTHON_FILE_PATH" = "" ]
8 | then
9 | echo 'BYTEWAX_PYTHON_FILE_PATH is not set. Exiting...'
10 | exit 1
11 | fi
12 | python -m bytewax.run $BYTEWAX_PYTHON_FILE_PATH
13 | fi
14 |
15 |
16 | echo 'Process ended.'
17 |
18 | if [ "$BYTEWAX_KEEP_CONTAINER_ALIVE" = true ]
19 | then
20 | echo 'Keeping container alive...';
21 | while :; do sleep 1; done
22 | fi
--------------------------------------------------------------------------------
/src/feature_pipeline/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/feature_pipeline/utils/chunking.py:
--------------------------------------------------------------------------------
1 | from langchain.text_splitter import (
2 | RecursiveCharacterTextSplitter,
3 | SentenceTransformersTokenTextSplitter,
4 | )
5 |
6 | from config import settings
7 |
8 |
9 | def chunk_text(text: str) -> list[str]:
10 | character_splitter = RecursiveCharacterTextSplitter(
11 | separators=["\n\n"], chunk_size=500, chunk_overlap=0
12 | )
13 | text_split = character_splitter.split_text(text)
14 |
15 | token_splitter = SentenceTransformersTokenTextSplitter(
16 | chunk_overlap=50,
17 | tokens_per_chunk=settings.EMBEDDING_MODEL_MAX_INPUT_LENGTH,
18 | model_name=settings.EMBEDDING_MODEL_ID,
19 | )
20 | chunks = []
21 |
22 | for section in text_split:
23 | chunks.extend(token_splitter.split_text(section))
24 |
25 | return chunks
26 |
--------------------------------------------------------------------------------
/src/feature_pipeline/utils/embeddings.py:
--------------------------------------------------------------------------------
1 | from InstructorEmbedding import INSTRUCTOR
2 | from sentence_transformers.SentenceTransformer import SentenceTransformer
3 |
4 | from config import settings
5 |
6 |
7 | def embedd_text(text: str):
8 | model = SentenceTransformer(settings.EMBEDDING_MODEL_ID)
9 | return model.encode(text)
10 |
11 |
12 | def embedd_repositories(text: str):
13 | model = INSTRUCTOR("hkunlp/instructor-xl")
14 | sentence = text
15 | instruction = "Represent the structure of the repository"
16 | return model.encode([instruction, sentence])
17 |
--------------------------------------------------------------------------------
/src/inference_pipeline/aws/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/decodingml/llm-twin-course/457e550cb0d273d0c803ce1e2cb3ec0050c33e76/src/inference_pipeline/aws/__init__.py
--------------------------------------------------------------------------------
/src/inference_pipeline/aws/delete_sagemaker_endpoint.py:
--------------------------------------------------------------------------------
1 | from core import get_logger
2 |
3 | logger = get_logger(__file__)
4 |
5 | try:
6 | import boto3
7 | from botocore.exceptions import ClientError
8 | except ModuleNotFoundError:
9 | logger.warning(
10 | "Couldn't load AWS or SageMaker imports. Run 'poetry install --with aws' to support AWS."
11 | )
12 |
13 |
14 | from config import settings
15 |
16 |
17 | def delete_endpoint_and_config(endpoint_name) -> None:
18 | """
19 | Deletes an AWS SageMaker endpoint and its associated configuration.
20 | Args:
21 | endpoint_name (str): The name of the SageMaker endpoint to delete.
22 | Returns:
23 | None
24 | """
25 |
26 | try:
27 | sagemaker_client = boto3.client(
28 | "sagemaker",
29 | region_name=settings.AWS_REGION,
30 | aws_access_key_id=settings.AWS_ACCESS_KEY,
31 | aws_secret_access_key=settings.AWS_SECRET_KEY,
32 | )
33 | except Exception:
34 | logger.exception("Error creating SageMaker client")
35 |
36 | return
37 |
38 | # Get the endpoint configuration name
39 | try:
40 | response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
41 | config_name = response["EndpointConfigName"]
42 | except ClientError:
43 | logger.error("Error getting endpoint configuration and modelname.")
44 |
45 | return
46 |
47 | # Delete the endpoint
48 | try:
49 | sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
50 | logger.info(f"Endpoint '{endpoint_name}' deletion initiated.")
51 | except ClientError:
52 | logger.error("Error deleting endpoint")
53 |
54 | try:
55 | response = sagemaker_client.describe_endpoint_config(
56 | EndpointConfigName=endpoint_name
57 | )
58 | model_name = response["ProductionVariants"][0]["ModelName"]
59 | except ClientError:
60 | logger.error("Error getting model name.")
61 |
62 | # Delete the endpoint configuration
63 | try:
64 | sagemaker_client.delete_endpoint_config(EndpointConfigName=config_name)
65 | logger.info(f"Endpoint configuration '{config_name}' deleted.")
66 | except ClientError:
67 | logger.error("Error deleting endpoint configuration.")
68 |
69 | # Delete models
70 | try:
71 | sagemaker_client.delete_model(ModelName=model_name)
72 | logger.info(f"Model '{model_name}' deleted.")
73 | except ClientError:
74 | logger.error("Error deleting model.")
75 |
76 |
77 | if __name__ == "__main__":
78 | endpoint_name = settings.DEPLOYMENT_ENDPOINT_NAME
79 | logger.info(f"Attempting to delete endpoint: {endpoint_name}")
80 | delete_endpoint_and_config(endpoint_name=endpoint_name)
81 |
--------------------------------------------------------------------------------
/src/inference_pipeline/aws/deploy_sagemaker_endpoint.py:
--------------------------------------------------------------------------------
1 | from config import settings
2 | from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
3 |
4 |
5 | def main() -> None:
6 | assert settings.HUGGINGFACE_ACCESS_TOKEN, "HUGGINGFACE_ACCESS_TOKEN is required."
7 |
8 | env_vars = {
9 | "HF_MODEL_ID": settings.MODEL_ID,
10 | "SM_NUM_GPUS": "1", # Number of GPU used per replica.
11 | "HUGGING_FACE_HUB_TOKEN": settings.HUGGINGFACE_ACCESS_TOKEN,
12 | "MAX_INPUT_TOKENS": str(
13 | settings.MAX_INPUT_TOKENS
14 | ), # Max length of input tokens.
15 | "MAX_TOTAL_TOKENS": str(
16 | settings.MAX_TOTAL_TOKENS
17 | ), # Max length of the generation (including input text).
18 | "MAX_BATCH_TOTAL_TOKENS": str(
19 | settings.MAX_BATCH_TOTAL_TOKENS
20 | ), # Limits the number of tokens that can be processed in parallel during the generation.
21 | "MESSAGES_API_ENABLED": "true", # Enable/disable the messages API, following OpenAI's standard.
22 | "HF_MODEL_QUANTIZE": "bitsandbytes",
23 | }
24 |
25 | image_uri = get_huggingface_llm_image_uri("huggingface", version="2.2.0")
26 |
27 | model = HuggingFaceModel(
28 | env=env_vars, role=settings.AWS_ARN_ROLE, image_uri=image_uri
29 | )
30 |
31 | model.deploy(
32 | initial_instance_count=1,
33 | instance_type="ml.g5.2xlarge",
34 | container_startup_health_check_timeout=900,
35 | endpoint_name=settings.DEPLOYMENT_ENDPOINT_NAME,
36 | )
37 |
38 |
39 | if __name__ == "__main__":
40 | main()
41 |
--------------------------------------------------------------------------------
/src/inference_pipeline/config.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from pydantic_settings import BaseSettings, SettingsConfigDict
4 |
5 | ROOT_DIR = str(Path(__file__).parent.parent.parent)
6 |
7 |
8 | class Settings(BaseSettings):
9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8")
10 |
11 | # Embeddings config
12 | EMBEDDING_MODEL_ID: str = "BAAI/bge-small-en-v1.5"
13 | EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 512
14 | EMBEDDING_SIZE: int = 384
15 | EMBEDDING_MODEL_DEVICE: str = "cpu"
16 |
17 | # OpenAI config
18 | OPENAI_MODEL_ID: str = "gpt-4o-mini"
19 | OPENAI_API_KEY: str | None = None
20 |
21 | # QdrantDB config
22 | QDRANT_DATABASE_HOST: str = "localhost" # Or 'qdrant' if running inside Docker
23 | QDRANT_DATABASE_PORT: int = 6333
24 |
25 | USE_QDRANT_CLOUD: bool = (
26 | False # if True, fill in QDRANT_CLOUD_URL and QDRANT_APIKEY
27 | )
28 | QDRANT_CLOUD_URL: str = "str"
29 | QDRANT_APIKEY: str | None = None
30 |
31 | # RAG config
32 | TOP_K: int = 5
33 | KEEP_TOP_K: int = 5
34 | EXPAND_N_QUERY: int = 5
35 |
36 | # CometML config
37 | COMET_API_KEY: str
38 | COMET_WORKSPACE: str
39 | COMET_PROJECT: str = "llm-twin"
40 |
41 | # LLM Model config
42 | HUGGINGFACE_ACCESS_TOKEN: str | None = None
43 | MODEL_ID: str = "pauliusztin/LLMTwin-Llama-3.1-8B" # Change this with your Hugging Face model ID to test out your fine-tuned LLM
44 | DEPLOYMENT_ENDPOINT_NAME: str = "twin"
45 |
46 | MAX_INPUT_TOKENS: int = 1536 # Max length of input text.
47 | MAX_TOTAL_TOKENS: int = 2048 # Max length of the generation (including input text).
48 | MAX_BATCH_TOTAL_TOKENS: int = 2048 # Limits the number of tokens that can be processed in parallel during the generation.
49 |
50 | # AWS Authentication
51 | AWS_REGION: str = "eu-central-1"
52 | AWS_ACCESS_KEY: str | None = None
53 | AWS_SECRET_KEY: str | None = None
54 | AWS_ARN_ROLE: str | None = None
55 |
56 |
57 | settings = Settings()
58 |
--------------------------------------------------------------------------------
/src/inference_pipeline/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | # To mimic using multiple Python modules, such as 'core' and 'feature_pipeline',
5 | # we will add the './src' directory to the PYTHONPATH. This is not intended for
6 | # production use cases but for development and educational purposes.
7 | ROOT_DIR = str(Path(__file__).parent.parent.parent)
8 | sys.path.append(ROOT_DIR)
9 |
10 | from core import logger_utils
11 |
12 | logger = logger_utils.get_logger(__name__)
13 | logger.info(
14 | f"Added the following directory to PYTHONPATH to simulate multiple modules: {ROOT_DIR}"
15 | )
16 |
--------------------------------------------------------------------------------
/src/inference_pipeline/evaluation/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from config import settings
4 | from core.logger_utils import get_logger
5 | from core.opik_utils import create_dataset_from_artifacts
6 | from llm_twin import LLMTwin
7 | from opik.evaluation import evaluate
8 | from opik.evaluation.metrics import Hallucination, LevenshteinRatio, Moderation
9 |
10 | from .style import Style
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | def evaluation_task(x: dict) -> dict:
16 | inference_pipeline = LLMTwin(mock=False)
17 | result = inference_pipeline.generate(
18 | query=x["instruction"],
19 | enable_rag=False,
20 | )
21 | answer = result["answer"]
22 |
23 | return {
24 | "input": x["instruction"],
25 | "output": answer,
26 | "expected_output": x["content"],
27 | "reference": x["content"],
28 | }
29 |
30 |
31 | def main() -> None:
32 | parser = argparse.ArgumentParser(description="Evaluate monitoring script.")
33 | parser.add_argument(
34 | "--dataset_name",
35 | type=str,
36 | default="LLMTwinMonitoringDataset",
37 | help="Name of the dataset to evaluate",
38 | )
39 |
40 | args = parser.parse_args()
41 |
42 | dataset_name = args.dataset_name
43 |
44 | logger.info(f"Evaluating Opik dataset: '{dataset_name}'")
45 |
46 | dataset = create_dataset_from_artifacts(
47 | dataset_name="LLMTwinArtifactTestDataset",
48 | artifact_names=[
49 | "articles-instruct-dataset",
50 | "repositories-instruct-dataset",
51 | ],
52 | )
53 | if dataset is None:
54 | logger.error("Dataset can't be created. Exiting.")
55 | exit(1)
56 |
57 | experiment_config = {
58 | "model_id": settings.MODEL_ID,
59 | }
60 | scoring_metrics = [
61 | LevenshteinRatio(),
62 | Hallucination(),
63 | Moderation(),
64 | Style(),
65 | ]
66 | evaluate(
67 | dataset=dataset,
68 | task=evaluation_task,
69 | scoring_metrics=scoring_metrics,
70 | experiment_config=experiment_config,
71 | )
72 |
73 |
74 | if __name__ == "__main__":
75 | main()
76 |
--------------------------------------------------------------------------------
/src/inference_pipeline/evaluation/evaluate_monitoring.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import opik
4 | from config import settings
5 | from core.logger_utils import get_logger
6 | from opik.evaluation import evaluate
7 | from opik.evaluation.metrics import AnswerRelevance, Hallucination, Moderation
8 |
9 | from .style import Style
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | def evaluation_task(x: dict) -> dict:
15 | return {
16 | "input": x["input"]["query"],
17 | "context": x["expected_output"]["context"],
18 | "output": x["expected_output"]["answer"],
19 | }
20 |
21 |
22 | def main() -> None:
23 | parser = argparse.ArgumentParser(description="Evaluate monitoring script.")
24 | parser.add_argument(
25 | "--dataset_name",
26 | type=str,
27 | default="LLMTwinMonitoringDataset",
28 | help="Name of the dataset to evaluate",
29 | )
30 |
31 | args = parser.parse_args()
32 |
33 | dataset_name = args.dataset_name
34 |
35 | logger.info(f"Evaluating Opik dataset: '{dataset_name}'")
36 |
37 | client = opik.Opik()
38 | try:
39 | dataset = client.get_dataset(dataset_name)
40 | except Exception:
41 | logger.error(f"Monitoring dataset '{dataset_name}' not found in Opik. Exiting.")
42 | exit(1)
43 |
44 | experiment_config = {
45 | "model_id": settings.MODEL_ID,
46 | }
47 |
48 | scoring_metrics = [Hallucination(), Moderation(), AnswerRelevance(), Style()]
49 | evaluate(
50 | dataset=dataset,
51 | task=evaluation_task,
52 | scoring_metrics=scoring_metrics,
53 | experiment_config=experiment_config,
54 | )
55 |
56 |
57 | if __name__ == "__main__":
58 | main()
59 |
--------------------------------------------------------------------------------
/src/inference_pipeline/evaluation/evaluate_rag.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from core.config import settings
4 | from core.logger_utils import get_logger
5 | from core.opik_utils import create_dataset_from_artifacts
6 | from llm_twin import LLMTwin
7 | from opik.evaluation import evaluate
8 | from opik.evaluation.metrics import (
9 | ContextPrecision,
10 | ContextRecall,
11 | Hallucination,
12 | )
13 |
14 | settings.patch_localhost()
15 |
16 | logger = get_logger(__name__)
17 | logger.warning(
18 | "Patched settings to work with 'localhost' URLs. \
19 | Remove the 'settings.patch_localhost()' call from above when deploying or running inside Docker."
20 | )
21 |
22 |
23 | def evaluation_task(x: dict) -> dict:
24 | inference_pipeline = LLMTwin(mock=False)
25 | result = inference_pipeline.generate(
26 | query=x["instruction"],
27 | enable_rag=True,
28 | )
29 | answer = result["answer"]
30 | context = result["context"]
31 |
32 | return {
33 | "input": x["instruction"],
34 | "output": answer,
35 | "context": context,
36 | "expected_output": x["content"],
37 | "reference": x["content"],
38 | }
39 |
40 |
41 | def main() -> None:
42 | parser = argparse.ArgumentParser(description="Evaluate monitoring script.")
43 | parser.add_argument(
44 | "--dataset_name",
45 | type=str,
46 | default="LLMTwinMonitoringDataset",
47 | help="Name of the dataset to evaluate",
48 | )
49 |
50 | args = parser.parse_args()
51 |
52 | dataset_name = args.dataset_name
53 |
54 | logger.info(f"Evaluating Opik dataset: '{dataset_name}'")
55 |
56 | dataset = create_dataset_from_artifacts(
57 | dataset_name="LLMTwinArtifactTestDataset",
58 | artifact_names=[
59 | "articles-instruct-dataset",
60 | "posts-instruct-dataset",
61 | "repositories-instruct-dataset",
62 | ],
63 | )
64 | if dataset is None:
65 | logger.error("Dataset can't be created. Exiting.")
66 | exit(1)
67 |
68 | experiment_config = {
69 | "model_id": settings.MODEL_ID,
70 | "embedding_model_id": settings.EMBEDDING_MODEL_ID,
71 | }
72 | scoring_metrics = [
73 | Hallucination(),
74 | ContextRecall(),
75 | ContextPrecision(),
76 | ]
77 | evaluate(
78 | dataset=dataset,
79 | task=evaluation_task,
80 | scoring_metrics=scoring_metrics,
81 | experiment_config=experiment_config,
82 | )
83 |
84 |
85 | if __name__ == "__main__":
86 | main()
87 |
--------------------------------------------------------------------------------
/src/inference_pipeline/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | # To mimic using multiple Python modules, such as 'core' and 'feature_pipeline',
5 | # we will add the './src' directory to the PYTHONPATH. This is not intended for
6 | # production use cases but for development and educational purposes.
7 | ROOT_DIR = str(Path(__file__).parent.parent)
8 | sys.path.append(ROOT_DIR)
9 |
10 | from core import logger_utils
11 | from core.config import settings
12 | from llm_twin import LLMTwin
13 |
14 | settings.patch_localhost()
15 |
16 | logger = logger_utils.get_logger(__name__)
17 | logger.info(
18 | f"Added the following directory to PYTHONPATH to simulate multiple modules: {ROOT_DIR}"
19 | )
20 | logger.warning(
21 | "Patched settings to work with 'localhost' URLs. \
22 | Remove the 'settings.patch_localhost()' call from above when deploying or running inside Docker."
23 | )
24 |
25 |
26 | if __name__ == "__main__":
27 | inference_endpoint = LLMTwin(mock=False)
28 |
29 | query = """
30 | Hello I am Paul Iusztin.
31 |
32 | Could you draft an article paragraph discussing RAG?
33 | I'm particularly interested in how to design a RAG system.
34 | """
35 |
36 | response = inference_endpoint.generate(
37 | query=query, enable_rag=True, sample_for_evaluation=True
38 | )
39 |
40 | logger.info("=" * 50)
41 | logger.info(f"Query: {query}")
42 | logger.info("=" * 50)
43 | logger.info(f"Answer: {response['answer']}")
44 | logger.info("=" * 50)
45 |
--------------------------------------------------------------------------------
/src/inference_pipeline/prompt_templates.py:
--------------------------------------------------------------------------------
1 | from core.rag.prompt_templates import BasePromptTemplate
2 | from langchain.prompts import PromptTemplate
3 |
4 |
5 | class InferenceTemplate(BasePromptTemplate):
6 | simple_system_prompt: str = """
7 | You are an AI language model assistant. Your task is to generate a cohesive and concise response based on the user's instruction by using a similar writing style and voice.
8 | """
9 | simple_prompt_template: str = """
10 | ### Instruction:
11 | {question}
12 | """
13 |
14 | rag_system_prompt: str = """ You are a specialist in technical content writing. Your task is to create technical content based on the user's instruction given a specific context
15 | with additional information consisting of the user's previous writings and his knowledge.
16 |
17 | Here is a list of steps that you need to follow in order to solve this task:
18 |
19 | Step 1: You need to analyze the user's instruction.
20 | Step 2: You need to analyze the provided context and how the information in it relates to the user instruction.
21 | Step 3: Generate the content keeping in mind that it needs to be as cohesive and concise as possible based on the query. You will use the users writing style and voice inferred from the user instruction and context.
22 | First try to answer based on the context. If the context is irrelevant answer with "I cannot answer your question, as I don't have enough context."
23 | """
24 | rag_prompt_template: str = """
25 | ### Instruction:
26 | {question}
27 |
28 | ### Context:
29 | {context}
30 | """
31 |
32 | def create_template(self, enable_rag: bool = True) -> tuple[str, PromptTemplate]:
33 | if enable_rag is True:
34 | return self.rag_system_prompt, PromptTemplate(
35 | template=self.rag_prompt_template,
36 | input_variables=["question", "context"],
37 | )
38 |
39 | return self.simple_system_prompt, PromptTemplate(
40 | template=self.simple_prompt_template, input_variables=["question"]
41 | )
42 |
--------------------------------------------------------------------------------
/src/inference_pipeline/ui.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | # To mimic using multiple Python modules, such as 'core' and 'feature_pipeline',
5 | # we will add the './src' directory to the PYTHONPATH. This is not intended for
6 | # production use cases but for development and educational purposes.
7 | ROOT_DIR = str(Path(__file__).parent.parent)
8 | sys.path.append(ROOT_DIR)
9 |
10 | from core.config import settings
11 | from llm_twin import LLMTwin
12 |
13 | settings.patch_localhost()
14 |
15 |
16 | import gradio as gr
17 | from inference_pipeline.llm_twin import LLMTwin
18 |
19 | llm_twin = LLMTwin(mock=False)
20 |
21 |
22 | def predict(message: str, history: list[list[str]], author: str) -> str:
23 | """
24 | Generates a response using the LLM Twin, simulating a conversation with your digital twin.
25 |
26 | Args:
27 | message (str): The user's input message or question.
28 | history (List[List[str]]): Previous conversation history between user and twin.
29 | about_me (str): Personal context about the user to help personalize responses.
30 |
31 | Returns:
32 | str: The LLM Twin's generated response.
33 | """
34 |
35 | query = f"I am {author}. Write about: {message}"
36 | response = llm_twin.generate(
37 | query=query, enable_rag=True, sample_for_evaluation=False
38 | )
39 |
40 | return response["answer"]
41 |
42 |
43 | demo = gr.ChatInterface(
44 | predict,
45 | textbox=gr.Textbox(
46 | placeholder="Chat with your LLM Twin",
47 | label="Message",
48 | container=False,
49 | scale=7,
50 | ),
51 | additional_inputs=[
52 | gr.Textbox(
53 | "Paul Iusztin",
54 | label="Who are you?",
55 | )
56 | ],
57 | title="Your LLM Twin",
58 | description="""
59 | Chat with your personalized LLM Twin! This AI assistant will help you write content incorporating your style and voice.
60 | """,
61 | theme="soft",
62 | examples=[
63 | [
64 | "Draft a post about RAG systems.",
65 | "Paul Iusztin",
66 | ],
67 | [
68 | "Draft an article paragraph about vector databases.",
69 | "Paul Iusztin",
70 | ],
71 | [
72 | "Draft a post about LLM chatbots.",
73 | "Paul Iusztin",
74 | ],
75 | ],
76 | cache_examples=False,
77 | )
78 |
79 |
80 | if __name__ == "__main__":
81 | demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
82 |
--------------------------------------------------------------------------------
/src/inference_pipeline/utils.py:
--------------------------------------------------------------------------------
1 | from config import settings
2 | from transformers import AutoTokenizer
3 |
4 |
5 | def compute_num_tokens(text: str) -> int:
6 | tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_ID)
7 |
8 | return len(tokenizer.encode(text, add_special_tokens=False))
9 |
10 |
11 | def truncate_text_to_max_tokens(text: str, max_tokens: int) -> tuple[str, int]:
12 | """Truncates text to not exceed max_tokens while trying to preserve complete sentences.
13 |
14 | Args:
15 | text: The text to truncate
16 | max_tokens: Maximum number of tokens allowed
17 |
18 | Returns:
19 | Truncated text that fits within max_tokens and the number of tokens in the truncated text.
20 | """
21 |
22 | current_tokens = compute_num_tokens(text)
23 |
24 | if current_tokens <= max_tokens:
25 | return text, current_tokens
26 |
27 | tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_ID)
28 | tokens = tokenizer.encode(text, add_special_tokens=False)
29 |
30 | # Take first max_tokens tokens and decode
31 | truncated_tokens = tokens[:max_tokens]
32 | truncated_text = tokenizer.decode(truncated_tokens)
33 |
34 | # Try to end at last complete sentence
35 | last_period = truncated_text.rfind(".")
36 | if last_period > 0:
37 | truncated_text = truncated_text[: last_period + 1]
38 |
39 | truncated_tokens = compute_num_tokens(truncated_text)
40 |
41 | return truncated_text, truncated_tokens
42 |
--------------------------------------------------------------------------------
/src/training_pipeline/config.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from pydantic_settings import BaseSettings, SettingsConfigDict
4 |
5 | ROOT_DIR = str(Path(__file__).parent.parent.parent.parent)
6 |
7 |
8 | class Settings(BaseSettings):
9 | model_config = SettingsConfigDict(env_file=ROOT_DIR, env_file_encoding="utf-8")
10 |
11 | # Hugging Face config
12 | HUGGINGFACE_BASE_MODEL_ID: str = "meta-llama/Llama-3.1-8B"
13 | HUGGINGFACE_ACCESS_TOKEN: str | None = None
14 |
15 | # Comet config
16 | COMET_API_KEY: str | None = None
17 | COMET_WORKSPACE: str | None = None
18 | COMET_PROJECT: str = "llm-twin"
19 |
20 | DATASET_ID: str = "articles-instruct-dataset" # Comet artifact containing your fine-tuning dataset (available after generating the instruct dataset).
21 |
22 | # AWS config
23 | AWS_REGION: str = "eu-central-1"
24 | AWS_ACCESS_KEY: str | None = None
25 | AWS_SECRET_KEY: str | None = None
26 | AWS_ARN_ROLE: str | None = None
27 |
28 |
29 | settings = Settings()
30 |
--------------------------------------------------------------------------------
/src/training_pipeline/download_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | from comet_ml import Artifact, Experiment
5 | from comet_ml.artifacts import ArtifactAsset
6 | from config import settings
7 | from core import get_logger
8 | from datasets import Dataset # noqa: E402
9 |
10 | logger = get_logger(__file__)
11 |
12 |
13 | class DatasetClient:
14 | def __init__(
15 | self,
16 | output_dir: Path = Path("./finetuning_dataset"),
17 | ) -> None:
18 | self.output_dir = output_dir
19 | self.output_dir.mkdir(parents=True, exist_ok=True)
20 |
21 | def download_dataset(self, dataset_id: str, split: str = "train") -> Dataset:
22 | assert split in ["train", "test"], "Split must be either 'train' or 'test'"
23 |
24 | if "/" in dataset_id:
25 | tokens = dataset_id.split("/")
26 | assert (
27 | len(tokens) == 2
28 | ), f"Wrong format for the {dataset_id}. It should have a maximum one '/' character following the next template: 'comet_ml_workspace/comet_ml_artiface_name'"
29 | workspace, artifact_name = tokens
30 |
31 | experiment = Experiment(workspace=workspace)
32 | else:
33 | artifact_name = dataset_id
34 |
35 | experiment = Experiment()
36 |
37 | artifact = self._download_artifact(artifact_name, experiment)
38 | asset = self._artifact_to_asset(artifact, split)
39 | dataset = self._load_data(asset)
40 |
41 | experiment.end()
42 |
43 | return dataset
44 |
45 | def _download_artifact(self, artifact_name: str, experiment) -> Artifact:
46 | try:
47 | logged_artifact = experiment.get_artifact(artifact_name)
48 | artifact = logged_artifact.download(self.output_dir)
49 | except Exception as e:
50 | print(f"Error retrieving artifact: {str(e)}")
51 |
52 | raise
53 |
54 | print(f"Successfully downloaded '{artifact_name}' at location '{self.output_dir}'")
55 |
56 | return artifact
57 |
58 | def _artifact_to_asset(self, artifact: Artifact, split: str) -> ArtifactAsset:
59 | if len(artifact.assets) == 0:
60 | raise RuntimeError("Artifact has no assets")
61 | elif len(artifact.assets) != 2:
62 | raise RuntimeError(
63 | f"Artifact has more {len(artifact.assets)} assets, which is invalid. It should have only 2."
64 | )
65 |
66 | print(f"Picking split = '{split}'")
67 | asset = [asset for asset in artifact.assets if split in asset.logical_path][0]
68 |
69 | return asset
70 |
71 | def _load_data(self, asset: ArtifactAsset) -> Dataset:
72 | data_file_path = asset.local_path_or_data
73 | with open(data_file_path, "r") as file:
74 | data = json.load(file)
75 |
76 | dataset_dict = {k: [str(d[k]) for d in data] for k in data[0].keys()}
77 | dataset = Dataset.from_dict(dataset_dict)
78 |
79 | print(
80 | f"Successfully loaded dataset from artifact, num_samples = {len(dataset)}",
81 | )
82 |
83 | return dataset
84 |
85 |
86 | if __name__ == "__main__":
87 | dataset_client = DatasetClient()
88 | dataset_client.download_dataset(dataset_id=settings.DATASET_ID)
89 |
90 | logger.info(f"Data available at '{dataset_client.output_dir}'.")
91 |
--------------------------------------------------------------------------------
/src/training_pipeline/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.33.0
2 | torch==2.4.0
3 | transformers==4.43.3
4 | datasets==2.20.0
5 | peft==0.12.0
6 | trl==0.9.6
7 | bitsandbytes==0.43.3
8 | comet-ml==3.44.3
9 | flash-attn==2.3.6
10 | unsloth==2024.9.post2
11 |
--------------------------------------------------------------------------------