├── src ├── __init__.py ├── agents │ ├── __init__.py │ ├── base.py │ └── research_assistant.py ├── config │ ├── __init__.py │ └── settings.py ├── core │ ├── __init__.py │ ├── state.py │ ├── messages.py │ └── graph.py ├── tools │ ├── __init__.py │ ├── rag.py │ └── paper_lookup.py ├── utils │ ├── __init__.py │ └── paper_id_extractor.py ├── components │ ├── __init__.py │ ├── paper │ │ ├── __init__.py │ │ ├── models.py │ │ └── tool.py │ ├── rag │ │ ├── __init__.py │ │ ├── embeddings.py │ │ ├── indexing.py │ │ └── tool.py │ ├── database │ │ ├── __init__.py │ │ ├── neo4j_client.py │ │ ├── vector_store.py │ │ ├── ingest.py │ │ └── neo4j_ingestion.py │ └── evaluation │ │ ├── __init__.py │ │ ├── experiment_tracker.py │ │ ├── custom_metric.py │ │ └── opik_evaluator.py ├── orchestrator │ ├── __init__.py │ └── coordinator.py ├── streamlit │ ├── __init__.py │ ├── session.py │ ├── message.py │ ├── predefined_questions.py │ ├── main.py │ ├── ui_component.py │ └── layout.py └── main.py ├── .gitignore ├── .DS_Store ├── .env.example ├── .idea ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── misc.xml └── CometML.iml ├── requirements.txt ├── pyproject.toml ├── scripts ├── neo4j_cleaner.py └── preprocess.py └── readme.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/orchestrator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/streamlit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/components/paper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/components/rag/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/components/database/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/components/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | !.env.example 3 | neo4j.conf 4 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvanguards/agentic-graph-rag-evaluation-cometml/HEAD/.DS_Store -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from src.streamlit.main import main 2 | 3 | if __name__ == "__main__": 4 | main() 5 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | #OPENAI 2 | OPENAI_API_KEY=openai-api-key 3 | 4 | #NEO4J 5 | NEO4J_URI=uri 6 | NEO4J_USER=username 7 | NEO4J_PASSWORD=password 8 | 9 | #COMETML 10 | COMETML_API_KEY=cometml_api_key -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /src/core/state.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, TypedDict, Dict, List, Any 2 | from langchain_core.messages import AnyMessage 3 | from langgraph.graph.message import add_messages 4 | 5 | class ConversationState(TypedDict): 6 | messages: Annotated[list[AnyMessage], add_messages] 7 | metrics: dict 8 | conversation_history: List[Dict[str, Any]] 9 | -------------------------------------------------------------------------------- /.idea/CometML.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/core/messages.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from datetime import datetime 4 | 5 | @dataclass 6 | class ConversationMessage: 7 | content: str 8 | type: str 9 | timestamp: datetime 10 | query: Optional[str] = None 11 | paper_id: Optional[str] = None 12 | response: Optional[str] = None 13 | -------------------------------------------------------------------------------- /src/streamlit/session.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Dict, Any 3 | 4 | @dataclass 5 | class SessionState: 6 | messages: List[Any] = field(default_factory=list) 7 | metrics: Dict[str, Any] = field(default_factory=dict) 8 | conversation_history: List[str] = field(default_factory=list) 9 | show_predefined: bool = True 10 | session_active: bool = True 11 | -------------------------------------------------------------------------------- /src/streamlit/message.py: -------------------------------------------------------------------------------- 1 | from langchain.schema import AIMessage, HumanMessage 2 | 3 | class Message: 4 | def __init__(self, content: str, is_human: bool = True): 5 | self.content = content 6 | self.message = HumanMessage(content=content) if is_human else AIMessage(content=content) 7 | 8 | def format_for_display(self) -> str: 9 | prefix = "You" if isinstance(self.message, HumanMessage) else "Assistant" 10 | return f"**{prefix}:** {self.content}" 11 | -------------------------------------------------------------------------------- /src/agents/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from src.core.state import ConversationState 3 | from typing import Dict, Any 4 | 5 | class BaseAgent(ABC): 6 | @abstractmethod 7 | def process_message(self, state: ConversationState) -> Dict[str, Any]: 8 | """Process a message and return updated state.""" 9 | pass 10 | 11 | @abstractmethod 12 | def handle_error(self, error: Exception) -> Dict[str, Any]: 13 | """Handle errors during message processing.""" 14 | pass -------------------------------------------------------------------------------- /src/config/settings.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | import os 3 | 4 | load_dotenv() 5 | 6 | class Settings: 7 | def __init__(self): 8 | load_dotenv() 9 | self.cometml_api_key = os.getenv("COMETML_API_KEY") 10 | self.project_name = "research-paper-rag" 11 | self.openai_api_key = os.getenv("OPENAI_API_KEY") 12 | self.neo4j_uri = os.getenv("NEO4J_URI") 13 | self.neo4j_user = os.getenv("NEO4J_USERNAME") 14 | self.neo4j_password = os.getenv("NEO4J_PASSWORD") 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | comet-ml==3.47.0 2 | langchain==0.3.3 3 | pandas==2.2.3 4 | langchain-community==0.3.2 5 | langchain-experimental==0.3.3 6 | langchain-openai==0.2.2 7 | graphdatascience==1.12.0 8 | tiktoken==0.8.0 9 | retry==0.9.2 10 | neo4j==5.21.0 11 | sentence-transformers==3.2.0 12 | faiss-cpu==1.9.0 13 | python-dotenv==1.0.1 14 | langgraph==0.2.39 15 | langgraph-checkpoint-sqlite==2.0.0 16 | postgres==4.0.0 17 | streamlit==1.39.0 18 | 19 | # Additional dependencies 20 | numpy>=1.24.0 21 | torch>=2.0.0 22 | transformers>=4.30.0 23 | openai>=1.0.0 24 | pydantic>=2.0.0 25 | typing-extensions>=4.5.0 26 | tqdm>=4.65.0 27 | requests>=2.31.0 28 | aiohttp>=3.10.11 -------------------------------------------------------------------------------- /src/streamlit/predefined_questions.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Callable 3 | 4 | @dataclass 5 | class PredefinedQuestion: 6 | text: str 7 | callback: Callable[[str], None] 8 | 9 | class PredefinedQuestionsManager: 10 | def __init__(self): 11 | self.questions = [ 12 | "Explain the concept of transformer models in NLP.", 13 | "What is the significance of paper ID 0704.2002?", 14 | "How does reinforcement learning apply to robotics?", 15 | "Summarize recent advances in computer vision." 16 | ] 17 | 18 | def get_questions(self) -> List[str]: 19 | return self.questions 20 | -------------------------------------------------------------------------------- /src/utils/paper_id_extractor.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Optional 3 | 4 | class PaperIdExtractor: 5 | _PATTERNS = [ 6 | r'paper\s*(?:id|ID|Id)?\s*[:#]?\s*(\d{4}\.\d{4})', 7 | r'id\s*[:#]?\s*(\d{4}\.\d{4})', 8 | r'(? Optional[str]: 16 | """Extract paper ID from text using various patterns.""" 17 | for pattern in cls._PATTERNS: 18 | match = re.search(pattern, text, re.IGNORECASE) 19 | if match: 20 | return match.group(1) 21 | return None -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "cometml" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["stefan"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.12" 10 | comet-ml = "^3.47.0" 11 | langchain = "^0.3.3" 12 | pandas = "^2.2.3" 13 | langchain-community = "^0.3.2" 14 | langchain-experimental = "^0.3.2" 15 | langchain-openai = "^0.2.2" 16 | graphdatascience = "^1.12" 17 | tiktoken = "^0.8.0" 18 | retry = "^0.9.2" 19 | neo4j = "5.21.0" 20 | sentence-transformers = "^3.2.0" 21 | faiss-cpu = "^1.9.0" 22 | python-dotenv = "^1.0.1" 23 | langgraph = "^0.2.39" 24 | langgraph-checkpoint-sqlite = "^2.0.0" 25 | postgres = "^4.0" 26 | streamlit = "^1.39.0" 27 | opik = "^1.3.3" 28 | 29 | 30 | [build-system] 31 | requires = ["poetry-core"] 32 | build-backend = "poetry.core.masonry.api" 33 | -------------------------------------------------------------------------------- /src/components/database/neo4j_client.py: -------------------------------------------------------------------------------- 1 | from neo4j import GraphDatabase 2 | from contextlib import contextmanager 3 | 4 | class Neo4jClient: 5 | def __init__(self, uri: str, user: str, password: str): 6 | self.uri = uri 7 | self.user = user 8 | self.password = password 9 | self._driver = None 10 | 11 | @property 12 | def driver(self): 13 | if self._driver is None: 14 | self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) 15 | return self._driver 16 | 17 | @contextmanager 18 | def session(self): 19 | session = self.driver.session() 20 | try: 21 | yield session 22 | finally: 23 | session.close() 24 | 25 | def close(self): 26 | if self._driver: 27 | self._driver.close() 28 | self._driver = None -------------------------------------------------------------------------------- /src/components/rag/embeddings.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import OpenAIEmbeddings 2 | from typing import List, Optional 3 | import os 4 | 5 | class Embedding: 6 | def __init__(self, api_key: Optional[str] = None): 7 | if not api_key and not os.getenv("OPENAI_API_KEY"): 8 | raise ValueError("OpenAI API key must be provided either directly or through environment variable") 9 | self.api_key = api_key or os.getenv("OPENAI_API_KEY") 10 | if not self.api_key.startswith("sk-"): 11 | raise ValueError("Invalid OpenAI API key format") 12 | self.model = OpenAIEmbeddings(openai_api_key=self.api_key) 13 | 14 | 15 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 16 | try: 17 | return self.model.embed_documents(texts) 18 | except Exception as e: 19 | raise ValueError(f"Error generating embeddings: {str(e)}") 20 | -------------------------------------------------------------------------------- /src/streamlit/main.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from src.orchestrator.coordinator import Coordinator 3 | from src.streamlit.layout import ResearchAssistantUI 4 | 5 | def main(): 6 | if "app" not in st.session_state: 7 | st.session_state.app = Coordinator() 8 | 9 | # Initialize session state variables if they don't exist 10 | if "messages" not in st.session_state: 11 | st.session_state.messages = [] 12 | if "metrics" not in st.session_state: 13 | st.session_state.metrics = {} 14 | if "conversation_history" not in st.session_state: 15 | st.session_state.conversation_history = [] 16 | if "show_predefined" not in st.session_state: 17 | st.session_state.show_predefined = True 18 | if "session_active" not in st.session_state: 19 | st.session_state.session_active = True 20 | 21 | ui = ResearchAssistantUI(st.session_state.app, st.session_state) 22 | ui.render() -------------------------------------------------------------------------------- /src/components/paper/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional 3 | from datetime import datetime 4 | 5 | @dataclass 6 | class Paper: 7 | id: str 8 | title: str 9 | abstract: str 10 | authors: List[str] 11 | categories: List[str] 12 | submit_date: datetime 13 | update_date: Optional[datetime] = None 14 | 15 | @classmethod 16 | def from_db_record(cls, record: dict) -> 'Paper': 17 | return cls( 18 | id=record.get('id'), 19 | title=record.get('title'), 20 | abstract=record.get('abstract'), 21 | authors=record.get('authors', []), 22 | categories=record.get('categories', []), 23 | submit_date=record.get('submit_date'), 24 | update_date=record.get('update_date') 25 | ) 26 | 27 | def to_string(self) -> str: 28 | return f"""Title: {self.title} 29 | Abstract: {self.abstract} 30 | Authors: {', '.join(self.authors) if self.authors else 'No authors listed'} 31 | Categories: {', '.join(self.categories) if self.categories else 'No categories listed'} 32 | Submitted on: {self.submit_date} 33 | Updated on: {self.update_date if self.update_date else 'N/A'}""" 34 | -------------------------------------------------------------------------------- /src/core/graph.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from langgraph.graph import START 3 | from langgraph.graph.state import StateGraph 4 | from langgraph.checkpoint.memory import MemorySaver 5 | from langgraph.prebuilt import tools_condition 6 | from src.core.state import ConversationState 7 | from src.tools.rag import RAG 8 | from src.tools.paper_lookup import PaperLookupTool 9 | 10 | def create_research_graph(assistant: Any, rag_tool: RAG, paper_lookup_tool: PaperLookupTool) -> Any: 11 | """Creates and returns the research graph with the specified components.""" 12 | 13 | # Initialize graph builder 14 | builder = StateGraph(ConversationState) 15 | 16 | # Add nodes 17 | builder.add_node("assistant", assistant) 18 | builder.add_node("rag_tool", rag_tool) 19 | builder.add_node("paper_lookup_tool", paper_lookup_tool) 20 | 21 | # Define edges 22 | builder.add_edge(START, "assistant") 23 | builder.add_conditional_edges( 24 | "assistant", 25 | tools_condition 26 | ) 27 | builder.add_edge("rag_tool", "assistant") 28 | builder.add_edge("paper_lookup_tool", "assistant") 29 | 30 | # Set up memory and compile 31 | memory = MemorySaver() 32 | return builder.compile(checkpointer=memory) 33 | -------------------------------------------------------------------------------- /src/components/database/vector_store.py: -------------------------------------------------------------------------------- 1 | from langchain_community.vectorstores import Neo4jVector 2 | from src.components.database.neo4j_client import Neo4jClient 3 | from typing import List, Tuple, Optional 4 | 5 | class VectorStore: 6 | def __init__(self, neo4j_client: Neo4jClient, embedding_model, index_name: str): 7 | self.client = neo4j_client 8 | self.embedding_model = embedding_model 9 | self.index_name = index_name 10 | self.vector_store = self._initialize_vector_store() 11 | 12 | def _initialize_vector_store(self) -> Neo4jVector: 13 | return Neo4jVector( 14 | embedding=self.embedding_model, 15 | url=self.client.uri, 16 | username=self.client.user, 17 | password=self.client.password, 18 | index_name=self.index_name, 19 | node_label="Paper", 20 | text_node_property="abstract" 21 | ) 22 | 23 | def similarity_search(self, query: str, k: int = 3) -> List[Tuple[str, float]]: 24 | try: 25 | results = self.vector_store.similarity_search_with_score(query, k=k) 26 | return [(doc.page_content, score) for doc, score in results] 27 | except Exception as e: 28 | raise ValueError(f"Error performing similarity search: {str(e)}") 29 | -------------------------------------------------------------------------------- /scripts/neo4j_cleaner.py: -------------------------------------------------------------------------------- 1 | from neo4j import GraphDatabase 2 | from src.config.settings import Settings 3 | 4 | settings = Settings() 5 | 6 | class Neo4jCleaner: 7 | def __init__(self, uri, user, password): 8 | self.driver = GraphDatabase.driver(uri, auth=(user, password)) 9 | 10 | def close(self): 11 | self.driver.close() 12 | 13 | def delete_all_data(self): 14 | with self.driver.session() as session: 15 | session.run("MATCH (n) DETACH DELETE n") 16 | print("All nodes and relationships have been deleted.") 17 | 18 | def delete_constraints_and_indexes(self): 19 | with self.driver.session() as session: 20 | # Drop all constraints 21 | for constraint in session.run("SHOW CONSTRAINTS"): 22 | session.run(f"DROP CONSTRAINT {constraint['name']}") 23 | 24 | # Drop all indexes 25 | for index in session.run("SHOW INDEXES"): 26 | session.run(f"DROP INDEX {index['name']}") 27 | 28 | print("All constraints and indexes have been dropped.") 29 | 30 | 31 | def main(): 32 | cleaner = Neo4jCleaner(uri=settings.neo4j_uri, user=settings.neo4j_user, password=settings.neo4j_password) 33 | 34 | try: 35 | cleaner.delete_all_data() 36 | cleaner.delete_constraints_and_indexes() 37 | finally: 38 | cleaner.close() 39 | 40 | 41 | if __name__ == "__main__": 42 | main() -------------------------------------------------------------------------------- /src/components/rag/indexing.py: -------------------------------------------------------------------------------- 1 | from src.services.database.neo4j_client import Neo4jClient 2 | from src.services.rag.embeddings import Embedding 3 | import logging 4 | 5 | class IndexingService: 6 | def __init__( 7 | self, 8 | db_client: Neo4jClient, 9 | embedding_service: Embedding, 10 | batch_size: int = 100 11 | ): 12 | self.db_client = db_client 13 | self.embedding_service = embedding_service 14 | self.batch_size = batch_size 15 | self.logger = logging.getLogger(__name__) 16 | 17 | def ensure_vector_index(self, index_name: str): 18 | with self.db_client.session() as session: 19 | if not self._vector_index_exists(session, index_name): 20 | self._create_vector_index(session, index_name) 21 | 22 | def _vector_index_exists(self, session, index_name: str) -> bool: 23 | query = """ 24 | SHOW INDEXES 25 | YIELD name, type 26 | WHERE name = $index_name AND type = 'VECTOR' 27 | RETURN count(*) > 0 AS exists 28 | """ 29 | result = session.run(query, index_name=index_name) 30 | return result.single()['exists'] 31 | 32 | def _create_vector_index(self, session, index_name: str): 33 | session.run(""" 34 | CALL db.index.vector.createNodeIndex( 35 | $index_name, 36 | 'Paper', 37 | 'embedding', 38 | 1536, 39 | 'cosine' 40 | ) 41 | """, index_name=index_name) 42 | self.logger.info(f"Vector index '{index_name}' created.") 43 | -------------------------------------------------------------------------------- /src/streamlit/ui_component.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from typing import List, Callable 3 | from langchain.schema import HumanMessage, AIMessage 4 | 5 | class ChatDisplay: 6 | @staticmethod 7 | def display_messages(messages: List): 8 | for message in messages: 9 | if hasattr(message, 'format_for_display'): 10 | formatted_message = message.format_for_display() 11 | elif isinstance(message, HumanMessage): 12 | formatted_message = f"**You:** {message.content}" 13 | elif isinstance(message, AIMessage): 14 | formatted_message = f"**Assistant:** {message.content}" 15 | else: 16 | formatted_message = f"**Unknown:** {message.content}" 17 | 18 | st.write(formatted_message) 19 | st.markdown("
" * 2, unsafe_allow_html=True) 20 | 21 | 22 | class InputArea: 23 | def __init__(self, on_submit: Callable[[str], None]): 24 | self.on_submit = on_submit 25 | 26 | def render(self): 27 | st.text_input( 28 | "Enter your question here:", 29 | key="user_input", 30 | on_change=self._handle_input 31 | ) 32 | 33 | def _handle_input(self): 34 | if st.session_state.user_input: 35 | self.on_submit(st.session_state.user_input) 36 | st.session_state.user_input = "" 37 | 38 | class SessionControls: 39 | def __init__(self, on_end_session: Callable[[], None]): 40 | self.on_end_session = on_end_session 41 | 42 | def render(self): 43 | col1, col2, col3 = st.columns([1, 3, 1]) 44 | with col1: 45 | if st.button("End Session"): 46 | self.on_end_session() -------------------------------------------------------------------------------- /src/components/database/ingest.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import time 4 | from dotenv import load_dotenv 5 | from src.config.settings import Settings 6 | from src.components.database.neo4j_ingestion import OptimizedNeo4jIngestor, worker 7 | 8 | load_dotenv() 9 | 10 | def ingest_data_parallel(uri: str, user: str, password: str, data, batch_size: int = 1000, num_processes: int = 4): 11 | total = len(data) 12 | batches = [data[i:i + batch_size] for i in range(0, total, batch_size)] 13 | 14 | with multiprocessing.Pool(num_processes) as pool: 15 | results = [] 16 | for batch in batches: 17 | result = pool.apply_async(worker, (uri, user, password, batch)) 18 | results.append(result) 19 | 20 | for i, result in enumerate(results): 21 | result.get() # Wait for the batch to complete 22 | print(f"Ingested {min((i + 1) * batch_size, total)}/{total} papers") 23 | 24 | 25 | def main(): 26 | settings = Settings() 27 | uri = settings.neo4j_uri 28 | user = settings.neo4j_user 29 | password = settings.neo4j_password 30 | 31 | # Initialize and create constraints 32 | ingestor = OptimizedNeo4jIngestor(uri, user, password) 33 | ingestor.create_constraints() 34 | ingestor.close() 35 | 36 | # Load data 37 | with open('processed_data.json', 'r') as f: 38 | processed_data = json.load(f) 39 | 40 | # Ingest in parallel 41 | start_time = time.time() 42 | ingest_data_parallel(uri, user, password, processed_data, batch_size=1000, num_processes=4) 43 | end_time = time.time() 44 | 45 | print(f"Total ingestion time: {end_time - start_time:.2f} seconds") 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /src/components/database/neo4j_ingestion.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any 2 | from neo4j import GraphDatabase 3 | 4 | class OptimizedNeo4jIngestor: 5 | def __init__(self, uri: str, user: str, password: str): 6 | self.driver = GraphDatabase.driver(uri, auth=(user, password)) 7 | 8 | def close(self): 9 | self.driver.close() 10 | 11 | def create_constraints(self): 12 | with self.driver.session() as session: 13 | session.run("CREATE CONSTRAINT paper_id IF NOT EXISTS FOR (p:Paper) REQUIRE p.id IS UNIQUE") 14 | session.run("CREATE CONSTRAINT author_name IF NOT EXISTS FOR (a:Author) REQUIRE a.name IS UNIQUE") 15 | session.run("CREATE CONSTRAINT category_name IF NOT EXISTS FOR (c:Category) REQUIRE c.name IS UNIQUE") 16 | 17 | def ingest_batch(self, batch: List[Dict[str, Any]]): 18 | with self.driver.session() as session: 19 | session.execute_write(self._create_and_link_batch, batch) 20 | 21 | def _create_and_link_batch(self, tx, batch: List[Dict[str, Any]]): 22 | query = """ 23 | UNWIND $batch AS paper 24 | MERGE (p:Paper {id: paper.id}) 25 | SET p.title = paper.title, p.abstract = paper.abstract, 26 | p.submit_date = paper.submit_date, p.update_date = paper.update_date 27 | WITH p, paper 28 | UNWIND paper.authors AS author_name 29 | MERGE (a:Author {name: author_name}) 30 | MERGE (p)-[:AUTHORED_BY]->(a) 31 | WITH p, paper 32 | UNWIND paper.categories AS category_name 33 | MERGE (c:Category {name: category_name}) 34 | MERGE (p)-[:BELONGS_TO]->(c) 35 | """ 36 | tx.run(query, batch=batch) 37 | 38 | 39 | def worker(uri: str, user: str, password: str, batch: List[Dict[str, Any]]): 40 | ingestor = OptimizedNeo4jIngestor(uri, user, password) 41 | try: 42 | ingestor.ingest_batch(batch) 43 | finally: 44 | ingestor.close() 45 | -------------------------------------------------------------------------------- /src/tools/rag.py: -------------------------------------------------------------------------------- 1 | from langchain.tools import BaseTool 2 | from pydantic import PrivateAttr 3 | import time 4 | 5 | from src.components.rag.tool import RAG 6 | from src.components.evaluation.experiment_tracker import ExperimentTracker, MetricsCollector, MetricsData 7 | 8 | class RAGTool(BaseTool): 9 | name: str = "RAG" 10 | description: str = "Use this tool to retrieve research papers and generate answers to general queries." 11 | _rag_service: RAG = PrivateAttr() 12 | _experiment_tracker: ExperimentTracker = PrivateAttr() 13 | _metrics_collector: MetricsCollector = PrivateAttr() 14 | 15 | def __init__( 16 | self, 17 | rag_service: RAG, 18 | experiment_tracker: ExperimentTracker, 19 | metrics_collector: MetricsCollector 20 | ): 21 | super().__init__() 22 | self._rag_service = rag_service 23 | self._experiment_tracker = experiment_tracker 24 | self._metrics_collector = metrics_collector 25 | 26 | def _run(self, query: str) -> str: 27 | """Execute the RAG functionality.""" 28 | # Start timing 29 | start_time = time.time() 30 | 31 | response_text = "" 32 | success = False 33 | error_msg = None 34 | 35 | try: 36 | response = self._rag_service.answer_question(query) 37 | response_text = response["response"] 38 | success = True 39 | except Exception as e: 40 | error_msg = str(e) 41 | response_text = f"Error processing RAG query: {error_msg}" 42 | 43 | # End timing 44 | end_time = time.time() 45 | processing_time = end_time - start_time 46 | 47 | # Build MetricsData 48 | metrics_data = MetricsData( 49 | processing_time=processing_time, 50 | query_length=len(query), 51 | response_length=len(response_text), 52 | success=success, 53 | token_count=self._metrics_collector.count_tokens(response_text), 54 | error=error_msg 55 | ) 56 | 57 | # Log to CometML 58 | self._experiment_tracker.log_rag_query(metrics_data) 59 | return response_text 60 | -------------------------------------------------------------------------------- /scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from typing import List, Dict, Any 4 | 5 | 6 | def preprocess_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 7 | """ 8 | Preprocess the arXiv papers data. 9 | 10 | Args: 11 | data: List of dictionaries containing paper information 12 | 13 | Returns: 14 | List of processed paper dictionaries 15 | """ 16 | processed_data = [] 17 | for paper in data: 18 | processed_paper = { 19 | 'id': paper['id'], 20 | 'title': paper['title'], 21 | 'abstract': paper['abstract'], 22 | 'categories': paper['categories'].split(), 23 | 'authors': [' '.join(author).strip() 24 | for author in paper['authors_parsed']], 25 | 'submit_date': paper['versions'][0]['created'], 26 | 'update_date': paper['update_date'] 27 | } 28 | processed_data.append(processed_paper) 29 | return processed_data 30 | 31 | 32 | def main(): 33 | # Set up argument parser 34 | parser = argparse.ArgumentParser(description='Preprocess arXiv papers data') 35 | parser.add_argument('--input', type=str, required=True, 36 | help='Input JSON file path') 37 | parser.add_argument('--output', type=str, required=True, 38 | help='Output JSON file path') 39 | 40 | # Parse arguments 41 | args = parser.parse_args() 42 | 43 | # Read input file 44 | print(f"Reading data from {args.input}...") 45 | try: 46 | with open(args.input, 'r') as f: 47 | raw_data = json.load(f) 48 | except FileNotFoundError: 49 | print(f"Error: Input file {args.input} not found") 50 | return 51 | except json.JSONDecodeError: 52 | print(f"Error: Input file {args.input} is not valid JSON") 53 | return 54 | 55 | # Process the data 56 | print("Processing data...") 57 | processed_data = preprocess_data(raw_data) 58 | 59 | # Save processed data 60 | print(f"Saving processed data to {args.output}...") 61 | try: 62 | with open(args.output, 'w') as f: 63 | json.dump(processed_data, f, indent=2) 64 | print("Processing completed successfully!") 65 | except Exception as e: 66 | print(f"Error saving output file: {str(e)}") 67 | return 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /src/streamlit/layout.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from src.streamlit.ui_component import ChatDisplay, InputArea, SessionControls 3 | from src.streamlit.predefined_questions import PredefinedQuestionsManager 4 | 5 | 6 | class ResearchAssistantUI: 7 | def __init__(self, coordinator, session_state): 8 | self.coordinator = coordinator 9 | self.session_state = session_state 10 | self.chat_display = ChatDisplay() 11 | self.input_area = InputArea(self._handle_user_input) 12 | self.session_controls = SessionControls(self._handle_session_end) 13 | self.predefined_questions = PredefinedQuestionsManager() 14 | 15 | def _get_unique_messages(self, messages): 16 | # Use content to identify unique messages 17 | seen_contents = set() 18 | unique_messages = [] 19 | for message in messages: 20 | if message.content not in seen_contents: 21 | seen_contents.add(message.content) 22 | unique_messages.append(message) 23 | return unique_messages 24 | 25 | def render(self): 26 | st.set_page_config(page_title="Research Paper Assistant", layout="wide") 27 | st.title("Research Paper Assistant") 28 | 29 | if not self.session_state.session_active: 30 | st.info("Session has ended. Please refresh the page to start a new session.") 31 | return 32 | 33 | chat_container = st.container() 34 | input_container = st.container() 35 | 36 | with chat_container: 37 | unique_messages = self._get_unique_messages(self.session_state.messages) 38 | self.chat_display.display_messages(unique_messages) 39 | 40 | with input_container: 41 | self.session_controls.render() 42 | self._render_predefined_questions() 43 | self.input_area.render() 44 | 45 | def _render_predefined_questions(self): 46 | if self.session_state.show_predefined: 47 | st.subheader("Predefined Questions") 48 | questions = self.predefined_questions.get_questions() 49 | col1, col2 = st.columns(2) 50 | 51 | for i, question in enumerate(questions): 52 | with col1 if i < len(questions) // 2 else col2: 53 | st.button( 54 | question, 55 | key=f"pred_q_{i}", 56 | on_click=self._handle_predefined_question, 57 | args=(question,) 58 | ) 59 | 60 | def _handle_user_input(self, user_input: str): 61 | self.session_state.show_predefined = False 62 | self.coordinator.process_message(user_input, self.session_state) 63 | 64 | def _handle_predefined_question(self, question: str): 65 | self.session_state.show_predefined = False 66 | self.coordinator.process_message(question, self.session_state) 67 | 68 | def _handle_session_end(self): 69 | self.coordinator.cleanup() 70 | self.session_state.messages.clear() 71 | self.session_state.session_active = False 72 | st.info("Session has ended. Please refresh the page to start a new session.") 73 | st.stop() 74 | 75 | -------------------------------------------------------------------------------- /src/tools/paper_lookup.py: -------------------------------------------------------------------------------- 1 | from src.components.paper.tool import PaperTool 2 | from src.utils.paper_id_extractor import PaperIdExtractor 3 | from langchain.tools.base import BaseTool 4 | from pydantic import PrivateAttr 5 | import time 6 | import json 7 | 8 | from src.components.evaluation.experiment_tracker import ExperimentTracker 9 | from src.components.evaluation.experiment_tracker import MetricsCollector, MetricsData 10 | 11 | 12 | class PaperLookupTool(BaseTool): 13 | name: str = "Paper Lookup" 14 | description: str = "Use this tool to retrieve details about a specific paper by its ID." 15 | _paper_service: PaperTool = PrivateAttr() 16 | _paper_id_extractor: PaperIdExtractor = PrivateAttr() 17 | _experiment_tracker: ExperimentTracker = PrivateAttr() 18 | _metrics_collector: MetricsCollector = PrivateAttr() 19 | 20 | def __init__( 21 | self, 22 | paper_service: PaperTool, 23 | experiment_tracker: ExperimentTracker, 24 | metrics_collector: MetricsCollector 25 | ): 26 | super().__init__() 27 | self._paper_service = paper_service 28 | self._paper_id_extractor = PaperIdExtractor() 29 | self._experiment_tracker = experiment_tracker 30 | self._metrics_collector = metrics_collector 31 | 32 | def _run(self, query: str) -> str: 33 | """Execute the paper lookup functionality.""" 34 | # Start timing 35 | start_time = time.time() 36 | 37 | paper_id = None 38 | response_text = "" 39 | success = False 40 | error_msg = None 41 | 42 | try: 43 | paper_id = self._paper_id_extractor.extract(query) 44 | if not paper_id: 45 | response_text = "No valid paper ID found in the message." 46 | return json.dumps({ 47 | "ground_truth": "", 48 | "tool_answer": response_text 49 | }) 50 | 51 | # Get paper info with metrics 52 | result = self._paper_service.find_paper_by_id(paper_id) 53 | if result["success"]: 54 | # The “official text” from the DB 55 | ground_truth = result["response"] 56 | success = True 57 | tool_answer = ( 58 | f"Here is the paper with ID {paper_id}:\n\n" 59 | f"{ground_truth}" 60 | ) 61 | else: 62 | ground_truth = "" 63 | tool_answer = f"Paper with ID {paper_id} not found." 64 | except Exception as e: 65 | ground_truth = "" 66 | error_msg = str(e) 67 | tool_answer = f"Error looking up paper: {error_msg}" 68 | success = False 69 | 70 | # End timing 71 | end_time = time.time() 72 | processing_time = end_time - start_time 73 | 74 | # Build MetricsData 75 | metrics_data = MetricsData( 76 | processing_time=processing_time, 77 | query_length=len(query), 78 | response_length=len(response_text), 79 | success=success, 80 | token_count=self._metrics_collector.count_tokens(response_text), 81 | error=error_msg 82 | ) 83 | 84 | # Log to CometML 85 | if paper_id: 86 | self._experiment_tracker.log_paper_lookup(paper_id, metrics_data) 87 | 88 | return json.dumps({ 89 | "ground_truth": ground_truth, 90 | "tool_answer": tool_answer 91 | }) 92 | -------------------------------------------------------------------------------- /src/agents/research_assistant.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List 2 | import time 3 | import json 4 | 5 | from langchain_community.chat_models import ChatOpenAI 6 | from langchain.schema import AIMessage 7 | 8 | from src.components.evaluation.experiment_tracker import ExperimentTracker 9 | from src.core.state import ConversationState 10 | from langchain.memory import ConversationBufferMemory 11 | from dotenv import load_dotenv 12 | from langchain.agents import initialize_agent, AgentType 13 | from langchain.tools.base import BaseTool 14 | from src.agents.base import BaseAgent 15 | 16 | load_dotenv() 17 | 18 | class ResearchAssistant(BaseAgent): 19 | def __init__( 20 | self, 21 | experiment_tracker: ExperimentTracker, 22 | tools: List[BaseTool], 23 | llm: ChatOpenAI, 24 | ): 25 | self.experiment_tracker = experiment_tracker 26 | self.llm = llm 27 | self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) 28 | # Initialize the agent with the tools and LLM 29 | self.agent_executor = initialize_agent( 30 | tools, 31 | self.llm, 32 | agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, 33 | verbose=True, 34 | memory=self.memory, 35 | handle_parsing_errors=True 36 | ) 37 | 38 | def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: 39 | last_message = state["messages"][-1] 40 | query_content = last_message.content 41 | 42 | start_time = time.time() 43 | self.experiment_tracker.experiment.log_parameter("input_query", query_content) 44 | 45 | # Add user message to memory 46 | self.memory.chat_memory.add_user_message(query_content) 47 | 48 | # 1) Run the agent 49 | response_text = self.agent_executor.run(input=query_content) 50 | 51 | processing_time = time.time() - start_time 52 | self.experiment_tracker.experiment.log_metrics({ 53 | "processing_time": processing_time, 54 | "response_length": len(response_text), 55 | "query_length": len(query_content) 56 | }) 57 | 58 | if "metrics" in state: 59 | self.experiment_tracker.experiment.log_metrics(state["metrics"]) 60 | 61 | ground_truth = "" 62 | tool_answer = response_text # fallback is the entire text 63 | 64 | try: 65 | parsed = json.loads(response_text) 66 | if isinstance(parsed, dict) and "ground_truth" in parsed: 67 | ground_truth = parsed["ground_truth"] 68 | tool_answer = parsed["tool_answer"] 69 | except: 70 | pass 71 | final_response = tool_answer 72 | 73 | # Add final response to memory and state 74 | state["messages"].append(AIMessage(content=final_response)) 75 | self.memory.chat_memory.add_ai_message(final_response) 76 | 77 | return { 78 | "messages": [AIMessage(content=final_response)], 79 | # Put the ground truth somewhere so we can pick it up in coordinator 80 | "tool_output": {"paper_ground_truth": ground_truth} 81 | } 82 | 83 | def process_message(self, state: ConversationState) -> Dict[str, Any]: 84 | """Process a message. This method is required by BaseAgent but is not used.""" 85 | # Since we're handling message processing in __call__, we can leave this empty just to not get an error. 86 | pass 87 | 88 | def handle_error(self, error: Exception) -> Dict[str, Any]: 89 | """Generic error handler which is required by BaseAgent but is not used.""" 90 | error_message = f"An error occurred: {str(error)}" 91 | return {"messages": [AIMessage(content=error_message)]} 92 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Uplifted RAG Systems with CometML Opik 2 | ## Building an Agentic Knowledge Graph for Enhanced Information Retrieval 3 | 4 | This repository demonstrates an advanced implementation of a Retrieval-Augmented Generation (RAG) system that combines graph-based knowledge representation with agentic capabilities, monitored and optimized through CometML Opik. The system transforms traditional RAG approaches by enabling context-aware, multi-hop reasoning while maintaining comprehensive observability. 5 | 6 | ## Table of Contents 7 | - [Overview](#overview) 8 | - [Key Features](#key-features) 9 | - [System Architecture](#system-architecture) 10 | - [Prerequisites](#prerequisites) 11 | - [Installation](#installation) 12 | - [Usage](#usage) 13 | 14 | ## Overview 15 | 16 | Traditional RAG systems often struggle with linear document retrieval, context blindness, and limited reasoning capabilities. This implementation addresses these challenges by: 17 | - Implementing a graph-based knowledge structure for complex relationship modeling 18 | - Integrating autonomous agents for dynamic query processing 19 | - Providing comprehensive monitoring through CometML Opik 20 | - Enabling multi-hop reasoning across interconnected documents 21 | 22 | ## Key Features 23 | 24 | ### Advanced RAG Capabilities 25 | - Graph-based knowledge representation using Neo4j 26 | - Semantic search with vector embeddings 27 | - Multi-hop reasoning across document relationships 28 | - Context-aware query processing 29 | 30 | ### Agent Integration 31 | - Dynamic tool selection based on query context 32 | - Adaptive exploration strategies 33 | - Comprehensive conversation memory 34 | - Fallback mechanisms for robust operation 35 | 36 | ### Monitoring & Optimization 37 | - Real-time performance tracking with CometML Opik 38 | - Detailed metrics collection and visualization 39 | - Hyperparameter optimization 40 | - Model versioning and experiment tracking 41 | 42 | ### User Interface 43 | - Streamlit-based interactive interface 44 | - Component-based architecture 45 | - Session state management 46 | - Metric visualization 47 | 48 | ## System Architecture![Copy of Template Diagrame (1200 x 800 px) (1200 x 900 px) (1200 x 700 px)(3)](https://github.com/user-attachments/assets/79ac7fc8-03f1-41f7-b5f7-37e55955ad11) 49 | 50 | 51 | The system consists of several key components: 52 | 53 | 1. **Data Pipeline** 54 | - Raw data preprocessing 55 | - Neo4j graph database integration 56 | - Parallel data ingestion system 57 | - Vector store bridging 58 | 59 | 2. **Core RAG Components** 60 | - Question-answering pipeline 61 | - Paper lookup service 62 | - Agent-based tool orchestration 63 | - Coordinator pattern implementation 64 | 65 | 3. **Monitoring Infrastructure** 66 | - Real-time metric collection 67 | - Performance visualization 68 | - Error tracking 69 | - Session analytics 70 | 71 | ## Prerequisites 72 | 73 | - Neo4j Database 74 | - CometML Account 75 | - OpenAI API Key 76 | - [arXiv Dataset from Kaggle](https://www.kaggle.com/datasets/Cornell-University/arxiv) 77 | 78 | ## Installation 79 | 80 | 1. **Clone the Repository** 81 | ```bash 82 | git clone https://github.com/mlvanguards/agentic-graph-rag-evaluation-cometml.git 83 | ``` 84 | 85 | 2. **Install Dependencies** 86 | ```bash 87 | pip install -r requirements.txt 88 | ``` 89 | 90 | 3. **Environment Setup** 91 | Create a `.env` file with the following credentials: 92 | ``` 93 | NEO4J_URI=your_neo4j_uri 94 | NEO4J_USER=your_username 95 | NEO4J_PASSWORD=your_password 96 | COMET_API_KEY=your_comet_key 97 | OPENAI_API_KEY=your_openai_key 98 | ``` 99 | 100 | ## Usage 101 | 102 | 1. **Data Preprocessing** 103 | ```python 104 | python -m scripts/preprocess_data.py --input arxiv_data.json --output processed_data.json 105 | ``` 106 | 107 | 2. **Database Ingestion** 108 | ```python 109 | cd src/ 110 | python -m components/database/ingest.py 111 | ``` 112 | 113 | 3. **Start the Application** 114 | ```python 115 | python -m streamlit run main.py 116 | ``` 117 | -------------------------------------------------------------------------------- /src/components/rag/tool.py: -------------------------------------------------------------------------------- 1 | from src.components.database.vector_store import VectorStore 2 | from langchain_core.prompts import PromptTemplate 3 | from langchain_openai import OpenAI 4 | from langchain.chains.llm import LLMChain 5 | from typing import Optional, Dict 6 | from src.components.evaluation.experiment_tracker import MetricsCollector 7 | import time 8 | 9 | class RAG: 10 | def __init__( 11 | self, 12 | vector_store: VectorStore, 13 | openai_api_key: str, 14 | prompt_template: Optional[str] = None 15 | ): 16 | if not openai_api_key: 17 | raise ValueError("OpenAI API key must be provided") 18 | self.vector_store = vector_store 19 | self.llm = OpenAI(openai_api_key=openai_api_key) 20 | self.metrics_collector = MetricsCollector() 21 | 22 | self.prompt_template = PromptTemplate( 23 | input_variables=['context', 'question'], 24 | template=prompt_template or self._default_prompt_template() 25 | ) 26 | self.llm_chain = LLMChain(llm=self.llm, prompt=self.prompt_template) 27 | 28 | def _default_prompt_template(self) -> str: 29 | return """ 30 | You are an AI assistant knowledgeable about computer science research. 31 | 32 | Context: 33 | {context} 34 | 35 | Question: 36 | {question} 37 | 38 | Provide a detailed and accurate answer based on the context provided. 39 | """ 40 | 41 | def answer_question(self, question: str, k: int = 3) -> Dict[str, any]: 42 | start_time = time.time() 43 | try: 44 | # Get context with metrics 45 | context_result = self.get_context(question, k) 46 | context = context_result["context"] 47 | metrics = context_result["metrics"] 48 | 49 | # Generate response 50 | response = self.llm_chain.run(context=context, question=question) 51 | 52 | # Add response metrics 53 | response_stats = self.metrics_collector.get_text_stats(response) 54 | metrics.update({ 55 | "response_length": len(response), 56 | "response_tokens": response_stats["token_count"], 57 | "total_processing_time": time.time() - start_time, 58 | "success": True 59 | }) 60 | 61 | return { 62 | "response": response, 63 | "metrics": metrics 64 | } 65 | except Exception as e: 66 | return { 67 | "response": f"Error generating answer: {str(e)}", 68 | "metrics": { 69 | "error": str(e), 70 | "success": False, 71 | "total_processing_time": time.time() - start_time 72 | } 73 | } 74 | 75 | def get_context(self, question: str, k: int = 3) -> Dict[str, any]: 76 | start_time = time.time() 77 | try: 78 | relevant_docs = self.vector_store.similarity_search(question, k=k) 79 | context = "\n\n".join([doc for doc, _ in relevant_docs]) 80 | 81 | # Collect metrics 82 | context_stats = self.metrics_collector.get_text_stats(context) 83 | question_stats = self.metrics_collector.get_text_stats(question) 84 | 85 | metrics = { 86 | "context_length": len(context), 87 | "context_tokens": context_stats["token_count"], 88 | "context_chunks": len(relevant_docs), 89 | "question_length": len(question), 90 | "question_tokens": question_stats["token_count"], 91 | "retrieval_time": time.time() - start_time, 92 | "success": True 93 | } 94 | 95 | return { 96 | "context": context, 97 | "metrics": metrics 98 | } 99 | except Exception as e: 100 | return { 101 | "context": "", 102 | "metrics": { 103 | "error": str(e), 104 | "success": False, 105 | "retrieval_time": time.time() - start_time 106 | } 107 | } 108 | 109 | -------------------------------------------------------------------------------- /src/components/evaluation/experiment_tracker.py: -------------------------------------------------------------------------------- 1 | from comet_ml import Experiment 2 | from typing import Dict, Any, Optional 3 | import time 4 | from dataclasses import dataclass 5 | import tiktoken 6 | from opik.evaluation.metrics import Hallucination 7 | @dataclass 8 | class MetricsData: 9 | processing_time: float 10 | query_length: int 11 | response_length: int 12 | success: bool 13 | context_length: int = 0 # Added for RAG context size 14 | token_count: int = 0 # Added for response tokens 15 | error: Optional[str] = None 16 | 17 | 18 | class MetricsCollector: 19 | def __init__(self): 20 | self.tokenizer = tiktoken.encoding_for_model("gpt-4o") 21 | 22 | def count_tokens(self, text: str) -> int: 23 | return len(self.tokenizer.encode(text)) 24 | 25 | def get_text_stats(self, text: str) -> Dict[str, int]: 26 | return { 27 | "char_length": len(text), 28 | "token_count": self.count_tokens(text), 29 | "word_count": len(text.split()), 30 | "line_count": len(text.splitlines()) 31 | } 32 | 33 | class ExperimentTracker: 34 | def __init__(self, api_key: str, project_name: str): 35 | self.experiment = Experiment( 36 | api_key=api_key, 37 | project_name=project_name 38 | ) 39 | self.start_time = time.time() 40 | self.query_count = 0 41 | self.error_count = 0 42 | 43 | def log_paper_lookup(self, paper_id: str, metrics: MetricsData): 44 | """Log metrics for paper lookups.""" 45 | self.query_count += 1 46 | if not metrics.success: 47 | self.error_count += 1 48 | 49 | self.experiment.log_metrics({ 50 | "paper_lookup_latency": metrics.processing_time, 51 | "paper_id_length": len(paper_id), 52 | "paper_lookup_year": int(paper_id.split('.')[0]) if '.' in paper_id else None, 53 | "paper_lookup_success": int(metrics.success), 54 | "paper_response_length": metrics.response_length, 55 | "paper_response_tokens": metrics.token_count, 56 | "cumulative_queries": self.query_count, 57 | "error_rate": self.error_count / self.query_count if self.query_count > 0 else 0 58 | }) 59 | if metrics.error: 60 | self.experiment.log_parameter("error", metrics.error) 61 | 62 | def log_rag_query(self, metrics: MetricsData): 63 | """Log metrics for RAG queries.""" 64 | self.query_count += 1 65 | if not metrics.success: 66 | self.error_count += 1 67 | 68 | self.experiment.log_metrics({ 69 | "rag_query_length": metrics.query_length, 70 | "rag_response_length": metrics.response_length, 71 | "rag_processing_time": metrics.processing_time, 72 | "rag_success": int(metrics.success), 73 | "rag_context_length": metrics.context_length, 74 | "rag_response_tokens": metrics.token_count, 75 | "cumulative_queries": self.query_count, 76 | "error_rate": self.error_count / self.query_count if self.query_count > 0 else 0 77 | }) 78 | 79 | # Log hourly aggregates 80 | hour = time.strftime("%Y-%m-%d-%H") 81 | self.experiment.log_metrics({ 82 | f"hourly_queries_{hour}": 1, 83 | f"hourly_errors_{hour}": 0 if metrics.success else 1, 84 | f"hourly_avg_latency_{hour}": metrics.processing_time 85 | }) 86 | 87 | def log_session_metrics(self): 88 | """Log overall session metrics.""" 89 | session_duration = time.time() - self.start_time 90 | self.experiment.log_metrics({ 91 | "session_duration": session_duration, 92 | "total_queries": self.query_count, 93 | "total_errors": self.error_count, 94 | "session_error_rate": self.error_count / self.query_count if self.query_count > 0 else 0, 95 | "queries_per_minute": (self.query_count * 60) / session_duration if session_duration > 0 else 0 96 | }) 97 | 98 | def end_session(self, session_metrics: Dict[str, Any]): 99 | """Log final session metrics and end experiment.""" 100 | self.log_session_metrics() 101 | for key, value in session_metrics.items(): 102 | self.experiment.log_metric(f"session_{key}", value) 103 | self.experiment.end() -------------------------------------------------------------------------------- /src/components/paper/tool.py: -------------------------------------------------------------------------------- 1 | from src.components.database.neo4j_client import Neo4jClient 2 | from src.components.paper.models import Paper 3 | from typing import Dict, Any 4 | from src.components.evaluation.experiment_tracker import MetricsCollector 5 | import logging 6 | from neo4j.exceptions import AuthError, ServiceUnavailable 7 | import time 8 | 9 | class PaperTool: 10 | def __init__(self, db_client: Neo4jClient): 11 | self.db_client = db_client 12 | self.metrics_collector = MetricsCollector() 13 | self.logger = logging.getLogger(__name__) 14 | 15 | def find_paper_by_id(self, paper_id: str) -> Dict[str, Any]: 16 | start_time = time.time() 17 | 18 | try: 19 | # Test connection before query 20 | self.logger.info(f"Testing connection before paper lookup for {paper_id}") 21 | with self.db_client.session() as test_session: 22 | test_session.run("RETURN 1").single() 23 | 24 | # Perform actual query 25 | self.logger.info(f"Executing paper lookup query for {paper_id}") 26 | with self.db_client.session() as session: 27 | result = session.run(self._get_paper_query(), paper_id=paper_id) 28 | record = result.single() 29 | 30 | if not record: 31 | self.logger.info(f"No paper found with ID {paper_id}") 32 | return { 33 | "response": f"Paper with ID {paper_id} not found.", 34 | "success": False, 35 | "metrics": { 36 | "error": "Paper not found", 37 | "success": False, 38 | "processing_time": time.time() - start_time, 39 | "question_length": len(paper_id) 40 | } 41 | } 42 | 43 | # Create paper object 44 | paper = Paper.from_db_record(record) 45 | paper_text = paper.to_string() 46 | 47 | # Collect metrics 48 | paper_stats = self.metrics_collector.get_text_stats(paper_text) 49 | metrics = { 50 | "success": True, 51 | "processing_time": time.time() - start_time, 52 | "response_length": len(paper_text), 53 | "response_tokens": paper_stats["token_count"], 54 | "word_count": paper_stats["word_count"], 55 | "authors_count": len(record["authors"]), 56 | "categories_count": len(record["categories"]), 57 | "question_length": len(paper_id) 58 | } 59 | 60 | return { 61 | "response": paper_text, 62 | "success": True, 63 | "metrics": metrics 64 | } 65 | 66 | except AuthError as e: 67 | error_msg = f"Authentication failed: {str(e)}" 68 | self.logger.error(f"Authentication error during paper lookup: {str(e)}") 69 | return self._create_error_response(error_msg, start_time, paper_id) 70 | except ServiceUnavailable as e: 71 | error_msg = f"Database service unavailable: {str(e)}" 72 | self.logger.error(f"Service unavailable during paper lookup: {str(e)}") 73 | return self._create_error_response(error_msg, start_time, paper_id) 74 | except Exception as e: 75 | error_msg = f"Error retrieving paper: {str(e)}" 76 | self.logger.error(f"Error during paper lookup: {str(e)}") 77 | return self._create_error_response(error_msg, start_time, paper_id) 78 | 79 | def _create_error_response(self, error_msg: str, start_time: float, paper_id: str) -> Dict[str, Any]: 80 | return { 81 | "response": error_msg, 82 | "success": False, 83 | "metrics": { 84 | "error": error_msg, 85 | "success": False, 86 | "processing_time": time.time() - start_time, 87 | "question_length": len(paper_id) 88 | } 89 | } 90 | 91 | def _get_paper_query(self) -> str: 92 | """Return the paper lookup query.""" 93 | return """ 94 | MATCH (p:Paper {id: $paper_id}) 95 | OPTIONAL MATCH (p)-[:AUTHORED_BY]->(a:Author) 96 | OPTIONAL MATCH (p)-[:BELONGS_TO]->(c:Category) 97 | RETURN p.id as id, p.title AS title, p.abstract AS abstract, 98 | p.submit_date AS submit_date, p.update_date AS update_date, 99 | collect(DISTINCT a.name) AS authors, 100 | collect(DISTINCT c.name) AS categories 101 | """ -------------------------------------------------------------------------------- /src/components/evaluation/custom_metric.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, List, Optional, Union 3 | import pydantic 4 | from opik.evaluation.metrics import base_metric, score_result 5 | from opik.evaluation.models import litellm_chat_model 6 | from opik.evaluation.models import base_model 7 | import logging 8 | LOGGER = logging.getLogger(__name__) 9 | 10 | 11 | class AnswerCompletenessResponseFormat(pydantic.BaseModel): 12 | answer_completeness_score: float 13 | reason: str 14 | 15 | 16 | class AnswerCompleteness(base_metric.BaseMetric): 17 | """ 18 | A metric that evaluates the completeness of an answer relative to the user's input. 19 | 20 | This metric uses an LLM to assess whether the generated answer fully addresses all aspects of the user's query. 21 | 22 | Args: 23 | model: The language model to use for evaluation. Defaults to "gpt-4o". 24 | name: The name of the metric. Defaults to "answer_completeness_metric". 25 | few_shot_examples: Optional list of few-shot examples to guide the LLM's evaluation. 26 | track: Whether to track the metric. Defaults to True. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | model: Optional[Union[str, base_model.OpikBaseModel]] = None, 32 | name: str = "answer_completeness_metric", 33 | few_shot_examples: Optional[List[Any]] = None, # Replace `Any` with actual type if defined 34 | track: bool = True, 35 | ): 36 | super().__init__( 37 | name=name, 38 | track=track, 39 | ) 40 | self._init_model(model) 41 | if few_shot_examples is None: 42 | self._few_shot_examples = [] # Define default few-shot examples if available 43 | else: 44 | self._few_shot_examples = few_shot_examples 45 | 46 | def _init_model( 47 | self, model: Optional[Union[str, base_model.OpikBaseModel]] 48 | ) -> None: 49 | if isinstance(model, base_model.OpikBaseModel): 50 | self._model = model 51 | else: 52 | self._model = litellm_chat_model.LiteLLMChatModel(model_name=model or "gpt-4o") 53 | 54 | def score( 55 | self, input: str, output: str, context: List[str], **ignored_kwargs: Any 56 | ) -> score_result.ScoreResult: 57 | """ 58 | Calculate the answer completeness score for the given input-output pair. 59 | 60 | Args: 61 | input: The user's question or prompt. 62 | output: The LLM-generated answer. 63 | context: A list of context strings relevant to the input. 64 | 65 | Returns: 66 | score_result.ScoreResult: Contains the completeness score and reason. 67 | """ 68 | llm_query = self._generate_prompt(input, output, context) 69 | model_output = self._model.generate_string( 70 | input=llm_query, response_format=AnswerCompletenessResponseFormat 71 | ) 72 | return self._parse_model_output(model_output) 73 | 74 | async def ascore( 75 | self, input: str, output: str, context: List[str], **ignored_kwargs: Any 76 | ) -> score_result.ScoreResult: 77 | """ 78 | Asynchronously calculate the answer completeness score for the given input-output pair. 79 | 80 | Args: 81 | input: The user's question or prompt. 82 | output: The LLM-generated answer. 83 | context: A list of context strings relevant to the input. 84 | 85 | Returns: 86 | score_result.ScoreResult: Contains the completeness score and reason. 87 | """ 88 | llm_query = self._generate_prompt(input, output, context) 89 | model_output = await self._model.agenerate_string( 90 | input=llm_query, response_format=AnswerCompletenessResponseFormat 91 | ) 92 | return self._parse_model_output(model_output) 93 | 94 | def _generate_prompt(self, input_text: str, output_text: str, context: List[str]) -> str: 95 | """ 96 | Generate the prompt to send to the LLM for evaluation. 97 | 98 | Args: 99 | input_text: The user's question or prompt. 100 | output_text: The LLM-generated answer. 101 | context: Relevant context information. 102 | 103 | Returns: 104 | str: The complete prompt for the LLM. 105 | """ 106 | context_combined = " ".join(context) 107 | prompt = f""" 108 | YOU ARE AN EXPERT IN NLP EVALUATION METRICS, TRAINED TO ASSESS THE COMPLETENESS OF ANSWERS PROVIDED BY LANGUAGE MODELS. 109 | 110 | ###INSTRUCTIONS### 111 | - ANALYZE THE GIVEN USER INPUT AND THE GENERATED ANSWER. 112 | - DETERMINE IF THE ANSWER FULLY ADDRESSES ALL ASPECTS OF THE USER'S QUERY. 113 | - ASSIGN A COMPLETENESS SCORE BETWEEN 0.0 (COMPLETELY INCOMPLETE) AND 1.0 (FULLY COMPLETE). 114 | - PROVIDE A BRIEF REASON FOR THE SCORE, HIGHLIGHTING WHICH PARTS WERE ADDRESSED OR MISSING. 115 | 116 | ###EXAMPLE### 117 | Input: "Explain the concept of transformer models in NLP and provide examples." 118 | Answer: "Transformer models are neural network architectures used in NLP that employ self-attention mechanisms..." 119 | Context: "Transformer models are neural network architectures used in NLP that employ self-attention mechanisms to process entire sentences simultaneously, capturing long-range dependencies and context." 120 | --- 121 | { 122 | "answer_completeness_score": 0.9, 123 | "reason": "The answer thoroughly explains transformer models and provides examples, fully addressing the user's request." 124 | } 125 | 126 | ###INPUTS:### 127 | *** 128 | User input: 129 | {input_text} 130 | Answer: 131 | {output_text} 132 | Contexts: 133 | {context_combined} 134 | *** 135 | """ 136 | return prompt 137 | 138 | def _parse_model_output(self, content: str) -> score_result.ScoreResult: 139 | """ 140 | Parse the LLM's JSON response into a ScoreResult object. 141 | 142 | Args: 143 | content: The JSON string returned by the LLM. 144 | 145 | Returns: 146 | score_result.ScoreResult: The parsed score and reason. 147 | """ 148 | dict_content = json.loads(content) 149 | score: float = dict_content.get("answer_completeness_score", 0.5) 150 | reason: str = dict_content.get("reason", "No reason provided.") 151 | 152 | # Validate score range 153 | if not (0.0 <= score <= 1.0): 154 | LOGGER.warning("Received score out of bounds. Defaulting to 0.5.") 155 | score = 0.5 156 | 157 | return score_result.ScoreResult( 158 | name=self.name, value=score, reason=reason 159 | ) 160 | 161 | -------------------------------------------------------------------------------- /src/components/evaluation/opik_evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import logging 3 | 4 | from opik.evaluation.metrics import ( 5 | Contains, 6 | Equals, 7 | LevenshteinRatio, 8 | Hallucination, 9 | Moderation, 10 | AnswerRelevance, 11 | GEval 12 | ) 13 | from opik.evaluation.metrics.score_result import ScoreResult 14 | 15 | from src.components.evaluation.custom_metric import AnswerCompleteness 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | TASK_INTRODUCTION = ( 20 | "You are an expert judge tasked with evaluating the faithfulness of an AI-generated answer to the given context." 21 | ) 22 | 23 | EVALUATION_CRITERIA = """ 24 | - The OUTPUT must not introduce new information beyond what's provided in the CONTEXT. 25 | - The OUTPUT must not contradict any information given in the CONTEXT. 26 | - The OUTPUT should be logically consistent and coherent. 27 | - The OUTPUT should comprehensively address all aspects of the user's query. 28 | """ 29 | 30 | class LlmEvaluator: 31 | def __init__(self): 32 | """ 33 | Hard-code a set of references and context for your 0704.0001 paper. 34 | Now we rely on the default opik classes directly. 35 | """ 36 | self.metrics = { 37 | "contains_diphoton": Contains(name="contains_diphoton", case_sensitive=False), 38 | "contains_berger": Contains(name="contains_berger", case_sensitive=False), 39 | "equals_title": Equals(name="equals_title"), 40 | "lev_ratio_abstract": LevenshteinRatio(name="lev_ratio_abstract"), 41 | } 42 | 43 | self.hallucination_metric = Hallucination() 44 | self.moderation_metric = Moderation() 45 | self.answer_relevance_metric = AnswerRelevance() 46 | self.g_eval_metric = GEval(task_introduction=TASK_INTRODUCTION, evaluation_criteria=EVALUATION_CRITERIA) 47 | 48 | # Custom metric 49 | self.answer_completeness_metric = AnswerCompleteness() 50 | 51 | self.abstract_0704_0001 = ( 52 | "A fully differential calculation in perturbative quantum chromodynamics is\n" 53 | "presented for the production of massive photon pairs at hadron colliders. All\n" 54 | "next-to-leading order perturbative contributions from quark-antiquark,\n" 55 | "gluon-(anti)quark, and gluon-gluon subprocesses are included, as well as\n" 56 | "all-orders resummation of initial-state gluon radiation valid at\n" 57 | "next-to-next-to-leading logarithmic accuracy. The region of phase space is\n" 58 | "specified in which the calculation is most reliable. Good agreement is\n" 59 | "demonstrated with data from the Fermilab Tevatron, and predictions are made for\n" 60 | "more detailed tests with CDF and DO data. Predictions are shown for\n" 61 | "distributions of diphoton pairs produced at the energy of the Large Hadron\n" 62 | "Collider (LHC). Distributions of the diphoton pairs from the decay of a Higgs\n" 63 | "boson are contrasted with those produced from QCD processes at the LHC, showing\n" 64 | "that enhanced sensitivity to the signal can be obtained with judicious\n" 65 | "selection of events.\n" 66 | ) 67 | self.context_0704_0001 = [ 68 | "Title: Calculation of prompt diphoton production cross sections at Tevatron and\n LHC energies", 69 | f"Abstract: {self.abstract_0704_0001}", 70 | "Authors: Balázs C., Berger E. L., Nadolsky P. M., Yuan C. -P.", 71 | "Submitted on: Mon, 2 Apr 2007 19:18:42 GMT", 72 | "Updated on: 2008-11-26" 73 | ] 74 | 75 | self.static_references = { 76 | "contains_diphoton": "diphoton", 77 | "contains_berger": "Berger", 78 | "equals_title": "Calculation of prompt diphoton production cross sections at Tevatron and\n LHC energies", 79 | "lev_ratio_abstract": self.abstract_0704_0001, 80 | } 81 | 82 | self.answer_context = ( 83 | "Transformer models are neural network architectures used in NLP " 84 | "that employ self-attention mechanisms to process entire sentences " 85 | "simultaneously, capturing long-range dependencies and context. They " 86 | "use positional encoding to differentiate between words in different " 87 | "positions within a sentence." 88 | ) 89 | 90 | 91 | def evaluate(self, output: str) -> Dict[str, ScoreResult]: 92 | """ 93 | Evaluate an LLM output with your *static* references for 0704.0001. 94 | Return a dict of {metric_name -> ScoreResult} objects. 95 | """ 96 | results = {} 97 | for metric_name, metric_obj in self.metrics.items(): 98 | # e.g. "contains_diphoton" => reference "diphoton" 99 | ref = self.static_references[metric_name] 100 | score_res = metric_obj.score(output=output, reference=ref) 101 | results[metric_name] = score_res 102 | return results 103 | 104 | def check_hallucination(self, input_text: str, output_text: str) -> ScoreResult: 105 | """ 106 | 0 = no hallucination, 1 = hallucination found. 107 | """ 108 | return self.hallucination_metric.score( 109 | input=input_text, 110 | output=output_text, 111 | context=self.context_0704_0001 112 | ) 113 | 114 | def check_moderation(self, output_text: str) -> ScoreResult: 115 | """ 116 | 0.0 => safe, up to 1.0 => extremely unsafe 117 | """ 118 | return self.moderation_metric.score(output=output_text) 119 | 120 | def check_answer_relevance( 121 | self, 122 | input_text: str, 123 | output_text: str, 124 | ) -> float: 125 | """ 126 | Return a float in [0..1], measuring how relevant `output_text` is 127 | to `input_text` given `context_snippet`. 128 | """ 129 | score_result = self.answer_relevance_metric.score( 130 | input=input_text, 131 | output=output_text, 132 | context=[self.answer_context] 133 | ) 134 | return score_result.value 135 | 136 | def check_g_eval( 137 | self, output_text: str 138 | ) -> float: 139 | """ 140 | Evaluate the LLM's output using GEval metric. 141 | """ 142 | score_result = self.g_eval_metric.score( 143 | output=output_text 144 | ) 145 | return score_result.value 146 | 147 | def check_answer_completeness( 148 | self, input_text: str, output_text: str 149 | ) -> float: 150 | """ 151 | Check how complete the LLM's answer is relative to the user's input. 152 | 153 | Args: 154 | input_text: The user's question or prompt. 155 | output_text: The LLM-generated answer. 156 | 157 | Returns: 158 | float: Completeness score between 0.0 and 1.0. 159 | """ 160 | score_result = self.answer_completeness_metric.score( 161 | input=input_text, 162 | output=output_text, 163 | context=[self.answer_context] 164 | ) 165 | return score_result.value 166 | 167 | -------------------------------------------------------------------------------- /src/orchestrator/coordinator.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import time 3 | import logging 4 | from typing import Dict, Any 5 | 6 | from langchain_community.chat_models import ChatOpenAI 7 | from langchain.schema import HumanMessage, AIMessage 8 | 9 | from src.config.settings import Settings 10 | from src.core.graph import create_research_graph 11 | from src.components.database.neo4j_client import Neo4jClient 12 | from src.components.paper.tool import PaperTool 13 | from src.components.rag.tool import RAG 14 | from src.components.rag.embeddings import Embedding 15 | from src.components.database.vector_store import VectorStore 16 | from src.components.evaluation.experiment_tracker import ExperimentTracker, MetricsCollector 17 | from src.components.evaluation.opik_evaluator import LlmEvaluator 18 | from src.tools.paper_lookup import PaperLookupTool 19 | from src.tools.rag import RAGTool 20 | from src.agents.research_assistant import ResearchAssistant 21 | from dotenv import load_dotenv 22 | 23 | load_dotenv() 24 | 25 | # Configure logging 26 | logging.basicConfig( 27 | level=logging.INFO, 28 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 29 | ) 30 | logging.getLogger("httpx").setLevel(logging.WARNING) 31 | logging.getLogger("urllib3").setLevel(logging.WARNING) 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | class Coordinator: 37 | def __init__(self): 38 | self.settings = Settings() 39 | self.experiment_tracker = self.setup_experiment_tracker() 40 | self.metrics_collector = MetricsCollector() 41 | self.services = self.initialize_services() 42 | self.tools = self.initialize_tools() 43 | self.assistant = self.initialize_assistant() 44 | self.graph = self.setup_graph() 45 | 46 | # Instantiate our new evaluator 47 | self.llm_evaluator = LlmEvaluator() 48 | 49 | def setup_experiment_tracker(self) -> ExperimentTracker: 50 | tracker = ExperimentTracker( 51 | api_key=self.settings.cometml_api_key, 52 | project_name=self.settings.project_name 53 | ) 54 | tracker.experiment.add_tags(['v1', 'graph-rag', 'research-papers']) 55 | tracker.experiment.log_parameter("session_id", str(uuid.uuid4())) 56 | tracker.experiment.log_parameter("session_start", time.strftime("%Y-%m-%d %H:%M:%S")) 57 | return tracker 58 | 59 | def initialize_services(self) -> Dict[str, Any]: 60 | db_client = Neo4jClient( 61 | uri=self.settings.neo4j_uri, 62 | user=self.settings.neo4j_user, 63 | password=self.settings.neo4j_password 64 | ) 65 | with db_client.session() as session: 66 | result = session.run("RETURN 1 as num").single() 67 | print(f"Initial connection test result: {result['num']}") 68 | embedding_service = Embedding(api_key=self.settings.openai_api_key) 69 | vector_store = VectorStore( 70 | neo4j_client=db_client, 71 | embedding_model=embedding_service.model, 72 | index_name="paper_vector_index" 73 | ) 74 | paper_service = PaperTool(db_client=db_client) 75 | rag_service = RAG( 76 | vector_store=vector_store, 77 | openai_api_key=self.settings.openai_api_key 78 | ) 79 | 80 | return { 81 | "db_client": db_client, 82 | "embedding_service": embedding_service, 83 | "vector_store": vector_store, 84 | "paper_service": paper_service, 85 | "rag_service": rag_service 86 | } 87 | 88 | def initialize_tools(self) -> Dict[str, Any]: 89 | paper_lookup_tool = PaperLookupTool( 90 | paper_service=self.services["paper_service"], 91 | experiment_tracker=self.experiment_tracker, 92 | metrics_collector=self.metrics_collector 93 | ) 94 | rag_tool = RAGTool( 95 | rag_service=self.services["rag_service"], 96 | experiment_tracker=self.experiment_tracker, 97 | metrics_collector=self.metrics_collector 98 | ) 99 | return { 100 | "paper_lookup": paper_lookup_tool, 101 | "rag": rag_tool 102 | } 103 | 104 | def initialize_assistant(self) -> ResearchAssistant: 105 | llm = ChatOpenAI( 106 | temperature=0, 107 | openai_api_key=self.settings.openai_api_key 108 | ) 109 | tools = [self.tools["paper_lookup"], self.tools["rag"]] 110 | return ResearchAssistant( 111 | experiment_tracker=self.experiment_tracker, 112 | tools=tools, 113 | llm=llm 114 | ) 115 | 116 | def setup_graph(self) -> Any: 117 | return create_research_graph( 118 | assistant=self.assistant, 119 | rag_tool=self.tools["rag"], 120 | paper_lookup_tool=self.tools["paper_lookup"] 121 | ) 122 | 123 | def process_message(self, message: str, state: Dict[str, Any]) -> None: 124 | try: 125 | # Store user message 126 | state["messages"].append(HumanMessage(content=message)) 127 | 128 | # Log conversation metrics 129 | self.experiment_tracker.experiment.log_metrics({ 130 | "conversation_turn": len(state["messages"]), 131 | "message_length": len(message) 132 | }) 133 | 134 | # Assistant response (LangChain chain) 135 | response = self.assistant(state) 136 | ai_messages = response["messages"] 137 | state["messages"].extend(ai_messages) 138 | 139 | # Grab the final user-facing output 140 | ai_text = ai_messages[-1].content if ai_messages else "" 141 | 142 | # Hallucination score 143 | hallucination_score = self.llm_evaluator.check_hallucination(message, ai_text) 144 | self.experiment_tracker.experiment.log_metric("hallucination_score", hallucination_score) 145 | 146 | # Moderation score 147 | moderation_score = self.llm_evaluator.check_moderation(ai_text) 148 | self.experiment_tracker.experiment.log_metric("moderation_score", moderation_score) 149 | 150 | # Evaluate references (Contains, Equals, LevenshteinRatio) 151 | metric_scores = self.llm_evaluator.evaluate(ai_text) 152 | for metric_name, score_result in metric_scores.items(): 153 | self.experiment_tracker.experiment.log_metric(metric_name, score_result.value) 154 | 155 | # Answer relevance 156 | relevance_score = self.llm_evaluator.check_answer_relevance( 157 | input_text=message, 158 | output_text=ai_text, 159 | ) 160 | 161 | # GEval metric 162 | g_eval_score = self.llm_evaluator.check_g_eval( 163 | output_text=ai_text 164 | ) 165 | self.experiment_tracker.experiment.log_metric("g_eval_score", g_eval_score) 166 | 167 | self.experiment_tracker.experiment.log_metric("answer_relevance_score", relevance_score) 168 | logger.info(f"Answer Relevance score: {relevance_score}") 169 | 170 | # Print final answer 171 | for msg in ai_messages: 172 | if isinstance(msg, AIMessage): 173 | self.experiment_tracker.experiment.log_metrics({"response_length": len(msg.content)}) 174 | print(f"Assistant: {msg.content}") 175 | 176 | except Exception as e: 177 | self.experiment_tracker.experiment.log_metric("errors", 1) 178 | print(f"Error in process_message: {str(e)}") 179 | 180 | def run(self): 181 | try: 182 | print("Research Paper Assistant initialized. Type 'exit' to quit.") 183 | state = { 184 | "messages": [], 185 | "metrics": {}, 186 | "conversation_history": [] 187 | } 188 | while True: 189 | user_input = input("\nYour question: ").strip() 190 | if user_input.lower() in ['exit', 'quit', 'bye']: 191 | self.cleanup() 192 | break 193 | self.process_message(user_input, state) 194 | except KeyboardInterrupt: 195 | print("\n\nSession interrupted by user.") 196 | self.cleanup() 197 | except Exception as e: 198 | logger.error(f"Unexpected error: {str(e)}") 199 | print(f"\n\nAn error occurred: {str(e)}") 200 | self.cleanup() 201 | 202 | def cleanup(self): 203 | try: 204 | if hasattr(self, 'experiment_tracker'): 205 | session_end_time = time.strftime("%Y-%m-%d %H:%M:%S") 206 | self.experiment_tracker.experiment.log_parameter("session_end", session_end_time) 207 | final_metrics = { 208 | "total_messages": len(self.assistant.memory.chat_memory.messages), 209 | "total_user_messages": len( 210 | [m for m in self.assistant.memory.chat_memory.messages if isinstance(m, HumanMessage)] 211 | ), 212 | "total_ai_messages": len( 213 | [m for m in self.assistant.memory.chat_memory.messages if isinstance(m, AIMessage)] 214 | ) 215 | } 216 | self.experiment_tracker.experiment.log_metrics(final_metrics) 217 | self.experiment_tracker.experiment.end() 218 | self.services["db_client"].close() 219 | print("\nSession ended. Thank you for using the Research Paper Assistant!") 220 | except Exception as e: 221 | logger.error(f"Error during cleanup: {str(e)}") 222 | --------------------------------------------------------------------------------