├── tests ├── __init__.py ├── test_sql_engine.py ├── test_workflow.py ├── test_document_management.py ├── test_audio.py ├── test_models.py └── test_utils.py ├── tools ├── test_import.py ├── cli │ ├── config │ │ ├── __init__.py │ │ └── models.py │ ├── screens │ │ ├── __init__.py │ │ ├── embedding_providers │ │ │ ├── __init__.py │ │ │ ├── azure.py │ │ │ ├── cohere.py │ │ │ ├── gemini.py │ │ │ ├── openai.py │ │ │ ├── huggingface.py │ │ │ └── bedrock.py │ │ ├── initial.py │ │ ├── base.py │ │ └── embedding_provider.py │ ├── stylesheets │ │ └── base.tcss │ └── embedding_app.py ├── create_llama_extract_agent.py └── create_llama_cloud_index.py ├── .python-version ├── src └── notebookllama │ ├── __init__.py │ ├── models.py │ ├── querying.py │ ├── server.py │ ├── verifying.py │ ├── workflow.py │ ├── mindmap.py │ ├── pages │ ├── 1_Document_Management_UI.py │ ├── 4_Observability_Dashboard.py │ ├── 3_Interactive_Table_and_Plot_Visualization.py │ └── 2_Document_Chat.py │ ├── documents.py │ ├── utils.py │ ├── instrumentation.py │ ├── processing.py │ ├── Home.py │ └── audio.py ├── .DS_Store ├── data └── test │ ├── images │ └── image.png │ ├── brain_for_kids.pdf │ └── md_sample.md ├── .env.example ├── .gitignore ├── .github ├── workflows │ ├── linting.yaml │ ├── test.yaml │ ├── release.yaml │ └── type_checking.yaml └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug-or-help-request-report.md ├── compose.yaml ├── LICENSE ├── pyproject.toml ├── CONTRIBUTING.md ├── .pre-commit-config.yaml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/test_import.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.13 2 | -------------------------------------------------------------------------------- /src/notebookllama/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/run-llama/notebookllama/HEAD/.DS_Store -------------------------------------------------------------------------------- /tools/cli/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import EmbeddingConfig 2 | 3 | __all__ = ["EmbeddingConfig"] 4 | -------------------------------------------------------------------------------- /data/test/images/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/run-llama/notebookllama/HEAD/data/test/images/image.png -------------------------------------------------------------------------------- /data/test/brain_for_kids.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/run-llama/notebookllama/HEAD/data/test/brain_for_kids.pdf -------------------------------------------------------------------------------- /tools/cli/screens/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseScreen, ConfigurationScreen 2 | from .initial import InitialScreen 3 | from .embedding_provider import ProviderSelectScreen 4 | 5 | __all__ = ["BaseScreen", "ConfigurationScreen", "InitialScreen", "ProviderSelectScreen"] 6 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY="sk-***" 2 | LLAMACLOUD_API_KEY="llx-***" 3 | 4 | # Regional Endpoint Configuration (Uncomment the appropriate line for your region) 5 | # LLAMACLOUD_REGION="eu" # Europe 6 | 7 | ELEVENLABS_API_KEY="sk_***" 8 | pgql_db="postgres" 9 | pgql_user="localhost" 10 | pgql_psw="admin" 11 | -------------------------------------------------------------------------------- /tools/cli/config/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class EmbeddingConfig: 7 | provider: str 8 | api_key: Optional[str] = None 9 | model: Optional[str] = None 10 | region: Optional[str] = None 11 | key_id: Optional[str] = None 12 | embedding_config: Optional[object] = None 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python-generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | lib/ 7 | wheels/ 8 | *.egg-info 9 | 10 | # Virtual environments 11 | .venv 12 | 13 | # caches 14 | .*_cache/ 15 | 16 | # env files 17 | .env 18 | 19 | # image files 20 | static/ 21 | 22 | # audio files 23 | *.mp3 24 | 25 | # auto-generated images 26 | static/ 27 | data/extracted_tables/ 28 | -------------------------------------------------------------------------------- /.github/workflows/linting.yaml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | lint: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v4 11 | 12 | - name: Install uv 13 | uses: astral-sh/setup-uv@v6 14 | 15 | - name: Set up Python 16 | run: uv python install 3.12 17 | 18 | - name: Run linter 19 | shell: bash 20 | run: uv run -- pre-commit run -a 21 | -------------------------------------------------------------------------------- /tools/cli/stylesheets/base.tcss: -------------------------------------------------------------------------------- 1 | Screen { 2 | align: center top; 3 | padding: 1 4; 4 | } 5 | 6 | .form-title { 7 | align: center top; 8 | content-align: center middle; 9 | width: 100%; 10 | color: rgb(83, 200, 0); 11 | padding: 1; 12 | } 13 | 14 | .form-control { 15 | align: center middle; 16 | width: 100%; 17 | height: auto; 18 | margin: 1; 19 | } 20 | 21 | Input, Select { 22 | width: 80; 23 | margin: 1; 24 | } 25 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: CI Tests 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | testing: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v4 11 | with: 12 | fetch-depth: 1 13 | 14 | - name: Install uv 15 | uses: astral-sh/setup-uv@v6 16 | 17 | - name: Set up Python 18 | run: uv python install 19 | 20 | - name: Run Tests 21 | run: uv run -- pytest tests/test_*.py 22 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: GitHub Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v[0-9].[0-9]+.[0-9]+*" 7 | 8 | jobs: 9 | release: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | contents: write 13 | 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v4 17 | 18 | - name: Create GitHub Release 19 | uses: ncipollo/release-action@v1 20 | with: 21 | generateReleaseNotes: true 22 | -------------------------------------------------------------------------------- /.github/workflows/type_checking.yaml: -------------------------------------------------------------------------------- 1 | name: Typecheck 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | core-typecheck: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v4 11 | with: 12 | fetch-depth: 1 13 | 14 | - name: Install uv 15 | uses: astral-sh/setup-uv@v6 16 | 17 | - name: Set up Python 18 | run: uv python install 19 | 20 | - name: Run Mypy 21 | working-directory: src 22 | run: uv run -- mypy notebookllama 23 | -------------------------------------------------------------------------------- /tools/cli/screens/embedding_providers/__init__.py: -------------------------------------------------------------------------------- 1 | from .openai import OpenAIEmbeddingScreen 2 | from .bedrock import BedrockEmbeddingScreen 3 | from .azure import AzureEmbeddingScreen 4 | from .gemini import GeminiEmbeddingScreen 5 | from .cohere import CohereEmbeddingScreen 6 | from .huggingface import HuggingFaceEmbeddingScreen 7 | 8 | __all__ = [ 9 | "OpenAIEmbeddingScreen", 10 | "BedrockEmbeddingScreen", 11 | "AzureEmbeddingScreen", 12 | "GeminiEmbeddingScreen", 13 | "CohereEmbeddingScreen", 14 | "HuggingFaceEmbeddingScreen", 15 | ] 16 | -------------------------------------------------------------------------------- /tools/create_llama_extract_agent.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 5 | 6 | from src.notebookllama.models import Notebook 7 | from src.notebookllama.utils import create_llama_extract_client 8 | from dotenv import load_dotenv 9 | 10 | load_dotenv() 11 | 12 | 13 | def main() -> int: 14 | conn = create_llama_extract_client() 15 | agent = conn.create_agent(name="q_and_a_agent", data_schema=Notebook) 16 | _id = agent.id 17 | with open(".env", "a") as f: 18 | f.write(f'\nEXTRACT_AGENT_ID="{_id}"') 19 | return 0 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /compose.yaml: -------------------------------------------------------------------------------- 1 | name: instrumentation 2 | 3 | services: 4 | jaeger: 5 | image: jaegertracing/all-in-one:latest 6 | ports: 7 | - 16686:16686 8 | - 4317:4317 9 | - 4318:4318 10 | - 9411:9411 11 | environment: 12 | - COLLECTOR_ZIPKIN_HOST_PORT=:9411 13 | 14 | postgres: 15 | image: postgres 16 | ports: 17 | - 5432:5432 18 | environment: 19 | POSTGRES_DB: $pgql_db 20 | POSTGRES_USER: $pgql_user 21 | POSTGRES_PASSWORD: $pgql_psw 22 | volumes: 23 | - pgdata:/var/lib/postgresql/data 24 | 25 | adminer: 26 | image: adminer 27 | ports: 28 | - "8080:8080" 29 | 30 | volumes: 31 | pgdata: 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "" 5 | labels: enhancement, triage 6 | assignees: AstraBert 7 | --- 8 | 9 | **Is your feature request related to a problem? Please describe.** 10 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 11 | 12 | **Describe the solution you'd like** 13 | A clear and concise description of what you want to happen. 14 | 15 | **Describe alternatives you've considered** 16 | A clear and concise description of any alternative solutions or features you've considered. 17 | 18 | **Additional context** 19 | Add any other context or screenshots about the feature request here. 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) Jerry Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /tools/cli/embedding_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from textual.app import App 3 | 4 | from .config import EmbeddingConfig 5 | from .screens import InitialScreen 6 | 7 | 8 | class EmbeddingSetupApp(App): 9 | """Main application for embedding configuration setup.""" 10 | 11 | CSS_PATH = "stylesheets/base.tcss" 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.config = EmbeddingConfig(provider="") 16 | 17 | def on_mount(self) -> None: 18 | self.push_screen(InitialScreen()) 19 | 20 | def handle_completion(self, config: EmbeddingConfig) -> None: 21 | self.exit(config) 22 | 23 | def handle_default_setup(self) -> None: 24 | from llama_index.embeddings.openai import OpenAIEmbedding 25 | from llama_cloud import PipelineCreateEmbeddingConfig_OpenaiEmbedding 26 | 27 | self.config.provider = "OpenAI" 28 | self.config.api_key = os.getenv("OPENAI_API_KEY") 29 | self.config.model = "text-embedding-3-small" 30 | 31 | embed_model = OpenAIEmbedding( 32 | model=self.config.model, api_key=self.config.api_key 33 | ) 34 | embedding_config = PipelineCreateEmbeddingConfig_OpenaiEmbedding( 35 | type="OPENAI_EMBEDDING", 36 | component=embed_model, 37 | ) 38 | self.config = embedding_config 39 | 40 | self.handle_completion(self.config) 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "notebookllama" 3 | version = "0.4.0" 4 | description = "An OSS and LlamaCloud-backed alternative to NotebookLM" 5 | readme = "README.md" 6 | requires-python = ">=3.13" 7 | dependencies = [ 8 | "audioop-lts>=0.2.1", 9 | "elevenlabs>=2.5.0", 10 | "fastmcp>=2.9.2", 11 | "ffprobe>=0.5", 12 | "llama-cloud>=0.1.29", 13 | "llama-cloud-services>=0.6.38", 14 | "llama-index-core>=0.12.44", 15 | "llama-index-embeddings-azure-inference>=0.3.0", 16 | "llama-index-embeddings-bedrock>=0.5.2", 17 | "llama-index-embeddings-cohere>=0.5.1", 18 | "llama-index-embeddings-gemini>=0.3.2", 19 | "llama-index-embeddings-huggingface-api>=0.3.1", 20 | "llama-index-embeddings-openai>=0.3.1", 21 | "llama-index-indices-managed-llama-cloud>=0.6.11", 22 | "llama-index-llms-openai>=0.4.7", 23 | "llama-index-observability-otel>=0.1.1", 24 | "llama-index-tools-mcp>=0.2.5", 25 | "llama-index-workflows>=1.0.1", 26 | "markdown-analysis>=0.1.5", 27 | "mypy>=1.16.1", 28 | "opentelemetry-exporter-otlp-proto-http>=1.34.1", 29 | "plotly>=6.2.0", 30 | "pre-commit>=4.2.0", 31 | "psycopg2-binary>=2.9.10", 32 | "pydub>=0.25.1", 33 | "pytest>=8.4.1", 34 | "pytest-asyncio>=1.0.0", 35 | "python-dotenv>=1.1.1", 36 | "pyvis>=0.3.2", 37 | "randomname>=0.2.1", 38 | "streamlit>=1.46.1", 39 | "textual>=3.7.1" 40 | ] 41 | 42 | [tool.mypy] 43 | disable_error_code = ["import-not-found", "import-untyped"] 44 | 45 | [tool.pytest.ini_options] 46 | pythonpath = ["src"] 47 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-or-help-request-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug or Help Request Report 3 | about: This template is intended for bug or help request reports 4 | title: "[REPORT]" 5 | labels: bug, help wanted 6 | assignees: AstraBert 7 | --- 8 | 9 | ### Checklist 10 | 11 | Before reporting bugs or help requests, please fill out the following checklist: 12 | 13 | - [ ] I set the `.env` file with the necessary keys, including OPENAI_API_KEY, LLAMACLOUD_API_KEY, ELEVENLABS_API_KEY 14 | - [ ] My API keys are all functioning and have at least some credit balance 15 | - [ ] I ran both the scripts in the `tools` directory 16 | - [ ] I checked that my LlamaCloud Extract Agent and my LlamaCloud Index has been correctly created and are functioning (using the playground functionality in the [UI](https://cloud.llamaindex.ai) 17 | - [ ] I activated the virtual environment and ensured that all the dependencies are installed properly 18 | - [ ] As a first step, I launched the docker services through `docker compose up -d` 19 | - [ ] As a second step, I launched the MCP server through `uv run src/notebookllama/server.py` 20 | - [ ] As a third and last step, I launched the Streamlit app through `strealit run src/notebookllama/Home.py` 21 | 22 | ### Issue Description 23 | 24 | My issue is... 25 | 26 | ### Relevant Traceback 27 | 28 | ```text 29 | 30 | ``` 31 | 32 | ### Other details 33 | 34 | **OS**: 35 | **uv version**: 36 | **streamlit version**: 37 | **fastmcp version**: 38 | **llama-index-workflows version**: 39 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to `notebooklm-clone` 2 | 3 | Do you want to contribute to this project? Make sure to read this guidelines first :) 4 | 5 | ## Issue 6 | 7 | **When to do it**: 8 | 9 | - You found bugs but you don't know how to solve them or don't have time/will to do the solve 10 | - You want new features but you don't know how to implement them or don't have time/will to do the implementation 11 | 12 | > ⚠️ _Always check open and closed issues before you submit yours to avoid duplicates_ 13 | 14 | **How to do it**: 15 | 16 | - Open an issue 17 | - Give the issue a meaningful title (short but effective problem/feature request description) 18 | - Describe the problem/feature request 19 | 20 | ## Traditional contribution 21 | 22 | **When to do it**: 23 | 24 | - You found bugs and corrected them 25 | - You optimized/improved the code 26 | - You added new features that you think could be useful to others 27 | 28 | **How to do it**: 29 | 30 | 1. Fork this repository 31 | 2. Install `pre-commit` and make sure to have it within the Git Hooks for your fork: 32 | 33 | ```bash 34 | pip install pre-commit 35 | pre-commit install 36 | ``` 37 | 38 | 3. Change the things you want, and make sure tests still pass or add new ones: 39 | 40 | ```bash 41 | pytest tests/test_*.py 42 | ``` 43 | 44 | 3. Commit your changes 45 | 4. Make sure your changes pass the pre-commit linting/type checking, if not modify them so that they pass 46 | 5. Submit pull request (make sure to provide a thorough description of the changes) 47 | 48 | ### Thanks for contributing! 49 | -------------------------------------------------------------------------------- /tools/cli/screens/initial.py: -------------------------------------------------------------------------------- 1 | from textual import on 2 | from textual.widgets import Select 3 | 4 | from .base import BaseScreen 5 | 6 | 7 | class InitialScreen(BaseScreen): 8 | """Initial screen for choosing between default or custom settings.""" 9 | 10 | def get_title(self) -> str: 11 | return "How do you wish to proceed?" 12 | 13 | def get_form_elements(self) -> list: 14 | return [ 15 | Select( 16 | options=[ 17 | ("With Default Settings", "default_settings"), 18 | ("With Custom Settings", "custom_settings"), 19 | ], 20 | prompt="Please select one of the following", 21 | id="setup_type", 22 | classes="form-control", 23 | ) 24 | ] 25 | 26 | @on(Select.Changed, "#setup_type") 27 | def handle_selection(self, event: Select.Changed) -> None: 28 | from ..embedding_app import EmbeddingSetupApp 29 | 30 | app = self.app 31 | if isinstance(app, EmbeddingSetupApp): 32 | app.config.setup_type = event.value 33 | self.handle_next() 34 | 35 | def handle_next(self) -> None: 36 | from ..embedding_app import EmbeddingSetupApp 37 | from .embedding_provider import ProviderSelectScreen 38 | 39 | app = self.app 40 | if isinstance(app, EmbeddingSetupApp): 41 | if app.config.setup_type == "default_settings": 42 | app.handle_default_setup() 43 | else: 44 | app.push_screen(ProviderSelectScreen()) 45 | -------------------------------------------------------------------------------- /src/notebookllama/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field, model_validator 2 | from typing import List 3 | from typing_extensions import Self 4 | 5 | 6 | class Notebook(BaseModel): 7 | summary: str = Field( 8 | description="Summary of the document.", 9 | ) 10 | highlights: List[str] = Field( 11 | description="Highlights of the documents: 3 to 10 bullet points that represent the crucial knots of the documents.", 12 | min_length=3, 13 | max_length=10, 14 | ) 15 | questions: List[str] = Field( 16 | description="5 to 15 questions based on the topic of the document.", 17 | examples=[ 18 | [ 19 | "What is the capital of Spain?", 20 | "What is the capital of France?", 21 | "What is the capital of Italy?", 22 | "What is the capital of Portugal?", 23 | "What is the capital of Germany?", 24 | ] 25 | ], 26 | min_length=5, 27 | max_length=15, 28 | ) 29 | answers: List[str] = Field( 30 | description="Answers to the questions reported in the 'questions' field, in the same order.", 31 | examples=[ 32 | [ 33 | "Madrid", 34 | "Paris", 35 | "Rome", 36 | "Lisbon", 37 | "Berlin", 38 | ] 39 | ], 40 | min_length=5, 41 | max_length=15, 42 | ) 43 | 44 | @model_validator(mode="after") 45 | def validate_q_and_a(self) -> Self: 46 | if len(self.questions) != len(self.answers): 47 | raise ValueError("Questions and Answers must be of the same length.") 48 | return self 49 | -------------------------------------------------------------------------------- /tools/cli/screens/embedding_providers/azure.py: -------------------------------------------------------------------------------- 1 | from textual.app import ComposeResult 2 | from textual.widgets import Input 3 | 4 | from llama_index.embeddings.azure_inference import AzureAIEmbeddingsModel 5 | from llama_cloud import PipelineCreateEmbeddingConfig_AzureEmbedding 6 | 7 | from ..base import ConfigurationScreen 8 | 9 | 10 | class AzureEmbeddingScreen(ConfigurationScreen): 11 | """Configuration screen for Azure embeddings.""" 12 | 13 | def get_title(self) -> str: 14 | return "Azure Embedding Configuration" 15 | 16 | def get_form_elements(self) -> list[ComposeResult]: 17 | return [ 18 | Input( 19 | placeholder="API Key", 20 | password=True, 21 | id="api_key", 22 | classes="form-control", 23 | ), 24 | Input(placeholder="Endpoint URL", id="endpoint", classes="form-control"), 25 | ] 26 | 27 | def process_submission(self) -> None: 28 | api_key = self.query_one("#api_key", Input).value 29 | endpoint = self.query_one("#endpoint", Input).value 30 | 31 | if not all([api_key, endpoint]): 32 | self.notify("All fields are required.", severity="error") 33 | return 34 | 35 | try: 36 | embed_model = AzureAIEmbeddingsModel(credential=api_key, endpoint=endpoint) 37 | embedding_config = PipelineCreateEmbeddingConfig_AzureEmbedding( 38 | type="AZURE_EMBEDDING", component=embed_model 39 | ) 40 | 41 | self.app.config = embedding_config 42 | self.app.handle_completion(self.app.config) 43 | except Exception as e: 44 | self.notify(f"Error: {e}", severity="error") 45 | -------------------------------------------------------------------------------- /src/notebookllama/querying.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from dotenv import load_dotenv 4 | 5 | from llama_index.core.query_engine import CitationQueryEngine 6 | from llama_index.core.base.response.schema import Response 7 | from llama_index.llms.openai import OpenAIResponses 8 | from typing import Union, cast 9 | 10 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 11 | 12 | from notebookllama.utils import create_llamacloud_index 13 | 14 | load_dotenv() 15 | 16 | if ( 17 | os.getenv("LLAMACLOUD_API_KEY", None) 18 | and os.getenv("LLAMACLOUD_PIPELINE_ID", None) 19 | and os.getenv("OPENAI_API_KEY", None) 20 | ): 21 | LLM = OpenAIResponses(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY")) 22 | PIPELINE_ID = os.getenv("LLAMACLOUD_PIPELINE_ID") 23 | API_KEY = os.getenv("LLAMACLOUD_API_KEY") 24 | 25 | if API_KEY is None or PIPELINE_ID is None: 26 | raise ValueError("LLAMACLOUD_API_KEY and LLAMACLOUD_PIPELINE_ID must be set") 27 | 28 | index = create_llamacloud_index(api_key=API_KEY, pipeline_id=PIPELINE_ID) 29 | RETR = index.as_retriever() 30 | QE = CitationQueryEngine( 31 | retriever=RETR, 32 | llm=LLM, 33 | citation_chunk_size=256, 34 | citation_chunk_overlap=50, 35 | ) 36 | 37 | 38 | async def query_index(question: str) -> Union[str, None]: 39 | response = await QE.aquery(question) 40 | response = cast(Response, response) 41 | sources = [] 42 | if not response.response: 43 | return None 44 | if response.source_nodes is not None: 45 | sources = [node.text for node in response.source_nodes] 46 | return ( 47 | "## Answer\n\n" 48 | + response.response 49 | + "\n\n## Sources\n\n- " 50 | + "\n- ".join(sources) 51 | ) 52 | -------------------------------------------------------------------------------- /src/notebookllama/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from querying import query_index 4 | from processing import process_file 5 | from mindmap import get_mind_map 6 | from fastmcp import FastMCP 7 | from typing import List, Union, Literal 8 | 9 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 10 | 11 | 12 | mcp: FastMCP = FastMCP(name="MCP For NotebookLM") 13 | 14 | 15 | @mcp.tool( 16 | name="process_file_tool", 17 | description="This tool is useful to process files and produce summaries, question-answers and highlights.", 18 | ) 19 | async def process_file_tool( 20 | filename: str, 21 | ) -> Union[str, Literal["Sorry, your file could not be processed."]]: 22 | notebook_model, text = await process_file(filename=filename) 23 | if notebook_model is None: 24 | return "Sorry, your file could not be processed." 25 | if text is None: 26 | text = "" 27 | return notebook_model + "\n%separator%\n" + text 28 | 29 | 30 | @mcp.tool(name="get_mind_map_tool", description="This tool is useful to get a mind ") 31 | async def get_mind_map_tool( 32 | summary: str, highlights: List[str] 33 | ) -> Union[str, Literal["Sorry, mind map creation failed."]]: 34 | mind_map_fl = await get_mind_map(summary=summary, highlights=highlights) 35 | if mind_map_fl is None: 36 | return "Sorry, mind map creation failed." 37 | return mind_map_fl 38 | 39 | 40 | @mcp.tool(name="query_index_tool", description="Query a LlamaCloud index.") 41 | async def query_index_tool(question: str) -> str: 42 | response = await query_index(question=question) 43 | if response is None: 44 | return "Sorry, I was unable to find an answer to your question." 45 | return response 46 | 47 | 48 | if __name__ == "__main__": 49 | mcp.run(transport="streamable-http") 50 | -------------------------------------------------------------------------------- /tools/cli/screens/embedding_providers/cohere.py: -------------------------------------------------------------------------------- 1 | from textual.app import ComposeResult 2 | from textual.widgets import Input 3 | 4 | from llama_index.embeddings.cohere import CohereEmbedding 5 | from llama_cloud import PipelineCreateEmbeddingConfig_CohereEmbedding 6 | 7 | from ..base import ConfigurationScreen 8 | 9 | 10 | class CohereEmbeddingScreen(ConfigurationScreen): 11 | """Configuration screen for Cohere embeddings.""" 12 | 13 | def get_title(self) -> str: 14 | return "Cohere Embedding Configuration" 15 | 16 | def get_form_elements(self) -> list[ComposeResult]: 17 | return [ 18 | Input( 19 | placeholder="API Key", 20 | password=True, 21 | id="api_key", 22 | classes="form-control", 23 | ), 24 | Input(placeholder="Model", id="model", classes="form-control"), 25 | ] 26 | 27 | def process_submission(self) -> None: 28 | api_key = self.query_one("#api_key", Input).value 29 | model = self.query_one("#model", Input).value 30 | 31 | if not all([api_key, model]): 32 | self.notify("All fields are required", severity="error") 33 | return 34 | 35 | try: 36 | embed_model = CohereEmbedding( 37 | model_name=model, 38 | api_key=api_key, 39 | input_type="search_document", 40 | embedding_type="float", 41 | ) 42 | embedding_config = PipelineCreateEmbeddingConfig_CohereEmbedding( 43 | type="COHERE_EMBEDDING", component=embed_model 44 | ) 45 | 46 | self.app.config = embedding_config 47 | self.app.handle_completion(self.app.config) 48 | except Exception as e: 49 | self.notify(f"Error: {e}", severity="error") 50 | -------------------------------------------------------------------------------- /tools/cli/screens/embedding_providers/gemini.py: -------------------------------------------------------------------------------- 1 | from textual.app import ComposeResult 2 | from textual.widgets import Input 3 | 4 | from llama_index.embeddings.gemini import GeminiEmbedding 5 | from llama_cloud import PipelineCreateEmbeddingConfig_GeminiEmbedding 6 | 7 | from ..base import ConfigurationScreen 8 | 9 | 10 | class GeminiEmbeddingScreen(ConfigurationScreen): 11 | """Configuration screen for Gemini embeddings.""" 12 | 13 | # This is the only model currently available in the API 14 | DEFAULT_MODEL = "models/embedding-001" 15 | 16 | def get_title(self) -> str: 17 | return "Gemini Embedding Configuration" 18 | 19 | def get_form_elements(self) -> list[ComposeResult]: 20 | return [ 21 | Input( 22 | value=self.DEFAULT_MODEL, 23 | id="model", 24 | classes="form-control", 25 | disabled=True, 26 | ), 27 | Input( 28 | placeholder="API Key", 29 | password=True, 30 | id="api_key", 31 | classes="form-control", 32 | ), 33 | ] 34 | 35 | def process_submission(self) -> None: 36 | api_key = self.query_one("#api_key", Input).value 37 | 38 | if not api_key: 39 | self.notify("API Key is required", severity="error") 40 | return 41 | 42 | try: 43 | embed_model = GeminiEmbedding( 44 | api_key=api_key, model_name=self.DEFAULT_MODEL 45 | ) 46 | embedding_config = PipelineCreateEmbeddingConfig_GeminiEmbedding( 47 | type="GEMINI_EMBEDDING", component=embed_model 48 | ) 49 | 50 | self.app.config = embedding_config 51 | self.app.handle_completion(self.app.config) 52 | except Exception as e: 53 | self.notify(f"Error: {e}", severity="error") 54 | -------------------------------------------------------------------------------- /src/notebookllama/verifying.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | import json 3 | import os 4 | 5 | from pydantic import BaseModel, Field, model_validator 6 | from llama_index.core.llms import ChatMessage 7 | from llama_index.llms.openai import OpenAIResponses 8 | from typing import List, Tuple, Optional 9 | from typing_extensions import Self 10 | 11 | load_dotenv() 12 | 13 | 14 | class ClaimVerification(BaseModel): 15 | claim_is_true: bool = Field( 16 | description="Based on the provided sources information, the claim passes or not." 17 | ) 18 | supporting_citations: Optional[List[str]] = Field( 19 | description="A minimum of one and a maximum of three citations from the sources supporting the claim. If the claim is not supported, please leave empty", 20 | default=None, 21 | min_length=1, 22 | max_length=3, 23 | ) 24 | 25 | @model_validator(mode="after") 26 | def validate_claim_ver(self) -> Self: 27 | if not self.claim_is_true and self.supporting_citations is not None: 28 | self.supporting_citations = ["The claim was deemed false."] 29 | return self 30 | 31 | 32 | if os.getenv("OPENAI_API_KEY", None): 33 | LLM = OpenAIResponses(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY")) 34 | LLM_VERIFIER = LLM.as_structured_llm(ClaimVerification) 35 | 36 | 37 | def verify_claim( 38 | claim: str, 39 | sources: str, 40 | ) -> Tuple[bool, Optional[List[str]]]: 41 | response = LLM_VERIFIER.chat( 42 | [ 43 | ChatMessage( 44 | role="user", 45 | content=f"I have this claim: {claim} that is allegedgly supported by these sources:\n\n'''\n{sources}\n'''\n\nCan you please tell me whether or not this claim is thrutful and, if it is, identify one to three passages in the sources specifically supporting the claim?", 46 | ) 47 | ] 48 | ) 49 | response_json = json.loads(response.message.content) 50 | return response_json["claim_is_true"], response_json["supporting_citations"] 51 | -------------------------------------------------------------------------------- /tools/cli/screens/base.py: -------------------------------------------------------------------------------- 1 | from textual.app import ComposeResult 2 | from textual.containers import Container 3 | from textual.screen import Screen 4 | from textual.widgets import Label, Footer 5 | from textual.binding import Binding 6 | from textual.widgets import Input 7 | 8 | 9 | class BaseScreen(Screen): 10 | """Base screen with common functionality for all screens.""" 11 | 12 | BINDINGS = [ 13 | Binding("ctrl+q", "quit", "Exit", key_display="ctrl+q"), 14 | Binding("ctrl+d", "toggle_dark", "Toggle Dark Theme", key_display="ctrl+d"), 15 | ] 16 | 17 | def action_toggle_dark(self) -> None: 18 | self.app.theme = ( 19 | "textual-dark" if self.app.theme == "textual-light" else "textual-light" 20 | ) 21 | 22 | def action_quit(self) -> None: 23 | self.app.exit() 24 | 25 | def compose(self) -> ComposeResult: 26 | yield Container( 27 | Label(self.get_title(), classes="form-title"), 28 | *self.get_form_elements(), 29 | classes="form-container", 30 | ) 31 | yield Footer() 32 | 33 | def get_title(self) -> str: 34 | return "Base Screen" 35 | 36 | def get_form_elements(self) -> list[ComposeResult]: 37 | return [] 38 | 39 | 40 | class ConfigurationScreen(BaseScreen): 41 | """Base screen provider configuration with submit functionality.""" 42 | 43 | BINDINGS = BaseScreen.BINDINGS + [ 44 | Binding("shift+enter", "submit", "Submit"), 45 | ] 46 | 47 | def on_input_submitted(self, event: Input.Submitted) -> None: 48 | """Catches the Enter key press and delegates the work.""" 49 | self.process_submission() 50 | 51 | def process_submission(self) -> None: 52 | """ 53 | To be implemented by each specific provider screen. 54 | This method contains the unique logic for validating and creating 55 | the embedding configuration. 56 | """ 57 | 58 | raise NotImplementedError( 59 | "Each configuration screen must implement 'process_submission'" 60 | ) 61 | -------------------------------------------------------------------------------- /tools/create_llama_cloud_index.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import sys 4 | 5 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 6 | 7 | 8 | from dotenv import load_dotenv 9 | from cli.embedding_app import EmbeddingSetupApp 10 | from src.notebookllama.utils import create_llamacloud_client 11 | 12 | from llama_cloud import ( 13 | PipelineTransformConfig_Advanced, 14 | AdvancedModeTransformConfigChunkingConfig_Sentence, 15 | AdvancedModeTransformConfigSegmentationConfig_Page, 16 | PipelineCreate, 17 | ) 18 | 19 | 20 | def main(): 21 | """ 22 | Create a new Llama Cloud index with the given embedding configuration. 23 | """ 24 | load_dotenv() 25 | client = create_llamacloud_client() 26 | 27 | app = EmbeddingSetupApp() 28 | embedding_config = app.run() 29 | 30 | if embedding_config: 31 | segm_config = AdvancedModeTransformConfigSegmentationConfig_Page(mode="page") 32 | chunk_config = AdvancedModeTransformConfigChunkingConfig_Sentence( 33 | chunk_size=1024, 34 | chunk_overlap=200, 35 | separator="", 36 | paragraph_separator="\n\n\n", 37 | mode="sentence", 38 | ) 39 | 40 | transform_config = PipelineTransformConfig_Advanced( 41 | segmentation_config=segm_config, 42 | chunking_config=chunk_config, 43 | mode="advanced", 44 | ) 45 | 46 | pipeline_request = PipelineCreate( 47 | name="notebooklm_pipeline", 48 | embedding_config=embedding_config, 49 | transform_config=transform_config, 50 | ) 51 | 52 | pipeline = asyncio.run( 53 | client.pipelines.upsert_pipeline(request=pipeline_request) 54 | ) 55 | 56 | with open(".env", "a") as f: 57 | f.write(f'\nLLAMACLOUD_PIPELINE_ID="{pipeline.id}"') 58 | 59 | return 0 60 | else: 61 | print("No embedding configuration provided") 62 | return 1 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /tools/cli/screens/embedding_providers/openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from textual.app import ComposeResult 3 | from textual.widgets import Input 4 | 5 | from llama_index.embeddings.openai import OpenAIEmbedding 6 | from llama_cloud import PipelineCreateEmbeddingConfig_OpenaiEmbedding 7 | 8 | from ..base import ConfigurationScreen 9 | 10 | 11 | class OpenAIEmbeddingScreen(ConfigurationScreen): 12 | """Configuration screen for OpenAI embeddings.""" 13 | 14 | def get_title(self) -> str: 15 | return "OpenAI Embedding Configuration" 16 | 17 | def get_form_elements(self) -> list[ComposeResult]: 18 | return [ 19 | Input( 20 | placeholder="API Key (optional if OPENAI_API_KEY set)", 21 | id="api_key", 22 | password=True, 23 | classes="form-control", 24 | ), 25 | Input( 26 | placeholder="Model", 27 | id="model", 28 | classes="form-control", 29 | ), 30 | ] 31 | 32 | def process_submission(self) -> None: 33 | """Handle form submission by creating OpenAI embedding configuration.""" 34 | api_key = self.query_one("#api_key", Input).value or os.getenv("OPENAI_API_KEY") 35 | model = self.query_one("#model", Input).value 36 | 37 | if not api_key: 38 | self.notify( 39 | "No API key provided and OPENAI_API_KEY not set", severity="error" 40 | ) 41 | return 42 | if not model: 43 | self.notify("Model name is required", severity="error") 44 | return 45 | 46 | try: 47 | embed_model = OpenAIEmbedding(model=model, api_key=api_key) 48 | embedding_config = PipelineCreateEmbeddingConfig_OpenaiEmbedding( 49 | type="OPENAI_EMBEDDING", 50 | component=embed_model, 51 | ) 52 | self.app.config = embedding_config 53 | self.app.handle_completion(self.app.config) 54 | except Exception as e: 55 | self.notify(f"Error: {e}", severity="error") 56 | -------------------------------------------------------------------------------- /tools/cli/screens/embedding_providers/huggingface.py: -------------------------------------------------------------------------------- 1 | from textual.app import ComposeResult 2 | from textual.widgets import Input 3 | 4 | from llama_index.embeddings.huggingface_api import HuggingFaceInferenceAPIEmbedding 5 | from llama_cloud import PipelineCreateEmbeddingConfig_HuggingfaceApiEmbedding 6 | 7 | from ..base import ConfigurationScreen 8 | 9 | 10 | class HuggingFaceEmbeddingScreen(ConfigurationScreen): 11 | """Configuration screen for HuggingFace embeddings.""" 12 | 13 | def get_title(self) -> str: 14 | return "HuggingFace Embedding Configuration" 15 | 16 | def get_form_elements(self) -> list[ComposeResult]: 17 | return [ 18 | Input( 19 | placeholder="HuggingFace API Token", 20 | password=True, 21 | id="api_key", 22 | classes="form-control", 23 | ), 24 | Input( 25 | placeholder="Model name", 26 | id="model", 27 | classes="form-control", 28 | ), 29 | ] 30 | 31 | def process_submission(self) -> None: 32 | """Handle form submission by creating HuggingFace embedding configuration.""" 33 | api_key = self.query_one("#api_key", Input).value 34 | model = self.query_one("#model", Input).value 35 | 36 | if not api_key: 37 | self.notify("HuggingFace API Token is required", severity="error") 38 | return 39 | 40 | if not model: 41 | self.notify("Model name is required", severity="error") 42 | return 43 | 44 | try: 45 | embed_model = HuggingFaceInferenceAPIEmbedding( 46 | token=api_key, model_name=model 47 | ) 48 | embedding_config = PipelineCreateEmbeddingConfig_HuggingfaceApiEmbedding( 49 | type="HUGGINGFACE_API_EMBEDDING", 50 | component=embed_model, 51 | ) 52 | 53 | self.app.config = embedding_config 54 | self.app.handle_completion(self.app.config) 55 | except Exception as e: 56 | self.notify(f"Error: {e}", severity="error") 57 | -------------------------------------------------------------------------------- /tools/cli/screens/embedding_provider.py: -------------------------------------------------------------------------------- 1 | from textual import on 2 | from textual.widgets import Select 3 | 4 | from .base import BaseScreen 5 | from .embedding_providers import ( 6 | OpenAIEmbeddingScreen, 7 | BedrockEmbeddingScreen, 8 | AzureEmbeddingScreen, 9 | GeminiEmbeddingScreen, 10 | CohereEmbeddingScreen, 11 | HuggingFaceEmbeddingScreen, 12 | ) 13 | 14 | 15 | class ProviderSelectScreen(BaseScreen): 16 | """Screen for selecting embedding provider.""" 17 | 18 | def get_title(self) -> str: 19 | return "Select an embedding provider" 20 | 21 | def get_form_elements(self) -> list: 22 | return [ 23 | Select( 24 | options=[ 25 | ("OpenAI", "OpenAI"), 26 | ("Cohere", "Cohere"), 27 | ("Bedrock", "Bedrock"), 28 | ("HuggingFace", "HuggingFace"), 29 | ("Azure", "Azure"), 30 | ("Gemini", "Gemini"), 31 | ], 32 | prompt="Please select an embedding provider", 33 | classes="form-control", 34 | id="provider_select", 35 | ) 36 | ] 37 | 38 | @on(Select.Changed, "#provider_select") 39 | def handle_selection(self, event: Select.Changed) -> None: 40 | from ..embedding_app import EmbeddingSetupApp 41 | 42 | app = self.app 43 | if isinstance(app, EmbeddingSetupApp): 44 | app.config.provider = event.value 45 | self.handle_next() 46 | 47 | def handle_next(self) -> None: 48 | from ..embedding_app import EmbeddingSetupApp 49 | 50 | app = self.app 51 | if isinstance(app, EmbeddingSetupApp): 52 | provider_screens = { 53 | "OpenAI": OpenAIEmbeddingScreen, 54 | "Bedrock": BedrockEmbeddingScreen, 55 | "Azure": AzureEmbeddingScreen, 56 | "Gemini": GeminiEmbeddingScreen, 57 | "Cohere": CohereEmbeddingScreen, 58 | "HuggingFace": HuggingFaceEmbeddingScreen, 59 | } 60 | screen_class = provider_screens.get(app.config.provider) 61 | if screen_class: 62 | app.push_screen(screen_class()) 63 | -------------------------------------------------------------------------------- /data/test/md_sample.md: -------------------------------------------------------------------------------- 1 | # Project Status Report 2 | 3 | This document provides an overview of our current project status and key metrics. 4 | 5 | ## Introduction 6 | 7 | Our development team has been working on multiple initiatives this quarter. The following sections outline our progress and upcoming milestones. 8 | 9 | ## Team Performance Metrics 10 | 11 | The table below shows our team's performance across different projects: 12 | 13 | | Project Name | Status | Completion % | Assigned Developer | Due Date | 14 | | ------------------ | ----------- | ------------ | ------------------ | ---------- | 15 | | User Dashboard | In Progress | 75% | Alice Johnson | 2025-07-15 | 16 | | API Integration | Completed | 100% | Bob Smith | 2025-06-30 | 17 | | Mobile App | Planning | 25% | Carol Davis | 2025-08-20 | 18 | | Database Migration | In Progress | 60% | David Wilson | 2025-07-10 | 19 | | Security Audit | Not Started | 0% | Eve Brown | 2025-08-01 | 20 | 21 | | Project Name | Status | Completion % | Assigned Developer | Due Date | 22 | | ------------------ | ----------- | ------------ | ------------------ | ---------- | 23 | | User Dashboard | In Progress | 75% | Alice Johnson | 2025-07-15 | 24 | | API Integration | Completed | 100% | Bob Smith | 2025-06-30 | 25 | | Mobile App | Planning | 25% | Carol Davis | 2025-08-20 | 26 | | Database Migration | In Progress | 60% | David Wilson | 2025-07-10 | 27 | | Security Audit | Not Started | 0% | Eve Brown | 2025-08-01 | 28 | 29 | ## Key Observations 30 | 31 | Based on the data above, we can see that: 32 | 33 | - The API Integration project was completed on schedule 34 | - The User Dashboard is progressing well and should meet its deadline 35 | - The Database Migration needs attention to stay on track 36 | - We need to begin the Security Audit soon to meet the August deadline 37 | 38 | ## Next Steps 39 | 40 | 1. **Prioritize Database Migration** - Assign additional resources if needed 41 | 2. **Begin Security Audit planning** - Schedule kickoff meeting with Eve 42 | 3. **Continue monitoring** User Dashboard progress 43 | 4. **Celebrate** the successful API Integration completion 44 | 45 | ## Conclusion 46 | 47 | Overall, our team is performing well with most projects on track. The key focus areas for the coming weeks are the database migration and security audit preparation. 48 | 49 | --- 50 | 51 | _Last updated: July 6, 2025_ 52 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | default_language_version: 3 | python: python3 4 | 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.5.0 8 | hooks: 9 | - id: check-merge-conflict 10 | - id: check-symlinks 11 | - id: check-toml 12 | - id: check-yaml 13 | - id: detect-private-key 14 | - id: end-of-file-fixer 15 | - id: mixed-line-ending 16 | - id: trailing-whitespace 17 | 18 | - repo: https://github.com/charliermarsh/ruff-pre-commit 19 | rev: v0.11.8 20 | hooks: 21 | - id: ruff 22 | args: [--exit-non-zero-on-fix, --fix] 23 | - id: ruff-format 24 | exclude: ".*poetry.lock|.*_static|.*uv.lock|.*ipynb|.*docs.*" 25 | 26 | - repo: https://github.com/pre-commit/mirrors-mypy 27 | rev: v1.0.1 28 | hooks: 29 | - id: mypy 30 | additional_dependencies: 31 | [ 32 | "types-requests", 33 | "types-Deprecated", 34 | "types-redis", 35 | "types-setuptools", 36 | "types-PyYAML", 37 | "types-protobuf==4.24.0.4", 38 | ] 39 | args: 40 | [ 41 | --namespace-packages, 42 | --explicit-package-bases, 43 | --disallow-untyped-defs, 44 | --ignore-missing-imports, 45 | --python-version=3.9, 46 | ] 47 | entry: bash -c "export MYPYPATH=src/notebookllama" 48 | 49 | - repo: https://github.com/psf/black-pre-commit-mirror 50 | rev: 23.10.1 51 | hooks: 52 | - id: black-jupyter 53 | name: black-docs-py 54 | alias: black 55 | files: ^(README.md|CONTRIBUTING.md) 56 | # Using PEP 8's line length in docs prevents excess left/right scrolling 57 | args: [--line-length=79] 58 | 59 | - repo: https://github.com/adamchainz/blacken-docs 60 | rev: 1.16.0 61 | hooks: 62 | - id: blacken-docs 63 | name: black-docs-text 64 | alias: black 65 | types_or: [rst, markdown, tex] 66 | additional_dependencies: [black==23.10.1] 67 | # Using PEP 8's line length in docs prevents excess left/right scrolling 68 | args: [--line-length=79] 69 | 70 | - repo: https://github.com/pre-commit/mirrors-prettier 71 | rev: v3.0.3 72 | hooks: 73 | - id: prettier 74 | 75 | - repo: https://github.com/srstevenson/nb-clean 76 | rev: 3.1.0 77 | hooks: 78 | - id: nb-clean 79 | args: [--preserve-cell-outputs, --remove-empty-cells] 80 | 81 | - repo: https://github.com/pappasam/toml-sort 82 | rev: v0.23.1 83 | hooks: 84 | - id: toml-sort-fix 85 | -------------------------------------------------------------------------------- /tools/cli/screens/embedding_providers/bedrock.py: -------------------------------------------------------------------------------- 1 | from textual.app import ComposeResult 2 | from textual.widgets import Input, Select 3 | 4 | from llama_index.embeddings.bedrock import BedrockEmbedding 5 | from llama_cloud import PipelineCreateEmbeddingConfig_BedrockEmbedding 6 | 7 | from ..base import ConfigurationScreen 8 | 9 | 10 | class BedrockEmbeddingScreen(ConfigurationScreen): 11 | """Configuration screen for Bedrock embeddings.""" 12 | 13 | def get_title(self) -> str: 14 | return "Bedrock Embedding Configuration" 15 | 16 | def get_form_elements(self) -> list[ComposeResult]: 17 | model_options = [] 18 | try: 19 | supported_models = BedrockEmbedding.list_supported_models() 20 | model_options = [ 21 | (f"{provider.title()}: {model_id.split('.')[-1]}", model_id) 22 | for provider, models in supported_models.items() 23 | for model_id in models 24 | ] 25 | except Exception as e: 26 | self.notify( 27 | f"Could not fetch Bedrock models: {e}", severity="error", timeout=10 28 | ) 29 | 30 | return [ 31 | Select( 32 | options=model_options, 33 | prompt="Select Bedrock Model", 34 | id="model", 35 | classes="form-control", 36 | ), 37 | Input( 38 | placeholder="Region (e.g., us-east-1)", 39 | id="region", 40 | classes="form-control", 41 | ), 42 | Input( 43 | placeholder="Access Key ID (Optional)", 44 | id="access_key_id", 45 | classes="form-control", 46 | ), 47 | Input( 48 | placeholder="Secret Access Key (Optional)", 49 | password=True, 50 | id="secret_access_key", 51 | classes="form-control", 52 | ), 53 | ] 54 | 55 | def process_submission(self) -> None: 56 | model = self.query_one("#model", Select).value 57 | region = self.query_one("#region", Input).value 58 | access_key_id = self.query_one("#access_key_id", Input).value 59 | secret_access_key = self.query_one("#secret_access_key", Input).value 60 | 61 | if not all([model, region]): 62 | self.notify("All fields are required.", severity="error") 63 | return 64 | 65 | try: 66 | embed_model = BedrockEmbedding( 67 | model_name=model, 68 | region_name=region, 69 | aws_access_key_id=access_key_id, 70 | aws_secret_access_key=secret_access_key, 71 | ) 72 | embedding_config = PipelineCreateEmbeddingConfig_BedrockEmbedding( 73 | type="BEDROCK_EMBEDDING", component=embed_model 74 | ) 75 | 76 | self.app.config = embedding_config 77 | self.app.handle_completion(self.app.config) 78 | except Exception as e: 79 | self.notify(f"Error: {e}", severity="error") 80 | -------------------------------------------------------------------------------- /tests/test_sql_engine.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import pytest 3 | import pandas as pd 4 | import os 5 | from dotenv import load_dotenv 6 | 7 | from src.notebookllama.instrumentation import OtelTracesSqlEngine 8 | from sqlalchemy import text 9 | 10 | ENV = load_dotenv() 11 | 12 | 13 | def is_port_open(host: str, port: int, timeout: float = 2.0) -> bool: 14 | """Check if a TCP port is open on a given host.""" 15 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 16 | sock.settimeout(timeout) 17 | result = sock.connect_ex((host, port)) 18 | return result == 0 19 | 20 | 21 | @pytest.fixture() 22 | def otel_data() -> pd.DataFrame: 23 | return pd.DataFrame( 24 | { 25 | "trace_id": ["abc123", "abc123", "def456"], 26 | "span_id": ["span1", "span2", "span3"], 27 | "parent_span_id": [None, "span1", "span2"], 28 | "operation_name": [ 29 | "ServiceA.handle_request", 30 | "ServiceA.query_db", 31 | "ServiceB.send_email", 32 | ], 33 | "start_time": [1750618321000000, 1750618321000100, 1750618321000200], 34 | "duration": [150, 300, 500], 35 | "status_code": ["OK", "OK", "ERROR"], 36 | "service_name": ["service-a", "service-a", "service-b"], 37 | } 38 | ) 39 | 40 | 41 | @pytest.mark.skipif( 42 | condition=not is_port_open(host="localhost", port=5432) and not ENV, 43 | reason="Either Postgres is currently unavailable or you did not set any env variables in a .env file", 44 | ) 45 | def test_engine(otel_data: pd.DataFrame) -> None: 46 | engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}" 47 | sql_engine = OtelTracesSqlEngine(engine_url=engine_url, table_name="test") 48 | res = sql_engine.execute(text("DROP TABLE IF EXISTS test;")) 49 | res.close() 50 | sql_engine._to_sql(dataframe=otel_data) 51 | res1 = sql_engine.execute( 52 | text( 53 | "SELECT span_id, operation_name, duration FROM test WHERE status_code = 'ERROR'" 54 | ) 55 | ) 56 | res2 = sql_engine.execute( 57 | text( 58 | "SELECT service_name, AVG(duration) AS avg_duration FROM test GROUP BY service_name;" 59 | ) 60 | ) 61 | res1_data = res1.fetchall() 62 | res2_data = res2.fetchall() 63 | 64 | # Compare just the values 65 | assert len(res1_data) == 1 66 | assert res1_data[0].span_id == "span3" 67 | assert res1_data[0].operation_name == "ServiceB.send_email" 68 | assert res1_data[0].duration == 500 69 | 70 | assert len(res2_data) == 2 71 | # Sort by service_name for consistent comparison 72 | res2_sorted = sorted(res2_data, key=lambda x: x.service_name) 73 | assert res2_sorted[0].service_name == "service-a" 74 | assert res2_sorted[0].avg_duration == 225.0 75 | assert res2_sorted[1].service_name == "service-b" 76 | assert res2_sorted[1].avg_duration == 500.0 77 | assert isinstance(sql_engine.to_pandas(), pd.DataFrame) 78 | res3 = sql_engine.execute( 79 | text( 80 | "SELECT service_name, AVG(duration) AS avg_duration FROM test GROUP BY service_name;" 81 | ), 82 | return_pandas=True, 83 | ) 84 | assert isinstance(res3, pd.DataFrame) 85 | -------------------------------------------------------------------------------- /tests/test_workflow.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import json 3 | 4 | from pydantic import ValidationError 5 | from src.notebookllama.workflow import ( 6 | NotebookLMWorkflow, 7 | MindMapCreationEvent, 8 | NotebookOutputEvent, 9 | ) 10 | from src.notebookllama.models import Notebook 11 | from workflows import Workflow 12 | 13 | 14 | @pytest.fixture 15 | def notebook_to_process() -> Notebook: 16 | return Notebook( 17 | summary="""The Human Brain: 18 | The human brain is a complex organ responsible for thought, memory, emotion, and coordination. It contains about 86 billion neurons and operates through electrical and chemical signals. Divided into major parts like the cerebrum, cerebellum, and brainstem, it controls everything from basic survival functions to advanced reasoning. Despite its size, it consumes around 20% of the body’s energy. Neuroscience continues to explore its mysteries, including consciousness and neuroplasticity—its ability to adapt and reorganize.""", 19 | questions=[ 20 | "How many neurons are in the human brain?", 21 | "What are the main parts of the human brain?", 22 | "What percentage of the body's energy does the brain use?", 23 | "What is neuroplasticity?", 24 | "What functions is the human brain responsible for?", 25 | ], 26 | answers=[ 27 | "About 86 billion neurons.", 28 | "The cerebrum, cerebellum, and brainstem.", 29 | "Around 20%.", 30 | "The brain's ability to adapt and reorganize itself.", 31 | "Thought, memory, emotion, and coordination.", 32 | ], 33 | highlights=[ 34 | "The human brain has about 86 billion neurons.", 35 | "It controls thought, memory, emotion, and coordination.", 36 | "Major brain parts include the cerebrum, cerebellum, and brainstem.", 37 | "The brain uses approximately 20% of the body's energy.", 38 | "Neuroplasticity allows the brain to adapt and reorganize.", 39 | ], 40 | ) 41 | 42 | 43 | def test_init() -> None: 44 | wf = NotebookLMWorkflow(disable_validation=True) 45 | assert isinstance(wf, Workflow) 46 | assert list(wf._get_steps().keys()) == [ 47 | "_done", 48 | "extract_file_data", 49 | "generate_mind_map", 50 | ] 51 | 52 | 53 | def test_mind_map_event_ser(notebook_to_process: Notebook) -> None: 54 | json_str = notebook_to_process.model_dump_json() 55 | json_rep = json.loads(json_str) 56 | try: 57 | model = MindMapCreationEvent( 58 | md_content="Hello", 59 | **json_rep, 60 | ) 61 | except ValidationError: 62 | model = None 63 | assert isinstance(model, MindMapCreationEvent) 64 | assert model.md_content == "Hello" 65 | assert model.model_dump(exclude="md_content") == notebook_to_process.model_dump() 66 | try: 67 | model1 = NotebookOutputEvent( 68 | mind_map="map.html", 69 | **model.model_dump( 70 | include={ 71 | "summary", 72 | "highlights", 73 | "questions", 74 | "answers", 75 | "md_content", 76 | } 77 | ), 78 | ) 79 | except ValidationError: 80 | model1 = None 81 | assert isinstance(model1, NotebookOutputEvent) 82 | assert model1.mind_map == "map.html" 83 | assert model1.model_dump(exclude="mind_map") == model.model_dump() 84 | -------------------------------------------------------------------------------- /tests/test_document_management.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import socket 4 | from dotenv import load_dotenv 5 | from typing import List 6 | 7 | from src.notebookllama.documents import DocumentManager, ManagedDocument 8 | from sqlalchemy import text, Table 9 | 10 | ENV = load_dotenv() 11 | 12 | 13 | def is_port_open(host: str, port: int, timeout: float = 2.0) -> bool: 14 | """Check if a TCP port is open on a given host.""" 15 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 16 | sock.settimeout(timeout) 17 | result = sock.connect_ex((host, port)) 18 | return result == 0 19 | 20 | 21 | @pytest.fixture 22 | def documents() -> List[ManagedDocument]: 23 | return [ 24 | ManagedDocument( 25 | document_name="Project Plan", 26 | content="This is the full content of the project plan document.", 27 | summary="A summary of the project plan.", 28 | q_and_a="Q: What is the goal? A: To deliver the project.", 29 | mindmap="Project -> Tasks -> Timeline", 30 | bullet_points="• Define scope\n• Assign tasks\n• Set deadlines", 31 | ), 32 | ManagedDocument( 33 | document_name="Meeting Notes", 34 | content="Notes from the weekly team meeting.", 35 | summary="Summary of meeting discussions.", 36 | q_and_a="Q: Who attended? A: All team members.", 37 | mindmap="Meeting -> Topics -> Decisions", 38 | bullet_points="• Discussed progress\n• Identified blockers\n• Planned next steps", 39 | ), 40 | ManagedDocument( 41 | document_name="Research Article", 42 | content="Content of the research article goes here.", 43 | summary="Key findings from the research.", 44 | q_and_a="Q: What was discovered? A: New insights into the topic.", 45 | mindmap="Research -> Methods -> Results", 46 | bullet_points="• Literature review\n• Data analysis\n• Conclusions", 47 | ), 48 | ManagedDocument( 49 | document_name="User Guide", 50 | content="Instructions for using the application.", 51 | summary="Overview of user guide contents.", 52 | q_and_a="Q: How to start? A: Follow the setup instructions.", 53 | mindmap="Guide -> Sections -> Steps", 54 | bullet_points="• Installation\n• Configuration\n• Usage tips", 55 | ), 56 | ] 57 | 58 | 59 | @pytest.mark.skipif( 60 | condition=not is_port_open(host="localhost", port=5432) and not ENV, 61 | reason="Either Postgres is currently unavailable or you did not set any env variables in a .env file", 62 | ) 63 | def test_document_manager(documents: List[ManagedDocument]) -> None: 64 | engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}" 65 | manager = DocumentManager(engine_url=engine_url, table_name="test_documents") 66 | assert not manager.table 67 | manager.connection.execute(text("DROP TABLE IF EXISTS test_documents;")) 68 | manager.connection.commit() 69 | manager._create_table() 70 | assert isinstance(manager.table, Table) 71 | manager.put_documents(documents=documents) 72 | names = manager.get_names() 73 | assert names == [doc.document_name for doc in documents] 74 | docs = manager.get_documents() 75 | assert docs == documents 76 | docs1 = manager.get_documents(names=["Project Plan", "Meeting Notes"]) 77 | assert len(docs1) == 2 78 | -------------------------------------------------------------------------------- /src/notebookllama/workflow.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from workflows import Workflow, step, Context 4 | from workflows.events import StartEvent, StopEvent, Event 5 | from workflows.resource import Resource 6 | from llama_index.tools.mcp import BasicMCPClient 7 | from typing import Annotated, List, Union 8 | 9 | MCP_CLIENT = BasicMCPClient(command_or_url="http://localhost:8000/mcp", timeout=120) 10 | 11 | 12 | class FileInputEvent(StartEvent): 13 | file: str 14 | 15 | 16 | class NotebookOutputEvent(StopEvent): 17 | mind_map: str 18 | md_content: str 19 | summary: str 20 | highlights: List[str] 21 | questions: List[str] 22 | answers: List[str] 23 | 24 | 25 | class MindMapCreationEvent(Event): 26 | summary: str 27 | highlights: List[str] 28 | questions: List[str] 29 | answers: List[str] 30 | md_content: str 31 | 32 | 33 | def get_mcp_client(*args, **kwargs) -> BasicMCPClient: 34 | return MCP_CLIENT 35 | 36 | 37 | class NotebookLMWorkflow(Workflow): 38 | @step 39 | async def extract_file_data( 40 | self, 41 | ev: FileInputEvent, 42 | mcp_client: Annotated[BasicMCPClient, Resource(get_mcp_client)], 43 | ctx: Context, 44 | ) -> Union[MindMapCreationEvent, NotebookOutputEvent]: 45 | ctx.write_event_to_stream( 46 | ev=ev, 47 | ) 48 | result = await mcp_client.call_tool( 49 | tool_name="process_file_tool", arguments={"filename": ev.file} 50 | ) 51 | split_result = result.content[0].text.split("\n%separator%\n") 52 | json_data = split_result[0] 53 | md_text = split_result[1] 54 | if json_data == "Sorry, your file could not be processed.": 55 | return NotebookOutputEvent( 56 | mind_map="Unprocessable file, sorry😭", 57 | md_content="", 58 | summary="", 59 | highlights=[], 60 | questions=[], 61 | answers=[], 62 | ) 63 | json_rep = json.loads(json_data) 64 | return MindMapCreationEvent( 65 | md_content=md_text, 66 | **json_rep, 67 | ) 68 | 69 | @step 70 | async def generate_mind_map( 71 | self, 72 | ev: MindMapCreationEvent, 73 | mcp_client: Annotated[BasicMCPClient, Resource(get_mcp_client)], 74 | ctx: Context, 75 | ) -> NotebookOutputEvent: 76 | ctx.write_event_to_stream( 77 | ev=ev, 78 | ) 79 | result = await mcp_client.call_tool( 80 | tool_name="get_mind_map_tool", 81 | arguments={"summary": ev.summary, "highlights": ev.highlights}, 82 | ) 83 | if result is not None: 84 | return NotebookOutputEvent( 85 | mind_map=result.content[0].text, 86 | **ev.model_dump( 87 | include={ 88 | "summary", 89 | "highlights", 90 | "questions", 91 | "answers", 92 | "md_content", 93 | } 94 | ), 95 | ) 96 | return NotebookOutputEvent( 97 | mind_map="Sorry, mind map creation failed😭", 98 | **ev.model_dump( 99 | include={ 100 | "summary", 101 | "highlights", 102 | "questions", 103 | "answers", 104 | "md_content", 105 | } 106 | ), 107 | ) 108 | -------------------------------------------------------------------------------- /src/notebookllama/mindmap.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import os 3 | import warnings 4 | import json 5 | from pydantic import BaseModel, Field, model_validator 6 | from typing_extensions import Self 7 | from typing import List, Union 8 | 9 | from pyvis.network import Network 10 | from llama_index.core.llms import ChatMessage 11 | from llama_index.llms.openai import OpenAIResponses 12 | 13 | 14 | class Node(BaseModel): 15 | id: str 16 | content: str 17 | 18 | 19 | class Edge(BaseModel): 20 | from_id: str 21 | to_id: str 22 | 23 | 24 | class MindMap(BaseModel): 25 | nodes: List[Node] = Field( 26 | description="List of nodes in the mind map, each represented as a Node object with an 'id' and concise 'content' (no more than 5 words).", 27 | examples=[ 28 | [ 29 | Node(id="A", content="Fall of the Roman Empire"), 30 | Node(id="B", content="476 AD"), 31 | Node(id="C", content="Barbarian invasions"), 32 | ], 33 | [ 34 | Node(id="A", content="Auxin is released"), 35 | Node(id="B", content="Travels to the roots"), 36 | Node(id="C", content="Root cells grow"), 37 | ], 38 | ], 39 | ) 40 | edges: List[Edge] = Field( 41 | description="The edges connecting the nodes of the mind map, as a list of Edge objects with from_id and to_id fields representing the source and target node IDs.", 42 | examples=[ 43 | [ 44 | Edge(from_id="A", to_id="B"), 45 | Edge(from_id="A", to_id="C"), 46 | Edge(from_id="B", to_id="C"), 47 | ], 48 | [ 49 | Edge(from_id="C", to_id="A"), 50 | Edge(from_id="B", to_id="C"), 51 | Edge(from_id="A", to_id="B"), 52 | ], 53 | ], 54 | ) 55 | 56 | @model_validator(mode="after") 57 | def validate_mind_map(self) -> Self: 58 | all_nodes = [el.id for el in self.nodes] 59 | all_edges = [el.from_id for el in self.edges] + [el.to_id for el in self.edges] 60 | if set(all_nodes).issubset(set(all_edges)) and set(all_nodes) != set(all_edges): 61 | raise ValueError( 62 | "There are non-existing nodes listed as source or target in the edges" 63 | ) 64 | return self 65 | 66 | 67 | class MindMapCreationFailedWarning(Warning): 68 | """A warning returned if the mind map creation failed""" 69 | 70 | 71 | if os.getenv("OPENAI_API_KEY", None): 72 | LLM = OpenAIResponses(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY")) 73 | LLM_STRUCT = LLM.as_structured_llm(MindMap) 74 | 75 | 76 | async def get_mind_map(summary: str, highlights: List[str]) -> Union[str, None]: 77 | try: 78 | keypoints = "\n- ".join(highlights) 79 | messages = [ 80 | ChatMessage( 81 | role="user", 82 | content=f"This is the summary for my document: {summary}\n\nAnd these are the key points:\n- {keypoints}", 83 | ) 84 | ] 85 | response = await LLM_STRUCT.achat(messages=messages) 86 | response_json = json.loads(response.message.content) 87 | net = Network(directed=True, height="750px", width="100%") 88 | net.set_options(""" 89 | var options = { 90 | "physics": { 91 | "enabled": false 92 | } 93 | } 94 | """) 95 | nodes = response_json["nodes"] 96 | edges = response_json["edges"] 97 | for node in nodes: 98 | net.add_node(n_id=node["id"], label=node["content"]) 99 | for edge in edges: 100 | net.add_edge(source=edge["from_id"], to=edge["to_id"]) 101 | name = str(uuid.uuid4()) 102 | net.save_graph(name + ".html") 103 | return name + ".html" 104 | except Exception as e: 105 | warnings.warn( 106 | message=f"An error occurred during the creation of the mind map: {e}", 107 | category=MindMapCreationFailedWarning, 108 | ) 109 | return None 110 | -------------------------------------------------------------------------------- /src/notebookllama/pages/1_Document_Management_UI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import streamlit as st 3 | import streamlit.components.v1 as components 4 | from dotenv import load_dotenv 5 | from typing import List, Optional 6 | 7 | from documents import DocumentManager, ManagedDocument 8 | 9 | # Load environment variables 10 | load_dotenv() 11 | 12 | # Initialize the document manager 13 | engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}" 14 | document_manager = DocumentManager(engine_url=engine_url) 15 | 16 | 17 | def fetch_documents(names: Optional[List[str]]) -> List[ManagedDocument]: 18 | """Retrieve documents from the database""" 19 | return document_manager.get_documents(names=names) 20 | 21 | 22 | def fetch_document_names() -> List[str]: 23 | return document_manager.get_names() 24 | 25 | 26 | def display_document(document: ManagedDocument) -> None: 27 | """Display a single document in an expandable format""" 28 | with st.expander(f"📄 {document.document_name}"): 29 | # Summary section 30 | st.markdown("## Summary") 31 | st.markdown(document.summary) 32 | 33 | # Bullet Points section 34 | st.markdown(document.bullet_points) 35 | 36 | # FAQ section (nested expander) 37 | with st.expander("FAQ"): 38 | st.markdown(document.q_and_a) 39 | 40 | # Mind Map section 41 | if document.mindmap: 42 | st.markdown("## Mind Map") 43 | components.html(document.mindmap, height=800, scrolling=True) 44 | 45 | 46 | def main(): 47 | # Display the network 48 | st.set_page_config( 49 | page_title="NotebookLlaMa - Document Management", 50 | page_icon="📚", 51 | layout="wide", 52 | menu_items={ 53 | "Get Help": "https://github.com/run-llama/notebooklm-clone/discussions/categories/general", 54 | "Report a bug": "https://github.com/run-llama/notebooklm-clone/issues/", 55 | "About": "An OSS alternative to NotebookLM that runs with the power of a flully Llama!", 56 | }, 57 | ) 58 | st.sidebar.header("Document Management📚") 59 | st.sidebar.info("To switch to the other pages, select it from above!🔺") 60 | st.markdown("---") 61 | st.markdown("## NotebookLlaMa - Document Management📚") 62 | 63 | # Slider for number of documents 64 | names = st.multiselect( 65 | options=fetch_document_names(), 66 | default=None, 67 | label="Select the Documents you want to display", 68 | ) 69 | 70 | # Button to load documents 71 | if st.button("Load Documents", type="primary"): 72 | with st.spinner("Loading documents..."): 73 | try: 74 | documents = fetch_documents(names) 75 | 76 | if documents: 77 | st.success(f"Successfully loaded {len(documents)} document(s)") 78 | st.session_state.documents = documents 79 | else: 80 | st.warning("No documents found in the database.") 81 | st.session_state.documents = [] 82 | 83 | except Exception as e: 84 | st.error(f"Error loading documents: {str(e)}") 85 | st.session_state.documents = [] 86 | 87 | # Display documents if they exist in session state 88 | if "documents" in st.session_state and st.session_state.documents: 89 | st.markdown("## Documents") 90 | 91 | # Display each document 92 | for i, document in enumerate(st.session_state.documents): 93 | display_document(document) 94 | 95 | # Add some spacing between documents 96 | if i < len(st.session_state.documents) - 1: 97 | st.markdown("---") 98 | 99 | elif "documents" in st.session_state: 100 | st.info( 101 | "No documents to display. Try adjusting the limit and clicking 'Load Documents'." 102 | ) 103 | 104 | else: 105 | st.info("Click 'Load Documents' to view your processed documents.") 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NotebookLlaMa🦙 2 | 3 | ## A fluffy and open-source alternative to NotebookLM! 4 | 5 | https://github.com/user-attachments/assets/7e9cca45-8a4c-4dfa-98d2-2cef147422f2 6 | 7 |

8 | A fully open-source alternative to NotebookLM, backed by LlamaCloud. 9 |

10 | 11 |

12 | License 13 | Stars 14 | Issues 15 |
16 | MseeP.ai Security Assessment Badge 17 |

18 | 19 | ### Prerequisites 20 | 21 | This project uses `uv` to manage dependencies. Before you begin, make sure you have `uv` installed. 22 | 23 | On macOS and Linux: 24 | 25 | ```bash 26 | curl -LsSf https://astral.sh/uv/install.sh | sh 27 | ``` 28 | 29 | On Windows: 30 | 31 | ```powershell 32 | powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" 33 | ``` 34 | 35 | For more install options, see `uv`'s [official documentation](https://docs.astral.sh/uv/getting-started/installation/). 36 | 37 | --- 38 | 39 | ### Get it up and running! 40 | 41 | **1. Clone the Repository** 42 | 43 | ```bash 44 | git clone https://github.com/run-llama/notebookllama 45 | cd notebookllama/ 46 | ``` 47 | 48 | **2. Install Dependencies** 49 | 50 | ```bash 51 | uv sync 52 | ``` 53 | 54 | **3. Configure API Keys** 55 | 56 | First, create your `.env` file by renaming the example file: 57 | 58 | ```bash 59 | mv .env.example .env 60 | ``` 61 | 62 | Next, open the `.env` file and add your API keys: 63 | 64 | - `OPENAI_API_KEY`: find it [on OpenAI Platform](https://platform.openai.com/api-keys) 65 | - `ELEVENLABS_API_KEY`: find it [on ElevenLabs Settings](https://elevenlabs.io/app/settings/api-keys) 66 | - `LLAMACLOUD_API_KEY`: find it [on LlamaCloud Dashboard](https://cloud.llamaindex.ai?utm_source=demo&utm_medium=notebookLM) 67 | 68 | > **🌍 Regional Support**: LlamaCloud operates in multiple regions. If you're using a European region, configure it in your `.env` file: 69 | > 70 | > - For **North America**: This is the default region - no configuration necesary. 71 | > - For **Europe (EU)**: Uncomment and set `LLAMACLOUD_REGION="eu"` 72 | 73 | **4. Activate the Virtual Environment** 74 | 75 | (on mac/unix) 76 | 77 | ```bash 78 | source .venv/bin/activate 79 | ``` 80 | 81 | (on Windows): 82 | 83 | ```bash 84 | .\.venv\Scripts\activate 85 | ``` 86 | 87 | **5. Create LlamaCloud Agent & Pipeline** 88 | 89 | You will now execute two scripts to configure your backend agents and pipelines. 90 | 91 | First, create the data extraction agent: 92 | 93 | ```bash 94 | uv run tools/create_llama_extract_agent.py 95 | ``` 96 | 97 | Next, run the interactive setup wizard to configure your index pipeline. 98 | 99 | > **⚡ Quick Start (Default OpenAI):** 100 | > For the fastest setup, select **"With Default Settings"** when prompted. This will automatically create a pipeline using OpenAI's `text-embedding-3-small` embedding model. 101 | 102 | > **🧠 Advanced (Custom Embedding Models):** 103 | > To use a different embedding model, select **"With Custom Settings"** and follow the on-screen instructions. 104 | 105 | Run the wizard with the following command: 106 | 107 | ```bash 108 | uv run tools/create_llama_cloud_index.py 109 | ``` 110 | 111 | **6. Launch Backend Services** 112 | 113 | This command will start the required Postgres and Jaeger containers. 114 | 115 | ```bash 116 | docker compose up -d 117 | ``` 118 | 119 | **7. Run the Application** 120 | 121 | First, run the **MCP** server: 122 | 123 | ```bash 124 | uv run src/notebookllama/server.py 125 | ``` 126 | 127 | Then, in a **new terminal window**, launch the Streamlit app: 128 | 129 | ```bash 130 | streamlit run src/notebookllama/Home.py 131 | ``` 132 | 133 | > [!IMPORTANT] 134 | > 135 | > _You might need to install `ffmpeg` if you do not have it installed already_ 136 | 137 | And start exploring the app at `http://localhost:8501/`. 138 | 139 | --- 140 | 141 | ### Contributing 142 | 143 | Contribute to this project following the [guidelines](./CONTRIBUTING.md). 144 | 145 | ### License 146 | 147 | This project is provided under an [MIT License](./LICENSE). 148 | -------------------------------------------------------------------------------- /src/notebookllama/documents.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from sqlalchemy import ( 3 | Table, 4 | MetaData, 5 | Column, 6 | Text, 7 | Integer, 8 | create_engine, 9 | Engine, 10 | Connection, 11 | insert, 12 | select, 13 | ) 14 | from typing import Optional, List, cast, Union 15 | 16 | 17 | def apply_string_correction(string: str) -> str: 18 | return string.replace("''", "'").replace('""', '"') 19 | 20 | 21 | @dataclass 22 | class ManagedDocument: 23 | document_name: str 24 | content: str 25 | summary: str 26 | q_and_a: str 27 | mindmap: str 28 | bullet_points: str 29 | 30 | 31 | class DocumentManager: 32 | def __init__( 33 | self, 34 | engine: Optional[Engine] = None, 35 | engine_url: Optional[str] = None, 36 | table_name: Optional[str] = None, 37 | table_metadata: Optional[MetaData] = None, 38 | ): 39 | self.table_name: str = table_name or "documents" 40 | self._table: Optional[Table] = None 41 | self._connection: Optional[Connection] = None 42 | self.metadata: MetaData = cast(MetaData, table_metadata or MetaData()) 43 | if engine or engine_url: 44 | self._engine: Union[Engine, str] = cast( 45 | Union[Engine, str], engine or engine_url 46 | ) 47 | else: 48 | raise ValueError("One of engine or engine_setup_kwargs must be set") 49 | 50 | @property 51 | def connection(self) -> Connection: 52 | if not self._connection: 53 | self._connect() 54 | return cast(Connection, self._connection) 55 | 56 | @property 57 | def table(self) -> Table: 58 | if self._table is None: 59 | self._create_table() 60 | return cast(Table, self._table) 61 | 62 | def _connect(self) -> None: 63 | # move network calls outside of constructor 64 | if isinstance(self._engine, str): 65 | self._engine = create_engine(self._engine) 66 | self._connection = self._engine.connect() 67 | 68 | def _create_table(self) -> None: 69 | self._table = Table( 70 | self.table_name, 71 | self.metadata, 72 | Column("id", Integer, primary_key=True, autoincrement=True), 73 | Column("document_name", Text), 74 | Column("content", Text), 75 | Column("summary", Text), 76 | Column("q_and_a", Text), 77 | Column("mindmap", Text), 78 | Column("bullet_points", Text), 79 | ) 80 | self._table.create(self.connection, checkfirst=True) 81 | 82 | def put_documents(self, documents: List[ManagedDocument]) -> None: 83 | for document in documents: 84 | stmt = insert(self.table).values( 85 | document_name=document.document_name, 86 | content=document.content, 87 | summary=document.summary, 88 | q_and_a=document.q_and_a, 89 | mindmap=document.mindmap, 90 | bullet_points=document.bullet_points, 91 | ) 92 | self.connection.execute(stmt) 93 | self.connection.commit() 94 | 95 | def get_documents(self, names: Optional[List[str]] = None) -> List[ManagedDocument]: 96 | if self.table is None: 97 | self._create_table() 98 | if not names: 99 | stmt = select(self.table).order_by(self.table.c.id) 100 | else: 101 | stmt = ( 102 | select(self.table) 103 | .where(self.table.c.document_name.in_(names)) 104 | .order_by(self.table.c.id) 105 | ) 106 | result = self.connection.execute(stmt) 107 | rows = result.fetchall() 108 | documents = [] 109 | for row in rows: 110 | documents.append( 111 | ManagedDocument( 112 | document_name=row.document_name, 113 | content=row.content, 114 | summary=row.summary, 115 | q_and_a=row.q_and_a, 116 | mindmap=row.mindmap, 117 | bullet_points=row.bullet_points, 118 | ) 119 | ) 120 | return documents 121 | 122 | def get_names(self) -> List[str]: 123 | if self.table is None: 124 | self._create_table() 125 | stmt = select(self.table) 126 | result = self.connection.execute(stmt) 127 | rows = result.fetchall() 128 | return [row.document_name for row in rows] 129 | 130 | def disconnect(self) -> None: 131 | if not self._connection: 132 | raise ValueError("Engine was never connected!") 133 | if isinstance(self._engine, str): 134 | pass 135 | else: 136 | self._engine.dispose(close=True) 137 | -------------------------------------------------------------------------------- /src/notebookllama/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Dict, Any 3 | from llama_cloud.client import AsyncLlamaCloud 4 | from llama_cloud_services import LlamaExtract, LlamaParse 5 | from llama_index.indices.managed.llama_cloud import LlamaCloudIndex 6 | 7 | # LlamaCloud regional endpoints 8 | LLAMACLOUD_REGIONS = { 9 | "default": "https://api.cloud.llamaindex.ai", # North America (default) 10 | "eu": "https://api.cloud.eu.llamaindex.ai", # Europe 11 | } 12 | 13 | 14 | class LlamaCloudConfigError(Exception): 15 | """Raised when LlamaCloud configuration is invalid.""" 16 | 17 | pass 18 | 19 | 20 | def get_llamacloud_base_url() -> Optional[str]: 21 | """ 22 | Get the appropriate LlamaCloud base URL based on region configuration. 23 | 24 | Defaults to North America region and returns the North America endpoint URL 25 | when no region is specified. 26 | 27 | Returns: 28 | str: The base URL for LlamaCloud API 29 | 30 | Raises: 31 | LlamaCloudConfigError: If an invalid region is specified 32 | """ 33 | # Direct base URL override takes precedence 34 | base_url = os.getenv("LLAMACLOUD_BASE_URL") 35 | if base_url: 36 | return base_url 37 | 38 | region = os.getenv("LLAMACLOUD_REGION", "default").lower().strip() 39 | 40 | if region not in LLAMACLOUD_REGIONS: 41 | valid_regions = ", ".join(LLAMACLOUD_REGIONS.keys()) 42 | raise LlamaCloudConfigError( 43 | f"Invalid LLAMACLOUD_REGION '{region}'. Supported regions: {valid_regions}" 44 | ) 45 | 46 | return LLAMACLOUD_REGIONS[region] 47 | 48 | 49 | def get_llamacloud_config() -> Dict[str, Any]: 50 | """ 51 | Get LlamaCloud configuration including base URL. 52 | 53 | Returns: 54 | dict: Configuration dictionary with token and optional base_url 55 | 56 | Raises: 57 | LlamaCloudConfigError: If API key is missing or region is invalid 58 | """ 59 | token = os.getenv("LLAMACLOUD_API_KEY") 60 | if not token: 61 | raise LlamaCloudConfigError( 62 | "LLAMACLOUD_API_KEY environment variable is required" 63 | ) 64 | 65 | config = {"token": token} 66 | 67 | base_url = get_llamacloud_base_url() 68 | if base_url: 69 | config["base_url"] = base_url 70 | 71 | return config 72 | 73 | 74 | def create_llamacloud_client() -> AsyncLlamaCloud: 75 | """ 76 | Create a configured AsyncLlamaCloud client with regional support. 77 | 78 | Returns: 79 | AsyncLlamaCloud: Configured client instance 80 | 81 | Raises: 82 | LlamaCloudConfigError: If API key is missing or region is invalid 83 | """ 84 | config = get_llamacloud_config() 85 | return AsyncLlamaCloud(**config) 86 | 87 | 88 | def create_llama_extract_client() -> LlamaExtract: 89 | """ 90 | Create a configured LlamaExtract client with regional support. 91 | 92 | Returns: 93 | LlamaExtract: Configured client instance 94 | 95 | Raises: 96 | LlamaCloudConfigError: If API key is missing or region is invalid 97 | """ 98 | api_key = os.getenv("LLAMACLOUD_API_KEY") 99 | if not api_key: 100 | raise LlamaCloudConfigError( 101 | "LLAMACLOUD_API_KEY environment variable is required" 102 | ) 103 | 104 | base_url = get_llamacloud_base_url() 105 | return LlamaExtract(api_key=api_key, base_url=base_url) 106 | 107 | 108 | def create_llama_parse_client(result_type: str = "markdown") -> LlamaParse: 109 | """ 110 | Create a configured LlamaParse client with regional support. 111 | 112 | Args: 113 | result_type: The result type for parsing (default: "markdown") 114 | 115 | Returns: 116 | LlamaParse: Configured client instance 117 | 118 | Raises: 119 | LlamaCloudConfigError: If API key is missing or region is invalid 120 | """ 121 | api_key = os.getenv("LLAMACLOUD_API_KEY") 122 | if not api_key: 123 | raise LlamaCloudConfigError( 124 | "LLAMACLOUD_API_KEY environment variable is required" 125 | ) 126 | 127 | base_url = get_llamacloud_base_url() 128 | return LlamaParse(api_key=api_key, result_type=result_type, base_url=base_url) 129 | 130 | 131 | def create_llamacloud_index(api_key: str, pipeline_id: str) -> LlamaCloudIndex: 132 | """ 133 | Create a configured LlamaCloudIndex with regional support. 134 | 135 | Args: 136 | api_key: The API key for authentication 137 | pipeline_id: The pipeline ID to use 138 | 139 | Returns: 140 | LlamaCloudIndex: Configured index instance 141 | 142 | Raises: 143 | LlamaCloudConfigError: If API key or pipeline_id is missing, or region is invalid 144 | """ 145 | if not api_key: 146 | raise LlamaCloudConfigError("API key is required") 147 | 148 | if not pipeline_id: 149 | raise LlamaCloudConfigError("Pipeline ID is required") 150 | 151 | base_url = get_llamacloud_base_url() 152 | return LlamaCloudIndex(api_key=api_key, pipeline_id=pipeline_id, base_url=base_url) 153 | -------------------------------------------------------------------------------- /src/notebookllama/pages/4_Observability_Dashboard.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pandas as pd 4 | import streamlit as st 5 | import plotly.express as px 6 | import plotly.graph_objects as go 7 | 8 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 9 | 10 | from dotenv import load_dotenv 11 | from instrumentation import OtelTracesSqlEngine 12 | from sqlalchemy import text 13 | 14 | load_dotenv() 15 | 16 | sql_engine = OtelTracesSqlEngine( 17 | engine_url=f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}", 18 | table_name="agent_traces", 19 | service_name="agent.traces", 20 | ) 21 | 22 | 23 | def display_sql() -> pd.DataFrame: 24 | query = """CREATE TABLE IF NOT EXISTS agent_traces ( 25 | trace_id TEXT NOT NULL, 26 | span_id TEXT NOT NULL, 27 | parent_span_id TEXT NULL, 28 | operation_name TEXT NOT NULL, 29 | start_time BIGINT NOT NULL, 30 | duration INTEGER NOT NULL, 31 | status_code TEXT NOT NULL, 32 | service_name TEXT NOT NULL 33 | );""" 34 | sql_engine.execute(text(query)) 35 | return sql_engine.to_pandas() 36 | 37 | 38 | def filter_traces(sql_query: str): 39 | df = sql_engine.execute(text(sql_query), return_pandas=True) 40 | return df 41 | 42 | 43 | def create_latency_chart(df: pd.DataFrame): 44 | """Create a line chart showing latency (duration) over time""" 45 | if df.empty: 46 | st.warning("No data available for latency chart") 47 | return 48 | 49 | # Convert start_time from nanoseconds to datetime 50 | df_chart = df.copy() 51 | progressive_count = list(range(0, len(df_chart["start_time"]))) 52 | df_chart["progressive_count"] = progressive_count 53 | 54 | fig = px.line( 55 | df_chart, 56 | x="progressive_count", 57 | y="duration", 58 | title="Latency Overview", 59 | labels={"duration": "Duration (ns)", "timestamp": "Time"}, 60 | hover_data=["operation_name", "status_code"], 61 | ) 62 | 63 | fig.update_layout( 64 | xaxis_title="Time", yaxis_title="Duration (nanoseconds)", hovermode="x unified" 65 | ) 66 | 67 | st.plotly_chart(fig) 68 | 69 | 70 | def create_status_pie_chart(df: pd.DataFrame): 71 | """Create a pie chart showing status code distribution""" 72 | if df.empty: 73 | st.warning("No data available for status code chart") 74 | return 75 | 76 | # Count status codes 77 | status_counts = df["status_code"].value_counts() 78 | 79 | # Map common status codes to more readable labels 80 | status_labels = { 81 | "OK": "OK", 82 | "ERROR": "ERROR", 83 | "UNSET": "UNSET", 84 | "200": "OK (200)", 85 | "500": "ERROR (500)", 86 | "404": "ERROR (404)", 87 | } 88 | 89 | # Create labels and values for the pie chart 90 | labels = [status_labels.get(status, status) for status in status_counts.index] 91 | values = status_counts.values 92 | 93 | # Define colors 94 | colors = [] 95 | for status in status_counts.index: 96 | if status in ["OK", "200"]: 97 | colors.append("#28a745") # Green for OK 98 | elif status in ["ERROR", "500", "404"]: 99 | colors.append("#dc3545") # Red for ERROR 100 | else: 101 | colors.append("#6c757d") # Gray for others 102 | 103 | fig = go.Figure( 104 | data=[go.Pie(labels=labels, values=values, hole=0.3, marker_colors=colors)] 105 | ) 106 | 107 | fig.update_layout( 108 | title="Status Code Distribution", 109 | annotations=[dict(text="Status", x=0.5, y=0.5, font_size=20, showarrow=False)], 110 | ) 111 | 112 | st.plotly_chart(fig) 113 | 114 | 115 | # Streamlit UI 116 | st.set_page_config(page_title="NotebookLlaMa - Observability Dashboard", page_icon="🔍") 117 | 118 | st.sidebar.header("Observability Dashboard🔍") 119 | st.sidebar.info("To switch to the other pages, select them from above!🔺") 120 | st.markdown("---") 121 | st.markdown("## NotebookLlaMa - Observability Dashboard🔍") 122 | 123 | # Get the data 124 | df_data = display_sql() 125 | 126 | # Charts section 127 | st.markdown("## 📊 Analytics Overview") 128 | 129 | if not df_data.empty: 130 | col1, col2 = st.columns(2) 131 | 132 | with col1: 133 | create_latency_chart(df_data) 134 | 135 | with col2: 136 | create_status_pie_chart(df_data) 137 | else: 138 | st.info("No trace data available yet. Charts will appear once data is collected.") 139 | 140 | st.markdown("---") 141 | 142 | # SQL Query section 143 | st.markdown("### SQL Query") 144 | sql_query = st.text_input(label="") 145 | 146 | st.markdown("## Traces Table") 147 | dataframe = st.dataframe(data=df_data) 148 | 149 | if st.button("Run SQL query", type="primary"): 150 | if sql_query.strip(): 151 | try: 152 | filtered_df = filter_traces(sql_query=sql_query) 153 | st.markdown("### Query Results") 154 | dataframe = st.dataframe(data=filtered_df) 155 | 156 | # Update charts with filtered data 157 | if not filtered_df.empty: 158 | st.markdown("### Updated Charts") 159 | col1, col2 = st.columns(2) 160 | 161 | with col1: 162 | create_latency_chart(filtered_df) 163 | 164 | with col2: 165 | create_status_pie_chart(filtered_df) 166 | except Exception as e: 167 | st.error(f"Error executing query: {str(e)}") 168 | else: 169 | st.warning("Please enter a SQL query") 170 | -------------------------------------------------------------------------------- /src/notebookllama/instrumentation.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import time 3 | import pandas as pd 4 | 5 | from sqlalchemy import Engine, create_engine, Connection, Result 6 | from typing import Optional, Dict, Any, List, Literal, Union, cast 7 | 8 | 9 | class OtelTracesSqlEngine: 10 | def __init__( 11 | self, 12 | engine: Optional[Engine] = None, 13 | engine_url: Optional[str] = None, 14 | table_name: Optional[str] = None, 15 | service_name: Optional[str] = None, 16 | ): 17 | self.service_name: str = service_name or "service" 18 | self.table_name: str = table_name or "otel_traces" 19 | self._connection: Optional[Connection] = None 20 | if engine: 21 | self._engine: Engine = engine 22 | elif engine_url: 23 | self._engine = create_engine(url=engine_url) 24 | else: 25 | raise ValueError("One of engine or engine_setup_kwargs must be set") 26 | 27 | def _connect(self) -> None: 28 | self._connection = self._engine.connect() 29 | 30 | def _export( 31 | self, 32 | start_time: Optional[int] = None, 33 | end_time: Optional[int] = None, 34 | limit: Optional[int] = None, 35 | ) -> Dict[str, Any]: 36 | url = "http://localhost:16686/api/traces" 37 | params: Dict[str, Union[str, int]] = { 38 | "service": self.service_name, 39 | "start": start_time 40 | or int(time.time() * 1000000) - (24 * 60 * 60 * 1000000), 41 | "end": end_time or int(time.time() * 1000000), 42 | "limit": limit or 1000, 43 | } 44 | response = requests.get(url, params=params) 45 | print(response.json()) 46 | return response.json() 47 | 48 | def _to_pandas(self, data: Dict[str, Any]) -> pd.DataFrame: 49 | rows: List[Dict[str, Any]] = [] 50 | 51 | # Loop over each trace 52 | for trace in data.get("data", []): 53 | trace_id = trace.get("traceID") 54 | service_map = { 55 | pid: proc.get("serviceName") 56 | for pid, proc in trace.get("processes", {}).items() 57 | } 58 | 59 | for span in trace.get("spans", []): 60 | span_id = span.get("spanID") 61 | operation = span.get("operationName") 62 | start = span.get("startTime") 63 | duration = span.get("duration") 64 | process_id = span.get("processID") 65 | service = service_map.get(process_id, "") 66 | status = next( 67 | ( 68 | tag.get("value") 69 | for tag in span.get("tags", []) 70 | if tag.get("key") == "otel.status_code" 71 | ), 72 | "", 73 | ) 74 | parent_span_id = None 75 | if span.get("references"): 76 | parent_span_id = span["references"][0].get("spanID") 77 | 78 | rows.append( 79 | { 80 | "trace_id": trace_id, 81 | "span_id": span_id, 82 | "parent_span_id": parent_span_id, 83 | "operation_name": operation, 84 | "start_time": start, 85 | "duration": duration, 86 | "status_code": status, 87 | "service_name": service, 88 | } 89 | ) 90 | 91 | return pd.DataFrame(rows) 92 | 93 | def _to_sql( 94 | self, 95 | dataframe: pd.DataFrame, 96 | if_exists_policy: Optional[Literal["fail", "replace", "append"]] = None, 97 | ) -> None: 98 | if not self._connection: 99 | self._connect() 100 | dataframe.to_sql( 101 | name=self.table_name, 102 | con=self._connection, 103 | if_exists=if_exists_policy or "append", 104 | ) 105 | 106 | def to_sql_database( 107 | self, 108 | start_time: Optional[int] = None, 109 | end_time: Optional[int] = None, 110 | limit: Optional[int] = None, 111 | if_exists_policy: Optional[Literal["fail", "replace", "append"]] = None, 112 | ) -> None: 113 | data = self._export(start_time=start_time, end_time=end_time, limit=limit) 114 | df = self._to_pandas(data=data) 115 | self._to_sql(dataframe=df, if_exists_policy=if_exists_policy) 116 | 117 | def execute( 118 | self, 119 | statement: Any, 120 | parameters: Optional[Any] = None, 121 | execution_options: Optional[Any] = None, 122 | return_pandas: bool = False, 123 | ) -> Union[Result, pd.DataFrame]: 124 | if not self._connection: 125 | self._connect() 126 | if not return_pandas: 127 | self._connection = cast(Connection, self._connection) 128 | return self._connection.execute( 129 | statement=statement, 130 | parameters=parameters, 131 | execution_options=execution_options, 132 | ) 133 | return pd.read_sql(sql=statement, con=self._connection) 134 | 135 | def to_pandas( 136 | self, 137 | ) -> pd.DataFrame: 138 | if not self._connection: 139 | self._connect() 140 | return pd.read_sql_table(table_name=self.table_name, con=self._connection) 141 | 142 | def disconnect(self) -> None: 143 | if not self._connection: 144 | raise ValueError("Engine was never connected!") 145 | self._engine.dispose(close=True) 146 | -------------------------------------------------------------------------------- /src/notebookllama/processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from dotenv import load_dotenv 4 | import pandas as pd 5 | import json 6 | import warnings 7 | from datetime import datetime 8 | 9 | from mrkdwn_analysis import MarkdownAnalyzer 10 | from mrkdwn_analysis.markdown_analyzer import InlineParser, MarkdownParser 11 | from llama_cloud_services.extract import SourceText 12 | from typing_extensions import override 13 | from typing import List, Tuple, Union, Optional, Dict 14 | 15 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 16 | 17 | from notebookllama.utils import ( 18 | create_llamacloud_client, 19 | create_llama_extract_client, 20 | create_llama_parse_client, 21 | ) 22 | 23 | load_dotenv() 24 | 25 | if ( 26 | os.getenv("LLAMACLOUD_API_KEY", None) 27 | and os.getenv("EXTRACT_AGENT_ID", None) 28 | and os.getenv("LLAMACLOUD_PIPELINE_ID", None) 29 | ): 30 | CLIENT = create_llamacloud_client() 31 | llama_extract_client = create_llama_extract_client() 32 | EXTRACT_AGENT = llama_extract_client.get_agent(id=os.getenv("EXTRACT_AGENT_ID")) 33 | PARSER = create_llama_parse_client(result_type="markdown") 34 | PIPELINE_ID = os.getenv("LLAMACLOUD_PIPELINE_ID") 35 | 36 | 37 | class MarkdownTextAnalyzer(MarkdownAnalyzer): 38 | @override 39 | def __init__(self, text: str): 40 | self.text = text 41 | parser = MarkdownParser(self.text) 42 | self.tokens = parser.parse() 43 | self.references = parser.references 44 | self.footnotes = parser.footnotes 45 | self.inline_parser = InlineParser( 46 | references=self.references, footnotes=self.footnotes 47 | ) 48 | self._parse_inline_tokens() 49 | 50 | 51 | def md_table_to_pd_dataframe(md_table: Dict[str, list]) -> Optional[pd.DataFrame]: 52 | try: 53 | df = pd.DataFrame() 54 | for i in range(len(md_table["header"])): 55 | ls = [row[i] for row in md_table["rows"]] 56 | df[md_table["header"][i]] = ls 57 | return df 58 | except Exception as e: 59 | warnings.warn(f"Skipping table as an error occurred: {e}") 60 | return None 61 | 62 | 63 | def rename_and_remove_past_images(path: str = "static/") -> List[str]: 64 | renamed = [] 65 | if os.path.exists(path) and len(os.listdir(path)) >= 0: 66 | for image_file in os.listdir(path): 67 | image_path = os.path.join(path, image_file) 68 | if os.path.isfile(image_path) and "_at_" not in image_path: 69 | with open(image_path, "rb") as img: 70 | bts = img.read() 71 | new_path = ( 72 | os.path.splitext(image_path)[0].replace("_current", "") 73 | + f"_at_{datetime.now().strftime('%Y_%d_%m_%H_%M_%S_%f')[:-3]}.png" 74 | ) 75 | with open( 76 | new_path, 77 | "wb", 78 | ) as img_tw: 79 | img_tw.write(bts) 80 | renamed.append(new_path) 81 | os.remove(image_path) 82 | return renamed 83 | 84 | 85 | def rename_and_remove_current_images(images: List[str]) -> List[str]: 86 | imgs = [] 87 | for image in images: 88 | with open(image, "rb") as rb: 89 | bts = rb.read() 90 | with open(os.path.splitext(image)[0] + "_current.png", "wb") as wb: 91 | wb.write(bts) 92 | imgs.append(os.path.splitext(image)[0] + "_current.png") 93 | os.remove(image) 94 | return imgs 95 | 96 | 97 | async def parse_file( 98 | file_path: str, with_images: bool = False, with_tables: bool = False 99 | ) -> Union[Tuple[Optional[str], Optional[List[str]], Optional[List[pd.DataFrame]]]]: 100 | images: Optional[List[str]] = None 101 | text: Optional[str] = None 102 | tables: Optional[List[pd.DataFrame]] = None 103 | document = await PARSER.aparse(file_path=file_path) 104 | md_content = await document.aget_markdown_documents() 105 | if len(md_content) != 0: 106 | text = "\n\n---\n\n".join([doc.text for doc in md_content]) 107 | if with_images: 108 | rename_and_remove_past_images() 109 | imgs = await document.asave_all_images("static/") 110 | images = rename_and_remove_current_images(imgs) 111 | if with_tables: 112 | if text is not None: 113 | analyzer = MarkdownTextAnalyzer(text) 114 | md_tables = analyzer.identify_tables()["Table"] 115 | tables = [] 116 | for md_table in md_tables: 117 | table = md_table_to_pd_dataframe(md_table=md_table) 118 | if table is not None: 119 | tables.append(table) 120 | os.makedirs("data/extracted_tables/", exist_ok=True) 121 | table.to_csv( 122 | f"data/extracted_tables/table_{datetime.now().strftime('%Y_%d_%m_%H_%M_%S_%f')[:-3]}.csv", 123 | index=False, 124 | ) 125 | return text, images, tables 126 | 127 | 128 | async def process_file( 129 | filename: str, 130 | ) -> Union[Tuple[str, None], Tuple[None, None], Tuple[str, str]]: 131 | with open(filename, "rb") as f: 132 | file = await CLIENT.files.upload_file(upload_file=f) 133 | files = [{"file_id": file.id}] 134 | await CLIENT.pipelines.add_files_to_pipeline_api( 135 | pipeline_id=PIPELINE_ID, request=files 136 | ) 137 | text, _, _ = await parse_file(file_path=filename) 138 | if text is None: 139 | return None, None 140 | extraction_output = await EXTRACT_AGENT.aextract( 141 | files=SourceText(text_content=text, filename=file.name) 142 | ) 143 | if extraction_output: 144 | return json.dumps(extraction_output.data, indent=4), text 145 | return None, None 146 | 147 | 148 | async def get_plots_and_tables( 149 | file_path: str, 150 | ) -> Union[Tuple[Optional[List[str]], Optional[List[pd.DataFrame]]]]: 151 | _, images, tables = await parse_file( 152 | file_path=file_path, with_images=True, with_tables=True 153 | ) 154 | return images, tables 155 | -------------------------------------------------------------------------------- /src/notebookllama/pages/3_Interactive_Table_and_Plot_Visualization.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import io 4 | import tempfile as tmp 5 | 6 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 7 | 8 | import asyncio 9 | from processing import get_plots_and_tables 10 | import streamlit as st 11 | from PIL import Image 12 | 13 | 14 | def get_plots_and_tables_sync(file: io.BytesIO): 15 | fl = tmp.NamedTemporaryFile(suffix=".pdf", delete=False, delete_on_close=False) 16 | with open(fl.name, "wb") as f: 17 | f.write(file.getvalue()) 18 | 19 | # Try to get existing event loop, if not create a new one 20 | try: 21 | loop = asyncio.get_event_loop() 22 | if loop.is_closed(): 23 | raise RuntimeError("Event loop is closed") 24 | except RuntimeError: 25 | # No event loop exists or it's closed, create a new one 26 | loop = asyncio.new_event_loop() 27 | asyncio.set_event_loop(loop) 28 | 29 | # Run the async function in the current event loop 30 | if loop.is_running(): 31 | # If loop is already running (e.g., in Jupyter/Streamlit), 32 | # we need to use a different approach 33 | import concurrent.futures 34 | 35 | with concurrent.futures.ThreadPoolExecutor() as executor: 36 | future = executor.submit( 37 | asyncio.run, get_plots_and_tables(file_path=fl.name) 38 | ) 39 | return future.result() 40 | else: 41 | # If loop is not running, we can use it directly 42 | return loop.run_until_complete(get_plots_and_tables(file_path=fl.name)) 43 | 44 | 45 | # Direct Streamlit execution 46 | # Chat Interface 47 | st.set_page_config(page_title="NotebookLlaMa - Images and Tables", page_icon="📊") 48 | 49 | st.sidebar.header("Images and Tables📊") 50 | st.sidebar.info("To switch to the Home page, select it from above!🔺") 51 | st.markdown("---") 52 | st.markdown("## NotebookLlaMa - Images and Tables📊") 53 | st.markdown("### Upload a PDF file to extract plots and tables") 54 | 55 | # File uploader 56 | uploaded_file = st.file_uploader("Choose a PDF file", type="pdf") 57 | 58 | if uploaded_file is not None: 59 | # Process the file 60 | with st.spinner("Processing PDF... This may take a moment."): 61 | try: 62 | # Convert uploaded file to BytesIO 63 | file_bytes = io.BytesIO(uploaded_file.getvalue()) 64 | 65 | # Extract plots and tables 66 | image_paths, dataframes = get_plots_and_tables_sync(file_bytes) 67 | 68 | # Display results summary 69 | st.success("✅ Processing complete!") 70 | st.info( 71 | f"Found {len(image_paths)} images and {len(dataframes)} tables.\nImages can be seen in the `static/` folder, tables in the `data/extracted_tables` folder." 72 | ) 73 | 74 | # Create tabs for better organization 75 | tab1, tab2 = st.tabs(["📊 Tables", "📈 Plots/Images"]) 76 | 77 | with tab1: 78 | st.header("Extracted Tables") 79 | if dataframes: 80 | for i, df in enumerate(dataframes): 81 | st.subheader(f"Table {i + 1}") 82 | 83 | # Display table with options 84 | col1, col2 = st.columns([3, 1]) 85 | 86 | with col1: 87 | st.dataframe(df, use_container_width=True) 88 | 89 | with col2: 90 | st.write("**Table Info:**") 91 | st.write(f"Rows: {len(df)}") 92 | st.write(f"Columns: {len(df.columns)}") 93 | 94 | # Show additional table statistics 95 | with st.expander(f"Table {i + 1} Details"): 96 | st.write("**Column Information:**") 97 | st.write(df.dtypes) 98 | st.write("**First few rows:**") 99 | st.write(df.head()) 100 | if len(df) > 5: 101 | st.write("**Summary statistics:**") 102 | st.write(df.describe(include="all")) 103 | 104 | st.divider() 105 | else: 106 | st.warning("No tables found in the PDF") 107 | 108 | with tab2: 109 | st.header("Extracted Plots and Images") 110 | if image_paths: 111 | for i, img_path in enumerate(image_paths): 112 | st.subheader(f"Image {i + 1}") 113 | 114 | try: 115 | # Display image 116 | if os.path.exists(img_path): 117 | image = Image.open(img_path) 118 | 119 | # Create columns for image and info 120 | col1, col2 = st.columns([3, 1]) 121 | 122 | with col1: 123 | st.image( 124 | image, 125 | caption=f"Image {i + 1}", 126 | use_container_width=True, 127 | ) 128 | 129 | with col2: 130 | st.write("**Image Info:**") 131 | st.write(f"Size: {image.size}") 132 | st.write(f"Format: {image.format}") 133 | st.write(f"Mode: {image.mode}") 134 | 135 | st.divider() 136 | else: 137 | st.error(f"Image file not found: {img_path}") 138 | 139 | except Exception as e: 140 | st.error(f"Error loading image {i + 1}: {str(e)}") 141 | else: 142 | st.warning("No images found in the PDF") 143 | 144 | except Exception as e: 145 | st.error(f"Error processing file: {str(e)}") 146 | st.exception(e) # Show full traceback in development 147 | -------------------------------------------------------------------------------- /tests/test_audio.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from elevenlabs import AsyncElevenLabs 4 | from src.notebookllama.audio import ( 5 | PodcastGenerator, 6 | MultiTurnConversation, 7 | PodcastConfig, 8 | AudioGenerationError, 9 | ConversationGenerationError, 10 | PodcastGeneratorError, 11 | ) 12 | from llama_index.core.llms.structured_llm import StructuredLLM 13 | from llama_index.core.llms import MockLLM 14 | from pydantic import BaseModel, ValidationError 15 | 16 | 17 | class MockElevenLabs(AsyncElevenLabs): 18 | def __init__(self, test_api_key: str) -> None: 19 | self.test_api_key = test_api_key 20 | 21 | 22 | class DataModel(BaseModel): 23 | test: str 24 | 25 | 26 | @pytest.fixture() 27 | def correct_structured_llm() -> StructuredLLM: 28 | return MockLLM().as_structured_llm(MultiTurnConversation) 29 | 30 | 31 | @pytest.fixture() 32 | def wrong_structured_llm() -> StructuredLLM: 33 | return MockLLM().as_structured_llm(DataModel) 34 | 35 | 36 | def test_podcast_generator_model( 37 | correct_structured_llm: StructuredLLM, wrong_structured_llm: StructuredLLM 38 | ) -> None: 39 | n = PodcastGenerator( 40 | client=MockElevenLabs(test_api_key="a"), llm=correct_structured_llm 41 | ) 42 | assert isinstance(n.client, AsyncElevenLabs) 43 | assert isinstance(n.llm, StructuredLLM) 44 | assert n.llm.output_cls == MultiTurnConversation 45 | with pytest.raises(ValidationError): 46 | PodcastGenerator( 47 | client=MockElevenLabs(test_api_key="a"), llm=wrong_structured_llm 48 | ) 49 | 50 | 51 | # Test Error Classes 52 | def test_podcast_generator_error_hierarchy(): 53 | """Test custom exception hierarchy""" 54 | assert issubclass(AudioGenerationError, PodcastGeneratorError) 55 | assert issubclass(ConversationGenerationError, PodcastGeneratorError) 56 | assert issubclass(PodcastGeneratorError, Exception) 57 | 58 | 59 | def test_custom_exceptions(): 60 | """Test custom exception instantiation""" 61 | base_error = PodcastGeneratorError("Base error") 62 | assert str(base_error) == "Base error" 63 | 64 | audio_error = AudioGenerationError("Audio error") 65 | assert str(audio_error) == "Audio error" 66 | assert isinstance(audio_error, PodcastGeneratorError) 67 | 68 | conversation_error = ConversationGenerationError("Conversation error") 69 | assert str(conversation_error) == "Conversation error" 70 | assert isinstance(conversation_error, PodcastGeneratorError) 71 | 72 | 73 | def test_build_conversation_prompt_basic(correct_structured_llm: StructuredLLM): 74 | """Test basic prompt building functionality""" 75 | generator = PodcastGenerator( 76 | client=MockElevenLabs(test_api_key="test"), llm=correct_structured_llm 77 | ) 78 | 79 | config = PodcastConfig() 80 | transcript = "This is a test transcript about artificial intelligence." 81 | 82 | prompt = generator._build_conversation_prompt(transcript, config) 83 | 84 | assert "conversational podcast conversation" in prompt 85 | assert "CONVERSATION STYLE: conversational" in prompt 86 | assert "TONE: friendly" in prompt 87 | assert "TARGET AUDIENCE: general" in prompt 88 | assert "Speaker 1: host" in prompt 89 | assert "Speaker 2: guest" in prompt 90 | assert transcript in prompt 91 | assert "Balance accessibility with depth" in prompt 92 | 93 | 94 | def test_build_conversation_prompt_with_focus_topics( 95 | correct_structured_llm: StructuredLLM, 96 | ): 97 | """Test prompt building with focus topics""" 98 | generator = PodcastGenerator( 99 | client=MockElevenLabs(test_api_key="test"), llm=correct_structured_llm 100 | ) 101 | 102 | focus_topics = ["Machine Learning", "Ethics in AI", "Future Applications"] 103 | config = PodcastConfig(focus_topics=focus_topics) 104 | transcript = "AI transcript content" 105 | 106 | prompt = generator._build_conversation_prompt(transcript, config) 107 | 108 | assert "FOCUS TOPICS: Make sure to discuss these topics in detail:" in prompt 109 | for topic in focus_topics: 110 | assert f"- {topic}" in prompt 111 | 112 | 113 | def test_build_conversation_prompt_technical_audience( 114 | correct_structured_llm: StructuredLLM, 115 | ): 116 | """Test prompt building for technical audience""" 117 | generator = PodcastGenerator( 118 | client=MockElevenLabs(test_api_key="test"), llm=correct_structured_llm 119 | ) 120 | 121 | config = PodcastConfig(target_audience="technical", style="educational") 122 | transcript = "Technical content about neural networks" 123 | 124 | prompt = generator._build_conversation_prompt(transcript, config) 125 | 126 | assert "educational podcast conversation" in prompt 127 | assert "TARGET AUDIENCE: technical" in prompt 128 | assert "Use technical terminology appropriately" in prompt 129 | 130 | 131 | def test_build_conversation_prompt_beginner_audience( 132 | correct_structured_llm: StructuredLLM, 133 | ): 134 | """Test prompt building for beginner audience""" 135 | generator = PodcastGenerator( 136 | client=MockElevenLabs(test_api_key="test"), llm=correct_structured_llm 137 | ) 138 | 139 | config = PodcastConfig(target_audience="beginner", tone="casual") 140 | transcript = "Complex topic simplified" 141 | 142 | prompt = generator._build_conversation_prompt(transcript, config) 143 | 144 | assert "TONE: casual" in prompt 145 | assert "TARGET AUDIENCE: beginner" in prompt 146 | assert "Explain concepts clearly and avoid jargon" in prompt 147 | 148 | 149 | def test_build_conversation_prompt_with_custom_prompt( 150 | correct_structured_llm: StructuredLLM, 151 | ): 152 | """Test prompt building with custom instructions""" 153 | generator = PodcastGenerator( 154 | client=MockElevenLabs(test_api_key="test"), llm=correct_structured_llm 155 | ) 156 | 157 | custom_prompt = ( 158 | "Make sure to include practical examples and keep it under 10 minutes" 159 | ) 160 | config = PodcastConfig(custom_prompt=custom_prompt) 161 | transcript = "Sample content" 162 | 163 | prompt = generator._build_conversation_prompt(transcript, config) 164 | 165 | assert f"ADDITIONAL INSTRUCTIONS: {custom_prompt}" in prompt 166 | 167 | 168 | def test_build_conversation_prompt_custom_speaker_roles( 169 | correct_structured_llm: StructuredLLM, 170 | ): 171 | """Test prompt building with custom speaker roles""" 172 | generator = PodcastGenerator( 173 | client=MockElevenLabs(test_api_key="test"), llm=correct_structured_llm 174 | ) 175 | 176 | config = PodcastConfig( 177 | speaker1_role="professor", speaker2_role="student", style="interview" 178 | ) 179 | transcript = "Educational content" 180 | 181 | prompt = generator._build_conversation_prompt(transcript, config) 182 | 183 | assert "Speaker 1: professor" in prompt 184 | assert "Speaker 2: student" in prompt 185 | assert "interview podcast conversation" in prompt 186 | 187 | 188 | def test_build_conversation_prompt_all_audiences(correct_structured_llm: StructuredLLM): 189 | """Test that all audience types have appropriate instructions""" 190 | generator = PodcastGenerator( 191 | client=MockElevenLabs(test_api_key="test"), llm=correct_structured_llm 192 | ) 193 | 194 | audience_expectations = { 195 | "technical": "Use technical terminology appropriately", 196 | "beginner": "Explain concepts clearly and avoid jargon", 197 | "expert": "Assume advanced knowledge and discuss nuanced aspects", 198 | "business": "Focus on practical applications, ROI, and strategic implications", 199 | "general": "Balance accessibility with depth", 200 | } 201 | 202 | for audience, expected_instruction in audience_expectations.items(): 203 | config = PodcastConfig(target_audience=audience) 204 | prompt = generator._build_conversation_prompt("test content", config) 205 | assert expected_instruction in prompt 206 | 207 | 208 | @pytest.fixture() 209 | def sample_podcast_generator(correct_structured_llm: StructuredLLM) -> PodcastGenerator: 210 | """Fixture providing a configured PodcastGenerator for testing""" 211 | return PodcastGenerator( 212 | client=MockElevenLabs(test_api_key="test"), llm=correct_structured_llm 213 | ) 214 | -------------------------------------------------------------------------------- /src/notebookllama/pages/2_Document_Chat.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import asyncio 3 | import sys 4 | import os 5 | 6 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 7 | 8 | from verifying import verify_claim as sync_verify_claim 9 | from llama_index.tools.mcp import BasicMCPClient 10 | 11 | MCP_CLIENT = BasicMCPClient(command_or_url="http://localhost:8000/mcp") 12 | 13 | 14 | async def chat(inpt: str): 15 | result = await MCP_CLIENT.call_tool( 16 | tool_name="query_index_tool", arguments={"question": inpt} 17 | ) 18 | return result.content[0].text 19 | 20 | 21 | def sync_chat(inpt: str): 22 | return asyncio.run(chat(inpt=inpt)) 23 | 24 | 25 | # Chat Interface 26 | st.set_page_config(page_title="NotebookLlaMa - Document Chat", page_icon="🗣") 27 | 28 | st.sidebar.header("Document Chat🗣") 29 | st.sidebar.info("To switch to the Home page, select it from above!🔺") 30 | st.markdown("---") 31 | st.markdown("## NotebookLlaMa - Document Chat🗣") 32 | 33 | # Initialize chat history 34 | if "messages" not in st.session_state: 35 | st.session_state.messages = [] 36 | 37 | # Display chat messages from history on app rerun 38 | for i, message in enumerate(st.session_state.messages): 39 | with st.chat_message(message["role"]): 40 | if message["role"] == "assistant" and "sources" in message: 41 | # Display the main response 42 | st.markdown(message["content"]) 43 | # Add toggle for sources 44 | with st.expander("Sources"): 45 | st.markdown(message["sources"]) 46 | elif message["role"] == "assistant" and "verification" in message: 47 | # Display verification results 48 | st.markdown(message["content"]) 49 | verification = message["verification"] 50 | 51 | # Show verification status with appropriate styling 52 | if verification["is_true"]: 53 | st.success("✅ Claim verified as TRUE") 54 | else: 55 | st.error("❌ Claim verified as FALSE") 56 | 57 | # Show citations if available 58 | if verification["citations"]: 59 | with st.expander("Supporting Citations"): 60 | for i, citation in enumerate(verification["citations"], 1): 61 | st.markdown(f"**Citation {i}:**") 62 | st.markdown(f"*{citation}*") 63 | st.markdown("---") 64 | else: 65 | st.info("No supporting citations found") 66 | else: 67 | st.markdown(message["content"]) 68 | 69 | # React to user input 70 | if prompt := st.chat_input("Ask a question about your document"): 71 | # Display user message in chat message container 72 | st.chat_message("user").markdown(prompt) 73 | # Add user message to chat history 74 | st.session_state.messages.append({"role": "user", "content": prompt}) 75 | 76 | # Get bot response 77 | with st.chat_message("assistant"): 78 | with st.spinner("Thinking..."): 79 | try: 80 | response = sync_chat(prompt) 81 | 82 | # Split response and sources if they exist 83 | # Assuming your response format includes sources somehow 84 | # You might need to modify this based on your actual response format 85 | if "## Sources" in response: 86 | parts = response.split("## Sources", 1) 87 | main_response = parts[0].strip() 88 | sources = "## Sources" + parts[1].strip() 89 | else: 90 | main_response = response 91 | sources = None 92 | 93 | st.markdown(main_response) 94 | 95 | # Add toggle for sources if they exist 96 | if sources: 97 | with st.expander("Sources"): 98 | st.markdown(sources) 99 | # Add to history with sources 100 | st.session_state.messages.append( 101 | { 102 | "role": "assistant", 103 | "content": main_response, 104 | "sources": sources, 105 | } 106 | ) 107 | else: 108 | # Add to history without sources 109 | st.session_state.messages.append( 110 | {"role": "assistant", "content": main_response} 111 | ) 112 | 113 | except Exception as e: 114 | error_msg = f"Error: {str(e)}" 115 | st.markdown(error_msg) 116 | st.session_state.messages.append( 117 | {"role": "assistant", "content": error_msg} 118 | ) 119 | 120 | # Claim Verification Section 121 | st.markdown("---") 122 | st.markdown("### 🔍 Claim Verification") 123 | 124 | # Check if there are any assistant messages with sources to verify 125 | assistant_messages_with_sources = [ 126 | msg 127 | for msg in st.session_state.messages 128 | if msg["role"] == "assistant" and "sources" in msg 129 | ] 130 | 131 | if assistant_messages_with_sources: 132 | st.markdown("Select a response to verify its claims against the sources:") 133 | 134 | # Create a selectbox with assistant responses 135 | response_options = [] 136 | for i, msg in enumerate(assistant_messages_with_sources): 137 | # Truncate long responses for display 138 | content_preview = ( 139 | msg["content"][:100] + "..." 140 | if len(msg["content"]) > 100 141 | else msg["content"] 142 | ) 143 | response_options.append(f"Response {i + 1}: {content_preview}") 144 | 145 | selected_response_idx = st.selectbox( 146 | "Choose response to verify:", 147 | range(len(response_options)), 148 | format_func=lambda x: response_options[x], 149 | ) 150 | 151 | # Create columns for the verify button 152 | col1, col2 = st.columns([4, 1]) 153 | 154 | with col2: 155 | verify_button = st.button("🔍 Verify Claims", type="primary") 156 | 157 | # Handle claim verification 158 | if verify_button: 159 | selected_message = assistant_messages_with_sources[selected_response_idx] 160 | claim_text = selected_message["content"] 161 | sources_text = selected_message["sources"] 162 | 163 | with st.spinner("Verifying claims against sources..."): 164 | try: 165 | # Call the verify_claim function with claim and sources 166 | is_true, citations = sync_verify_claim(claim_text, sources_text) 167 | 168 | # Create verification result message 169 | verification_result = {"is_true": is_true, "citations": citations} 170 | 171 | # Display verification result 172 | st.markdown("### Verification Result:") 173 | 174 | # Show the claim being verified (truncated for display) 175 | claim_preview = ( 176 | claim_text[:200] + "..." if len(claim_text) > 200 else claim_text 177 | ) 178 | st.markdown(f"**Claim:** {claim_preview}") 179 | 180 | # Show verification status with appropriate styling 181 | if is_true: 182 | st.success("✅ Claims verified as TRUE") 183 | else: 184 | st.error("❌ Claims verified as FALSE") 185 | 186 | # Show citations if available 187 | if citations: 188 | with st.expander("Supporting Citations", expanded=True): 189 | for i, citation in enumerate(citations, 1): 190 | st.markdown(f"**Citation {i}:**") 191 | st.markdown(f"*{citation}*") 192 | st.markdown("---") 193 | else: 194 | st.info("No supporting citations found") 195 | 196 | # Add verification to chat history 197 | verification_message = f"**Claim Verification:** {claim_preview}\n\n" 198 | verification_message += "**Result:** " + ( 199 | "✅ TRUE" if is_true else "❌ FALSE" 200 | ) 201 | 202 | st.session_state.messages.append( 203 | { 204 | "role": "assistant", 205 | "content": verification_message, 206 | "verification": verification_result, 207 | } 208 | ) 209 | 210 | except Exception as e: 211 | error_msg = f"Error verifying claims: {str(e)}" 212 | st.error(error_msg) 213 | st.session_state.messages.append( 214 | {"role": "assistant", "content": error_msg} 215 | ) 216 | 217 | else: 218 | st.info( 219 | "No assistant responses with sources available to verify. Chat with the document first to generate responses with sources." 220 | ) 221 | 222 | # Optional: Add some helpful information 223 | with st.expander("ℹ️ About Claim Verification"): 224 | st.markdown(""" 225 | **How it works:** 226 | - Select an assistant response that includes sources 227 | - The system will verify the claims made in that response against the provided sources 228 | - You'll receive a TRUE/FALSE result with supporting citations if available 229 | 230 | **What gets verified:** 231 | - The assistant's response content is treated as the claim 232 | - The verification is done against the sources that were provided with that response 233 | - This helps ensure the assistant's answers are backed by the actual document content 234 | 235 | **Note:** 236 | - Only responses that include sources can be verified 237 | - The verification checks if the assistant's claims are supported by the cited sources 238 | """) 239 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from src.notebookllama.models import ( 4 | Notebook, 5 | ) 6 | from src.notebookllama.verifying import ClaimVerification 7 | from src.notebookllama.mindmap import MindMap, Node, Edge 8 | from src.notebookllama.audio import MultiTurnConversation, ConversationTurn 9 | from src.notebookllama.documents import ManagedDocument 10 | from src.notebookllama.audio import ( 11 | PodcastConfig, 12 | VoiceConfig, 13 | AudioQuality, 14 | ) 15 | from pydantic import ValidationError 16 | 17 | 18 | def test_notebook() -> None: 19 | n1 = Notebook( 20 | summary="This is a summary", 21 | questions=[ 22 | "What is the capital of Spain?", 23 | "What is the capital of France?", 24 | "What is the capital of Italy?", 25 | "What is the capital of Portugal?", 26 | "What is the capital of Germany?", 27 | ], 28 | answers=[ 29 | "Madrid", 30 | "Paris", 31 | "Rome", 32 | "Lisbon", 33 | "Berlin", 34 | ], 35 | highlights=["This", "is", "a", "summary"], 36 | ) 37 | assert n1.summary == "This is a summary" 38 | assert n1.questions[0] == "What is the capital of Spain?" 39 | assert n1.answers[0] == "Madrid" 40 | assert n1.highlights[0] == "This" 41 | # Fewer answers than questions 42 | with pytest.raises(ValidationError): 43 | Notebook( 44 | summary="This is a summary", 45 | questions=[ 46 | "What is the capital of France?", 47 | "What is the capital of Italy?", 48 | "What is the capital of Portugal?", 49 | "What is the capital of Germany?", 50 | ], 51 | answers=[ 52 | "Paris", 53 | "Rome", 54 | "Lisbon", 55 | ], 56 | highlights=["This", "is", "a", "summary"], 57 | ) 58 | # Fewer highlights than required 59 | with pytest.raises(ValidationError): 60 | Notebook( 61 | summary="This is a summary", 62 | questions=[ 63 | "What is the capital of Spain?", 64 | "What is the capital of France?", 65 | "What is the capital of Italy?", 66 | "What is the capital of Portugal?", 67 | "What is the capital of Germany?", 68 | ], 69 | answers=[ 70 | "Madrid", 71 | "Paris", 72 | "Rome", 73 | "Lisbon", 74 | "Berlin", 75 | ], 76 | highlights=["This", "is"], 77 | ) 78 | 79 | 80 | def test_mind_map() -> None: 81 | m1 = MindMap( 82 | nodes=[ 83 | Node(id="A", content="Auxin is released"), 84 | Node(id="B", content="Travels to the roots"), 85 | Node(id="C", content="Root cells grow"), 86 | ], 87 | edges=[ 88 | Edge(from_id="A", to_id="B"), 89 | Edge(from_id="A", to_id="C"), 90 | Edge(from_id="B", to_id="C"), 91 | ], 92 | ) 93 | assert m1.nodes[0].id == "A" 94 | assert m1.nodes[0].content == "Auxin is released" 95 | assert m1.edges[0].from_id == "A" 96 | assert m1.edges[0].to_id == "B" 97 | 98 | with pytest.raises(ValidationError): 99 | MindMap( 100 | nodes=[ 101 | Node(id="A", content="Auxin is released"), 102 | Node(id="B", content="Travels to the roots"), 103 | Node(id="C", content="Root cells grow"), 104 | ], 105 | edges=[ 106 | Edge(from_id="A", to_id="B"), 107 | Edge(from_id="A", to_id="D"), # "D" does not exist 108 | Edge(from_id="B", to_id="C"), 109 | ], 110 | ) 111 | 112 | 113 | def test_multi_turn_conversation() -> None: 114 | turns = [ 115 | ConversationTurn(speaker="speaker1", content="Hello, who are you?"), 116 | ConversationTurn(speaker="speaker2", content="I am very well, how about you?"), 117 | ConversationTurn(speaker="speaker1", content="I am well too, thanks!"), 118 | ] 119 | assert turns[0].speaker == "speaker1" 120 | assert turns[0].content == "Hello, who are you?" 121 | conversation = MultiTurnConversation( 122 | conversation=turns, 123 | ) 124 | assert isinstance(conversation.conversation, list) 125 | assert isinstance(conversation.conversation[0], ConversationTurn) 126 | wrong_turns = [ 127 | ConversationTurn(speaker="speaker1", content="Hello, who are you?"), 128 | ConversationTurn(speaker="speaker2", content="I am very well, how about you?"), 129 | ] 130 | wrong_turns1 = [ 131 | ConversationTurn(speaker="speaker2", content="Hello, who are you?"), 132 | ConversationTurn(speaker="speaker1", content="I am very well, how about you?"), 133 | ConversationTurn(speaker="speaker2", content="I am well too!"), 134 | ] 135 | wrong_turns2 = [ 136 | ConversationTurn(speaker="speaker1", content="Hello, who are you?"), 137 | ConversationTurn(speaker="speaker1", content="How is your life going?"), 138 | ConversationTurn( 139 | speaker="speaker2", 140 | content="What is all this interest in me all of a sudden?!", 141 | ), 142 | ] 143 | wrong_turns3 = [ 144 | ConversationTurn(speaker="speaker1", content="Hello, who are you?"), 145 | ConversationTurn(speaker="speaker2", content="I'm well! But..."), 146 | ConversationTurn( 147 | speaker="speaker2", 148 | content="...What is all this interest in me all of a sudden?!", 149 | ), 150 | ] 151 | with pytest.raises(ValidationError): 152 | MultiTurnConversation(conversation=wrong_turns) 153 | with pytest.raises(ValidationError): 154 | MultiTurnConversation(conversation=wrong_turns1) 155 | with pytest.raises(ValidationError): 156 | MultiTurnConversation(conversation=wrong_turns2) 157 | with pytest.raises(ValidationError): 158 | MultiTurnConversation(conversation=wrong_turns3) 159 | 160 | 161 | def test_claim_verification() -> None: 162 | cl1 = ClaimVerification( 163 | claim_is_true=True, supporting_citations=["Support 1", "Support 2"] 164 | ) 165 | assert cl1.claim_is_true 166 | assert cl1.supporting_citations == ["Support 1", "Support 2"] 167 | cl2 = ClaimVerification( 168 | claim_is_true=False, supporting_citations=["Support 1", "Support 2"] 169 | ) 170 | assert cl2.supporting_citations == ["The claim was deemed false."] 171 | cl3 = ClaimVerification( 172 | claim_is_true=False, 173 | ) 174 | assert cl3.supporting_citations is None 175 | with pytest.raises(ValidationError): 176 | ClaimVerification( 177 | claim_is_true=True, 178 | supporting_citations=["Support 1", "Support 2", "Support 3", "Support 4"], 179 | ) 180 | 181 | 182 | def test_managed_documents() -> None: 183 | d1 = ManagedDocument( 184 | document_name="Hello World", 185 | content="This is a test", 186 | summary="Test", 187 | q_and_a="Hello? World.", 188 | mindmap="Hello -> World", 189 | bullet_points=". Hello, . World", 190 | ) 191 | assert d1.document_name == "Hello World" 192 | assert d1.content == "This is a test" 193 | assert d1.summary == "Test" 194 | assert d1.q_and_a == "Hello? World." 195 | assert d1.mindmap == "Hello -> World" 196 | assert d1.bullet_points == ". Hello, . World" 197 | d2 = ManagedDocument( 198 | document_name="Hello World", 199 | content="This is a test", 200 | summary="Test's child", 201 | q_and_a="Hello? World.", 202 | mindmap="Hello -> World", 203 | bullet_points=". Hello, . World", 204 | ) 205 | assert d2.summary == "Test's child" 206 | 207 | 208 | # Test Audio Configuration Models 209 | def test_voice_config_defaults(): 210 | """Test VoiceConfig default values""" 211 | config = VoiceConfig() 212 | assert config.speaker1_voice_id == "nPczCjzI2devNBz1zQrb" 213 | assert config.speaker2_voice_id == "Xb7hH8MSUJpSbSDYk0k2" 214 | assert config.model_id == "eleven_turbo_v2_5" 215 | assert config.output_format == "mp3_22050_32" 216 | 217 | 218 | def test_voice_config_custom_values(): 219 | """Test VoiceConfig with custom values""" 220 | config = VoiceConfig( 221 | speaker1_voice_id="custom_voice_1", 222 | speaker2_voice_id="custom_voice_2", 223 | model_id="custom_model", 224 | output_format="wav_44100_16", 225 | ) 226 | assert config.speaker1_voice_id == "custom_voice_1" 227 | assert config.speaker2_voice_id == "custom_voice_2" 228 | assert config.model_id == "custom_model" 229 | assert config.output_format == "wav_44100_16" 230 | 231 | 232 | def test_audio_quality_defaults(): 233 | """Test AudioQuality default values""" 234 | config = AudioQuality() 235 | assert config.bitrate == "320k" 236 | assert config.quality_params == ["-q:a", "0"] 237 | 238 | 239 | def test_audio_quality_custom_values(): 240 | """Test AudioQuality with custom values""" 241 | custom_params = ["-q:a", "2", "-compression_level", "5"] 242 | config = AudioQuality(bitrate="256k", quality_params=custom_params) 243 | assert config.bitrate == "256k" 244 | assert config.quality_params == custom_params 245 | 246 | 247 | def test_podcast_config_defaults(): 248 | """Test that PodcastConfig creates with proper defaults""" 249 | config = PodcastConfig() 250 | 251 | assert config.style == "conversational" 252 | assert config.tone == "friendly" 253 | assert config.focus_topics is None 254 | assert config.target_audience == "general" 255 | assert config.custom_prompt is None 256 | assert config.speaker1_role == "host" 257 | assert config.speaker2_role == "guest" 258 | assert isinstance(config.voice_config, VoiceConfig) 259 | assert isinstance(config.audio_quality, AudioQuality) 260 | 261 | 262 | def test_podcast_config_custom_values(): 263 | """Test that PodcastConfig accepts custom values""" 264 | focus_topics = ["AI Ethics", "Machine Learning", "Future Tech"] 265 | custom_prompt = "Make it engaging and technical" 266 | 267 | config = PodcastConfig( 268 | style="interview", 269 | tone="professional", 270 | focus_topics=focus_topics, 271 | target_audience="expert", 272 | custom_prompt=custom_prompt, 273 | speaker1_role="interviewer", 274 | speaker2_role="technical_expert", 275 | ) 276 | 277 | assert config.style == "interview" 278 | assert config.tone == "professional" 279 | assert config.focus_topics == focus_topics 280 | assert config.target_audience == "expert" 281 | assert config.custom_prompt == custom_prompt 282 | assert config.speaker1_role == "interviewer" 283 | assert config.speaker2_role == "technical_expert" 284 | 285 | 286 | def test_podcast_config_validation(): 287 | """Test that PodcastConfig validates input values""" 288 | # Test invalid style 289 | with pytest.raises(ValidationError): 290 | PodcastConfig(style="invalid_style") 291 | 292 | # Test invalid tone 293 | with pytest.raises(ValidationError): 294 | PodcastConfig(tone="invalid_tone") 295 | 296 | # Test invalid target_audience 297 | with pytest.raises(ValidationError): 298 | PodcastConfig(target_audience="invalid_audience") 299 | 300 | 301 | def test_conversation_turn(): 302 | """Test ConversationTurn model""" 303 | turn = ConversationTurn(speaker="speaker1", content="Hello world") 304 | assert turn.speaker == "speaker1" 305 | assert turn.content == "Hello world" 306 | -------------------------------------------------------------------------------- /src/notebookllama/Home.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import io 3 | import os 4 | import asyncio 5 | import tempfile as temp 6 | from dotenv import load_dotenv 7 | import sys 8 | import time 9 | import randomname 10 | import streamlit.components.v1 as components 11 | 12 | from pathlib import Path 13 | from documents import ManagedDocument, DocumentManager 14 | from audio import PODCAST_GEN, PodcastConfig 15 | from typing import Tuple 16 | from workflow import NotebookLMWorkflow, FileInputEvent, NotebookOutputEvent 17 | from instrumentation import OtelTracesSqlEngine 18 | from llama_index.observability.otel import LlamaIndexOpenTelemetry 19 | from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( 20 | OTLPSpanExporter, 21 | ) 22 | 23 | load_dotenv() 24 | 25 | # define a custom span exporter 26 | span_exporter = OTLPSpanExporter("http://localhost:4318/v1/traces") 27 | 28 | # initialize the instrumentation object 29 | instrumentor = LlamaIndexOpenTelemetry( 30 | service_name_or_resource="agent.traces", 31 | span_exporter=span_exporter, 32 | debug=True, 33 | ) 34 | engine_url = f"postgresql+psycopg2://{os.getenv('pgql_user')}:{os.getenv('pgql_psw')}@localhost:5432/{os.getenv('pgql_db')}" 35 | sql_engine = OtelTracesSqlEngine( 36 | engine_url=engine_url, 37 | table_name="agent_traces", 38 | service_name="agent.traces", 39 | ) 40 | document_manager = DocumentManager(engine_url=engine_url) 41 | 42 | WF = NotebookLMWorkflow(timeout=600) 43 | 44 | 45 | # Read the HTML file 46 | def read_html_file(file_path: str) -> str: 47 | with open(file_path, "r", encoding="utf-8") as f: 48 | return f.read() 49 | 50 | 51 | async def run_workflow( 52 | file: io.BytesIO, document_title: str 53 | ) -> Tuple[str, str, str, str, str]: 54 | # Create temp file with proper Windows handling 55 | with temp.NamedTemporaryFile(suffix=".pdf", delete=False) as fl: 56 | content = file.getvalue() 57 | fl.write(content) 58 | fl.flush() # Ensure data is written 59 | temp_path = fl.name 60 | 61 | try: 62 | st_time = int(time.time() * 1000000) 63 | ev = FileInputEvent(file=temp_path) 64 | result: NotebookOutputEvent = await WF.run(start_event=ev) 65 | 66 | q_and_a = "" 67 | for q, a in zip(result.questions, result.answers): 68 | q_and_a += f"**{q}**\n\n{a}\n\n" 69 | bullet_points = "## Bullet Points\n\n- " + "\n- ".join(result.highlights) 70 | 71 | mind_map = result.mind_map 72 | if Path(mind_map).is_file(): 73 | mind_map = read_html_file(mind_map) 74 | try: 75 | os.remove(result.mind_map) 76 | except OSError: 77 | pass # File might be locked on Windows 78 | 79 | end_time = int(time.time() * 1000000) 80 | sql_engine.to_sql_database(start_time=st_time, end_time=end_time) 81 | document_manager.put_documents( 82 | [ 83 | ManagedDocument( 84 | document_name=document_title, 85 | content=result.md_content, 86 | summary=result.summary, 87 | q_and_a=q_and_a, 88 | mindmap=mind_map, 89 | bullet_points=bullet_points, 90 | ) 91 | ] 92 | ) 93 | return result.md_content, result.summary, q_and_a, bullet_points, mind_map 94 | 95 | finally: 96 | try: 97 | os.remove(temp_path) 98 | except OSError: 99 | await asyncio.sleep(0.1) 100 | try: 101 | os.remove(temp_path) 102 | except OSError: 103 | pass # Give up if still locked 104 | 105 | 106 | def sync_run_workflow(file: io.BytesIO, document_title: str): 107 | try: 108 | # Try to use existing event loop 109 | loop = asyncio.get_event_loop() 110 | if loop.is_running(): 111 | # If loop is already running, schedule the coroutine 112 | import concurrent.futures 113 | 114 | with concurrent.futures.ThreadPoolExecutor() as executor: 115 | future = executor.submit( 116 | asyncio.run, run_workflow(file, document_title) 117 | ) 118 | return future.result() 119 | else: 120 | return loop.run_until_complete(run_workflow(file, document_title)) 121 | except RuntimeError: 122 | # No event loop exists, create one 123 | if sys.platform == "win32": 124 | asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) 125 | return asyncio.run(run_workflow(file, document_title)) 126 | 127 | 128 | async def create_podcast(file_content: str, config: PodcastConfig = None): 129 | audio_fl = await PODCAST_GEN.create_conversation( 130 | file_transcript=file_content, config=config 131 | ) 132 | return audio_fl 133 | 134 | 135 | def sync_create_podcast(file_content: str, config: PodcastConfig = None): 136 | return asyncio.run(create_podcast(file_content=file_content, config=config)) 137 | 138 | 139 | # Display the network 140 | st.set_page_config( 141 | page_title="NotebookLlaMa - Home", 142 | page_icon="🏠", 143 | layout="wide", 144 | menu_items={ 145 | "Get Help": "https://github.com/run-llama/notebooklm-clone/discussions/categories/general", 146 | "Report a bug": "https://github.com/run-llama/notebooklm-clone/issues/", 147 | "About": "An OSS alternative to NotebookLM that runs with the power of a flully Llama!", 148 | }, 149 | ) 150 | st.sidebar.header("Home🏠") 151 | st.sidebar.info("To switch to the Document Chat, select it from above!🔺") 152 | st.markdown("---") 153 | st.markdown("## NotebookLlaMa - Home🦙") 154 | 155 | # Initialize session state BEFORE creating the text input 156 | if "workflow_results" not in st.session_state: 157 | st.session_state.workflow_results = None 158 | if "document_title" not in st.session_state: 159 | st.session_state.document_title = randomname.get_name( 160 | adj=("music_theory", "geometry", "emotions"), noun=("cats", "food") 161 | ) 162 | 163 | # Use session_state as the value and update it when changed 164 | document_title = st.text_input( 165 | label="Document Title", 166 | value=st.session_state.document_title, 167 | key="document_title_input", 168 | ) 169 | 170 | # Update session state when the input changes 171 | if document_title != st.session_state.document_title: 172 | st.session_state.document_title = document_title 173 | 174 | file_input = st.file_uploader( 175 | label="Upload your source PDF file!", accept_multiple_files=False 176 | ) 177 | 178 | if file_input is not None: 179 | # First button: Process Document 180 | if st.button("Process Document", type="primary"): 181 | with st.spinner("Processing document... This may take a few minutes."): 182 | try: 183 | md_content, summary, q_and_a, bullet_points, mind_map = ( 184 | sync_run_workflow(file_input, st.session_state.document_title) 185 | ) 186 | st.session_state.workflow_results = { 187 | "md_content": md_content, 188 | "summary": summary, 189 | "q_and_a": q_and_a, 190 | "bullet_points": bullet_points, 191 | "mind_map": mind_map, 192 | } 193 | st.success("Document processed successfully!") 194 | except Exception as e: 195 | st.error(f"Error processing document: {str(e)}") 196 | 197 | # Display results if available 198 | if st.session_state.workflow_results: 199 | results = st.session_state.workflow_results 200 | 201 | # Summary 202 | st.markdown("## Summary") 203 | st.markdown(results["summary"]) 204 | 205 | # Bullet Points 206 | st.markdown(results["bullet_points"]) 207 | 208 | # FAQ (toggled) 209 | with st.expander("FAQ"): 210 | st.markdown(results["q_and_a"]) 211 | 212 | # Mind Map 213 | if results["mind_map"]: 214 | st.markdown("## Mind Map") 215 | components.html(results["mind_map"], height=800, scrolling=True) 216 | 217 | # Podcast Configuration Panel 218 | st.markdown("---") 219 | st.markdown("## Podcast Configuration") 220 | 221 | with st.expander("Customize Your Podcast", expanded=False): 222 | col1, col2 = st.columns(2) 223 | 224 | with col1: 225 | style = st.selectbox( 226 | "Conversation Style", 227 | ["conversational", "interview", "debate", "educational"], 228 | help="The overall style of the podcast conversation", 229 | ) 230 | 231 | tone = st.selectbox( 232 | "Tone", 233 | ["friendly", "professional", "casual", "energetic"], 234 | help="The tone of voice for the conversation", 235 | ) 236 | 237 | target_audience = st.selectbox( 238 | "Target Audience", 239 | ["general", "technical", "business", "expert", "beginner"], 240 | help="Who is the intended audience for this podcast?", 241 | ) 242 | 243 | with col2: 244 | speaker1_role = st.text_input( 245 | "Speaker 1 Role", 246 | value="host", 247 | help="The role or persona of the first speaker", 248 | ) 249 | 250 | speaker2_role = st.text_input( 251 | "Speaker 2 Role", 252 | value="guest", 253 | help="The role or persona of the second speaker", 254 | ) 255 | 256 | # Focus Topics 257 | st.markdown("**Focus Topics** (optional)") 258 | focus_topics_input = st.text_area( 259 | "Enter topics to emphasize (one per line)", 260 | help="List specific topics you want the podcast to focus on. Leave empty for general coverage.", 261 | placeholder="How can this be applied for Machine Learning Applications?\nUnderstand the historical context\nFuture Implications", 262 | ) 263 | 264 | # Parse focus topics 265 | focus_topics = None 266 | if focus_topics_input.strip(): 267 | focus_topics = [ 268 | topic.strip() 269 | for topic in focus_topics_input.split("\n") 270 | if topic.strip() 271 | ] 272 | 273 | # Custom Prompt 274 | custom_prompt = st.text_area( 275 | "Custom Instructions (optional)", 276 | help="Add any additional instructions for the podcast generation", 277 | placeholder="Make sure to explain technical concepts simply and include real-world examples...", 278 | ) 279 | 280 | # Create config object 281 | podcast_config = PodcastConfig( 282 | style=style, 283 | tone=tone, 284 | focus_topics=focus_topics, 285 | target_audience=target_audience, 286 | custom_prompt=custom_prompt if custom_prompt.strip() else None, 287 | speaker1_role=speaker1_role, 288 | speaker2_role=speaker2_role, 289 | ) 290 | 291 | # Second button: Generate Podcast 292 | if st.button("Generate In-Depth Conversation", type="secondary"): 293 | with st.spinner("Generating podcast... This may take several minutes."): 294 | try: 295 | audio_file = sync_create_podcast( 296 | results["md_content"], config=podcast_config 297 | ) 298 | st.success("Podcast generated successfully!") 299 | 300 | # Display audio player 301 | st.markdown("## Generated Podcast") 302 | if os.path.exists(audio_file): 303 | with open(audio_file, "rb") as f: 304 | audio_bytes = f.read() 305 | os.remove(audio_file) 306 | st.audio(audio_bytes, format="audio/mp3") 307 | else: 308 | st.error("Audio file not found.") 309 | 310 | except Exception as e: 311 | st.error(f"Error generating podcast: {str(e)}") 312 | 313 | else: 314 | st.info("Please upload a PDF file to get started.") 315 | 316 | if __name__ == "__main__": 317 | instrumentor.start_registering() 318 | -------------------------------------------------------------------------------- /src/notebookllama/audio.py: -------------------------------------------------------------------------------- 1 | import tempfile as temp 2 | import os 3 | import uuid 4 | from dotenv import load_dotenv 5 | import logging 6 | from contextlib import asynccontextmanager 7 | 8 | from pydub import AudioSegment 9 | from elevenlabs import AsyncElevenLabs 10 | from llama_index.core.llms.structured_llm import StructuredLLM 11 | from typing_extensions import Self 12 | from typing import List, Literal, Optional, AsyncIterator 13 | from pydantic import BaseModel, ConfigDict, model_validator, Field 14 | from llama_index.core.llms import ChatMessage 15 | from llama_index.llms.openai import OpenAIResponses 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class ConversationTurn(BaseModel): 21 | speaker: Literal["speaker1", "speaker2"] = Field( 22 | description="The person who is speaking", 23 | ) 24 | content: str = Field( 25 | description="The content of the speech", 26 | ) 27 | 28 | 29 | class MultiTurnConversation(BaseModel): 30 | conversation: List[ConversationTurn] = Field( 31 | description="List of conversation turns. Conversation must start with speaker1, and continue with an alternance of speaker1 and speaker2", 32 | min_length=3, 33 | max_length=50, 34 | examples=[ 35 | [ 36 | ConversationTurn(speaker="speaker1", content="Hello, who are you?"), 37 | ConversationTurn( 38 | speaker="speaker2", content="I am very well, how about you?" 39 | ), 40 | ConversationTurn(speaker="speaker1", content="I am well too, thanks!"), 41 | ] 42 | ], 43 | ) 44 | 45 | @model_validator(mode="after") 46 | def validate_conversation(self) -> Self: 47 | speakers = [turn.speaker for turn in self.conversation] 48 | if speakers[0] != "speaker1": 49 | raise ValueError("Conversation must start with speaker1") 50 | for i, speaker in enumerate(speakers): 51 | if i % 2 == 0 and speaker != "speaker1": 52 | raise ValueError( 53 | "Conversation must be an alternance between speaker1 and speaker2" 54 | ) 55 | elif i % 2 != 0 and speaker != "speaker2": 56 | raise ValueError( 57 | "Conversation must be an alternance between speaker1 and speaker2" 58 | ) 59 | continue 60 | return self 61 | 62 | 63 | class VoiceConfig(BaseModel): 64 | """Configuration for voice settings""" 65 | 66 | speaker1_voice_id: str = Field( 67 | default="nPczCjzI2devNBz1zQrb", description="Voice ID for speaker 1" 68 | ) 69 | speaker2_voice_id: str = Field( 70 | default="Xb7hH8MSUJpSbSDYk0k2", description="Voice ID for speaker 2" 71 | ) 72 | model_id: str = Field( 73 | default="eleven_turbo_v2_5", description="ElevenLabs model ID" 74 | ) 75 | output_format: str = Field( 76 | default="mp3_22050_32", description="Audio output format" 77 | ) 78 | 79 | 80 | class AudioQuality(BaseModel): 81 | """Configuration for audio quality settings""" 82 | 83 | bitrate: str = Field(default="320k", description="Audio bitrate") 84 | quality_params: List[str] = Field( 85 | default=["-q:a", "0"], description="Additional quality parameters" 86 | ) 87 | 88 | 89 | class PodcastConfig(BaseModel): 90 | """Configuration for podcast generation""" 91 | 92 | # Basic style options 93 | style: Literal["conversational", "interview", "debate", "educational"] = Field( 94 | default="conversational", 95 | description="The style of the conversation", 96 | ) 97 | tone: Literal["friendly", "professional", "casual", "energetic"] = Field( 98 | default="friendly", 99 | description="The tone of the conversation", 100 | ) 101 | 102 | # Content focus 103 | focus_topics: Optional[List[str]] = Field( 104 | default=None, description="Specific topics to focus on during the conversation" 105 | ) 106 | 107 | # Audience targeting 108 | target_audience: Literal[ 109 | "general", "technical", "business", "expert", "beginner" 110 | ] = Field( 111 | default="general", 112 | description="The target audience for the conversation", 113 | ) 114 | 115 | # Custom user prompt 116 | custom_prompt: Optional[str] = Field( 117 | default=None, description="Additional instructions for the conversation" 118 | ) 119 | 120 | # Speaker customization 121 | speaker1_role: str = Field( 122 | default="host", description="The role of the first speaker" 123 | ) 124 | speaker2_role: str = Field( 125 | default="guest", description="The role of the second speaker" 126 | ) 127 | 128 | # Voice and audio configuration 129 | voice_config: VoiceConfig = Field( 130 | default_factory=VoiceConfig, description="Voice configuration for speakers" 131 | ) 132 | audio_quality: AudioQuality = Field( 133 | default_factory=AudioQuality, description="Audio quality settings" 134 | ) 135 | 136 | 137 | class PodcastGeneratorError(Exception): 138 | """Base exception for podcast generator errors""" 139 | 140 | pass 141 | 142 | 143 | class AudioGenerationError(PodcastGeneratorError): 144 | """Raised when audio generation fails""" 145 | 146 | pass 147 | 148 | 149 | class ConversationGenerationError(PodcastGeneratorError): 150 | """Raised when conversation generation fails""" 151 | 152 | pass 153 | 154 | 155 | class PodcastGenerator(BaseModel): 156 | llm: StructuredLLM 157 | client: AsyncElevenLabs 158 | 159 | model_config = ConfigDict(arbitrary_types_allowed=True) 160 | 161 | @model_validator(mode="after") 162 | def validate_podcast(self) -> Self: 163 | try: 164 | assert self.llm.output_cls == MultiTurnConversation 165 | except AssertionError: 166 | raise ValueError( 167 | f"The output class of the structured LLM must be {MultiTurnConversation.__qualname__}, your LLM has output class: {self.llm.output_cls.__qualname__}" 168 | ) 169 | return self 170 | 171 | def _build_conversation_prompt( 172 | self, file_transcript: str, config: PodcastConfig 173 | ) -> str: 174 | """Build a customized prompt based on the configuration""" 175 | 176 | # Base prompt with style and tone 177 | prompt = f"""Create a {config.style} podcast conversation with two speakers from this transcript. 178 | 179 | CONVERSATION STYLE: {config.style} 180 | TONE: {config.tone} 181 | TARGET AUDIENCE: {config.target_audience} 182 | 183 | SPEAKER ROLES: 184 | - Speaker 1: {config.speaker1_role} 185 | - Speaker 2: {config.speaker2_role} 186 | """ 187 | 188 | if config.focus_topics: 189 | prompt += "\nFOCUS TOPICS: Make sure to discuss these topics in detail:\n" 190 | for topic in config.focus_topics: 191 | prompt += f"- {topic}\n" 192 | 193 | # Add audience-specific instructions 194 | audience_instructions = { 195 | "technical": "Use technical terminology appropriately and dive deep into technical details.", 196 | "beginner": "Explain concepts clearly and avoid jargon. Define technical terms when used.", 197 | "expert": "Assume advanced knowledge and discuss nuanced aspects and implications.", 198 | "business": "Focus on practical applications, ROI, and strategic implications.", 199 | "general": "Balance accessibility with depth, explaining key concepts clearly.", 200 | } 201 | 202 | prompt += ( 203 | f"\nAUDIENCE APPROACH: {audience_instructions[config.target_audience]}\n" 204 | ) 205 | 206 | # Add custom prompt if provided 207 | if config.custom_prompt: 208 | prompt += f"\nADDITIONAL INSTRUCTIONS: {config.custom_prompt}\n" 209 | 210 | # Add the source material 211 | prompt += f"\nSOURCE MATERIAL:\n'''\n{file_transcript}\n'''\n" 212 | 213 | # Add final instructions 214 | prompt += """ 215 | IMPORTANT: Create an engaging, natural conversation that flows well between the two speakers. 216 | The conversation should feel authentic and provide value to the target audience. 217 | """ 218 | 219 | return prompt 220 | 221 | async def _conversation_script( 222 | self, file_transcript: str, config: PodcastConfig 223 | ) -> MultiTurnConversation: 224 | """Generate conversation script with customization""" 225 | logger.info("Generating conversation script...") 226 | prompt = self._build_conversation_prompt(file_transcript, config) 227 | 228 | response = await self.llm.achat( 229 | messages=[ 230 | ChatMessage( 231 | role="user", 232 | content=prompt, 233 | ) 234 | ] 235 | ) 236 | 237 | conversation = MultiTurnConversation.model_validate_json( 238 | response.message.content 239 | ) 240 | logger.info( 241 | f"Generated conversation with {len(conversation.conversation)} turns" 242 | ) 243 | return conversation 244 | 245 | @asynccontextmanager 246 | async def _cleanup_files(self, files: List[str]) -> AsyncIterator[None]: 247 | """Context manager to ensure temporary files are cleaned up""" 248 | try: 249 | yield 250 | finally: 251 | for file_path in files: 252 | try: 253 | if os.path.exists(file_path): 254 | os.remove(file_path) 255 | logger.debug(f"Cleaned up temporary file: {file_path}") 256 | except OSError as e: 257 | logger.warning(f"Failed to clean up file {file_path}: {str(e)}") 258 | 259 | async def _generate_speech_file( 260 | self, text: str, voice_id: str, config: PodcastConfig 261 | ) -> str: 262 | """Generate speech file for a single turn""" 263 | try: 264 | speech_iterator = self.client.text_to_speech.convert( 265 | voice_id=voice_id, 266 | text=text, 267 | output_format=config.voice_config.output_format, 268 | model_id=config.voice_config.model_id, 269 | ) 270 | 271 | temp_file = temp.NamedTemporaryFile( 272 | suffix=".mp3", delete=False, delete_on_close=False 273 | ) 274 | 275 | with open(temp_file.name, "wb") as f: 276 | async for chunk in speech_iterator: 277 | if chunk: 278 | f.write(chunk) 279 | 280 | return temp_file.name 281 | except Exception as e: 282 | logger.error(f"Failed to generate speech for text: {text[:50]}") 283 | raise AudioGenerationError( 284 | f"Failed to generate speech for text: {text[:50]}" 285 | ) from e 286 | 287 | async def _conversation_audio( 288 | self, conversation: MultiTurnConversation, config: PodcastConfig 289 | ) -> str: 290 | """Generate audio for the conversation""" 291 | files: List[str] = [] 292 | 293 | async with self._cleanup_files(files): 294 | try: 295 | logger.info("Generating audio for conversation") 296 | 297 | for i, turn in enumerate(conversation.conversation): 298 | voice_id = ( 299 | config.voice_config.speaker1_voice_id 300 | if turn.speaker == "speaker1" 301 | else config.voice_config.speaker2_voice_id 302 | ) 303 | file_path = await self._generate_speech_file( 304 | turn.content, voice_id, config 305 | ) 306 | files.append(file_path) 307 | 308 | logger.info("Combining audio files...") 309 | output_path = f"conversation_{str(uuid.uuid4())}.mp3" 310 | combined_audio: AudioSegment = AudioSegment.empty() 311 | 312 | for file_path in files: 313 | audio = AudioSegment.from_file(file_path) 314 | combined_audio += audio 315 | 316 | combined_audio.export( 317 | output_path, 318 | format="mp3", 319 | bitrate=config.audio_quality.bitrate, 320 | parameters=config.audio_quality.quality_params, 321 | ) 322 | 323 | logger.info(f"Successfully created podcast audio: {output_path}") 324 | return output_path 325 | except Exception as e: 326 | logger.error(f"Failed to generate conversation audio: {str(e)}") 327 | raise AudioGenerationError( 328 | f"Failed to generate conversation audio: {str(e)}" 329 | ) from e 330 | 331 | async def create_conversation( 332 | self, file_transcript: str, config: Optional[PodcastConfig] = None 333 | ): 334 | """Main method to create a customized podcast conversation""" 335 | if config is None: 336 | config = PodcastConfig() 337 | 338 | try: 339 | logger.info("Starting podcast generation...") 340 | 341 | conversation = await self._conversation_script( 342 | file_transcript=file_transcript, config=config 343 | ) 344 | podcast_file = await self._conversation_audio( 345 | conversation=conversation, config=config 346 | ) 347 | 348 | logger.info("Podcast generation completed successfully") 349 | return podcast_file 350 | 351 | except (ConversationGenerationError, AudioGenerationError) as e: 352 | logger.error(f"Failed to generate podcast: {str(e)}") 353 | raise 354 | except Exception as e: 355 | logger.error(f"Unexpected error in podcast generation: {str(e)}") 356 | raise PodcastGeneratorError( 357 | f"Unexpected error in podcast generation: {str(e)}" 358 | ) from e 359 | 360 | 361 | load_dotenv() 362 | 363 | PODCAST_GEN: Optional[PodcastGenerator] 364 | 365 | if os.getenv("ELEVENLABS_API_KEY", None) and os.getenv("OPENAI_API_KEY", None): 366 | SLLM = OpenAIResponses( 367 | model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY") 368 | ).as_structured_llm(MultiTurnConversation) 369 | EL_CLIENT = AsyncElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY")) 370 | PODCAST_GEN = PodcastGenerator(llm=SLLM, client=EL_CLIENT) 371 | else: 372 | logger.warning("Missing API keys - PODCAST_GEN not initialized") 373 | PODCAST_GEN = None 374 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | import pandas as pd 4 | from pathlib import Path 5 | from dotenv import load_dotenv 6 | from unittest.mock import patch, MagicMock 7 | 8 | from typing import Callable 9 | from pydantic import ValidationError 10 | from src.notebookllama.processing import ( 11 | process_file, 12 | md_table_to_pd_dataframe, 13 | rename_and_remove_current_images, 14 | rename_and_remove_past_images, 15 | MarkdownTextAnalyzer, 16 | ) 17 | from src.notebookllama.mindmap import get_mind_map 18 | from src.notebookllama.models import Notebook 19 | from src.notebookllama.utils import ( 20 | get_llamacloud_base_url, 21 | get_llamacloud_config, 22 | create_llamacloud_client, 23 | create_llama_extract_client, 24 | create_llama_parse_client, 25 | create_llamacloud_index, 26 | LlamaCloudConfigError, 27 | LLAMACLOUD_REGIONS, 28 | ) 29 | 30 | load_dotenv() 31 | 32 | skip_condition = not ( 33 | os.getenv("LLAMACLOUD_API_KEY", None) 34 | and os.getenv("EXTRACT_AGENT_ID", None) 35 | and os.getenv("LLAMACLOUD_PIPELINE_ID", None) 36 | and os.getenv("OPENAI_API_KEY", None) 37 | ) 38 | 39 | 40 | @pytest.fixture() 41 | def input_file() -> str: 42 | return "data/test/brain_for_kids.pdf" 43 | 44 | 45 | @pytest.fixture() 46 | def markdown_file() -> str: 47 | return "data/test/md_sample.md" 48 | 49 | 50 | @pytest.fixture() 51 | def images_dir() -> str: 52 | return "data/test/images/" 53 | 54 | 55 | @pytest.fixture() 56 | def dataframe_from_tables() -> pd.DataFrame: 57 | project_data = { 58 | "Project Name": [ 59 | "User Dashboard", 60 | "API Integration", 61 | "Mobile App", 62 | "Database Migration", 63 | "Security Audit", 64 | ], 65 | "Status": [ 66 | "In Progress", 67 | "Completed", 68 | "Planning", 69 | "In Progress", 70 | "Not Started", 71 | ], 72 | "Completion %": ["75%", "100%", "25%", "60%", "0%"], 73 | "Assigned Developer": [ 74 | "Alice Johnson", 75 | "Bob Smith", 76 | "Carol Davis", 77 | "David Wilson", 78 | "Eve Brown", 79 | ], 80 | "Due Date": [ 81 | "2025-07-15", 82 | "2025-06-30", 83 | "2025-08-20", 84 | "2025-07-10", 85 | "2025-08-01", 86 | ], 87 | } 88 | 89 | df = pd.DataFrame(project_data) 90 | return df 91 | 92 | 93 | @pytest.fixture() 94 | def file_exists_fn() -> Callable[[str], bool]: 95 | def file_exists(file_path: str) -> bool: 96 | return Path(file_path).exists() 97 | 98 | return file_exists 99 | 100 | 101 | @pytest.fixture() 102 | def is_not_empty_fn() -> Callable[[str], bool]: 103 | def is_not_empty(file_path: str) -> bool: 104 | return Path(file_path).stat().st_size > 0 105 | 106 | return is_not_empty 107 | 108 | 109 | @pytest.fixture 110 | def notebook_to_process() -> Notebook: 111 | return Notebook( 112 | summary="""The Human Brain: 113 | The human brain is a complex organ responsible for thought, memory, emotion, and coordination. It contains about 86 billion neurons and operates through electrical and chemical signals. Divided into major parts like the cerebrum, cerebellum, and brainstem, it controls everything from basic survival functions to advanced reasoning. Despite its size, it consumes around 20% of the body’s energy. Neuroscience continues to explore its mysteries, including consciousness and neuroplasticity—its ability to adapt and reorganize.""", 114 | questions=[ 115 | "How many neurons are in the human brain?", 116 | "What are the main parts of the human brain?", 117 | "What percentage of the body's energy does the brain use?", 118 | "What is neuroplasticity?", 119 | "What functions is the human brain responsible for?", 120 | ], 121 | answers=[ 122 | "About 86 billion neurons.", 123 | "The cerebrum, cerebellum, and brainstem.", 124 | "Around 20%.", 125 | "The brain's ability to adapt and reorganize itself.", 126 | "Thought, memory, emotion, and coordination.", 127 | ], 128 | highlights=[ 129 | "The human brain has about 86 billion neurons.", 130 | "It controls thought, memory, emotion, and coordination.", 131 | "Major brain parts include the cerebrum, cerebellum, and brainstem.", 132 | "The brain uses approximately 20% of the body's energy.", 133 | "Neuroplasticity allows the brain to adapt and reorganize.", 134 | ], 135 | ) 136 | 137 | 138 | @pytest.mark.skipif( 139 | condition=skip_condition, 140 | reason="You do not have the necessary env variables to run this test.", 141 | ) 142 | @pytest.mark.asyncio 143 | async def test_mind_map_creation( 144 | notebook_to_process: Notebook, 145 | file_exists_fn: Callable[[str], bool], 146 | is_not_empty_fn: Callable[[str], bool], 147 | ): 148 | test_mindmap = await get_mind_map( 149 | summary=notebook_to_process.summary, highlights=notebook_to_process.highlights 150 | ) 151 | assert test_mindmap is not None 152 | assert file_exists_fn(test_mindmap) 153 | assert is_not_empty_fn(test_mindmap) 154 | os.remove(test_mindmap) 155 | 156 | 157 | @pytest.mark.skipif( 158 | condition=skip_condition, 159 | reason="You do not have the necessary env variables to run this test.", 160 | ) 161 | @pytest.mark.asyncio 162 | async def test_file_processing(input_file: str) -> None: 163 | notebook, text = await process_file(filename=input_file) 164 | print(notebook) 165 | assert notebook is not None 166 | assert isinstance(text, str) 167 | try: 168 | notebook_model = Notebook.model_validate_json(json_data=notebook) 169 | except ValidationError: 170 | notebook_model = None 171 | assert isinstance(notebook_model, Notebook) 172 | 173 | 174 | def test_table_to_dataframe( 175 | markdown_file: str, dataframe_from_tables: pd.DataFrame 176 | ) -> None: 177 | with open(markdown_file, "r") as f: 178 | text = f.read() 179 | analyzer = MarkdownTextAnalyzer(text) 180 | md_tables = analyzer.identify_tables()["Table"] 181 | assert len(md_tables) == 2 182 | for md_table in md_tables: 183 | df = md_table_to_pd_dataframe(md_table) 184 | assert df is not None 185 | assert df.equals(dataframe_from_tables) 186 | 187 | 188 | def test_images_renaming(images_dir: str): 189 | images = [os.path.join(images_dir, f) for f in os.listdir(images_dir)] 190 | imgs = rename_and_remove_current_images(images) 191 | assert all("_current" in img for img in imgs) 192 | assert all(os.path.exists(img) for img in imgs) 193 | renamed = rename_and_remove_past_images(images_dir) 194 | assert all("_at_" in img for img in renamed) 195 | assert all("_current" not in img for img in renamed) 196 | assert all(os.path.exists(img) for img in renamed) 197 | for image in renamed: 198 | with open(image, "rb") as rb: 199 | bts = rb.read() 200 | with open(images_dir + "image.png", "wb") as wb: 201 | wb.write(bts) 202 | os.remove(image) 203 | 204 | 205 | # ============================================================================= 206 | # Regional LlamaCloud Utilities Tests 207 | # ============================================================================= 208 | 209 | 210 | class TestLlamaCloudRegionalUtils: 211 | """Test suite for regional LlamaCloud utility functions.""" 212 | 213 | def test_llamacloud_regions_constant(self): 214 | """Test that LLAMACLOUD_REGIONS contains expected regions.""" 215 | assert "default" in LLAMACLOUD_REGIONS 216 | assert "eu" in LLAMACLOUD_REGIONS 217 | assert LLAMACLOUD_REGIONS["default"] == "https://api.cloud.llamaindex.ai" 218 | assert LLAMACLOUD_REGIONS["eu"] == "https://api.cloud.eu.llamaindex.ai" 219 | 220 | @patch.dict(os.environ, {}, clear=True) 221 | def test_get_llamacloud_base_url_no_region(self): 222 | """Test get_llamacloud_base_url with no region set (defaults to North America).""" 223 | result = get_llamacloud_base_url() 224 | assert result == "https://api.cloud.llamaindex.ai" 225 | 226 | @patch.dict(os.environ, {"LLAMACLOUD_REGION": "eu"}) 227 | def test_get_llamacloud_base_url_eu_region(self): 228 | """Test get_llamacloud_base_url with EU region.""" 229 | result = get_llamacloud_base_url() 230 | assert result == "https://api.cloud.eu.llamaindex.ai" 231 | 232 | @patch.dict(os.environ, {"LLAMACLOUD_REGION": "default"}) 233 | def test_get_llamacloud_base_url_default_region(self): 234 | """Test get_llamacloud_base_url with default region.""" 235 | result = get_llamacloud_base_url() 236 | assert result == "https://api.cloud.llamaindex.ai" 237 | 238 | @patch.dict(os.environ, {"LLAMACLOUD_REGION": "DEFAULT"}) 239 | def test_get_llamacloud_base_url_case_insensitive(self): 240 | """Test get_llamacloud_base_url with case insensitive region.""" 241 | result = get_llamacloud_base_url() 242 | assert result == "https://api.cloud.llamaindex.ai" 243 | 244 | @patch.dict(os.environ, {"LLAMACLOUD_REGION": " default "}) 245 | def test_get_llamacloud_base_url_strips_whitespace(self): 246 | """Test get_llamacloud_base_url strips whitespace from region.""" 247 | result = get_llamacloud_base_url() 248 | assert result == "https://api.cloud.llamaindex.ai" 249 | 250 | @patch.dict(os.environ, {"LLAMACLOUD_BASE_URL": "https://custom.api.com"}) 251 | def test_get_llamacloud_base_url_custom_override(self): 252 | """Test get_llamacloud_base_url with custom base URL override.""" 253 | result = get_llamacloud_base_url() 254 | assert result == "https://custom.api.com" 255 | 256 | @patch.dict( 257 | os.environ, 258 | {"LLAMACLOUD_BASE_URL": "https://custom.api.com", "LLAMACLOUD_REGION": "eu"}, 259 | ) 260 | def test_get_llamacloud_base_url_custom_override_precedence(self): 261 | """Test that custom base URL takes precedence over region.""" 262 | result = get_llamacloud_base_url() 263 | assert result == "https://custom.api.com" 264 | 265 | @patch.dict(os.environ, {"LLAMACLOUD_REGION": "invalid"}) 266 | def test_get_llamacloud_base_url_invalid_region(self): 267 | """Test get_llamacloud_base_url with invalid region raises error.""" 268 | with pytest.raises(LlamaCloudConfigError) as exc_info: 269 | get_llamacloud_base_url() 270 | assert "Invalid LLAMACLOUD_REGION 'invalid'" in str(exc_info.value) 271 | assert "default, eu" in str(exc_info.value) 272 | 273 | @patch.dict(os.environ, {"LLAMACLOUD_API_KEY": "test-key"}, clear=True) 274 | def test_get_llamacloud_config_valid(self): 275 | """Test get_llamacloud_config with valid API key (defaults to North America).""" 276 | result = get_llamacloud_config() 277 | expected = {"token": "test-key", "base_url": "https://api.cloud.llamaindex.ai"} 278 | assert result == expected 279 | 280 | @patch.dict( 281 | os.environ, 282 | {"LLAMACLOUD_API_KEY": "test-key", "LLAMACLOUD_REGION": "eu"}, 283 | clear=True, 284 | ) 285 | def test_get_llamacloud_config_with_region(self): 286 | """Test get_llamacloud_config with region.""" 287 | result = get_llamacloud_config() 288 | expected = { 289 | "token": "test-key", 290 | "base_url": "https://api.cloud.eu.llamaindex.ai", 291 | } 292 | assert result == expected 293 | 294 | @patch.dict(os.environ, {}, clear=True) 295 | def test_get_llamacloud_config_missing_api_key(self): 296 | """Test get_llamacloud_config with missing API key raises error.""" 297 | with pytest.raises(LlamaCloudConfigError): 298 | get_llamacloud_config() 299 | 300 | @patch.dict( 301 | os.environ, 302 | {"LLAMACLOUD_API_KEY": "test-key", "LLAMACLOUD_REGION": "invalid"}, 303 | clear=True, 304 | ) 305 | def test_get_llamacloud_config_invalid_region(self): 306 | """Test get_llamacloud_config with invalid region raises error.""" 307 | with pytest.raises(LlamaCloudConfigError): 308 | get_llamacloud_config() 309 | 310 | @patch("src.notebookllama.utils.AsyncLlamaCloud") 311 | @patch.dict(os.environ, {"LLAMACLOUD_API_KEY": "test-key"}, clear=True) 312 | def test_create_llamacloud_client_valid(self, mock_client_class): 313 | """Test create_llamacloud_client with valid configuration (defaults to North America).""" 314 | mock_instance = MagicMock() 315 | mock_client_class.return_value = mock_instance 316 | 317 | result = create_llamacloud_client() 318 | 319 | mock_client_class.assert_called_once_with( 320 | token="test-key", base_url="https://api.cloud.llamaindex.ai" 321 | ) 322 | assert result == mock_instance 323 | 324 | @patch("src.notebookllama.utils.AsyncLlamaCloud") 325 | @patch.dict( 326 | os.environ, 327 | {"LLAMACLOUD_API_KEY": "test-key", "LLAMACLOUD_REGION": "eu"}, 328 | clear=True, 329 | ) 330 | def test_create_llamacloud_client_with_region(self, mock_client_class): 331 | """Test create_llamacloud_client with region.""" 332 | mock_instance = MagicMock() 333 | mock_client_class.return_value = mock_instance 334 | 335 | result = create_llamacloud_client() 336 | 337 | mock_client_class.assert_called_once_with( 338 | token="test-key", base_url="https://api.cloud.eu.llamaindex.ai" 339 | ) 340 | assert result == mock_instance 341 | 342 | @patch.dict(os.environ, {}, clear=True) 343 | def test_create_llamacloud_client_missing_api_key(self): 344 | """Test create_llamacloud_client with missing API key raises error.""" 345 | with pytest.raises(LlamaCloudConfigError): 346 | create_llamacloud_client() 347 | 348 | @patch("src.notebookllama.utils.LlamaExtract") 349 | @patch.dict(os.environ, {"LLAMACLOUD_API_KEY": "test-key"}, clear=True) 350 | def test_create_llama_extract_client_valid(self, mock_extract_class): 351 | """Test create_llama_extract_client with valid configuration (defaults to North America).""" 352 | mock_instance = MagicMock() 353 | mock_extract_class.return_value = mock_instance 354 | 355 | result = create_llama_extract_client() 356 | 357 | mock_extract_class.assert_called_once_with( 358 | api_key="test-key", base_url="https://api.cloud.llamaindex.ai" 359 | ) 360 | assert result == mock_instance 361 | 362 | @patch("src.notebookllama.utils.LlamaExtract") 363 | @patch.dict( 364 | os.environ, 365 | {"LLAMACLOUD_API_KEY": "test-key", "LLAMACLOUD_REGION": "eu"}, 366 | clear=True, 367 | ) 368 | def test_create_llama_extract_client_with_region(self, mock_extract_class): 369 | """Test create_llama_extract_client with region.""" 370 | mock_instance = MagicMock() 371 | mock_extract_class.return_value = mock_instance 372 | 373 | result = create_llama_extract_client() 374 | 375 | mock_extract_class.assert_called_once_with( 376 | api_key="test-key", base_url="https://api.cloud.eu.llamaindex.ai" 377 | ) 378 | assert result == mock_instance 379 | 380 | @patch.dict(os.environ, {}, clear=True) 381 | def test_create_llama_extract_client_missing_api_key(self): 382 | """Test create_llama_extract_client with missing API key raises error.""" 383 | with pytest.raises(LlamaCloudConfigError): 384 | create_llama_extract_client() 385 | 386 | @patch("src.notebookllama.utils.LlamaParse") 387 | @patch.dict(os.environ, {"LLAMACLOUD_API_KEY": "test-key"}, clear=True) 388 | def test_create_llama_parse_client_default(self, mock_parse_class): 389 | """Test create_llama_parse_client with default parameters (defaults to North America).""" 390 | mock_instance = MagicMock() 391 | mock_parse_class.return_value = mock_instance 392 | 393 | result = create_llama_parse_client() 394 | 395 | mock_parse_class.assert_called_once_with( 396 | api_key="test-key", 397 | result_type="markdown", 398 | base_url="https://api.cloud.llamaindex.ai", 399 | ) 400 | assert result == mock_instance 401 | 402 | @patch("src.notebookllama.utils.LlamaParse") 403 | @patch.dict(os.environ, {"LLAMACLOUD_API_KEY": "test-key"}, clear=True) 404 | def test_create_llama_parse_client_custom_result_type(self, mock_parse_class): 405 | """Test create_llama_parse_client with custom result type (defaults to North America).""" 406 | mock_instance = MagicMock() 407 | mock_parse_class.return_value = mock_instance 408 | 409 | result = create_llama_parse_client(result_type="text") 410 | 411 | mock_parse_class.assert_called_once_with( 412 | api_key="test-key", 413 | result_type="text", 414 | base_url="https://api.cloud.llamaindex.ai", 415 | ) 416 | assert result == mock_instance 417 | 418 | @patch("src.notebookllama.utils.LlamaParse") 419 | @patch.dict( 420 | os.environ, 421 | {"LLAMACLOUD_API_KEY": "test-key", "LLAMACLOUD_REGION": "eu"}, 422 | clear=True, 423 | ) 424 | def test_create_llama_parse_client_with_region(self, mock_parse_class): 425 | """Test create_llama_parse_client with region.""" 426 | mock_instance = MagicMock() 427 | mock_parse_class.return_value = mock_instance 428 | 429 | result = create_llama_parse_client() 430 | 431 | mock_parse_class.assert_called_once_with( 432 | api_key="test-key", 433 | result_type="markdown", 434 | base_url="https://api.cloud.eu.llamaindex.ai", 435 | ) 436 | assert result == mock_instance 437 | 438 | @patch.dict(os.environ, {}, clear=True) 439 | def test_create_llama_parse_client_missing_api_key(self): 440 | """Test create_llama_parse_client with missing API key raises error.""" 441 | with pytest.raises(LlamaCloudConfigError): 442 | create_llama_parse_client() 443 | 444 | @patch("src.notebookllama.utils.LlamaCloudIndex") 445 | @patch.dict(os.environ, {}, clear=True) 446 | def test_create_llamacloud_index_valid(self, mock_index_class): 447 | """Test create_llamacloud_index with valid parameters (defaults to North America).""" 448 | mock_instance = MagicMock() 449 | mock_index_class.return_value = mock_instance 450 | 451 | result = create_llamacloud_index("test-key", "test-pipeline") 452 | 453 | mock_index_class.assert_called_once_with( 454 | api_key="test-key", 455 | pipeline_id="test-pipeline", 456 | base_url="https://api.cloud.llamaindex.ai", 457 | ) 458 | assert result == mock_instance 459 | 460 | @patch("src.notebookllama.utils.LlamaCloudIndex") 461 | @patch.dict(os.environ, {"LLAMACLOUD_REGION": "eu"}, clear=True) 462 | def test_create_llamacloud_index_with_region(self, mock_index_class): 463 | """Test create_llamacloud_index with region.""" 464 | mock_instance = MagicMock() 465 | mock_index_class.return_value = mock_instance 466 | 467 | result = create_llamacloud_index("test-key", "test-pipeline") 468 | 469 | mock_index_class.assert_called_once_with( 470 | api_key="test-key", 471 | pipeline_id="test-pipeline", 472 | base_url="https://api.cloud.eu.llamaindex.ai", 473 | ) 474 | assert result == mock_instance 475 | 476 | @patch.dict(os.environ, {}, clear=True) 477 | def test_create_llamacloud_index_missing_api_key(self): 478 | """Test create_llamacloud_index with missing API key raises error.""" 479 | with pytest.raises(LlamaCloudConfigError) as exc_info: 480 | create_llamacloud_index("", "test-pipeline") 481 | assert "API key is required" in str(exc_info.value) 482 | 483 | @patch.dict(os.environ, {}, clear=True) 484 | def test_create_llamacloud_index_missing_pipeline_id(self): 485 | """Test create_llamacloud_index with missing pipeline ID raises error.""" 486 | with pytest.raises(LlamaCloudConfigError) as exc_info: 487 | create_llamacloud_index("test-key", "") 488 | assert "Pipeline ID is required" in str(exc_info.value) 489 | 490 | @patch.dict(os.environ, {"LLAMACLOUD_REGION": "invalid"}, clear=True) 491 | def test_create_llamacloud_index_invalid_region(self): 492 | """Test create_llamacloud_index with invalid region raises error.""" 493 | with pytest.raises(LlamaCloudConfigError): 494 | create_llamacloud_index("test-key", "test-pipeline") 495 | --------------------------------------------------------------------------------