├── data └── document.pdf ├── pytest.ini ├── requirements.txt ├── Makefile ├── pyproject.toml ├── terraform ├── secrets.tf ├── bedrock_opensearch_access.tf ├── variables.tf └── opensearch_domain.tf ├── src ├── retrieve_endpoint.py ├── check_index_content.py ├── retrieve_secret.py ├── delete_index.py ├── create_index.py ├── ingest_docs_with_embeddings.py ├── generate_embeddings.py └── app.py ├── README.md ├── tests ├── test_index.py ├── test_json_processing.py └── test_secret_endpoint.py └── .gitignore /data/document.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zer0-A1/Scalable-RAG-in-AWS-with-Fargate/HEAD/data/document.pdf -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | integration: marks tests as integration tests (deselect with '-m "not integration"') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | PyMuPDF 2 | boto3 3 | langchain_aws 4 | langchain-community 5 | loguru 6 | ruff 7 | pytest 8 | opensearch-py -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile 2 | 3 | .PHONY: req lint clean 4 | 5 | # Variables 6 | PIP := pip 7 | RUFF := ruff 8 | PYTHONPATH := $(shell pwd)/src 9 | 10 | all: req lint test clean ## Run all tasks 11 | 12 | req: ## Install the requirements 13 | $(PIP) install -r requirements.txt 14 | 15 | lint: ## Run linter and code formatter (ruff) 16 | $(RUFF) check . --fix 17 | 18 | test: ## Run tests using pytest 19 | PYTHONPATH=$(PYTHONPATH) pytest tests/ 20 | 21 | clean: ## Clean up generated files 22 | rm -rf __pycache__ 23 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | # Define maximum line length 3 | line-length = 88 4 | 5 | [tool.ruff.lint] 6 | # Define the rules to enforce, including flake8 and isort rules 7 | select = [ 8 | "E", # Error codes (from flake8) 9 | "F", # Failures (from flake8) 10 | "W", # Warnings (from flake8) 11 | "C90", # Custom/Specific categories 12 | "I", # Import sorting (from isort) 13 | "D", # Docstring conventions (from pydocstyle) 14 | ] 15 | 16 | # Exclude specific error codes 17 | ignore = ["E501", "D211", "D212", "W291", "D205", "D401", "W293"] 18 | 19 | -------------------------------------------------------------------------------- /terraform/secrets.tf: -------------------------------------------------------------------------------- 1 | resource "random_password" "master_password" { 2 | length = 32 3 | special = true 4 | override_special = "!#$%&*()-_=+[]{}<>:?" 5 | min_numeric = 1 6 | min_special = 1 7 | min_upper = 1 8 | } 9 | 10 | resource "aws_secretsmanager_secret" "opensearch_master_password" { 11 | name_prefix = var.name 12 | description = "Master password for OpenSearch domain" 13 | } 14 | 15 | resource "aws_secretsmanager_secret_version" "opensearch_master_password" { 16 | secret_id = aws_secretsmanager_secret.opensearch_master_password.id 17 | secret_string = random_password.master_password.result 18 | } 19 | 20 | output "secret_name" { 21 | value = aws_secretsmanager_secret.opensearch_master_password.name 22 | } -------------------------------------------------------------------------------- /terraform/bedrock_opensearch_access.tf: -------------------------------------------------------------------------------- 1 | data "aws_iam_policy_document" "bedrock" { 2 | statement { 3 | sid = 1 4 | actions = ["bedrock:InvokeModel"] 5 | resources = ["arn:aws:bedrock:*::foundation-model/*"] 6 | } 7 | statement { 8 | sid = 2 9 | actions = ["sts:AssumeRole"] 10 | resources = ["*"] 11 | } 12 | } 13 | 14 | resource "aws_iam_policy" "bedrock" { 15 | name = "bedrock" 16 | policy = data.aws_iam_policy_document.bedrock.json 17 | } 18 | 19 | # Using :root allows any user in the account to assume the role 20 | resource "aws_iam_role" "bedrock" { 21 | name = "bedrock" 22 | assume_role_policy = < 4 | cover_gke_medium 5 |

6 | 7 | This repository contains a full RAG application using Terraform as IaC, LangChain as framework, AWS Bedrock as LLM and Embedding Models, AWS OpenSearch as a vector database, and deployment on AWS OpenSearch endpoint. 8 | 9 | Main Steps 10 | 11 | - **Data Ingestion**: Load data to an Opensearch Index 12 | - **Embedding and Model**: Bedrock Titan 13 | - **Vector Store and Endpoint**: Opensearch 14 | - **IaC**: Terraform 15 | - **data**: original pdf document and generated json file with embeddings 16 | 17 | Feel free to ⭐ and clone this repo 😉 18 | 19 | ## Tech Stack 20 | 21 | ![Visual Studio Code](https://img.shields.io/badge/Visual%20Studio%20Code-0078d7.svg?style=for-the-badge&logo=visual-studio-code&logoColor=white) 22 | ![Python](https://img.shields.io/badge/python-3670A0?style=for-the-badge&logo=python&logoColor=ffdd54) 23 | ![Anaconda](https://img.shields.io/badge/Anaconda-%2344A833.svg?style=for-the-badge&logo=anaconda&logoColor=white) 24 | ![Linux](https://img.shields.io/badge/Linux-FCC624?style=for-the-badge&logo=linux&logoColor=white) 25 | ![Ubuntu](https://img.shields.io/badge/Ubuntu-E95420?style=for-the-badge&logo=ubuntu&logoColor=white) 26 | ![Git](https://img.shields.io/badge/git-%23F05033.svg?style=for-the-badge&logo=git&logoColor=white) 27 | ![AWS](https://img.shields.io/badge/AWS-%23FF9900.svg?style=for-the-badge&logo=amazon-aws&logoColor=white) 28 | 29 | 30 | ## Project Structure 31 | 32 | The project has been structured with the following files: 33 | 34 | - `terraform:` IaC 35 | - `tests`: unittest and mock tests 36 | - `src:` scripts with the app logic 37 | - `requirements.txt:` project requirements 38 | - `Makefile:` command for testing, linting and formating 39 | - `pyproject.toml:` linting/formatting requirements 40 | 41 | 42 | 43 | ## Project Set Up 44 | 45 | The Python version used for this project is Python 3.11. 46 | 47 | 1. Clone the repo (or download it as a zip file): 48 | 49 | ```bash 50 | git clone https://github.com/T-AIMaven/aws-bedrock-opensearch-langchain.git 51 | ``` 52 | 53 | 2. Create the virtual environment named `main-env` using Conda with Python version 3.10: 54 | 55 | ```bash 56 | conda create -n main-env python=3.11 57 | conda activate main-env 58 | ``` 59 | 60 | 3. Install the requirements.txt: 61 | 62 | ```bash 63 | pip install -r requirements.txt 64 | 65 | or 66 | 67 | make req 68 | ``` 69 | 70 | 4. Create infrastructure from the terraform folder. This can take up to 30 minutes 71 | 72 | ```bash 73 | conda install conda-forge::terraform 74 | terraform init 75 | terraform plan 76 | terraform apply 77 | ``` 78 | 79 | 5. Generate embeddings from documents: 80 | 81 | ```bash 82 | python src/generate_embeddings.py 83 | ``` 84 | 85 | 6. Create Index: 86 | 87 | ```bash 88 | python src/create_index.py 89 | ``` 90 | 91 | 7. Ingest documents into index: 92 | 93 | ```bash 94 | python src/ingest_docs_with_embeddings.py 95 | ``` 96 | 97 | 8. Test the app to get a reply: 98 | 99 | ```bash 100 | python src/app.py 101 | ``` 102 | 103 | The app contains a question. You can change it accordingly to test other scenarios. 104 | -------------------------------------------------------------------------------- /src/check_index_content.py: -------------------------------------------------------------------------------- 1 | """File to check the content of the Opensearch Index.""" 2 | 3 | import json 4 | import os 5 | import sys 6 | 7 | from loguru import logger 8 | from opensearchpy import OpenSearch, RequestsHttpConnection 9 | 10 | from retrieve_endpoint import get_opensearch_endpoint 11 | from retrieve_secret import get_secret 12 | 13 | # Loguru logger 14 | logger.remove() 15 | logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO")) 16 | 17 | def get_opensearch_client(host, username, password): 18 | """ 19 | Create and return an OpenSearch client. 20 | 21 | Args: 22 | host (str): The host URL of the OpenSearch cluster. 23 | username (str): The username for authentication. 24 | password (str): The password for authentication. 25 | 26 | Returns: 27 | OpenSearch: An instance of the OpenSearch client. 28 | 29 | """ 30 | logger.info(f"Creating OpenSearch client for host: {host}") 31 | return OpenSearch( 32 | hosts = [{'host': host, 'port': 443}], 33 | http_auth = (username, password), 34 | use_ssl = True, 35 | verify_certs = True, 36 | connection_class = RequestsHttpConnection, 37 | timeout=30 38 | ) 39 | 40 | def check_index_content(client, index_name): 41 | """ 42 | Check and display content of the specified index. 43 | 44 | This function retrieves and displays index statistics, mapping, and sample documents. 45 | 46 | Args: 47 | client (OpenSearch): The OpenSearch client instance. 48 | index_name (str): The name of the index to check. 49 | 50 | Raises: 51 | Exception: If there's an error retrieving index information. 52 | 53 | """ 54 | try: 55 | # Get index stats 56 | stats = client.indices.stats(index=index_name) 57 | logger.info(f"Index stats for '{index_name}':") 58 | logger.info(f" Total documents: {stats['indices'][index_name]['total']['docs']['count']}") 59 | logger.info(f" Total size: {stats['indices'][index_name]['total']['store']['size_in_bytes']} bytes") 60 | 61 | # Get mapping 62 | mapping = client.indices.get_mapping(index=index_name) 63 | logger.info("\nIndex mapping:") 64 | logger.info(json.dumps(mapping, indent=2)) 65 | 66 | # Sample documents 67 | search_results = client.search(index=index_name, body={"query": {"match_all": {}}, "size": 5}) 68 | logger.info("\nSample documents (up to 5):") 69 | for hit in search_results['hits']['hits']: 70 | logger.info(json.dumps(hit['_source'], indent=2)) 71 | except Exception as e: 72 | logger.error(f"Error checking index content: {str(e)}") 73 | raise 74 | 75 | if __name__ == "__main__": 76 | """ 77 | Main execution block of the script. 78 | 79 | This block sets up the OpenSearch client using configuration from imported modules, 80 | then proceeds to check and display the content of the specified index. 81 | """ 82 | # Domain Values 83 | region = "eu-central-1" 84 | index_name = "rag" 85 | username = "rag" 86 | domain_name = "rag" 87 | 88 | try: 89 | host = get_opensearch_endpoint(domain_name, region) 90 | logger.info(f"Retrieved OpenSearch endpoint: {host}") 91 | 92 | password = get_secret() 93 | logger.info("Retrieved secret successfully") 94 | 95 | client = get_opensearch_client(host, username, password) 96 | check_index_content(client, index_name) 97 | 98 | logger.info("Script execution completed successfully") 99 | except Exception as e: 100 | logger.error(f"An error occurred during script execution: {str(e)}") 101 | sys.exit(1) 102 | -------------------------------------------------------------------------------- /tests/test_index.py: -------------------------------------------------------------------------------- 1 | """Unit tests for the Index Creation.""" 2 | 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | from opensearchpy import OpenSearch, RequestsHttpConnection 7 | 8 | from create_index import create_index 9 | from retrieve_endpoint import get_opensearch_endpoint 10 | from retrieve_secret import get_secret 11 | 12 | 13 | @pytest.fixture 14 | def mock_opensearch_client(): 15 | """ 16 | Create a mock Opensearch client for testing. 17 | 18 | Returns: 19 | MagicMock: A mock object representing the Opensearch client 20 | 21 | """ 22 | with patch('create_index.OpenSearch') as mock_opensearch: 23 | mock_client = MagicMock() 24 | mock_opensearch.return_value = mock_client 25 | yield mock_client 26 | 27 | def test_create_index(mock_opensearch_client): 28 | """ 29 | Test the create_index function to ensure it creates an index successfully. 30 | 31 | This test mocks the OpenSearch client and verifies that the index creation 32 | method is called with the correct parameters. 33 | 34 | Args: 35 | mock_opensearch_client (MagicMock): A mock object for the OpenSearch client. 36 | 37 | Raises: 38 | AssertionError: If any of the assertions fail. 39 | 40 | """ 41 | # Arrange 42 | host = 'test-host' 43 | index_name = 'test-index' 44 | username = 'test-user' 45 | password = 'test-password' 46 | # Set up the mock response 47 | mock_response = {'acknowledged': True, 'shards_acknowledged': True, 'index': index_name} 48 | mock_opensearch_client.indices.create.return_value = mock_response 49 | 50 | # Act 51 | create_index(host, index_name, username, password) 52 | 53 | # Assert 54 | mock_opensearch_client.indices.create.assert_called_once() 55 | 56 | # Check if the index creation method was called with the correct parameters 57 | call_args = mock_opensearch_client.indices.create.call_args 58 | assert call_args[0][0] == index_name 59 | assert 'body' in call_args[1] 60 | 61 | # Check if the index body contains the expected settings and mappings 62 | index_body = call_args[1]['body'] 63 | assert 'settings' in index_body 64 | assert 'mappings' in index_body 65 | assert index_body['settings']['index']['number_of_shards'] == 3 66 | assert index_body['settings']['index']['number_of_replicas'] == 2 67 | assert index_body['settings']['index']['knn'] is True 68 | assert index_body['mappings']['properties']['vector_field']['type'] == 'knn_vector' 69 | assert index_body['mappings']['properties']['vector_field']['dimension'] == 1536 70 | 71 | @pytest.mark.integration 72 | def test_index_exists_after_creation(): 73 | """ 74 | Integration test to verify that the index exists after creation. 75 | 76 | This test creates an actual connection to OpenSearch and verifies 77 | that the index exists after calling the create_index function. 78 | 79 | Note: This test requires actual OpenSearch credentials and will make 80 | a real connection. Use with caution. 81 | 82 | Raises: 83 | AssertionError: If the index does not exist after creation. 84 | 85 | """ 86 | # Arrange 87 | domain_name="rag" 88 | region = "eu-central-1" 89 | host = get_opensearch_endpoint(domain_name, region) 90 | index_name = "rag-test" 91 | username = "rag" 92 | password = get_secret() 93 | 94 | 95 | # Act 96 | create_index(host, index_name, username, password) 97 | 98 | # Assert 99 | client = OpenSearch( 100 | hosts = [{'host': host, 'port': 443}], 101 | http_auth = (username, password), 102 | use_ssl = True, 103 | verify_certs = True, 104 | connection_class = RequestsHttpConnection 105 | ) 106 | 107 | assert client.indices.exists(index=index_name), f"Index '{index_name}' does not exist after creation" 108 | 109 | # Clean up - delete the test index 110 | client.indices.delete(index=index_name) 111 | -------------------------------------------------------------------------------- /src/retrieve_secret.py: -------------------------------------------------------------------------------- 1 | """File to generate retrieve the AWS Secret.""" 2 | 3 | import json 4 | import os 5 | import sys 6 | 7 | import boto3 8 | from botocore.exceptions import ClientError 9 | from loguru import logger 10 | 11 | # logger configuration 12 | logger.remove() 13 | logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO")) 14 | 15 | def get_secret_name(region_name, domain_name): 16 | """ 17 | Retrieve the secret name from AWS Secrets Manager based on a domain name pattern. 18 | 19 | List all secrets in the specified region and search for a secret whose name 20 | matches the domain name pattern. 21 | 22 | Args: 23 | region_name (str): The AWS region where the secrets are stored. 24 | domain_name (str): The pattern to match against secret names. 25 | 26 | Returns: 27 | str or None: The name of the secret if a match is found, otherwise None. 28 | 29 | Raises: 30 | ClientError: If there is an error with the AWS Secrets Manager client. 31 | 32 | """ 33 | client = boto3.client('secretsmanager', region_name=region_name) 34 | try: 35 | logger.info(f"Attempting to list secrets in region {region_name}") 36 | response = client.list_secrets() 37 | secrets = response['SecretList'] 38 | 39 | # The secret name pattern for OpenSearch domains 40 | secret_pattern = f"{domain_name}" 41 | 42 | for secret in secrets: 43 | if secret_pattern in secret['Name']: 44 | logger.info(f"Found matching secret: {secret['Name']}") 45 | return secret['Name'] 46 | 47 | logger.warning(f"No secret found matching pattern: {secret_pattern}") 48 | return None 49 | except ClientError as e: 50 | logger.error(f"An error occurred while listing secrets: {e}") 51 | raise e 52 | 53 | def get_secret(): 54 | """ 55 | Retrieve and display a secret from AWS Secrets Manager. 56 | 57 | Uses the domain name pattern to find the appropriate secret, retrieves it, 58 | and logs the contents. Masks secret values in the output. 59 | 60 | Returns: 61 | str: The retrieved secret as a string. 62 | 63 | Raises: 64 | ClientError: If there is an error with the AWS Secrets Manager client. 65 | 66 | """ 67 | region_name = "eu-central-1" 68 | domain_name = "rag" 69 | secret_name = get_secret_name(region_name, domain_name) 70 | 71 | # Create a Secrets Manager client 72 | session = boto3.session.Session() 73 | client = session.client( 74 | service_name='secretsmanager', 75 | region_name=region_name 76 | ) 77 | 78 | try: 79 | logger.info(f"Attempting to retrieve secret: {secret_name}") 80 | get_secret_value_response = client.get_secret_value( 81 | SecretId=secret_name 82 | ) 83 | except ClientError as e: 84 | logger.error(f"An error occurred: {e}") 85 | raise e 86 | else: 87 | # Decrypt the secret using the associated KMS key 88 | secret = get_secret_value_response['SecretString'] 89 | 90 | logger.info("Secret retrieved successfully") 91 | 92 | try: 93 | # Attempt to parse the secret as JSON 94 | secret_dict = json.loads(secret) 95 | logger.info("Secret contents (key-value pairs):") 96 | for key, value in secret_dict.items(): 97 | logger.info(f" {key}: {'*' * len(value)}") # Mask the actual values 98 | except json.JSONDecodeError: 99 | logger.info("Secret is not in JSON format. Raw secret (masked):") 100 | logger.info('*' * len(secret)) 101 | 102 | return secret 103 | 104 | if __name__ == "__main__": 105 | """ 106 | Execute the script to retrieve and display a secret from AWS Secrets Manager. 107 | 108 | Calls the function to get the secret and handles exceptions that may occur during execution. 109 | """ 110 | try: 111 | retrieved_secret = get_secret() 112 | logger.info("Script executed successfully") 113 | except Exception as e: 114 | logger.error(f"An error occurred during script execution: {e}") 115 | -------------------------------------------------------------------------------- /src/delete_index.py: -------------------------------------------------------------------------------- 1 | """File to delete the Opensearch Index.""" 2 | 3 | import os 4 | import sys 5 | 6 | from loguru import logger 7 | from opensearchpy import OpenSearch, RequestsHttpConnection 8 | 9 | from retrieve_endpoint import get_opensearch_endpoint 10 | from retrieve_secret import get_secret 11 | 12 | # logger configuration 13 | logger.remove() 14 | logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO")) 15 | 16 | def get_opensearch_cluster_client(domain_name, host, username, password): 17 | """ 18 | Create and return an OpenSearch client for the specified cluster. 19 | 20 | Args: 21 | domain_name (str): The name of the OpenSearch domain. 22 | host (str): The host URL of the OpenSearch cluster. 23 | username (str): The username for authentication. 24 | password (str, optional): The password for authentication. Defaults to a preset value. 25 | 26 | Returns: 27 | OpenSearch: An instance of the OpenSearch client. 28 | 29 | """ 30 | opensearch_client = OpenSearch( 31 | hosts = [{'host': host, 'port': 443}], 32 | http_auth = (username, password), 33 | use_ssl = True, 34 | verify_certs = True, 35 | connection_class = RequestsHttpConnection, 36 | timeout=30 37 | ) 38 | return opensearch_client 39 | 40 | def delete_opensearch_index(opensearch_client, index_name): 41 | """ 42 | Delete a specific index from the OpenSearch cluster. 43 | 44 | Args: 45 | opensearch_client (OpenSearch): The OpenSearch client instance. 46 | index_name (str): The name of the index to delete. 47 | 48 | Returns: 49 | bool: True if the index was successfully deleted, False otherwise. 50 | 51 | """ 52 | logger.info(f"Trying to delete index {index_name}") 53 | try: 54 | response = opensearch_client.indices.delete(index=index_name) 55 | logger.info(f"Index {index_name} deleted") 56 | return response['acknowledged'] 57 | except Exception as e: 58 | logger.error(f"Error deleting index {index_name}: {str(e)}") 59 | return False 60 | 61 | def list_all_indices(opensearch_client): 62 | """ 63 | List all indices in the OpenSearch cluster. 64 | 65 | Args: 66 | opensearch_client (OpenSearch): The OpenSearch client instance. 67 | 68 | Returns: 69 | list: A list of all index names in the cluster. 70 | 71 | """ 72 | try: 73 | indices = opensearch_client.indices.get_alias("*") 74 | return list(indices.keys()) 75 | except Exception as e: 76 | logger.error(f"Error listing indices: {str(e)}") 77 | return [] 78 | 79 | def delete_all_indices(opensearch_client): 80 | """ 81 | Delete all non-system indices in the OpenSearch cluster. 82 | 83 | This function lists all indices, skips system indices (those starting with '.'), 84 | and attempts to delete each non-system index. 85 | 86 | Args: 87 | opensearch_client (OpenSearch): The OpenSearch client instance. 88 | 89 | """ 90 | indices = list_all_indices(opensearch_client) 91 | logger.info(f"Found {len(indices)} indices") 92 | 93 | for index in indices: 94 | if index.startswith('.'): 95 | logger.info(f"Skipping system index: {index}") 96 | continue 97 | success = delete_opensearch_index(opensearch_client, index) 98 | if not success: 99 | logger.warning(f"Failed to delete index: {index}") 100 | 101 | remaining_indices = list_all_indices(opensearch_client) 102 | logger.info(f"Remaining indices after deletion: {remaining_indices}") 103 | 104 | if __name__ == "__main__": 105 | """ 106 | Main execution block of the script. 107 | 108 | This block sets up the OpenSearch client using environment variables and 109 | configuration from imported modules, then proceeds to delete all non-system 110 | indices in the specified OpenSearch cluster. 111 | """ 112 | domain_name = "rag" 113 | region = "eu-central-1" 114 | username = "rag" 115 | password = get_secret() 116 | host = get_opensearch_endpoint(domain_name, region) 117 | 118 | client = get_opensearch_cluster_client(domain_name, host, username, password) 119 | 120 | delete_all_indices(client) 121 | 122 | logger.info("Script execution completed") 123 | -------------------------------------------------------------------------------- /src/create_index.py: -------------------------------------------------------------------------------- 1 | """File to create the Opensearch Index.""" 2 | 3 | import os 4 | import sys 5 | 6 | import boto3 7 | from loguru import logger 8 | from opensearchpy import OpenSearch, RequestsHttpConnection 9 | 10 | from retrieve_secret import get_secret, get_secret_name 11 | 12 | # logger 13 | logger.remove() 14 | logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO")) 15 | 16 | def get_opensearch_endpoint(domain_name, region_name): 17 | """ 18 | Retrieve the endpoint for an OpenSearch domain. 19 | 20 | Args: 21 | domain_name (str): The name of the OpenSearch domain. 22 | region_name (str): The AWS region where the domain is located. 23 | 24 | Returns: 25 | str: The endpoint URL for the OpenSearch domain. 26 | 27 | Raises: 28 | Exception: If there's an error retrieving the OpenSearch endpoint. 29 | 30 | """ 31 | client = boto3.client('es', region_name=region_name) 32 | try: 33 | response = client.describe_elasticsearch_domain(DomainName=domain_name) 34 | return response['DomainStatus']['Endpoint'] 35 | except Exception as e: 36 | logger.error(f"Error retrieving OpenSearch endpoint: {e}") 37 | raise 38 | 39 | def create_index(host, index_name, username, password): 40 | """ 41 | Create an index in OpenSearch with specified settings and mappings. 42 | 43 | This function creates an OpenSearch client and uses it to create an index 44 | with predefined settings for shards, replicas, and KNN, as well as 45 | mappings for text and vector fields. 46 | 47 | Args: 48 | host (str): The OpenSearch host URL. 49 | index_name (str): The name of the index to create. 50 | username (str): The username for OpenSearch authentication. 51 | password (str): The password for OpenSearch authentication. 52 | 53 | Raises: 54 | Exception: If there's an error creating the index. 55 | 56 | """ 57 | # Create the OpenSearch client 58 | client = OpenSearch( 59 | hosts = [{'host': host, 'port': 443}], 60 | http_auth = (username, password), 61 | use_ssl = True, 62 | verify_certs = True, 63 | connection_class = RequestsHttpConnection, 64 | timeout=30 65 | ) 66 | 67 | # Define the index settings and mappings 68 | index_body = { 69 | "settings": { 70 | "index": { 71 | "number_of_shards": 3, 72 | "number_of_replicas": 2, 73 | "knn": True, 74 | "knn.space_type": "cosinesimil" 75 | } 76 | }, 77 | "mappings": { 78 | "properties": { 79 | "text": { 80 | "type": "text", 81 | "analyzer": "standard" 82 | }, 83 | "vector_field": { 84 | "type": "knn_vector", 85 | "dimension": 1536, 86 | } 87 | } 88 | } 89 | } 90 | 91 | # Create the index 92 | try: 93 | response = client.indices.create(index_name, body=index_body) 94 | logger.info(f"Index '{index_name}' created successfully: {response}") 95 | except Exception as e: 96 | logger.error(f"Error creating index: {e}") 97 | raise 98 | 99 | if __name__ == "__main__": 100 | """ 101 | Main execution block of the script. 102 | 103 | This block retrieves necessary configuration information, gets the OpenSearch 104 | endpoint and password, and creates an index in OpenSearch. It handles exceptions 105 | and logs relevant information during the process. 106 | """ 107 | 108 | region_name = "eu-central-1" 109 | domain_name = "rag" 110 | secret_name = get_secret_name(region_name, domain_name) 111 | index_name = "rag" 112 | username = "rag" 113 | 114 | try: 115 | host = get_opensearch_endpoint(domain_name, region_name) 116 | logger.info(f"OpenSearch endpoint: {host}") 117 | 118 | password = get_secret() 119 | logger.info("Retrieved secret successfully") 120 | 121 | # Create the index 122 | create_index(host, index_name, username, password) 123 | 124 | logger.info("Script executed successfully") 125 | except Exception as e: 126 | logger.error(f"An error occurred during script execution: {e}") 127 | -------------------------------------------------------------------------------- /src/ingest_docs_with_embeddings.py: -------------------------------------------------------------------------------- 1 | """File to ingest documents in the Opensearch Index.""" 2 | 3 | import json 4 | import os 5 | import sys 6 | 7 | import boto3 8 | from loguru import logger 9 | from opensearchpy import OpenSearch, RequestsHttpConnection 10 | 11 | from retrieve_secret import get_secret 12 | 13 | # logger configuration 14 | logger.remove() 15 | logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO")) 16 | 17 | def get_opensearch_endpoint(domain_name, region_name): 18 | """ 19 | Retrieve the endpoint for an OpenSearch domain. 20 | 21 | Args: 22 | domain_name (str): The name of the OpenSearch domain. 23 | region_name (str): The AWS region where the domain is located. 24 | 25 | Returns: 26 | str: The endpoint URL for the OpenSearch domain. 27 | 28 | Raises: 29 | Exception: If there's an error retrieving the OpenSearch endpoint. 30 | 31 | """ 32 | client = boto3.client('es', region_name=region_name) 33 | try: 34 | response = client.describe_elasticsearch_domain(DomainName=domain_name) 35 | return response['DomainStatus']['Endpoint'] 36 | except Exception as e: 37 | logger.error(f"Error retrieving OpenSearch endpoint: {e}") 38 | raise 39 | 40 | def ingest_data(client, index_name, data): 41 | """ 42 | Ingest data into the specified OpenSearch index. 43 | 44 | Args: 45 | client (OpenSearch): The OpenSearch client instance. 46 | index_name (str): The name of the index to ingest data into. 47 | data (list): A list of dictionaries containing the data to be ingested. 48 | 49 | Raises: 50 | ValueError: If the input data is not in the expected format. 51 | 52 | """ 53 | if not isinstance(data, list): 54 | logger.error(f"Error: Expected a list of dictionaries, but got {type(data)}") 55 | logger.error(f"Data sample: {str(data)[:200]}...") # Log first 200 characters of data 56 | raise ValueError("Invalid data format") 57 | 58 | for page in data: 59 | try: 60 | document = { 61 | "text": page.get('text'), 62 | "vector_field": page.get('vector_field') 63 | } 64 | 65 | response = client.index( 66 | index=index_name, 67 | body=document, 68 | ) 69 | logger.info(f"Indexed page {page.get('page_number')}: {response['result']}") 70 | except Exception as e: 71 | logger.error(f"Error indexing page {page.get('page_number')}: {e}") 72 | logger.error(f"Problematic page data: {page}") 73 | 74 | def main(): 75 | """ 76 | Main function to orchestrate the data ingestion process. 77 | 78 | This function retrieves necessary configuration, sets up the OpenSearch client, 79 | loads data from a JSON file, and ingests it into the specified OpenSearch index. 80 | """ 81 | region_name = "eu-central-1" 82 | domain_name = "rag" 83 | index_name = "rag" 84 | username = "rag" 85 | 86 | try: 87 | # Get the OpenSearch endpoint 88 | host = get_opensearch_endpoint(domain_name, region_name) 89 | logger.info(f"OpenSearch endpoint: {host}") 90 | 91 | # Get the master password from Secrets Manager 92 | password = get_secret() 93 | 94 | client = OpenSearch( 95 | hosts = [{'host': host, 'port': 443}], 96 | http_auth = (username, password), 97 | use_ssl = True, 98 | verify_certs = True, 99 | connection_class = RequestsHttpConnection, 100 | timeout=30 101 | ) 102 | 103 | # Load the data from the JSON file 104 | with open('../data/text_with_embeddings.json', 'r') as f: 105 | data = json.load(f) 106 | 107 | logger.info(f"Loaded data type: {type(data)}") 108 | if isinstance(data, list) and len(data) > 0: 109 | logger.info(f"First item keys: {', '.join(data[0].keys())}") 110 | logger.info(f"Number of pages: {len(data)}") 111 | else: 112 | logger.warning("Data is empty or not in the expected format") 113 | 114 | # Ingest the data into OpenSearch 115 | ingest_data(client, index_name, data) 116 | 117 | logger.info("Data ingestion completed successfully") 118 | except Exception as e: 119 | logger.error(f"An error occurred during script execution: {e}") 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | 165 | # Local .terraform directories 166 | **/.terraform/* 167 | terraform/.terraform.lock.hcl 168 | # .tfstate files 169 | *.tfstate 170 | *.tfstate.* 171 | 172 | # Crash log files 173 | crash.log 174 | crash.*.log 175 | 176 | # Exclude all .tfvars files, which are likely to contain sensitive data, such as 177 | # password, private keys, and other secrets. These should not be part of version 178 | # control as they are data points which are potentially sensitive and subject 179 | # to change depending on the environment. 180 | *.tfvars 181 | *.tfvars.json 182 | 183 | # Ignore override files as they are usually used to override resources locally and so 184 | # are not checked in 185 | override.tf 186 | override.tf.json 187 | *_override.tf 188 | *_override.tf.json 189 | 190 | # Ignore transient lock info files created by terraform apply 191 | .terraform.tfstate.lock.info 192 | 193 | # Include override files you do wish to add to version control using negated pattern 194 | # !example_override.tf 195 | 196 | # Include tfplan files to ignore the plan output of command: terraform plan -out=tfplan 197 | # example: *tfplan* 198 | 199 | # Ignore CLI configuration files 200 | .terraformrc 201 | terraform.rc -------------------------------------------------------------------------------- /src/generate_embeddings.py: -------------------------------------------------------------------------------- 1 | 2 | """File to generate embeddings from PDF.""" 3 | 4 | import json 5 | import os 6 | import sys 7 | 8 | import boto3 9 | import fitz 10 | from langchain_aws import BedrockEmbeddings 11 | from loguru import logger 12 | 13 | # logger 14 | logger.remove() 15 | logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO")) 16 | 17 | 18 | def extract_text_from_pdf(pdf_path): 19 | """ 20 | Extracts text from a PDF file and returns it as a list of dictionaries. 21 | 22 | Each dictionary contains the page number and the extracted text for that page. 23 | 24 | Args: 25 | pdf_path (str): The file path to the PDF document. 26 | 27 | Returns: 28 | list: A list of dictionaries where each dictionary contains 'page_number' and 'text'. 29 | 30 | Raises: 31 | Exception: If the PDF file cannot be opened or read. 32 | 33 | """ 34 | logger.info(f"Opening PDF file: {pdf_path}") 35 | 36 | try: 37 | doc = fitz.open(pdf_path) 38 | logger.info("PDF file opened successfully.") 39 | except Exception as e: 40 | logger.error(f"Failed to open PDF: {e}") 41 | raise 42 | 43 | pages = [] 44 | 45 | logger.info(f"Extracting text from {len(doc)} pages.") 46 | 47 | for page_num in range(len(doc)): 48 | page = doc.load_page(page_num) 49 | text = page.get_text() 50 | 51 | pages.append({ 52 | "page_number": page_num + 1, 53 | "text": text.strip() 54 | }) 55 | logger.debug(f"Extracted text from page {page_num + 1}: {text[:100]}...") 56 | 57 | doc.close() 58 | logger.info(f"Text extraction completed for {len(pages)} pages.") 59 | 60 | return pages 61 | 62 | def create_embeddings(pages, bedrock_client, model_id): 63 | """ 64 | Creates embeddings for each page of text using a Bedrock model. 65 | 66 | Args: 67 | pages (list): A list of dictionaries where each dictionary contains 'page_number' and 'text'. 68 | bedrock_client (boto3.client): A Bedrock client for creating embeddings. 69 | model_id (str): The ID of the Bedrock model to use for embeddings. 70 | 71 | Returns: 72 | list: The input list of dictionaries with an additional 'vector_field' key containing the embeddings. 73 | 74 | Raises: 75 | Exception: If an error occurs while generating embeddings. 76 | 77 | """ 78 | logger.info(f"Creating embeddings using model: {model_id}") 79 | 80 | bedrock_embeddings = BedrockEmbeddings( 81 | client=bedrock_client, 82 | model_id=model_id 83 | ) 84 | 85 | for idx, page in enumerate(pages): 86 | logger.debug(f"Generating embedding for page {page['page_number']}") 87 | 88 | try: 89 | embedding = bedrock_embeddings.embed_query(page['text']) 90 | page['vector_field'] = embedding 91 | logger.debug(f"Embedding created for page {page['page_number']}: {embedding[:5]}...") 92 | except Exception as e: 93 | logger.error(f"Error creating embedding for page {page['page_number']}: {e}") 94 | 95 | logger.info("Embeddings creation completed.") 96 | 97 | return pages 98 | 99 | def main(): 100 | """ 101 | Run the main process to extract text from a PDF, create embeddings, 102 | and save the results to a JSON file. 103 | """ 104 | pdf_path = "../data/document.pdf" 105 | region_name = "eu-central-1" 106 | bedrock_embedding_model_id = "amazon.titan-embed-text-v1" 107 | 108 | logger.info("Starting the main process.") 109 | 110 | # Create Bedrock client 111 | try: 112 | bedrock_client = boto3.client("bedrock-runtime", region_name=region_name) 113 | logger.info(f"Created Bedrock client in region {region_name}.") 114 | except Exception as e: 115 | logger.error(f"Failed to create Bedrock client: {e}") 116 | raise 117 | 118 | # Extract text from PDF 119 | extracted_pages = extract_text_from_pdf(pdf_path) 120 | 121 | # Create embeddings for each page 122 | pages_with_embeddings = create_embeddings(extracted_pages, bedrock_client, bedrock_embedding_model_id) 123 | 124 | # Save the extracted text and embeddings to a JSON file 125 | output_file = "../data/text_with_embeddings.json" 126 | logger.info(f"Saving extracted text and embeddings to {output_file}.") 127 | 128 | try: 129 | with open(output_file, "w", encoding="utf-8") as f: 130 | json.dump(pages_with_embeddings, f, ensure_ascii=False, indent=2) 131 | logger.info(f"Data successfully saved to {output_file}.") 132 | except Exception as e: 133 | logger.error(f"Failed to save data to {output_file}: {e}") 134 | raise 135 | 136 | # Print a sample of the first page 137 | if pages_with_embeddings: 138 | first_page = pages_with_embeddings[0] 139 | logger.info("Displaying sample of the first page.") 140 | logger.info(f"Page Number: {first_page['page_number']}") 141 | logger.info(f"Content (first 200 characters): {first_page['text'][:200]}...") 142 | logger.info(f"Embedding (first 5 values): {first_page['vector_field'][:5]}...") 143 | 144 | logger.info("Main process completed.") 145 | 146 | if __name__ == "__main__": 147 | logger.info("Starting the script.") 148 | main() 149 | logger.info("Script execution finished.") 150 | -------------------------------------------------------------------------------- /src/app.py: -------------------------------------------------------------------------------- 1 | """LancgChain and Bedrock Q&A App.""" 2 | 3 | import os 4 | import sys 5 | 6 | import boto3 7 | from langchain.chains import create_retrieval_chain 8 | from langchain.chains.combine_documents import create_stuff_documents_chain 9 | from langchain_aws import BedrockEmbeddings, ChatBedrock 10 | from langchain_community.vectorstores import OpenSearchVectorSearch 11 | from langchain_core.prompts import ChatPromptTemplate 12 | from loguru import logger 13 | 14 | from retrieve_endpoint import get_opensearch_endpoint 15 | from retrieve_secret import get_secret 16 | 17 | # logger configuration 18 | logger.remove() 19 | logger.add(sys.stdout, level=os.getenv("LOG_LEVEL", "INFO")) 20 | 21 | 22 | def bedrock_embeddings(bedrock_client, bedrock_embedding_model_id): 23 | """ 24 | Create a LangChain vector embedding using Bedrock. 25 | 26 | Args: 27 | bedrock_client (boto3.client): The Bedrock client. 28 | bedrock_embedding_model_id (str): The ID of the Bedrock embedding model. 29 | 30 | Returns: 31 | BedrockEmbeddings: A LangChain Bedrock embeddings client. 32 | 33 | """ 34 | logger.info(f"Creating LangChain vector embedding using Bedrock model: {bedrock_embedding_model_id}") 35 | return BedrockEmbeddings( 36 | client=bedrock_client, 37 | model_id=bedrock_embedding_model_id) 38 | 39 | def opensearch_vectorstore(index_name, opensearch_password, bedrock_embeddings_client, opensearch_endpoint, _is_aoss=False): 40 | """ 41 | Create an OpenSearch vector search client. 42 | 43 | Args: 44 | index_name (str): The name of the OpenSearch index. 45 | opensearch_password (str): The password for OpenSearch authentication. 46 | bedrock_embeddings_client (BedrockEmbeddings): The Bedrock embeddings client. 47 | opensearch_endpoint (str): The OpenSearch endpoint URL. 48 | _is_aoss (bool, optional): Whether it's Amazon OpenSearch Serverless. Defaults to False. 49 | 50 | Returns: 51 | OpenSearchVectorSearch: An OpenSearch vector search client. 52 | 53 | """ 54 | logger.info(f"Creating OpenSearch vector search client for index: {index_name}") 55 | return OpenSearchVectorSearch( 56 | index_name=index_name, 57 | embedding_function=bedrock_embeddings_client, 58 | opensearch_url=f"https://{opensearch_endpoint}", 59 | http_auth=(index_name, opensearch_password), 60 | is_aoss=_is_aoss, 61 | timeout=30, 62 | retry_on_timeout=True, 63 | max_retries=3, 64 | ) 65 | 66 | def bedrock_llm(bedrock_client, bedrock_model_id): 67 | """ 68 | Create a Bedrock language model client. 69 | 70 | Args: 71 | bedrock_client (boto3.client): The Bedrock client. 72 | bedrock_model_id (str): The ID of the Bedrock model. 73 | 74 | Returns: 75 | ChatBedrock: A LangChain Bedrock chat model. 76 | 77 | """ 78 | logger.info(f"Creating Bedrock LLM with model: {bedrock_model_id}") 79 | 80 | model_kwargs = { 81 | # "maxTokenCount": 4096, 82 | "temperature": 0, 83 | "topP": 0.3, 84 | } 85 | 86 | return ChatBedrock( 87 | model_id=bedrock_model_id, 88 | client=bedrock_client, 89 | model_kwargs=model_kwargs 90 | ) 91 | 92 | def main(): 93 | """ 94 | Main function to run the LangChain with Bedrock and OpenSearch workflow. 95 | 96 | This function sets up the necessary clients, creates the LangChain components, 97 | and executes a query using the retrieval chain. 98 | """ 99 | logger.info("Starting the LangChain with Bedrock and OpenSearch workflow...") 100 | 101 | bedrock_model_id = "amazon.titan-text-lite-v1" 102 | bedrock_embedding_model_id = "amazon.titan-embed-text-v1" 103 | region = "eu-central-1" 104 | index_name = "rag" 105 | domain_name = "rag" 106 | question = " Can you describe the React approach?" 107 | 108 | 109 | logger.info(f"Creating Bedrock client with model {bedrock_model_id}, and embeddings with {bedrock_embedding_model_id}") 110 | 111 | # Creating all clients for chain 112 | bedrock_client = boto3.client("bedrock-runtime", region_name=region) 113 | llm = bedrock_llm(bedrock_client, bedrock_model_id) 114 | 115 | embeddings = bedrock_embeddings(bedrock_client, bedrock_embedding_model_id) 116 | host = get_opensearch_endpoint(domain_name, region) 117 | password = get_secret() 118 | 119 | vectorstore = opensearch_vectorstore(index_name, password, embeddings, host) 120 | 121 | 122 | prompt = ChatPromptTemplate.from_template("""You are an assistant for question-answering tasks. 123 | Use the following pieces of retrieved context to answer the question. 124 | If you don't know the answer, just say that you don't know. 125 | Use five sentences maximum. 126 | 127 | {context} 128 | 129 | Question: {input} 130 | Answer:""") 131 | 132 | chain = create_stuff_documents_chain(llm, prompt) 133 | 134 | retrieval_chain = create_retrieval_chain( 135 | retriever=vectorstore.as_retriever(), 136 | combine_docs_chain = chain 137 | ) 138 | 139 | response = retrieval_chain.invoke({"input": question}) 140 | 141 | logger.info(f"The answer from Bedrock {bedrock_model_id} is: {response.get('answer')}") 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /tests/test_json_processing.py: -------------------------------------------------------------------------------- 1 | """Unit tests for the JSON document.""" 2 | 3 | import json 4 | import os 5 | from unittest.mock import MagicMock, patch 6 | 7 | import boto3 8 | import pytest 9 | from botocore.exceptions import NoCredentialsError 10 | 11 | from generate_embeddings import create_embeddings 12 | 13 | # Path to the document 14 | DOCUMENT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data', 'text_with_embeddings.json') 15 | 16 | @pytest.fixture 17 | def sample_json_data(): 18 | """ 19 | Create a sample JSON data for testing. 20 | Returns: 21 | list: A list of dictionaries representing pages with text 22 | 23 | """ 24 | return [ 25 | {"page_number": 1, "text": "This is a test JSON document."}, 26 | {"page_number": 2, "text": "It has multiple pages."} 27 | ] 28 | 29 | @pytest.fixture 30 | def mock_bedrock_client(): 31 | """ 32 | Create a mock Bedrock client for testing. 33 | 34 | Returns: 35 | MagicMock: A mock object representing the Bedrock client 36 | 37 | """ 38 | with patch('boto3.client') as mock_client: 39 | mock_bedrock = MagicMock() 40 | mock_response = { 41 | 'body': MagicMock(), 42 | 'contentType': 'application/json', 43 | } 44 | mock_response['body'].read.return_value = b'{"embedding": [0.1, 0.2, 0.3]}' 45 | mock_bedrock.invoke_model.return_value = mock_response 46 | mock_client.return_value = mock_bedrock 47 | yield mock_bedrock 48 | 49 | def test_create_embeddings(sample_json_data, mock_bedrock_client): 50 | """ 51 | Test the create_embeddings function with sample JSON data. 52 | 53 | Args: 54 | sample_json_data (list): Sample JSON data 55 | mock_bedrock_client (MagicMock): Mock Bedrock client 56 | 57 | Raises: 58 | AssertionError: If any of the assertions fail 59 | 60 | """ 61 | model_id = "amazon.titan-embed-text-v1" 62 | pages_with_embeddings = create_embeddings(sample_json_data, mock_bedrock_client, model_id) 63 | 64 | assert len(pages_with_embeddings) == len(sample_json_data) 65 | for page in pages_with_embeddings: 66 | assert "vector_field" in page, "vector_field is missing from page with embeddings" 67 | assert isinstance(page["vector_field"], list), "vector_field should be a list" 68 | assert page["vector_field"] == [0.1, 0.2, 0.3], "Unexpected embedding values" 69 | 70 | assert mock_bedrock_client.invoke_model.call_count == len(sample_json_data), "Bedrock client not called for each page" 71 | 72 | for call_args in mock_bedrock_client.invoke_model.call_args_list: 73 | kwargs = call_args[1] 74 | assert kwargs['modelId'] == model_id 75 | assert 'body' in kwargs 76 | assert kwargs['contentType'] == 'application/json' 77 | 78 | def test_process_real_json_document(): 79 | """ 80 | Test processing the real JSON document. 81 | 82 | Raises: 83 | AssertionError: If any of the assertions fail 84 | 85 | """ 86 | assert os.path.exists(DOCUMENT_PATH), f"The file {DOCUMENT_PATH} does not exist" 87 | 88 | with open(DOCUMENT_PATH, 'r') as file: 89 | data = json.load(file) 90 | 91 | assert isinstance(data, list), "JSON data should be a list" 92 | assert len(data) > 0, "JSON data is empty" 93 | 94 | for item in data: 95 | assert "page_number" in item, "page_number is missing from JSON item" 96 | assert "text" in item, "text is missing from JSON item" 97 | assert isinstance(item["page_number"], int), "page_number should be an integer" 98 | assert isinstance(item["text"], str), "text should be a string" 99 | assert len(item["text"]) > 0, f"Text is empty for page {item['page_number']}" 100 | 101 | # Add specific content checks based on your document 102 | assert any("react" in item["text"].lower() for item in data), "Expected phrase not found in document" 103 | 104 | @pytest.mark.integration 105 | def test_create_embeddings_with_real_document(mock_bedrock_client): 106 | """ 107 | Integration test for create_embeddings function with the real JSON document. 108 | 109 | Args: 110 | mock_bedrock_client (MagicMock): Mock Bedrock client 111 | 112 | Raises: 113 | AssertionError: If any of the assertions fail 114 | 115 | """ 116 | assert os.path.exists(DOCUMENT_PATH), f"The file {DOCUMENT_PATH} does not exist" 117 | 118 | with open(DOCUMENT_PATH, 'r') as file: 119 | data = json.load(file) 120 | 121 | model_id = "amazon.titan-embed-text-v1" 122 | pages_with_embeddings = create_embeddings(data, mock_bedrock_client, model_id) 123 | 124 | assert len(pages_with_embeddings) == len(data), "Number of pages with embeddings doesn't match original data count" 125 | 126 | for page in pages_with_embeddings: 127 | assert "vector_field" in page, "vector_field is missing from page with embeddings" 128 | assert isinstance(page["vector_field"], list), "vector_field should be a list" 129 | assert page["vector_field"] == [0.1, 0.2, 0.3], "Unexpected embedding values" 130 | 131 | assert mock_bedrock_client.invoke_model.call_count == len(data), "Bedrock client not called for each page" 132 | 133 | for call_args in mock_bedrock_client.invoke_model.call_args_list: 134 | kwargs = call_args[1] 135 | assert kwargs['modelId'] == model_id 136 | assert 'body' in kwargs 137 | assert kwargs['contentType'] == 'application/json' 138 | 139 | @pytest.mark.integration 140 | def test_create_embeddings_with_real_document_and_aws(): 141 | """ 142 | Full integration test for create_embeddings function with the real JSON document and AWS. 143 | 144 | Raises: 145 | pytest.skip: If no AWS credentials are found 146 | AssertionError: If any of the assertions fail 147 | 148 | """ 149 | try: 150 | client = boto3.client("bedrock-runtime", region_name="eu-central-1") 151 | except NoCredentialsError: 152 | pytest.skip("No AWS credentials found for Bedrock client") 153 | 154 | assert os.path.exists(DOCUMENT_PATH), f"The file {DOCUMENT_PATH} does not exist" 155 | 156 | with open(DOCUMENT_PATH, 'r') as file: 157 | data = json.load(file) 158 | 159 | model_id = "amazon.titan-embed-text-v1" 160 | pages_with_embeddings = create_embeddings(data, client, model_id) 161 | 162 | assert len(pages_with_embeddings) == len(data), "Number of pages with embeddings doesn't match original data count" 163 | 164 | for page in pages_with_embeddings: 165 | assert "vector_field" in page, "vector_field is missing from page with embeddings" 166 | assert isinstance(page["vector_field"], list), "vector_field should be a list" 167 | assert len(page["vector_field"]) > 0, "vector_field is empty" 168 | assert all(isinstance(value, float) for value in page["vector_field"]), "Embedding values should be floats" 169 | 170 | print(f"Successfully created embeddings for {len(data)} pages from the real document.") 171 | -------------------------------------------------------------------------------- /tests/test_secret_endpoint.py: -------------------------------------------------------------------------------- 1 | """Unit tests for checking the endpoint and secret.""" 2 | 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | from botocore.exceptions import ClientError 7 | 8 | from retrieve_endpoint import get_opensearch_endpoint, main 9 | 10 | 11 | @pytest.fixture 12 | def mock_boto3_client(): 13 | """ 14 | Fixture to mock the boto3 client. 15 | Returns: 16 | MagicMock: A mock object representing the boto3 client. 17 | 18 | Raises: 19 | None 20 | 21 | """ 22 | with patch('boto3.client') as mock_client: 23 | yield mock_client 24 | 25 | def test_get_opensearch_endpoint_success(mock_boto3_client, capsys): 26 | """ 27 | Test successful retrieval of OpenSearch endpoint for the 'rag' domain. 28 | 29 | Args: 30 | mock_boto3_client (MagicMock): A pytest fixture that mocks the boto3 client. 31 | capsys (pytest.CaptureFixture): A pytest fixture for capturing stdout and stderr. 32 | 33 | Returns: 34 | None 35 | 36 | Raises: 37 | AssertionError: If any of the assertions fail. 38 | 39 | """ 40 | # Arrange 41 | domain_name = 'rag' 42 | region = 'eu-central-1' 43 | expected_endpoint = f'search-{domain_name}-abcdef123456.{region}.es.amazonaws.com' 44 | mock_es_client = MagicMock() 45 | mock_boto3_client.return_value = mock_es_client 46 | mock_es_client.describe_elasticsearch_domain.return_value = { 47 | 'DomainStatus': {'Endpoint': expected_endpoint} 48 | } 49 | 50 | # Act 51 | result = get_opensearch_endpoint(domain_name, region) 52 | 53 | # Assert 54 | assert result == expected_endpoint 55 | assert result.startswith(f'search-{domain_name}') 56 | assert result.endswith(f'{region}.es.amazonaws.com') 57 | captured = capsys.readouterr() 58 | assert f"Attempting to describe OpenSearch domain: {domain_name}" in captured.out 59 | assert "Domain description retrieved successfully" in captured.out 60 | assert f"Endpoint: {expected_endpoint}" in captured.out 61 | 62 | def test_get_opensearch_endpoint_no_endpoint(mock_boto3_client, capsys): 63 | """ 64 | Test behavior when no endpoint is found for the 'rag' domain. 65 | 66 | Args: 67 | mock_boto3_client (MagicMock): A pytest fixture that mocks the boto3 client. 68 | capsys (pytest.CaptureFixture): A pytest fixture for capturing stdout and stderr. 69 | 70 | Returns: 71 | None 72 | 73 | Raises: 74 | AssertionError: If any of the assertions fail. 75 | 76 | """ 77 | # Arrange 78 | domain_name = 'rag' 79 | region = 'eu-central-1' 80 | mock_es_client = MagicMock() 81 | mock_boto3_client.return_value = mock_es_client 82 | mock_es_client.describe_elasticsearch_domain.return_value = { 83 | 'DomainStatus': {} # No endpoint in response 84 | } 85 | 86 | # Act 87 | result = get_opensearch_endpoint(domain_name, region) 88 | 89 | # Assert 90 | assert result is None 91 | captured = capsys.readouterr() 92 | assert "No endpoint found for the domain." in captured.err 93 | 94 | def test_get_opensearch_endpoint_client_error(mock_boto3_client, capsys): 95 | """ 96 | Test handling of ClientError exception for the 'rag' domain. 97 | 98 | Args: 99 | mock_boto3_client (MagicMock): A pytest fixture that mocks the boto3 client. 100 | capsys (pytest.CaptureFixture): A pytest fixture for capturing stdout and stderr. 101 | 102 | Returns: 103 | None 104 | 105 | Raises: 106 | AssertionError: If any of the assertions fail. 107 | 108 | """ 109 | # Arrange 110 | domain_name = 'rag' 111 | region = 'eu-central-1' 112 | mock_es_client = MagicMock() 113 | mock_boto3_client.return_value = mock_es_client 114 | mock_es_client.describe_elasticsearch_domain.side_effect = ClientError( 115 | {'Error': {'Code': 'ResourceNotFoundException', 'Message': 'Domain not found'}}, 116 | 'DescribeElasticsearchDomain' 117 | ) 118 | 119 | # Act 120 | result = get_opensearch_endpoint(domain_name, region) 121 | 122 | # Assert 123 | assert result is None 124 | captured = capsys.readouterr() 125 | assert "ClientError:" in captured.err 126 | 127 | def test_get_opensearch_endpoint_general_exception(mock_boto3_client, capsys): 128 | """ 129 | Test handling of general exceptions for the 'rag' domain. 130 | 131 | Args: 132 | mock_boto3_client (MagicMock): A pytest fixture that mocks the boto3 client. 133 | capsys (pytest.CaptureFixture): A pytest fixture for capturing stdout and stderr. 134 | 135 | Returns: 136 | None 137 | 138 | Raises: 139 | AssertionError: If any of the assertions fail. 140 | 141 | """ 142 | # Arrange 143 | domain_name = 'rag' 144 | region = 'eu-central-1' 145 | mock_es_client = MagicMock() 146 | mock_boto3_client.return_value = mock_es_client 147 | mock_es_client.describe_elasticsearch_domain.side_effect = Exception("Unexpected error") 148 | 149 | # Act 150 | result = get_opensearch_endpoint(domain_name, region) 151 | 152 | # Assert 153 | assert result is None 154 | captured = capsys.readouterr() 155 | assert "An error occurred during script execution:" in captured.err 156 | 157 | @pytest.mark.parametrize("secret_name,region,endpoint,expected", [ 158 | ("rag-test", "eu-central-1", "search-rag-abcdef123456.eu-central-1.es.amazonaws.com", True), 159 | ("rag-non-existent", "eu-central-1", None, False), 160 | ]) 161 | def test_main_function(mock_boto3_client, capsys, secret_name, region, endpoint, expected): 162 | """ 163 | Test the main function with various scenarios for the 'rag' domain. 164 | 165 | This parameterized test verifies the behavior of the main function 166 | under different conditions: 167 | 1. Successful endpoint retrieval for different secret names 168 | 2. Failed endpoint retrieval for a non-existent secret 169 | 170 | It checks that the function returns the expected boolean result, 171 | prints appropriate messages to stdout or stderr, and ensures the 172 | endpoint format is correct when successful. 173 | 174 | Args: 175 | mock_boto3_client (MagicMock): A pytest fixture that mocks the boto3 client. 176 | capsys (pytest.CaptureFixture): A pytest fixture for capturing stdout and stderr. 177 | secret_name (str): Name of the secret being tested (starts with 'rag'). 178 | region (str): AWS region being tested. 179 | endpoint (str or None): Mocked endpoint (or None for failure case). 180 | expected (bool): Expected boolean return value of the main function. 181 | 182 | Returns: 183 | None 184 | 185 | Raises: 186 | AssertionError: If any of the assertions fail. 187 | 188 | """ 189 | # Arrange 190 | domain_name = 'rag' 191 | mock_es_client = MagicMock() 192 | mock_boto3_client.return_value = mock_es_client 193 | if endpoint: 194 | mock_es_client.describe_elasticsearch_domain.return_value = { 195 | 'DomainStatus': {'Endpoint': endpoint} 196 | } 197 | else: 198 | mock_es_client.describe_elasticsearch_domain.side_effect = ClientError( 199 | {'Error': {'Code': 'ResourceNotFoundException', 'Message': 'Domain not found'}}, 200 | 'DescribeElasticsearchDomain' 201 | ) 202 | 203 | # Act 204 | result = main(domain_name, region) 205 | 206 | # Assert 207 | assert result == expected 208 | captured = capsys.readouterr() 209 | if expected: 210 | assert endpoint.startswith(f'search-{domain_name}') 211 | assert endpoint.endswith(f'{region}.es.amazonaws.com') 212 | assert f"Attempting to describe OpenSearch domain: {domain_name}" in captured.out 213 | assert "Domain description retrieved successfully" in captured.out 214 | assert f"Endpoint: {endpoint}" in captured.out 215 | assert "Script executed successfully" in captured.out 216 | else: 217 | if endpoint is None: 218 | assert "ClientError:" in captured.err 219 | assert "No endpoint returned." in captured.err or "An error occurred during script execution:" in captured.err 220 | 221 | # Additional check for secret name 222 | assert secret_name.startswith('rag') 223 | --------------------------------------------------------------------------------