├── hn_search ├── __init__.py ├── rag │ ├── __init__.py │ ├── state.py │ ├── graph.py │ ├── cli.py │ ├── web_ui.py │ └── nodes.py ├── common.py ├── db_config.py ├── logging_config.py ├── query.py ├── cache_config.py ├── init_db_pgvector.py └── job_manager.py ├── Procfile ├── example.png ├── docker-compose.yml ├── Makefile ├── .env.example ├── misc ├── sync_embeddings.sh ├── optimize_and_create_index.sh ├── generate_embeddings_gpu.py └── fetch_and_embed_new_comments.py ├── pyproject.toml ├── .gitignore └── README.md /hn_search/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hn_search/rag/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Procfile: -------------------------------------------------------------------------------- 1 | web: uv run python -m hn_search.rag.web_ui 2 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/afiodorov/hn-search/HEAD/example.png -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | redis: 3 | image: redis:7-alpine 4 | container_name: hn-search-redis 5 | ports: 6 | - "6380:6379" 7 | volumes: 8 | - redis_data:/data 9 | command: redis-server --appendonly yes 10 | restart: unless-stopped 11 | 12 | volumes: 13 | redis_data: -------------------------------------------------------------------------------- /hn_search/rag/state.py: -------------------------------------------------------------------------------- 1 | from typing import List, TypedDict 2 | 3 | 4 | class SearchResult(TypedDict): 5 | id: str 6 | author: str 7 | type: str 8 | text: str 9 | timestamp: str 10 | distance: float 11 | 12 | 13 | class RAGState(TypedDict): 14 | query: str 15 | search_results: List[SearchResult] 16 | context: str 17 | answer: str 18 | error_message: str 19 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: format lint clean 2 | 3 | # Default target - run both formatting and linting 4 | format: 5 | uv run ruff check --select I --fix . 6 | uv run ruff format . 7 | 8 | # Just run import sorting 9 | imports: 10 | uv run ruff check --select I --fix . 11 | 12 | # Just run code formatting 13 | fmt: 14 | uv run ruff format . 15 | 16 | # Run linting (without fixes) 17 | lint: 18 | uv run ruff check . 19 | 20 | # Clean ChromaDB container and volumes 21 | clean: 22 | docker compose down -v 23 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # Database Configuration 2 | # PostgreSQL with pgvector extension 3 | DATABASE_URL=postgres://user:password@host:port/database 4 | 5 | # LLM API Keys 6 | # Get from: https://platform.deepseek.com/ 7 | DEEPSEEK_API_KEY=sk-your-api-key-here 8 | 9 | # Redis Cache (optional but recommended) 10 | # Improves query performance by 50-100x for repeated queries 11 | REDIS_URL=redis://localhost:6379 12 | 13 | # Google Cloud Platform (for BigQuery data fetching) 14 | # Only needed for incremental updates 15 | GOOGLE_CLOUD_PROJECT=your-gcp-project-id 16 | 17 | # Tokenizers Configuration 18 | # Set to false to avoid warnings in multi-threaded environments 19 | TOKENIZERS_PARALLELISM=false 20 | 21 | # Port Configuration (optional) 22 | # Default port for Gradio web UI 23 | PORT=7860 24 | -------------------------------------------------------------------------------- /hn_search/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sentence_transformers import SentenceTransformer 3 | 4 | from hn_search.logging_config import get_logger 5 | 6 | MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" 7 | 8 | # Singleton embedding model 9 | _model = None 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | def get_device(): 15 | return "mps" if torch.backends.mps.is_available() else "cpu" 16 | 17 | 18 | def get_model(device=None): 19 | """Get or create singleton embedding model.""" 20 | global _model 21 | if _model is None: 22 | if device is None: 23 | device = get_device() 24 | logger.info(f"🔧 Loading embedding model on {device}...") 25 | _model = SentenceTransformer(MODEL_NAME, device=device) 26 | logger.info(f"✅ Embedding model loaded") 27 | return _model 28 | -------------------------------------------------------------------------------- /hn_search/rag/graph.py: -------------------------------------------------------------------------------- 1 | from langgraph.graph import END, StateGraph 2 | 3 | from hn_search.logging_config import get_logger 4 | 5 | from .nodes import answer_node, retrieve_node 6 | from .state import RAGState 7 | 8 | logger = get_logger(__name__) 9 | 10 | # Singleton compiled workflow 11 | _compiled_workflow = None 12 | 13 | 14 | def create_rag_workflow(): 15 | """Get or create singleton compiled RAG workflow.""" 16 | global _compiled_workflow 17 | if _compiled_workflow is None: 18 | logger.info("🔧 Compiling RAG workflow...") 19 | workflow = StateGraph(RAGState) 20 | 21 | workflow.add_node("retrieve", retrieve_node) 22 | workflow.add_node("answer", answer_node) 23 | 24 | workflow.set_entry_point("retrieve") 25 | workflow.add_edge("retrieve", "answer") 26 | workflow.add_edge("answer", END) 27 | 28 | _compiled_workflow = workflow.compile() 29 | logger.info("✅ RAG workflow compiled") 30 | return _compiled_workflow 31 | -------------------------------------------------------------------------------- /hn_search/db_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from urllib.parse import urlparse 3 | 4 | 5 | def get_db_config(): 6 | """ 7 | Get database configuration from environment variables. 8 | Supports both Railway environment variables and DATABASE_URL. 9 | """ 10 | # First try DATABASE_URL (Railway standard) 11 | database_url = os.getenv("DATABASE_URL") 12 | if database_url: 13 | parsed = urlparse(database_url) 14 | return { 15 | "host": parsed.hostname, 16 | "port": parsed.port or 5432, 17 | "dbname": parsed.path[1:] if parsed.path else "postgres", 18 | "user": parsed.username, 19 | "password": parsed.password, 20 | } 21 | 22 | # Fallback to individual Railway environment variables 23 | return { 24 | "host": os.getenv("PGHOST", os.getenv("PGHOST_PRIVATE", "localhost")), 25 | "port": int(os.getenv("PGPORT", os.getenv("PGPORT_PRIVATE", "5432"))), 26 | "dbname": os.getenv("PGDATABASE", os.getenv("POSTGRES_DB", "hn_search")), 27 | "user": os.getenv("PGUSER", os.getenv("POSTGRES_USER", "postgres")), 28 | "password": os.getenv("PGPASSWORD", os.getenv("POSTGRES_PASSWORD", "postgres")), 29 | } 30 | -------------------------------------------------------------------------------- /hn_search/rag/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | from .graph import create_rag_workflow 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser( 9 | description="Ask questions about Hacker News discussions using RAG" 10 | ) 11 | parser.add_argument("query", type=str, help="Your question about HN discussions") 12 | args = parser.parse_args() 13 | 14 | print("\n" + "=" * 70) 15 | print("🔎 Hacker News RAG Search") 16 | print("=" * 70 + "\n") 17 | 18 | app = create_rag_workflow() 19 | 20 | initial_state = {"query": args.query} 21 | 22 | try: 23 | final_state = app.invoke(initial_state) 24 | 25 | if final_state.get("error_message"): 26 | print(f"\n❌ Error: {final_state['error_message']}\n") 27 | sys.exit(1) 28 | 29 | print("\n" + "-" * 70) 30 | print("💬 Answer:") 31 | print("-" * 70 + "\n") 32 | print(final_state["answer"]) 33 | 34 | print("\n" + "-" * 70) 35 | print(f"📚 Based on {len(final_state['search_results'])} HN comments/articles") 36 | print("-" * 70 + "\n") 37 | 38 | except Exception as e: 39 | print(f"\n❌ Error: {e}\n") 40 | sys.exit(1) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /misc/sync_embeddings.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | REMOTE_HOST="root@86.127.249.120" 4 | REMOTE_PORT="21604" 5 | REMOTE_DIR="/hn/embeddings" 6 | LOCAL_DIR="/Users/artiomfiodorov/code/hn-search/embeddings" 7 | 8 | mkdir -p "$LOCAL_DIR" 9 | 10 | while true; do 11 | echo "$(date): Checking for completed files..." 12 | 13 | completed_files=$(ssh -p "$REMOTE_PORT" "$REMOTE_HOST" ' 14 | cd '"$REMOTE_DIR"' 2>/dev/null || exit 0 15 | for f in *.parquet; do 16 | [ -f "$f" ] || continue 17 | size1=$(stat -c%s "$f" 2>/dev/null || stat -f%z "$f" 2>/dev/null) 18 | sleep 2 19 | size2=$(stat -c%s "$f" 2>/dev/null || stat -f%z "$f" 2>/dev/null) 20 | if [ "$size1" = "$size2" ]; then 21 | echo "$f" 22 | fi 23 | done 24 | ') 25 | 26 | if [ -n "$completed_files" ]; then 27 | echo "$(date): Found completed files, syncing..." 28 | for file in $completed_files; do 29 | rsync -avz --remove-source-files -e "ssh -p $REMOTE_PORT" \ 30 | "$REMOTE_HOST:$REMOTE_DIR/$file" "$LOCAL_DIR/" 31 | echo "$(date): Synced $file" 32 | done 33 | else 34 | echo "$(date): No completed files to sync" 35 | fi 36 | 37 | sleep 60 38 | done 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "hn-search" 3 | version = "0.1.0" 4 | description = "Hacker News search using vector embeddings" 5 | requires-python = ">=3.13" 6 | dependencies = [ 7 | "sentence-transformers", 8 | "torch>=2.0.0", 9 | "torchvision>=0.15.0", 10 | "torchaudio>=2.0.0", 11 | "langgraph", 12 | "langchain", 13 | "langchain-openai", 14 | "python-dotenv", 15 | "gradio", 16 | "pyarrow", 17 | "pandas", 18 | "psycopg[binary,pool]>=3.0", 19 | "pgvector", 20 | "redis>=6.4.0", 21 | "langchain-community>=0.3.30", 22 | ] 23 | 24 | [[tool.uv.index]] 25 | name = "pytorch-cpu" 26 | url = "https://download.pytorch.org/whl/cpu" 27 | explicit = true 28 | 29 | [tool.uv.sources] 30 | torch = { index = "pytorch-cpu", marker = "sys_platform == 'linux'" } 31 | torchvision = { index = "pytorch-cpu", marker = "sys_platform == 'linux'" } 32 | torchaudio = { index = "pytorch-cpu", marker = "sys_platform == 'linux'" } 33 | 34 | [project.optional-dependencies] 35 | dev = [ 36 | "db-dtypes>=1.4.3", 37 | "google-cloud-bigquery>=3.38.0", 38 | "google-cloud-bigquery-storage>=2.0.0", 39 | "html2text>=2025.4.15", 40 | "ruff", 41 | ] 42 | prod = [ 43 | "gunicorn>=21.0.0", 44 | "uvicorn[standard]>=0.24.0", 45 | ] 46 | 47 | [tool.ruff] 48 | line-length = 88 49 | target-version = "py313" 50 | 51 | [build-system] 52 | requires = ["hatchling"] 53 | build-backend = "hatchling.build" 54 | 55 | [tool.hatch.build.targets.wheel] 56 | packages = ["hn_search"] 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /misc/optimize_and_create_index.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Optimizing PostgreSQL settings for HNSW index creation..." 4 | echo "" 5 | 6 | # First, cancel any existing index creation 7 | echo "Canceling any existing index creation..." 8 | docker exec hn-search-postgres-1 psql -U postgres -d hn_search -c \ 9 | "SELECT pg_cancel_backend(pid) FROM pg_stat_activity WHERE query LIKE '%CREATE INDEX%hn_documents_embedding_idx%' AND state = 'active';" 10 | 11 | # Increase maintenance_work_mem for this session (4GB for better performance) 12 | echo "Setting maintenance_work_mem to 4GB for faster index creation..." 13 | docker exec hn-search-postgres-1 psql -U postgres -d hn_search -c \ 14 | "SET maintenance_work_mem = '4GB'; CREATE INDEX IF NOT EXISTS hn_documents_embedding_idx ON hn_documents USING hnsw (embedding vector_cosine_ops);" & 15 | 16 | PID=$! 17 | echo "" 18 | echo "Index creation started with PID: $PID" 19 | echo "Using 4GB of maintenance memory for faster building" 20 | echo "" 21 | echo "Note: The warning about memory after 16,758 tuples is expected for large datasets." 22 | echo "With 4GB memory, it should handle much more before needing to spill to disk." 23 | echo "" 24 | echo "You can check progress with:" 25 | echo " docker exec hn-search-postgres-1 psql -U postgres -d hn_search -c \"SELECT * FROM pg_stat_progress_create_index;\"" 26 | echo "" 27 | echo "Or check active queries:" 28 | echo " docker exec hn-search-postgres-1 psql -U postgres -d hn_search -c \"SELECT pid, now() - query_start AS duration, query FROM pg_stat_activity WHERE state = 'active' AND query LIKE '%INDEX%';\"" -------------------------------------------------------------------------------- /hn_search/logging_config.py: -------------------------------------------------------------------------------- 1 | """Centralized logging configuration for HN Search.""" 2 | 3 | import logging 4 | import sys 5 | import time 6 | from contextlib import contextmanager 7 | from functools import wraps 8 | from typing import Any, Callable 9 | 10 | 11 | def setup_logging(level: str = "INFO"): 12 | """ 13 | Configure logging for the application. 14 | 15 | Args: 16 | level: Logging level (DEBUG, INFO, WARNING, ERROR) 17 | """ 18 | logging.basicConfig( 19 | level=getattr(logging, level.upper()), 20 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 21 | datefmt="%Y-%m-%d %H:%M:%S", 22 | handlers=[logging.StreamHandler(sys.stdout)], 23 | ) 24 | 25 | 26 | def get_logger(name: str) -> logging.Logger: 27 | """Get a logger instance for a module.""" 28 | return logging.getLogger(name) 29 | 30 | 31 | @contextmanager 32 | def log_time(logger: logging.Logger, operation: str, level: str = "INFO"): 33 | """ 34 | Context manager to log execution time of an operation. 35 | 36 | Usage: 37 | with log_time(logger, "vector search"): 38 | # code here 39 | """ 40 | start = time.time() 41 | logger.log(getattr(logging, level.upper()), f"⏱️ {operation} - starting") 42 | try: 43 | yield 44 | finally: 45 | elapsed = time.time() - start 46 | logger.log( 47 | getattr(logging, level.upper()), 48 | f"⏱️ {operation} - completed in {elapsed:.2f}s", 49 | ) 50 | 51 | 52 | def log_time_decorator(operation: str = None, level: str = "INFO"): 53 | """ 54 | Decorator to log execution time of a function. 55 | 56 | Usage: 57 | @log_time_decorator("process query") 58 | def my_function(): 59 | pass 60 | """ 61 | 62 | def decorator(func: Callable) -> Callable: 63 | op_name = operation or f"{func.__module__}.{func.__name__}" 64 | 65 | @wraps(func) 66 | def wrapper(*args, **kwargs) -> Any: 67 | logger = logging.getLogger(func.__module__) 68 | start = time.time() 69 | logger.log(getattr(logging, level.upper()), f"⏱️ {op_name} - starting") 70 | try: 71 | result = func(*args, **kwargs) 72 | return result 73 | finally: 74 | elapsed = time.time() - start 75 | logger.log( 76 | getattr(logging, level.upper()), 77 | f"⏱️ {op_name} - completed in {elapsed:.2f}s", 78 | ) 79 | 80 | return wrapper 81 | 82 | return decorator 83 | 84 | 85 | # Initialize logging on import 86 | import os 87 | 88 | log_level = os.environ.get("LOG_LEVEL", "INFO") 89 | setup_logging(log_level) 90 | -------------------------------------------------------------------------------- /misc/generate_embeddings_gpu.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from pathlib import Path 3 | 4 | import html2text 5 | import numpy as np 6 | import pandas as pd 7 | import pyarrow as pa 8 | import pyarrow.parquet as pq 9 | import torch 10 | from sentence_transformers import SentenceTransformer 11 | 12 | 13 | def strip_html(text: str) -> str: 14 | h = html2text.HTML2Text() 15 | h.ignore_links = True 16 | h.ignore_images = True 17 | h.ignore_emphasis = True 18 | return h.handle(text).strip() 19 | 20 | 21 | def generate_embeddings(data_dir: str = "data", output_dir: str = "embeddings"): 22 | device = "cuda" if torch.cuda.is_available() else "cpu" 23 | print(f"Using device: {device}") 24 | if device == "cuda": 25 | print(f"GPU: {torch.cuda.get_device_name(0)}") 26 | print(f"CUDA version: {torch.version.cuda}") 27 | 28 | model = SentenceTransformer( 29 | "sentence-transformers/all-mpnet-base-v2", device=device 30 | ) 31 | 32 | data_path = Path(data_dir) 33 | output_path = Path(output_dir) 34 | output_path.mkdir(exist_ok=True) 35 | 36 | parquet_files = sorted(glob(str(data_path / "since_2023_*"))) 37 | if not parquet_files: 38 | raise FileNotFoundError(f"No parquet files found in {data_dir}") 39 | 40 | print(f"Found {len(parquet_files)} parquet files") 41 | 42 | total_processed = 0 43 | batch_size = 5_000 44 | encode_batch_size = 128 45 | 46 | for file_idx, parquet_file in enumerate(parquet_files, 1): 47 | file_name = Path(parquet_file).name 48 | print(f"\n[{file_idx}/{len(parquet_files)}] Processing {file_name}") 49 | 50 | df = pd.read_parquet(parquet_file) 51 | df = df[df["text"].notna() & (df["text"] != "")] 52 | 53 | print(f" Loaded {len(df)} rows") 54 | 55 | all_embeddings = [] 56 | all_indices = [] 57 | 58 | for batch_start in range(0, len(df), batch_size): 59 | batch_df = df.iloc[batch_start : batch_start + batch_size].copy() 60 | 61 | batch_df["clean_text"] = batch_df["text"].astype(str).apply(strip_html) 62 | batch_df = batch_df[batch_df["clean_text"].str.len() > 0] 63 | 64 | if len(batch_df) == 0: 65 | continue 66 | 67 | documents = batch_df["clean_text"].tolist() 68 | 69 | print(f" Encoding batch of {len(documents)} documents...") 70 | embeddings = model.encode( 71 | documents, 72 | batch_size=encode_batch_size, 73 | show_progress_bar=False, 74 | convert_to_numpy=True, 75 | ) 76 | 77 | all_embeddings.extend(embeddings) 78 | all_indices.extend(batch_df.index.tolist()) 79 | 80 | total_processed += len(documents) 81 | print(f" Total processed: {total_processed}") 82 | 83 | if all_embeddings: 84 | result_df = df.loc[all_indices].copy() 85 | result_df["clean_text"] = result_df["text"].astype(str).apply(strip_html) 86 | 87 | embeddings_array = np.array([emb for emb in all_embeddings]) 88 | 89 | table = pa.Table.from_pandas(result_df) 90 | embeddings_list = pa.array( 91 | [emb.tolist() for emb in embeddings_array], type=pa.list_(pa.float32()) 92 | ) 93 | table = table.append_column("embedding", embeddings_list) 94 | 95 | output_file = output_path / f"{file_name}.parquet" 96 | pq.write_table(table, output_file, compression="snappy") 97 | print(f" Saved to {output_file}") 98 | 99 | print(f"\n✅ Successfully processed {total_processed} documents") 100 | 101 | 102 | if __name__ == "__main__": 103 | generate_embeddings() 104 | -------------------------------------------------------------------------------- /hn_search/query.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import psycopg 4 | from dotenv import load_dotenv 5 | from pgvector.psycopg import register_vector 6 | 7 | from hn_search.cache_config import ( 8 | cache_vector_search, 9 | get_cached_vector_search, 10 | ) 11 | from hn_search.common import get_model 12 | from hn_search.db_config import get_db_config 13 | 14 | load_dotenv() 15 | 16 | 17 | def query( 18 | query_text: str, 19 | n_results: int = 10, 20 | db_host: str = None, 21 | db_port: int = None, 22 | db_name: str = None, 23 | db_user: str = None, 24 | db_password: str = None, 25 | ): 26 | # Check cache first 27 | cached_results = get_cached_vector_search(query_text, n_results) 28 | if cached_results: 29 | print("Using cached results...", file=sys.stderr) 30 | for i, result in enumerate(cached_results, 1): 31 | print(f"=== Result {i} (distance: {result['distance']:.4f}) ===") 32 | print(f"ID: {result['id']}") 33 | print(f"Author: {result['author']}") 34 | print(f"Date: {result['timestamp']}") 35 | print(f"Type: {result['type']}") 36 | print(f"Text: {result['text']}") 37 | print() 38 | sys.stdout.flush() 39 | return 40 | 41 | print("Loading embedding model...", file=sys.stderr) 42 | model = get_model() 43 | 44 | print("Encoding query...", file=sys.stderr) 45 | query_embedding = model.encode([query_text])[0] 46 | 47 | print("Querying PostgreSQL...", file=sys.stderr) 48 | # Use Railway/environment variables if individual params not provided 49 | if not all([db_host, db_port, db_name, db_user, db_password]): 50 | db_config = get_db_config() 51 | else: 52 | db_config = { 53 | "host": db_host, 54 | "port": db_port, 55 | "dbname": db_name, 56 | "user": db_user, 57 | "password": db_password, 58 | } 59 | 60 | conn = psycopg.connect(**db_config) 61 | register_vector(conn) 62 | 63 | with conn.cursor() as cur: 64 | cur.execute( 65 | """ 66 | SELECT id, clean_text, author, timestamp, type, 67 | embedding <=> %s::vector AS distance 68 | FROM hn_documents 69 | ORDER BY embedding <=> %s::vector 70 | LIMIT %s 71 | """, 72 | (query_embedding.tolist(), query_embedding.tolist(), n_results), 73 | ) 74 | 75 | print("Fetching results...\n", file=sys.stderr) 76 | sys.stdout.flush() 77 | 78 | results = cur.fetchall() 79 | 80 | # Prepare results for caching 81 | cache_data = [] 82 | 83 | for i, (doc_id, document, author, timestamp, doc_type, distance) in enumerate( 84 | results, 1 85 | ): 86 | print(f"=== Result {i} (distance: {distance:.4f}) ===") 87 | print(f"ID: {doc_id}") 88 | print(f"Author: {author}") 89 | print(f"Date: {timestamp}") 90 | print(f"Type: {doc_type}") 91 | print(f"Text: {document}") 92 | print() 93 | sys.stdout.flush() 94 | 95 | # Add to cache data 96 | cache_data.append( 97 | { 98 | "id": doc_id, 99 | "text": document, 100 | "author": author, 101 | "timestamp": timestamp.isoformat() 102 | if hasattr(timestamp, "isoformat") 103 | else str(timestamp), 104 | "type": doc_type, 105 | "distance": float(distance), 106 | } 107 | ) 108 | 109 | # Cache the results 110 | if cache_data: 111 | cache_vector_search(query_text, cache_data, n_results) 112 | 113 | conn.close() 114 | 115 | 116 | if __name__ == "__main__": 117 | if len(sys.argv) < 2: 118 | print("Usage: python -m hn_search.query [n_results]") 119 | sys.exit(1) 120 | 121 | query_text = sys.argv[1] 122 | n_results = int(sys.argv[2]) if len(sys.argv) > 2 else 10 123 | 124 | query(query_text, n_results) 125 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | # Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | # poetry.lock 109 | # poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 115 | # pdm.lock 116 | # pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # pixi 121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 122 | # pixi.lock 123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 124 | # in the .venv directory. It is recommended not to include this directory in version control. 125 | .pixi 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # Redis 135 | *.rdb 136 | *.aof 137 | *.pid 138 | 139 | # RabbitMQ 140 | mnesia/ 141 | rabbitmq/ 142 | rabbitmq-data/ 143 | 144 | # ActiveMQ 145 | activemq-data/ 146 | 147 | # SageMath parsed files 148 | *.sage.py 149 | 150 | # Environments 151 | .env 152 | .envrc 153 | .venv 154 | env/ 155 | venv/ 156 | ENV/ 157 | env.bak/ 158 | venv.bak/ 159 | 160 | # Spyder project settings 161 | .spyderproject 162 | .spyproject 163 | 164 | # Rope project settings 165 | .ropeproject 166 | 167 | # mkdocs documentation 168 | /site 169 | 170 | # mypy 171 | .mypy_cache/ 172 | .dmypy.json 173 | dmypy.json 174 | 175 | # Pyre type checker 176 | .pyre/ 177 | 178 | # pytype static type analyzer 179 | .pytype/ 180 | 181 | # Cython debug symbols 182 | cython_debug/ 183 | 184 | # PyCharm 185 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 186 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 187 | # and can be added to the global gitignore or merged into this file. For a more nuclear 188 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 189 | # .idea/ 190 | 191 | # Abstra 192 | # Abstra is an AI-powered process automation framework. 193 | # Ignore directories containing user credentials, local state, and settings. 194 | # Learn more at https://abstra.io/docs 195 | .abstra/ 196 | 197 | # Visual Studio Code 198 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 199 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 200 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 201 | # you could uncomment the following to ignore the entire vscode folder 202 | # .vscode/ 203 | 204 | # Ruff stuff: 205 | .ruff_cache/ 206 | 207 | # PyPI configuration file 208 | .pypirc 209 | 210 | # Marimo 211 | marimo/_static/ 212 | marimo/_lsp/ 213 | __marimo__/ 214 | 215 | # Streamlit 216 | .streamlit/secrets.toml 217 | .claude 218 | .gradio 219 | 220 | data/ 221 | -------------------------------------------------------------------------------- /hn_search/cache_config.py: -------------------------------------------------------------------------------- 1 | """Redis cache configuration for HN search.""" 2 | 3 | import hashlib 4 | import json 5 | import os 6 | from typing import Any, Dict, List, Optional 7 | from urllib.parse import urlparse 8 | 9 | import redis 10 | from dotenv import load_dotenv 11 | from langchain.globals import set_llm_cache 12 | from langchain_community.cache import RedisCache 13 | 14 | from hn_search.logging_config import get_logger 15 | 16 | # Load environment variables 17 | load_dotenv() 18 | 19 | logger = get_logger(__name__) 20 | 21 | 22 | def sanitize_url(url: str) -> str: 23 | """Sanitize URL to hide credentials.""" 24 | try: 25 | parsed = urlparse(url) 26 | if parsed.password: 27 | # Replace password with asterisks 28 | sanitized = parsed._replace( 29 | netloc=f"{parsed.username}:***@{parsed.hostname}:{parsed.port}" 30 | if parsed.port 31 | else f"{parsed.username}:***@{parsed.hostname}" 32 | ) 33 | return sanitized.geturl() 34 | return url 35 | except Exception: 36 | return "redis://***" 37 | 38 | 39 | # Redis configuration 40 | REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0") 41 | 42 | # Initialize Redis client 43 | try: 44 | redis_client = redis.from_url(REDIS_URL) 45 | # Test connection 46 | redis_client.ping() 47 | 48 | # Set up LangChain Redis cache 49 | cache = RedisCache(redis_client) 50 | set_llm_cache(cache) 51 | 52 | logger.info(f"✅ Redis cache initialized at {sanitize_url(REDIS_URL)}") 53 | except Exception as e: 54 | logger.warning(f"⚠️ Redis cache not available: {e}") 55 | logger.warning("🔄 Running without cache") 56 | redis_client = None 57 | 58 | 59 | # PostgreSQL query cache functions 60 | def get_pg_cache_key(query: str, params: tuple = ()) -> str: 61 | """Generate a cache key for PostgreSQL queries.""" 62 | # Combine query and params for unique key 63 | cache_str = f"{query}:{str(params)}" 64 | return f"pg:{hashlib.md5(cache_str.encode()).hexdigest()}" 65 | 66 | 67 | def get_cached_pg_results( 68 | query: str, params: tuple = () 69 | ) -> Optional[List[Dict[str, Any]]]: 70 | """Get cached PostgreSQL query results.""" 71 | if not redis_client: 72 | return None 73 | try: 74 | cache_key = get_pg_cache_key(query, params) 75 | cached = redis_client.get(cache_key) 76 | if cached: 77 | return json.loads(cached) 78 | except Exception: 79 | pass 80 | return None 81 | 82 | 83 | def cache_pg_results( 84 | query: str, results: List[Dict[str, Any]], params: tuple = (), ttl: int = 3600 85 | ): 86 | """Cache PostgreSQL query results with TTL (default 1 hour).""" 87 | if not redis_client: 88 | return 89 | try: 90 | cache_key = get_pg_cache_key(query, params) 91 | redis_client.setex(cache_key, ttl, json.dumps(results)) 92 | except Exception: 93 | pass 94 | 95 | 96 | # Vector search cache functions 97 | def get_vector_cache_key(query: str, k: int = 10) -> str: 98 | """Generate a cache key for vector search queries.""" 99 | return f"vector:{hashlib.md5(f'{query}:{k}'.encode()).hexdigest()}" 100 | 101 | 102 | def get_cached_vector_search(query: str, k: int = 10) -> Optional[List[Dict[str, Any]]]: 103 | """Get cached vector search results.""" 104 | if not redis_client: 105 | return None 106 | try: 107 | cache_key = get_vector_cache_key(query, k) 108 | cached = redis_client.get(cache_key) 109 | if cached: 110 | return json.loads(cached) 111 | except Exception: 112 | pass 113 | return None 114 | 115 | 116 | def cache_vector_search( 117 | query: str, results: List[Dict[str, Any]], k: int = 10, ttl: int = 21600 118 | ): 119 | """Cache vector search results with TTL (default 6 hours).""" 120 | if not redis_client: 121 | return 122 | try: 123 | cache_key = get_vector_cache_key(query, k) 124 | redis_client.setex(cache_key, ttl, json.dumps(results)) 125 | except Exception: 126 | pass 127 | 128 | 129 | # LangChain answer cache functions 130 | def get_answer_cache_key(query: str, context: str) -> str: 131 | """Generate a cache key for LLM answers.""" 132 | # Hash context to keep key manageable 133 | context_hash = hashlib.md5(context.encode()).hexdigest() 134 | return f"answer:{hashlib.md5(f'{query}:{context_hash}'.encode()).hexdigest()}" 135 | 136 | 137 | def get_cached_answer(query: str, context: str) -> Optional[str]: 138 | """Get cached LLM answer.""" 139 | if not redis_client: 140 | return None 141 | try: 142 | cache_key = get_answer_cache_key(query, context) 143 | cached = redis_client.get(cache_key) 144 | if cached: 145 | return cached.decode("utf-8") 146 | except Exception: 147 | pass 148 | return None 149 | 150 | 151 | def cache_answer(query: str, context: str, answer: str, ttl: int = 21600): 152 | """Cache LLM answer with TTL (default 6 hours).""" 153 | if not redis_client: 154 | return 155 | try: 156 | cache_key = get_answer_cache_key(query, context) 157 | redis_client.setex(cache_key, ttl, answer) 158 | except Exception: 159 | pass 160 | 161 | 162 | # Clear cache utility 163 | def clear_cache(pattern: str = "*"): 164 | """Clear cache entries matching pattern.""" 165 | if not redis_client: 166 | return 167 | try: 168 | keys = redis_client.keys(pattern) 169 | if keys: 170 | redis_client.delete(*keys) 171 | logger.info(f"🗑️ Cleared {len(keys)} cache entries") 172 | except Exception as e: 173 | logger.exception(f"❌ Error clearing cache: {e}") 174 | -------------------------------------------------------------------------------- /hn_search/init_db_pgvector.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from glob import glob 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | import psycopg 7 | from dotenv import load_dotenv 8 | from pgvector.psycopg import register_vector 9 | 10 | from hn_search.db_config import get_db_config 11 | 12 | load_dotenv() 13 | 14 | 15 | def get_connection(): 16 | """Get a fresh database connection with vector registration""" 17 | db_config = get_db_config() 18 | conn = psycopg.connect(**db_config) 19 | register_vector(conn) 20 | return conn 21 | 22 | 23 | def init_db_from_precomputed( 24 | embeddings_dir: str = "embeddings", 25 | test_mode: bool = False, 26 | ): 27 | db_config = get_db_config() 28 | print( 29 | f"Connecting to database at {db_config['host']}:{db_config['port']}/{db_config['dbname']}" 30 | ) 31 | 32 | # Initial connection for schema setup 33 | conn = get_connection() 34 | 35 | with conn.cursor() as cur: 36 | cur.execute("CREATE EXTENSION IF NOT EXISTS vector") 37 | 38 | cur.execute(""" 39 | CREATE TABLE IF NOT EXISTS hn_documents ( 40 | id TEXT PRIMARY KEY, 41 | clean_text TEXT NOT NULL, 42 | author TEXT, 43 | timestamp TEXT, 44 | type TEXT, 45 | embedding vector(768) 46 | ) 47 | """) 48 | conn.commit() 49 | 50 | cur.execute(""" 51 | CREATE TABLE IF NOT EXISTS processed_files ( 52 | filename TEXT PRIMARY KEY, 53 | processed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP 54 | ) 55 | """) 56 | conn.commit() 57 | 58 | conn.close() 59 | 60 | embeddings_path = Path(embeddings_dir) 61 | if not embeddings_path.is_absolute(): 62 | embeddings_path = Path.cwd() / embeddings_path 63 | 64 | parquet_files = sorted(glob(str(embeddings_path / "*.parquet"))) 65 | if not parquet_files: 66 | raise FileNotFoundError(f"No parquet files found in {embeddings_dir}") 67 | 68 | if test_mode: 69 | parquet_files = parquet_files[:1] 70 | print(f"TEST MODE: Processing only first file") 71 | 72 | print(f"Found {len(parquet_files)} parquet files with precomputed embeddings") 73 | 74 | total_loaded = 0 75 | 76 | for file_idx, parquet_file in enumerate(parquet_files, 1): 77 | file_name = Path(parquet_file).name 78 | 79 | # Fresh connection for each file 80 | print(f"\n[{file_idx}/{len(parquet_files)}] Processing {file_name}") 81 | 82 | max_retries = 3 83 | for retry in range(max_retries): 84 | try: 85 | conn = get_connection() 86 | 87 | # Check if already processed 88 | with conn.cursor() as cur: 89 | cur.execute( 90 | "SELECT 1 FROM processed_files WHERE filename = %s", 91 | (file_name,), 92 | ) 93 | if cur.fetchone(): 94 | print(f" Skipping {file_name} (already done)") 95 | conn.close() 96 | break 97 | 98 | print(f" Loading parquet file...") 99 | df = pd.read_parquet(parquet_file) 100 | 101 | if test_mode: 102 | df = df.head(100) 103 | print(f" TEST MODE: Limited to {len(df)} rows") 104 | 105 | print(f" Processing {len(df)} rows") 106 | 107 | # Start transaction for this file 108 | conn.execute("BEGIN") 109 | 110 | batch_size = 1000 111 | file_loaded = 0 112 | for batch_start in range(0, len(df), batch_size): 113 | batch_df = df.iloc[batch_start : batch_start + batch_size] 114 | 115 | records = [ 116 | ( 117 | str(row["id"]), 118 | row["clean_text"], 119 | str(row["author"]), 120 | str(row["timestamp"]), 121 | str(row["type"]), 122 | row["embedding"].tolist(), 123 | ) 124 | for _, row in batch_df.iterrows() 125 | ] 126 | 127 | with conn.cursor() as cur: 128 | cur.executemany( 129 | """ 130 | INSERT INTO hn_documents (id, clean_text, author, timestamp, type, embedding) 131 | VALUES (%s, %s, %s, %s, %s, %s) 132 | ON CONFLICT (id) DO NOTHING 133 | """, 134 | records, 135 | ) 136 | 137 | file_loaded += len(records) 138 | total_loaded += len(records) 139 | print(f" Batch: {file_loaded}/{len(df)} | Total: {total_loaded}") 140 | 141 | # Mark file as processed 142 | with conn.cursor() as cur: 143 | cur.execute( 144 | "INSERT INTO processed_files (filename) VALUES (%s) ON CONFLICT DO NOTHING", 145 | (file_name,), 146 | ) 147 | 148 | # Commit entire file at once 149 | conn.commit() 150 | print(f" ✅ Committed {file_name} ({file_loaded} records)") 151 | conn.close() 152 | break # Success - exit retry loop 153 | 154 | except Exception as e: 155 | print(f" ❌ Error (attempt {retry + 1}/{max_retries}): {e}") 156 | try: 157 | conn.rollback() 158 | conn.close() 159 | except: 160 | pass # Connection might already be closed 161 | 162 | if retry < max_retries - 1: 163 | print(f" 🔄 Retrying in 5 seconds...") 164 | import time 165 | 166 | time.sleep(5) 167 | else: 168 | print( 169 | f" ❌ Failed after {max_retries} attempts - skipping {file_name}" 170 | ) 171 | continue 172 | 173 | print(f"\n✅ Successfully loaded {total_loaded} documents into PostgreSQL") 174 | 175 | # Final count with fresh connection 176 | conn = get_connection() 177 | with conn.cursor() as cur: 178 | cur.execute("SELECT COUNT(*) FROM hn_documents") 179 | count = cur.fetchone()[0] 180 | print(f"Total documents in database: {count}") 181 | conn.close() 182 | 183 | 184 | if __name__ == "__main__": 185 | test_mode = "--test" in sys.argv 186 | init_db_from_precomputed(test_mode=test_mode) 187 | -------------------------------------------------------------------------------- /hn_search/job_manager.py: -------------------------------------------------------------------------------- 1 | """Job manager for deduplicating concurrent RAG queries using Redis.""" 2 | 3 | import hashlib 4 | import json 5 | import time 6 | from typing import Any, Dict, Optional, Tuple 7 | 8 | from hn_search.logging_config import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | class JobManager: 14 | """Manages RAG query jobs to prevent duplicate processing.""" 15 | 16 | def __init__(self, redis_client=None): 17 | self.redis = redis_client 18 | self.job_timeout = 180 # 3 minutes (reduced from 10) 19 | self.result_ttl = 21600 # 6 hours 20 | self.poll_interval = 0.5 # 500ms 21 | self.max_poll_time = 120 # 2 minutes (reduced from 5) 22 | 23 | def get_job_id(self, query: str) -> str: 24 | """Generate a unique job ID from query text.""" 25 | return hashlib.md5(query.strip().encode()).hexdigest() 26 | 27 | def try_claim_job(self, query: str) -> Tuple[bool, str]: 28 | """ 29 | Try to claim a job for processing. 30 | 31 | Returns: 32 | (claimed, job_id): claimed=True if this caller should process the job 33 | """ 34 | if not self.redis: 35 | # No Redis - always claim (fallback to original behavior) 36 | return True, self.get_job_id(query) 37 | 38 | job_id = self.get_job_id(query) 39 | status_key = f"job:{job_id}:status" 40 | 41 | try: 42 | # Check current status 43 | status = self.redis.get(status_key) 44 | 45 | if status == b"completed": 46 | # Already completed - don't claim 47 | return False, job_id 48 | 49 | if status == b"failed": 50 | # Failed job - allow immediate retry by deleting it 51 | self.redis.delete(status_key) 52 | self.redis.delete(f"job:{job_id}:error") 53 | logger.info(f"🔄 Clearing failed job {job_id[:8]} for retry") 54 | 55 | # Try to claim with atomic SET NX (only if not exists) 56 | claimed = self.redis.set( 57 | status_key, 58 | "processing", 59 | nx=True, # Only set if key doesn't exist 60 | ex=self.job_timeout, # Auto-expire after timeout 61 | ) 62 | 63 | return bool(claimed), job_id 64 | 65 | except Exception as e: 66 | logger.exception(f"⚠️ Job claim error: {e}") 67 | # Fallback: allow processing if Redis fails 68 | return True, job_id 69 | 70 | def wait_for_job( 71 | self, job_id: str, timeout: Optional[int] = None 72 | ) -> Optional[Dict]: 73 | """ 74 | Wait for another process to complete the job. 75 | 76 | Returns: 77 | Result dict if job completes, None if timeout or error 78 | """ 79 | if not self.redis: 80 | return None 81 | 82 | timeout = timeout or self.max_poll_time 83 | start_time = time.time() 84 | result_key = f"job:{job_id}:result" 85 | status_key = f"job:{job_id}:status" 86 | error_key = f"job:{job_id}:error" 87 | 88 | logger.info( 89 | f"⏳ Waiting for job {job_id[:8]}... (another request is processing)" 90 | ) 91 | 92 | while time.time() - start_time < timeout: 93 | try: 94 | status = self.redis.get(status_key) 95 | 96 | if status == b"completed": 97 | result = self.redis.get(result_key) 98 | if result: 99 | elapsed = time.time() - start_time 100 | logger.info( 101 | f"✅ Job {job_id[:8]} completed by another request (waited {elapsed:.2f}s)" 102 | ) 103 | return json.loads(result) 104 | 105 | elif status == b"failed": 106 | error = self.redis.get(error_key) 107 | error_msg = error.decode() if error else "Unknown error" 108 | logger.error(f"❌ Job {job_id[:8]} failed: {error_msg}") 109 | return None 110 | 111 | elif status is None: 112 | # Job disappeared (expired or deleted) 113 | logger.warning(f"⚠️ Job {job_id[:8]} disappeared") 114 | return None 115 | 116 | # Still processing - wait and retry 117 | time.sleep(self.poll_interval) 118 | 119 | except Exception as e: 120 | logger.exception(f"⚠️ Error polling job {job_id[:8]}: {e}") 121 | return None 122 | 123 | logger.warning(f"⏱️ Timeout waiting for job {job_id[:8]} after {timeout}s") 124 | # Force-clear the stuck job to allow retry 125 | self.clear_job(job_id) 126 | return None 127 | 128 | def clear_job(self, job_id: str): 129 | """Force-clear a stuck job to allow retry.""" 130 | if not self.redis: 131 | return 132 | 133 | try: 134 | status_key = f"job:{job_id}:status" 135 | result_key = f"job:{job_id}:result" 136 | error_key = f"job:{job_id}:error" 137 | progress_key = f"job:{job_id}:progress" 138 | 139 | self.redis.delete(status_key, result_key, error_key, progress_key) 140 | logger.info(f"🗑️ Cleared stuck job {job_id[:8]}") 141 | except Exception as e: 142 | logger.exception(f"⚠️ Error clearing job {job_id[:8]}: {e}") 143 | 144 | def store_result(self, job_id: str, result: Dict[str, Any]): 145 | """Store job result and mark as completed.""" 146 | if not self.redis: 147 | return 148 | 149 | try: 150 | result_key = f"job:{job_id}:result" 151 | status_key = f"job:{job_id}:status" 152 | 153 | # Store result with TTL 154 | self.redis.setex(result_key, self.result_ttl, json.dumps(result)) 155 | 156 | # Update status to completed 157 | self.redis.setex(status_key, self.result_ttl, "completed") 158 | 159 | logger.info(f"💾 Stored result for job {job_id[:8]}") 160 | 161 | except Exception as e: 162 | logger.exception(f"⚠️ Error storing result for job {job_id[:8]}: {e}") 163 | 164 | def store_error(self, job_id: str, error_message: str): 165 | """Store job error and mark as failed.""" 166 | if not self.redis: 167 | return 168 | 169 | try: 170 | error_key = f"job:{job_id}:error" 171 | status_key = f"job:{job_id}:status" 172 | 173 | self.redis.setex(error_key, 3600, error_message) # 1 hour TTL 174 | self.redis.setex(status_key, 3600, "failed") 175 | 176 | logger.info(f"💾 Stored error for job {job_id[:8]}") 177 | 178 | except Exception as e: 179 | logger.exception(f"⚠️ Error storing error for job {job_id[:8]}: {e}") 180 | 181 | def update_progress(self, job_id: str, progress: str): 182 | """Update job progress for streaming updates.""" 183 | if not self.redis: 184 | return 185 | 186 | try: 187 | progress_key = f"job:{job_id}:progress" 188 | self.redis.setex(progress_key, 300, progress) # 5 min TTL 189 | except Exception: 190 | pass # Silent fail for progress updates 191 | 192 | def get_progress(self, job_id: str) -> Optional[str]: 193 | """Get current job progress.""" 194 | if not self.redis: 195 | return None 196 | 197 | try: 198 | progress_key = f"job:{job_id}:progress" 199 | progress = self.redis.get(progress_key) 200 | return progress.decode() if progress else None 201 | except Exception: 202 | return None 203 | 204 | def get_result(self, job_id: str) -> Optional[Dict]: 205 | """Get completed job result.""" 206 | if not self.redis: 207 | return None 208 | 209 | try: 210 | result_key = f"job:{job_id}:result" 211 | result = self.redis.get(result_key) 212 | return json.loads(result) if result else None 213 | except Exception: 214 | return None 215 | 216 | def track_recent_query(self, query: str): 217 | """Track query in recent queries sorted set.""" 218 | if not self.redis: 219 | logger.warning("⚠️ Redis not available, cannot track recent query") 220 | return 221 | 222 | try: 223 | # Use sorted set with timestamp as score 224 | timestamp = time.time() 225 | self.redis.zadd("recent_queries", {query: timestamp}) 226 | 227 | # Keep only last 100 queries (trim older ones) 228 | self.redis.zremrangebyrank("recent_queries", 0, -101) 229 | 230 | logger.debug(f"✅ Tracked recent query: {query[:50]}...") 231 | 232 | except Exception as e: 233 | logger.exception(f"⚠️ Error tracking recent query: {e}") 234 | 235 | def get_recent_queries(self, limit: int = 10) -> list: 236 | """Get most recent queries with timestamps.""" 237 | if not self.redis: 238 | return [] 239 | 240 | try: 241 | # Get top queries in reverse order (most recent first) 242 | queries = self.redis.zrevrange( 243 | "recent_queries", 0, limit - 1, withscores=True 244 | ) 245 | 246 | # Format results with human-readable timestamps 247 | from datetime import datetime 248 | 249 | result = [] 250 | for query_bytes, timestamp in queries: 251 | query = query_bytes.decode("utf-8") 252 | dt = datetime.fromtimestamp(timestamp) 253 | time_ago = self._format_time_ago(timestamp) 254 | result.append( 255 | {"query": query, "timestamp": dt.isoformat(), "time_ago": time_ago} 256 | ) 257 | 258 | return result 259 | 260 | except Exception as e: 261 | logger.exception(f"⚠️ Error getting recent queries: {e}") 262 | return [] 263 | 264 | def _format_time_ago(self, timestamp: float) -> str: 265 | """Format timestamp as human-readable 'time ago' string.""" 266 | seconds_ago = time.time() - timestamp 267 | 268 | if seconds_ago < 60: 269 | return "just now" 270 | elif seconds_ago < 3600: 271 | minutes = int(seconds_ago / 60) 272 | return f"{minutes}m ago" 273 | elif seconds_ago < 86400: 274 | hours = int(seconds_ago / 3600) 275 | return f"{hours}h ago" 276 | else: 277 | days = int(seconds_ago / 86400) 278 | return f"{days}d ago" 279 | -------------------------------------------------------------------------------- /hn_search/rag/web_ui.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import gradio as gr 5 | 6 | from hn_search.cache_config import redis_client 7 | from hn_search.job_manager import JobManager 8 | from hn_search.logging_config import get_logger 9 | 10 | from .graph import create_rag_workflow 11 | 12 | # Initialize job manager 13 | job_manager = JobManager(redis_client) 14 | 15 | logger = get_logger(__name__) 16 | 17 | 18 | def hn_search_rag(query: str): 19 | if not query.strip(): 20 | yield "Please enter a question.", "", "", "" 21 | return 22 | 23 | # Try to claim this job 24 | claimed, job_id = job_manager.try_claim_job(query) 25 | 26 | # Track query immediately so it shows up in recent queries 27 | job_manager.track_recent_query(query) 28 | 29 | progress_log = [] 30 | current_sources = "" 31 | current_answer = "" 32 | current_recent = _format_recent_queries() 33 | logger.debug(f"Recent queries at start: {current_recent[:100]}...") 34 | 35 | if not claimed: 36 | # Another request is processing this query - poll for progress 37 | progress_log.append(f"🔍 Searching for: {query}") 38 | yield "\n".join(progress_log), "", "", current_recent 39 | 40 | # Poll for progress updates from the processing job 41 | timeout = job_manager.max_poll_time 42 | start_time = time.time() 43 | last_progress = "" 44 | 45 | while time.time() - start_time < timeout: 46 | # Check progress 47 | current_progress = job_manager.get_progress(job_id) 48 | if current_progress and current_progress != last_progress: 49 | # Mirror the progress from the processing job 50 | last_progress = current_progress 51 | yield current_progress, "", "", current_recent 52 | 53 | # Check if complete 54 | result = job_manager.get_result(job_id) 55 | if result: 56 | # Job completed - show final state 57 | final_progress = f"🔍 Searching for: {query}\n📚 Retrieving relevant HN comments...\n🤖 Generating answer with DeepSeek...\n✅ Complete!" 58 | # Refresh recent queries after completion 59 | current_recent = _format_recent_queries() 60 | yield ( 61 | final_progress, 62 | result.get("answer", ""), 63 | result.get("sources", ""), 64 | current_recent, 65 | ) 66 | return 67 | 68 | time.sleep(0.5) 69 | 70 | # Timeout - try to claim and process ourselves 71 | progress_log = [f"🔍 Searching for: {query}"] 72 | progress_log.append( 73 | "⚠️ Timeout waiting for other request, processing query now..." 74 | ) 75 | yield "\n".join(progress_log), "", "", current_recent 76 | claimed, job_id = job_manager.try_claim_job(query) 77 | if not claimed: 78 | yield ( 79 | "\n".join(progress_log) + "\n❌ Unable to process query", 80 | "", 81 | "", 82 | current_recent, 83 | ) 84 | return 85 | 86 | # We claimed the job - process it 87 | app = create_rag_workflow() 88 | initial_state = {"query": query} 89 | 90 | try: 91 | progress_log.append(f"🔍 Searching for: {query}") 92 | job_manager.update_progress(job_id, "\n".join(progress_log)) 93 | yield "\n".join(progress_log), "", "", current_recent 94 | 95 | for event in app.stream(initial_state): 96 | node_name = list(event.keys())[0] 97 | result_state = event[node_name] 98 | 99 | if result_state is None: 100 | continue 101 | 102 | node_messages = { 103 | "retrieve": "📚 Retrieving relevant HN comments...", 104 | "answer": "🤖 Generating answer with DeepSeek...", 105 | } 106 | 107 | if node_name in node_messages: 108 | progress_log.append(node_messages[node_name]) 109 | job_manager.update_progress(job_id, "\n".join(progress_log)) 110 | yield ( 111 | "\n".join(progress_log), 112 | current_answer, 113 | current_sources, 114 | current_recent, 115 | ) 116 | 117 | if "search_results" in result_state and result_state["search_results"]: 118 | sources = [] 119 | for i, r in enumerate(result_state["search_results"], 1): 120 | hn_link = f"https://news.ycombinator.com/item?id={r['id']}" 121 | sources.append( 122 | f"**[{i}]** [{r['author']}]({hn_link}) ({r['timestamp']})\n\n{r['text'][:200]}..." 123 | ) 124 | current_sources = "\n\n---\n\n".join(sources) 125 | yield ( 126 | "\n".join(progress_log), 127 | current_answer, 128 | current_sources, 129 | current_recent, 130 | ) 131 | 132 | if "answer" in result_state: 133 | current_answer = result_state["answer"] 134 | yield ( 135 | "\n".join(progress_log), 136 | current_answer, 137 | current_sources, 138 | current_recent, 139 | ) 140 | 141 | if "error_message" in result_state: 142 | error_msg = f"❌ Error: {result_state['error_message']}" 143 | progress_log.append(error_msg) 144 | job_manager.store_error(job_id, result_state["error_message"]) 145 | yield "\n".join(progress_log), "", "", current_recent 146 | return 147 | 148 | # Store successful result for other waiting requests 149 | job_manager.store_result( 150 | job_id, {"answer": current_answer, "sources": current_sources} 151 | ) 152 | 153 | # Refresh recent queries after storing result 154 | current_recent = _format_recent_queries() 155 | 156 | progress_log.append("✅ Complete!") 157 | yield "\n".join(progress_log), current_answer, current_sources, current_recent 158 | 159 | except Exception as e: 160 | error_msg = f"❌ Error: {str(e)}" 161 | progress_log.append(error_msg) 162 | job_manager.store_error(job_id, str(e)) 163 | yield "\n".join(progress_log), current_answer, current_sources, current_recent 164 | return 165 | 166 | 167 | def _format_recent_queries() -> str: 168 | """Format recent queries as markdown.""" 169 | recent = job_manager.get_recent_queries(limit=10) 170 | 171 | if not recent: 172 | return "*No recent queries yet*" 173 | 174 | lines = [] 175 | for item in recent: 176 | query = item["query"] 177 | time_ago = item["time_ago"] 178 | # Make queries clickable (URL-encoded) 179 | import urllib.parse 180 | 181 | encoded_query = urllib.parse.quote(query) 182 | lines.append(f"- **{time_ago}**: [{query}](?q={encoded_query})") 183 | 184 | return "\n".join(lines) 185 | 186 | 187 | def create_interface(): 188 | with gr.Blocks( 189 | title="🔎 Hacker News RAG Search", 190 | head=""" 191 | 209 | """, 210 | ) as demo: 211 | gr.Markdown( 212 | """ 213 | # 🔎 Hacker News RAG Search 214 | 215 | Ask questions about Hacker News discussions and get AI-powered answers! 216 | """ 217 | ) 218 | 219 | query_input = gr.Textbox( 220 | label="Your Question", 221 | placeholder="What do people think about Rust vs Go?", 222 | lines=2, 223 | elem_id="query_input", 224 | ) 225 | 226 | search_button = gr.Button("🔍 Search", variant="primary") 227 | 228 | progress_output = gr.Textbox( 229 | label="Progress", lines=5, interactive=False, value="Ready to search..." 230 | ) 231 | 232 | answer_output = gr.Markdown(label="💬 Answer", value="") 233 | 234 | with gr.Accordion("📚 Source Comments", open=False): 235 | sources_output = gr.Markdown(value="") 236 | 237 | with gr.Accordion("🕒 Recent Queries", open=False): 238 | recent_output = gr.Markdown(value="*Loading recent queries...*") 239 | 240 | # Auto-refresh recent queries every 3 seconds 241 | timer = gr.Timer(value=3, active=True) 242 | timer.tick(fn=_format_recent_queries, inputs=[], outputs=[recent_output]) 243 | 244 | # Hidden HTML component for JavaScript execution 245 | html_output = gr.HTML(visible=False) 246 | 247 | def search_and_update_url(query: str): 248 | """Search and update URL in browser.""" 249 | query = query or "" 250 | for result in hn_search_rag(query): 251 | yield result + ("",) # Add empty string for HTML output 252 | 253 | # Set up search action 254 | search_button.click( 255 | fn=search_and_update_url, 256 | inputs=[query_input], 257 | outputs=[ 258 | progress_output, 259 | answer_output, 260 | sources_output, 261 | recent_output, 262 | html_output, 263 | ], 264 | show_progress="full", 265 | ) 266 | 267 | # Add JavaScript click handler to update URL 268 | search_button.click( 269 | fn=None, 270 | inputs=[query_input], 271 | outputs=[], 272 | js="(query) => { console.log('Updating URL:', query); window.updateUrlWithSearch(query); }", 273 | ) 274 | 275 | # Handle URL parameters and auto-search on load 276 | def load_and_search_from_url(request: gr.Request): 277 | """Load query parameters from URL and auto-search if present.""" 278 | recent_queries = _format_recent_queries() 279 | 280 | if request: 281 | query = request.query_params.get("q", "") 282 | logger.info(f"📎 Loading from URL: q='{query}'") 283 | 284 | if query: 285 | logger.info(f"🔍 Auto-searching for: {query}") 286 | # Yield initial state with query and recent queries immediately 287 | yield query, "Ready to search...", "", "", recent_queries, "" 288 | 289 | # Then stream search results 290 | for progress, answer, sources, recent in hn_search_rag(query): 291 | yield query, progress, answer, sources, recent, "" 292 | return 293 | 294 | # No query - yield once to update recent queries immediately without loading state 295 | yield "", "Ready to search...", "", "", recent_queries, "" 296 | 297 | # Set up load handler to populate fields and auto-search from URL 298 | demo.load( 299 | fn=load_and_search_from_url, 300 | inputs=[], 301 | outputs=[ 302 | query_input, 303 | progress_output, 304 | answer_output, 305 | sources_output, 306 | recent_output, 307 | html_output, 308 | ], 309 | ) 310 | 311 | return demo 312 | 313 | 314 | demo = create_interface() 315 | 316 | if __name__ == "__main__": 317 | logger.info("🔎 Starting HN RAG Search Web Interface...") 318 | logger.info("✨ Features:") 319 | logger.info(" • URL parameter support: ?q=query") 320 | logger.info(" • Auto-search from URL parameters") 321 | demo.launch( 322 | server_name="0.0.0.0", 323 | server_port=int(os.environ.get("PORT", 7860)), 324 | share=False, 325 | ) 326 | -------------------------------------------------------------------------------- /hn_search/rag/nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | 5 | import psycopg 6 | from dotenv import load_dotenv 7 | from langchain_openai import ChatOpenAI 8 | from pgvector.psycopg import register_vector 9 | from psycopg_pool import ConnectionPool 10 | from sentence_transformers import SentenceTransformer 11 | 12 | from hn_search.cache_config import ( 13 | cache_answer, 14 | cache_vector_search, 15 | get_cached_answer, 16 | get_cached_vector_search, 17 | ) 18 | from hn_search.common import get_device, get_model 19 | from hn_search.db_config import get_db_config 20 | from hn_search.logging_config import get_logger, log_time 21 | 22 | from .state import RAGState, SearchResult 23 | 24 | load_dotenv() 25 | 26 | logger = get_logger(__name__) 27 | 28 | # Initialize connection pool (singleton) 29 | _connection_pool = None 30 | 31 | # Cache partition names (refreshed periodically) 32 | _partition_cache = None 33 | _partition_cache_timestamp = 0 34 | _PARTITION_CACHE_TTL = 3600 # 1 hour 35 | 36 | 37 | def get_connection_pool(): 38 | global _connection_pool 39 | if _connection_pool is None: 40 | db_config = get_db_config() 41 | # Create connection string from config 42 | conn_string = f"host={db_config['host']} port={db_config['port']} dbname={db_config['dbname']} user={db_config['user']} password={db_config['password']}" 43 | _connection_pool = ConnectionPool( 44 | conn_string, 45 | min_size=2, 46 | max_size=20, 47 | max_idle=300, # Close connections idle for 5 minutes 48 | max_lifetime=3600, # Replace connections after 1 hour 49 | kwargs={ 50 | "prepare_threshold": None, # Disable prepared statements for pgvector 51 | "keepalives": 1, # Enable TCP keepalive 52 | "keepalives_idle": 30, # Send keepalive after 30s idle 53 | "keepalives_interval": 10, # Interval between keepalives 54 | "keepalives_count": 5, # Failed keepalives before declaring dead 55 | }, 56 | configure=lambda conn: register_vector( 57 | conn 58 | ), # Register vector on each connection 59 | check=ConnectionPool.check_connection, # Check connection health before use 60 | ) 61 | return _connection_pool 62 | 63 | 64 | def _query_partition(pool, partition_name: str, query_embedding, n_results: int): 65 | """Query a single partition for nearest neighbors.""" 66 | try: 67 | with pool.connection() as conn: 68 | with conn.cursor() as cur: 69 | cur.execute( 70 | f""" 71 | SELECT id, clean_text, author, timestamp, type, 72 | embedding <=> %s::vector AS distance 73 | FROM {partition_name} 74 | ORDER BY embedding <=> %s::vector 75 | LIMIT %s 76 | """, 77 | (query_embedding.tolist(), query_embedding.tolist(), n_results), 78 | ) 79 | return cur.fetchall() 80 | except Exception as e: 81 | logger.warning(f"Error querying partition {partition_name}: {e}") 82 | return [] 83 | 84 | 85 | def _get_partitions(pool): 86 | """ 87 | Get list of all partition table names. 88 | 89 | Cached for 1 hour since partitions rarely change (only monthly). 90 | """ 91 | global _partition_cache, _partition_cache_timestamp 92 | 93 | current_time = time.time() 94 | 95 | # Return cached value if still fresh 96 | if ( 97 | _partition_cache 98 | and (current_time - _partition_cache_timestamp) < _PARTITION_CACHE_TTL 99 | ): 100 | logger.debug( 101 | f"Using cached partition list ({len(_partition_cache)} partitions)" 102 | ) 103 | return _partition_cache 104 | 105 | # Fetch fresh partition list 106 | logger.info("Fetching partition list from database") 107 | with pool.connection() as conn: 108 | with conn.cursor() as cur: 109 | cur.execute(""" 110 | SELECT tablename 111 | FROM pg_tables 112 | WHERE schemaname = 'public' 113 | AND tablename LIKE 'hn_documents_%' 114 | ORDER BY tablename 115 | """) 116 | partitions = [row[0] for row in cur.fetchall()] 117 | 118 | # Update cache 119 | _partition_cache = partitions 120 | _partition_cache_timestamp = current_time 121 | logger.info(f"Cached {len(partitions)} partitions") 122 | 123 | return partitions 124 | 125 | 126 | def _parallel_partition_search(pool, query_embedding, n_results: int): 127 | """ 128 | Query all partitions in parallel and return combined results. 129 | 130 | Uses ThreadPoolExecutor to query multiple partitions concurrently, 131 | then merges results. Much faster than sequential partition scanning. 132 | """ 133 | partitions = _get_partitions(pool) 134 | logger.info(f"Querying {len(partitions)} partitions in parallel") 135 | 136 | all_results = [] 137 | 138 | # Use max_workers based on connection pool size (leave some headroom) 139 | max_workers = min(15, len(partitions)) # Use up to 15 connections 140 | 141 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 142 | # Submit all partition queries 143 | future_to_partition = { 144 | executor.submit( 145 | _query_partition, pool, partition, query_embedding, n_results 146 | ): partition 147 | for partition in partitions 148 | } 149 | 150 | # Collect results as they complete 151 | for future in as_completed(future_to_partition): 152 | partition = future_to_partition[future] 153 | try: 154 | results = future.result() 155 | all_results.extend(results) 156 | logger.debug(f"Partition {partition}: {len(results)} results") 157 | except Exception as e: 158 | logger.exception(f"Exception querying partition {partition}: {e}") 159 | 160 | logger.info( 161 | f"Retrieved {len(all_results)} total results from {len(partitions)} partitions" 162 | ) 163 | return all_results 164 | 165 | 166 | def retrieve_node(state: RAGState) -> RAGState: 167 | query = state["query"] 168 | n_results = 10 169 | 170 | try: 171 | # Check cache first 172 | with log_time(logger, "cache lookup"): 173 | cached_results = get_cached_vector_search(query, n_results) 174 | 175 | if cached_results: 176 | logger.info(f"🔍 Using cached results for: {query}") 177 | search_results = [] 178 | for result in cached_results: 179 | search_results.append( 180 | SearchResult( 181 | id=result["id"], 182 | author=result["author"], 183 | type=result["type"], 184 | text=result["text"], 185 | timestamp=result["timestamp"], 186 | distance=result["distance"], 187 | ) 188 | ) 189 | 190 | context = "\n\n---\n\n".join( 191 | [ 192 | f"[{i + 1}] Author: {r['author']} ({r['timestamp']})\nLink: https://news.ycombinator.com/item?id={r['id']}\n{r['text']}" 193 | for i, r in enumerate(search_results) 194 | ] 195 | ) 196 | 197 | logger.info( 198 | f"✅ Found {len(search_results)} relevant comments/articles (cached)" 199 | ) 200 | 201 | return { 202 | **state, 203 | "search_results": search_results, 204 | "context": context, 205 | } 206 | 207 | logger.info(f"🔍 Searching for: {query}") 208 | 209 | # Use singleton embedding model 210 | with log_time(logger, "query embedding generation"): 211 | model = get_model() 212 | query_embedding = model.encode([query])[0] 213 | 214 | # Get connection from pool 215 | pool = get_connection_pool() 216 | 217 | # Query partitions in parallel 218 | with log_time(logger, "parallel vector search across partitions"): 219 | all_results = _parallel_partition_search(pool, query_embedding, n_results) 220 | 221 | # Merge and sort results from all partitions 222 | with log_time(logger, "merging partition results"): 223 | # Sort by distance and take top n_results 224 | all_results.sort(key=lambda x: x[5]) # Sort by distance (index 5) 225 | results = all_results[:n_results] 226 | 227 | search_results = [] 228 | cache_data = [] 229 | 230 | for ( 231 | doc_id, 232 | document, 233 | author, 234 | timestamp, 235 | doc_type, 236 | distance, 237 | ) in results: 238 | search_results.append( 239 | SearchResult( 240 | id=doc_id, 241 | author=author, 242 | type=doc_type, 243 | text=document, 244 | timestamp=timestamp, 245 | distance=distance, 246 | ) 247 | ) 248 | # Prepare for caching 249 | cache_data.append( 250 | { 251 | "id": doc_id, 252 | "text": document, 253 | "author": author, 254 | "timestamp": timestamp.isoformat() 255 | if hasattr(timestamp, "isoformat") 256 | else str(timestamp), 257 | "type": doc_type, 258 | "distance": float(distance), 259 | } 260 | ) 261 | 262 | # Cache the results 263 | if cache_data: 264 | with log_time(logger, "caching search results"): 265 | cache_vector_search(query, cache_data, n_results) 266 | 267 | context = "\n\n---\n\n".join( 268 | [ 269 | f"[{i + 1}] Author: {r['author']} ({r['timestamp']})\nLink: https://news.ycombinator.com/item?id={r['id']}\n{r['text']}" 270 | for i, r in enumerate(search_results) 271 | ] 272 | ) 273 | 274 | logger.info(f"✅ Found {len(search_results)} relevant comments/articles") 275 | 276 | return { 277 | **state, 278 | "search_results": search_results, 279 | "context": context, 280 | } 281 | except Exception as e: 282 | error_msg = f"vector type not found in the database" 283 | logger.exception(f"Database error: {str(e)}") 284 | return { 285 | **state, 286 | "error_message": error_msg, 287 | "search_results": [], 288 | "context": "", 289 | } 290 | 291 | 292 | def answer_node(state: RAGState) -> RAGState: 293 | query = state["query"] 294 | context = state["context"] 295 | 296 | # Check cache first 297 | with log_time(logger, "answer cache lookup"): 298 | cached_answer = get_cached_answer(query, context) 299 | 300 | if cached_answer: 301 | logger.info("🤖 Using cached answer") 302 | return { 303 | **state, 304 | "answer": cached_answer, 305 | } 306 | 307 | logger.info("🤖 Generating answer with DeepSeek...") 308 | 309 | with log_time(logger, "LLM answer generation"): 310 | llm = ChatOpenAI( 311 | model="deepseek-chat", 312 | api_key=os.getenv("DEEPSEEK_API_KEY"), 313 | base_url="https://api.deepseek.com", 314 | temperature=0.7, 315 | ) 316 | 317 | prompt = f"""You are a helpful assistant answering questions about Hacker News discussions. 318 | 319 | User Question: {query} 320 | 321 | Here are relevant comments and articles from Hacker News: 322 | 323 | {context} 324 | 325 | Please provide a comprehensive answer to the user's question based on the context above. 326 | If the context doesn't contain enough information, say so. 327 | 328 | When citing comments, use this format: 329 | - For quotes: As user AuthorName puts it, "quote here" [[1]](link) 330 | - For paraphrasing: User AuthorName explains that... [[2]](link) 331 | - For multiple references: Several users [[3]](link1) [[4]](link2) discuss... 332 | 333 | The [number] should match the source number from the context above, and should be a clickable link to the HN comment. 334 | 335 | Example response format: 336 | The community has mixed views on this topic. As user john_doe explains, "Python is great for prototyping" [[1]](https://news.ycombinator.com/item?id=12345). Meanwhile, user jane_smith argues that performance can be an issue [[2]](https://news.ycombinator.com/item?id=67890).""" 337 | 338 | response = llm.invoke(prompt) 339 | answer = response.content 340 | 341 | # Cache the answer 342 | with log_time(logger, "caching answer"): 343 | cache_answer(query, context, answer) 344 | 345 | logger.info("✅ Answer generated") 346 | 347 | return { 348 | **state, 349 | "answer": answer, 350 | } 351 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🔎 HN Search: Semantic Search & RAG for Hacker News 2 | 3 | A production-ready semantic search engine and RAG (Retrieval-Augmented Generation) system for Hacker News comments, built with vector embeddings, PostgreSQL with pgvector, and LangGraph. 4 | 5 | ## 🎯 Project Overview 6 | 7 | This project implements a full-stack semantic search system over **millions of Hacker News comments** (2023-2025), enabling natural language queries and AI-powered question answering. Unlike traditional keyword search, it uses dense vector embeddings to understand semantic meaning, finding relevant discussions even when exact keywords don't match. 8 | 9 | **Live Demo**: 10 | 11 | Hosted at [hn.fiodorov.es](https://hn.fiodorov.es) 12 | 13 | **Key Features:** 14 | - 🔍 **Semantic Search**: Natural language queries with cosine similarity ranking 15 | - 🤖 **RAG System**: AI-powered Q&A using LangGraph workflows and DeepSeek LLM 16 | - 📊 **Scalable Architecture**: PostgreSQL with pgvector for production-grade vector search 17 | - ⚡ **Redis Caching**: Sub-second response times for repeated queries 18 | - 🔄 **Incremental Updates**: Idempotent data pipeline for fetching new HN comments 19 | - 🎨 **Web Interface**: Gradio-based UI with URL parameter support 20 | - 🏗️ **Partitioned Tables**: Time-based partitioning for efficient query performance 21 | 22 | ## 🏗️ Architecture 23 | 24 | ``` 25 | ┌─────────────────────────────────────────────────────────────────┐ 26 | │ Data Pipeline │ 27 | ├─────────────────────────────────────────────────────────────────┤ 28 | │ BigQuery (HN Public Dataset) │ 29 | │ ↓ │ 30 | │ Fetch New Comments (idempotent, resumable) │ 31 | │ ↓ │ 32 | │ Generate Embeddings (sentence-transformers, MPS/CUDA) │ 33 | │ ↓ │ 34 | │ PostgreSQL + pgvector (partitioned by month) │ 35 | └─────────────────────────────────────────────────────────────────┘ 36 | 37 | ┌─────────────────────────────────────────────────────────────────┐ 38 | │ Query System │ 39 | ├─────────────────────────────────────────────────────────────────┤ 40 | │ User Query │ 41 | │ ↓ │ 42 | │ Encode with sentence-transformers/all-mpnet-base-v2 │ 43 | │ ↓ │ 44 | │ Redis Cache Check ────────────────┐ │ 45 | │ ↓ │ (cache hit) │ 46 | │ PostgreSQL Vector Search (cosine) │ │ 47 | │ ↓ │ │ 48 | │ Cache Results ─────────────────────┘ │ 49 | │ ↓ │ 50 | │ Return Top K Documents │ 51 | └─────────────────────────────────────────────────────────────────┘ 52 | 53 | ┌─────────────────────────────────────────────────────────────────┐ 54 | │ RAG System │ 55 | ├─────────────────────────────────────────────────────────────────┤ 56 | │ LangGraph Workflow (StateGraph) │ 57 | │ ↓ │ 58 | │ [Retrieve] → Vector Search → Top 10 Comments │ 59 | │ ↓ │ 60 | │ [Answer] → DeepSeek LLM → Generated Response │ 61 | │ ↓ │ 62 | │ Gradio Web UI (with sources & citations) │ 63 | └─────────────────────────────────────────────────────────────────┘ 64 | ``` 65 | 66 | ## 🛠️ Technical Stack 67 | 68 | ### Core Technologies 69 | - **Language**: Python 3.13 70 | - **Package Manager**: [uv](https://github.com/astral-sh/uv) (fast Python package installer) 71 | - **Database**: PostgreSQL 16 + [pgvector](https://github.com/pgvector/pgvector) extension 72 | - **Vector Model**: [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) (768-dim embeddings) 73 | - **LLM**: DeepSeek via OpenAI-compatible API 74 | - **Cache**: Redis 6.x 75 | - **Data Source**: [BigQuery HN Public Dataset](https://console.cloud.google.com/marketplace/product/y-combinator/hacker-news) 76 | 77 | ### Key Libraries 78 | - **sentence-transformers**: High-quality sentence embeddings 79 | - **psycopg3**: Modern PostgreSQL adapter with async support 80 | - **pgvector**: PostgreSQL extension for vector similarity search 81 | - **LangGraph**: Orchestration framework for LLM workflows 82 | - **LangChain**: LLM abstractions and prompting 83 | - **Gradio**: Web UI framework 84 | - **torch**: PyTorch for model inference (MPS support for Apple Silicon) 85 | - **pandas + pyarrow**: Data processing pipeline 86 | 87 | ### Infrastructure 88 | - **Compute**: Railway (PostgreSQL + Redis) / Local development 89 | - **Deployment**: Docker-ready with docker-compose.yml 90 | 91 | ## 📊 Dataset 92 | 93 | **Source**: Hacker News comments from BigQuery public dataset (`bigquery-public-data.hacker_news.full`) 94 | 95 | **Scope**: 96 | - Time range: January 2023 - September 2025 97 | - Comment count: ~9.4M comments 98 | - Filters: Non-deleted, non-dead, non-null text 99 | 100 | **Partitioning Strategy**: 101 | - Tables partitioned by month (e.g., `hn_documents_2023_01`, `hn_documents_2023_02`, ...) 102 | - Enables efficient querying and index management 103 | - Simplifies incremental updates 104 | 105 | ## 🚀 Getting Started 106 | 107 | ### Prerequisites 108 | 109 | ```bash 110 | # Install uv (recommended) 111 | curl -LsSf https://astral.sh/uv/install.sh | sh 112 | 113 | # Or use pip 114 | pip install uv 115 | 116 | # Install PostgreSQL with pgvector 117 | # See: https://github.com/pgvector/pgvector#installation 118 | 119 | # Install Redis (optional, for caching) 120 | brew install redis # macOS 121 | sudo apt install redis # Ubuntu 122 | ``` 123 | 124 | ### Installation 125 | 126 | ```bash 127 | # Clone the repository 128 | git clone https://github.com/yourusername/hn-search.git 129 | cd hn-search 130 | 131 | # Install dependencies 132 | uv sync 133 | 134 | # For development (includes BigQuery tools) 135 | uv sync --extra dev 136 | 137 | # Set up environment variables 138 | cp .env.example .env 139 | # Edit .env with your credentials: 140 | # DATABASE_URL=postgres://user:pass@host:port/dbname 141 | # DEEPSEEK_API_KEY=sk-... 142 | # REDIS_URL=redis://localhost:6379 (optional) 143 | ``` 144 | 145 | ### Database Setup 146 | 147 | ```bash 148 | # Initialize database with precomputed embeddings 149 | uv run python -m hn_search.init_db_pgvector 150 | 151 | # Or initialize with test mode (100 docs only) 152 | uv run python -m hn_search.init_db_pgvector --test 153 | ``` 154 | 155 | ## 💡 Usage 156 | 157 | ### 1. Semantic Search (CLI) 158 | 159 | ```bash 160 | # Basic search 161 | uv run python -m hn_search.query "What do people think about Rust vs Go?" 162 | 163 | # Get more results 164 | uv run python -m hn_search.query "best practices for system design" 20 165 | 166 | # Output includes: 167 | # - Comment ID (with HN link) 168 | # - Author 169 | # - Timestamp 170 | # - Cosine distance score 171 | # - Full comment text 172 | ``` 173 | 174 | ### 2. RAG System (Web UI) 175 | 176 | ```bash 177 | # Start the Gradio web interface 178 | uv run python -m hn_search.rag.web_ui 179 | 180 | # Open http://localhost:7860 181 | # Ask questions like: 182 | # "What are the main criticisms of microservices?" 183 | # "How do people debug production issues?" 184 | # "What do HN users think about AI coding assistants?" 185 | ``` 186 | 187 | **Features**: 188 | - Real-time streaming responses 189 | - Source citations with HN links 190 | - URL parameter support: `?q=your+question` 191 | - Auto-search from URL parameters 192 | 193 | ### 3. Incremental Data Updates 194 | 195 | ```bash 196 | # Fetch new comments, generate embeddings, and upsert to DB 197 | uv run --extra dev python misc/fetch_and_embed_new_comments.py 198 | 199 | # Options: 200 | # --project # Specify BigQuery billing project 201 | # --skip-fetch # Skip BigQuery download 202 | # --skip-embed # Skip embedding generation 203 | # --skip-upsert # Skip database insertion 204 | # --reset # Clear state and start fresh 205 | 206 | # Resume interrupted runs (automatic) 207 | uv run --extra dev python misc/fetch_and_embed_new_comments.py 208 | 209 | # The script is fully idempotent and resumable: 210 | # - Saves state to data/raw/fetch_state.json 211 | # - Checks for existing files before re-downloading 212 | # - Incremental embedding generation with checkpoints 213 | # - Tracks processed IDs to avoid duplicate inserts 214 | ``` 215 | 216 | ### 4. Generate Embeddings (Batch) 217 | 218 | A single 5090 Nvidia was rented from vast.ai to compute all historical 219 | embeddings for a few dollars. 220 | 221 | ```bash 222 | # Process raw parquet files and generate embeddings 223 | uv run python misc/generate_embeddings_gpu.py 224 | 225 | # Uses MPS (Apple Silicon) or CUDA automatically 226 | # Processes in batches to avoid OOM 227 | # Saves to embeddings/*.parquet 228 | ``` 229 | 230 | ## 🔍 How It Works 231 | 232 | ### Vector Search 233 | 234 | The system uses cosine distance for similarity search: 235 | 236 | ```sql 237 | SELECT id, clean_text, author, timestamp, type, 238 | embedding <=> query_vector AS distance 239 | FROM hn_documents 240 | ORDER BY embedding <=> query_vector 241 | LIMIT 10 242 | ``` 243 | 244 | **Performance Optimizations**: 245 | - HNSW index for approximate nearest neighbor search 246 | - Redis caching layer reduces repeated queries to <100ms 247 | - Connection pooling with psycopg3 248 | - Partitioned tables for efficient index scans 249 | 250 | ### RAG Pipeline 251 | 252 | The RAG system uses LangGraph to orchestrate a two-node workflow: 253 | 254 | 1. **Retrieve Node**: 255 | - Encodes user query with sentence-transformers 256 | - Performs vector search in PostgreSQL 257 | - Returns top 10 most relevant comments 258 | 259 | 2. **Answer Node**: 260 | - Formats retrieved comments as context 261 | - Prompts DeepSeek LLM with query + context 262 | - Streams response back to user 263 | 264 | **Prompt Engineering**: 265 | ```python 266 | system_prompt = """You are a helpful assistant that answers questions 267 | based on Hacker News discussions. Use the provided comments to give 268 | accurate, well-sourced answers. Cite comment numbers [1], [2], etc.""" 269 | 270 | user_prompt = f"""Question: {query} 271 | 272 | Context from HN comments: 273 | {formatted_comments} 274 | 275 | Answer:""" 276 | ``` 277 | 278 | ### Embedding Model 279 | 280 | **Model**: `sentence-transformers/all-mpnet-base-v2` 281 | - Dimensions: 768 282 | - Max sequence length: 384 tokens 283 | - Training: MS MARCO + Natural Questions + other datasets 284 | - Performance: SOTA for semantic similarity tasks 285 | 286 | **Why this model?** 287 | - Excellent balance of quality vs. speed 288 | - Pre-trained on diverse Q&A datasets 289 | - Good generalization to HN comment domain 290 | - Efficient inference on CPU/MPS/CUDA 291 | 292 | ## 📈 Performance & Scale 293 | 294 | ### Current Scale 295 | - **Documents**: ~9.4M Hacker News comments 296 | - **Storage**: ~40 GB (including embeddings) 297 | - **Query Latency**: 298 | - Cold query: ~30s (embedding + search + LLM) 299 | - Cached query: <1s (Redis cache hit) 300 | - Concurrent duplicate query: <1s (job deduplication) 301 | - RAG end-to-end: ~30s (including LLM generation) 302 | 303 | ### Production Optimizations ⚡ 304 | 305 | **Implemented**: 306 | 1. **Singleton Embedding Model**: Model loaded once and reused (3-5x throughput) 307 | 2. **Job Deduplication**: Concurrent duplicate queries share processing (saves 90%+ compute) 308 | 3. **Multi-layer Caching**: Redis cache for vector search, LLM answers, and job results 309 | 4. **Connection Pooling**: PostgreSQL connection pool (min: 2, max: 20) 310 | 5. **Partitioned Tables**: Monthly partitions for efficient indexing 311 | 6. **Incremental Updates**: Only process new comments since last run 312 | 313 | ### Capacity 314 | 315 | **Single Instance**: 316 | - 20-30 concurrent users (unique queries) 317 | - 100+ concurrent users (with 80% cache hit rate) 318 | 319 | **Horizontal Scaling** (Railway/Cloud): 320 | - 2 replicas: 40-60 concurrent users 321 | - 4 replicas: 80-120 concurrent users 322 | - 8 replicas: 160-240 concurrent users 323 | 324 | See [RAILWAY.md](RAILWAY.md) for deployment guide. 325 | 326 | ### Resource Requirements 327 | - **RAM**: 2-3 GB per instance (1.5GB for embedding model) 328 | - **Storage**: ~40 GB for database (9.4M documents + embeddings) 329 | - **CPU**: 0.5-1.0 cores per instance 330 | - **PostgreSQL**: 100+ connections (20 per instance) 331 | - **Redis**: 512MB-1GB for cache 332 | 333 | ## 🔧 Configuration 334 | 335 | ### Environment Variables 336 | 337 | ```bash 338 | # Required 339 | DATABASE_URL=postgres://user:pass@host:port/dbname 340 | DEEPSEEK_API_KEY=sk-... 341 | 342 | # Optional 343 | REDIS_URL=redis://localhost:6379 344 | GOOGLE_CLOUD_PROJECT=your-gcp-project 345 | TOKENIZERS_PARALLELISM=false # Disable for multi-threaded use 346 | ``` 347 | 348 | ### PostgreSQL Settings 349 | 350 | For optimal performance on Railway/cloud instances: 351 | ```sql 352 | -- Default settings (already applied) 353 | shared_buffers = 128MB 354 | maintenance_work_mem = 64MB 355 | work_mem = 4MB 356 | max_parallel_workers = 8 357 | effective_cache_size = 4GB 358 | ``` 359 | 360 | ## 🧪 Development 361 | 362 | ### Code Quality 363 | 364 | ```bash 365 | # Format code 366 | make format 367 | 368 | # Run linter 369 | make lint 370 | 371 | # Sort imports 372 | make imports 373 | ``` 374 | 375 | ### Project Structure 376 | 377 | ``` 378 | hn-search/ 379 | ├── hn_search/ # Main package 380 | │ ├── query.py # Vector search interface 381 | │ ├── init_db_pgvector.py # Database initialization 382 | │ ├── db_config.py # Database connection config 383 | │ ├── cache_config.py # Redis caching layer 384 | │ ├── common.py # Shared utilities 385 | │ └── rag/ # RAG system 386 | │ ├── graph.py # LangGraph workflow 387 | │ ├── nodes.py # Retrieve & Answer nodes 388 | │ ├── state.py # State management 389 | │ ├── cli.py # CLI interface 390 | │ └── web_ui.py # Gradio web interface 391 | ├── misc/ # Utility scripts 392 | │ ├── generate_embeddings_gpu.py # Batch embedding generation 393 | │ └── fetch_and_embed_new_comments.py # Incremental updates 394 | ├── data/ # Data directory 395 | │ └── raw/ # Raw parquet files 396 | ├── pyproject.toml # Project dependencies 397 | └── Makefile # Development shortcuts 398 | ``` 399 | 400 | ## 🎓 Learning Outcomes 401 | 402 | This project demonstrates: 403 | 404 | 1. **Vector Search at Scale**: Implementing semantic search with pgvector on millions of documents 405 | 2. **Production ML Pipelines**: Idempotent, resumable data processing with checkpointing 406 | 3. **RAG Architecture**: Building retrieval-augmented generation with LangGraph 407 | 4. **Database Optimization**: Partitioning strategies, connection pooling, caching 408 | 5. **Modern Python Tooling**: uv, ruff, type hints, async patterns 409 | 6. **Cloud Integration**: BigQuery public datasets, Railway deployment, Redis caching 410 | 7. **GPU Optimization**: MPS/CUDA support for efficient embedding generation 411 | 412 | ## 📚 References 413 | 414 | - [pgvector: Open-source vector similarity search for Postgres](https://github.com/pgvector/pgvector) 415 | - [sentence-transformers: State-of-the-art sentence embeddings](https://www.sbert.net/) 416 | - [LangGraph: Building stateful, multi-actor LLM applications](https://langchain-ai.github.io/langgraph/) 417 | - [BigQuery HN Dataset](https://console.cloud.google.com/marketplace/product/y-combinator/hacker-news) 418 | - [Retrieval-Augmented Generation (Lewis et al., 2020)](https://arxiv.org/abs/2005.11401) 419 | 420 | ## 📄 License 421 | 422 | MIT License - see LICENSE file for details 423 | 424 | ## 🤝 Contributing 425 | 426 | Contributions welcome! Please open an issue or PR. 427 | 428 | ### Development Setup 429 | 430 | ```bash 431 | # Fork and clone 432 | git clone https://github.com/yourusername/hn-search.git 433 | 434 | # Create a branch 435 | git checkout -b feature/your-feature 436 | 437 | # Make changes and test 438 | uv run python -m hn_search.query "test query" 439 | 440 | # Format and lint 441 | make format 442 | 443 | # Commit and push 444 | git commit -m "Add your feature" 445 | git push origin feature/your-feature 446 | ``` 447 | 448 | ## 🙏 Acknowledgments 449 | 450 | - Y Combinator for open-sourcing Hacker News data 451 | - The pgvector team for excellent PostgreSQL integration 452 | - sentence-transformers community for pre-trained models 453 | - LangChain team for RAG tooling 454 | 455 | --- 456 | 457 | **⭐ If you find this project useful, please consider giving it a star!** 458 | 459 | ![example](./example.png) 460 | -------------------------------------------------------------------------------- /misc/fetch_and_embed_new_comments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Script to fetch new HN comments from BigQuery, generate embeddings using MPS (Mac GPU), 4 | and upsert them into partitioned PostgreSQL tables. 5 | 6 | Idempotent and resumable - can be interrupted and restarted at any point. 7 | 8 | Usage: 9 | uv run python misc/fetch_and_embed_new_comments.py [--resume] [--skip-fetch] [--skip-embed] [--skip-upsert] 10 | """ 11 | 12 | import argparse 13 | import json 14 | import os 15 | from datetime import datetime 16 | from pathlib import Path 17 | from typing import Optional, Tuple 18 | 19 | import html2text 20 | import numpy as np 21 | import pandas as pd 22 | import psycopg 23 | import torch 24 | from dotenv import load_dotenv 25 | from google.cloud import bigquery 26 | from pgvector.psycopg import register_vector 27 | from sentence_transformers import SentenceTransformer 28 | 29 | from hn_search.db_config import get_db_config 30 | 31 | load_dotenv() 32 | 33 | STATE_FILE = "data/raw/fetch_state.json" 34 | 35 | 36 | def load_state() -> dict: 37 | """Load state from file""" 38 | state_path = Path(STATE_FILE) 39 | if state_path.exists(): 40 | with open(state_path, "r") as f: 41 | return json.load(f) 42 | return {} 43 | 44 | 45 | def save_state(state: dict): 46 | """Save state to file""" 47 | state_path = Path(STATE_FILE) 48 | state_path.parent.mkdir(parents=True, exist_ok=True) 49 | with open(state_path, "w") as f: 50 | json.dump(state, f, indent=2) 51 | print(f"💾 State saved to {state_path}") 52 | 53 | 54 | def strip_html(text: str) -> str: 55 | """Clean HTML from text""" 56 | h = html2text.HTML2Text() 57 | h.ignore_links = True 58 | h.ignore_images = True 59 | h.ignore_emphasis = True 60 | return h.handle(text).strip() 61 | 62 | 63 | def get_connection(): 64 | """Get a fresh database connection with vector registration""" 65 | db_config = get_db_config() 66 | conn = psycopg.connect(**db_config) 67 | register_vector(conn) 68 | return conn 69 | 70 | 71 | def find_latest_nonempty_partition(): 72 | """Find the latest non-empty partitioned table""" 73 | conn = get_connection() 74 | with conn.cursor() as cur: 75 | cur.execute(""" 76 | SELECT tablename 77 | FROM pg_tables 78 | WHERE schemaname = 'public' 79 | AND tablename LIKE 'hn_documents_____%%' 80 | ORDER BY tablename DESC 81 | """) 82 | tables = [row[0] for row in cur.fetchall()] 83 | 84 | # Check each table for rows 85 | for table in tables: 86 | cur.execute(f"SELECT COUNT(*) FROM {table}") 87 | count = cur.fetchone()[0] 88 | if count > 0: 89 | conn.close() 90 | print(f"Found latest non-empty partition: {table} ({count:,} rows)") 91 | return table 92 | 93 | conn.close() 94 | return None 95 | 96 | 97 | def get_max_id_from_partition(table_name): 98 | """Get the maximum id from the specified partition""" 99 | conn = get_connection() 100 | with conn.cursor() as cur: 101 | # IDs are stored as text, so we need to cast to bigint for proper comparison 102 | cur.execute(f"SELECT MAX(CAST(id AS BIGINT)) FROM {table_name}") 103 | max_id = cur.fetchone()[0] 104 | conn.close() 105 | print(f"Max ID in {table_name}: {max_id}") 106 | return max_id 107 | 108 | 109 | def fetch_from_bigquery( 110 | min_id, output_dir="data/raw", state=None, project=None 111 | ) -> Optional[Path]: 112 | """Fetch new comments from BigQuery starting after min_id""" 113 | output_path = Path(output_dir) 114 | output_path.mkdir(parents=True, exist_ok=True) 115 | 116 | # Check if we already have a raw file from a previous run 117 | if state and state.get("raw_file"): 118 | raw_file = Path(state["raw_file"]) 119 | if raw_file.exists(): 120 | print(f"✅ Found existing raw file: {raw_file}") 121 | return raw_file 122 | 123 | print(f"\n📥 Fetching comments from BigQuery with id > {min_id}") 124 | 125 | # Initialize client - uses Application Default Credentials by default 126 | # Specify project to control billing 127 | project = project or os.getenv("GOOGLE_CLOUD_PROJECT") 128 | 129 | # Check for credentials JSON in environment (for Railway/cloud deployments) 130 | creds_json = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON") 131 | if creds_json: 132 | import tempfile 133 | from google.oauth2 import service_account 134 | 135 | # Write credentials to a temp file 136 | with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as f: 137 | f.write(creds_json) 138 | creds_file = f.name 139 | 140 | credentials = service_account.Credentials.from_service_account_file(creds_file) 141 | client = bigquery.Client(project=project, credentials=credentials) 142 | 143 | # Clean up temp file 144 | os.unlink(creds_file) 145 | else: 146 | client = bigquery.Client(project=project) 147 | 148 | print(f"💰 Using GCP project: {client.project}") 149 | print(f" Account: Run 'gcloud auth list' to see active account") 150 | print(f" Tip: Set billing alerts at https://console.cloud.google.com/billing/") 151 | 152 | query = f""" 153 | SELECT 154 | id, 155 | `by` AS author, 156 | `type`, 157 | text, 158 | timestamp 159 | FROM 160 | `bigquery-public-data.hacker_news.full` 161 | WHERE 162 | dead IS NOT TRUE 163 | AND deleted IS NOT TRUE 164 | AND text IS NOT NULL 165 | AND type = 'comment' 166 | AND id > {min_id} 167 | ORDER BY id 168 | """ 169 | 170 | print("Running BigQuery query...") 171 | query_job = client.query(query) 172 | df = query_job.to_dataframe() 173 | 174 | print(f"Fetched {len(df):,} new comments") 175 | 176 | if len(df) == 0: 177 | print("No new comments to process") 178 | return None 179 | 180 | # Use consistent filename for resumability 181 | filename = f"new_comments_from_{min_id}.parquet" 182 | filepath = output_path / filename 183 | 184 | # Save atomically with temp file 185 | temp_filepath = filepath.with_suffix(".parquet.tmp") 186 | df.to_parquet(temp_filepath, index=False) 187 | temp_filepath.rename(filepath) 188 | 189 | print(f"💾 Saved to {filepath}") 190 | 191 | return filepath 192 | 193 | 194 | def generate_embeddings_mps( 195 | parquet_file, state=None 196 | ) -> Tuple[Optional[Path], Optional[pd.DataFrame]]: 197 | """Generate embeddings using MPS (Mac GPU) - saves progress incrementally""" 198 | output_file = Path(str(parquet_file).replace(".parquet", "_embedded.parquet")) 199 | 200 | # Check if embeddings already exist 201 | if output_file.exists(): 202 | print(f"✅ Found existing embedded file: {output_file}") 203 | df = pd.read_parquet(output_file) 204 | return output_file, df 205 | 206 | # Check for MPS availability 207 | if torch.backends.mps.is_available(): 208 | device = "mps" 209 | print("🚀 Using MPS (Apple Silicon GPU)") 210 | elif torch.cuda.is_available(): 211 | device = "cuda" 212 | print(f"🚀 Using CUDA GPU: {torch.cuda.get_device_name(0)}") 213 | else: 214 | device = "cpu" 215 | print("⚠️ Using CPU (no GPU available)") 216 | 217 | model = SentenceTransformer( 218 | "sentence-transformers/all-mpnet-base-v2", device=device 219 | ) 220 | 221 | print(f"\n🔄 Loading {parquet_file}") 222 | df = pd.read_parquet(parquet_file) 223 | df = df[df["text"].notna() & (df["text"] != "")] 224 | 225 | print(f"Processing {len(df):,} documents") 226 | 227 | # Clean text 228 | df["clean_text"] = df["text"].astype(str).apply(strip_html) 229 | df = df[df["clean_text"].str.len() > 0] 230 | 231 | if len(df) == 0: 232 | print("No valid documents after cleaning") 233 | return None, None 234 | 235 | # Generate embeddings in chunks and save incrementally 236 | chunk_size = 1000 # Save every 1000 documents 237 | encode_batch_size = 128 238 | 239 | documents = df["clean_text"].tolist() 240 | all_embeddings = [] 241 | 242 | temp_output = output_file.with_suffix(".parquet.tmp") 243 | 244 | for chunk_start in range(0, len(documents), chunk_size): 245 | chunk_end = min(chunk_start + chunk_size, len(documents)) 246 | chunk_docs = documents[chunk_start:chunk_end] 247 | 248 | print( 249 | f"\n📦 Processing chunk {chunk_start // chunk_size + 1}/{(len(documents) - 1) // chunk_size + 1}" 250 | ) 251 | 252 | chunk_embeddings = [] 253 | for batch_start in range(0, len(chunk_docs), encode_batch_size): 254 | batch = chunk_docs[batch_start : batch_start + encode_batch_size] 255 | print( 256 | f" Encoding batch {batch_start // encode_batch_size + 1}/{(len(chunk_docs) - 1) // encode_batch_size + 1}" 257 | ) 258 | 259 | embeddings = model.encode( 260 | batch, 261 | batch_size=encode_batch_size, 262 | show_progress_bar=False, 263 | convert_to_numpy=True, 264 | ) 265 | chunk_embeddings.extend(embeddings) 266 | 267 | all_embeddings.extend(chunk_embeddings) 268 | 269 | # Save intermediate progress 270 | partial_df = df.iloc[:chunk_end].copy() 271 | partial_df["embedding"] = [emb.tolist() for emb in all_embeddings] 272 | partial_df.to_parquet(temp_output, index=False) 273 | print(f" 💾 Saved progress: {chunk_end}/{len(documents)} documents") 274 | 275 | # Final save - rename temp to final 276 | df["embedding"] = [emb.tolist() for emb in all_embeddings] 277 | temp_output.rename(output_file) 278 | print(f"\n✅ Saved all embeddings to {output_file}") 279 | 280 | return output_file, df 281 | 282 | 283 | def get_partition_name_for_timestamp(timestamp): 284 | """Get partition table name for a given timestamp""" 285 | # Convert timestamp to datetime if it's a string 286 | if isinstance(timestamp, str): 287 | dt = pd.to_datetime(timestamp) 288 | else: 289 | dt = timestamp 290 | 291 | return f"hn_documents_{dt.year:04d}_{dt.month:02d}" 292 | 293 | 294 | def upsert_to_db(embedded_parquet_file, df, state=None): 295 | """Upsert embeddings into partitioned tables with batch processing""" 296 | print(f"\n🔄 Upserting {len(df):,} documents to database") 297 | 298 | conn = get_connection() 299 | 300 | # Group by partition 301 | df["partition"] = df["timestamp"].apply(get_partition_name_for_timestamp) 302 | 303 | # Track processed IDs if resuming 304 | processed_ids = set(state.get("upserted_ids", [])) if state else set() 305 | 306 | total_inserted = 0 307 | batch_size = 500 # Insert in batches for better performance 308 | 309 | try: 310 | for partition_name, partition_df in df.groupby("partition"): 311 | print( 312 | f"\n 📋 Processing partition {partition_name} ({len(partition_df):,} rows)" 313 | ) 314 | 315 | # Get existing IDs from the database to avoid duplicates 316 | with conn.cursor() as cur: 317 | partition_ids = partition_df["id"].astype(str).tolist() 318 | cur.execute( 319 | f"SELECT id FROM {partition_name} WHERE id = ANY(%s)", 320 | (partition_ids,), 321 | ) 322 | existing_ids = set(row[0] for row in cur.fetchall()) 323 | print(f" Found {len(existing_ids):,} existing IDs in {partition_name}") 324 | 325 | # Filter out already processed rows and existing IDs 326 | partition_df = partition_df[ 327 | ~partition_df["id"].astype(str).isin(processed_ids | existing_ids) 328 | ] 329 | 330 | if len(partition_df) == 0: 331 | print(f" ⏭️ All rows already exist in {partition_name}") 332 | continue 333 | 334 | print(f" Inserting {len(partition_df):,} new rows") 335 | 336 | records = [ 337 | ( 338 | str(row["id"]), 339 | row["clean_text"], 340 | str(row["author"]), 341 | str(row["timestamp"]), 342 | str(row["type"]), 343 | row["embedding"], 344 | ) 345 | for _, row in partition_df.iterrows() 346 | ] 347 | 348 | # Process in batches 349 | for batch_start in range(0, len(records), batch_size): 350 | batch = records[batch_start : batch_start + batch_size] 351 | 352 | with conn.cursor() as cur: 353 | cur.executemany( 354 | f""" 355 | INSERT INTO {partition_name} (id, clean_text, author, timestamp, type, embedding) 356 | VALUES (%s, %s, %s, %s, %s, %s) 357 | """, 358 | batch, 359 | ) 360 | 361 | conn.commit() 362 | 363 | # Track progress 364 | batch_ids = [r[0] for r in batch] 365 | processed_ids.update(batch_ids) 366 | 367 | total_inserted += len(batch) 368 | print( 369 | f" Progress: {batch_start + len(batch)}/{len(records)} | Total: {total_inserted:,}" 370 | ) 371 | 372 | # Save state periodically 373 | if state is not None: 374 | state["upserted_ids"] = list(processed_ids) 375 | save_state(state) 376 | 377 | print(f" ✅ Inserted {len(records):,} rows into {partition_name}") 378 | 379 | finally: 380 | conn.close() 381 | 382 | print(f"\n✅ Total inserted: {total_inserted:,} documents") 383 | 384 | 385 | def main(): 386 | """Main execution flow with resumability""" 387 | parser = argparse.ArgumentParser( 388 | description="Fetch, embed, and upsert HN comments (idempotent & resumable)" 389 | ) 390 | parser.add_argument( 391 | "--resume", action="store_true", help="Resume from previous state" 392 | ) 393 | parser.add_argument( 394 | "--skip-fetch", action="store_true", help="Skip BigQuery fetch step" 395 | ) 396 | parser.add_argument( 397 | "--skip-embed", action="store_true", help="Skip embedding generation step" 398 | ) 399 | parser.add_argument( 400 | "--skip-upsert", action="store_true", help="Skip database upsert step" 401 | ) 402 | parser.add_argument( 403 | "--reset", action="store_true", help="Reset state and start fresh" 404 | ) 405 | parser.add_argument( 406 | "--project", 407 | type=str, 408 | help="GCP project ID for BigQuery billing (defaults to GOOGLE_CLOUD_PROJECT env var or gcloud default)", 409 | ) 410 | 411 | args = parser.parse_args() 412 | 413 | print("=" * 80) 414 | print("HN Comments Fetcher, Embedder, and Upserter (Idempotent)") 415 | print("=" * 80) 416 | 417 | # Load or initialize state 418 | if args.reset: 419 | print("🔄 Resetting state...") 420 | state = {} 421 | else: 422 | state = load_state() 423 | if state: 424 | print(f"📂 Loaded state from {STATE_FILE}") 425 | print(f" Raw file: {state.get('raw_file', 'None')}") 426 | print(f" Embedded file: {state.get('embedded_file', 'None')}") 427 | print(f" Upserted IDs: {len(state.get('upserted_ids', []))} documents") 428 | 429 | # Step 1: Find latest non-empty partition 430 | latest_partition = find_latest_nonempty_partition() 431 | if not latest_partition: 432 | print("❌ No partitioned tables found!") 433 | return 434 | 435 | # Step 2: Get max ID from that partition 436 | max_id = get_max_id_from_partition(latest_partition) 437 | if max_id is None: 438 | print("❌ Could not determine max ID") 439 | return 440 | 441 | state["max_id"] = max_id 442 | save_state(state) 443 | 444 | # Step 3: Fetch new comments from BigQuery 445 | parquet_file = None 446 | if not args.skip_fetch: 447 | parquet_file = fetch_from_bigquery(max_id, state=state, project=args.project) 448 | if not parquet_file: 449 | print("No new data to process") 450 | return 451 | state["raw_file"] = str(parquet_file) 452 | save_state(state) 453 | else: 454 | print("⏭️ Skipping fetch step") 455 | if state.get("raw_file"): 456 | parquet_file = Path(state["raw_file"]) 457 | else: 458 | print("❌ No raw file in state - cannot skip fetch") 459 | return 460 | 461 | # Step 4: Generate embeddings 462 | embedded_file = None 463 | df = None 464 | if not args.skip_embed: 465 | embedded_file, df = generate_embeddings_mps(parquet_file, state=state) 466 | if not embedded_file: 467 | print("❌ Embedding generation failed") 468 | return 469 | state["embedded_file"] = str(embedded_file) 470 | save_state(state) 471 | else: 472 | print("⏭️ Skipping embed step") 473 | if state.get("embedded_file"): 474 | embedded_file = Path(state["embedded_file"]) 475 | df = pd.read_parquet(embedded_file) 476 | else: 477 | print("❌ No embedded file in state - cannot skip embed") 478 | return 479 | 480 | # Step 5: Upsert to database 481 | if not args.skip_upsert: 482 | upsert_to_db(embedded_file, df, state=state) 483 | state["completed"] = True 484 | state["completed_at"] = datetime.now().isoformat() 485 | save_state(state) 486 | else: 487 | print("⏭️ Skipping upsert step") 488 | 489 | print("\n" + "=" * 80) 490 | print("✅ All done!") 491 | print("=" * 80) 492 | 493 | 494 | if __name__ == "__main__": 495 | main() 496 | --------------------------------------------------------------------------------