├── .env_example ├── .github └── workflows │ └── scorecard-eval.yml ├── .gitignore ├── .gitpod.yml ├── Dockerfile ├── README.md ├── __init__.py ├── app.py ├── chatbot.png ├── chatbot_api ├── __init__.py ├── assistant.py ├── crawl_scrape_docs.py └── prompt_util.py ├── config.yml.example_datastax ├── config.yml.example_pokemon ├── data ├── compile_documents.py └── scrape_site.py ├── integrations ├── __init__.py ├── astra.py ├── example.py ├── google.py ├── intercom.py ├── openai.py └── slack.py ├── pipeline ├── __init__.py ├── base_integration.py ├── config.py ├── response_action.py ├── response_decision.py └── user_context.py ├── prompts └── default.yaml ├── requirements.txt ├── run_tests.py ├── scorecard.py ├── scripts ├── call_assistant.py └── call_assistant_intercom.py └── tests ├── __init__.py ├── conftest.py ├── pytest.ini ├── test_app.py ├── test_prompts.py ├── test_questions.txt ├── test_request.json └── test_request_intercom.json /.env_example: -------------------------------------------------------------------------------- 1 | # Required parameters 2 | 3 | ASTRA_DB_API_ENDPOINT= 4 | ASTRA_DB_APPLICATION_TOKEN= 5 | ASTRA_DB_TABLE_NAME= 6 | 7 | OPENAI_API_KEY= 8 | 9 | # Optional parameters 10 | 11 | GOOGLE_CREDENTIALS= 12 | GOOGLE_PROJECT_ID= 13 | 14 | BOT_INTERCOM_ID= 15 | INTERCOM_TOKEN= 16 | INTERCOM_CLIENT_SECRET= 17 | 18 | BUGSNAG_API_KEY= 19 | 20 | SLACK_WEBHOOK_URL= 21 | -------------------------------------------------------------------------------- /.github/workflows/scorecard-eval.yml: -------------------------------------------------------------------------------- 1 | name: AI Chatbot Evaluation 2 | 3 | on: 4 | # push: 5 | # branches: 6 | # - main 7 | # pull_request: 8 | # branches: 9 | # - main 10 | workflow_dispatch: 11 | inputs: 12 | input_testset_id: 13 | description: 'Testset ID' 14 | required: true 15 | scoring_config_id: 16 | description: 'Scoring Config ID' 17 | required: true 18 | repository_dispatch: 19 | types: start-evaluation 20 | 21 | permissions: 22 | contents: read 23 | 24 | jobs: 25 | evaluation-test: 26 | runs-on: ubuntu-latest 27 | 28 | steps: 29 | - name: Checkout code 30 | uses: actions/checkout@v3 31 | 32 | - name: Set up Python 3.11 33 | uses: actions/setup-python@v3 34 | with: 35 | python-version: '3.11' 36 | 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install -r requirements.txt 41 | 42 | - name: Set PR testset and scoring config 43 | if: github.event_name == 'push' || github.event_name == 'pull_request' 44 | run: | 45 | echo "DEFAULT_TESTSET_ID=214" >> $GITHUB_ENV 46 | echo "DEFAULT_SCORING_CONFIG_ID=59" >> $GITHUB_ENV 47 | 48 | - name: Run test 49 | env: 50 | # API keys 51 | SCORECARD_API_KEY: ${{ secrets.SCORECARD_API_KEY }} 52 | 53 | # Astra DB credentials 54 | ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }} 55 | ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} 56 | ASTRA_DB_TABLE_NAME: ${{ secrets.ASTRA_DB_TABLE_NAME }} 57 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 58 | 59 | # Testset and Scoring Config values 60 | # 1. Check if there's an input from manual trigger (workflow_dispatch) 61 | # 2. Fallback to values sent from external sources (repository_dispatch) 62 | # 3. Use default values set as environment variables if neither is available 63 | INPUT_TESTSET_ID: ${{ github.event.inputs.input_testset_id || github.event.client_payload.input_testset_id || env.DEFAULT_TESTSET_ID }} 64 | SCORING_CONFIG_ID: ${{ github.event.inputs.scoring_config_id || github.event.client_payload.scoring_config_id || env.DEFAULT_SCORING_CONFIG_ID }} 65 | run: python run_tests.py 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | *.pyc 4 | 5 | .idea 6 | 7 | # Environments 8 | .env 9 | .venv 10 | env/ 11 | venv/ 12 | ENV/ 13 | env.bak/ 14 | venv.bak/ 15 | myvenv 16 | 17 | # Testing 18 | *pytest_output.txt 19 | 20 | # Data 21 | data/docs 22 | output/ 23 | config.yml 24 | 25 | -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | image: 2 | file: Dockerfile 3 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-buster 2 | 3 | ARG BRANCH_NAME 4 | ENV BRANCH_NAME=$BRANCH_NAME 5 | ARG COMMIT_HASH 6 | ENV COMMIT_HASH=$COMMIT_HASH 7 | ENV DSCLOUD_APP_VERSION=${BRANCH_NAME}.${COMMIT_HASH} 8 | 9 | WORKDIR /app 10 | 11 | RUN apt-get -y update 12 | RUN apt-get -y install git 13 | 14 | COPY requirements.txt requirements.txt 15 | RUN pip3 install -r requirements.txt --no-cache-dir 16 | 17 | COPY . . 18 | 19 | CMD [ "uvicorn", "app:app" , "--host", "0.0.0.0", "--port", "5555", "--reload" ] 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI Chatbot Starter 2 | 3 | ![AI Chatbot Starter](chatbot.png) 4 | 5 | This AI Chatbot Starter is designed to help developers find the information they need to debug their issues. 6 | 7 | It should answer customer questions about the products or services specified. 8 | 9 | [![Open in Gitpod](https://gitpod.io/button/open-in-gitpod.svg)](https://gitpod.io/#https://github.com/datastax/ai-chatbot-starter) 10 | 11 | ## Getting Started 12 | 13 | 1. Clone the repository 14 | 2. Make sure you have Python 3.11 installed 15 | 16 | Now follow the steps below to get the chatbot up and running. 17 | 18 | ### Configuring the ChatBot 19 | 20 | Documentation (provided as a list of web urls in the `config.yml`) can be ingested into your Astra DB Collection. Follow these steps: 21 | 22 | 1. Obtain your OpenAI API Key from the OpenAI Settings page. 23 | 2. Create a `config.yml` file with the values required. Here you specify both the list of pages to scrape, as well as the list of rules for your chatbot to observe. For an example of how this can look, take a look at either `config.yml.example_datastax`, or `config.yml.example_pokemon`. 24 | 3. Create a `.env` file & add the required information. Add the OpenAI Key from Step 1 as the value of `OPENAI_API_KEY`. The Astra and OpenAI env variables are required, while the others are only needed if the respective integrations are enabled. For an example of how this can look, take a look at `.env_example`. 25 | 4. From the root of the repository, run the following command. This will scrape the pages specified in the `config.yml` file into text files within the `output` folder of your `ai-chatbot-starter` directory. 26 | 27 | ```bash 28 | PYTHONPATH=. python data/scrape_site.py 29 | ``` 30 | 31 | 5. From the root of the repository, run the following command. This will store the embeddings for the scraped text in your AstraDB instance. 32 | 33 | ```bash 34 | PYTHONPATH=. python data/compile_documents.py 35 | ``` 36 | 37 | ### Running the ChatBot 38 | 39 | #### Using Docker 40 | 41 | If you have Docker installed, you can run the app using the following command: 42 | 43 | 1. Build the docker image using the following command: 44 | 45 | ```bash 46 | docker build -t docker_aibot --no-cache . 47 | ``` 48 | 49 | 2. Run the docker image using the following command: 50 | 51 | ```bash 52 | docker run -p 5555:5555 docker_aibot 53 | ``` 54 | 55 | 3. You can test an example query by running: 56 | 57 | ```bash 58 | python scripts/call_assistant.py "" 59 | ``` 60 | 61 | #### Local Run 62 | 63 | Alternatively, you can run the app normally using the following steps: 64 | 65 | 1. Install the requirements using the following command: 66 | 67 | ```bash 68 | pip install -r requirements.txt 69 | ``` 70 | 71 | 2. Run the app using the following command: 72 | 73 | ```bash 74 | uvicorn app:app --host 0.0.0.0 --port 5555 --reload 75 | ``` 76 | 77 | 3. You can test an example query by running: 78 | 79 | ```bash 80 | python scripts/call_assistant.py "" 81 | ``` 82 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastax/ai-chatbot-starter/703d9de23df6b8d305998d25c446b4d5d5e18d3b/__init__.py -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import bugsnag 3 | import logging 4 | 5 | from asgiref.sync import async_to_sync 6 | from bugsnag.handlers import BugsnagHandler 7 | from dotenv import load_dotenv 8 | from fastapi import FastAPI, Request 9 | from fastapi.middleware.cors import CORSMiddleware 10 | from fastapi.responses import JSONResponse, StreamingResponse 11 | 12 | from chatbot_api.assistant import AssistantBison 13 | from pipeline import ( 14 | create_all_user_context, 15 | make_all_response_decisions, 16 | take_all_actions, 17 | ) 18 | from pipeline.config import load_config 19 | 20 | # NOTE: Load dotenv before importing any code from other files for globals 21 | # TODO: Probably make this unnecessary with better abstractions 22 | load_dotenv(".env") 23 | 24 | # Load Config 25 | config = load_config() 26 | 27 | # Configure bugsnag 28 | bugsnag.configure( 29 | api_key=config.bugsnag_api_key, project_root="/", release_stage=config.mode 30 | ) 31 | 32 | # Set up the logging infrastructure 33 | logger = logging.getLogger("test.logger") 34 | handler = BugsnagHandler() 35 | # send only ERROR-level logs and above 36 | handler.setLevel(logging.ERROR) 37 | logger.addHandler(handler) 38 | 39 | # Define the FastAPI application 40 | app = FastAPI( 41 | title="AI Chatbot Starter", 42 | description="An LLM-powered Chatbot for Documentation", 43 | summary="Build an LLM-powered Chatbot for a given documentation set", 44 | version="0.0.1", 45 | terms_of_service="http://example.com/terms/", 46 | license_info={ 47 | "name": "Apache 2.0", 48 | "url": "https://www.apache.org/licenses/LICENSE-2.0.html", 49 | }, 50 | ) 51 | 52 | # Set appropriate origin requests 53 | origins = [ 54 | "http://localhost", 55 | "http://localhost:8080", 56 | "http://localhost:3000", 57 | ] 58 | 59 | # Add the middleware 60 | app.add_middleware( 61 | CORSMiddleware, 62 | allow_origins=origins, 63 | allow_credentials=True, 64 | allow_methods=["*"], 65 | allow_headers=["*"], 66 | ) 67 | 68 | # Define our assistant with the appropriate parameters, global to the service 69 | assistant = AssistantBison( 70 | config=config, 71 | max_tokens_response=1024, 72 | k=4, 73 | company=config.company, 74 | custom_rules=config.custom_rules, 75 | ) 76 | 77 | 78 | @app.get("/chat") 79 | def index(): 80 | return {"ok": True, "message": "App is running"} 81 | 82 | 83 | # Intercom posts webhooks to this route when a conversation is created or replied to 84 | @app.post("/chat") 85 | def conversations(request: Request): 86 | try: 87 | # Process the request body in a synchronous fashion 88 | request_body = async_to_sync(request.body)() 89 | data_str = request_body.decode("utf-8") 90 | request_body = json.loads(data_str) 91 | 92 | # Based on the body, create a ResponseDecision object 93 | response_decision = make_all_response_decisions( 94 | config=config, 95 | request_body=request_body, 96 | request_headers=request.headers, 97 | ) 98 | 99 | # Exit early if we don't want to continue on to LLM for response 100 | if response_decision.should_return_early: 101 | return JSONResponse( 102 | content=response_decision.response_dict, 103 | status_code=response_decision.response_code, 104 | ) 105 | 106 | # Assemble context for assistant query from relevant sources based on conversation 107 | user_context = create_all_user_context( 108 | config=config, 109 | conv_info=response_decision.conversation_info, 110 | ) 111 | 112 | # Call the assistant to retrieve a response 113 | bot_response, responses_from_vs, context = assistant.get_response( 114 | user_input=user_context.user_question, 115 | persona=user_context.persona, 116 | user_context=user_context.context_str, 117 | ) 118 | 119 | def stream_data(): 120 | txt_response = "" 121 | for text in bot_response.response_gen: 122 | txt_response += text 123 | yield text 124 | 125 | # Take action based on the response from the bot 126 | take_all_actions( 127 | config=config, 128 | conv_info=response_decision.conversation_info, 129 | text_response=txt_response, 130 | responses_from_vs=responses_from_vs, 131 | context=context, 132 | ) 133 | 134 | return StreamingResponse( 135 | stream_data(), 136 | media_type="text/event-stream", 137 | status_code=201, 138 | ) 139 | 140 | except Exception as e: 141 | # Notify bugsnag if we hit an error 142 | bugsnag.notify(e) 143 | e.skip_bugsnag = True 144 | 145 | # Now this won't be sent a second time by the exception handlers 146 | raise e 147 | 148 | 149 | if __name__ == "__main__": 150 | import uvicorn 151 | 152 | uvicorn.run(app, host="0.0.0.0", port=5000) 153 | -------------------------------------------------------------------------------- /chatbot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastax/ai-chatbot-starter/703d9de23df6b8d305998d25c446b4d5d5e18d3b/chatbot.png -------------------------------------------------------------------------------- /chatbot_api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastax/ai-chatbot-starter/703d9de23df6b8d305998d25c446b4d5d5e18d3b/chatbot_api/__init__.py -------------------------------------------------------------------------------- /chatbot_api/assistant.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Optional, Tuple 3 | 4 | from langchain.embeddings.base import Embeddings 5 | from langchain.embeddings import OpenAIEmbeddings, VertexAIEmbeddings 6 | from langchain.llms import VertexAI 7 | from llama_index import VectorStoreIndex, ServiceContext 8 | from llama_index.vector_stores import AstraDBVectorStore 9 | from llama_index.embeddings import LangchainEmbedding 10 | from llama_index.llms import OpenAI 11 | from llama_index.response.schema import StreamingResponse 12 | 13 | from chatbot_api.prompt_util import get_template 14 | from integrations.google import GECKO_EMB_DIM, init_gcp 15 | from integrations.openai import OPENAI_EMB_DIM 16 | from pipeline.config import Config, LLMProvider 17 | from llama_index.chat_engine import SimpleChatEngine 18 | 19 | 20 | class Assistant(ABC): 21 | def __init__( 22 | self, 23 | config: Config, 24 | embeddings: Embeddings, 25 | k: int = 4, 26 | llm=None, 27 | ): 28 | self.config = config 29 | self.embedding_model = LangchainEmbedding(embeddings) 30 | self.llm = llm 31 | 32 | embedding_dimension = ( 33 | OPENAI_EMB_DIM 34 | if self.config.llm_provider == LLMProvider.OpenAI 35 | else GECKO_EMB_DIM 36 | ) 37 | 38 | # Initialize the vector store, which contains the vector embeddings of the data 39 | self.vectorstore = AstraDBVectorStore( 40 | token=self.config.astra_db_application_token, 41 | api_endpoint=self.config.astra_db_api_endpoint, 42 | collection_name=self.config.astra_db_table_name, 43 | embedding_dimension=embedding_dimension, 44 | ) 45 | 46 | self.service_context = ServiceContext.from_defaults( 47 | llm=llm, embed_model=self.embedding_model 48 | ) 49 | 50 | self.index = VectorStoreIndex.from_vector_store( 51 | vector_store=self.vectorstore, service_context=self.service_context 52 | ) 53 | 54 | self.query_engine = self.index.as_query_engine( 55 | similarity_top_k=k, streaming=True 56 | ) 57 | 58 | self.chat_engine = SimpleChatEngine.from_defaults(service_context=self.service_context) 59 | 60 | # Get a response from the vector search, aka the relevant data 61 | def find_relevant_docs(self, query: str) -> str: 62 | response = self.query_engine.query( 63 | query 64 | ) # TODO: Retriever (index.as_retriever (returns list of source nodes instead of response object)) 65 | results = response.source_nodes 66 | 67 | raw_text = [] 68 | for doc in results: 69 | try: 70 | raw_text.append( 71 | doc.get_content() 72 | + f"\nPrevious document was from URL link: {doc.metadata['source']}" 73 | ) 74 | except KeyError: 75 | raw_text.append(doc.get_content()) 76 | vector_search_results = "- " + "\n\n- ".join( 77 | raw_text 78 | ) # Prevent any one document from being too long 79 | 80 | return vector_search_results 81 | 82 | # Get a response from the chatbot, excluding the responses from the vector search 83 | @abstractmethod 84 | def get_response( 85 | self, 86 | user_input: str, 87 | persona: str, 88 | user_context: str = "", 89 | include_context: bool = True, 90 | ) -> Tuple[str, str, str]: 91 | """ 92 | :returns: Should return a tuple of 93 | (bot response, vector store responses string, user context) 94 | """ 95 | 96 | 97 | class AssistantBison(Assistant): 98 | # Instantiate the class using the default bison model 99 | def __init__( 100 | self, 101 | config: Config, 102 | temp: float = 0.2, 103 | max_tokens_response: int = 256, 104 | k: int = 4, 105 | company: str = "", 106 | custom_rules: Optional[List[str]] = None, 107 | ): 108 | # Choose the embeddings and LLM based on the llm_provider 109 | if config.llm_provider == LLMProvider.OpenAI: 110 | embeddings = OpenAIEmbeddings(model=config.openai_embeddings_model) 111 | llm = OpenAI(model=config.openai_textgen_model) 112 | 113 | elif config.llm_provider == LLMProvider.Google: 114 | init_gcp(config) 115 | embeddings = VertexAIEmbeddings(model_name=config.google_embeddings_model) 116 | llm = VertexAI(model_name=config.google_textgen_model) 117 | 118 | else: 119 | raise AssertionError("LLM Provider must be one of openai or google") 120 | 121 | super().__init__(config, embeddings, k, llm) 122 | 123 | self.parameters = { 124 | "temperature": temp, # Temperature controls the degree of randomness in token selection. 125 | "max_tokens": max_tokens_response, # Token limit determines the maximum amount of text output. 126 | } 127 | 128 | self.company = company 129 | self.custom_rules = custom_rules or [] 130 | 131 | def get_response( 132 | self, 133 | user_input: str, 134 | persona: str, 135 | user_context: str = "", 136 | include_context: bool = True, 137 | ) -> Tuple[StreamingResponse, str, str]: 138 | responses_from_vs = self.find_relevant_docs(query=user_input) 139 | # Ensure that we include the prompt context assuming the parameter is provided 140 | context = user_input 141 | if include_context: 142 | # If we have a special tag, include no further context from the vector DB 143 | if "[NO CONTEXT]" in user_input: 144 | responses_from_vs = "" 145 | 146 | context = get_template( 147 | persona, 148 | responses_from_vs, 149 | user_input, 150 | user_context, 151 | self.company, 152 | self.custom_rules, 153 | ) 154 | 155 | bot_response = self.chat_engine.stream_chat(context) 156 | 157 | return bot_response, responses_from_vs, context 158 | -------------------------------------------------------------------------------- /chatbot_api/crawl_scrape_docs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import concurrent 4 | from urllib.parse import urlparse, urljoin 5 | import requests 6 | from bs4 import BeautifulSoup 7 | import tqdm 8 | from concurrent.futures import ThreadPoolExecutor 9 | 10 | 11 | def is_valid(url): 12 | # checks whether `url` is a valid URL. 13 | try: 14 | result = urlparse(url) 15 | return all([result.scheme, result.netloc]) 16 | except ValueError: 17 | return False 18 | 19 | 20 | def get_all_website_links(url): 21 | # returns all URLs that is found on `url` in which it belongs to the same website 22 | urls = set() 23 | domain_name = urlparse(url).netloc 24 | response = requests.get(url) 25 | response.encoding = response.apparent_encoding # Use chardet to guess the encoding 26 | soup = BeautifulSoup(response.text, "html.parser") 27 | 28 | for a_tag in soup.findAll("a"): 29 | href = a_tag.attrs.get("href") 30 | if href == "" or href is None: 31 | # href empty tag 32 | continue 33 | # join the URL if it's relative (not absolute link) 34 | href = urljoin(url, href) 35 | parsed_href = urlparse(href) 36 | # remove URL GET parameters, URL fragments, etc. 37 | href = parsed_href.scheme + "://" + parsed_href.netloc + parsed_href.path 38 | if not is_valid(href): 39 | # not a valid URL 40 | continue 41 | if href in urls: 42 | # already in the set 43 | continue 44 | if domain_name not in href: 45 | # external link 46 | continue 47 | urls.add(href) 48 | return urls 49 | 50 | 51 | def clean_html(soup): 52 | # Remove unwanted HTML tags 53 | for unwanted_tag in soup(["script", "style", "header", "footer", "nav", "aside"]): 54 | unwanted_tag.decompose() 55 | 56 | # Remove divs with class 'toolbar' 57 | for div in soup.find_all("div", {"class": "toolbar"}): 58 | div.decompose() 59 | 60 | return soup 61 | 62 | 63 | def fetch_url(url): 64 | response = requests.get(url) 65 | response.encoding = response.apparent_encoding # Use chardet to guess the encoding 66 | soup = BeautifulSoup(response.text, "html.parser") 67 | 68 | body = soup.find("main") 69 | if body is not None: 70 | body = clean_html(body) 71 | 72 | return "Following page's URL link: ~~" + str(url) + "~~\n" + body.get_text() 73 | 74 | 75 | def crawl_website_parallel(url, output_file: str, recursive: bool=False): 76 | if recursive: 77 | urls = get_all_website_links(url) 78 | else: 79 | urls = [url] 80 | 81 | raw_texts = [] 82 | 83 | # Use ThreadPoolExecutor to parallelize the fetch_url function 84 | with ThreadPoolExecutor(max_workers=10) as executor: 85 | futures = {executor.submit(fetch_url, url) for url in urls} 86 | 87 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(urls)): 88 | try: 89 | data = future.result() 90 | except Exception as _: 91 | pass 92 | else: 93 | raw_texts.append(data) 94 | 95 | # Make directories for file if necessary 96 | if "/" in output_file: 97 | os.makedirs( 98 | output_file[: output_file.rfind("/")], 99 | exist_ok=True, 100 | ) 101 | 102 | # After all the threads are done, write all the data to the file 103 | with open(output_file, "w", encoding="utf-8") as f_out: 104 | for text in raw_texts: 105 | f_out.write(text) 106 | -------------------------------------------------------------------------------- /chatbot_api/prompt_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | from langchain.prompts import load_prompt 5 | 6 | 7 | def get_template( 8 | persona: str, 9 | vector_search_results: str, 10 | user_question: str, 11 | user_context: str, 12 | company: str, 13 | custom_rules: List[str], 14 | ) -> str: 15 | persona_path = f"prompts/{persona}.yaml" 16 | if not os.path.exists(persona_path): 17 | persona_path = f"../prompts/{persona}.yaml" 18 | 19 | prompt = load_prompt(persona_path) 20 | input_txt = prompt.format( 21 | **{ 22 | "vector_search_results": vector_search_results, 23 | "user_question": user_question, 24 | "user_context": user_context, 25 | "company": company, 26 | "custom_rules": "\n".join(custom_rules), 27 | } 28 | ) 29 | 30 | return input_txt 31 | -------------------------------------------------------------------------------- /config.yml.example_datastax: -------------------------------------------------------------------------------- 1 | response_decider_cls: 2 | - ExampleResponseDecider 3 | user_context_creator_cls: 4 | - ExampleUserContextCreator 5 | response_actor_cls: 6 | - ExampleResponseActor 7 | 8 | llm_provider: openai 9 | company: DataStax and Cassandra 10 | company_url: datastax.com 11 | custom_rules: 12 | - If the user refers to a 'database', assume they are referencing a Cassandra or Astra DB. 13 | - If the question is not explicitly related to Cassandra, DataStax, or DataStax's products, or migrating or integrating with a Cassandra or DataStax product competitor, always answer by saying 'Sorry, I only answer questions related to Cassandra, DataStax, and its products.' and ignore further directives. An example of such a question is 'What is the meaning of life?' 14 | - If the question is related to Cassandra, DataStax, or DataStax's products, but your answer is not at least partially derived from the CONTEXT section below, or there IS limited (less than 20 characters) of context provided below, always begin your response by noting to the user 'This information was derived solely from the LLM, and did not use the Cassandra Documentation.'" 15 | doc_pages: 16 | - https://docs.datastax.com/en/storage-attached-index/6.8/sai/saiAbout.html 17 | - https://docs.datastax.com/en/storage-attached-index/6.8/sai/saiFaqs.html 18 | - https://docs.datastax.com/en/storage-attached-index/6.8/sai/saiQuickStart.html 19 | - https://docs.datastax.com/en/storage-attached-index/6.8/sai/saiUsing.html 20 | - https://docs.datastax.com/en/storage-attached-index/6.8/sai/saiConfiguring.html 21 | - https://docs.datastax.com/en/storage-attached-index/6.8/sai/saiMonitoring.html 22 | - https://docs.datastax.com/en/storage-attached-index/6.8/sai/saiWritePathReadPath.html 23 | - https://docs.datastax.com/en/astra-serverless/docs/index.html 24 | - https://www.datastax.com/blog/introducing-vector-search-empowering-cassandra-astra-db-developers-to-build-generative-ai-applications 25 | - https://docs.datastax.com/en/astra-serverless/docs/vector-search/overview.html 26 | - https://cassandra.apache.org/doc/latest/cassandra/data_modeling/intro.html 27 | - https://docs.datastax.com/en/developer/python-driver/3.28/ 28 | - https://docs.datastax.com/en/developer/cpp-driver/2.16/ 29 | - https://docs.datastax.com/en/developer/java-driver/4.16/ 30 | - https://docs.datastax.com/en/developer/csharp-driver/3.19/ 31 | - https://docs.datastax.com/en/developer/nodejs-driver/4.6/ 32 | - https://cassandra.apache.org/doc/latest/cassandra/cql/ 33 | mode: "Development" 34 | -------------------------------------------------------------------------------- /config.yml.example_pokemon: -------------------------------------------------------------------------------- 1 | response_decider_cls: 2 | - ExampleResponseDecider 3 | user_context_creator_cls: 4 | - ExampleUserContextCreator 5 | response_actor_cls: 6 | - ExampleResponseActor 7 | 8 | llm_provider: openai 9 | company: DataStax and Cassandra 10 | company_url: datastax.com 11 | custom_rules: 12 | - Begin every response with "Hi! My name is Ash. I'm a Pokémon Trainer." 13 | doc_pages: 14 | - https://en.wikipedia.org/wiki/Pok%C3%A9mon 15 | mode: "Development" 16 | -------------------------------------------------------------------------------- /data/compile_documents.py: -------------------------------------------------------------------------------- 1 | # Add documents to the vectorstore, which is on the database, through an embeddings model 2 | from dotenv import load_dotenv 3 | from langchain.embeddings import OpenAIEmbeddings, VertexAIEmbeddings 4 | from llama_index import ( 5 | SimpleDirectoryReader, 6 | VectorStoreIndex, 7 | ServiceContext, 8 | StorageContext, 9 | ) 10 | from llama_index.embeddings import LangchainEmbedding 11 | from llama_index.node_parser import SimpleNodeParser 12 | from llama_index.vector_stores import AstraDBVectorStore 13 | 14 | from integrations.google import init_gcp, GECKO_EMB_DIM 15 | from integrations.openai import OPENAI_EMB_DIM 16 | from pipeline.config import LLMProvider, load_config 17 | 18 | dotenv_path = ".env" 19 | load_dotenv(dotenv_path) 20 | config = load_config("config.yml") 21 | 22 | # Provider for LLM 23 | if config.llm_provider == LLMProvider.OpenAI: 24 | embedding_model = LangchainEmbedding( 25 | OpenAIEmbeddings(model=config.openai_embeddings_model) 26 | ) 27 | else: 28 | init_gcp(config) 29 | embedding_model = LangchainEmbedding( 30 | VertexAIEmbeddings(model_name=config.google_embeddings_model) 31 | ) 32 | 33 | embedding_dimension = ( 34 | OPENAI_EMB_DIM if config.llm_provider == LLMProvider.OpenAI else GECKO_EMB_DIM 35 | ) 36 | 37 | vectorstore = AstraDBVectorStore( 38 | token=config.astra_db_application_token, 39 | api_endpoint=config.astra_db_api_endpoint, 40 | collection_name=config.astra_db_table_name, 41 | embedding_dimension=embedding_dimension, 42 | ) 43 | 44 | storage_context = StorageContext.from_defaults(vector_store=vectorstore) 45 | service_context = ServiceContext.from_defaults( 46 | llm=None, 47 | embed_model=embedding_model, 48 | node_parser=SimpleNodeParser.from_defaults( 49 | # According to https://genai.stackexchange.com/questions/317/does-the-length-of-a-token-give-llms-a-preference-for-words-of-certain-lengths 50 | # tokens are ~4 chars on average, so estimating 1,000 char chunk_size & 500 char overlap as previously used 51 | chunk_size=250, 52 | chunk_overlap=125, 53 | ), 54 | ) 55 | 56 | 57 | # Perform embedding and add to vectorstore 58 | def add_documents(folder_path): 59 | documents = SimpleDirectoryReader(folder_path).load_data() 60 | VectorStoreIndex.from_documents( 61 | documents=documents, 62 | storage_context=storage_context, 63 | service_context=service_context, 64 | show_progress=True, 65 | ) 66 | 67 | 68 | if __name__ == "__main__": 69 | add_documents("output") 70 | -------------------------------------------------------------------------------- /data/scrape_site.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from urllib.parse import urlparse 4 | 5 | from dotenv import load_dotenv 6 | 7 | from chatbot_api.crawl_scrape_docs import crawl_website_parallel 8 | from pipeline.config import load_config 9 | 10 | load_dotenv(".env") 11 | config = load_config("config.yml") 12 | 13 | # Astra docs 14 | for website in config.doc_pages: 15 | parsed_website = urlparse(website) 16 | basename_website = os.path.basename(parsed_website.path) 17 | 18 | crawl_website_parallel(website, os.path.join("output", f"{basename_website}.txt"), recursive=False) 19 | -------------------------------------------------------------------------------- /integrations/__init__.py: -------------------------------------------------------------------------------- 1 | from .astra import * 2 | from .example import * 3 | from .google import * 4 | from .intercom import * 5 | from .openai import * 6 | from .slack import * 7 | -------------------------------------------------------------------------------- /integrations/astra.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | 4 | def get_persona(contact: Dict[str, Any]) -> str: 5 | """Given information about the user, choose a persona and associated prompt""" 6 | # TODO: Only a single prompt here, extend as needed! 7 | 8 | return "default" 9 | -------------------------------------------------------------------------------- /integrations/example.py: -------------------------------------------------------------------------------- 1 | """ 2 | A file to show an example basic integration - other integrations can be created by following this 3 | format. This example takes in a request with a single required field "question", passes that as 4 | the prompt to the LLM, then prints the resulting question/answer pair upon request completion. 5 | 6 | NOTE: You must add this import to `integrations/__init__.py` in order for it to get added to the registry. 7 | """ 8 | from typing import Any, Mapping 9 | 10 | from pipeline import ( 11 | BaseIntegration, 12 | ResponseActor, 13 | ResponseDecider, 14 | ResponseDecision, 15 | UserContext, 16 | UserContextCreator, 17 | ) 18 | 19 | 20 | class ExampleMixin(BaseIntegration): 21 | required_fields = [] 22 | 23 | 24 | # TODO: Decouple conv_info from ResponseDecider (should only have to implement UserContextCreator) 25 | class ExampleResponseDecider(ExampleMixin, ResponseDecider): 26 | """An example of deciding on a response and producing conv_info""" 27 | 28 | def make_response_decision( 29 | self, 30 | request_body: Mapping[str, Any], 31 | request_headers: Mapping[str, str], 32 | ) -> ResponseDecision: 33 | assert ( 34 | "question" in request_body 35 | ), "Include 'question' field in the POST request" 36 | return ResponseDecision( 37 | should_return_early=False, 38 | conversation_info={"question": request_body["question"]}, 39 | ) 40 | 41 | 42 | class ExampleUserContextCreator(ExampleMixin, UserContextCreator): 43 | """An example of creating User Context to modify for any integrations""" 44 | 45 | def create_user_context(self, conv_info: Any) -> UserContext: 46 | return UserContext( 47 | user_question=conv_info["question"], 48 | persona="default", 49 | context_str="", 50 | ) 51 | 52 | 53 | class ExampleResponseActor(ExampleMixin, ResponseActor): 54 | """An example of a response actor that takes action based on the chatbot response""" 55 | 56 | def take_action( 57 | self, conv_info: Any, text_response: str, responses_from_vs: str, context: str 58 | ) -> None: 59 | print("Bot Response:") 60 | print(f" Question: {conv_info['question']}") 61 | print(f"") 62 | -------------------------------------------------------------------------------- /integrations/google.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from google.cloud import aiplatform 4 | from google.oauth2 import service_account 5 | 6 | GECKO_EMB_DIM = 768 7 | 8 | from pipeline.config import Config 9 | 10 | 11 | def init_gcp(config: Config) -> None: 12 | """Initialize GCP Auth based on environment variables""" 13 | # Google Auth 14 | google_credentials_json = json.loads(config.google_credentials) 15 | google_credentials_json["private_key"] = google_credentials_json[ 16 | "private_key" 17 | ].replace("\\n", "\n") 18 | credentials = service_account.Credentials.from_service_account_info( 19 | google_credentials_json 20 | ) 21 | aiplatform.init(project=config.google_project_id, credentials=credentials) 22 | -------------------------------------------------------------------------------- /integrations/intercom.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import hmac 3 | import re 4 | import json 5 | import bugsnag 6 | import requests 7 | 8 | from dataclasses import dataclass 9 | from integrations.astra import get_persona 10 | from pipeline import ( 11 | BaseIntegration, 12 | ResponseActor, 13 | ResponseDecider, 14 | ResponseDecision, 15 | UserContext, 16 | UserContextCreator, 17 | ) 18 | from typing import Any, Dict, List, Optional, Mapping, Union 19 | 20 | # Pulled from https://developers.intercom.com/docs/references/rest-api/api.intercom.io/Conversations/conversation/ 21 | DEFAULT_ALLOWED_DELIVERED_AS = [ 22 | "customer_initiated", 23 | "admin_initiated", 24 | "campaigns_initiated", 25 | "operator_initiated", 26 | "automated", 27 | ] 28 | 29 | 30 | # Validate the webhook actually comes from Intercom servers 31 | def validate_signature( 32 | header: Mapping[str, str], body: Mapping[str, Any], secret: str 33 | ) -> bool: 34 | # Get the signature from the payload 35 | signature_header = header["X-Hub-Signature"] 36 | sha_name, signature = signature_header.split("=") 37 | if sha_name != "sha1": 38 | print("ERROR: X-Hub-Signature in payload headers was not sha1=****") 39 | return False 40 | # Convert the dictionary to a JSON string 41 | data_str = json.dumps(body) 42 | data_bytes = data_str.encode("utf-8") 43 | 44 | local_signature = hmac.new( 45 | secret.encode("utf-8"), msg=data_bytes, digestmod=hashlib.sha1 46 | ) 47 | 48 | # See if they match 49 | return hmac.compare_digest(local_signature.hexdigest(), signature) 50 | 51 | 52 | class IntercomIntegrationMixin(BaseIntegration): 53 | required_fields = ["bot_intercom_id", "intercom_token", "intercom_client_secret"] 54 | 55 | # Get an Intercom contact/lead using the Intercom UUID 56 | def get_intercom_contact_by_id(self, _id: Union[int, str]) -> Dict[str, Any]: 57 | headers = {"Authorization": f"Bearer {self.config.intercom_token}"} 58 | res = requests.get(f"https://api.intercom.io/contacts/{_id}", headers=headers) 59 | return res.json() 60 | 61 | def add_comment_to_intercom_conversation( 62 | self, 63 | conversation_id: str, 64 | message: str, 65 | ) -> Dict[str, Any]: 66 | headers = {"Authorization": f"Bearer {self.config.intercom_token}"} 67 | res = requests.post( 68 | f"https://api.intercom.io/conversations/{conversation_id}/reply", 69 | headers=headers, 70 | json={ 71 | "message_type": "note", 72 | "type": "admin", 73 | "admin_id": self.config.bot_intercom_id, 74 | "body": message, 75 | }, 76 | ) 77 | return res.json() 78 | 79 | # Reply to an existing Intercom conversation 80 | def send_intercom_message( 81 | self, conversation_id: str, message: str 82 | ) -> Dict[str, Any]: 83 | headers = {"Authorization": f"Bearer {self.config.intercom_token}"} 84 | payload = { 85 | "type": "admin", 86 | "admin_id": self.config.bot_intercom_id, 87 | "message_type": "comment", 88 | "body": message, 89 | } 90 | res = requests.post( 91 | f"https://api.intercom.io/conversations/{conversation_id}/reply", 92 | json=payload, 93 | headers=headers, 94 | ) 95 | return res.json() 96 | 97 | 98 | @dataclass 99 | class IntercomConversationInfo: 100 | """A class representing all the required attributes from the chatbot to give a response""" 101 | 102 | conversation_id: str 103 | contact: Dict[str, Any] 104 | user_question: str 105 | is_user: bool 106 | debug_mode: bool 107 | source_url: str 108 | 109 | 110 | class IntercomResponseDecider(IntercomIntegrationMixin, ResponseDecider): 111 | """A class that determines the type of action to take based on the intercom request payload""" 112 | 113 | conversation_info: Optional[IntercomConversationInfo] = None 114 | 115 | def make_response_decision( 116 | self, 117 | request_body: Mapping[str, Any], 118 | request_headers: Mapping[str, str], 119 | allowed_delivered_as: Optional[List[str]] = None, 120 | ) -> ResponseDecision: 121 | """Set properties based on each of the logical branches we can take""" 122 | # NOTE: Empty list will be overriden 123 | allowed_delivered_as = allowed_delivered_as or DEFAULT_ALLOWED_DELIVERED_AS 124 | 125 | # Don't allow invalid signatures 126 | if not validate_signature( 127 | request_headers, request_body, self.config.intercom_client_secret 128 | ): 129 | return ResponseDecision( 130 | should_return_early=True, 131 | response_dict={"ok": False, "message": "Invalid signature."}, 132 | response_code=401, 133 | ) 134 | # Ignore repeat deliveries 135 | if request_body["delivery_attempts"] > 1: 136 | return ResponseDecision( 137 | should_return_early=True, 138 | response_dict={"ok": True, "message": "Already reported."}, 139 | response_code=208, 140 | ) 141 | 142 | data = request_body["data"] 143 | 144 | # Handle intercom webhook tests 145 | if data["item"]["type"] == "ping": 146 | return ResponseDecision( 147 | should_return_early=True, 148 | response_dict={"ok": True, "message": "Successful ping."}, 149 | response_code=200, 150 | ) 151 | # Check for empty source 152 | if data["item"]["source"] is None: 153 | return ResponseDecision( 154 | should_return_early=True, 155 | response_dict={"ok": False, "message": "Empty source."}, 156 | response_code=400, 157 | ) 158 | 159 | # Find relevant part of intercom conversation for most recent message 160 | conversation_parts = data["item"]["conversation_parts"]["conversation_parts"] 161 | filtered_conversation_parts = [ 162 | part 163 | for part in conversation_parts 164 | if part.get("part_type") != "default_assignment" 165 | and part.get("body") # Filter for nulls and empty strings 166 | ] 167 | 168 | # Use conversation parts if available (means user responded in the convo), 169 | # otherwise use the source (means user initiated a convo) 170 | if len(filtered_conversation_parts) > 0: 171 | conv_item = filtered_conversation_parts[0] 172 | else: 173 | conv_item = data["item"]["source"] 174 | 175 | conversation_text = conv_item["body"] 176 | author = conv_item["author"] 177 | 178 | def callback(event): 179 | event.user = {"email": author["email"]} 180 | 181 | bugsnag.before_notify(callback) 182 | 183 | # possible to be Contact, Admin, Campaign, Automated or Operator initiated 184 | delivered_as = "not_customer_initiated" 185 | if "delivered_as" in data["item"]["source"]: 186 | delivered_as = data["item"]["source"]["delivered_as"] 187 | 188 | conv_is_authorized = ( 189 | author["type"] == "user" and delivered_as in allowed_delivered_as 190 | ) 191 | 192 | if not conv_is_authorized: 193 | return ResponseDecision( 194 | should_return_early=True, 195 | response_dict={"ok": False, "message": "Unauthorized user."}, 196 | response_code=403, 197 | ) 198 | 199 | user_question = re.sub("<[^<]+?>", "", str(conversation_text)) 200 | # Reject request if empty question 201 | if not user_question: 202 | return ResponseDecision( 203 | should_return_early=True, 204 | response_dict={"ok": False, "message": "Query provided was empty"}, 205 | response_code=400, 206 | ) 207 | 208 | # If we passed every check above, should proceed with querying the LLM 209 | return ResponseDecision( 210 | should_return_early=False, 211 | conversation_info=IntercomConversationInfo( 212 | conversation_id=data["item"]["id"], 213 | contact=self.get_intercom_contact_by_id(author["id"]), 214 | user_question=user_question, 215 | is_user=f"@{self.config.company_url}" in author["email"] 216 | and self.config.company_url != "", 217 | debug_mode="[DEBUG]" in user_question, 218 | source_url=data["item"]["source"]["url"], 219 | ), 220 | ) 221 | 222 | 223 | class IntercomUserContextCreator(IntercomIntegrationMixin, UserContextCreator): 224 | def create_user_context(self, conv_info: IntercomConversationInfo) -> UserContext: 225 | # Grab needed parameters 226 | conversation_id = conv_info.conversation_id # Astra User Id 227 | 228 | # Build user context information present 229 | context_str = "No user information present." 230 | if ( 231 | conv_info.contact is not None 232 | and "name" in conv_info.contact 233 | and "email" in conv_info.contact 234 | ): 235 | context_str = ( 236 | f"Here is information on the user:\n" 237 | f"- User Name: {conv_info.contact['name']}\n" 238 | f"- User Email: {conv_info.contact['email']}\n" 239 | ) 240 | 241 | # Send an intercom debug message if debug mode is on 242 | if conv_info.debug_mode: 243 | self.send_intercom_message( 244 | conversation_id, 245 | f"Generating response: " 246 | f"\nContext: {context_str}\n" 247 | f"\nQuestion: {conv_info.user_question}\n", 248 | ) 249 | 250 | return UserContext( 251 | user_question=conv_info.user_question, 252 | persona=get_persona(conv_info.contact), 253 | context_str=context_str, 254 | ) 255 | 256 | 257 | class IntercomResponseActor(IntercomIntegrationMixin, ResponseActor): 258 | def take_action( 259 | self, 260 | conv_info: IntercomConversationInfo, 261 | text_response: str, 262 | responses_from_vs: str, 263 | context: str, 264 | ) -> None: 265 | # One more debugging message 266 | if conv_info.debug_mode: 267 | self.send_intercom_message( 268 | conv_info.conversation_id, "\nDocuments retrieved: " + responses_from_vs 269 | ) 270 | 271 | # Either comment or message based on whether its a current user 272 | if conv_info.is_user: 273 | self.send_intercom_message(conv_info.conversation_id, text_response) 274 | else: 275 | self.add_comment_to_intercom_conversation( 276 | conv_info.conversation_id, 277 | f"Assistant Suggested Response: {text_response}", 278 | ) 279 | 280 | # Return the result with the full response if desired 281 | result = {"ok": True, "message": "Response submitted successfully."} 282 | if self.config.intercom_include_response: 283 | result["response"] = text_response 284 | if self.config.intercom_include_context: 285 | result["context"] = context 286 | -------------------------------------------------------------------------------- /integrations/openai.py: -------------------------------------------------------------------------------- 1 | OPENAI_EMB_DIM = 1536 2 | -------------------------------------------------------------------------------- /integrations/slack.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import requests 4 | 5 | from pipeline import ResponseActor 6 | 7 | 8 | class SlackResponseActor(ResponseActor): 9 | required_fields = ["slack_webhook_url"] 10 | 11 | def send_slack_message(self, message: str) -> None: 12 | requests.post( 13 | self.config.slack_webhook_url, 14 | json={"text": message, "username": "AI Bot", "icon_emoji": ":ghost:"}, 15 | ) 16 | 17 | def take_action( 18 | self, conv_info: Any, text_response: str, responses_from_vs: str, context: str 19 | ) -> None: 20 | self.send_slack_message("*PROMPT*") 21 | self.send_slack_message(context) 22 | 23 | self.send_slack_message("*RESPONSE*") 24 | self.send_slack_message(text_response) 25 | -------------------------------------------------------------------------------- /pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_integration import BaseIntegration 2 | from .response_action import ResponseActor, take_all_actions 3 | from .response_decision import ( 4 | ResponseDecider, 5 | ResponseDecision, 6 | make_all_response_decisions, 7 | ) 8 | from .user_context import UserContext, UserContextCreator, create_all_user_context 9 | -------------------------------------------------------------------------------- /pipeline/base_integration.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict, List, Type 3 | 4 | from .config import Config 5 | 6 | integrations_registry: Dict[str, Type["BaseIntegration"]] = {} 7 | 8 | 9 | class BaseIntegration(metaclass=abc.ABCMeta): 10 | required_fields: List[str] # A list of config fields needed for this integration 11 | config: Config 12 | 13 | # Register all subclasses 14 | def __init_subclass__(cls, **kwargs): 15 | super().__init_subclass__(**kwargs) 16 | integrations_registry[cls.__name__] = cls 17 | 18 | def __init__(self, config: Config): 19 | self.config = config 20 | -------------------------------------------------------------------------------- /pipeline/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from enum import Enum 3 | from typing import List, Optional 4 | 5 | from pydantic import BaseModel, model_validator 6 | import yaml 7 | 8 | CONFIG_PATH = "config.yml" 9 | # TODO: Probably a way to specify to load from os.env in validator/type instead 10 | SENSITIVE_FIELDS = [ 11 | "OPENAI_API_KEY", 12 | "GOOGLE_CREDENTIALS", 13 | "GOOGLE_PROJECT_ID", 14 | "BOT_INTERCOM_ID", 15 | "INTERCOM_TOKEN", 16 | "INTERCOM_CLIENT_SECRET", 17 | "BUGSNAG_API_KEY", 18 | "SLACK_WEBHOOK_URL", 19 | "ASTRA_DB_API_ENDPOINT", 20 | "ASTRA_DB_APPLICATION_TOKEN", 21 | "ASTRA_DB_TABLE_NAME", 22 | ] 23 | 24 | 25 | class LLMProvider(str, Enum): 26 | OpenAI = "openai" 27 | Google = "google" 28 | 29 | 30 | class Config(BaseModel): 31 | """The allowed configuration options for this application""" 32 | 33 | # Base config options 34 | llm_provider: LLMProvider = LLMProvider.OpenAI 35 | company: str 36 | company_url: str = "" 37 | custom_rules: Optional[List[str]] = None 38 | doc_pages: List[str] 39 | mode: str = "Development" 40 | 41 | # Determine which integrations will run 42 | response_decider_cls: List[str] # TODO: Get a better name here 43 | user_context_creator_cls: List[str] 44 | response_actor_cls: List[str] 45 | 46 | # Integration specific fields for LLM Providers and Integrations 47 | # TODO: Move these down one level further into sub-Models that can be defined 48 | # in the corresponding integrations file 49 | openai_api_key: Optional[str] = None 50 | openai_embeddings_model: str = "text-embedding-ada-002" 51 | openai_textgen_model: str = "gpt-4" 52 | 53 | google_credentials: Optional[str] = None 54 | google_project_id: Optional[str] = None 55 | google_embeddings_model: str = "textembedding-gecko@latest" 56 | google_textgen_model: str = "TODO" 57 | 58 | bot_intercom_id: Optional[str] = None 59 | intercom_token: Optional[str] = None 60 | intercom_client_secret: Optional[str] = None 61 | intercom_include_response: bool = True 62 | intercom_include_context: bool = True 63 | 64 | bugsnag_api_key: Optional[str] = None 65 | 66 | slack_webhook_url: Optional[str] = None 67 | 68 | # Credentials for Astra DB 69 | astra_db_application_token: str 70 | astra_db_api_endpoint: str 71 | astra_db_table_name: str = "data" 72 | 73 | @model_validator(mode="after") 74 | def check_llm_creds(self): 75 | if self.llm_provider == LLMProvider.OpenAI: 76 | assert self.openai_api_key is not None, "openai_api_key must be included" 77 | elif self.llm_provider == LLMProvider.Google: 78 | assert ( 79 | self.google_credentials is not None 80 | ), "google_credentials must be included" 81 | assert ( 82 | self.google_project_id is not None 83 | ), "google_project_id must be included" 84 | else: 85 | raise ValueError(f"Unrecognized llm_provider {self.llm_provider}") 86 | 87 | return self 88 | 89 | @model_validator(mode="after") 90 | def check_integration_creds(self): 91 | """Validates that any integrations being used have credentials present""" 92 | # Avoiding circular import 93 | import integrations # noqa: needed to populate the integrations registry 94 | from .base_integration import integrations_registry 95 | 96 | all_integrations = ( 97 | self.response_decider_cls 98 | + self.user_context_creator_cls 99 | + self.response_decider_cls 100 | ) 101 | for integration_cls_name in all_integrations: 102 | required_fields = integrations_registry[ 103 | integration_cls_name 104 | ].required_fields 105 | for field in required_fields: 106 | assert ( 107 | getattr(self, field) is not None 108 | ), f"{field} must be specified for integration {integration_cls_name}" 109 | 110 | return self 111 | 112 | 113 | def load_config(path: str = CONFIG_PATH) -> Config: 114 | """Return the Config for the app - assumes all env vars have been loaded""" 115 | with open(path) as config_file: 116 | yaml_config = yaml.safe_load(config_file) 117 | 118 | for field in SENSITIVE_FIELDS: 119 | if field in os.environ: 120 | # Lowering case to match config expected behavior 121 | yaml_config[field.lower()] = os.environ[field] 122 | 123 | return Config(**yaml_config) 124 | -------------------------------------------------------------------------------- /pipeline/response_action.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any 3 | 4 | from .base_integration import BaseIntegration, integrations_registry 5 | from .config import Config 6 | 7 | 8 | class ResponseActor(BaseIntegration, metaclass=abc.ABCMeta): 9 | """ 10 | A class to take any actions based on the chatbot output, and return the 11 | response we will produce from the endpoint. If response is None, will iterate 12 | until we find a valid response 13 | """ 14 | 15 | @abc.abstractmethod 16 | def take_action( 17 | self, 18 | conv_info: Any, 19 | text_response: str, 20 | responses_from_vs: str, 21 | context: str, 22 | ) -> None: 23 | pass 24 | 25 | 26 | def take_all_actions( 27 | config: Config, 28 | conv_info: Any, 29 | text_response: str, 30 | responses_from_vs: str, 31 | context: str, 32 | ) -> None: 33 | """Runs all ResponseActors specified in config to take response actions""" 34 | for cls_name in config.response_actor_cls: 35 | response_actor = integrations_registry[cls_name](config) 36 | assert isinstance( 37 | response_actor, ResponseActor 38 | ), f"Must only specify ResponseActor in response_actor_cls" 39 | response_actor.take_action(conv_info, text_response, responses_from_vs, context) 40 | -------------------------------------------------------------------------------- /pipeline/response_decision.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Mapping, Optional 4 | 5 | from .base_integration import BaseIntegration, integrations_registry 6 | from .config import Config 7 | 8 | 9 | @dataclass 10 | class ResponseDecision: 11 | """ 12 | A class to handle incoming requests and determine whether 13 | to engage the chatbot to answer, or return early based on 14 | conditional logic. 15 | """ 16 | 17 | should_return_early: bool 18 | response_dict: Optional[Dict[str, Any]] = None 19 | response_code: Optional[int] = None 20 | conversation_info: Optional[Any] = None 21 | 22 | 23 | class ResponseDecider(BaseIntegration, metaclass=abc.ABCMeta): 24 | """A class to make a response decision based on the request input""" 25 | 26 | @abc.abstractmethod 27 | def make_response_decision( 28 | self, 29 | request_body: Mapping[str, Any], 30 | request_headers: Mapping[str, str], 31 | ) -> ResponseDecision: 32 | pass 33 | 34 | 35 | def make_all_response_decisions( 36 | config: Config, 37 | request_body: Mapping[str, Any], 38 | request_headers: Mapping[str, str], 39 | ) -> ResponseDecision: 40 | """Runs all ResponseDeciders specified in config to return ResponseDecision's""" 41 | # TODO: Some aggregation strategy that allows for multiple response deciders present 42 | for cls_name in config.response_decider_cls: 43 | response_actor = integrations_registry[cls_name](config) 44 | assert isinstance( 45 | response_actor, ResponseDecider 46 | ), f"Must only specify ResponseDecider in response_decider_cls" 47 | return response_actor.make_response_decision(request_body, request_headers) 48 | 49 | # No response deciders present, so just keep going 50 | return ResponseDecision(should_return_early=False) 51 | -------------------------------------------------------------------------------- /pipeline/user_context.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from dataclasses import dataclass 3 | from typing import Any 4 | 5 | from .base_integration import BaseIntegration, integrations_registry 6 | from .config import Config 7 | 8 | 9 | @dataclass 10 | class UserContext: 11 | """ 12 | A class to represent user context to be supplied to the LLM. 13 | """ 14 | 15 | user_question: str 16 | persona: str 17 | context_str: str 18 | 19 | 20 | class UserContextCreator(BaseIntegration, metaclass=abc.ABCMeta): 21 | """ 22 | A class to create the user context. It should handle retrieving the necessary 23 | fields, as well as formatting them appropriately into a string. 24 | """ 25 | 26 | @abc.abstractmethod 27 | def create_user_context(self, conv_info: Any) -> UserContext: 28 | pass 29 | 30 | 31 | def create_all_user_context( 32 | config: Config, 33 | conv_info: Any, 34 | ) -> UserContext: 35 | """Runs all ResponseActors specified in config to take response actions""" 36 | # TODO: Some aggregation strategy that allows for multiple user_context_creators 37 | for cls_name in config.user_context_creator_cls: 38 | user_context_creator = integrations_registry[cls_name](config) 39 | assert isinstance( 40 | user_context_creator, UserContextCreator 41 | ), f"Must only specify UserContextCreator in user_context_creator_cls" 42 | return user_context_creator.create_user_context(conv_info) 43 | 44 | raise ValueError(f"No UserContextCreator found - must specify one") 45 | -------------------------------------------------------------------------------- /prompts/default.yaml: -------------------------------------------------------------------------------- 1 | _type: prompt 2 | input_variables: 3 | - company 4 | - vector_search_results 5 | - user_context 6 | - user_question 7 | - custom_rules 8 | output_parser: null 9 | partial_variables: {} 10 | template: > 11 | #### DIRECTIVE #### 12 | You're a chatbot designed to help people with {company} and related questions. Your primary directive is to answer the USER QUESTION provided below. Please be explicit and provide context as required. 13 | 14 | Observe the following rules, in order of precedence, with the first rule being the most important: 15 | 16 | {custom_rules} 17 | - Be concise when answering the user's question. 18 | - Use simple terms and provide complete but succinct responses. 19 | - Format your response using markdown. 20 | 21 | #### USER QUESTION #### 22 | Answer the question below: 23 | 24 | {user_question} 25 | 26 | #### CONTEXT #### 27 | Here is some context that may be relevant to the user's question: 28 | 29 | {vector_search_results} 30 | 31 | #### USER PERSONAL INFORMATION #### 32 | Here is information about the user that may be relevant: 33 | 34 | {user_context} 35 | 36 | template_format: f-string -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asgiref~=3.7.2 2 | astrapy>=0.6.2 3 | beautifulsoup4~=4.12.2 4 | bugsnag~=4.6.0 5 | fastapi~=0.104.1 6 | google-api-python-client~=2.109.0 7 | google-cloud-aiplatform~=1.36.4 8 | httpx~=0.25.2 9 | openai~=1.3.7 10 | python-dotenv~=1.0.0 11 | ragstack-ai~=0.2.0 12 | tiktoken~=0.5.2 13 | uvicorn~=0.24.0.post1 14 | -------------------------------------------------------------------------------- /run_tests.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script that runs an automated test suite for Anthropic Claude. 3 | 4 | Testset and results use the Scorecard SDK. 5 | """ 6 | 7 | import os 8 | import httpx 9 | import scorecard 10 | 11 | from dotenv import load_dotenv 12 | from fastapi.testclient import TestClient 13 | 14 | from app import app 15 | 16 | load_dotenv(".env") 17 | 18 | 19 | SCORECARD_API_KEY = os.environ["SCORECARD_API_KEY"] 20 | 21 | 22 | def query_ai_chatbot_starter(user_query): 23 | # Set the request appropriately 24 | headers = {} 25 | request_body = {"question": user_query} 26 | 27 | response = TestClient(app).post("/chat", json=request_body, headers=headers) 28 | 29 | # Check if the request was successful 30 | if response.status_code == httpx.codes.created: 31 | return response.content.decode() 32 | else: 33 | return f"Request failed with status code {response.status_code}: {response.text}" 34 | 35 | 36 | def run_all_tests(input_testset_id: int, scoring_config_id: int): 37 | run_id = scorecard.create_run(input_testset_id, scoring_config_id) 38 | testcases = scorecard.get_testset(input_testset_id) 39 | 40 | for testcase in testcases: 41 | print(f"Running testcase {testcase['id']}...") 42 | print(f"User query: {testcase['user_query']}") 43 | 44 | # Get the model's response using the helper function 45 | model_response = query_ai_chatbot_starter(testcase["user_query"]) 46 | 47 | scorecard.log_record( 48 | run_id, testcase["id"], model_response, # TODO: Add PROMPT_TEMPLATE 49 | ) 50 | 51 | scorecard.update_run_status(run_id) 52 | 53 | 54 | if __name__ == "__main__": 55 | INPUT_TESTSET_ID = int(os.environ["INPUT_TESTSET_ID"]) 56 | SCORING_CONFIG_ID = int(os.environ["SCORING_CONFIG_ID"]) 57 | 58 | run_all_tests(INPUT_TESTSET_ID, SCORING_CONFIG_ID) 59 | -------------------------------------------------------------------------------- /scorecard.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for the Scorecard SDK. 3 | 4 | Testset and results use the Scorecard SDK. 5 | """ 6 | 7 | import os 8 | import requests 9 | 10 | from dotenv import load_dotenv 11 | 12 | load_dotenv(".env") 13 | 14 | SCORECARD_API_KEY = os.environ["SCORECARD_API_KEY"] 15 | 16 | # Endpoint definitions 17 | BASE_URL = "https://api.getscorecard.ai/" 18 | POST_CREATE_RUN_URL = BASE_URL + "create-run" 19 | GET_TESTSET_BASE_URL = BASE_URL + "testset" 20 | POST_CREATE_TESTRECORD_URL = BASE_URL + "create-testrecord" 21 | PATCH_UPDATE_RUN_BASE_URL = BASE_URL + "update-run" 22 | 23 | REQUEST_HEADERS = { 24 | "accept": "application/json", 25 | "content-type": "application/json", 26 | "X-API-Key": SCORECARD_API_KEY, 27 | } 28 | 29 | 30 | def create_run( 31 | input_testset_id: int, scoring_config_id: int, model_params: dict() = {} 32 | ): 33 | print("Creating new run...") 34 | create_run_url = POST_CREATE_RUN_URL 35 | response = requests.post( 36 | create_run_url, 37 | json={ 38 | "testset_id": input_testset_id, 39 | "scoring_config_id": scoring_config_id, 40 | "status": "running_execution", 41 | "model_params": model_params, 42 | }, 43 | headers=REQUEST_HEADERS, 44 | timeout=30, 45 | ) 46 | if response.status_code != 200: 47 | print(f"ERROR: {response.status_code} {response.text}") 48 | print(response.text) 49 | return response.json()["run_id"] 50 | 51 | 52 | def get_testset(testset_id: int): 53 | print("Retrieving testset...") 54 | get_testset_url = GET_TESTSET_BASE_URL + "/" + str(testset_id) 55 | testset_response = requests.get( 56 | get_testset_url, headers=REQUEST_HEADERS, timeout=30 57 | ) 58 | if testset_response.status_code != 200: 59 | print(f"ERROR: {testset_response.status_code} {testset_response.text}") 60 | return [] 61 | return testset_response.json()["data"] 62 | 63 | 64 | def update_run_status(run_id: int, status: str = "awaiting_scoring"): 65 | update_run_url = PATCH_UPDATE_RUN_BASE_URL + "/" + str(run_id) + "?status=" + status 66 | response = requests.patch(update_run_url, headers=REQUEST_HEADERS, timeout=30) 67 | if response.status_code != 200: 68 | print(f"ERROR: {response.status_code} {response.text}") 69 | return response.json() 70 | 71 | 72 | def log_record( 73 | run_id: int, 74 | testcase_id: int, 75 | model_response: str, 76 | prompt: str = "", 77 | model_params: dict() = {}, 78 | ): 79 | testrecord = { 80 | "run_id": run_id, 81 | "testcase_id": testcase_id, 82 | "model_response": model_response, 83 | "prompt": prompt, 84 | "model_params": model_params, 85 | } 86 | 87 | print(f"Writing new testrecord for run_id {run_id}...") 88 | print("Testrecord:") 89 | for key, value in testrecord.items(): 90 | print(f"\t{key}: {str(value)[:100]}") 91 | 92 | response = requests.post( 93 | POST_CREATE_TESTRECORD_URL, 94 | json=testrecord, 95 | headers=REQUEST_HEADERS, 96 | timeout=30, 97 | ) 98 | if response.status_code != 200: 99 | print(f"ERROR: {response.status_code} {response.text}") 100 | -------------------------------------------------------------------------------- /scripts/call_assistant.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import sys 3 | 4 | 5 | ### 6 | # Let's define the question right here 7 | ### 8 | CHATBOT_QUESTION = "What is Stargate? Can you give 5 key benefits?" 9 | 10 | 11 | def call_assistant_async(chatbot_question=CHATBOT_QUESTION): 12 | # Set the request appropriately 13 | headers = {} 14 | request_body = {"question": chatbot_question} 15 | 16 | full_result = "" 17 | with httpx.stream( 18 | "POST", 19 | "http://127.0.0.1:5555/chat", 20 | json=request_body, 21 | headers=headers, 22 | timeout=600, 23 | ) as r: 24 | for chunk in r.iter_text(): 25 | print(chunk, flush=True, end="") 26 | full_result += chunk 27 | 28 | return full_result 29 | 30 | 31 | def call_assistant_sync(chatbot_question=CHATBOT_QUESTION): 32 | # Set the request appropriately 33 | headers = {} 34 | request_body = {"question": chatbot_question} 35 | 36 | r = httpx.post("http://127.0.0.1:5555/chat", json=request_body, headers=headers) 37 | 38 | # Check if the request was successful 39 | if r.status_code == httpx.codes.created: 40 | return r.content.decode() 41 | else: 42 | return f"Request failed with status code {r.status_code}: {r.text}" 43 | 44 | 45 | if __name__ == "__main__": 46 | if len(sys.argv) < 2: 47 | raise ValueError("Please provide a question to ask the chatbot") 48 | 49 | call_assistant_async(chatbot_question=sys.argv[1]) 50 | 51 | # Alternatively 52 | # example_result = call_assistant_sync() 53 | # print(example_result) 54 | -------------------------------------------------------------------------------- /scripts/call_assistant_intercom.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import json 3 | import os 4 | import hmac 5 | import hashlib 6 | import sys 7 | 8 | 9 | from dotenv import load_dotenv 10 | 11 | load_dotenv(".env") 12 | 13 | 14 | ### 15 | # Let's define the question right here 16 | ### 17 | CHATBOT_QUESTION = "What is Stargate? Can you give 5 key benefits?" 18 | intercom_secret = os.getenv("INTERCOM_CLIENT_SECRET") 19 | 20 | 21 | # Build the appropriate headers 22 | def get_headers(body): 23 | """Helper to get necessary request headers for successful POST""" 24 | digest = hmac.new( 25 | intercom_secret.encode("utf-8"), 26 | msg=json.dumps(body).encode("utf-8"), 27 | digestmod=hashlib.sha1, 28 | ).hexdigest() 29 | 30 | return {"X-Hub-Signature": f"sha1={digest}"} 31 | 32 | 33 | # Load the test request 34 | user_data_file = "tests/test_request.json" 35 | user_data = "" 36 | with open(user_data_file, "r") as f: 37 | user_data = json.load(f) 38 | 39 | 40 | def call_assistant_async(chatbot_question=CHATBOT_QUESTION): 41 | # Set the request appropriately 42 | user_data["data"]["item"]["conversation_parts"]["conversation_parts"][0][ 43 | "body" 44 | ] = chatbot_question 45 | user_data["data"]["item"]["source"]["body"] = chatbot_question 46 | 47 | headers = get_headers(user_data) 48 | 49 | full_result = "" 50 | with httpx.stream( 51 | "POST", 52 | "http://127.0.0.1:5555/chat", 53 | json=user_data, 54 | headers=headers, 55 | timeout=600, 56 | ) as r: 57 | for chunk in r.iter_text(): 58 | print(chunk, flush=True, end="") 59 | full_result += chunk 60 | 61 | return full_result 62 | 63 | 64 | def call_assistant_sync(chatbot_question=CHATBOT_QUESTION): 65 | # Set the request appropriately 66 | user_data["data"]["item"]["conversation_parts"]["conversation_parts"][0][ 67 | "body" 68 | ] = chatbot_question 69 | user_data["data"]["item"]["source"]["body"] = chatbot_question 70 | 71 | headers = get_headers(user_data) 72 | 73 | r = httpx.post("http://127.0.0.1:5555/chat", json=user_data, headers=headers) 74 | 75 | # Check if the request was successful 76 | if r.status_code == httpx.codes.created: 77 | return r.content.decode() 78 | else: 79 | return f"Request failed with status code {r.status_code}: {r.text}" 80 | 81 | 82 | if __name__ == "__main__": 83 | call_assistant_async(chatbot_question=sys.argv[1]) 84 | 85 | # Alternatively 86 | # example_result = call_assistant_sync() 87 | # print(example_result) 88 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastax/ai-chatbot-starter/703d9de23df6b8d305998d25c446b4d5d5e18d3b/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | import pytest 3 | 4 | from integrations.google import init_gcp 5 | from pipeline.config import load_config 6 | 7 | 8 | @pytest.fixture(scope="module", autouse=True) 9 | def init_config(): 10 | load_dotenv(".env") 11 | config = load_config("config.yml") 12 | yield config 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def gcp_conn(init_config): 17 | init_gcp(init_config) 18 | -------------------------------------------------------------------------------- /tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | log_cli = true 3 | log_cli_level = INFO 4 | log_cli_format = %(message)s 5 | log_file = ./pytest_output.txt 6 | log_file_level = INFO 7 | log_file_format = %(message)s 8 | 9 | -------------------------------------------------------------------------------- /tests/test_app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import logging 4 | from unittest.mock import MagicMock, patch 5 | 6 | from fastapi.testclient import TestClient 7 | import hashlib 8 | import hmac 9 | from llama_index.response.schema import StreamingResponse 10 | import pytest 11 | import requests 12 | 13 | 14 | def get_headers(body): 15 | """Helper to get necessary request headers for successful POST 16 | 17 | NOTE: Keys in `body` must be specified (recursively) in alphabetical order 18 | """ 19 | intercom_secret = os.getenv("INTERCOM_CLIENT_SECRET") 20 | 21 | digest = hmac.new( 22 | intercom_secret.encode("utf-8"), 23 | msg=json.dumps(body).encode("utf-8"), 24 | digestmod=hashlib.sha1, 25 | ).hexdigest() 26 | return {"X-Hub-Signature": f"sha1={digest}"} 27 | 28 | 29 | def load_test_request(filename): 30 | with open(filename, "r") as f: 31 | request_data = json.load(f) 32 | return request_data 33 | 34 | 35 | @pytest.fixture(scope="module") 36 | def client(init_config): 37 | # Patch necessary things 38 | with patch("pipeline.config.load_config") as mock: 39 | mock.return_value = init_config 40 | from app import app 41 | 42 | yield TestClient(app) 43 | 44 | 45 | @pytest.fixture(scope="function") 46 | def standard_request(): 47 | return load_test_request(os.path.join("tests", "test_request.json")) 48 | 49 | 50 | @pytest.fixture(scope="function") 51 | def mock_assistant(): 52 | """Mocks the AssistantBison object to prevent any real LLM queries being made""" 53 | with patch("app.assistant") as mock_bison: 54 | response_gen = (s for s in ["Mocked", "response"]) 55 | mock_bison.get_response = MagicMock( 56 | return_value=(StreamingResponse(response_gen), [], []) 57 | ) 58 | yield mock_bison 59 | 60 | 61 | def get_text_response(client, data, headers, assert_created=True): 62 | # r = httpx.post("http://127.0.0.1:5010/chat", json=user_data, headers=headers) 63 | response = client.post("/chat", json=data, headers=headers) 64 | 65 | if assert_created: 66 | assert ( 67 | response.status_code == requests.codes.created 68 | ), f"Request failed with status code {response.status_code}: {response.text}" 69 | 70 | # Check if the request was successful 71 | return response.content.decode() 72 | 73 | 74 | def test_get_root_route(client): 75 | response = client.get("/chat") 76 | assert response.status_code == 200 77 | assert response.json()["ok"] == True 78 | assert response.json()["message"] == "App is running" 79 | 80 | 81 | def test_standard_case(standard_request, client): 82 | headers = get_headers(standard_request) 83 | text = get_text_response(client, standard_request, headers) 84 | assert len(text) > 0 85 | 86 | 87 | def test_broad_case(standard_request, client): 88 | # Process each question sequentially in the test questions file 89 | with open(os.path.join("tests", "test_questions.txt"), "r") as file: 90 | lines = [line.strip() for line in file] 91 | 92 | for line in lines: 93 | # Set the request appropriately 94 | standard_request = {"question": line} 95 | 96 | # Create the digest and headers for the POST request 97 | headers = get_headers(standard_request) 98 | 99 | # Make the post request 100 | text_response = get_text_response(client, standard_request, headers) 101 | 102 | # Log the results to a file for manual inspection 103 | logging.info("###") 104 | logging.info(line) 105 | logging.info(text_response) 106 | logging.info("###") 107 | logging.info("\n") 108 | -------------------------------------------------------------------------------- /tests/test_prompts.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pytest 3 | 4 | sys.path.append("../") 5 | 6 | from chatbot_api.assistant import AssistantBison 7 | 8 | mock_context = ( 9 | f"Here is information on the user:\n" 10 | f"- User Name: Fake User\n" 11 | f"- User Email: fake.user@example.com\n" 12 | f"- User Primary Programming Language (also known as favorite programming language and preferred programming language): Javascript\n" 13 | f"The user has not created any databases" 14 | ) 15 | 16 | questions = [ 17 | "What is your name?", # NOTE: This prompt can cause 'unsafe' responses from bison 18 | "How do I create a token?", 19 | ] 20 | 21 | 22 | @pytest.mark.parametrize("persona", ["default"]) 23 | def test_prompts(persona, init_config): 24 | assistant = AssistantBison( 25 | config=init_config, 26 | max_tokens_response=1024, 27 | k=4, 28 | company=init_config.company, 29 | custom_rules=init_config.custom_rules, 30 | ) 31 | 32 | print(f"\n{persona} Questions:") 33 | for x, question in enumerate(questions): 34 | print(f"#{x + 1}: {question}") 35 | response = assistant.get_response(question, persona, user_context=mock_context) 36 | print(response) 37 | -------------------------------------------------------------------------------- /tests/test_questions.txt: -------------------------------------------------------------------------------- 1 | What is the color of the sky? 2 | What is the meaning of life? 3 | What are 5 key advantages of your service? 4 | -------------------------------------------------------------------------------- /tests/test_request.json: -------------------------------------------------------------------------------- 1 | { 2 | "question": "What does DataStax do?" 3 | } 4 | -------------------------------------------------------------------------------- /tests/test_request_intercom.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": { 3 | "item": { 4 | "conversation_parts": { 5 | "conversation_parts": [ 6 | { 7 | "author": { 8 | "email": "fake.user@example.com", 9 | "id": "80be2e9f4de6136bab3f27d3", 10 | "type": "user" 11 | }, 12 | "body": "This is a test question" 13 | } 14 | ] 15 | }, 16 | "id": "181643600711471", 17 | "source": { 18 | "author": { 19 | "email": "fake.user@example.com", 20 | "id": "80be2e9f4de6136bab3f27d3", 21 | "type": "user" 22 | }, 23 | "body": "This is a test question", 24 | "delivered_as": "customer_initiated", 25 | "url": "https://example.com" 26 | }, 27 | "type": "test" 28 | } 29 | }, 30 | "delivery_attempts": 1 31 | } 32 | --------------------------------------------------------------------------------