├── examples ├── 02_rest_api_simple_usage │ ├── src │ │ ├── __init__.py │ │ └── utils │ │ │ ├── jwtutils.py │ │ │ └── queryutils.py │ ├── requirements.txt │ ├── simple_query.py │ ├── interactive_query.py │ ├── README.md │ └── notebook_query.ipynb ├── 01_python_simple_usage │ ├── requirements.txt │ ├── example_connections.toml │ ├── using_connections_config.py │ ├── simple_python_connection.py │ └── README.md ├── 04_python_concurrent_queries │ ├── requirements.txt │ ├── queries │ ├── README.md │ └── concurrent-example.py ├── 09_multihop_rag │ ├── assets │ │ └── vector_db_graph_db.png │ ├── environment.yml │ ├── README.md │ └── streamlit_chatbot_multihop_rag.py ├── 05_streamlit_ai_search_app │ ├── .streamlit │ │ └── example_secrets.toml │ ├── conda_env.yml │ ├── README.md │ └── cortex_search.py ├── 07_streamlit_search_evaluation_app │ ├── Makefile │ ├── pyproject.toml │ ├── README.md │ └── eval_test.py ├── 08_multimodal_rag │ ├── README.md │ ├── streamlit_chatbot_multimodal_rag.py │ └── cortex_search_multimodal.ipynb ├── 10_reranker_finetuning │ └── README.md ├── 06_streamlit_chatbot_app │ ├── README.md │ ├── setup.sql │ └── chat.py └── 03_batch_querying_from_sql │ └── batch_sproc.sql ├── .gitignore ├── .devcontainer └── devcontainer.json ├── README.md ├── projects └── improving-catalog-search │ ├── README.md │ ├── soft_query_boost_example.ipynb │ ├── cortex_search_setup_ecommerce_example.ipynb │ └── llm_judge_ecommerce_ranking_example.ipynb └── LICENSE /examples/02_rest_api_simple_usage/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/01_python_simple_usage/requirements.txt: -------------------------------------------------------------------------------- 1 | snowflake==0.8.0 2 | backoff>=2.2.1 3 | python-dotenv>=1.0 4 | -------------------------------------------------------------------------------- /examples/04_python_concurrent_queries/requirements.txt: -------------------------------------------------------------------------------- 1 | snowflake==0.8.0 2 | backoff>=2.2.1 3 | python-dotenv>=1.0 4 | -------------------------------------------------------------------------------- /examples/04_python_concurrent_queries/queries: -------------------------------------------------------------------------------- 1 | riding shotgun 2 | hello 3 | what's going on 4 | pet sounds 5 | blue 6 | songs in the key of life -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | */.ipynb_checkpoints 3 | .ipynb_checkpoints 4 | */.DS_Store 5 | .DS_Store 6 | **/venv/ 7 | **/connections.toml 8 | **/.streamlit/secrets.toml -------------------------------------------------------------------------------- /examples/09_multihop_rag/assets/vector_db_graph_db.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Snowflake-Labs/cortex-search/HEAD/examples/09_multihop_rag/assets/vector_db_graph_db.png -------------------------------------------------------------------------------- /examples/01_python_simple_usage/example_connections.toml: -------------------------------------------------------------------------------- 1 | [example_connection] 2 | account = "ABC12345" 3 | user = "my_username" 4 | password = "..." 5 | database = "my_db" 6 | schema = "my_schema" 7 | role = "my_role" 8 | -------------------------------------------------------------------------------- /examples/05_streamlit_ai_search_app/.streamlit/example_secrets.toml: -------------------------------------------------------------------------------- 1 | [connections.conn] 2 | account = "" 3 | user = "" 4 | password = "" 5 | warehouse = "" 6 | database = "" 7 | schema = "" 8 | client_session_keep_alive = true -------------------------------------------------------------------------------- /examples/09_multihop_rag/environment.yml: -------------------------------------------------------------------------------- 1 | name: multihop_rag_app_env 2 | channels: 3 | - snowflake 4 | dependencies: 5 | - streamlit 6 | - snowflake 7 | - snowflake.core 8 | - snowflake-ml-python 9 | - langchain-core 10 | -------------------------------------------------------------------------------- /examples/05_streamlit_ai_search_app/conda_env.yml: -------------------------------------------------------------------------------- 1 | name: cortex-search-env 2 | dependencies: 3 | - python=3.9 4 | - pip 5 | - pip: 6 | - snowflake-snowpark-python 7 | - streamlit 8 | - snowflake 9 | - snowflake-ml-python 10 | - snowflake-connector-python -------------------------------------------------------------------------------- /examples/02_rest_api_simple_usage/requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2024.2.2 2 | cffi==1.16.0 3 | charset-normalizer==3.3.2 4 | cryptography==42.0.5 5 | idna==3.6 6 | pandas==2.2.1 7 | pycparser==2.21 8 | PyJWT==2.8.0 9 | python-dotenv==1.0.1 10 | requests==2.31.0 11 | urllib3==2.2.1 12 | 13 | -------------------------------------------------------------------------------- /examples/07_streamlit_search_evaluation_app/Makefile: -------------------------------------------------------------------------------- 1 | # sets up a local Python development environment 2 | setup: 3 | @if ! command -v uv >/dev/null 2>&1; then \ 4 | echo "Installing uv via curl..."; \ 5 | curl -LsSf https://astral.sh/uv/install.sh | sh; \ 6 | fi 7 | uv sync 8 | 9 | check: 10 | uv run ruff check 11 | 12 | format: 13 | uv run ruff check --fix 14 | uv run ruff format 15 | 16 | test: 17 | uv run pytest eval_test.py 18 | 19 | run: 20 | uv run streamlit run eval.py --server.headless True 21 | -------------------------------------------------------------------------------- /examples/07_streamlit_search_evaluation_app/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "cortex-search-eval-tool" 3 | version = "0" 4 | description = "Tool to evaluate the quality of Cortex Search Service against a set of queries and goldens" 5 | readme = "README.md" 6 | requires-python = ">=3.11" 7 | dependencies = [ 8 | "snowflake-connector-python>=3.12.2", 9 | "snowflake-snowpark-python>=1.23.0", 10 | "streamlit>=1.39.0", 11 | "matplotlib>=3.7.2", 12 | "pandas>=2.0.3", 13 | "scikit-learn>=1.5.1", 14 | "scikit-optimize>=0.9.0", 15 | "scipy>=1.13.1", 16 | "snowflake>=0.13.0", 17 | "ruff>=0.7.1", 18 | "pytest", 19 | "snowflake-ml-python", 20 | ] 21 | 22 | [tool.uv] 23 | dev-dependencies = [ 24 | "watchdog>=5.0.3", 25 | ] 26 | -------------------------------------------------------------------------------- /examples/01_python_simple_usage/using_connections_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from snowflake.connector import connect 5 | from snowflake.connector.config_manager import CONFIG_MANAGER 6 | from snowflake.core import Root 7 | 8 | # Load the Cortex Search Service name from your environment 9 | svc = os.environ["SNOWFLAKE_CORTEX_SEARCH_SERVICE"] 10 | 11 | # Replace with your search parameters 12 | query = "riding shotgun" 13 | columns = ["LYRIC","ALBUM_NAME","TRACK_TITLE","TRACK_N","LINE"] 14 | limit = 5 15 | 16 | with connect( 17 | connection_name="example_connection", 18 | ) as conn: 19 | try: 20 | # create a root as the entry point for all objects 21 | root = Root(conn) 22 | 23 | response = ( 24 | root.databases[conn.database] 25 | .schemas[conn.schema] 26 | .cortex_search_services[svc] 27 | .search( 28 | query, 29 | columns, 30 | limit=limit 31 | ) 32 | ) 33 | 34 | print(f"Received response with `request_id`: {response.request_id}") 35 | print(json.dumps(response.results,indent=4)) 36 | finally: 37 | conn.close() -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Python 3", 3 | // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile 4 | "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye", 5 | "customizations": { 6 | "codespaces": { 7 | "openFiles": [ 8 | "README.md", 9 | "examples/streamlit-ai-search/cortex_search.py" 10 | ] 11 | }, 12 | "vscode": { 13 | "settings": {}, 14 | "extensions": [ 15 | "ms-python.python", 16 | "ms-python.vscode-pylance" 17 | ] 18 | } 19 | }, 20 | "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y str: 13 | """ 14 | Generate a valid JWT token from snowflake account name, user name, private key and private key passphrase. 15 | """ 16 | 17 | # Prompt for private key passphrase 18 | def get_private_key_passphrase(): 19 | return getpass("Private Key Passphrase: ") 20 | 21 | # Generate encoded private key 22 | with open(private_key_path, "rb") as pem_in: 23 | pemlines = pem_in.read() 24 | try: 25 | private_key = load_pem_private_key(pemlines, None, default_backend()) 26 | except TypeError: 27 | private_key = load_pem_private_key( 28 | pemlines, get_private_key_passphrase().encode(), default_backend() 29 | ) 30 | public_key_raw = private_key.public_key().public_bytes( 31 | Encoding.DER, PublicFormat.SubjectPublicKeyInfo 32 | ) 33 | sha256hash = hashlib.sha256() 34 | sha256hash.update(public_key_raw) 35 | public_key_fp = "SHA256:" + base64.b64encode(sha256hash.digest()).decode("utf-8") 36 | 37 | # Generate JWT payload 38 | qualified_username = account + "." + user 39 | now = datetime.now(timezone.utc) 40 | lifetime = timedelta(minutes=60) 41 | payload = { 42 | "iss": qualified_username + "." + public_key_fp, 43 | "sub": qualified_username, 44 | "iat": now, 45 | "exp": now + lifetime, 46 | } 47 | 48 | # Return the encoded JWT token 49 | return jwt.encode(payload, key=private_key, algorithm="RS256") 50 | -------------------------------------------------------------------------------- /examples/10_reranker_finetuning/README.md: -------------------------------------------------------------------------------- 1 | # Reranker Finetuning 2 | 3 | This tutorial notebook guides you through finetuning a reranker model on your collection of documents, creating a service and use the service in your search workflow. 4 | 5 | ## When do you need to finetune a reranker? 6 | 7 | Generally speaking, you can benefit from a finetuned reranker when you frequently observe irrelevant documents ranked in top positions by Cortex Search. This might happen when: 8 | - Your search task is quite different from the standard "short query, long document" search format. 9 | - Your search task requires an understanding of technical, proprietary terms, concepts, and jargon that are rarely found on the open web. 10 | - Your search task involves languages that Cortex Search is not optimized for. 11 | 12 | ## Prerequisites 13 | 14 | - Snowflake account with SPCS enabled 15 | - Appropriate Snowflake privileges 16 | - `CREATE DATABASE`, `CREATE SCHEMA` 17 | - `CREATE COMPUTE POOL` 18 | - `CREATE IMAGE REPOSITORY`, `CREATE SERVICE` 19 | 20 | **This tutorial also relies on a preview feature.** Before you can proceed with this tutorial, reach out to your account team to ask to enable this feature for your account: 21 | - [Batch Cortex Search](https://docs.snowflake.com/LIMITEDACCESS/cortex-search/batch-cortex-search) 22 | 23 | 24 | ## Usage 25 | - Create a compute pool with GPU resources following the [instructions here](https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool) 26 | - Upload the [attached notebook](../10_reranker_finetuning/reranker_finetuning.ipynb) to Snowflake using the [instructions here](https://docs.snowflake.com/en/user-guide/ui-snowsight/notebooks-create#create-a-new-notebook). Make sure to use `Snowflake ML Runtime GPU 1.0`, select the compute pool created, and turn on [all external access options](https://docs.snowflake.com/en/user-guide/ui-snowsight/notebooks-external-access) for downloading Python packages and base models. 27 | - Follow the instructions in the notebook for data processing, prompt generation, synthetic data generation, training and deployment. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cortex Search 2 | 3 | This repository contains example usage of Cortex Search, currently in Private Preview. The official preview documentation can be [found here](https://docs.snowflake.com/LIMITEDACCESS/cortex-search/cortex-search-overview). 4 | 5 | ## Examples 6 | 7 | The `examples` directory showcases several Cortex Search usage patterns. Navigate to each of the following subdirectories for installation information and sample usage for the method of choice: 8 | 9 | - [01_python_simple_usage](examples/01_python_simple_usage): Simple querying of a Cortex Search Service via the `snowflake` [python package](https://pypi.org/project/snowflake/). 10 | - [02_rest_api_simple_usage](examples/02_rest_api_simple_usage): Simple querying of a Cortex Search Service via the REST API 11 | - [03_batch_querying_from_sql](examples/03_batch_querying_from_sql): Querying a Cortex Search Service in a "batch" fashion using SQL. 12 | - [04_python_concurrent_queries](examples/04_python_concurrent_queries): Quering a Cortex Search Service with concurrency using the Python SDK. 13 | - [05_streamlit_ai_search_app](examples/05_streamlit_ai_search_app): Sample Streamlit app using Cortex Search to power a search bar. 14 | - [06_streamlit_chatbot_app](examples/06_streamlit_chatbot_app): Sample Streamlit app using Cortex Search and Cortex LLM Functions to power a document chatbot. 15 | - [07_streamlit_search_evaluation_app](examples/07_streamlit_search_evaluation_app): Streamlit app guiding users through evaluation of the quality of a Cortex Search Service. 16 | - [08_multimodal_rag](examples/08_multimodal_rag): Sample notebook and streamlit app using Cortex Search and Cortex LLM Functions for multimodal RAG on PDFs. 17 | - [09_multihop_rag](examples/09_multihop_rag): Sample notebook and streamlit app combining Cortex Search for document retrieval with document graph traversal for multi-hop multimodal RAG on complex PDFs. 18 | - [10_reranker_finetuning](examples/10_reranker_finetuning): Sample notebook demonstrating finetuning a reranker model on Snowflake SPCS Notebook and drastically improve search quality. 19 | 20 | ## License 21 | 22 | Apache Version 2.0 23 | -------------------------------------------------------------------------------- /examples/04_python_concurrent_queries/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Executing concurrent Cortex Search queries with throttling 3 | 4 | This directory contains example usage for the Cortex Search REST API 5 | via the `snowflake.core` python package. `snowflake.core` supports 6 | authentication and connection to a Snowflake account through several 7 | different mechanisms, a few of which are outlined in the examples. 8 | 9 | Notably, the Cortex Search API is only available in versions of 10 | `snowflake.core >= 0.8.0`. 11 | 12 | ## Prerequisites 13 | 14 | Before you can run the examples, ensure you have the following 15 | prerequisites installed: 16 | 17 | - Python 3.11 18 | - pip (Python package installer) 19 | 20 | Additionally, you must have access to a Snowflake account and the 21 | required permissions to query the Cortex Search Service at the 22 | specified database and schema. 23 | 24 | ## Installation 25 | 26 | First, clone this repository to your local machine using git and navigate to this directory: 27 | 28 | ``` 29 | git clone https://github.com/snowflake-labs/cortex-search.git 30 | cd cortex-search/examples/04_python_concurrent_queries 31 | ``` 32 | 33 | Install the necessary Python dependencies by running: 34 | 35 | ``` 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ## Usage 40 | 41 | `concurrent-example.py` is an example for running concurrent search 42 | requests on your cortex search service along with throttling the 43 | request rate when the server is busy. 44 | 45 | This example collects default connection parameters from a file. Add the 46 | following lines to a file named `.env` in the current directory: 47 | 48 | ``` 49 | SNOWFLAKE_ACCOUNT=AID123456 50 | SNOWFLAKE_USER=myself 51 | SNOWFLAKE_AUTHENTICATOR= 52 | SNOWFLAKE_PASSWORD= 53 | SNOWFLAKE_ROLE=my_role 54 | SNOWFLAKE_DATABASE=my_db 55 | SNOWFLAKE_SCHEMA=my_schema 56 | SNOWFLAKE_CORTEX_SEARCH_SERVICE=my_service 57 | ``` 58 | 59 | Add the queries you want to run concurrently to a file named 60 | `queries`, also in the current directory, one query per line: 61 | 62 | ``` 63 | riding shotgun 64 | hello 65 | what's going on 66 | pet sounds 67 | blue 68 | songs in the key of life 69 | ``` 70 | 71 | You can then run `concurrent-example.py` file to print search results: 72 | 73 | ``` 74 | python concurrent-example.py --columns col1 col2 75 | ``` 76 | 77 | You can override the connection parameters specified in `.env` file with command line parameters. Try: 78 | 79 | 80 | ``` 81 | python concurrent-example.py -h 82 | ``` 83 | 84 | to enumerate the list of command line options. 85 | -------------------------------------------------------------------------------- /examples/02_rest_api_simple_usage/simple_query.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import argparse 3 | 4 | import sys 5 | from pathlib import Path 6 | import json 7 | 8 | # Add the project root to sys.path 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | from src.utils import jwtutils, queryutils 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description="Cortex Search Service API Query CLI") 14 | parser.add_argument( 15 | "-u", 16 | "--url", 17 | help="Snowflake account URL that uses the account locator", 18 | required=True, 19 | ) 20 | parser.add_argument( 21 | "-a", 22 | "--account", 23 | help="Snowflake account name, eg my_org-my_account", 24 | required=True, 25 | ) 26 | parser.add_argument( 27 | "-n", 28 | "--user-name", 29 | help="Snowflake user name", 30 | required=True, 31 | ) 32 | parser.add_argument( 33 | "-s", 34 | "--qualified-service-name", 35 | help="Qualified name of the Cortex Search Service, eg `MY_DB.MY_SCHEMA.MY_SERVICE`. Case insensitive.", 36 | required=True, 37 | ) 38 | parser.add_argument( 39 | "-k", 40 | "--private-key-path", 41 | help="Absolute local path to RSA private key", 42 | required=True, 43 | ) 44 | parser.add_argument( 45 | "-c", 46 | "--columns", 47 | help="Comma-separated list of columns to return", 48 | required=True, 49 | ) 50 | parser.add_argument( 51 | "-q", 52 | "--query", 53 | help="Query string", 54 | required=True, 55 | ) 56 | parser.add_argument( 57 | "-l", 58 | "--limit", 59 | help="Max number of results to return", 60 | required=True, 61 | ) 62 | parser.add_argument( 63 | "-r", 64 | "--role", 65 | help="User role to use for queries. If provided, a session token scoped to this role will be generated for authenticating to the API.", 66 | required=False, 67 | ) 68 | 69 | args = parser.parse_args() 70 | 71 | request_body = { 72 | "columns": args.columns.split(","), 73 | "query": args.query, 74 | "limit": args.limit, 75 | } 76 | 77 | search_service = queryutils.CortexSearchService( 78 | args.private_key_path, 79 | args.url, 80 | args.account, 81 | args.user_name, 82 | args.qualified_service_name, 83 | args.role, 84 | ) 85 | response = search_service.search(request_body=request_body) 86 | 87 | if response is not None: 88 | print(json.dumps(response.json(), indent=4)) 89 | else: 90 | print("Failed to fetch data.") 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /examples/01_python_simple_usage/README.md: -------------------------------------------------------------------------------- 1 | # Cortex Search SDK usage via the `snowflake.core` library 2 | 3 | This directory contains example usage for the Cortex Search REST API 4 | via the `snowflake.core` python package. `snowflake.core` supports 5 | authentication and connection to a Snowflake account through several 6 | different mechanisms, a few of which are outlined in the examples. 7 | 8 | Notably, the Cortex Search API is only available in versions of 9 | `snowflake.core >= 0.8.0`. 10 | 11 | ## Prerequisites 12 | 13 | Before you can run the examples, ensure you have the following 14 | prerequisites installed: 15 | 16 | - Python 3.11 17 | - pip (Python package installer) 18 | 19 | Additionally, you must have access to a Snowflake account and the 20 | required permissions to query the Cortex Search Service at the 21 | specified database and schema. 22 | 23 | ## Installation 24 | 25 | First, clone this repository to your local machine using git and navigate to this directory: 26 | 27 | ``` 28 | git clone https://github.com/snowflake-labs/cortex-search.git 29 | cd cortex-search/examples/01_python_simple_usage 30 | ``` 31 | 32 | Install the necessary Python dependencies by running: 33 | 34 | ``` 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ## Usage 39 | 40 | ### 1. Passing connection parameters explicitly or via environment variables 41 | 42 | `simple.py` collects connection parameters from your shell 43 | environment. These can be set like so (note: we recommend setting the 44 | `SNOWFLAKE_ACCOUNT` using the account locator, rather than the 45 | `org-account` format): 46 | 47 | ``` 48 | export SNOWFLAKE_ACCOUNT=AID123456 49 | export SNOWFLAKE_USER=myself 50 | export SNOWFLAKE_PASSWORD=pass123... 51 | export SNOWFLAKE_ROLE=my_role 52 | export SNOWFLAKE_DATABASE=my_db 53 | export SNOWFLAKE_SCHEMA=my_schema 54 | export SNOWFLAKE_CORTEX_SEARCH_SERVICE=my_service 55 | ``` 56 | 57 | However, you may also simply replace each `os.environ[".."]` with hardcoded values, if you wish. 58 | 59 | Then, after modifying the search parameters to your liking, run the example to generate results: 60 | 61 | ``` 62 | python simple.py 63 | ``` 64 | 65 | ### 2. With a connections.toml file 66 | 67 | To set up a `connections.toml` file to store aliased Snowflake 68 | connections and all the parameters needed to connect, see the 69 | [snowflake connector docs]( 70 | https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-connect#connecting-using-the-connections-toml-file) 71 | 72 | An example is found under `example_connections.toml`, but make sure to 73 | name yours `connections.toml` and ensure it is located at a valid 74 | system-dependent path as specified in the Snowflake connector docs. 75 | 76 | Then, export your Cortex Search Service name (assuming the database 77 | and schema are already in your connection parameters): 78 | 79 | ``` 80 | export SNOWFLAKE_CORTEX_SEARCH_SERVICE=my_service 81 | ``` 82 | 83 | You can then run the `using_connections_config.py` file to print search results: 84 | 85 | ``` 86 | python using_connections_config.py 87 | ``` -------------------------------------------------------------------------------- /projects/improving-catalog-search/README.md: -------------------------------------------------------------------------------- 1 | # Cortex Search E-commerce Examples 2 | 3 | This repository contains companion notebooks for the blog post "Improving E-commerce Search: Intelligent Ranking Made Simple with Snowflake Cortex Search" (link coming soon). These notebooks demonstrate how to set up, evaluate, and optimize Cortex Search for e-commerce use cases. 4 | 5 | ## Notebooks Overview 6 | 7 | ### 1. `wands_ingest.ipynb` 8 | This notebook shows how to ingest and prepare the WANDS (Wayfair Product Search) dataset for use with Cortex Search: 9 | - Downloads the WANDS dataset 10 | - Processes product features and creates a unified text column for search 11 | - Creates and configures a Cortex Search service with appropriate attributes 12 | 13 | ### 2. `soft_query_boost_example.ipynb` 14 | Demonstrates the implementation of smart query boosting in Cortex Search with example feature extraction strategy: 15 | - Shows how to extract main terms from search queries 16 | - Implements soft boost logic for queries with supplemental signals 17 | - Provides examples of query processing for products with multiple aspects 18 | 19 | ### 3. `llm_judge_product_search.ipynb` 20 | Illustrates the query label extraction process using LLM-based relevance judgments: 21 | - Implements structured prompts for LLM-based relevance evaluation 22 | - Shows how to process and evaluate search results 23 | - Demonstrates integration with Snowflake's Cortex Search service 24 | 25 | ## Prerequisites 26 | 27 | - Snowflake account with access to Cortex Search 28 | - Python 3.11+ 29 | - Required Python packages: 30 | - snowflake-snowpark-python 31 | - pandas 32 | - numpy 33 | 34 | ## Setup Instructions 35 | 36 | 1. Clone this repository 37 | 2. Set up your Snowflake connection: 38 | ```python 39 | from snowflake.snowpark.context import get_active_session 40 | session = get_active_session() 41 | ``` 42 | 3. Follow the notebooks in this recommended order: 43 | - Start with `wands_ingest.ipynb` to set up your data 44 | - Explore `soft_query_boost_example.ipynb` for query optimization 45 | - Use `llm_judge_product_search.ipynb` for automated relevance tuning 46 | 47 | ## Data Requirements 48 | 49 | The notebooks work with several e-commerce datasets: 50 | - [WANDS (Wayfair Product Search Dataset)](https://github.com/wayfair/WANDS) 51 | - [TREC Product Search 2023](https://arxiv.org/abs/2311.07861) 52 | - [TREC Product Search 2024](https://trec-product-search.github.io/index.html) 53 | - [Amazon ESCI](https://github.com/amazon-research/esci-data) 54 | 55 | Note: Some datasets may require separate download and setup. Please refer to the individual dataset documentation for access instructions. 56 | 57 | ## Additional Resources 58 | 59 | - [Cortex Search Documentation](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/overview-tutorials) 60 | - [Cortex Search Technical Blog](https://www.snowflake.com/engineering-blog/cortex-search-and-retrieval-enterprise-ai/) 61 | - [Cortex Search Optimization App](../../examples/streamlit-evaluation/) 62 | 63 | ## License 64 | 65 | Please refer to the license terms for individual datasets and Snowflake's terms of service for Cortex Search usage. 66 | -------------------------------------------------------------------------------- /examples/02_rest_api_simple_usage/interactive_query.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import argparse 3 | 4 | import sys 5 | from pathlib import Path 6 | import json 7 | 8 | # Add the project root to sys.path 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | from src.utils import jwtutils, queryutils 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser( 15 | description="Interactive Cortex Search Service API Query CLI" 16 | ) 17 | parser.add_argument( 18 | "-u", 19 | "--url", 20 | help="Snowflake account URL", 21 | required=True, 22 | ) 23 | parser.add_argument( 24 | "-a", 25 | "--account", 26 | help="Snowflake account name, eg my_org-my_account", 27 | required=True, 28 | ) 29 | parser.add_argument( 30 | "-n", 31 | "--user-name", 32 | help="Snowflake user name", 33 | required=True, 34 | ) 35 | parser.add_argument( 36 | "-s", 37 | "--qualified-service-name", 38 | help="Qualified name of the Cortex Search Service, eg `MY_DB.MY_SCHEMA.MY_SERVICE`. Case insensitive.", 39 | required=True, 40 | ) 41 | parser.add_argument( 42 | "-k", 43 | "--private-key-path", 44 | help="Absolute local path to RSA private key", 45 | required=True, 46 | ) 47 | parser.add_argument( 48 | "-c", 49 | "--columns", 50 | help="Comma-separated list of columns to return", 51 | required=True, 52 | ) 53 | parser.add_argument( 54 | "-q", 55 | "--query", 56 | help="Query string", 57 | required=True, 58 | ) 59 | parser.add_argument( 60 | "-r", 61 | "--role", 62 | help="User role to use for queries. If provided, a session token scoped to this role will be generated for authenticating to the API.", 63 | required=False, 64 | ) 65 | 66 | args = parser.parse_args() 67 | 68 | search_service = queryutils.CortexSearchService( 69 | args.private_key_path, 70 | args.url, 71 | args.account, 72 | args.user_name, 73 | args.qualified_service_name, 74 | args.role, 75 | ) 76 | 77 | print("\n\nWelcome to the interactive Cortex Search Service query CLI!\n\n") 78 | 79 | while True: 80 | query_input = input("Enter your search query or type 'exit' to quit: ").strip() 81 | if query_input.lower() == "exit": 82 | print("Exiting the program.") 83 | break 84 | 85 | request_body = { 86 | "columns": args.columns.split(","), 87 | "query": args.query, 88 | "limit": 5, 89 | } 90 | 91 | response = search_service.search(request_body=request_body) 92 | 93 | if response is not None: 94 | if response.status_code != 200: 95 | print(f"Failed response with status code {response.status_code}") 96 | print(json.dumps(response.json(), indent=4)) 97 | else: 98 | print("Failed to fetch data.") 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /examples/06_streamlit_chatbot_app/README.md: -------------------------------------------------------------------------------- 1 | # RAG Chatbot with Cortex Search in Streamlit-in-Snowflake 2 | 3 | This repository walks you through creating a [Cortex Search Service](https://docs.snowflake.com/LIMITEDACCESS/cortex-search/cortex-search-overview) and chatting with it in a Streamlit-in-Snowflake interface. This application runs entirely in Snowflake. 4 | 5 | **Note**: Cortex Search is a private preview feature. To enable your account with the feature, please reach out to your account team. 6 | 7 | ## Instructions: 8 | - Open a Snowflake SQL environment and run `Section 1` of [`setup.sql`](./setup.sql). 9 | - Upload the files from [this folder](https://drive.google.com/drive/folders/1_erdfr7ZR49Ub2Sw-oGMJLs3KvJSap0d?usp=sharing) to the stage you created in Section 1. These files are recent FOMC minutes from the US Federal Reserve. 10 | - You can use this script with any set of files, but this example provides a specific set. You can follow the instructions to [upload files using Snowsight](https://docs.snowflake.com/en/user-guide/data-load-local-file-system-stage-ui) or, alternatively [via SnowSQL or drivers](https://docs.snowflake.com/en/user-guide/data-load-local-file-system-stage). 11 | - Note: The PyPDF2 library is not well optimized for PDFs that contain more complex structures such as tables, charts, images, etc. This parsing approach thus works best on simple, text-heavy PDFs. 12 | - Run Sections 2 - 4 of `setup.sql` to parse the files and create the Cortex Search Service. 13 | - In Snowsight, create a Streamlit in Snowflake in the same schema in which you created the Cortex Search Service (`demo_cortex_search.fomc`). In this streamlit app, add the following anaconda packages: 14 | - `snowflake==0.8.0` 15 | - `snowflake-ml-python==1.5.1` 16 | - Copy and paste the [`chat.py`](./chat.py) script into your Streamlit-in-Snowflake application. Select your Cortex Search Service from the sidebar, then start chatting! 17 | - You can congiure advanced chat settings in the sidebar of the streamlit app: 18 | - In the `Advanced Options` container, you can select the model used for answer generation, the number of context chunks used in each answer, and the depth of chat history messages to use in generation of a new response. 19 | - Toggling the `Debug` slider enables the printing of all context documents used in the model's answer in the sidebar 20 | - The `Session State` container shows the session state, including chat messages and currently-selected service. 21 | 22 | ## Sample queries: 23 | - **Example session 1**: multi-turn with point lookups 24 | - `how was gpd growth in q4 23?` 25 | - `how was unemployment in the same quarter?` 26 | - **Example 2**: summarizing multiple documents 27 | - `how has the fed's view of the market change over the course of 2024?` 28 | - **Example 3**: abstaining when the documents don't contain the right answer 29 | - `What was janet yellen's opinion about 2024 q1?` 30 | 31 | ## Components 32 | 1. [`setup.sql`](./setup.sql): this is a script that shows how to create a Cortex Search Service from a set of PDFs residing in a Snowflake stage, inlcuding parsing, chunking, and service creation. 33 | 3. [`chat.py`](./chat.py): this is a generic template for a Streamlit RAG chatbot that uses Cortex Search for retrieval and Cortex LLM Functions for Generation. This script is meant to be used in [Streamlit-in-Snowflake](https://docs.snowflake.com/en/developer-guide/streamlit/about-streamlit). The script shows off basic RAG orchestration techniques, with chat history summarization and prompt engineering. Feel free to customize it to your specific needs. You can use this script with any Cortex Search Service, not just the one we create in this tutorial. 34 | 35 | 36 | ## Improvements 37 | - Add in-line citations with links to source chunks 38 | - Add abstaining logic to the orchestration when no relevant documents were retrieved for the user's query 39 | - Add query classification to respond quickly to pleasantries -------------------------------------------------------------------------------- /projects/improving-catalog-search/soft_query_boost_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "7b3e0c9f-539c-4684-af76-35c7ef55de39", 7 | "metadata": { 8 | "scrolled": true 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "import pandas as pd\n", 13 | "from collections import defaultdict" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "5822d193-7cd5-4360-a809-dca5109e4332", 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "query: plantar fasciiti brace with ball\n", 27 | "mainterm: plantar fasciiti brace\n", 28 | "\n", 29 | "query: refill ink kit for printer\n", 30 | "mainterm: refill ink kit\n", 31 | "\n", 32 | "query: key fob cover for 4 runner\n", 33 | "mainterm: key fob cover\n", 34 | "\n", 35 | "query: phone cover for iphone 8 plus\n", 36 | "mainterm: phone cover\n", 37 | "\n", 38 | "query: wireless game controller for ipad\n", 39 | "mainterm: wireless game controller\n", 40 | "\n", 41 | "query: replacement cushion for headphone\n", 42 | "mainterm: replacement cushion\n", 43 | "\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "# 1. Load TREC24 product search dataset with QUERY field\n", 49 | "goldens_df = pd.read_parquet(\"/notebook/trec24_golden.parquet\")\n", 50 | "\n", 51 | "# 2. Extract mainterm from QUERY field using simple heuristic\n", 52 | "# - If the query contains a signal word (e.g., \"for\", \"with\", \"without\"), the mainterm is the substring before the signal word\n", 53 | "query_to_mainterm = defaultdict(dict)\n", 54 | "supplemental_signals = {'for', 'with', 'without'}\n", 55 | "\n", 56 | "for i in range(goldens_df.shape[0]):\n", 57 | " query_id = str(goldens_df.QUERY_ID[i])\n", 58 | " query = str(goldens_df.QUERY[i])\n", 59 | " tokens = query.split()\n", 60 | " mainterm = None\n", 61 | " for i, t in enumerate(tokens):\n", 62 | " if t.lower().strip(',').strip('.') in supplemental_signals:\n", 63 | " mainterm = ' '.join(tokens[:i])\n", 64 | " break\n", 65 | " query_to_mainterm[query_id][\"query\"] = query\n", 66 | " query_to_mainterm[query_id][\"mainterm\"] = mainterm\n", 67 | "\n", 68 | "# 3. Add SOFTBOOST field to goldens_df for Cortex Search\n", 69 | "goldens_df[\"SOFTBOOST\"] = goldens_df.QUERY_ID.apply(\n", 70 | " lambda query_id: (\n", 71 | " [{\"phrase\": query_to_mainterm[str(query_id)][\"mainterm\"]}]\n", 72 | " if query_to_mainterm[str(query_id)][\"mainterm\"] else []\n", 73 | " )\n", 74 | ")\n", 75 | "\n", 76 | "# 4. Sanity check with a few sample values.\n", 77 | "for i, data in enumerate(query_to_mainterm.values()):\n", 78 | " if i < 20 and data[\"mainterm\"]:\n", 79 | " print('query: ' + data[\"query\"])\n", 80 | " print('mainterm: ' + data[\"mainterm\"])\n", 81 | " print()" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "id": "2a865604", 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [] 91 | } 92 | ], 93 | "metadata": { 94 | "kernelspec": { 95 | "display_name": ".venv", 96 | "language": "python", 97 | "name": "python3" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.13.1" 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 5 114 | } 115 | -------------------------------------------------------------------------------- /examples/04_python_concurrent_queries/concurrent-example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import backoff 5 | 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | from snowflake.snowpark import Session 8 | from snowflake.core.exceptions import APIError 9 | from snowflake.core import Root 10 | from dotenv import load_dotenv 11 | 12 | # Setup exponentioal backoff on "too many requests" errors (status: 429), with 13 | # up to 10 retries. 14 | @backoff.on_exception(backoff.expo, 15 | APIError, 16 | max_time=60, # seconds 17 | max_tries=10, 18 | giveup=lambda x : x.status != 429) 19 | def runQuery(service, query, columns): 20 | return service.search(query=query, columns=columns, limit=1) 21 | 22 | def main(): 23 | # Set environment variables from the .env file. 24 | load_dotenv("./.env") 25 | 26 | # Parse command line arguments 27 | parser = argparse.ArgumentParser(description="Concurrent example") 28 | parser.add_argument("--host", default=os.getenv("SNOWFLAKE_HOST"), help="Snowflake host") 29 | parser.add_argument("-a", "--account", default=os.getenv("SNOWFLAKE_ACCOUNT", "CORTEXSEARCH"), help="Snowflake account") 30 | parser.add_argument("-u", "--user", default=os.getenv("SNOWFLAKE_USER", os.getenv("USER")), help="Snowflake user") 31 | parser.add_argument("-p", "--password", default=os.getenv("SNOWFLAKE_PASSWORD"), help="Snowflake password") 32 | parser.add_argument("-d", "--database", default=os.getenv("SNOWFLAKE_DATABASE"), help="Snowflake database") 33 | parser.add_argument("-s", "--schema", default=os.getenv("SNOWFLAKE_SCHEMA"), help="Snowflake schema") 34 | parser.add_argument("-c", "--cortex_search_service", default=os.getenv("SNOWFLAKE_CORTEX_SEARCH_SERVICE"), help="Cortex search service") 35 | parser.add_argument("-q", "--queries", default="queries", help="File containing queries to be executed") 36 | parser.add_argument("--authenticator", default=os.getenv("SNOWFLAKE_AUTHENTICATOR"), help="Snowflake authenticator, e.g. externalbrowser. Must not be used with --password") 37 | parser.add_argument("--columns", nargs="+", default=[], help="Columns to be returned") 38 | args = parser.parse_args() 39 | 40 | # Create a session and a root object 41 | config = { 42 | "account": args.account, 43 | "user": args.user, 44 | } 45 | if args.host: 46 | config["host"] = args.host 47 | 48 | if args.authenticator: 49 | config["authenticator"] = args.authenticator 50 | elif args.password: 51 | config["password"] = args.password 52 | else: 53 | raise ValueError("Either password or authenticator must be provided") 54 | 55 | session = Session.builder.configs(config).create() 56 | root = Root(session) 57 | search_service = root.databases[args.database].schemas[args.schema].cortex_search_services[args.cortex_search_service] 58 | 59 | # Read a set of queries to be executed on the service above. 60 | queries = [] 61 | with open(args.queries, "r") as f: 62 | queries = f.readlines() 63 | 64 | # Execute queries in parallel with up to 20 concurrent threads 65 | with ThreadPoolExecutor(max_workers=20) as executor: 66 | fq = {executor.submit(runQuery, search_service, query, args.columns): query for query in queries if len(query) > 0} 67 | for f in as_completed(fq): 68 | query = fq[f] 69 | if len(query) > 40: 70 | query = query[:40] + "..." 71 | try: 72 | result = f.result() 73 | except APIError as ae: 74 | body = json.loads(ae.body) 75 | print("Failed to execute query %s, (status: %s, code: %d, message: %s", (query, ae.status, body["code"], body["message"])) 76 | else: 77 | if (len(result.results) > 0): 78 | print("Top result for query %s : " % query, end=" ") 79 | for (key, value) in result.results[0].items(): 80 | print(f"{key}={value}", sep=",") 81 | else: 82 | print("No results for query %s", query) 83 | 84 | if __name__=="__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /examples/06_streamlit_chatbot_app/setup.sql: -------------------------------------------------------------------------------- 1 | -- *************************************** 2 | -- * Section 1: Setup * 3 | -- *************************************** 4 | 5 | -- Create the database, schema, stage, and warehouse for the demo. 6 | CREATE DATABASE demo_cortex_search; 7 | CREATE SCHEMA fomc; 8 | CREATE STAGE minutes 9 | DIRECTORY = ( ENABLE = true ) 10 | ENCRYPTION = ( TYPE = 'SNOWFLAKE_SSE' ); 11 | CREATE WAREHOUSE demo_cortex_search_wh; 12 | 13 | -- Upload pdfs to demo_cortex_search.fomc.minutes from https://drive.google.com/drive/folders/1_erdfr7ZR49Ub2Sw-oGMJLs3KvJSap0d?usp=sharing 14 | 15 | 16 | -- *************************************** 17 | -- * Section 2: UDTF Creation * 18 | -- *************************************** 19 | 20 | -- This section creates a user-defined table function (UDTF) to parse PDFs and chunk the extracted text. 21 | CREATE OR REPLACE FUNCTION pypdf_extract_and_chunk(file_url VARCHAR, chunk_size INTEGER, overlap INTEGER) 22 | RETURNS TABLE (chunk VARCHAR) 23 | LANGUAGE PYTHON 24 | RUNTIME_VERSION = '3.9' 25 | HANDLER = 'pdf_text_chunker' 26 | PACKAGES = ('snowflake-snowpark-python','PyPDF2', 'langchain') 27 | AS 28 | $$ 29 | from snowflake.snowpark.types import StringType, StructField, StructType 30 | from langchain.text_splitter import RecursiveCharacterTextSplitter 31 | from snowflake.snowpark.files import SnowflakeFile 32 | import PyPDF2, io 33 | import logging 34 | import pandas as pd 35 | 36 | class pdf_text_chunker: 37 | 38 | def read_pdf(self, file_url: str) -> str: 39 | 40 | logger = logging.getLogger("udf_logger") 41 | logger.info(f"Opening file {file_url}") 42 | 43 | with SnowflakeFile.open(file_url, 'rb') as f: 44 | buffer = io.BytesIO(f.readall()) 45 | 46 | reader = PyPDF2.PdfReader(buffer) 47 | text = "" 48 | for page in reader.pages: 49 | try: 50 | text += page.extract_text().replace('\n', ' ').replace('\0', ' ') 51 | except: 52 | text = "Unable to Extract" 53 | logger.warn(f"Unable to extract from file {file_url}, page {page}") 54 | 55 | return text 56 | 57 | 58 | def process(self,file_url: str, chunk_size: int, chunk_overlap: int): 59 | 60 | text = self.read_pdf(file_url) 61 | 62 | text_splitter = RecursiveCharacterTextSplitter( 63 | chunk_size = chunk_size, 64 | chunk_overlap = chunk_overlap, 65 | length_function = len 66 | ) 67 | 68 | chunks = text_splitter.split_text(text) 69 | df = pd.DataFrame(chunks, columns=['CHUNK']) 70 | 71 | yield from df.itertuples(index=False, name=None) 72 | $$; 73 | 74 | -- *************************************** 75 | -- * Section 3: Parse PDFs with UDTF * 76 | -- *************************************** 77 | 78 | -- This section parses the PDFs, chunks the output, and inserts the chunked documents into a table. 79 | CREATE OR REPLACE TABLE parsed_doc_chunks ( 80 | relative_path VARCHAR, -- Relative path to the PDF file 81 | chunk VARCHAR 82 | ) AS ( 83 | SELECT 84 | relative_path, 85 | chunks.chunk as chunk 86 | FROM 87 | directory(@DEMO_CORTEX_SEARCH.FOMC.MINUTES) 88 | , TABLE(pypdf_extract_and_chunk( 89 | build_scoped_file_url(@MINUTES, relative_path), 90 | 2000, 91 | 500 92 | )) as chunks 93 | ); 94 | 95 | 96 | -- *************************************** 97 | -- * Section 4: Create Cortex Search Service * 98 | -- *************************************** 99 | 100 | -- This section creates a Cortex search service with the parsed pdf data. 101 | 102 | SELECT * FROM parsed_doc_chunks; -- preview data 103 | 104 | CREATE OR REPLACE CORTEX SEARCH SERVICE fomc_minutes_search_service 105 | ON minutes 106 | ATTRIBUTES relative_path 107 | WAREHOUSE = demo_cortex_search_wh 108 | TARGET_LAG = '1 hour' 109 | AS ( 110 | SELECT 111 | LEFT(RIGHT(relative_path, 12), 8) as meeting_date, 112 | CONCAT('Meeting date: ', meeting_date, ' \nMinutes: ', chunk) as minutes, 113 | relative_path 114 | FROM parsed_doc_chunks 115 | ); 116 | 117 | -- grant usage to public role 118 | GRANT USAGE ON CORTEX SEARCH SERVICE fomc_minutes_search_service TO ROLE public; 119 | GRANT USAGE ON DATABASE demo_cortex_search to role public; 120 | GRANT USAGE ON SCHEMA demo_cortex_search.fomc to role public; 121 | GRANT READ ON STAGE demo_cortex_search.fomc.minutes to role public; -------------------------------------------------------------------------------- /examples/02_rest_api_simple_usage/README.md: -------------------------------------------------------------------------------- 1 | # Cortex Search REST usage 2 | 3 | This directory contains example usage for the Cortex Search REST API using pure python (and libraries such as `requests`). Authentication for this method requires a JWT as described below. 4 | 5 | ## Prerequisites 6 | 7 | Before you can run the examples, ensure you have the following prerequisites installed: 8 | 9 | - Python 3.x 10 | - pip (Python package installer) 11 | Additionally, you must have access to a Snowflake account and the required permissions to query the Cortex Search Service at the specified database and schema. 12 | 13 | ## Installation 14 | 15 | First, clone this repository to your local machine using git and navigate to this directory: 16 | 17 | ``` 18 | git clone https://github.com/snowflake-labs/cortex-search.git 19 | cd cortex-search/examples/02_rest_api_simple_usage 20 | ``` 21 | 22 | Install the necessary Python dependencies by running: 23 | 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Key-pair auth configuration 29 | 30 | Additionally, you must generate a private key for JWT auth with Snowflake as described in [this document](https://docs.snowflake.com/user-guide/key-pair-auth#configuring-key-pair-authentication). 31 | 32 | **Note**: take note of the path to your generated RSA private key, e.g., `/path/to/my/rsa_key.p8` -- you will need to supply this as the `--private-key-path` parameter to query the service later from the command line, or list the path to the file from within a notebook. 33 | 34 | ## Usage 35 | 36 | ### 1. Notebook usage 37 | 38 | The [examples/notebook_query.ipynb file](https://github.com/Snowflake-Labs/cortex-search/blob/main/examples/notebook_query.ipynb) shows an example of querying the service from within a Jupyter Notebook. 39 | 40 | ### 2. Python script 41 | 42 | The `simple_query.py` example script can be executed from the command line. For instance: 43 | 44 | ``` 45 | python3 examples/simple_query.py -u https://my_org-my_account.us-west-2.aws.snowflakecomputing.com -s MY_DB.MY_SCHEMA.MY_SERVICE_NAME -q "the sky is blue" -c "description,text" -l 10 -a my_account -k /path/to/my/rsa_key.p8 -n my_name 46 | ``` 47 | 48 | **Arguments:** 49 | 50 | - `-u`, `--url`: URL of the Snowflake instance. See [this guide](https://docs.snowflake.com/en/user-guide/admin-account-identifier#finding-the-organization-and-account-name-for-an-account) for finding your Account URL 51 | - `-s`, `--qualified-service-name`: The fully-qualified Cortex Search Service name, in the format DATABASE.SCHEMA.SERVICE 52 | - `-q`, `--query`: The search query string 53 | - `-c`, `--columns`: Comma-separated list of columns to return in the results 54 | - `-l`, `--limit`: The max number of results to return 55 | - `-a`, `--account`: Snowflake account name. See [this guide](https://docs.snowflake.com/en/user-guide/admin-account-identifier#finding-the-organization-and-account-name-for-an-account) for finding your Account name 56 | - `-k`, `--private-key-path`: Path to the RSA private key file for authentication. 57 | - `-n`, `--user-name`: Username for the Snowflake account 58 | - `-r`, `--role`: Role to use for the query. If provided, a session token scoped to this role will be created and used for authentication to the API. 59 | 60 | The `interactive_query.py` example provides an interactive CLI that demonstrates caching the JWT used for authentication between requests for better performance and implements retries when the JWT has expired. You can run it like the following: 61 | 62 | ``` 63 | python3 examples/interactive_query.py -u https://my_org-my_account.us-west-2.aws.snowflakecomputing.com -s DB.SCHEMA.SERVICE_NAME -c "description,text" -a my_account -k /path/to/my/rsa_key.p8 -n my_name 64 | ``` 65 | 66 | This will launch an interactive session, where you will be prompted repeatedly for search queries to your Cortex Search Service. 67 | 68 | ### 3. Command line usage (cURL) 69 | 70 | First, generate a JWT. For instance, if you have a private RSA key at the relative path `rsa_key.p8`, you can run the following from a shell (passing your account and user): 71 | 72 | `snowsql --private-key-path rsa_key.p8 --generate-jwt -a my_org-my_account -u my_name` 73 | 74 | Then, export the following variables in your shell session: 75 | 76 | ``` 77 | export CORTEX_SEARCH_JWT= 78 | export CORTEX_SEARCH_DATABASE=MY_DB 79 | export CORTEX_SEARCH_SCHEMA=MY_SCHEMA 80 | export CORTEX_SEARCH_SERVICE_NAME=MY_SERVICE_NAME 81 | export CORTEX_SEARCH_BASE_URL='https://my_org-my_account.us-west-2.aws.snowflakecomputing.com' 82 | ``` 83 | 84 | Then, you can run the following cURL command (modifiying the `data` passed as needed): 85 | 86 | ``` 87 | curl --location "$CORTEX_SEARCH_BASE_URL/api/v2/databases/$CORTEX_SEARCH_DATABASE/schemas/$CORTEX_SEARCH_SCHEMA/cortex-search-services/$CORTEX_SEARCH_SERVICE_NAME:query" \ 88 | --header 'X-Snowflake-Authorization-Token-Type: KEYPAIR_JWT' \ 89 | --header 'Content-Type: application/json' \ 90 | --header 'Accept: application/json' \ 91 | --header "Authorization: Bearer $CORTEX_SEARCH_JWT" \ 92 | --data '{ 93 | "query": the sky is blue", 94 | "columns": ["description", "text"], 95 | "limit": 10 96 | }' 97 | ``` 98 | 99 | ## License 100 | 101 | Apache Version 2.0 102 | -------------------------------------------------------------------------------- /examples/02_rest_api_simple_usage/src/utils/queryutils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import argparse 3 | 4 | import sys 5 | from pathlib import Path 6 | import json 7 | 8 | # Add the project root to sys.path 9 | sys.path.append(str(Path(__file__).parent.parent)) 10 | from src.utils import jwtutils 11 | 12 | 13 | class CortexSearchService: 14 | 15 | private_key_path = None 16 | 17 | base_url = None 18 | account = None 19 | user = None 20 | role = None 21 | 22 | database = None 23 | schema = None 24 | service = None 25 | 26 | cached_jwt = None 27 | session_token = None 28 | 29 | def __init__( 30 | self, 31 | private_key_path: str, 32 | account_url: str, 33 | account: str, 34 | user: str, 35 | qualified_service_name: str, 36 | role: str | None = None, 37 | ): 38 | self.private_key_path = private_key_path 39 | self.base_url = account_url 40 | self.account = account 41 | self.user = user 42 | service_name_parts = qualified_service_name.split(".") 43 | if len(service_name_parts) != 3: 44 | raise ValueError( 45 | f"Expected qualified name to have DB, schema, and name components; got {qualified_service_name}" 46 | ) 47 | self.database = service_name_parts[0] 48 | self.schema = service_name_parts[1] 49 | self.service = service_name_parts[2] 50 | self.role = role 51 | 52 | def search(self, request_body: dict[str, any]) -> requests.Response: 53 | """ 54 | Perform a POST request to the Cortex Search query API. 55 | 56 | :param request_body: The query body. 57 | :return: The requests.Response object. 58 | """ 59 | url = self._make_query_url() 60 | if self.role is None: 61 | return self._make_request_with_jwt(url, request_body) 62 | 63 | return self._make_request_with_session_token(url, request_body) 64 | 65 | def _make_headers(self, use_jwt=True) -> dict[str, str]: 66 | if not use_jwt: 67 | return self._make_headers_with_session_token() 68 | if self.cached_jwt is None: 69 | self.cached_jwt = jwtutils.generate_JWT_token( 70 | self.private_key_path, self.account, self.user 71 | ) 72 | headers = { 73 | "X-Snowflake-Authorization-Token-Type": "KEYPAIR_JWT", 74 | "Content-Type": "application/json", 75 | "Accept": "application/json", 76 | "Authorization": f"Bearer {self.cached_jwt}", 77 | } 78 | return headers 79 | 80 | def _make_headers_with_session_token(self) -> dict[str, str]: 81 | if self.session_token is None: 82 | self.session_token = self._create_session_token() 83 | headers = { 84 | "Content-Type": "application/json", 85 | "Accept": "application/json", 86 | "Authorization": f'Snowflake Token="{self.session_token}"', 87 | } 88 | return headers 89 | 90 | def _create_session_token(self) -> str: 91 | url = self._make_sessions_url() 92 | request_body = {"roleName": self.role} 93 | res = self._make_request_with_jwt(url, request_body).json() 94 | return res["token"] 95 | 96 | def _make_query_url(self) -> str: 97 | return f"{self.base_url}/api/v2/databases/{self.database}/schemas/{self.schema}/cortex-search-services/{self.service}:query" 98 | 99 | def _make_sessions_url(self) -> str: 100 | return f"{self.base_url}/api/v2/sessions" 101 | 102 | def _make_request_with_session_token( 103 | self, url, request_body 104 | ) -> requests.Response | None: 105 | try: 106 | headers = self._make_headers(use_jwt=False) 107 | response = requests.post(url, headers=headers, json=request_body) 108 | response.raise_for_status() 109 | return response 110 | except requests.exceptions.HTTPError as http_err: 111 | print(f"HTTP error occurred: {http_err}") 112 | return response 113 | 114 | def _make_request_with_jwt( 115 | self, url, request_body, retry_for_invalid_jwt=True 116 | ) -> requests.Response | None: 117 | try: 118 | headers = self._make_headers() 119 | response = requests.post(url, headers=headers, json=request_body) 120 | response.raise_for_status() 121 | return response 122 | except requests.exceptions.HTTPError as http_err: 123 | # An invalid JWT may be due to expiration of the cached JWT. 124 | if retry_for_invalid_jwt and self._is_invalid_jwt_response(response): 125 | # Only retry with a fresh JWT once. 126 | print("JWT invalid, trying with a fresh JWT...") 127 | self.cached_jwt = None 128 | return self._make_request_with_jwt( 129 | url, request_body, retry_for_invalid_jwt=False 130 | ) 131 | print(f"HTTP error occurred: {http_err}") 132 | return response 133 | except Exception as err: 134 | print(f"An error occurred: {err}") 135 | return None 136 | 137 | def _is_invalid_jwt_response(self, response: requests.Response) -> bool: 138 | return response.status_code == 401 and response.json()["code"] == "390144" 139 | -------------------------------------------------------------------------------- /examples/03_batch_querying_from_sql/batch_sproc.sql: -------------------------------------------------------------------------------- 1 | /************************************************************************************ 2 | Name: batch_sproc.sql 3 | 4 | Purpose: This script creates a 'batch_cortex_search' Stored Procedure that calls a 5 | Cortex Search Service with parallelism on an array of supplied query values. The script 6 | shows the creation of the SProc and sample invocation on queries from a table column. 7 | 8 | This script assumes you have an existing Cortex Search Service and table containing 9 | queries you'd like to issue against it. 10 | /************************************************************************************ 11 | 12 | /************************************************************************************ 13 | 1. Create the SProc 14 | ************************************************************************************/ 15 | 16 | CREATE OR REPLACE PROCEDURE batch_cortex_search(db_name STRING, schema_name STRING, service_name STRING, queries ARRAY, filters ARRAY, columns ARRAY, n_jobs INTEGER DEFAULT -1) 17 | RETURNS VARIANT 18 | LANGUAGE PYTHON 19 | PACKAGES = ('snowflake-snowpark-python==1.9.0', 'joblib==1.4.2', 'backoff==2.2.1') 20 | RUNTIME_VERSION = '3.10' 21 | HANDLER = 'main' 22 | as 23 | $$ 24 | import _snowflake 25 | import json 26 | import time 27 | from joblib import Parallel, delayed 28 | import backoff 29 | 30 | @backoff.on_exception(backoff.expo, Exception, max_tries=5, giveup=lambda e: not (isinstance(e, Exception) and hasattr(e, "args") and len(e.args) > 0 and isinstance(e.args[0], dict) and e.args[0].get("status") == 429)) 31 | def call_api(db_name, schema_name, service_name, request_body): 32 | """Calls the Cortex Search REST API with retry logic for rate limiting.""" 33 | resp = _snowflake.send_snow_api_request( 34 | "POST", 35 | f"/api/v2/databases/{db_name}/schemas/{schema_name}/cortex-search-services/{service_name}:query", 36 | {}, 37 | {}, 38 | request_body, 39 | {}, 40 | 30000, 41 | ) 42 | if resp["status"] == 429: 43 | raise Exception({"status": resp["status"], "content": resp["content"]}) 44 | return resp 45 | 46 | def search(db_name, schema_name, service_name, query, columns, filter): 47 | """Calls the Cortex Search REST API and returns the response.""" 48 | 49 | request_body = { 50 | "query": query, 51 | "columns": columns, 52 | "filter": filter, 53 | "limit": 5 54 | } 55 | try: 56 | resp = call_api(db_name, schema_name, service_name, request_body) 57 | if resp["status"] < 400: 58 | response_content = json.loads(resp["content"]) 59 | results = response_content.get("results", []) 60 | return {"query": query, "filter": filter, "results": results} 61 | else: 62 | return {"query": query, "filter": filter, "results": f"Failed request with status {resp['status']}: {resp}"} 63 | except Exception as e: 64 | return {"query": query, "filter": filter, "results": f"API Error: {e}"} 65 | 66 | def concurrent_searches(db_name, schema_name, service_name, queries, filters, columns, n_jobs): 67 | """Calls the Cortex Search REST API for multiple queries and returns the response.""" 68 | 69 | results = Parallel(n_jobs=n_jobs, backend='threading')( 70 | delayed(search)(db_name, schema_name, service_name, q, columns, f) for q, f in zip(queries, filters) 71 | ) 72 | responses = results 73 | return responses 74 | 75 | def main(db_name, schema_name, service_name, queries, filters, columns, n_jobs): 76 | if isinstance(queries, list) and isinstance(filters, list): 77 | if len(queries) == len(filters): 78 | if len(queries) >= 1: 79 | return concurrent_searches(db_name, schema_name, service_name, queries, filters, columns, n_jobs) 80 | else: 81 | raise ValueError("Queries must be an array of query text") 82 | else: 83 | raise ValueError("Queries and filters must have the same length") 84 | else: 85 | raise ValueError("Queries and filters must be an array of query text") 86 | $$; 87 | 88 | /************************************************************************************ 89 | 2. Call the SProc and materialize results in a table 90 | 91 | Variables: 92 | - : the table containing queries in a column 93 | - : the name of the search column in the 94 | - : the name of the filter column in the 95 | - : a desired column to return from the Cortex Search Service 96 | 97 | Note: if you do not have filters to apply to queries, you can specify an empty object 98 | for the FILTERS argument like: 99 | 100 | `(SELECT ARRAY_AGG({}) FROM )` 101 | 102 | ************************************************************************************/ 103 | CALL batch_cortex_search( 104 | '', 105 | '', 106 | '', 107 | (SELECT ARRAY_AGG() FROM ), 108 | (SELECT ARRAY_AGG() FROM ), 109 | ARRAY_CONSTRUCT('', '', ...), 110 | -1 111 | ); 112 | 113 | CREATE OR REPLACE TEMP TABLE RESULTS AS 114 | SELECT * FROM TABLE(RESULT_SCAN(LAST_QUERY_ID())); 115 | 116 | /************************************************************************************ 117 | 3. View results 118 | ************************************************************************************/ 119 | SELECT 120 | value['query'] as query, 121 | value['results'][0]['col_to_return_1'] as col_1, 122 | value['results'][0]['col_to_return_2'] as col_2, 123 | -- .. 124 | value['filter'] as filter 125 | FROM RESULTS r, LATERAL FLATTEN(r.batch_cortex_search) -------------------------------------------------------------------------------- /examples/07_streamlit_search_evaluation_app/README.md: -------------------------------------------------------------------------------- 1 | # Tutorial: Evaluating Cortex Search Quality with Streamlit 2 | 3 | _Last updated: Jan 15, 2024_ 4 | 5 | This tutorial walks you through Cortex Search Quality evaluation in a Streamlit-in-Snowflake app. By the end of this tutorial, you will have generated and run a quantitative evaluation of search quality on your use-case for a given Cortex Search Service. 6 | 7 | ## Prerequisites 8 | 9 | There are three objects you’ll need before beginning the evaluation process: 10 | 11 | * A [Cortex Search Service](https://docs.snowflake.com/user-guide/snowflake-cortex/cortex-search/cortex-search-overview) 12 | * A query table ([example](https://docs.google.com/spreadsheets/d/1q4RMplovT5lyt-zC4Y-ncf_sl4f8qEn6ydIVwkCSqP8/edit?gid=214438211#gid=214438211), description below) 13 | * A relevancy table ([example](https://docs.google.com/spreadsheets/d/1q4RMplovT5lyt-zC4Y-ncf_sl4f8qEn6ydIVwkCSqP8/edit?gid=0#gid=0), description below) 14 | 15 | ### Building a query table 16 | 17 | First, you’ll need to have a query table. Queries are the basic “inputs” in a search system. A query table has the set of questions, representative of your production workload, that you would like to retrieve answers for using Cortex Search Service. The table consists of just one column “QUERY”, and looks like the following: 18 | 19 | | QUERY | 20 | | :---- | 21 | | | 22 | 23 | See [here](https://docs.google.com/spreadsheets/d/1q4RMplovT5lyt-zC4Y-ncf_sl4f8qEn6ydIVwkCSqP8/edit?gid=214438211#gid=214438211) for an example of a small QUERY table. 24 | 25 | The best source of a query table is real-life customer queries. However, if you don’t have a list of real customer queries, you could consider generating them manually by reading through your document corpus and generating queries that should hit random documents, or synthetically with an LLM. One way to generate synthetic queries is to query an LLM for a given document using a prompt “*Given this text: {insert text here}, generate 5 natural search queries users might input to find this.*” 26 | 27 | ### Building a relevancy table 28 | 29 | Then, you’ll need a relevancy table. You can think of this table as the labeled “outputs” for your query “inputs”, generated above. The relevancy table contains a set of (query, doc, relevancy) pairs, where relevancy is the perceived relevance of the document to the given query. Relevancy is represented by a *relevance score*, defined on a scale from **0 to 3**, where: 30 | 31 | * **0 (Irrelevant):** The text does not address or relate to the query in any meaningful way. 32 | * **1 (Slightly Relevant):** The text contains minor or tangential information that may only loosely relate to the query. 33 | * **2 (Somewhat Relevant):** The text provides partial or incomplete information related to the query but does not fully satisfy the intent. 34 | * **3 (Perfectly Relevant):** The text fully and comprehensively addresses the query, answering it effectively and directly. 35 | 36 | The relevancy table should have three columns, query, the text column name for the Cortex Search Service and the relevancy score of the text with respect to the query. The relevancy table should look like the following: 37 | 38 | | QUERY | *``* | RELEVANCY | 39 | | :---- | :---- | :---- | 40 | | | | | 41 | 42 | Note: *``* is the name of the search column in your Cortex Search Service. 43 | 44 | See [here](https://docs.google.com/spreadsheets/d/1q4RMplovT5lyt-zC4Y-ncf_sl4f8qEn6ydIVwkCSqP8/edit?gid=0#gid=0) for an example of a small RELEVANCY table. 45 | 46 | Now that you have these three objects, you’re ready to get started with the evaluation process. 47 | 48 | ## Step 1: Create the Evaluation Streamlit-in-Snowflake App 49 | 50 | * Create a new Streamlit-in-Snowflake App. For help on creating a Streamlit in Snowflake application, see [here](https://docs.snowflake.com/en/developer-guide/streamlit/create-streamlit-ui). 51 | * Install the following packages in the Streamlit app: 52 | Python: \>= 3.11 53 | * `snowflake` 54 | * `streamlit` 55 | * `snowflake-snowpark-python` 56 | * `scikit-optimize` \>= 0.9.0 57 | * Copy and paste the contents of the [eval.py file here](https://github.com/Snowflake-Labs/cortex-search/blob/main/examples/streamlit-evaluation/eval.py) into your new application and click `Run`. 58 | 59 | ## Step 2: Run the app 60 | 61 | Now, you’re ready to get started with the evaluation process. There are in-app instructions for you to follow for running the evaluation. 62 | 63 | ## Step 3: Interpret metrics 64 | 65 | 🎉 **Evaluation Completed\!** 66 | You have successfully run an evaluation of the Cortex Search Service against your dataset. Here’s a summary to help you interpret the results: 67 | 68 | 1. **Section: Static results:** Key metrics summarizing the retrieval performance of the Cortex Search Service, averaged across all queries in the dataset, along with run metadata. 69 | 2. **Section: Interactive Results (at a chosen k)**. This section computes metrics based on the chosen result limit, or `k` ,value. Choose a value of k that suits your downstream use case (commonly 5 or 10). 70 | 1. Aggregate Summary: High-level retrieval metrics calculated at k, with short explanations. 71 | 2. Per Query Metrics: Allows you to drill down into individual query performance for failed 72 | 1. Sort results: Click on column headers to sort. For example, if you are curious to know what queries had low NDCG, you could click on the NDCG@k column to sort such queries. 73 | 2. Inspect Specific Queries: Click the index of any query row to explore details. View all top-k retrieved documents for the selected query, along with their Cortex Search Service scores. 74 | 3. Download all queries and scores: Export all top-k results, including scores for all queries, by clicking the Download button. 75 | 76 | ## Step 4 (optional): Modify data, re-run 77 | 78 | If you update your Cortex Search Service definition or query/relevancy tables, you can then rerun the evaluation to check the updated results. 79 | -------------------------------------------------------------------------------- /examples/05_streamlit_ai_search_app/cortex_search.py: -------------------------------------------------------------------------------- 1 | import streamlit as st # Import python packages 2 | from snowflake.core import Root 3 | 4 | MODELS = [ 5 | "mistral-large", 6 | "snowflake-arctic", 7 | "llama3-70b", 8 | "llama3-8b", 9 | ] 10 | 11 | @st.cache_resource 12 | def make_session(): 13 | """ 14 | Initialize the connection to snowflake using the `conn` connection stored in `.streamlit/secrets.toml` 15 | """ 16 | conn = st.connection("conn", type="snowflake") 17 | return conn.session() 18 | 19 | def get_available_search_services(): 20 | """ 21 | Returns list of cortex search services available in the current schema 22 | """ 23 | search_service_results = session.sql(f"SHOW CORTEX SEARCH SERVICES IN SCHEMA {db}.{schema}").collect() 24 | return [svc.name for svc in search_service_results] 25 | 26 | def get_search_column(svc): 27 | """ 28 | Returns the name of the search column for the provided cortex search service 29 | """ 30 | search_service_result = session.sql(f"DESC CORTEX SEARCH SERVICE {svc}").collect()[0] 31 | return search_service_result.search_column 32 | 33 | def init_layout(): 34 | st.title("Cortex AI Search and Summary") 35 | st.sidebar.markdown(f"Current database and schema: `{db}.{schema}`".replace('"', '')) 36 | 37 | def init_config_options(): 38 | """ 39 | Initialize sidebar configuration options 40 | """ 41 | st.text_area("Search:", value="", key="query", height=100) 42 | st.sidebar.selectbox("Cortex Search Service", get_available_search_services(), key="cortex_search_service") 43 | st.sidebar.number_input("Results", value=5, key="limit", min_value=3, max_value=10) 44 | st.sidebar.selectbox("Summarization model", MODELS, key="model") 45 | st.sidebar.toggle("Summarize", key="summarize", value = False) 46 | 47 | def query_cortex_search_service(query): 48 | """ 49 | Queries the cortex search service in the session state and returns a list of results 50 | """ 51 | cortex_search_service = ( 52 | root 53 | .databases[db] 54 | .schemas[schema] 55 | .cortex_search_services[st.session_state.cortex_search_service] 56 | ) 57 | context_documents = cortex_search_service.search(query, [], limit=st.session_state.limit) 58 | return context_documents.results 59 | 60 | def complete(model, prompt): 61 | """ 62 | Queries the cortex COMPLETE LLM function with the provided model and prompt 63 | """ 64 | try: 65 | resp = session.sql("select snowflake.cortex.complete(?, ?)", params=(model, prompt)).collect()[0][0] 66 | except Exception as e: 67 | resp = f"COMPLETE error: {e}" 68 | return resp 69 | 70 | def summarize_search_results(results, query, search_col): 71 | """ 72 | Returns an AI summary of the search results based on the user's query 73 | """ 74 | search_result_str = "" 75 | for i, r in enumerate(results): 76 | search_result_str += f"Result {i+1}: {r[search_col]} \n" 77 | 78 | prompt = f""" 79 | [INST] 80 | You are a helpful AI Assistant embedded in a search application. You will be provided a user's search query and a set of search result documents. 81 | Your task is to provide a concise, summarized answer to the user's query with the help of the provided the search results. 82 | 83 | The user's query will be included in the tag. 84 | The search results will be provided in JSON format in the tag. 85 | 86 | Here are the critical rules that you MUST abide by: 87 | - You must only use the provided search result documents to generate your summary. Do not fabricate any answers or use your prior knowledge. 88 | - You must only summarize the search result documents that are relevant to the user's query. Do not reference any search results that aren't related to the user's query. If none of the provided search results are relevant to the user's query, reply with "My apologies, I am unable to answer that question with the provided search results". 89 | - You must keep your summary to less than 10 sentences. You are encouraged to use bulleted lists, where stylistically appropriate. 90 | - Only respond with the summary without any extra explantion. Do not include any sentences like 'Sure, here is an explanation...'. 91 | 92 | 93 | {query} 94 | 95 | 96 | 97 | {search_result_str} 98 | 99 | 100 | [/INST] 101 | """ 102 | 103 | resp = complete(st.session_state.model, prompt) 104 | return resp 105 | 106 | def display_summary(summary): 107 | """ 108 | Display the AI summary in the UI 109 | """ 110 | st.subheader("AI summary") 111 | container = st.container(border=True) 112 | container.markdown(summary) 113 | 114 | def display_search_results(results, search_col): 115 | """ 116 | Display the search results in the UI 117 | """ 118 | st.subheader("Search results") 119 | for i, result in enumerate(results): 120 | container = st.expander(f"Result {i+1}", expanded=True) 121 | container.markdown(result[search_col]) 122 | 123 | def main(): 124 | init_layout() 125 | init_config_options() 126 | 127 | # run chat engine 128 | if not st.session_state.query: 129 | return 130 | results = query_cortex_search_service(st.session_state.query) 131 | search_col = get_search_column(st.session_state.cortex_search_service) 132 | if st.session_state.summarize: 133 | with st.spinner("Summarizing results..."): 134 | summary = summarize_search_results(results, st.session_state.query, search_col) 135 | display_summary(summary) 136 | display_search_results(results, search_col) 137 | 138 | 139 | if __name__ == "__main__": 140 | st.set_page_config(page_title="Cortex AI Search and Summary", layout="wide") 141 | 142 | session = make_session() 143 | root = Root(session) 144 | db, schema = session.get_current_database(), session.get_current_schema() 145 | 146 | main() -------------------------------------------------------------------------------- /projects/improving-catalog-search/cortex_search_setup_ecommerce_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# install required packages\n", 10 | "! pip install snowflake-snowpark-python==1.26.0\n", 11 | "! pip install snowflake.core==1.0.2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# Import python packages\n", 21 | "from snowflake.snowpark import Session\n", 22 | "from snowflake.core import Root\n", 23 | "import pandas as pd\n", 24 | "\n", 25 | "connection_parameters = {\n", 26 | " \"account\": \"your_account_name\",\n", 27 | " \"user\": \"your_username\",\n", 28 | " \"host\": \"your_host\",\n", 29 | " \"password\": \"your_password\",\n", 30 | " \"role\": \"your_role\",\n", 31 | " \"warehouse\": \"your_warehouse\",\n", 32 | " \"database\": \"your_database\",\n", 33 | " \"schema\": \"your_schema\"\n", 34 | "}\n", 35 | "\n", 36 | "session = Session.builder.configs(connection_parameters).create()\n", 37 | "root = Root(session)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "We will first download the wands data" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "! git clone https://github.com/wayfair/WANDS.git # Clone the WANDS repository" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "Before creating the cortex search service. We need to create a text column which has all the information we want to search upon. Hence we will create a text column with all the columns." 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# parsing product features\n", 70 | "def get_features(product_features: str):\n", 71 | " features = \"\"\n", 72 | " if product_features:\n", 73 | " for feature in product_features.split(\"|\"):\n", 74 | " pair = feature.split(\":\")\n", 75 | " if len(pair) >= 2 and pair[0] and pair[1]:\n", 76 | " key = pair[0].strip()\n", 77 | " value = pair[1].strip()\n", 78 | " features += f\"{key}: {value} \"\n", 79 | " return features\n", 80 | "\n", 81 | "# Function to create a single text column from multiple columns\n", 82 | "# this will be used to create the text column for the search index\n", 83 | "def wands_text(row):\n", 84 | " text = \"\"\n", 85 | " if row[\"product_name\"]:\n", 86 | " text += f\"Name: {str(row['product_name']).strip()} \"\n", 87 | " if row[\"product_class\"]:\n", 88 | " text += f\"Class: {str(row['product_class']).strip()} \"\n", 89 | " if row[\"product_description\"]:\n", 90 | " text += f\"Description: {str(row['product_description']).strip()} \"\n", 91 | " if row[\"category hierarchy\"]:\n", 92 | " text += f\"Hierarchy: {str(row['category hierarchy']).strip()} \"\n", 93 | " if row[\"features\"]:\n", 94 | " text += row['features']\n", 95 | " return text\n", 96 | "\n", 97 | "\n", 98 | "product_df = pd.read_csv(\"WANDS/dataset/product.csv\", sep=\"\\t\")\n", 99 | "product_df[\"features\"] = product_df[\"product_features\"].apply(get_features)\n", 100 | "product_df[\"TEXT\"] = product_df.apply(wands_text, axis=1)\n", 101 | "upload_df = product_df.rename(\n", 102 | " columns={\n", 103 | " \"product_id\": \"ID\",\n", 104 | " \"product_name\": \"NAME\",\n", 105 | " \"product_class\": \"CLASS\",\n", 106 | " \"rating_count\": \"RATING_COUNT\",\n", 107 | " \"average_rating\": \"RATING\",\n", 108 | " \"review_count\": \"REVIEW_COUNT\",\n", 109 | " }\n", 110 | ")\n", 111 | "upload_df = upload_df[\n", 112 | " [\n", 113 | " \"ID\",\n", 114 | " \"NAME\",\n", 115 | " \"CLASS\",\n", 116 | " \"RATING_COUNT\",\n", 117 | " \"RATING\",\n", 118 | " \"REVIEW_COUNT\",\n", 119 | " \"TEXT\",\n", 120 | " ]\n", 121 | "]" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "Now we can upload the data to snowflake and create a cortex search service" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "session.write_pandas(\n", 138 | " df=upload_df,\n", 139 | " table_name=\"WANDS_PRODUCT_DATASET\",\n", 140 | " schema=\"DATASETS\",\n", 141 | " database=\"CORTEX_SEARCH_DB\",\n", 142 | " overwrite=True,\n", 143 | " auto_create_table=True,\n", 144 | ")\n", 145 | "\n", 146 | "session.sql(\"\"\"CREATE OR REPLACE CORTEX SEARCH SERVICE CORTEX_SEARCH_DB.SERVICES.WANDS\n", 147 | "ON TEXT\n", 148 | "ATTRIBUTES CLASS \n", 149 | "WAREHOUSE = WH_TEST\n", 150 | "TARGET_LAG = '60 minute'\n", 151 | "AS (\n", 152 | " SELECT\n", 153 | " TEXT, ID, CLASS, RATING_COUNT, RATING, REVIEW_COUNT\n", 154 | " FROM CORTEX_SEARCH_DB.DATASETS.WANDS_PRODUCT_DATASET\n", 155 | ")\"\"\").collect()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "Now we can query the service. Note that we add softboost here. Softboost can be used to boost on a specific phrase, so that results which are similar to the phrase get ranked higher. In this case we want the results to be from the category `Furniture, Office Furniture, Desks` hence we add it as a softboost." 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "# fetch service\n", 172 | "my_service = (root\n", 173 | " .databases[\"CORTEX_SEARCH_DB\"]\n", 174 | " .schemas[\"SERVICES\"]\n", 175 | " .cortex_search_services[\"WANDS\"]\n", 176 | ")\n", 177 | "\n", 178 | "my_service.search(\n", 179 | " query=\"hulmeville writing desk with hutch\",\n", 180 | " experimental={\n", 181 | " \"softBoosts\": [\n", 182 | " {\"phrase\": \"Furniture, Office Furniture, Desks\"}\n", 183 | " ]\n", 184 | " },\n", 185 | " columns=[\"TEXT\"]\n", 186 | ")" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": ".venv", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.11.7" 214 | } 215 | }, 216 | "nbformat": 4, 217 | "nbformat_minor": 2 218 | } 219 | -------------------------------------------------------------------------------- /examples/08_multimodal_rag/streamlit_chatbot_multimodal_rag.py: -------------------------------------------------------------------------------- 1 | import streamlit as st # Import python packages 2 | from snowflake.snowpark.context import get_active_session 3 | 4 | from snowflake.cortex import Complete 5 | from snowflake.core import Root 6 | 7 | import pandas as pd 8 | import json 9 | 10 | pd.set_option("max_colwidth",None) 11 | 12 | # Cortex Search Service and Source Docs Stage parameters 13 | CORTEX_SEARCH_DATABASE = "CORTEX_SEARCH_DB" 14 | CORTEX_SEARCH_SCHEMA = "PYU" 15 | CORTEX_SEARCH_SERVICE = "DEMO_SEC_CORTEX_SEARCH_SERVICE" 16 | SOURCE_DOCS_STAGE = "@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL" 17 | SOURCE_DOCS_PATH = "raw_pdf" 18 | ###### 19 | 20 | ### Default Values 21 | # Default value for NUM_CHUNKS 22 | DEFAULT_NUM_CHUNKS = 3 23 | # Available options for NUM_CHUNKS 24 | CHUNK_OPTIONS = [1, 2, 3, 4, 5] 25 | 26 | NUM_CHUNKS = st.session_state.get("num_chunks", DEFAULT_NUM_CHUNKS) 27 | SLIDE_WINDOW = 5 # how many last conversations to remember. This is the slide window. 28 | 29 | # Define columns for the multimodal service 30 | COLUMNS = [ 31 | "text", 32 | "page_number", 33 | "image_filepath" 34 | ] 35 | 36 | session = get_active_session() 37 | root = Root(session) 38 | svc = None 39 | 40 | # Initialize the service based on the selected service 41 | def init_service(): 42 | global svc # Make svc a global variable 43 | svc = root.databases[CORTEX_SEARCH_DATABASE].schemas[CORTEX_SEARCH_SCHEMA].cortex_search_services[CORTEX_SEARCH_SERVICE] 44 | return svc 45 | 46 | # Initialize the service 47 | init_service() 48 | 49 | ### Functions 50 | 51 | def config_options(): 52 | 53 | st.sidebar.selectbox('Select your model:', 54 | ('claude-3-5-sonnet', 'pixtral-large'), 55 | key="model_name") 56 | 57 | st.sidebar.selectbox('Select number of docs to retrieve:', 58 | CHUNK_OPTIONS, key="num_chunks", 59 | index=CHUNK_OPTIONS.index(NUM_CHUNKS)) 60 | 61 | st.sidebar.checkbox('Chat history', key="use_chat_history", value = True) 62 | st.sidebar.checkbox('Debug mode', key="debug", value = True) 63 | 64 | st.sidebar.button("Start Over", key="clear_conversation", on_click=init_messages) 65 | st.sidebar.expander("Session State").write(st.session_state) 66 | 67 | def init_messages(): 68 | 69 | # Initialize chat history 70 | if st.session_state.clear_conversation or "messages" not in st.session_state: 71 | st.session_state.messages = [] 72 | 73 | def get_similar_chunks_search_service(query): 74 | response = svc.search( 75 | multi_index_query={ 76 | "text": [{"text": query}], 77 | "vector_main": [{"text": query}] 78 | }, 79 | columns=COLUMNS, 80 | limit=NUM_CHUNKS 81 | ) 82 | 83 | st.sidebar.json(response.json()) 84 | 85 | return response.json() 86 | 87 | def get_chat_history(): 88 | #Get the history from the st.session_stage.messages according to the slide window parameter 89 | 90 | chat_history = [] 91 | 92 | start_index = max(0, len(st.session_state.messages) - SLIDE_WINDOW) 93 | for i in range (start_index , len(st.session_state.messages) -1): 94 | chat_history.append(st.session_state.messages[i]) 95 | 96 | return chat_history 97 | 98 | def summarize_question_with_history(chat_history, question): 99 | # To get the right context, use the LLM to first summarize the previous conversation 100 | # This will be used to get embeddings and find similar chunks in the docs for context 101 | 102 | prompt = f""" 103 | Based on the chat history below and the question, generate a query that extend the question 104 | with the chat history provided. The query should be in natual language. 105 | Answer with only the query. Do not add any explanation. 106 | 107 | 108 | {chat_history} 109 | 110 | 111 | {question} 112 | 113 | """ 114 | 115 | sumary = Complete(st.session_state.model_name, prompt) 116 | 117 | if st.session_state.debug: 118 | st.sidebar.text("Summary to be used to find similar chunks in the docs:") 119 | st.sidebar.caption(sumary) 120 | 121 | sumary = sumary.replace("'", "") 122 | 123 | return sumary 124 | 125 | def create_prompt_multimodal(myquestion): 126 | # Get chat history and question summary 127 | chat_history = get_chat_history() if st.session_state.use_chat_history else [] 128 | question_summary = summarize_question_with_history(chat_history, myquestion) if chat_history else myquestion 129 | 130 | # Get prompt context 131 | prompt_context = get_similar_chunks_search_service(question_summary) 132 | 133 | # Convert chat history to string format 134 | chat_history_str = "\n".join([f"{msg['role'].upper()}: {msg['content']}" for msg in chat_history]) if chat_history else "" 135 | 136 | # Generate images placeholder string for the prompt 137 | if NUM_CHUNKS == 1: 138 | image_placeholders_str = "{0}" 139 | else: 140 | image_placeholders = [f"{{{i}}}" for i in range(NUM_CHUNKS)] 141 | image_placeholders_str = ", ".join(image_placeholders[:-1]) + f", and {image_placeholders[-1]}" 142 | 143 | prompt = f""" 144 | You are an expert AI assistant that extracts information from the document image(s) {image_placeholders_str}. 145 | You are specialized in accurately extracting screenshots, diagrams and structured data 146 | from tables presented within images, paying close attention to merged cells in tables. 147 | 148 | You also offer a chat experience considering the information included in the CHAT HISTORY 149 | provided between and tags.. 150 | 151 | When answering the question contained between and tags 152 | be concise and do not hallucinate. If you don´t have the information just say so. 153 | 154 | Do not mention the IMAGES used in your answer, but do reference the content including any diagrams or text. 155 | Do not mention the CHAT HISTORY used in your answer. 156 | 157 | Only answer the question if you can extract it from the IMAGES provided. 158 | 159 | 160 | {chat_history_str} 161 | 162 | 163 | {myquestion} 164 | 165 | Answer: 166 | """ 167 | 168 | json_data = json.loads(prompt_context) 169 | 170 | relative_paths = [item['image_filepath'] for item in json_data['results']] 171 | 172 | return prompt, relative_paths 173 | 174 | 175 | def answer_question(myquestion): 176 | prompt, relative_paths = create_prompt_multimodal(myquestion) 177 | 178 | image_files = [] 179 | for path in relative_paths: 180 | image_files.append(f"TO_FILE('{SOURCE_DOCS_STAGE}', '{path}')") 181 | image_files_str = ",\n".join(image_files) 182 | 183 | query = f""" 184 | SELECT SNOWFLAKE.CORTEX.COMPLETE('claude-3-5-sonnet', 185 | PROMPT('{prompt}', 186 | {image_files_str}) 187 | ); 188 | """ 189 | 190 | sql_output = session.sql(query).collect() 191 | 192 | response = list(sql_output[0].asDict().values())[0] 193 | 194 | return response, relative_paths 195 | 196 | def main(): 197 | 198 | st.title(f":robot_face: :mag_right: Multimodal RAG Assistant") 199 | st.write("This is the list of documents you already have and that will be used to answer your questions:") 200 | docs_available = session.sql(f"ls {SOURCE_DOCS_STAGE}/{SOURCE_DOCS_PATH}").collect() 201 | list_docs = [doc["name"] for doc in docs_available] 202 | for doc in list_docs: 203 | st.markdown(f"- {doc}") 204 | 205 | config_options() 206 | init_messages() 207 | 208 | # Display chat messages from history on app rerun 209 | for message in st.session_state.messages: 210 | with st.chat_message(message["role"]): 211 | st.markdown(message["content"]) 212 | 213 | # Accept user input 214 | if question := st.chat_input("What do you want to know about your products?"): 215 | # Add user message to chat history 216 | st.session_state.messages.append({"role": "user", "content": question}) 217 | # Display user message in chat message container 218 | with st.chat_message("user"): 219 | st.markdown(question) 220 | # Display assistant response in chat message container 221 | with st.chat_message("assistant"): 222 | message_placeholder = st.empty() 223 | 224 | question = question.replace("'","") 225 | 226 | with st.spinner(f"{st.session_state.model_name} thinking..."): 227 | response, relative_paths = answer_question(question) 228 | response = response.replace("'", "") 229 | message_placeholder.markdown(response) 230 | 231 | # Display images inline 232 | if relative_paths: 233 | st.write("Related Images:") 234 | cols = st.columns(3) # Create 3 columns for images 235 | for i, path in enumerate(relative_paths): 236 | cmd2 = f"select GET_PRESIGNED_URL('{SOURCE_DOCS_STAGE}', '{path}', 360) as URL_LINK;" 237 | df_url_link = session.sql(cmd2).to_pandas() 238 | url_link = df_url_link._get_value(0,'URL_LINK') 239 | 240 | # Display image in the appropriate column 241 | with cols[i % 3]: 242 | st.image(url_link, caption=f"Image {i+1}", use_container_width=True) 243 | st.markdown(f"Doc: [{path}]({url_link})") 244 | 245 | st.session_state.messages.append({"role": "assistant", "content": response}) 246 | 247 | 248 | if __name__ == "__main__": 249 | main() -------------------------------------------------------------------------------- /examples/06_streamlit_chatbot_app/chat.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from snowflake.core import Root # requires snowflake>=0.8.0 3 | from snowflake.cortex import Complete 4 | from snowflake.snowpark.context import get_active_session 5 | 6 | MODELS = [ 7 | "mistral-large", 8 | "snowflake-arctic", 9 | "llama3-70b", 10 | "llama3-8b", 11 | ] 12 | 13 | 14 | def init_messages(): 15 | """ 16 | Initialize the session state for chat messages. If the session state indicates that the 17 | conversation should be cleared or if the "messages" key is not in the session state, 18 | initialize it as an empty list. 19 | """ 20 | if st.session_state.clear_conversation or "messages" not in st.session_state: 21 | st.session_state.messages = [] 22 | 23 | 24 | def init_service_metadata(): 25 | """ 26 | Initialize the session state for cortex search service metadata. Query the available 27 | cortex search services from the Snowflake session and store their names and search 28 | columns in the session state. 29 | """ 30 | if "service_metadata" not in st.session_state: 31 | services = session.sql("SHOW CORTEX SEARCH SERVICES;").collect() 32 | service_metadata = [] 33 | if services: 34 | # TODO: remove loop once changes land to add the column metadata in SHOW 35 | for s in services: 36 | svc_name = s["name"] 37 | svc_search_col = session.sql( 38 | f"DESC CORTEX SEARCH SERVICE {svc_name};" 39 | ).collect()[0]["search_column"] 40 | service_metadata.append( 41 | {"name": svc_name, "search_column": svc_search_col} 42 | ) 43 | 44 | st.session_state.service_metadata = service_metadata 45 | 46 | 47 | def init_config_options(): 48 | """ 49 | Initialize the configuration options in the Streamlit sidebar. Allow the user to select 50 | a cortex search service, clear the conversation, toggle debug mode, and toggle the use of 51 | chat history. Also provide advanced options to select a model, the number of context chunks, 52 | and the number of chat messages to use in the chat history. 53 | """ 54 | st.sidebar.selectbox( 55 | "Select cortex search service:", 56 | [s["name"] for s in st.session_state.service_metadata], 57 | key="selected_cortex_search_service", 58 | ) 59 | 60 | st.sidebar.button("Clear conversation", key="clear_conversation") 61 | st.sidebar.toggle("Debug", key="debug", value=False) 62 | st.sidebar.toggle("Use chat history", key="use_chat_history", value=True) 63 | 64 | with st.sidebar.expander("Advanced options"): 65 | st.selectbox("Select model:", MODELS, key="model_name") 66 | st.number_input( 67 | "Select number of context chunks", 68 | value=5, 69 | key="num_retrieved_chunks", 70 | min_value=1, 71 | max_value=10, 72 | ) 73 | st.number_input( 74 | "Select number of messages to use in chat history", 75 | value=5, 76 | key="num_chat_messages", 77 | min_value=1, 78 | max_value=10, 79 | ) 80 | 81 | st.sidebar.expander("Session State").write(st.session_state) 82 | 83 | 84 | def query_cortex_search_service(query): 85 | """ 86 | Query the selected cortex search service with the given query and retrieve context documents. 87 | Display the retrieved context documents in the sidebar if debug mode is enabled. Return the 88 | context documents as a string. 89 | 90 | Args: 91 | query (str): The query to search the cortex search service with. 92 | 93 | Returns: 94 | str: The concatenated string of context documents. 95 | """ 96 | db, schema = session.get_current_database(), session.get_current_schema() 97 | 98 | cortex_search_service = ( 99 | root.databases[db] 100 | .schemas[schema] 101 | .cortex_search_services[st.session_state.selected_cortex_search_service] 102 | ) 103 | 104 | context_documents = cortex_search_service.search( 105 | query, columns=[], limit=st.session_state.num_retrieved_chunks 106 | ) 107 | results = context_documents.results 108 | 109 | service_metadata = st.session_state.service_metadata 110 | search_col = [s["search_column"] for s in service_metadata 111 | if s["name"] == st.session_state.selected_cortex_search_service][0] 112 | 113 | context_str = "" 114 | for i, r in enumerate(results): 115 | context_str += f"Context document {i+1}: {r[search_col]} \n" + "\n" 116 | 117 | if st.session_state.debug: 118 | st.sidebar.text_area("Context documents", context_str, height=500) 119 | 120 | return context_str 121 | 122 | 123 | def get_chat_history(): 124 | """ 125 | Retrieve the chat history from the session state limited to the number of messages specified 126 | by the user in the sidebar options. 127 | 128 | Returns: 129 | list: The list of chat messages from the session state. 130 | """ 131 | start_index = max( 132 | 0, len(st.session_state.messages) - st.session_state.num_chat_messages 133 | ) 134 | return st.session_state.messages[start_index : len(st.session_state.messages) - 1] 135 | 136 | 137 | def complete(model, prompt): 138 | """ 139 | Generate a completion for the given prompt using the specified model. 140 | 141 | Args: 142 | model (str): The name of the model to use for completion. 143 | prompt (str): The prompt to generate a completion for. 144 | 145 | Returns: 146 | str: The generated completion. 147 | """ 148 | return Complete(model, prompt).replace("$", "\$") 149 | 150 | 151 | def make_chat_history_summary(chat_history, question): 152 | """ 153 | Generate a summary of the chat history combined with the current question to extend the query 154 | context. Use the language model to generate this summary. 155 | 156 | Args: 157 | chat_history (str): The chat history to include in the summary. 158 | question (str): The current user question to extend with the chat history. 159 | 160 | Returns: 161 | str: The generated summary of the chat history and question. 162 | """ 163 | prompt = f""" 164 | [INST] 165 | Based on the chat history below and the question, generate a query that extend the question 166 | with the chat history provided. The query should be in natural language. 167 | Answer with only the query. Do not add any explanation. 168 | 169 | 170 | {chat_history} 171 | 172 | 173 | {question} 174 | 175 | [/INST] 176 | """ 177 | 178 | summary = complete(st.session_state.model_name, prompt) 179 | 180 | if st.session_state.debug: 181 | st.sidebar.text_area( 182 | "Chat history summary", summary.replace("$", "\$"), height=150 183 | ) 184 | 185 | return summary 186 | 187 | 188 | def create_prompt(user_question): 189 | """ 190 | Create a prompt for the language model by combining the user question with context retrieved 191 | from the cortex search service and chat history (if enabled). Format the prompt according to 192 | the expected input format of the model. 193 | 194 | Args: 195 | user_question (str): The user's question to generate a prompt for. 196 | 197 | Returns: 198 | str: The generated prompt for the language model. 199 | """ 200 | if st.session_state.use_chat_history: 201 | chat_history = get_chat_history() 202 | if chat_history != []: 203 | question_summary = make_chat_history_summary(chat_history, user_question) 204 | prompt_context = query_cortex_search_service(question_summary) 205 | else: 206 | prompt_context = query_cortex_search_service(user_question) 207 | else: 208 | prompt_context = query_cortex_search_service(user_question) 209 | chat_history = "" 210 | 211 | prompt = f""" 212 | [INST] 213 | You are a helpful AI chat assistant with RAG capabilities. When a user asks you a question, 214 | you will also be given context provided between and tags. Use that context 215 | with the user's chat history provided in the between and tags 216 | to provide a summary that addresses the user's question. Ensure the answer is coherent, concise, 217 | and directly relevant to the user's question. 218 | 219 | If the user asks a generic question which cannot be answered with the given context or chat_history, 220 | just say "I don't know the answer to that question. 221 | 222 | Don't saying things like "according to the provided context". 223 | 224 | 225 | {chat_history} 226 | 227 | 228 | {prompt_context} 229 | 230 | 231 | {user_question} 232 | 233 | [/INST] 234 | Answer: 235 | """ 236 | return prompt 237 | 238 | 239 | def main(): 240 | st.title(f":speech_balloon: Chatbot with Snowflake Cortex") 241 | 242 | init_service_metadata() 243 | init_config_options() 244 | init_messages() 245 | 246 | icons = {"assistant": "❄️", "user": "👤"} 247 | 248 | # Display chat messages from history on app rerun 249 | for message in st.session_state.messages: 250 | with st.chat_message(message["role"], avatar=icons[message["role"]]): 251 | st.markdown(message["content"]) 252 | 253 | disable_chat = ( 254 | "service_metadata" not in st.session_state 255 | or len(st.session_state.service_metadata) == 0 256 | ) 257 | if question := st.chat_input("Ask a question...", disabled=disable_chat): 258 | # Add user message to chat history 259 | st.session_state.messages.append({"role": "user", "content": question}) 260 | # Display user message in chat message container 261 | with st.chat_message("user", avatar=icons["user"]): 262 | st.markdown(question.replace("$", "\$")) 263 | 264 | # Display assistant response in chat message container 265 | with st.chat_message("assistant", avatar=icons["assistant"]): 266 | message_placeholder = st.empty() 267 | question = question.replace("'", "") 268 | with st.spinner("Thinking..."): 269 | generated_response = complete( 270 | st.session_state.model_name, create_prompt(question) 271 | ) 272 | message_placeholder.markdown(generated_response) 273 | 274 | st.session_state.messages.append( 275 | {"role": "assistant", "content": generated_response} 276 | ) 277 | 278 | 279 | if __name__ == "__main__": 280 | session = get_active_session() 281 | root = Root(session) 282 | main() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /projects/improving-catalog-search/llm_judge_ecommerce_ranking_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3775908f-ca36-4846-8f38-5adca39217f2", 7 | "metadata": { 8 | "language": "python", 9 | "name": "cell1", 10 | "resultHeight": 0 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "# Import python packages\n", 15 | "from snowflake.snowpark.functions import col\n", 16 | "from snowflake.snowpark.functions import col, udf\n", 17 | "from snowflake.snowpark.types import StringType, ArrayType, StructType, StructField, FloatType\n", 18 | "from snowflake.snowpark.context import get_active_session\n", 19 | "\n", 20 | "from copy import deepcopy\n", 21 | "\n", 22 | "session = get_active_session()\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "6e877fca-d23e-4286-865e-55637bca904f", 29 | "metadata": { 30 | "language": "python", 31 | "name": "cell9", 32 | "resultHeight": 0 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "STRUCTURED_QUERY_PROMPT = \"\"\"You are an assitant helping a user searching for a product on amazon, given the user query and product information you have to give a rating. The user is flexible with product requirements which doesn't match the exact query. Specifically you should give a three point rating which is defined as:\n", 37 | "0 - Product doesn't seem related to the use query. \n", 38 | "1 - Product is related to the query and might vaguely satisfy the intent behind the query even though it might not match all the requirements present in the query and/or there are significant differences in sepcifications provided in the query (while still being related to the query).\n", 39 | "2 - Product is broadly what the user is searching for. However there might still be some minor differences in specifications, such as brand, or exact details of the product.\n", 40 | "3 - Product is completely what the user is searching for, matching all description in query. \n", 41 | "\n", 42 | "You should think step by step about the user query and the search result and rate the search result. You should also provide a reasoning for your rating.\n", 43 | " \n", 44 | "Use the following format:\n", 45 | "Reasoning: Example Reasoning\n", 46 | "Rating: Example Rating\n", 47 | "\n", 48 | "### Examples\n", 49 | "Example 1:\n", 50 | "INPUT:\n", 51 | "Query: 4.2 Lt Freezerless Mini Fridge\n", 52 | "PRODUCT: \n", 53 | "TITLE: 4 Lt. Refrigerator Description: Energy Star Apartment Freezerless Fridge, Stainless Steel, E-Star with LED Lighting, Reversible Door, Adjustable Temperature, Quiet, for Dorm, Office, Home Kitchen\n", 54 | "OUTPUT:\n", 55 | "Reasoning: In this case the product is a 4 Lt. Refrigerator which is close to the 4.2 Lt Freezerless Mini Fridge and all other requirements match. Hence the product is a close match to the user query and is broadly what the user is looking for. Therefore it is rated 2.\n", 56 | "Rating: 2\n", 57 | "\n", 58 | "Example 2:\n", 59 | "INPUT:\n", 60 | "Query: Lead pencil without plastic grip\n", 61 | "PRODUCT: \n", 62 | "TITLE: Pencil Description: A pack of 10 lead pencils. The pencils are 2B and are perfect for writing and drawing. Erasers are provided for free with the 10 pencil pack.\n", 63 | "OUTPUT:\n", 64 | "Reasoning: The query mentions a lead pencil without a plastic grip. The product is a pack of 10 lead pencils which don't mention a plastic grip, so we can assume they don't have a plastic grip. Therefore the product is exactly what the user is looking for. Hence the rating is 2.\n", 65 | "Rating: 2\n", 66 | "\n", 67 | "Example 3:\n", 68 | "INPUT:\n", 69 | "Query: US immigration test pass gift cup\n", 70 | "PRODUCT: \n", 71 | "TITLE: Immigration Gift Jacket Description: A perfect gift for someone who recently pased the immigration test. US Flag is on the back of the jacket. \n", 72 | "OUTPUT:\n", 73 | "Reasoning: Although the query doesn't request for a jacket, it is about a gift for someone who passed immigration test. Hence the product vaguely satisfies the intent behind the query. Therefore the rating is 1.\n", 74 | "Rating: 1\n", 75 | "\n", 76 | "Example 4:\n", 77 | "INPUT:\n", 78 | "Query: Bomber Jacket with chinese collar\n", 79 | "PRODUCT: \n", 80 | "TITLE: Bomber Jacket Description: A stylish dark brown bomber jacket with a zipper and pockets with a fur collar. Will keep you warm in the winter while looking stylish. Needs dry cleaning.\n", 81 | "OUTPUT: \n", 82 | "Reasoning: The query is looking for a bomber jacket with chinese collars. The product is a bomber jacket with a fur collar. Which is a similar product to what the query is searching for. Therefore the rating is 1.\n", 83 | "Rating: 1\n", 84 | "\n", 85 | "Example 5:\n", 86 | "{example2_0}\n", 87 | "\n", 88 | "Example 6:\n", 89 | "{example2_1}\n", 90 | "\n", 91 | "Example 7:\n", 92 | "{example1_0}\n", 93 | "\n", 94 | "Example 8:\n", 95 | "{example1_1}\n", 96 | "\n", 97 | "Example 9:\n", 98 | "{example0_0}\n", 99 | "\n", 100 | "Example 10:\n", 101 | "{example3_0}\n", 102 | "###\n", 103 | "\n", 104 | "Now given the user query and search result below, rate the search result based on its relevance to the user query and provide a reasoning for your rating.\n", 105 | "INPUT:\n", 106 | "User Query: {query}\n", 107 | "Search Result: {passage}\n", 108 | "OUTPUT:\n", 109 | "\"\"\"\n", 110 | "\n", 111 | "RATINGS_TO_DEFAULT_EXAMPLES = {\n", 112 | " \"0\": [\n", 113 | "\"\"\"Query: Daiwa Liberty Club Short Swing\n", 114 | "PRODUCT: YONEX AC1025P Tennis Badminton Grip. DESCRIPTION: Product Description Muscle Power locates the string on rounded archways that eliminate stress-load and fatigue through contact friction. \n", 115 | "OUTPUT:\n", 116 | "Rating: 0\n", 117 | "Reasoning: The query is looking for a fishing club short swing from Daiwa. But the product is a tennis badminton grip. It's completely unrelated.\n", 118 | "\"\"\",\n", 119 | " ],\n", 120 | " \"1\": [\n", 121 | "\"\"\"INPUT:\n", 122 | "Query: YONEX AC1025P Tennis Badminton Grip\n", 123 | "PRODUCT: TITLE: Yonex Badminton Racquet Voltric 200 Taufik Series - 80Gms DESCRIPTION: Product Description YONEX's head-light series, NANORAY provides a fast and controlled swing with enhanced repulsion via the New Aero Frame. \n", 124 | "OUTPUT:\n", 125 | "Rating: 1\n", 126 | "Reasoning: The query is looking for badminton grip, but product is badminton racquet instead. The product isn't what user is looking for.\n", 127 | "\"\"\",\n", 128 | "\"\"\"INPUT:\n", 129 | "User Query: 2010 dodge nitro crossbar\n", 130 | "PRODUCT: BLACK HORSE Armour Roll Bar Compatible with 2000 to 2022 Ram Chevrolet Ford GMC Toyota 3500 2500 Silverado F-150 Sierra Tundra 1500 2500 3500 Black Steel RB-AR1B\n", 131 | "OUTPUT:\n", 132 | "Rating: 1\n", 133 | "Reasoning: The query is looking for crossbar, but product is roll bar, which isn't same product. Even if dodge 2010 is mentioned, it's not a complete match given product doesn't align.\n", 134 | "\"\"\",\n", 135 | " ],\n", 136 | " \"2\": [\n", 137 | "\"\"\"INPUT:\n", 138 | "Query: Dogfish 500GB Msata Internal SSD\n", 139 | "PRODUCT: Crucial MX500 500GB 3D NAND SATA M.2 Internal SSD, up to 560MB/s & Seagate Barracuda 2TB Internal Hard Drive HDD – 3.5 Inch SATA 6Gb/s 7200 RPM 256MB Cache 3.5-Inch – Frustration Free Packaging\n", 140 | "OUTPUT:\n", 141 | "Rating: 2\n", 142 | "Reasoning: Query is looking for 500GB SSD, which product satisfies. However, brand mentioned in query, Dogfish, doesn't match brand in product.\n", 143 | "\"\"\",\n", 144 | "\"\"\"INPUT:\n", 145 | "Query: YONEX AC1025P Tennis Badminton Grip\n", 146 | "PRODUCT: WILSON Pro Overgrip-Comfort DESCRIPTION: Product Description Will fit tennis, racquetball, badminton, and squash handles. Product Description Will fit tennis, racquetball, badminton, and squash handles.\n", 147 | "OUTPUT:\n", 148 | "Rating: 2\n", 149 | "Reasoning: The query is looking for badminton grip, which is aligned with product. But the product has minor details like brand that doens't fit query description.\n", 150 | "\"\"\",\n", 151 | " ],\n", 152 | " \"3\": [\n", 153 | "\"\"\"INPUT:\n", 154 | "Query: center console organizer\n", 155 | "PRODUCT: MX Auto Center Console Organizer| Compatible with Ford Trucks & SUVs – Accessories for F150, F250, F350, Raptor, Expedition|2015, 16, 17, 18, 19, 20, 21| Must-Have Bucket Seats|SEE COMPATIBILITY BELOW\n", 156 | "OUTPUT:\n", 157 | "Rating: 2\n", 158 | "Reasoning: Query and product are both center console organizer, complete match.\n", 159 | "\"\"\",\n", 160 | " ],\n", 161 | "}\n", 162 | "\n", 163 | "def generate_llm_label(\n", 164 | " input_query: str,\n", 165 | " intermediate_columns: list[str],\n", 166 | " output_select_columns: list[str],\n", 167 | " output_table: str,\n", 168 | ") -> None:\n", 169 | " table_raw = session.sql(input_query)\n", 170 | " def generate_prompt(query: str, passage: str, golden_docs: list):\n", 171 | " query_ratings = deepcopy(RATINGS_TO_DEFAULT_EXAMPLES)\n", 172 | " ptrs = {\n", 173 | " \"0\": len(RATINGS_TO_DEFAULT_EXAMPLES[\"0\"]), \n", 174 | " \"1\": len(RATINGS_TO_DEFAULT_EXAMPLES[\"1\"]), \n", 175 | " \"2\": len(RATINGS_TO_DEFAULT_EXAMPLES[\"2\"]), \n", 176 | " \"3\": len(RATINGS_TO_DEFAULT_EXAMPLES[\"3\"]),\n", 177 | " }\n", 178 | " for gd in golden_docs:\n", 179 | " gd_score = str(gd[\"score\"])\n", 180 | " if gd_score in ptrs and ptrs[gd_score] > 0: # still need\n", 181 | " # SKIPPED REASONING GIVEN NONE EXISTS, IT MAY NOT BE GOOD...\n", 182 | " query_ratings[gd_score][ptrs[gd_score]-1] = f\"\"\"INPUT:\n", 183 | "Query: {query}\n", 184 | "PRODUCT: {gd[\"doc_text\"]}\n", 185 | "OUTPUT:\n", 186 | "Rating: {gd_score}\n", 187 | "\"\"\"\n", 188 | " return STRUCTURED_QUERY_PROMPT.format(\n", 189 | " query=query,\n", 190 | " passage=passage,\n", 191 | " example2_0=query_ratings[\"2\"][0],\n", 192 | " example2_1=query_ratings[\"2\"][1],\n", 193 | " example1_0=query_ratings[\"1\"][0],\n", 194 | " example1_1=query_ratings[\"1\"][1],\n", 195 | " example0_0=query_ratings[\"0\"][0],\n", 196 | " example3_0=query_ratings[\"3\"][0],\n", 197 | " )\n", 198 | " \n", 199 | " # Register the generate_prompt function as a UDF\n", 200 | " golden_doc_struct = StructType([\n", 201 | " StructField(\"score\", FloatType()),\n", 202 | " StructField(\"doc_text\", StringType()),\n", 203 | " StructField(\"doc_id\", StringType())\n", 204 | " ])\n", 205 | " input_cols = [StringType() for _ in range(len(intermediate_columns))]\n", 206 | " input_cols.append(ArrayType(golden_doc_struct))\n", 207 | " generate_prompt_udf = udf(\n", 208 | " generate_prompt,\n", 209 | " return_type=StringType(),\n", 210 | " input_types=input_cols,\n", 211 | " packages=['snowflake-snowpark-python'],\n", 212 | " max_batch_size=100,\n", 213 | " )\n", 214 | " \n", 215 | " # Apply the UDF to generate the 'generated_question' column\n", 216 | " table_with_prompt = table_raw.with_column(\n", 217 | " \"PROMPT\",\n", 218 | " generate_prompt_udf(\n", 219 | " *[col(colname) for colname in intermediate_columns],\n", 220 | " col(\"GOLDEN_DOCS\"),\n", 221 | " ),\n", 222 | " )\n", 223 | " \n", 224 | " # Filter and limit the rows, then show them\n", 225 | " table_with_prompt = table_with_prompt.select(output_select_columns)\n", 226 | " # table_with_prompt.limit(1).show()\n", 227 | " \n", 228 | " # Save the DataFrame to a Snowflake table\n", 229 | " table_with_prompt.write.save_as_table(f\"{output_table}_INTERMEDIATE\", mode=\"overwrite\")\n", 230 | "\n", 231 | " session.sql(f\"\"\"CREATE OR REPLACE TABLE {output_table} AS\n", 232 | "SELECT\n", 233 | " *,\n", 234 | " SNOWFLAKE.CORTEX.COMPLETE(\n", 235 | " 'llama3.1-405b',\n", 236 | " [{{'role': 'user', 'content': prompt}}],\n", 237 | " {{'temperature': 0,'top_p': 1}}\n", 238 | " )['choices'][0]['messages']::VARCHAR AS LLM_JUDGE,\n", 239 | " REGEXP_SUBSTR(LLM_JUDGE, 'Rating: ([0-9])', 1, 1, 'e', 1) AS LLM_RELEVANCE\n", 240 | "FROM {output_table}_INTERMEDIATE\"\"\").collect()\n", 241 | " \n" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "a51b320e-6df1-4b29-8741-7f5813cd2aab", 248 | "metadata": { 249 | "codeCollapsed": false, 250 | "language": "python", 251 | "name": "cell5", 252 | "resultHeight": 0 253 | }, 254 | "outputs": [], 255 | "source": [ 256 | "input_query = \"\"\"SELECT QUERY_ID, TEXT, QUERY, RANK FROM CORTEX_SEARCH_DB.GOLDEN.TREC23_WITH_SUBTABLE_BASE_RESULTS\"\"\"\n", 257 | "\n", 258 | "generate_llm_label(\n", 259 | " input_query=input_query,\n", 260 | " intermediate_columns=[\"QUERY\", \"TEXT\"],\n", 261 | " output_select_columns=[\"QUERY_ID\", \"QUERY\", \"TEXT\", \"PROMPT\", \"RANK\"],\n", 262 | " output_table=\"CORTEX_SEARCH_DB.GOLDEN.TREC23_WITH_SUBTABLE_BASE_RESULTS_LLM_JUDGE\",\n", 263 | ")" 264 | ] 265 | } 266 | ], 267 | "metadata": { 268 | "kernelspec": { 269 | "display_name": "Streamlit Notebook", 270 | "name": "streamlit" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 5 275 | } 276 | -------------------------------------------------------------------------------- /examples/02_rest_api_simple_usage/notebook_query.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 30, 6 | "id": "8549cdb7-3b82-416b-b8d0-900732227212", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from datetime import timedelta, timezone, datetime\n", 11 | "import jwt\n", 12 | "from cryptography.hazmat.primitives.serialization import load_pem_private_key\n", 13 | "from cryptography.hazmat.primitives.serialization import Encoding\n", 14 | "from cryptography.hazmat.primitives.serialization import PublicFormat\n", 15 | "from cryptography.hazmat.backends import default_backend\n", 16 | "import base64\n", 17 | "from getpass import getpass\n", 18 | "import hashlib\n", 19 | "import requests\n", 20 | "import json\n", 21 | "import os\n", 22 | "import pandas as pd\n", 23 | "pd.options.display.max_colwidth = 1000\n", 24 | "\n", 25 | "# account parameters\n", 26 | "SNOWFLAKE_ACCOUNT = \"\" # must be capitalized\n", 27 | "SNOWFLAKE_USER = \"\" # must be capitalized\n", 28 | "SNOWFLAKE_URL = \"https://org-acc.snowflakecomputing.com\"\n", 29 | "PRIVATE_KEY_PATH = \"/path/to/your/rsa_key.p8\"\n", 30 | "\n", 31 | "# service parameters\n", 32 | "CORTEX_SEARCH_DATABASE = \"\"\n", 33 | "CORTEX_SEARCH_SCHEMA = \"\"\n", 34 | "CORTEX_SEARCH_SERVICE = \"\"\n", 35 | "\n", 36 | "# columns to query in the service\n", 37 | "COLUMNS = [\n", 38 | " \"COL1\",\n", 39 | " \"COL2\",\n", 40 | "]" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 31, 46 | "id": "4ab6c048-b34c-47f9-a032-aa20fadb91d6", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "def generate_JWT_token():\n", 51 | " \"\"\"\n", 52 | " https://docs.snowflake.com/en/developer-guide/sql-api/authenticating#generating-a-jwt-in-python\n", 53 | " Generate a valid JWT token from snowflake account name, user name, private key and private key passphrase.\n", 54 | " \"\"\"\n", 55 | " # Prompt for private key passphrase\n", 56 | " def get_private_key_passphrase():\n", 57 | " return getpass('Private Key Passphrase: ')\n", 58 | "\n", 59 | " # Generate encoded public key\n", 60 | " with open(PRIVATE_KEY_PATH, 'rb') as pem_in:\n", 61 | " pemlines = pem_in.read()\n", 62 | " try:\n", 63 | " private_key = load_pem_private_key(pemlines, None, default_backend())\n", 64 | " except TypeError:\n", 65 | " private_key = load_pem_private_key(pemlines, get_private_key_passphrase().encode(), default_backend())\n", 66 | " public_key_raw = private_key.public_key().public_bytes(Encoding.DER, PublicFormat.SubjectPublicKeyInfo)\n", 67 | " sha256hash = hashlib.sha256()\n", 68 | " sha256hash.update(public_key_raw)\n", 69 | " public_key_fp = 'SHA256:' + base64.b64encode(sha256hash.digest()).decode('utf-8')\n", 70 | "\n", 71 | " # Generate JWT payload\n", 72 | " qualified_username = SNOWFLAKE_ACCOUNT + \".\" + SNOWFLAKE_USER\n", 73 | " now = datetime.now(timezone.utc)\n", 74 | " lifetime = timedelta(minutes=60)\n", 75 | " payload = {\n", 76 | " \"iss\": qualified_username + '.' + public_key_fp,\n", 77 | " \"sub\": qualified_username,\n", 78 | " \"iat\": now,\n", 79 | " \"exp\": now + lifetime\n", 80 | " }\n", 81 | " return jwt.encode(payload, key=private_key, algorithm=\"RS256\")\n", 82 | " \n", 83 | "jwt_token = generate_JWT_token()\n", 84 | "\n", 85 | "headers = {\n", 86 | " 'X-Snowflake-Authorization-Token-Type': 'KEYPAIR_JWT',\n", 87 | " 'Content-Type': 'application/json',\n", 88 | " 'Accept': 'application/json',\n", 89 | " 'Authorization': f'Bearer {jwt_token}',\n", 90 | "}" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 34, 96 | "id": "1ae0e607-a9aa-400e-8bb7-a9b6945cccaf", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "def query_service(query):\n", 101 | " \"\"\"\n", 102 | " Query the specified Snowflake service with the given query string.\n", 103 | " \"\"\"\n", 104 | " url = f\"{SNOWFLAKE_URL}/api/v2/databases/{CORTEX_SEARCH_DATABASE}/schemas/{CORTEX_SEARCH_SCHEMA}/cortex-search-services/{CORTEX_SEARCH_SERVICE}:query\"\n", 105 | " data = {\n", 106 | " \"query\": query,\n", 107 | " \"columns\": COLUMNS,\n", 108 | " \"filter\": \"\",\n", 109 | " \"limit\": 10\n", 110 | " }\n", 111 | " \n", 112 | " jwt_token = generate_JWT_token()\n", 113 | " headers = {\n", 114 | " 'X-Snowflake-Authorization-Token-Type': 'KEYPAIR_JWT',\n", 115 | " 'Content-Type': 'application/json',\n", 116 | " 'Accept': 'application/json',\n", 117 | " 'Authorization': f'Bearer {jwt_token}',\n", 118 | " }\n", 119 | " \n", 120 | " try:\n", 121 | " response = requests.post(url, headers=headers, json=data)\n", 122 | " response.raise_for_status()\n", 123 | " except requests.exceptions.HTTPError as http_err:\n", 124 | " print(f\"HTTP error occurred: {http_err} - Status code: {response.status_code}\")\n", 125 | " except Exception as err:\n", 126 | " print(f\"An error occurred: {err}\")\n", 127 | " else:\n", 128 | " return response.json()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 37, 134 | "id": "00202c33-97e0-4860-b34a-fe599d3948b2", 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "data": { 139 | "text/html": [ 140 | "
\n", 141 | "\n", 154 | "\n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | "
TALK_IDSPEAKER_1TITLEDESCRIPTIONURL
0220Joseph LekutonA parable for KenyaJoseph Lekuton, a member of parliament in Kenya, starts with the story of his remarkable education, then offers a parable of how Africa can grow. His message of hope has never been more relevant.https://www.ted.com/talks/joseph_lekuton_a_parable_for_kenya/
12221Boniface MwangiThe day I stood up alonePhotographer Boniface Mwangi wanted to protest against corruption in his home country of Kenya. So he made a plan: He and some friends would stand up and heckle during a public mass meeting. But when the moment came ... he stood alone. What happened next, he says, showed him who he truly was. As he says, \"There are two most powerful days in your life. The day you are born, and the day you discover why.\" Graphic images.https://www.ted.com/talks/boniface_mwangi_the_day_i_stood_up_alone/
2523Erik HersmanReporting crisis via textingAt TEDU 2009, Erik Hersman presents the remarkable story of Ushahidi, a GoogleMap mashup that allowed Kenyans to report and track violence via cell phone texts following the 2008 elections, and has evolved to continue saving lives in other countries.https://www.ted.com/talks/erik_hersman_reporting_crisis_via_texting/
32653Charity WayuaA few ways to fix a governmentCharity Wayua put her skills as a cancer researcher to use on an unlikely patient: the government of her native Kenya. She shares how she helped her government drastically improve its process for opening up new businesses, a crucial part of economic health and growth, leading to new investments and a World Bank recognition as a top reformer.https://www.ted.com/talks/charity_wayua_a_few_ways_to_fix_a_government/
421033Mary MakerWhy I fight for the education of refugee girls (like me)After fleeing war-torn South Sudan as a child, Mary Maker found security and hope in the school at Kenya's Kakuma Refugee Camp. Now a teacher of young refugees herself, she sees education as an essential tool for rebuilding lives -- and empowering a generation of girls who are too often denied entrance into the classroom. \"For the child of war, an education can turn their tears of loss into a passion for peace,\" Maker says.https://www.ted.com/talks/mary_maker_why_i_fight_for_the_education_of_refugee_girls_like_me/
\n", 208 | "
" 209 | ], 210 | "text/plain": [ 211 | " TALK_ID SPEAKER_1 \\\n", 212 | "0 220 Joseph Lekuton \n", 213 | "1 2221 Boniface Mwangi \n", 214 | "2 523 Erik Hersman \n", 215 | "3 2653 Charity Wayua \n", 216 | "4 21033 Mary Maker \n", 217 | "\n", 218 | " TITLE \\\n", 219 | "0 A parable for Kenya \n", 220 | "1 The day I stood up alone \n", 221 | "2 Reporting crisis via texting \n", 222 | "3 A few ways to fix a government \n", 223 | "4 Why I fight for the education of refugee girls (like me) \n", 224 | "\n", 225 | " DESCRIPTION \\\n", 226 | "0 Joseph Lekuton, a member of parliament in Kenya, starts with the story of his remarkable education, then offers a parable of how Africa can grow. His message of hope has never been more relevant. \n", 227 | "1 Photographer Boniface Mwangi wanted to protest against corruption in his home country of Kenya. So he made a plan: He and some friends would stand up and heckle during a public mass meeting. But when the moment came ... he stood alone. What happened next, he says, showed him who he truly was. As he says, \"There are two most powerful days in your life. The day you are born, and the day you discover why.\" Graphic images. \n", 228 | "2 At TEDU 2009, Erik Hersman presents the remarkable story of Ushahidi, a GoogleMap mashup that allowed Kenyans to report and track violence via cell phone texts following the 2008 elections, and has evolved to continue saving lives in other countries. \n", 229 | "3 Charity Wayua put her skills as a cancer researcher to use on an unlikely patient: the government of her native Kenya. She shares how she helped her government drastically improve its process for opening up new businesses, a crucial part of economic health and growth, leading to new investments and a World Bank recognition as a top reformer. \n", 230 | "4 After fleeing war-torn South Sudan as a child, Mary Maker found security and hope in the school at Kenya's Kakuma Refugee Camp. Now a teacher of young refugees herself, she sees education as an essential tool for rebuilding lives -- and empowering a generation of girls who are too often denied entrance into the classroom. \"For the child of war, an education can turn their tears of loss into a passion for peace,\" Maker says. \n", 231 | "\n", 232 | " URL \n", 233 | "0 https://www.ted.com/talks/joseph_lekuton_a_parable_for_kenya/ \n", 234 | "1 https://www.ted.com/talks/boniface_mwangi_the_day_i_stood_up_alone/ \n", 235 | "2 https://www.ted.com/talks/erik_hersman_reporting_crisis_via_texting/ \n", 236 | "3 https://www.ted.com/talks/charity_wayua_a_few_ways_to_fix_a_government/ \n", 237 | "4 https://www.ted.com/talks/mary_maker_why_i_fight_for_the_education_of_refugee_girls_like_me/ " 238 | ] 239 | }, 240 | "execution_count": 37, 241 | "metadata": {}, 242 | "output_type": "execute_result" 243 | } 244 | ], 245 | "source": [ 246 | "querystr = \"\"\n", 247 | "\n", 248 | "df = pd.DataFrame(query_service(querystr)[\"results\"])\n", 249 | "df[COLUMNS].head()" 250 | ] 251 | } 252 | ], 253 | "metadata": { 254 | "kernelspec": { 255 | "display_name": "Python3 (cs-quickstart)", 256 | "language": "python", 257 | "name": "cs-quickstart" 258 | }, 259 | "language_info": { 260 | "codemirror_mode": { 261 | "name": "ipython", 262 | "version": 3 263 | }, 264 | "file_extension": ".py", 265 | "mimetype": "text/x-python", 266 | "name": "python", 267 | "nbconvert_exporter": "python", 268 | "pygments_lexer": "ipython3", 269 | "version": "3.12.2" 270 | } 271 | }, 272 | "nbformat": 4, 273 | "nbformat_minor": 5 274 | } 275 | -------------------------------------------------------------------------------- /examples/09_multihop_rag/streamlit_chatbot_multihop_rag.py: -------------------------------------------------------------------------------- 1 | import streamlit as st # Import python packages 2 | from snowflake.snowpark.context import get_active_session 3 | from snowflake.cortex import Complete 4 | from snowflake.core import Root 5 | import os 6 | 7 | # LangChain imports for structured prompt and tool management 8 | from langchain_core.prompts import ChatPromptTemplate, PromptTemplate 9 | from langchain_core.tools import tool 10 | 11 | import pandas as pd 12 | import json 13 | 14 | pd.set_option("max_colwidth",None) 15 | 16 | # service parameters 17 | SOURCE_DOCS_STAGE = "@CORTEX_SEARCH_DOCS.DATA.DOCS" 18 | SOURCE_DOCS_PATH = "raw_pdf" 19 | CORTEX_SEARCH_DATABASE = "CORTEX_SEARCH_DOCS" 20 | CORTEX_SEARCH_SCHEMA = "DATA" 21 | CORTEX_SEARCH_SERVICE = "SEARCH_SERVICE_MULTIMODAL" 22 | GRAPH_DB = "DOCS_EDGES" 23 | 24 | ### Default Values 25 | DEFAULT_NUM_CHUNKS = 3 26 | DEFAULT_NUM_HOPS = 3 27 | CHUNK_OPTIONS = [1, 2, 3, 4, 5] 28 | 29 | NUM_CHUNKS = st.session_state.get("num_chunks", DEFAULT_NUM_CHUNKS) 30 | NUM_HOPS = st.session_state.get("num_hops", DEFAULT_NUM_HOPS) 31 | SLIDE_WINDOW = 5 # how many last conversations to remember 32 | 33 | # Define columns for the multimodal service 34 | COLUMNS = [ 35 | "text", 36 | "page_number", 37 | "image_filepath" 38 | ] 39 | 40 | def init_session(): 41 | """Initialize the Snowflake session""" 42 | try: 43 | # Try to get active session (works in SiS) 44 | return get_active_session() 45 | except: 46 | # For local development, use st.connection 47 | conn = st.connection("snowflake", type="snowflake") 48 | return conn.session() 49 | 50 | session = init_session() 51 | root = Root(session) 52 | svc = None 53 | 54 | # Initialize the service based on the selected service 55 | def init_service(): 56 | global svc # Make svc a global variable 57 | svc = root.databases[CORTEX_SEARCH_DATABASE].schemas[CORTEX_SEARCH_SCHEMA].cortex_search_services[CORTEX_SEARCH_SERVICE] 58 | return svc 59 | 60 | # Initialize the service 61 | init_service() 62 | 63 | @tool 64 | def search_similar_documents(query: str) -> list: 65 | """ 66 | Search for documents similar to the given query using multimodal search. 67 | 68 | Uses Snowflake's multimodal embedding model to find relevant documents 69 | based on both text and image content. 70 | 71 | Args: 72 | query: The search query text to find similar documents 73 | 74 | Returns: 75 | List of similar documents with their metadata including image paths 76 | """ 77 | # Embed the query using the same multimodal model used for image embeddings 78 | sql_output = session.sql(f""" 79 | SELECT AI_EMBED('voyage-multimodal-3', 80 | '{query.replace("'", "")}') 81 | """).collect() 82 | query_vector = list(sql_output[0].asDict().values())[0] 83 | 84 | response = svc.search(query, COLUMNS, limit=NUM_CHUNKS, 85 | experimental={'queryEmbedding': query_vector}) 86 | 87 | if st.session_state.get('debug', False): 88 | page_numbers = [doc.get('page_number', '') for doc in response.results if doc.get('page_number')] 89 | st.sidebar.write(f" 📄 Found {len(response.results)} similar docs: {page_numbers}") 90 | 91 | return response.results 92 | 93 | @tool 94 | def search_connected_documents(source_paths: list, num_hops: int = None) -> list: 95 | """ 96 | Search for documents that are logically connected to the given source documents. 97 | 98 | Uses Snowflake's graph traversal capabilities to find related documents 99 | through citations, references, and other logical connections. 100 | 101 | Args: 102 | source_paths: List of document paths from search_similar_documents 103 | num_hops: Number of connection hops to traverse (default: user setting) 104 | 105 | Returns: 106 | List of connected documents with explanations of their relationships 107 | """ 108 | if not source_paths: 109 | if st.session_state.get('debug', False): 110 | st.sidebar.write(f" 🔗 No source paths to search") 111 | return [] 112 | 113 | if num_hops is None: 114 | num_hops = st.session_state.get("num_hops", DEFAULT_NUM_HOPS) 115 | 116 | # Build the recursive query to find connected images 117 | query = f""" 118 | SELECT 119 | DEST_PATH AS image_filepath, 120 | DEST_PAGE AS page_number, 121 | ARRAY_AGG(exp.value) AS explanations 122 | FROM TABLE(FIND_CONNECTED_PAGES( 123 | ARRAY_CONSTRUCT({','.join([f"'{path}'" for path in source_paths])}), 124 | {num_hops} 125 | )), 126 | LATERAL FLATTEN(input => EXPLANATIONS) exp 127 | GROUP BY 1, 2 128 | """ 129 | 130 | try: 131 | results = session.sql(query).collect() 132 | # Convert Row objects to dictionaries with lowercase keys for consistent handling 133 | results_dict = [{k.lower(): v for k, v in doc.asDict().items()} for doc in results] 134 | 135 | if st.session_state.get('debug', False): 136 | if results_dict: 137 | page_numbers = [doc.get('page_number', '') for doc in results_dict if doc.get('page_number')] 138 | st.sidebar.write(f" 🔗 Found {len(results_dict)} connected docs: {page_numbers}") 139 | else: 140 | st.sidebar.write(f" 🔗 No connected documents found") 141 | return results_dict 142 | except Exception as e: 143 | st.error(f"Error querying connected documents: {str(e)}") 144 | return [] 145 | 146 | @tool 147 | def summarize_chat_history(chat_history: list, current_question: str) -> str: 148 | """ 149 | Summarize the chat history with the current question to create a better search query. 150 | 151 | Creates a contextual search query that incorporates previous conversation 152 | context to improve document retrieval relevance. 153 | 154 | Args: 155 | chat_history: List of previous chat messages 156 | current_question: The current user question 157 | 158 | Returns: 159 | A summarized query that incorporates chat history context 160 | """ 161 | if not chat_history: 162 | return current_question 163 | 164 | # Format chat history for prompt 165 | history_str = "\n".join([f"{msg['role'].upper()}: {msg['content']}" for msg in chat_history]) 166 | 167 | prompt = f"""Based on the chat history and the question, generate a query that extends the question 168 | with the chat history provided. The query should be in natural language. 169 | Answer with only the query. Do not add any explanation. 170 | 171 | Chat history: {history_str} 172 | 173 | Question: {current_question}""" 174 | 175 | # Use native Cortex Complete 176 | summary = Complete(st.session_state.model_name, prompt) 177 | 178 | if st.session_state.get('debug', False): 179 | st.sidebar.write("🔍 Enhanced query with chat history") 180 | st.sidebar.caption(summary) 181 | 182 | return summary.replace("'", "") 183 | 184 | # Prompt template for consistent multimodal responses 185 | @st.cache_resource 186 | def get_response_prompt_template(): 187 | """Get the prompt template for generating final responses""" 188 | return PromptTemplate.from_template(""" 189 | You are an expert technical assistant. Answer the user's question using the provided documents. 190 | Provide a comprehensive, step-by-step answer based on the document content and context. 191 | Be thorough and include all necessary details and step-by-step instructions. 192 | Never reference page numbers or document sections - include the actual content. 193 | 194 | DOCUMENT CONTEXT: 195 | The following describes what each document covers: 196 | {context} 197 | 198 | CHAT HISTORY: 199 | {chat_history} 200 | 201 | USER QUESTION: {question} 202 | 203 | DOCUMENT IMAGES TO ANALYZE: 204 | Please analyze the following document images: {image_refs} 205 | 206 | ANSWER USER QUESTION: 207 | """) 208 | 209 | # Main declarative workflow 210 | def execute_declarative_workflow(question: str, chat_history: list = None) -> tuple: 211 | """ 212 | Execute a clean, predictable RAG workflow using structured tools. 213 | 214 | This declarative approach follows a proven optimal sequence: 215 | 1. Enhance query with chat history (if available) 216 | 2. Search for similar documents using multimodal search 217 | 3. Find connected documents through graph traversal 218 | 4. Generate comprehensive multimodal response 219 | 220 | Args: 221 | question: User's question 222 | chat_history: Previous conversation context 223 | 224 | Returns: 225 | Tuple of (response_text, document_paths) 226 | """ 227 | 228 | if st.session_state.get('debug', False): 229 | st.sidebar.write("📋 **Reasoning RAG Started**") 230 | 231 | try: 232 | # Enhance query with chat history if available 233 | if chat_history: 234 | enhanced_query = summarize_chat_history.invoke({ 235 | "chat_history": chat_history, 236 | "current_question": question 237 | }) 238 | else: 239 | enhanced_query = question 240 | 241 | # Search for similar documents 242 | st.sidebar.write(f"✅ Step 1: Semantic similarity search") 243 | similar_docs = search_similar_documents.invoke({"query": enhanced_query}) 244 | 245 | 246 | # Search for connected documents 247 | st.sidebar.write(f"✅ Step 2: Relevancy graph search") 248 | similar_paths = [doc.get('image_filepath', '') for doc in similar_docs if doc.get('image_filepath')] 249 | 250 | connected_docs = [] 251 | if similar_paths: 252 | connected_docs = search_connected_documents.invoke({ 253 | "source_paths": similar_paths, 254 | "num_hops": st.session_state.get("num_hops", DEFAULT_NUM_HOPS) 255 | }) 256 | 257 | # Combine all unique document paths 258 | all_paths = [] 259 | connected_paths = [doc.get('image_filepath', '') for doc in connected_docs if doc.get('image_filepath')] 260 | all_paths.extend(connected_paths) 261 | 262 | # Add similar paths that aren't already included 263 | for path in similar_paths: 264 | if path and path not in all_paths: 265 | all_paths.append(path) 266 | 267 | # Create context information 268 | explanations = [] 269 | for i, doc in enumerate(connected_docs): 270 | try: 271 | doc_explanations = json.loads(doc.get('explanations', '[]') or '[]') 272 | except (json.JSONDecodeError, TypeError): 273 | doc_explanations = doc.get('explanations', []) if isinstance(doc.get('explanations'), list) else [] 274 | 275 | if doc_explanations: 276 | explanation_text = "\n".join([f"- {exp}" for exp in doc_explanations]) 277 | doc_path = doc.get('image_filepath', '') or f"connected_doc_{i+1}" 278 | doc_name = doc_path.split('/')[-1] if '/' in doc_path else doc_path 279 | explanations.append(f"Document '{doc_name}' covers:\n{explanation_text}") 280 | 281 | context_info = "\n\n".join(explanations) if explanations else "No specific context available." 282 | 283 | st.sidebar.write(f"✅ Step 3: Generating response with {len(all_paths)} docs") 284 | 285 | # Step 6: Generate final multimodal response 286 | if all_paths: 287 | return generate_multimodal_response(question, all_paths, context_info, chat_history) 288 | else: 289 | return "I couldn't find any relevant documents for your question.", [] 290 | 291 | except Exception as e: 292 | st.error(f"Workflow error: {str(e)}") 293 | return "I encountered an error while processing your request.", [] 294 | 295 | def generate_multimodal_response(question: str, document_paths: list, context: str, chat_history: list = None) -> tuple: 296 | """ 297 | Generate final response using Snowflake's native multimodal capabilities. 298 | 299 | Uses structured prompt template for consistency and Snowflake's PROMPT() function 300 | with TO_FILE() for full multimodal support. 301 | 302 | Args: 303 | question: User's question 304 | document_paths: List of document paths to include 305 | context: Context information from connected documents 306 | chat_history: Previous conversation context 307 | 308 | Returns: 309 | Tuple of (response_text, document_paths) 310 | """ 311 | 312 | # Format chat history 313 | history_str = "" 314 | if chat_history: 315 | history_str = "\n".join([f"{msg['role'].upper()}: {msg['content']}" for msg in chat_history]) 316 | 317 | # Create image file references for PROMPT function 318 | image_files = [] 319 | for path in document_paths: 320 | image_files.append(f"TO_FILE('{SOURCE_DOCS_STAGE}', '{path}')") 321 | image_files_str = ",\n".join(image_files) 322 | 323 | # Create positional references for images in the prompt 324 | image_refs = " ".join([f"{{{i}}}" for i in range(len(document_paths))]) 325 | 326 | # Use structured prompt template with positional references 327 | prompt_template = get_response_prompt_template() 328 | formatted_prompt = prompt_template.format( 329 | context=context, 330 | chat_history=history_str if history_str else "No previous conversation", 331 | question=question, 332 | image_refs=image_refs 333 | ) 334 | 335 | # Escape single quotes in the prompt for SQL 336 | escaped_prompt = formatted_prompt.replace("'", "''") 337 | 338 | # Use Snowflake's native multimodal PROMPT function 339 | query = f""" 340 | SELECT AI_COMPLETE('{st.session_state.model_name}', 341 | PROMPT('{escaped_prompt}', 342 | {image_files_str}) 343 | ); 344 | """ 345 | 346 | sql_output = session.sql(query).collect() 347 | response = list(sql_output[0].asDict().values())[0] 348 | 349 | return response, document_paths 350 | 351 | ### UI Configuration Functions 352 | 353 | def config_options(): 354 | st.sidebar.selectbox( 355 | 'Select your model:', 356 | ( 357 | 'claude-3-7-sonnet', 'claude-3-5-sonnet', 'claude-4-opus', 'claude-4-sonnet', 358 | 'openai-gpt-4.1', 'openai-o4-mini', 359 | 'llama4-maverick', 'llama4-scout', 'pixtral-large' 360 | ), key="model_name") 361 | 362 | # Create two columns for the selectboxes to be side by side with equal width 363 | col1, col2 = st.sidebar.columns([1, 1]) 364 | 365 | with col1: 366 | st.selectbox('Initial docs:', 367 | CHUNK_OPTIONS, key="num_chunks", 368 | index=CHUNK_OPTIONS.index(NUM_CHUNKS), 369 | help="Number of documents to retrieve in **initial semantic similarity search**") 370 | 371 | with col2: 372 | st.selectbox('Hops:', 373 | CHUNK_OPTIONS, key="num_hops", 374 | index=CHUNK_OPTIONS.index(NUM_HOPS), 375 | help="Number of connection hops to traverse in **secondary relevancy graph search**") 376 | 377 | st.sidebar.checkbox('Chat history', key="use_chat_history", value = False) 378 | st.sidebar.checkbox('Debug mode', key="debug", value = True) 379 | 380 | st.sidebar.button("Start Over", key="clear_conversation", on_click=init_messages) 381 | st.sidebar.expander("Session State").write(st.session_state) 382 | 383 | def init_messages(): 384 | # Initialize chat history 385 | if st.session_state.clear_conversation or "messages" not in st.session_state: 386 | st.session_state.messages = [] 387 | 388 | def get_chat_history(): 389 | """Get the history from the st.session_state.messages according to the slide window parameter""" 390 | chat_history = [] 391 | start_index = max(0, len(st.session_state.messages) - SLIDE_WINDOW) 392 | for i in range(start_index, len(st.session_state.messages) - 1): 393 | chat_history.append(st.session_state.messages[i]) 394 | return chat_history 395 | 396 | def main(): 397 | st.title(f":robot_face: :mag_right: Multi-hop RAG Assistant") 398 | st.write("Predictable, optimized multi-hop retrieval with full multimodal support") 399 | st.write("List of pre-processed documents that will be used to answer your questions:") 400 | docs_available = session.sql(f"ls {SOURCE_DOCS_STAGE}/{SOURCE_DOCS_PATH}").collect() 401 | list_docs = [doc["name"] for doc in docs_available] 402 | for doc in list_docs: 403 | st.markdown(f"- {doc}") 404 | 405 | config_options() 406 | init_messages() 407 | 408 | # Display chat messages from history on app rerun 409 | for message in st.session_state.messages: 410 | with st.chat_message(message["role"]): 411 | st.markdown(message["content"]) 412 | 413 | # Accept user input 414 | if question := st.chat_input("What do you want to know about your products?"): 415 | # Add user message to chat history 416 | st.session_state.messages.append({"role": "user", "content": question}) 417 | # Display user message in chat message container 418 | with st.chat_message("user"): 419 | st.markdown(question) 420 | # Display assistant response in chat message container 421 | with st.chat_message("assistant"): 422 | message_placeholder = st.empty() 423 | 424 | question = question.replace("'","") 425 | 426 | with st.spinner(f"Multimodal RAG + {st.session_state.model_name} thinking..."): 427 | # Get chat history for context 428 | chat_history = get_chat_history() if st.session_state.use_chat_history else [] 429 | 430 | # Execute the declarative RAG workflow 431 | response, relative_paths = execute_declarative_workflow(question, chat_history) 432 | 433 | response = response.replace("'", "") 434 | message_placeholder.markdown(response) 435 | 436 | # Display images inline 437 | if relative_paths: 438 | st.write("**Related Documents:**") 439 | cols = st.columns(3) # Create 3 columns for images 440 | for i, path in enumerate(relative_paths): 441 | cmd2 = f"select GET_PRESIGNED_URL('{SOURCE_DOCS_STAGE}', '{path}', 360) as URL_LINK;" 442 | df_url_link = session.sql(cmd2).to_pandas() 443 | url_link = df_url_link._get_value(0,'URL_LINK') 444 | 445 | # Display image in the appropriate column 446 | with cols[i % 3]: 447 | st.image(url_link, caption=f"Document {i+1}", use_container_width=True) 448 | filename = path.split('/')[-1] if '/' in path else path 449 | st.markdown(f"**Source:** [{filename}]({url_link})") 450 | 451 | st.session_state.messages.append({"role": "assistant", "content": response}) 452 | 453 | if __name__ == "__main__": 454 | main() -------------------------------------------------------------------------------- /examples/07_streamlit_search_evaluation_app/eval_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | from unittest.mock import patch, MagicMock 4 | import math 5 | from eval import ( 6 | hit_rate, 7 | sdcg, 8 | precision, 9 | _dcg, 10 | generate_and_store_scrape, 11 | prepare_query_df, 12 | perform_scrape_for_eval, 13 | perform_scrape_for_autotune, 14 | store_scrape_results, 15 | validate_scrape, 16 | get_result_limit, 17 | prepare_relevancy_df, 18 | extract_and_dedupe_goldens, 19 | prepare_golden_scores, 20 | evaluate_queries, 21 | compute_fusion_score_from_service, 22 | QUERY_ID, 23 | QUERY, 24 | HIT_RATE, 25 | RELEVANCY, 26 | RUN_ID, 27 | DOC_ID, 28 | SDCG, 29 | PRECISION, 30 | DEBUG_PER_RESULT, 31 | RESPONSE_RESULTS, 32 | DEBUG_SIGNALS, 33 | DEBUG, 34 | SLOWMODE, 35 | EmbeddingMultiplier, 36 | RerankingMultiplier, 37 | TopicalityMultiplier, 38 | RERANK_WEIGHTS, 39 | DEFAULT_EMBEDDING_MULTIPLIER, 40 | ) 41 | from datetime import datetime 42 | from snowflake.snowpark import Session 43 | from snowflake.snowpark import Table, DataFrame 44 | import hashlib 45 | 46 | 47 | TEXT = "text" 48 | 49 | 50 | class TestMetrics(TestCase): 51 | def setUp(self): 52 | # Setup some reusable test data 53 | self.results = ["doc1", "doc2", "doc3", "doc4"] 54 | self.golden_to_score = { 55 | "doc1": {"score": 3, "rank": 1}, 56 | "doc2": {"score": 2, "rank": 2}, 57 | "doc5": {"score": 1, "rank": 3}, 58 | "doc6": {"score": 0, "rank": 4}, 59 | } 60 | self.empty_results = [] 61 | self.empty_golden = {} 62 | 63 | def test_hit_rate(self): 64 | # Expected to return 1 because doc1 is in golden_to_score and has score > 0 65 | self.assertEqual(hit_rate(self.results, self.golden_to_score), 1) 66 | 67 | # Expected to return 0 because no documents in results have score > 0 68 | self.assertEqual(hit_rate(["doc4"], self.golden_to_score), 0) 69 | 70 | # Test with empty results 71 | self.assertEqual(hit_rate(self.empty_results, self.golden_to_score), 0) 72 | 73 | def test_sdcg(self): 74 | # Calculate DCG and SDCG values for given results and golden data 75 | expected_sdcg = _dcg(self.results, self.golden_to_score) / ( 76 | 3.0 * sum(1.0 / math.log2(i + 2.0) for i in range(4)) 77 | ) 78 | self.assertAlmostEqual( 79 | sdcg(3.0, self.results, self.golden_to_score), expected_sdcg, places=5 80 | ) 81 | 82 | # Test SDCG with empty results 83 | self.assertEqual(sdcg(3.0, self.empty_results, self.golden_to_score), 0.0) 84 | 85 | # Test SDCG with empty golden set 86 | self.assertEqual(sdcg(3.0, self.results, self.empty_golden), 0.0) 87 | 88 | def test_precision(self): 89 | # 2 relevant documents (doc1, doc2), out of 4 total -> precision = 2/4 = 0.5 90 | self.assertAlmostEqual(precision(self.results, self.golden_to_score), 0.5) 91 | 92 | # No relevant documents, precision should be 0 93 | self.assertEqual(precision(["doc4"], self.golden_to_score), 0.0) 94 | 95 | # Test precision with empty results 96 | self.assertEqual(precision(self.empty_results, self.golden_to_score), 0.0) 97 | 98 | def test_dcg(self): 99 | # Test DCG for correct calculation 100 | expected_dcg = 3 / math.log2(2) + 2 / math.log2(3) # score of doc1 and doc2 101 | self.assertAlmostEqual( 102 | _dcg(self.results, self.golden_to_score), expected_dcg, places=5 103 | ) 104 | 105 | # Test DCG with empty results 106 | self.assertEqual(_dcg(self.empty_results, self.golden_to_score), 0.0) 107 | 108 | 109 | class TestScrapeFlow(TestCase): 110 | @patch("eval.get_active_session") 111 | @patch("eval.Root") 112 | @patch("streamlit.session_state", new_callable=MagicMock) 113 | @patch("eval.prepare_query_df") 114 | @patch("eval.perform_scrape_for_eval") 115 | @patch("eval.store_scrape_results") 116 | def test_generate_and_store_scrape( 117 | self, 118 | mock_store_scrape_results, 119 | mock_perform_scrape_for_eval, 120 | mock_prepare_query_df, 121 | mock_session_state, 122 | mock_root, 123 | mock_get_active_session, 124 | ): 125 | # Mock session state properties 126 | mock_session_state.start_time = datetime(2024, 1, 1) 127 | 128 | # Mock return values 129 | mock_query_table = MagicMock(name="query_table") 130 | mock_prepare_query_df.return_value = mock_query_table 131 | mock_scrape_out = [{"example_key": "example_value"}] 132 | mock_perform_scrape_for_eval.return_value = mock_scrape_out 133 | 134 | # Mock session and root 135 | mock_session = mock_get_active_session.return_value 136 | mock_root_instance = mock_root.return_value 137 | 138 | # Call function 139 | generate_and_store_scrape(mock_session, mock_root_instance) 140 | 141 | # Assertions 142 | mock_prepare_query_df.assert_called_once() 143 | mock_perform_scrape_for_eval.assert_called_once_with( 144 | mock_session, mock_query_table, mock_root_instance 145 | ) 146 | mock_store_scrape_results.assert_called_once_with(mock_scrape_out) 147 | 148 | @patch("eval.get_active_session") # Mock get_active_session 149 | @patch("eval.get_session") # Mock get_session 150 | @patch("eval.Root") # Mock Root 151 | @patch("streamlit.session_state", new_callable=MagicMock) # Mock st.session_state 152 | def test_prepare_query_df_creates_query_id_if_missing( 153 | self, mock_session_state, mock_root, mock_get_session, mock_get_active_session 154 | ): 155 | # Define mock properties for session state 156 | mock_session_state.queryset_fqn = "mock_fqn" 157 | mock_session_state.md5_hash.return_value = "mock_query_id" 158 | 159 | # Mock the session object returned by get_session 160 | mock_session = MagicMock(spec=Session) 161 | mock_get_session.return_value = mock_session 162 | 163 | # Mock the table method on the session 164 | mock_query_table = MagicMock() 165 | mock_query_table.columns = [QUERY] # Simulate missing QUERY_ID initially 166 | mock_query_table.with_column.return_value = mock_query_table 167 | 168 | # Setup the session.table mock to return mock_query_table 169 | mock_session.table.return_value = mock_query_table 170 | 171 | # Call the function 172 | result = prepare_query_df(mock_session) # Pass the mock session 173 | 174 | # Assertions 175 | mock_session.table.assert_called_once_with( 176 | "mock_fqn" 177 | ) # Check if session.table was called with the correct fqn 178 | mock_query_table.with_column.assert_called_once_with( 179 | QUERY_ID, mock_session_state.md5_hash(mock_query_table[QUERY]) 180 | ) # Ensure with_column was called to add QUERY_ID using md5_hash 181 | self.assertEqual( 182 | result, mock_query_table 183 | ) # Ensure the result is the mock query_table 184 | 185 | @patch( 186 | "streamlit.session_state", new_callable=MagicMock 187 | ) # Mock streamlit session_state 188 | @patch("eval.Session", new_callable=MagicMock) # Mock eval.Session 189 | @patch("eval.perform_scrape", new_callable=MagicMock) # Mock perform_scrape 190 | def test_perform_scrape_for_eval( 191 | self, mock_perform_scrape, mock_session, mock_session_state 192 | ): 193 | # Mock session state properties 194 | mock_session_state.css_id_given = True # Simulate that CSS ID is given 195 | mock_session_state.css_id_col = DOC_ID 196 | mock_session_state.css_text_col = "text" # Adjust based on your code 197 | mock_session_state.css_fqn = "mock_db.mock_schema.mock_service" 198 | 199 | # Mock input query DataFrame 200 | mock_query_df = MagicMock() 201 | mock_query_df.collect.return_value = [{QUERY: "test query", "QUERY_ID": "1"}] 202 | 203 | # Mock perform_scrape output 204 | mock_perform_scrape.return_value = { 205 | "1": { 206 | QUERY: "test query", 207 | RUN_ID: "mock_run_id", 208 | RESPONSE_RESULTS: [ 209 | { 210 | mock_session_state.css_text_col: "sample text", 211 | DEBUG_PER_RESULT: {"score": 0.9}, 212 | } 213 | ], 214 | } 215 | } 216 | 217 | # Expected output 218 | expected_output = [ 219 | { 220 | QUERY: "test query", 221 | RUN_ID: "mock_run_id", 222 | QUERY_ID: "1", 223 | DOC_ID: hashlib.md5("sample text".encode("utf-8")).hexdigest(), 224 | "RANK": 1, 225 | DEBUG_SIGNALS: {"score": 0.9}, 226 | "TEXT": "sample text", 227 | } 228 | ] 229 | 230 | # Mock the DataFrame creation 231 | mock_scrape_df = MagicMock() 232 | mock_session.create_dataframe.return_value = mock_scrape_df 233 | 234 | # Call the function under test 235 | result_df = perform_scrape_for_eval( 236 | mock_session, mock_query_df, MagicMock(), run_id="mock_run_id" 237 | ) 238 | 239 | mock_session.create_dataframe.assert_called_once_with( 240 | expected_output 241 | ) # Check DataFrame creation 242 | 243 | # Ensure returned DataFrame matches the mocked DataFrame 244 | self.assertEqual(result_df, mock_scrape_df) 245 | 246 | @patch("streamlit.session_state", new_callable=MagicMock) 247 | def test_store_scrape_results(self, mock_session_state): 248 | # Mock session state properties 249 | mock_session_state.scrape_fqn = "mock_db.mock_schema.mock_table" 250 | mock_session_state.start_time = datetime(2024, 1, 1, 12, 0, 0) 251 | 252 | # Mock the scrape DataFrame 253 | mock_scrape_df = MagicMock(name="scrape_df") 254 | 255 | # Mock datetime to control timing 256 | mock_now = datetime(2024, 1, 1, 12, 0, 10) 257 | with patch("eval.datetime") as mock_datetime: 258 | mock_datetime.now.return_value = mock_now 259 | 260 | # Call the function under test 261 | store_scrape_results(mock_scrape_df) 262 | 263 | # Assertions for the DataFrame write operation 264 | mock_scrape_df.write.mode.assert_called_once_with("append") 265 | mock_scrape_df.write.mode().save_as_table.assert_called_once_with( 266 | "mock_db.mock_schema.mock_table" 267 | ) 268 | 269 | # Assertions for session state and success message 270 | duration = mock_now - mock_session_state.start_time 271 | assert duration.total_seconds() == 10.0 272 | 273 | 274 | class TestPrepareRelevancyTable(TestCase): 275 | @patch("streamlit.session_state", new_callable=MagicMock) 276 | @patch("eval.Session") 277 | def test_prepare_relevancy_df(self, mock_session, mock_session_state): 278 | # Mock session state properties 279 | mock_session_state.md5_hash = MagicMock(side_effect=lambda x: f"hash_{x}") 280 | mock_session_state.css_text_col = TEXT 281 | 282 | # Mock relevancy table with QUERY_ID and DOC_ID present 283 | mock_table_with_ids = MagicMock() 284 | mock_table_with_ids.columns = [QUERY_ID, DOC_ID] 285 | mock_table_with_ids.withColumn.return_value = ( 286 | mock_table_with_ids # Mock chainable behavior 287 | ) 288 | 289 | mock_session.table.return_value = mock_table_with_ids 290 | 291 | # Call the function under test 292 | result_table = prepare_relevancy_df("relevancy_fqn", mock_session) 293 | 294 | # Assert QUERY_ID and DOC_ID were cast to string 295 | mock_table_with_ids.withColumn.assert_any_call( 296 | QUERY_ID, mock_table_with_ids[QUERY_ID].cast("string") 297 | ) 298 | mock_table_with_ids.withColumn.assert_any_call( 299 | DOC_ID, mock_table_with_ids[DOC_ID].cast("string") 300 | ) 301 | 302 | # Assert no new columns were added 303 | self.assertEqual(result_table, mock_table_with_ids) 304 | mock_session_state.md5_hash.assert_not_called() 305 | 306 | @patch("streamlit.session_state", new_callable=MagicMock) 307 | @patch("eval.Session") 308 | def test_prepare_relevancy_df_missing_columns( 309 | self, mock_session, mock_session_state 310 | ): 311 | # Mock session state properties 312 | mock_session_state.md5_hash = MagicMock(side_effect=lambda x: f"hash_{x}") 313 | mock_session_state.css_text_col = TEXT 314 | 315 | # Mock relevancy table missing QUERY_ID and DOC_ID 316 | mock_table_missing_ids = MagicMock() 317 | mock_table_missing_ids.columns = [QUERY] 318 | mock_table_missing_ids.withColumn.return_value = ( 319 | mock_table_missing_ids # Mock chainable behavior 320 | ) 321 | 322 | mock_session.table.return_value = mock_table_missing_ids 323 | 324 | # Call the function under test 325 | result_table = prepare_relevancy_df("relevancy_fqn", mock_session) 326 | 327 | # Assert QUERY_ID was added using md5_hash on QUERY 328 | mock_table_missing_ids.withColumn.assert_any_call( 329 | QUERY_ID, mock_session_state.md5_hash(mock_table_missing_ids[QUERY]) 330 | ) 331 | 332 | # Assert DOC_ID was added using md5_hash on css_text_col 333 | mock_table_missing_ids.withColumn.assert_any_call( 334 | DOC_ID, 335 | mock_session_state.md5_hash( 336 | mock_table_missing_ids[mock_session_state.css_text_col] 337 | ), 338 | ) 339 | 340 | # Assert result_table matches the transformed table 341 | self.assertEqual(result_table, mock_table_missing_ids) 342 | 343 | 344 | class TestEvalFlow(TestCase): 345 | @patch("streamlit.session_state", new_callable=MagicMock) 346 | @patch("eval.Session") 347 | def test_prepare_query_df(self, mock_session, mock_session_state): 348 | # Mock session and input table 349 | mock_query_table = MagicMock() 350 | mock_session.table.return_value = mock_query_table 351 | mock_query_table.columns = ["query_column"] # QUERY_ID missing initially 352 | 353 | # Mock behavior of session_state's md5_hash 354 | mock_session_state.queryset_fqn = "mock_queryset_fqn" 355 | mock_session_state.md5_hash.return_value = "mock_hash" 356 | 357 | # Mock the with_column behavior to add QUERY_ID 358 | mock_query_table.with_column.return_value = mock_query_table 359 | 360 | # Call prepare_query_df 361 | query_df = prepare_query_df(mock_session) 362 | 363 | # Verify the correct calls 364 | mock_session.table.assert_called_once_with("mock_queryset_fqn") 365 | mock_query_table.with_column.assert_called_once_with( 366 | "QUERY_ID", mock_session_state.md5_hash(mock_query_table["query_column"]) 367 | ) 368 | self.assertEqual(query_df, mock_query_table) 369 | 370 | @patch("streamlit.session_state", new_callable=MagicMock) 371 | @patch("eval.Session") 372 | def test_prepare_query_df_with_query_id(self, mock_session, mock_session_state): 373 | # Mock session and input table 374 | mock_query_table = MagicMock() 375 | mock_session.table.return_value = mock_query_table 376 | mock_query_table.columns = [ 377 | "query_column", 378 | "QUERY_ID", 379 | ] # QUERY_ID already exists 380 | 381 | # Mock behavior of session_state's md5_hash 382 | mock_session_state.queryset_fqn = "mock_queryset_fqn" 383 | 384 | # Call prepare_query_df 385 | query_df = prepare_query_df(mock_session) 386 | 387 | # Verify the correct calls 388 | mock_session.table.assert_called_once_with("mock_queryset_fqn") 389 | mock_query_table.with_column.assert_not_called() # QUERY_ID already exists 390 | self.assertEqual(query_df, mock_query_table) 391 | 392 | @patch("streamlit.session_state", new_callable=MagicMock) 393 | def test_validate_scrape(self, mock_session_state): 394 | mock_scrape_df = MagicMock(spec=DataFrame) 395 | mock_scrape_df.count.return_value = 1 396 | 397 | validate_scrape(mock_scrape_df) 398 | 399 | mock_scrape_df.count.return_value = 0 400 | with self.assertRaises( 401 | AssertionError, 402 | msg="Scrape is empty! Recheck the Run ID or the Scrape table.", 403 | ): 404 | validate_scrape(mock_scrape_df) 405 | 406 | @patch("eval.spmax") 407 | def test_get_result_limit(self, mock_spmax): 408 | mock_scrape_df = MagicMock(spec=DataFrame) 409 | mock_spmax.return_value = MagicMock() 410 | mock_scrape_df.select.return_value.collect.return_value = [[5]] 411 | 412 | result = get_result_limit(mock_scrape_df) 413 | self.assertEqual(result, 5) 414 | 415 | @patch("streamlit.session_state", new_callable=MagicMock) 416 | def test_extract_and_dedupe_goldens(self, mock_session_state): 417 | self.maxDiff = None 418 | mock_session_state.rel_scores = {} 419 | mock_session_state.colors = {} 420 | relevance_color_mapping = { 421 | 0: "lightcoral", 422 | 1: "lightyellow", 423 | 2: "lightgreen", 424 | 3: "lightgreen", 425 | } 426 | 427 | # Setup mock rows in relevancy_table with duplicate (query_id, doc_id) pairs 428 | mock_rows = [ 429 | {QUERY_ID: "q1", DOC_ID: "d1", RELEVANCY: 1}, 430 | { 431 | QUERY_ID: "q1", 432 | DOC_ID: "d1", 433 | RELEVANCY: 2, 434 | }, # Higher score, should replace 435 | {QUERY_ID: "q1", DOC_ID: "d2", RELEVANCY: 1}, 436 | {QUERY_ID: "q2", DOC_ID: "d1", RELEVANCY: 3}, 437 | { 438 | QUERY_ID: "q2", 439 | DOC_ID: "d1", 440 | RELEVANCY: 2, 441 | }, # Lower score, should be ignored 442 | ] 443 | 444 | # Run function 445 | mock_relevancy_table = MagicMock() 446 | mock_relevancy_table.collect.return_value = mock_rows 447 | raw_goldens = extract_and_dedupe_goldens(mock_relevancy_table) 448 | 449 | # Assertions on deduplicated results 450 | expected_raw_goldens = { 451 | "q1": [("d1", 2), ("d2", 1)], # d1 has the updated score of 2 452 | "q2": [("d1", 3)], # d1 has the highest score of 3 453 | } 454 | 455 | self.assertEqual(raw_goldens, expected_raw_goldens) 456 | 457 | # Verify session_state updates 458 | self.assertEqual(mock_session_state.rel_scores["q1"]["d1"], 2) 459 | self.assertEqual(mock_session_state.rel_scores["q2"]["d1"], 3) 460 | self.assertEqual( 461 | mock_session_state.colors["q1"]["d1"], relevance_color_mapping[2] 462 | ) 463 | self.assertEqual( 464 | mock_session_state.colors["q2"]["d1"], relevance_color_mapping[3] 465 | ) 466 | 467 | @patch("streamlit.session_state", new_callable=MagicMock) 468 | def test_prepare_golden_scores(self, mock_session_state): 469 | mock_session_state.relevancy_provided = True 470 | raw_goldens = { 471 | "q1": [("d1", 3), ("d2", 2)], 472 | "q2": [("d3", 1)], 473 | } 474 | result = prepare_golden_scores(raw_goldens) 475 | 476 | self.assertIn("q1", result) 477 | self.assertEqual(result["q1"]["d1"]["rank"], 0) 478 | self.assertEqual(result["q1"]["d1"]["score"], 3) 479 | self.assertEqual(result["q2"]["d3"]["rank"], 0) 480 | self.assertEqual(result["q2"]["d3"]["score"], 1) 481 | 482 | @patch("eval.calculate_metrics") 483 | @patch("streamlit.progress") 484 | @patch("streamlit.empty") 485 | @patch("streamlit.session_state", new_callable=MagicMock) 486 | def test_evaluate_queries( 487 | self, mock_session_state, mock_empty, mock_progress, mock_calculate_metrics 488 | ): 489 | mock_query_table = MagicMock(spec=Table) 490 | mock_query_table.collect.return_value = [{QUERY_ID: "q1"}, {QUERY_ID: "q2"}] 491 | mock_scrape_df = MagicMock(spec=DataFrame) 492 | mock_scrape_df.collect.side_effect = [ 493 | [{DOC_ID: "d1"}, {DOC_ID: "d2"}], # For q1 494 | [{DOC_ID: "d3"}], # For q2 495 | ] 496 | goldens = { 497 | "q1": {"d1": {"rank": 0, "score": 3}, "d2": {"rank": 1, "score": 2}}, 498 | "q2": {"d3": {"rank": 0, "score": 1}}, 499 | } 500 | mock_calculate_metrics.return_value = { 501 | HIT_RATE: 0.8, 502 | SDCG: 0.7, 503 | PRECISION: 0.6, 504 | } 505 | mock_session_state.idcg_factor = 3.0 506 | 507 | result = evaluate_queries(mock_query_table, mock_scrape_df, goldens) 508 | 509 | self.assertEqual(len(result), 2) 510 | self.assertEqual(result[0][QUERY_ID], "q1") 511 | self.assertEqual(result[0][HIT_RATE], 0.8) 512 | self.assertEqual(result[1][QUERY_ID], "q2") 513 | 514 | 515 | class TestAutotuneFlow(TestCase): 516 | @patch("streamlit.session_state", new_callable=MagicMock) # Mock st.session_state 517 | @patch("eval.generate_docid") 518 | @patch("eval.perform_scrape") 519 | @patch("eval.Root") 520 | def test_perform_scrape_for_autotune( 521 | self, 522 | mock_root, 523 | mock_perform_scrape, 524 | mock_generate_docid, 525 | mock_session_state, 526 | ): 527 | # Mock parameters and results 528 | mock_experimental_params = { 529 | DEBUG: True, 530 | SLOWMODE: True, 531 | RERANK_WEIGHTS: { 532 | RerankingMultiplier: 1.4, 533 | EmbeddingMultiplier: DEFAULT_EMBEDDING_MULTIPLIER, 534 | TopicalityMultiplier: 1.0, 535 | }, 536 | } 537 | 538 | # Mock return values 539 | mock_query_df = MagicMock(spec=DataFrame, name="query_df") 540 | mock_scrape_out = { 541 | "123": { 542 | QUERY: "abc", 543 | RUN_ID: "xyz", 544 | RESPONSE_RESULTS: [{"TEXT": "text"}], 545 | } 546 | } 547 | mock_perform_scrape.return_value = mock_scrape_out 548 | mock_generate_docid.return_value = "mnp" 549 | 550 | # Mock session and root 551 | mock_session_state.css_text_col = "TEXT" 552 | mock_root_instance = mock_root.return_value 553 | 554 | # Expected output 555 | expected_output = {"123": ["mnp"]} 556 | 557 | # Call function 558 | result = perform_scrape_for_autotune( 559 | mock_query_df, mock_root_instance, mock_experimental_params 560 | ) 561 | 562 | # Assertions 563 | mock_perform_scrape.assert_called_once_with( 564 | mock_query_df, 565 | mock_root_instance, 566 | autotune=True, 567 | experimental_params=mock_experimental_params, 568 | run_id="", 569 | ) 570 | mock_generate_docid.assert_called_once() 571 | self.assertEqual(result, expected_output) 572 | 573 | @patch("streamlit.session_state", new_callable=MagicMock) # Mock st.session_state 574 | @patch("eval.perform_scrape_for_autotune") 575 | @patch("eval.sdcg") 576 | @patch("eval.Root") 577 | def test_compute_fusion_score_from_service( 578 | self, 579 | mock_root, 580 | mock_sdcg, 581 | mock_perform_scrape_for_autotune, 582 | mock_session_state, 583 | ): 584 | # Mock parameters and results 585 | mock_params = { 586 | RerankingMultiplier: 1.4, 587 | EmbeddingMultiplier: DEFAULT_EMBEDDING_MULTIPLIER, 588 | TopicalityMultiplier: 1.0, 589 | } 590 | mock_experimental_params = { 591 | DEBUG: True, 592 | SLOWMODE: True, 593 | RERANK_WEIGHTS: mock_params, 594 | } 595 | mock_doc_list = ["mnp"] 596 | mock_golden_set = { 597 | "abc": {"score": 3}, 598 | "mnp": {"score": 8}, 599 | } 600 | mock_query_to_doc_list = {"123": mock_doc_list} 601 | mock_goldens = {"123": mock_golden_set} 602 | 603 | # Mock return values 604 | mock_query_df = MagicMock(spec=DataFrame, name="query_df") 605 | mock_query_df.collect.return_value = [{QUERY_ID: "123"}] 606 | mock_perform_scrape_for_autotune.return_value = mock_query_to_doc_list 607 | mock_sdcg.return_value = 0.7 608 | 609 | # Mock session and root 610 | mock_session_state.idcg_factor = 3.0 611 | mock_root_instance = mock_root.return_value 612 | 613 | # Expected output 614 | expected_output = 0.7 615 | 616 | # Call function 617 | result = compute_fusion_score_from_service( 618 | mock_root_instance, mock_query_df, mock_goldens, mock_params 619 | ) 620 | 621 | # Assertions 622 | mock_perform_scrape_for_autotune.assert_called_once_with( 623 | mock_query_df, 624 | mock_root_instance, 625 | experimental_params=mock_experimental_params, 626 | ) 627 | mock_sdcg.assert_called_once_with( 628 | idcg_factor=3.0, 629 | results=["mnp"], 630 | golden_to_score={"abc": {"score": 3}, "mnp": {"score": 8}}, 631 | ) 632 | self.assertEqual(result, expected_output) 633 | 634 | 635 | if __name__ == "__main__": 636 | unittest.main() 637 | -------------------------------------------------------------------------------- /examples/08_multimodal_rag/cortex_search_multimodal.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c02b90c5-e06b-4082-ad0c-7e3a5e31a457", 6 | "metadata": { 7 | "collapsed": false, 8 | "jupyter": { 9 | "outputs_hidden": false 10 | }, 11 | "name": "cell13" 12 | }, 13 | "source": [ 14 | "# Implementing multimodal retrieval using Cortex Search Service\n", 15 | "\n", 16 | "Welcome! This tutorial shows a lightweight example where a customer has 2 long pdfs and wants to search and ask natural questions on them. On a high level, this tutorial demonstrates:\n", 17 | "\n", 18 | "- Convert long PDF files to document screenshots (images).\n", 19 | "- (Optional but highly recommended) Run parse_document on PDFs for auxiliary text retrieval to further improve quality.\n", 20 | "- Embed document screenshots using EMBED_IMAGE_1024 (PrPr) which runs `voyage-multimodal-3` under the hood\n", 21 | "- Create a Cortex Search Service using multimodal embeddings and OCR text.\n", 22 | "- Retrieve top pages using Cortex Search.\n", 23 | "- Get natural language answer with multimodal RAG!" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "30728acc-9579-4149-bcc5-b3241e5eac65", 29 | "metadata": { 30 | "collapsed": false, 31 | "jupyter": { 32 | "outputs_hidden": false 33 | }, 34 | "name": "cell15" 35 | }, 36 | "source": [ 37 | "To start with, make sure you have PDFs stored under a stage. The two PDF files used in this demo can be found at https://drive.google.com/drive/folders/1bExhPiJlF9aNushnXeLLBR4m9EMaShHw?usp=sharing" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "505524cb-873d-425d-83e9-de8ec433b5e4", 44 | "metadata": { 45 | "language": "sql", 46 | "name": "cell4" 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "-- CREATE SCHEMA IF NOT EXISTS CORTEX_SEARCH_DB.PYU;\n", 51 | "-- CREATE OR REPLACE STAGE CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO\n", 52 | "-- STORAGE_INTEGRATION = ML_DEV\n", 53 | "-- URL = 's3://ml-dev-sfc-or-dev-misc1-k8s/cortexsearch/pyu/multimodal/demo/'\n", 54 | "-- DIRECTORY = (ENABLE = TRUE);\n", 55 | "\n", 56 | "-- CREATE OR REPLACE STAGE CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL DIRECTORY = (ENABLE = TRUE) ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE');\n", 57 | "\n", 58 | "-- COPY FILES INTO @CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/raw_pdf/\n", 59 | "-- FROM @CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO/raw_pdf/;\n", 60 | "\n", 61 | "LS @CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/raw_pdf/;" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "ab710021-6fc0-4358-88eb-4bac717184c3", 67 | "metadata": { 68 | "collapsed": false, 69 | "jupyter": { 70 | "outputs_hidden": false 71 | }, 72 | "name": "cell5" 73 | }, 74 | "source": [ 75 | "Now let's run some python code:\n", 76 | "\n", 77 | "The purpose is to paginate raw pages into pages -- in image and PDF format. Images are for multimodal retrieval, while PDFs are for better OCR quality (optional). As long as you configure the config correctly, you are good to go!\n", 78 | "\n", 79 | "```\n", 80 | "class Config:\n", 81 | " input_stage: str = \"@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/raw_pdf/\"\n", 82 | " output_stage: str = \"@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/\"\n", 83 | " input_path: str = \"raw_pdf\"\n", 84 | " output_pdf_path: str = \"paged_pdf\"\n", 85 | " output_image_path: str = \"paged_image\"\n", 86 | " allowed_extensions: List[str] = None\n", 87 | " max_dimension: int = 1500 # Maximum dimension in pixels before scaling\n", 88 | " dpi: int = 300 # Default DPI for image conversion\n", 89 | "\n", 90 | " def __post_init__(self):\n", 91 | " if self.allowed_extensions is None:\n", 92 | " self.allowed_extensions = [\".pdf\"]\n", 93 | "```\n", 94 | "\n", 95 | "**Make sure the output_stage is an internal stage**, because `embed_image_1024` only works with internal stages at the moment." 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "3775908f-ca36-4846-8f38-5adca39217f2", 102 | "metadata": { 103 | "language": "python", 104 | "name": "cell1" 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "# Import python packages\n", 109 | "import os\n", 110 | "import sys\n", 111 | "import tempfile\n", 112 | "from contextlib import contextmanager\n", 113 | "from dataclasses import dataclass\n", 114 | "from typing import List\n", 115 | "from typing import Tuple\n", 116 | "\n", 117 | "import pdfplumber\n", 118 | "import PyPDF2\n", 119 | "import snowflake.snowpark.session as session\n", 120 | "import streamlit as st\n", 121 | "\n", 122 | "\n", 123 | "def print_info(msg: str) -> None:\n", 124 | " \"\"\"Print info message\"\"\"\n", 125 | " print(f\"INFO: {msg}\", file=sys.stderr)\n", 126 | "\n", 127 | "\n", 128 | "def print_error(msg: str) -> None:\n", 129 | " \"\"\"Print error message\"\"\"\n", 130 | " print(f\"ERROR: {msg}\", file=sys.stderr)\n", 131 | " if hasattr(st, \"error\"):\n", 132 | " st.error(msg)\n", 133 | "\n", 134 | "\n", 135 | "def print_warning(msg: str) -> None:\n", 136 | " \"\"\"Print warning message\"\"\"\n", 137 | " print(f\"WARNING: {msg}\", file=sys.stderr)\n", 138 | "\n", 139 | "\n", 140 | "@dataclass\n", 141 | "class Config:\n", 142 | " input_stage: str = \"@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/raw_pdf/\"\n", 143 | " output_stage: str = (\n", 144 | " \"@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/\" # Base output stage without subdirectories\n", 145 | " )\n", 146 | " input_path: str = \"raw_pdf\"\n", 147 | " output_pdf_path: str = \"paged_pdf\"\n", 148 | " output_image_path: str = \"paged_image\"\n", 149 | " allowed_extensions: List[str] = None\n", 150 | " max_dimension: int = 1500 # Maximum dimension in pixels before scaling\n", 151 | " dpi: int = 300 # Default DPI for image conversion\n", 152 | "\n", 153 | " def __post_init__(self):\n", 154 | " if self.allowed_extensions is None:\n", 155 | " self.allowed_extensions = [\".pdf\"]\n", 156 | "\n", 157 | "\n", 158 | "class PDFProcessingError(Exception):\n", 159 | " \"\"\"Base exception for PDF processing errors\"\"\"\n", 160 | "\n", 161 | "\n", 162 | "class FileDownloadError(PDFProcessingError):\n", 163 | " \"\"\"Raised when file download fails\"\"\"\n", 164 | "\n", 165 | "\n", 166 | "class PDFConversionError(PDFProcessingError):\n", 167 | " \"\"\"Raised when PDF conversion fails\"\"\"\n", 168 | "\n", 169 | "\n", 170 | "@contextmanager\n", 171 | "def managed_temp_file(suffix: str = None) -> str:\n", 172 | " \"\"\"Context manager for temporary file handling\"\"\"\n", 173 | " temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)\n", 174 | " try:\n", 175 | " yield temp_file.name\n", 176 | " finally:\n", 177 | " # Don't delete the file immediately, let the caller handle cleanup\n", 178 | " pass\n", 179 | "\n", 180 | "\n", 181 | "def cleanup_temp_file(file_path: str) -> None:\n", 182 | " \"\"\"Clean up a temporary file\"\"\"\n", 183 | " try:\n", 184 | " if os.path.exists(file_path):\n", 185 | " os.unlink(file_path)\n", 186 | " except OSError as e:\n", 187 | " print_warning(f\"Failed to delete temporary file {file_path}: {e}\")\n", 188 | "\n", 189 | "\n", 190 | "def list_pdf_files(session: session.Session, config: Config) -> List[dict]:\n", 191 | " \"\"\"List all PDF files in the source stage\"\"\"\n", 192 | " try:\n", 193 | " # Use LIST command instead of DIRECTORY function\n", 194 | " query = f\"\"\"\n", 195 | " LIST {config.input_stage}\n", 196 | " \"\"\"\n", 197 | "\n", 198 | " file_list = session.sql(query).collect()\n", 199 | "\n", 200 | " # Filter for PDF files\n", 201 | " pdf_files = []\n", 202 | " for file_info in file_list:\n", 203 | " full_path = file_info[\"name\"]\n", 204 | " # Extract just the filename from the full path\n", 205 | " file_name = os.path.basename(full_path)\n", 206 | "\n", 207 | " if any(\n", 208 | " file_name.lower().endswith(ext) for ext in config.allowed_extensions\n", 209 | " ):\n", 210 | " pdf_files.append(\n", 211 | " {\n", 212 | " \"RELATIVE_PATH\": file_name, # Use just the filename\n", 213 | " \"FULL_STAGE_PATH\": full_path, # Use full path for download\n", 214 | " \"SIZE\": file_info[\"size\"] if \"size\" in file_info else 0,\n", 215 | " }\n", 216 | " )\n", 217 | "\n", 218 | " print_info(f\"Found {len(pdf_files)} PDF files in the stage\")\n", 219 | " return pdf_files\n", 220 | " except Exception as e:\n", 221 | " print_error(f\"Failed to list files: {e}\")\n", 222 | " raise\n", 223 | "\n", 224 | "\n", 225 | "def download_file_from_stage(\n", 226 | " session: session.Session, file_path: str, config: Config\n", 227 | ") -> str:\n", 228 | " \"\"\"Download a file from stage using session.file.get\"\"\"\n", 229 | " # Create a temporary directory\n", 230 | " temp_dir = tempfile.mkdtemp()\n", 231 | " try:\n", 232 | " # Ensure there are no double slashes in the path\n", 233 | " stage_path = f\"{config.input_stage.rstrip('/')}/{file_path.lstrip('/')}\"\n", 234 | "\n", 235 | " # Get the file from stage\n", 236 | " get_result = session.file.get(stage_path, temp_dir)\n", 237 | " if not get_result or get_result[0].status != \"DOWNLOADED\":\n", 238 | " raise FileDownloadError(f\"Failed to download file: {file_path}\")\n", 239 | "\n", 240 | " # Construct the local path where the file was downloaded\n", 241 | " local_path = os.path.join(temp_dir, os.path.basename(file_path))\n", 242 | " if not os.path.exists(local_path):\n", 243 | " raise FileDownloadError(f\"Downloaded file not found at: {local_path}\")\n", 244 | "\n", 245 | " return local_path\n", 246 | " except Exception as e:\n", 247 | " print_error(f\"Error downloading {file_path}: {e}\")\n", 248 | " # Clean up the temporary directory\n", 249 | " try:\n", 250 | " import shutil\n", 251 | "\n", 252 | " shutil.rmtree(temp_dir)\n", 253 | " except Exception as cleanup_error:\n", 254 | " print_warning(f\"Failed to clean up temporary directory: {cleanup_error}\")\n", 255 | " raise FileDownloadError(f\"Failed to download file: {e}\")\n", 256 | "\n", 257 | "\n", 258 | "def upload_file_to_stage(\n", 259 | " session: session.Session, file_path: str, output_path: str, config: Config\n", 260 | ") -> str:\n", 261 | " \"\"\"Upload file to the output stage\"\"\"\n", 262 | " try:\n", 263 | " # Get the directory and filename from the output path\n", 264 | " output_dir = os.path.dirname(output_path)\n", 265 | " base_name = os.path.basename(output_path)\n", 266 | "\n", 267 | " # Create the full stage path with subdirectory\n", 268 | " stage_path = f\"{config.output_stage.rstrip('/')}/{output_dir.lstrip('/')}\"\n", 269 | "\n", 270 | " # Read the content of the original file\n", 271 | " with open(file_path, \"rb\") as f:\n", 272 | " file_content = f.read()\n", 273 | "\n", 274 | " # Create a new file with the correct name\n", 275 | " temp_dir = tempfile.gettempdir()\n", 276 | " temp_file_path = os.path.join(temp_dir, base_name)\n", 277 | "\n", 278 | " # Write the content to the new file\n", 279 | " with open(temp_file_path, \"wb\") as f:\n", 280 | " f.write(file_content)\n", 281 | "\n", 282 | " # Upload the file using session.file.put with compression disabled\n", 283 | " put_result = session.file.put(\n", 284 | " temp_file_path, stage_path, auto_compress=False, overwrite=True\n", 285 | " )\n", 286 | "\n", 287 | " # Check upload status\n", 288 | " if not put_result or len(put_result) == 0:\n", 289 | " raise Exception(f\"Failed to upload file: {base_name}\")\n", 290 | "\n", 291 | " if put_result[0].status not in [\"UPLOADED\", \"SKIPPED\"]:\n", 292 | " raise Exception(f\"Upload failed with status: {put_result[0].status}\")\n", 293 | "\n", 294 | " # Clean up the temporary file\n", 295 | " if os.path.exists(temp_file_path):\n", 296 | " os.remove(temp_file_path)\n", 297 | "\n", 298 | " return f\"Successfully uploaded {base_name} to {stage_path}\"\n", 299 | " except Exception as e:\n", 300 | " print_error(f\"Error uploading file: {e}\")\n", 301 | " raise\n", 302 | "\n", 303 | "\n", 304 | "def process_pdf_files(config: Config) -> None:\n", 305 | " \"\"\"Main process to orchestrate the PDF splitting\"\"\"\n", 306 | " try:\n", 307 | " session = get_active_session()\n", 308 | " pdf_files = list_pdf_files(session, config)\n", 309 | "\n", 310 | " for file_info in pdf_files:\n", 311 | " file_path = file_info[\"RELATIVE_PATH\"]\n", 312 | " print_info(f\"Processing: {file_path}\")\n", 313 | "\n", 314 | " try:\n", 315 | " # Download the PDF file\n", 316 | " local_pdf_path = download_file_from_stage(session, file_path, config)\n", 317 | "\n", 318 | " # Get base filename without extension\n", 319 | " base_name = os.path.splitext(os.path.basename(file_path))[0]\n", 320 | "\n", 321 | " # Extract individual PDF pages\n", 322 | " with open(local_pdf_path, \"rb\") as file:\n", 323 | " pdf_reader = PyPDF2.PdfReader(file)\n", 324 | " num_pages = len(pdf_reader.pages)\n", 325 | " print_info(f\"Converting PDF to {num_pages} pages of PDFs\")\n", 326 | "\n", 327 | " for i in range(num_pages):\n", 328 | " page_num = i + 1\n", 329 | " s3_pdf_output_path = (\n", 330 | " f\"{config.output_pdf_path}/{base_name}_page_{page_num}.pdf\"\n", 331 | " )\n", 332 | " pdf_writer = PyPDF2.PdfWriter()\n", 333 | " pdf_writer.add_page(pdf_reader.pages[i])\n", 334 | " temp_file = tempfile.NamedTemporaryFile(\n", 335 | " delete=False, suffix=\".pdf\"\n", 336 | " )\n", 337 | " local_pdf_tmp_file_name = temp_file.name\n", 338 | " with open(local_pdf_tmp_file_name, \"wb\") as output_file:\n", 339 | " pdf_writer.write(output_file)\n", 340 | " \n", 341 | " upload_file_to_stage(\n", 342 | " session, local_pdf_tmp_file_name, s3_pdf_output_path, config\n", 343 | " )\n", 344 | " cleanup_temp_file(local_pdf_tmp_file_name)\n", 345 | " \n", 346 | " # Convert PDF to images \n", 347 | " with pdfplumber.open(local_pdf_path) as pdf:\n", 348 | " print_info(f\"Converting PDF to {len(pdf.pages)} images\")\n", 349 | " for i, page in enumerate(pdf.pages):\n", 350 | " page_num = i + 1\n", 351 | " # Get page dimensions\n", 352 | " width = page.width\n", 353 | " height = page.height\n", 354 | "\n", 355 | " # Determine if scaling is needed\n", 356 | " max_dim = max(width, height)\n", 357 | " if max_dim > config.max_dimension:\n", 358 | " # Calculate scale factor to fit within max_dimension\n", 359 | " scale_factor = config.max_dimension / max_dim\n", 360 | " width = int(width * scale_factor)\n", 361 | " height = int(height * scale_factor)\n", 362 | "\n", 363 | " img = page.to_image(resolution=config.dpi)\n", 364 | " temp_file = tempfile.NamedTemporaryFile(\n", 365 | " delete=False, suffix=\".png\"\n", 366 | " )\n", 367 | " local_image_tmp_file_name = temp_file.name\n", 368 | " img.save(local_image_tmp_file_name)\n", 369 | "\n", 370 | " s3_image_output_path = (\n", 371 | " f\"{config.output_image_path}/{base_name}_page_{page_num}.png\"\n", 372 | " )\n", 373 | " \n", 374 | " upload_file_to_stage(\n", 375 | " session, local_image_tmp_file_name, s3_image_output_path, config\n", 376 | " )\n", 377 | " cleanup_temp_file(local_image_tmp_file_name)\n", 378 | " \n", 379 | " # Clean up the original downloaded file\n", 380 | " cleanup_temp_file(local_pdf_path)\n", 381 | "\n", 382 | " except Exception as e:\n", 383 | " print_error(f\"Error processing {file_path}: {e}\")\n", 384 | " continue\n", 385 | "\n", 386 | " except Exception as e:\n", 387 | " print_error(f\"Fatal error in process_pdf_files: {e}\")\n", 388 | " raise\n", 389 | "\n", 390 | "\n", 391 | "config = Config(dpi=200)\n", 392 | "process_pdf_files(config)" 393 | ] 394 | }, 395 | { 396 | "cell_type": "markdown", 397 | "id": "df7440dc-dcff-4219-8093-8d389b29b599", 398 | "metadata": {}, 399 | "source": [ 400 | "Check out one image and see if it's clear. If you can't read clearly, neural models won't either!" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "id": "9ba9c845-6dff-4f18-8df5-a2a01c482d4f", 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "session = get_active_session()\n", 411 | "\n", 412 | "image=session.file.get_stream(\n", 413 | " f\"@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/paged_image/abbott-laboratories-10-q-2024-10-31_page_1.png\", # change to one image on your stage\n", 414 | " decompress=False).read()\n", 415 | "st.image(image)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "id": "c8857644-dc84-4c19-b0a4-594d5e19586c", 421 | "metadata": { 422 | "collapsed": false, 423 | "jupyter": { 424 | "outputs_hidden": false 425 | }, 426 | "name": "cell2" 427 | }, 428 | "source": [ 429 | "Now let's start the multimodal embedding part! We first create an intermediate table that holds relative file names of images, and then call `SNOWFLAKE.CORTEX.embed_image_1024` to turn them into vectors!" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "id": "bd6854a8-8f72-45e3-936a-8b6b6f792bf8", 436 | "metadata": { 437 | "language": "sql", 438 | "name": "cell9" 439 | }, 440 | "outputs": [], 441 | "source": [ 442 | "CREATE OR REPLACE TABLE CORTEX_SEARCH_DB.PYU.DEMO_SEC_IMAGE_CORPUS AS\n", 443 | "SELECT\n", 444 | " CONCAT('paged_image/', split_part(metadata$filename, '/', -1)) AS FILE_NAME,\n", 445 | " REGEXP_SUBSTR(metadata$filename, '_page_(\\\\d+)\\.', 1, 1, 'e')::INTEGER as PAGE_NUMBER,\n", 446 | " '@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL' AS STAGE_PREFIX\n", 447 | "FROM\n", 448 | " @CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/paged_image/\n", 449 | "GROUP BY 1, 2, 3\n", 450 | ";\n", 451 | "\n", 452 | "SELECT * FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_IMAGE_CORPUS LIMIT 5;" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "id": "fcc56ebb-45a4-4ded-b9ed-2d11b6054af1", 459 | "metadata": { 460 | "language": "sql", 461 | "name": "cell8" 462 | }, 463 | "outputs": [], 464 | "source": [ 465 | "CREATE OR REPLACE TABLE CORTEX_SEARCH_DB.PYU.DEMO_SEC_VM3_VECTORS AS\n", 466 | "SELECT\n", 467 | " FILE_NAME,\n", 468 | " PAGE_NUMBER,\n", 469 | " STAGE_PREFIX,\n", 470 | " AI_EMBED(\n", 471 | " 'voyage-multimodal-3',\n", 472 | " TO_FILE('@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL', FILE_NAME)\n", 473 | " ) AS IMAGE_VECTOR\n", 474 | "FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_IMAGE_CORPUS;\n", 475 | "\n", 476 | "\n", 477 | "SELECT * FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_VM3_VECTORS LIMIT 5;" 478 | ] 479 | }, 480 | { 481 | "cell_type": "markdown", 482 | "id": "90ce0184-ea75-415c-9911-c70028a9a20b", 483 | "metadata": { 484 | "collapsed": false, 485 | "jupyter": { 486 | "outputs_hidden": false 487 | }, 488 | "name": "cell14" 489 | }, 490 | "source": [ 491 | "Similarly, we call `SNOWFLAKE.CORTEX.PARSE_DOCUMENT` to extract text from PDF pages. We discover that, although multimodal retrieval is powerful, augmenting it with text retrieval for keyword matching can bring quality improvement on certain types of search tasks/queries." 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "id": "b002eaac-2163-4f96-8a77-9c052e4381db", 498 | "metadata": { 499 | "language": "sql", 500 | "name": "cell7" 501 | }, 502 | "outputs": [], 503 | "source": [ 504 | "CREATE OR REPLACE TABLE CORTEX_SEARCH_DB.PYU.DEMO_SEC_PDF_CORPUS AS\n", 505 | "SELECT\n", 506 | " CONCAT('paged_pdf/', split_part(metadata$filename, '/', -1)) AS FILE_NAME,\n", 507 | " REGEXP_SUBSTR(metadata$filename, '_page_(\\\\d+)\\.', 1, 1, 'e')::INTEGER as PAGE_NUMBER,\n", 508 | " '@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL' AS STAGE_PREFIX\n", 509 | "FROM\n", 510 | " @CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/paged_pdf/\n", 511 | "GROUP BY 1, 2, 3\n", 512 | ";\n", 513 | "\n", 514 | "CREATE OR REPLACE TABLE CORTEX_SEARCH_DB.PYU.DEMO_SEC_PARSE_DOC AS\n", 515 | " SELECT\n", 516 | " FILE_NAME,\n", 517 | " PAGE_NUMBER,\n", 518 | " STAGE_PREFIX,\n", 519 | " PARSE_JSON(TO_VARCHAR(SNOWFLAKE.CORTEX.PARSE_DOCUMENT(\n", 520 | " '@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL',\n", 521 | " FILE_NAME,\n", 522 | " {'mode': 'LAYOUT'}\n", 523 | " ))):content AS PARSE_DOC_OUTPUT\n", 524 | " FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_PDF_CORPUS\n", 525 | ";\n", 526 | "\n", 527 | "SELECT * FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_PARSE_DOC LIMIT 5;" 528 | ] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "id": "41e715c5-ccef-40dc-958c-c655d780f457", 533 | "metadata": { 534 | "collapsed": false, 535 | "jupyter": { 536 | "outputs_hidden": false 537 | }, 538 | "name": "cell16" 539 | }, 540 | "source": [ 541 | "Now we join image vectors and texts into a single table, and create a Cortex Search service!" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "id": "fd6901c9-6563-46ff-8369-9234169a799a", 548 | "metadata": { 549 | "language": "sql", 550 | "name": "cell10" 551 | }, 552 | "outputs": [], 553 | "source": [ 554 | "CREATE OR REPLACE TABLE CORTEX_SEARCH_DB.PYU.DEMO_SEC_JOINED_DATA AS\n", 555 | "SELECT\n", 556 | " v.FILE_NAME,\n", 557 | " v.PAGE_NUMBER,\n", 558 | " v.IMAGE_VECTOR AS VECTOR_MAIN,\n", 559 | " p.PARSE_DOC_OUTPUT AS TEXT,\n", 560 | " v.FILE_NAME AS IMAGE_FILEPATH\n", 561 | "FROM\n", 562 | " CORTEX_SEARCH_DB.PYU.DEMO_SEC_VM3_VECTORS v\n", 563 | "JOIN\n", 564 | " CORTEX_SEARCH_DB.PYU.DEMO_SEC_PARSE_DOC p\n", 565 | "ON\n", 566 | " REGEXP_SUBSTR(v.FILE_NAME, 'paged_image/(.*)\\\\.png$', 1, 1, 'e', 1) = REGEXP_SUBSTR(p.FILE_NAME, 'paged_pdf/(.*)\\\\.pdf$', 1, 1, 'e', 1);\n", 567 | "\n", 568 | "\n", 569 | "CREATE OR REPLACE CORTEX SEARCH SERVICE CORTEX_SEARCH_DB.PYU.DEMO_SEC_CORTEX_SEARCH_SERVICE\n", 570 | " TEXT INDEXES TEXT\n", 571 | " VECTOR INDEXES VECTOR_MAIN(query_model='voyage-multimodal-3')\n", 572 | " WAREHOUSE='SEARCH_L'\n", 573 | " TARGET_LAG='1 day'\n", 574 | "AS (\n", 575 | " SELECT \n", 576 | " TO_VARCHAR(TEXT) AS TEXT, \n", 577 | " PAGE_NUMBER, \n", 578 | " VECTOR_MAIN,\n", 579 | " IMAGE_FILEPATH\n", 580 | " FROM CORTEX_SEARCH_DB.PYU.DEMO_SEC_JOINED_DATA\n", 581 | ");" 582 | ] 583 | }, 584 | { 585 | "cell_type": "markdown", 586 | "id": "78771232-88b1-4096-83fb-b4e233e548d8", 587 | "metadata": { 588 | "collapsed": false, 589 | "jupyter": { 590 | "outputs_hidden": false 591 | }, 592 | "name": "cell17" 593 | }, 594 | "source": [ 595 | "We have created a multi-index Cortex Search Service with both text and vector indexes. This allows us to perform hybrid search combining keyword matching on text content and semantic similarity on vector embeddings. We'll use the new multi-index query syntax to search across both index types.\n", 596 | "\n", 597 | "**Note:** The multi-index query syntax requires Snowflake Python API version 1.6.0 or later." 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": null, 603 | "id": "f5611514-3ef0-45e9-a921-28eabe710893", 604 | "metadata": { 605 | "language": "python", 606 | "name": "cell6" 607 | }, 608 | "outputs": [], 609 | "source": [ 610 | "\n", 611 | "demo_query_text = \"What was the overall operational cost incurred by Abbott Laboratories in 2023, and how much of this amount was allocated to research and development?\"\n", 612 | "print(demo_query_text)" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": null, 618 | "id": "b7e908c4-aa91-40d2-943b-5f2934741373", 619 | "metadata": { 620 | "codeCollapsed": false, 621 | "language": "python", 622 | "name": "cell11" 623 | }, 624 | "outputs": [], 625 | "source": [ 626 | "from snowflake.core import Root\n", 627 | "\n", 628 | "root = Root(session)\n", 629 | "# fetch service\n", 630 | "my_service = (root\n", 631 | " .databases[\"CORTEX_SEARCH_DB\"]\n", 632 | " .schemas[\"PYU\"]\n", 633 | " .cortex_search_services[\"DEMO_SEC_CORTEX_SEARCH_SERVICE\"]\n", 634 | ")\n", 635 | "\n", 636 | "# query service using multi-index query syntax\n", 637 | "resp = my_service.search(\n", 638 | " multi_index_query={\n", 639 | " \"TEXT\": [{\"text\": demo_query_text}],\n", 640 | " \"VECTOR_MAIN\": [{\"text\": demo_query_text}]\n", 641 | " },\n", 642 | " columns=[\"TEXT\", \"PAGE_NUMBER\", \"IMAGE_FILEPATH\"],\n", 643 | " limit=5\n", 644 | ")\n", 645 | "\n", 646 | "for i in range(5):\n", 647 | " print(f\"rank {i + 1}: {resp.to_dict()['results'][i]['PAGE_NUMBER']}\")\n", 648 | "\n", 649 | "top_page_id = resp.to_dict()['results'][0]['PAGE_NUMBER']\n", 650 | "\n", 651 | "top_page_path = resp.to_dict()['results'][0]['IMAGE_FILEPATH']" 652 | ] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "id": "f448324b-c09d-4ed2-b449-94b0615b124b", 657 | "metadata": { 658 | "collapsed": false, 659 | "jupyter": { 660 | "outputs_hidden": false 661 | }, 662 | "name": "cell18" 663 | }, 664 | "source": [ 665 | "Let's see the top ranked page we found!" 666 | ] 667 | }, 668 | { 669 | "cell_type": "code", 670 | "execution_count": null, 671 | "id": "93407e70-4657-472e-a6a8-1703f4282787", 672 | "metadata": { 673 | "language": "python", 674 | "name": "cell12" 675 | }, 676 | "outputs": [], 677 | "source": [ 678 | "session = get_active_session()\n", 679 | "image=session.file.get_stream(\n", 680 | " f\"@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/{top_page_path}\",\n", 681 | " decompress=False).read()\n", 682 | "st.image(image)" 683 | ] 684 | }, 685 | { 686 | "cell_type": "markdown", 687 | "id": "f1ae6bea-b20c-42d5-9c8b-de1a63c0f859", 688 | "metadata": { 689 | "collapsed": false, 690 | "jupyter": { 691 | "outputs_hidden": false 692 | }, 693 | "name": "cell19" 694 | }, 695 | "source": [ 696 | "Finally, we can also perform multimodal retrieval augmented generation (mRAG) by sending the query and the top page image to a multimodal LLM served on snowflake cortex and get a natural language answer to our question." 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": null, 702 | "id": "c695373e-ac74-4b62-a1f1-08206cbd5c81", 703 | "metadata": { 704 | "codeCollapsed": false, 705 | "language": "sql", 706 | "name": "cell3" 707 | }, 708 | "outputs": [], 709 | "source": [ 710 | "prompt = \"Answer the following question by referencing the document image {0}: What was the overall operational cost incurred by Abbott Laboratories in 2023, and how much of this amount was allocated to research and development?\"\n", 711 | "query = f\"\"\"\n", 712 | " SELECT SNOWFLAKE.CORTEX.COMPLETE('pixtral-large',\n", 713 | " PROMPT('{prompt}',\n", 714 | " TO_FILE('@CORTEX_SEARCH_DB.PYU.MULTIMODAL_DEMO_INTERNAL/{top_page_path}'))\n", 715 | " );\n", 716 | "\"\"\"\n", 717 | "sql_output = session.sql(query).collect()\n", 718 | "response = list(sql_output[0].asDict().values())[0]\n", 719 | "print(response)" 720 | ] 721 | } 722 | ], 723 | "metadata": { 724 | "kernelspec": { 725 | "display_name": "Python 3 (ipykernel)", 726 | "language": "python", 727 | "name": "python3" 728 | }, 729 | "language_info": { 730 | "codemirror_mode": { 731 | "name": "ipython", 732 | "version": 3 733 | }, 734 | "file_extension": ".py", 735 | "mimetype": "text/x-python", 736 | "name": "python", 737 | "nbconvert_exporter": "python", 738 | "pygments_lexer": "ipython3", 739 | "version": "3.11.9" 740 | }, 741 | "lastEditStatus": { 742 | "authorEmail": "puxuan.yu@snowflake.com", 743 | "authorId": "7316233358469", 744 | "authorName": "PYU", 745 | "lastEditTime": 1743811030949, 746 | "notebookId": "xw7ke2yiifgrxgxsyy2f", 747 | "sessionId": "e3312463-a075-4b08-a68c-4eb1627ec394" 748 | } 749 | }, 750 | "nbformat": 4, 751 | "nbformat_minor": 5 752 | } 753 | --------------------------------------------------------------------------------