├── src
├── __init__.py
├── agents
│ ├── __init__.py
│ ├── base.py
│ └── research_assistant.py
├── config
│ ├── __init__.py
│ └── settings.py
├── core
│ ├── __init__.py
│ ├── state.py
│ ├── messages.py
│ └── graph.py
├── tools
│ ├── __init__.py
│ ├── rag.py
│ └── paper_lookup.py
├── utils
│ ├── __init__.py
│ └── paper_id_extractor.py
├── components
│ ├── __init__.py
│ ├── paper
│ │ ├── __init__.py
│ │ ├── models.py
│ │ └── tool.py
│ ├── rag
│ │ ├── __init__.py
│ │ ├── embeddings.py
│ │ ├── indexing.py
│ │ └── tool.py
│ ├── database
│ │ ├── __init__.py
│ │ ├── neo4j_client.py
│ │ ├── vector_store.py
│ │ ├── ingest.py
│ │ └── neo4j_ingestion.py
│ └── evaluation
│ │ ├── __init__.py
│ │ ├── experiment_tracker.py
│ │ ├── custom_metric.py
│ │ └── opik_evaluator.py
├── orchestrator
│ ├── __init__.py
│ └── coordinator.py
├── streamlit
│ ├── __init__.py
│ ├── session.py
│ ├── message.py
│ ├── predefined_questions.py
│ ├── main.py
│ ├── ui_component.py
│ └── layout.py
└── main.py
├── .gitignore
├── .DS_Store
├── .env.example
├── .idea
├── vcs.xml
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── misc.xml
└── CometML.iml
├── requirements.txt
├── pyproject.toml
├── scripts
├── neo4j_cleaner.py
└── preprocess.py
└── readme.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/agents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/config/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/orchestrator/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/streamlit/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/components/paper/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/components/rag/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/components/database/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/components/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | !.env.example
3 | neo4j.conf
4 |
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlvanguards/agentic-graph-rag-evaluation-cometml/HEAD/.DS_Store
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | from src.streamlit.main import main
2 |
3 | if __name__ == "__main__":
4 | main()
5 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | #OPENAI
2 | OPENAI_API_KEY=openai-api-key
3 |
4 | #NEO4J
5 | NEO4J_URI=uri
6 | NEO4J_USER=username
7 | NEO4J_PASSWORD=password
8 |
9 | #COMETML
10 | COMETML_API_KEY=cometml_api_key
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/src/core/state.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, TypedDict, Dict, List, Any
2 | from langchain_core.messages import AnyMessage
3 | from langgraph.graph.message import add_messages
4 |
5 | class ConversationState(TypedDict):
6 | messages: Annotated[list[AnyMessage], add_messages]
7 | metrics: dict
8 | conversation_history: List[Dict[str, Any]]
9 |
--------------------------------------------------------------------------------
/.idea/CometML.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/src/core/messages.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional
3 | from datetime import datetime
4 |
5 | @dataclass
6 | class ConversationMessage:
7 | content: str
8 | type: str
9 | timestamp: datetime
10 | query: Optional[str] = None
11 | paper_id: Optional[str] = None
12 | response: Optional[str] = None
13 |
--------------------------------------------------------------------------------
/src/streamlit/session.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List, Dict, Any
3 |
4 | @dataclass
5 | class SessionState:
6 | messages: List[Any] = field(default_factory=list)
7 | metrics: Dict[str, Any] = field(default_factory=dict)
8 | conversation_history: List[str] = field(default_factory=list)
9 | show_predefined: bool = True
10 | session_active: bool = True
11 |
--------------------------------------------------------------------------------
/src/streamlit/message.py:
--------------------------------------------------------------------------------
1 | from langchain.schema import AIMessage, HumanMessage
2 |
3 | class Message:
4 | def __init__(self, content: str, is_human: bool = True):
5 | self.content = content
6 | self.message = HumanMessage(content=content) if is_human else AIMessage(content=content)
7 |
8 | def format_for_display(self) -> str:
9 | prefix = "You" if isinstance(self.message, HumanMessage) else "Assistant"
10 | return f"**{prefix}:** {self.content}"
11 |
--------------------------------------------------------------------------------
/src/agents/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from src.core.state import ConversationState
3 | from typing import Dict, Any
4 |
5 | class BaseAgent(ABC):
6 | @abstractmethod
7 | def process_message(self, state: ConversationState) -> Dict[str, Any]:
8 | """Process a message and return updated state."""
9 | pass
10 |
11 | @abstractmethod
12 | def handle_error(self, error: Exception) -> Dict[str, Any]:
13 | """Handle errors during message processing."""
14 | pass
--------------------------------------------------------------------------------
/src/config/settings.py:
--------------------------------------------------------------------------------
1 | from dotenv import load_dotenv
2 | import os
3 |
4 | load_dotenv()
5 |
6 | class Settings:
7 | def __init__(self):
8 | load_dotenv()
9 | self.cometml_api_key = os.getenv("COMETML_API_KEY")
10 | self.project_name = "research-paper-rag"
11 | self.openai_api_key = os.getenv("OPENAI_API_KEY")
12 | self.neo4j_uri = os.getenv("NEO4J_URI")
13 | self.neo4j_user = os.getenv("NEO4J_USERNAME")
14 | self.neo4j_password = os.getenv("NEO4J_PASSWORD")
15 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | comet-ml==3.47.0
2 | langchain==0.3.3
3 | pandas==2.2.3
4 | langchain-community==0.3.2
5 | langchain-experimental==0.3.3
6 | langchain-openai==0.2.2
7 | graphdatascience==1.12.0
8 | tiktoken==0.8.0
9 | retry==0.9.2
10 | neo4j==5.21.0
11 | sentence-transformers==3.2.0
12 | faiss-cpu==1.9.0
13 | python-dotenv==1.0.1
14 | langgraph==0.2.39
15 | langgraph-checkpoint-sqlite==2.0.0
16 | postgres==4.0.0
17 | streamlit==1.39.0
18 |
19 | # Additional dependencies
20 | numpy>=1.24.0
21 | torch>=2.0.0
22 | transformers>=4.30.0
23 | openai>=1.0.0
24 | pydantic>=2.0.0
25 | typing-extensions>=4.5.0
26 | tqdm>=4.65.0
27 | requests>=2.31.0
28 | aiohttp>=3.10.11
--------------------------------------------------------------------------------
/src/streamlit/predefined_questions.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Callable
3 |
4 | @dataclass
5 | class PredefinedQuestion:
6 | text: str
7 | callback: Callable[[str], None]
8 |
9 | class PredefinedQuestionsManager:
10 | def __init__(self):
11 | self.questions = [
12 | "Explain the concept of transformer models in NLP.",
13 | "What is the significance of paper ID 0704.2002?",
14 | "How does reinforcement learning apply to robotics?",
15 | "Summarize recent advances in computer vision."
16 | ]
17 |
18 | def get_questions(self) -> List[str]:
19 | return self.questions
20 |
--------------------------------------------------------------------------------
/src/utils/paper_id_extractor.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Optional
3 |
4 | class PaperIdExtractor:
5 | _PATTERNS = [
6 | r'paper\s*(?:id|ID|Id)?\s*[:#]?\s*(\d{4}\.\d{4})',
7 | r'id\s*[:#]?\s*(\d{4}\.\d{4})',
8 | r'(? Optional[str]:
16 | """Extract paper ID from text using various patterns."""
17 | for pattern in cls._PATTERNS:
18 | match = re.search(pattern, text, re.IGNORECASE)
19 | if match:
20 | return match.group(1)
21 | return None
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "cometml"
3 | version = "0.1.0"
4 | description = ""
5 | authors = ["stefan"]
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = "^3.12"
10 | comet-ml = "^3.47.0"
11 | langchain = "^0.3.3"
12 | pandas = "^2.2.3"
13 | langchain-community = "^0.3.2"
14 | langchain-experimental = "^0.3.2"
15 | langchain-openai = "^0.2.2"
16 | graphdatascience = "^1.12"
17 | tiktoken = "^0.8.0"
18 | retry = "^0.9.2"
19 | neo4j = "5.21.0"
20 | sentence-transformers = "^3.2.0"
21 | faiss-cpu = "^1.9.0"
22 | python-dotenv = "^1.0.1"
23 | langgraph = "^0.2.39"
24 | langgraph-checkpoint-sqlite = "^2.0.0"
25 | postgres = "^4.0"
26 | streamlit = "^1.39.0"
27 | opik = "^1.3.3"
28 |
29 |
30 | [build-system]
31 | requires = ["poetry-core"]
32 | build-backend = "poetry.core.masonry.api"
33 |
--------------------------------------------------------------------------------
/src/components/database/neo4j_client.py:
--------------------------------------------------------------------------------
1 | from neo4j import GraphDatabase
2 | from contextlib import contextmanager
3 |
4 | class Neo4jClient:
5 | def __init__(self, uri: str, user: str, password: str):
6 | self.uri = uri
7 | self.user = user
8 | self.password = password
9 | self._driver = None
10 |
11 | @property
12 | def driver(self):
13 | if self._driver is None:
14 | self._driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
15 | return self._driver
16 |
17 | @contextmanager
18 | def session(self):
19 | session = self.driver.session()
20 | try:
21 | yield session
22 | finally:
23 | session.close()
24 |
25 | def close(self):
26 | if self._driver:
27 | self._driver.close()
28 | self._driver = None
--------------------------------------------------------------------------------
/src/components/rag/embeddings.py:
--------------------------------------------------------------------------------
1 | from langchain_openai import OpenAIEmbeddings
2 | from typing import List, Optional
3 | import os
4 |
5 | class Embedding:
6 | def __init__(self, api_key: Optional[str] = None):
7 | if not api_key and not os.getenv("OPENAI_API_KEY"):
8 | raise ValueError("OpenAI API key must be provided either directly or through environment variable")
9 | self.api_key = api_key or os.getenv("OPENAI_API_KEY")
10 | if not self.api_key.startswith("sk-"):
11 | raise ValueError("Invalid OpenAI API key format")
12 | self.model = OpenAIEmbeddings(openai_api_key=self.api_key)
13 |
14 |
15 | def embed_documents(self, texts: List[str]) -> List[List[float]]:
16 | try:
17 | return self.model.embed_documents(texts)
18 | except Exception as e:
19 | raise ValueError(f"Error generating embeddings: {str(e)}")
20 |
--------------------------------------------------------------------------------
/src/streamlit/main.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from src.orchestrator.coordinator import Coordinator
3 | from src.streamlit.layout import ResearchAssistantUI
4 |
5 | def main():
6 | if "app" not in st.session_state:
7 | st.session_state.app = Coordinator()
8 |
9 | # Initialize session state variables if they don't exist
10 | if "messages" not in st.session_state:
11 | st.session_state.messages = []
12 | if "metrics" not in st.session_state:
13 | st.session_state.metrics = {}
14 | if "conversation_history" not in st.session_state:
15 | st.session_state.conversation_history = []
16 | if "show_predefined" not in st.session_state:
17 | st.session_state.show_predefined = True
18 | if "session_active" not in st.session_state:
19 | st.session_state.session_active = True
20 |
21 | ui = ResearchAssistantUI(st.session_state.app, st.session_state)
22 | ui.render()
--------------------------------------------------------------------------------
/src/components/paper/models.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional
3 | from datetime import datetime
4 |
5 | @dataclass
6 | class Paper:
7 | id: str
8 | title: str
9 | abstract: str
10 | authors: List[str]
11 | categories: List[str]
12 | submit_date: datetime
13 | update_date: Optional[datetime] = None
14 |
15 | @classmethod
16 | def from_db_record(cls, record: dict) -> 'Paper':
17 | return cls(
18 | id=record.get('id'),
19 | title=record.get('title'),
20 | abstract=record.get('abstract'),
21 | authors=record.get('authors', []),
22 | categories=record.get('categories', []),
23 | submit_date=record.get('submit_date'),
24 | update_date=record.get('update_date')
25 | )
26 |
27 | def to_string(self) -> str:
28 | return f"""Title: {self.title}
29 | Abstract: {self.abstract}
30 | Authors: {', '.join(self.authors) if self.authors else 'No authors listed'}
31 | Categories: {', '.join(self.categories) if self.categories else 'No categories listed'}
32 | Submitted on: {self.submit_date}
33 | Updated on: {self.update_date if self.update_date else 'N/A'}"""
34 |
--------------------------------------------------------------------------------
/src/core/graph.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 | from langgraph.graph import START
3 | from langgraph.graph.state import StateGraph
4 | from langgraph.checkpoint.memory import MemorySaver
5 | from langgraph.prebuilt import tools_condition
6 | from src.core.state import ConversationState
7 | from src.tools.rag import RAG
8 | from src.tools.paper_lookup import PaperLookupTool
9 |
10 | def create_research_graph(assistant: Any, rag_tool: RAG, paper_lookup_tool: PaperLookupTool) -> Any:
11 | """Creates and returns the research graph with the specified components."""
12 |
13 | # Initialize graph builder
14 | builder = StateGraph(ConversationState)
15 |
16 | # Add nodes
17 | builder.add_node("assistant", assistant)
18 | builder.add_node("rag_tool", rag_tool)
19 | builder.add_node("paper_lookup_tool", paper_lookup_tool)
20 |
21 | # Define edges
22 | builder.add_edge(START, "assistant")
23 | builder.add_conditional_edges(
24 | "assistant",
25 | tools_condition
26 | )
27 | builder.add_edge("rag_tool", "assistant")
28 | builder.add_edge("paper_lookup_tool", "assistant")
29 |
30 | # Set up memory and compile
31 | memory = MemorySaver()
32 | return builder.compile(checkpointer=memory)
33 |
--------------------------------------------------------------------------------
/src/components/database/vector_store.py:
--------------------------------------------------------------------------------
1 | from langchain_community.vectorstores import Neo4jVector
2 | from src.components.database.neo4j_client import Neo4jClient
3 | from typing import List, Tuple, Optional
4 |
5 | class VectorStore:
6 | def __init__(self, neo4j_client: Neo4jClient, embedding_model, index_name: str):
7 | self.client = neo4j_client
8 | self.embedding_model = embedding_model
9 | self.index_name = index_name
10 | self.vector_store = self._initialize_vector_store()
11 |
12 | def _initialize_vector_store(self) -> Neo4jVector:
13 | return Neo4jVector(
14 | embedding=self.embedding_model,
15 | url=self.client.uri,
16 | username=self.client.user,
17 | password=self.client.password,
18 | index_name=self.index_name,
19 | node_label="Paper",
20 | text_node_property="abstract"
21 | )
22 |
23 | def similarity_search(self, query: str, k: int = 3) -> List[Tuple[str, float]]:
24 | try:
25 | results = self.vector_store.similarity_search_with_score(query, k=k)
26 | return [(doc.page_content, score) for doc, score in results]
27 | except Exception as e:
28 | raise ValueError(f"Error performing similarity search: {str(e)}")
29 |
--------------------------------------------------------------------------------
/scripts/neo4j_cleaner.py:
--------------------------------------------------------------------------------
1 | from neo4j import GraphDatabase
2 | from src.config.settings import Settings
3 |
4 | settings = Settings()
5 |
6 | class Neo4jCleaner:
7 | def __init__(self, uri, user, password):
8 | self.driver = GraphDatabase.driver(uri, auth=(user, password))
9 |
10 | def close(self):
11 | self.driver.close()
12 |
13 | def delete_all_data(self):
14 | with self.driver.session() as session:
15 | session.run("MATCH (n) DETACH DELETE n")
16 | print("All nodes and relationships have been deleted.")
17 |
18 | def delete_constraints_and_indexes(self):
19 | with self.driver.session() as session:
20 | # Drop all constraints
21 | for constraint in session.run("SHOW CONSTRAINTS"):
22 | session.run(f"DROP CONSTRAINT {constraint['name']}")
23 |
24 | # Drop all indexes
25 | for index in session.run("SHOW INDEXES"):
26 | session.run(f"DROP INDEX {index['name']}")
27 |
28 | print("All constraints and indexes have been dropped.")
29 |
30 |
31 | def main():
32 | cleaner = Neo4jCleaner(uri=settings.neo4j_uri, user=settings.neo4j_user, password=settings.neo4j_password)
33 |
34 | try:
35 | cleaner.delete_all_data()
36 | cleaner.delete_constraints_and_indexes()
37 | finally:
38 | cleaner.close()
39 |
40 |
41 | if __name__ == "__main__":
42 | main()
--------------------------------------------------------------------------------
/src/components/rag/indexing.py:
--------------------------------------------------------------------------------
1 | from src.services.database.neo4j_client import Neo4jClient
2 | from src.services.rag.embeddings import Embedding
3 | import logging
4 |
5 | class IndexingService:
6 | def __init__(
7 | self,
8 | db_client: Neo4jClient,
9 | embedding_service: Embedding,
10 | batch_size: int = 100
11 | ):
12 | self.db_client = db_client
13 | self.embedding_service = embedding_service
14 | self.batch_size = batch_size
15 | self.logger = logging.getLogger(__name__)
16 |
17 | def ensure_vector_index(self, index_name: str):
18 | with self.db_client.session() as session:
19 | if not self._vector_index_exists(session, index_name):
20 | self._create_vector_index(session, index_name)
21 |
22 | def _vector_index_exists(self, session, index_name: str) -> bool:
23 | query = """
24 | SHOW INDEXES
25 | YIELD name, type
26 | WHERE name = $index_name AND type = 'VECTOR'
27 | RETURN count(*) > 0 AS exists
28 | """
29 | result = session.run(query, index_name=index_name)
30 | return result.single()['exists']
31 |
32 | def _create_vector_index(self, session, index_name: str):
33 | session.run("""
34 | CALL db.index.vector.createNodeIndex(
35 | $index_name,
36 | 'Paper',
37 | 'embedding',
38 | 1536,
39 | 'cosine'
40 | )
41 | """, index_name=index_name)
42 | self.logger.info(f"Vector index '{index_name}' created.")
43 |
--------------------------------------------------------------------------------
/src/streamlit/ui_component.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from typing import List, Callable
3 | from langchain.schema import HumanMessage, AIMessage
4 |
5 | class ChatDisplay:
6 | @staticmethod
7 | def display_messages(messages: List):
8 | for message in messages:
9 | if hasattr(message, 'format_for_display'):
10 | formatted_message = message.format_for_display()
11 | elif isinstance(message, HumanMessage):
12 | formatted_message = f"**You:** {message.content}"
13 | elif isinstance(message, AIMessage):
14 | formatted_message = f"**Assistant:** {message.content}"
15 | else:
16 | formatted_message = f"**Unknown:** {message.content}"
17 |
18 | st.write(formatted_message)
19 | st.markdown("
" * 2, unsafe_allow_html=True)
20 |
21 |
22 | class InputArea:
23 | def __init__(self, on_submit: Callable[[str], None]):
24 | self.on_submit = on_submit
25 |
26 | def render(self):
27 | st.text_input(
28 | "Enter your question here:",
29 | key="user_input",
30 | on_change=self._handle_input
31 | )
32 |
33 | def _handle_input(self):
34 | if st.session_state.user_input:
35 | self.on_submit(st.session_state.user_input)
36 | st.session_state.user_input = ""
37 |
38 | class SessionControls:
39 | def __init__(self, on_end_session: Callable[[], None]):
40 | self.on_end_session = on_end_session
41 |
42 | def render(self):
43 | col1, col2, col3 = st.columns([1, 3, 1])
44 | with col1:
45 | if st.button("End Session"):
46 | self.on_end_session()
--------------------------------------------------------------------------------
/src/components/database/ingest.py:
--------------------------------------------------------------------------------
1 | import json
2 | import multiprocessing
3 | import time
4 | from dotenv import load_dotenv
5 | from src.config.settings import Settings
6 | from src.components.database.neo4j_ingestion import OptimizedNeo4jIngestor, worker
7 |
8 | load_dotenv()
9 |
10 | def ingest_data_parallel(uri: str, user: str, password: str, data, batch_size: int = 1000, num_processes: int = 4):
11 | total = len(data)
12 | batches = [data[i:i + batch_size] for i in range(0, total, batch_size)]
13 |
14 | with multiprocessing.Pool(num_processes) as pool:
15 | results = []
16 | for batch in batches:
17 | result = pool.apply_async(worker, (uri, user, password, batch))
18 | results.append(result)
19 |
20 | for i, result in enumerate(results):
21 | result.get() # Wait for the batch to complete
22 | print(f"Ingested {min((i + 1) * batch_size, total)}/{total} papers")
23 |
24 |
25 | def main():
26 | settings = Settings()
27 | uri = settings.neo4j_uri
28 | user = settings.neo4j_user
29 | password = settings.neo4j_password
30 |
31 | # Initialize and create constraints
32 | ingestor = OptimizedNeo4jIngestor(uri, user, password)
33 | ingestor.create_constraints()
34 | ingestor.close()
35 |
36 | # Load data
37 | with open('processed_data.json', 'r') as f:
38 | processed_data = json.load(f)
39 |
40 | # Ingest in parallel
41 | start_time = time.time()
42 | ingest_data_parallel(uri, user, password, processed_data, batch_size=1000, num_processes=4)
43 | end_time = time.time()
44 |
45 | print(f"Total ingestion time: {end_time - start_time:.2f} seconds")
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
50 |
--------------------------------------------------------------------------------
/src/components/database/neo4j_ingestion.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Any
2 | from neo4j import GraphDatabase
3 |
4 | class OptimizedNeo4jIngestor:
5 | def __init__(self, uri: str, user: str, password: str):
6 | self.driver = GraphDatabase.driver(uri, auth=(user, password))
7 |
8 | def close(self):
9 | self.driver.close()
10 |
11 | def create_constraints(self):
12 | with self.driver.session() as session:
13 | session.run("CREATE CONSTRAINT paper_id IF NOT EXISTS FOR (p:Paper) REQUIRE p.id IS UNIQUE")
14 | session.run("CREATE CONSTRAINT author_name IF NOT EXISTS FOR (a:Author) REQUIRE a.name IS UNIQUE")
15 | session.run("CREATE CONSTRAINT category_name IF NOT EXISTS FOR (c:Category) REQUIRE c.name IS UNIQUE")
16 |
17 | def ingest_batch(self, batch: List[Dict[str, Any]]):
18 | with self.driver.session() as session:
19 | session.execute_write(self._create_and_link_batch, batch)
20 |
21 | def _create_and_link_batch(self, tx, batch: List[Dict[str, Any]]):
22 | query = """
23 | UNWIND $batch AS paper
24 | MERGE (p:Paper {id: paper.id})
25 | SET p.title = paper.title, p.abstract = paper.abstract,
26 | p.submit_date = paper.submit_date, p.update_date = paper.update_date
27 | WITH p, paper
28 | UNWIND paper.authors AS author_name
29 | MERGE (a:Author {name: author_name})
30 | MERGE (p)-[:AUTHORED_BY]->(a)
31 | WITH p, paper
32 | UNWIND paper.categories AS category_name
33 | MERGE (c:Category {name: category_name})
34 | MERGE (p)-[:BELONGS_TO]->(c)
35 | """
36 | tx.run(query, batch=batch)
37 |
38 |
39 | def worker(uri: str, user: str, password: str, batch: List[Dict[str, Any]]):
40 | ingestor = OptimizedNeo4jIngestor(uri, user, password)
41 | try:
42 | ingestor.ingest_batch(batch)
43 | finally:
44 | ingestor.close()
45 |
--------------------------------------------------------------------------------
/src/tools/rag.py:
--------------------------------------------------------------------------------
1 | from langchain.tools import BaseTool
2 | from pydantic import PrivateAttr
3 | import time
4 |
5 | from src.components.rag.tool import RAG
6 | from src.components.evaluation.experiment_tracker import ExperimentTracker, MetricsCollector, MetricsData
7 |
8 | class RAGTool(BaseTool):
9 | name: str = "RAG"
10 | description: str = "Use this tool to retrieve research papers and generate answers to general queries."
11 | _rag_service: RAG = PrivateAttr()
12 | _experiment_tracker: ExperimentTracker = PrivateAttr()
13 | _metrics_collector: MetricsCollector = PrivateAttr()
14 |
15 | def __init__(
16 | self,
17 | rag_service: RAG,
18 | experiment_tracker: ExperimentTracker,
19 | metrics_collector: MetricsCollector
20 | ):
21 | super().__init__()
22 | self._rag_service = rag_service
23 | self._experiment_tracker = experiment_tracker
24 | self._metrics_collector = metrics_collector
25 |
26 | def _run(self, query: str) -> str:
27 | """Execute the RAG functionality."""
28 | # Start timing
29 | start_time = time.time()
30 |
31 | response_text = ""
32 | success = False
33 | error_msg = None
34 |
35 | try:
36 | response = self._rag_service.answer_question(query)
37 | response_text = response["response"]
38 | success = True
39 | except Exception as e:
40 | error_msg = str(e)
41 | response_text = f"Error processing RAG query: {error_msg}"
42 |
43 | # End timing
44 | end_time = time.time()
45 | processing_time = end_time - start_time
46 |
47 | # Build MetricsData
48 | metrics_data = MetricsData(
49 | processing_time=processing_time,
50 | query_length=len(query),
51 | response_length=len(response_text),
52 | success=success,
53 | token_count=self._metrics_collector.count_tokens(response_text),
54 | error=error_msg
55 | )
56 |
57 | # Log to CometML
58 | self._experiment_tracker.log_rag_query(metrics_data)
59 | return response_text
60 |
--------------------------------------------------------------------------------
/scripts/preprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from typing import List, Dict, Any
4 |
5 |
6 | def preprocess_data(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
7 | """
8 | Preprocess the arXiv papers data.
9 |
10 | Args:
11 | data: List of dictionaries containing paper information
12 |
13 | Returns:
14 | List of processed paper dictionaries
15 | """
16 | processed_data = []
17 | for paper in data:
18 | processed_paper = {
19 | 'id': paper['id'],
20 | 'title': paper['title'],
21 | 'abstract': paper['abstract'],
22 | 'categories': paper['categories'].split(),
23 | 'authors': [' '.join(author).strip()
24 | for author in paper['authors_parsed']],
25 | 'submit_date': paper['versions'][0]['created'],
26 | 'update_date': paper['update_date']
27 | }
28 | processed_data.append(processed_paper)
29 | return processed_data
30 |
31 |
32 | def main():
33 | # Set up argument parser
34 | parser = argparse.ArgumentParser(description='Preprocess arXiv papers data')
35 | parser.add_argument('--input', type=str, required=True,
36 | help='Input JSON file path')
37 | parser.add_argument('--output', type=str, required=True,
38 | help='Output JSON file path')
39 |
40 | # Parse arguments
41 | args = parser.parse_args()
42 |
43 | # Read input file
44 | print(f"Reading data from {args.input}...")
45 | try:
46 | with open(args.input, 'r') as f:
47 | raw_data = json.load(f)
48 | except FileNotFoundError:
49 | print(f"Error: Input file {args.input} not found")
50 | return
51 | except json.JSONDecodeError:
52 | print(f"Error: Input file {args.input} is not valid JSON")
53 | return
54 |
55 | # Process the data
56 | print("Processing data...")
57 | processed_data = preprocess_data(raw_data)
58 |
59 | # Save processed data
60 | print(f"Saving processed data to {args.output}...")
61 | try:
62 | with open(args.output, 'w') as f:
63 | json.dump(processed_data, f, indent=2)
64 | print("Processing completed successfully!")
65 | except Exception as e:
66 | print(f"Error saving output file: {str(e)}")
67 | return
68 |
69 |
70 | if __name__ == "__main__":
71 | main()
72 |
--------------------------------------------------------------------------------
/src/streamlit/layout.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from src.streamlit.ui_component import ChatDisplay, InputArea, SessionControls
3 | from src.streamlit.predefined_questions import PredefinedQuestionsManager
4 |
5 |
6 | class ResearchAssistantUI:
7 | def __init__(self, coordinator, session_state):
8 | self.coordinator = coordinator
9 | self.session_state = session_state
10 | self.chat_display = ChatDisplay()
11 | self.input_area = InputArea(self._handle_user_input)
12 | self.session_controls = SessionControls(self._handle_session_end)
13 | self.predefined_questions = PredefinedQuestionsManager()
14 |
15 | def _get_unique_messages(self, messages):
16 | # Use content to identify unique messages
17 | seen_contents = set()
18 | unique_messages = []
19 | for message in messages:
20 | if message.content not in seen_contents:
21 | seen_contents.add(message.content)
22 | unique_messages.append(message)
23 | return unique_messages
24 |
25 | def render(self):
26 | st.set_page_config(page_title="Research Paper Assistant", layout="wide")
27 | st.title("Research Paper Assistant")
28 |
29 | if not self.session_state.session_active:
30 | st.info("Session has ended. Please refresh the page to start a new session.")
31 | return
32 |
33 | chat_container = st.container()
34 | input_container = st.container()
35 |
36 | with chat_container:
37 | unique_messages = self._get_unique_messages(self.session_state.messages)
38 | self.chat_display.display_messages(unique_messages)
39 |
40 | with input_container:
41 | self.session_controls.render()
42 | self._render_predefined_questions()
43 | self.input_area.render()
44 |
45 | def _render_predefined_questions(self):
46 | if self.session_state.show_predefined:
47 | st.subheader("Predefined Questions")
48 | questions = self.predefined_questions.get_questions()
49 | col1, col2 = st.columns(2)
50 |
51 | for i, question in enumerate(questions):
52 | with col1 if i < len(questions) // 2 else col2:
53 | st.button(
54 | question,
55 | key=f"pred_q_{i}",
56 | on_click=self._handle_predefined_question,
57 | args=(question,)
58 | )
59 |
60 | def _handle_user_input(self, user_input: str):
61 | self.session_state.show_predefined = False
62 | self.coordinator.process_message(user_input, self.session_state)
63 |
64 | def _handle_predefined_question(self, question: str):
65 | self.session_state.show_predefined = False
66 | self.coordinator.process_message(question, self.session_state)
67 |
68 | def _handle_session_end(self):
69 | self.coordinator.cleanup()
70 | self.session_state.messages.clear()
71 | self.session_state.session_active = False
72 | st.info("Session has ended. Please refresh the page to start a new session.")
73 | st.stop()
74 |
75 |
--------------------------------------------------------------------------------
/src/tools/paper_lookup.py:
--------------------------------------------------------------------------------
1 | from src.components.paper.tool import PaperTool
2 | from src.utils.paper_id_extractor import PaperIdExtractor
3 | from langchain.tools.base import BaseTool
4 | from pydantic import PrivateAttr
5 | import time
6 | import json
7 |
8 | from src.components.evaluation.experiment_tracker import ExperimentTracker
9 | from src.components.evaluation.experiment_tracker import MetricsCollector, MetricsData
10 |
11 |
12 | class PaperLookupTool(BaseTool):
13 | name: str = "Paper Lookup"
14 | description: str = "Use this tool to retrieve details about a specific paper by its ID."
15 | _paper_service: PaperTool = PrivateAttr()
16 | _paper_id_extractor: PaperIdExtractor = PrivateAttr()
17 | _experiment_tracker: ExperimentTracker = PrivateAttr()
18 | _metrics_collector: MetricsCollector = PrivateAttr()
19 |
20 | def __init__(
21 | self,
22 | paper_service: PaperTool,
23 | experiment_tracker: ExperimentTracker,
24 | metrics_collector: MetricsCollector
25 | ):
26 | super().__init__()
27 | self._paper_service = paper_service
28 | self._paper_id_extractor = PaperIdExtractor()
29 | self._experiment_tracker = experiment_tracker
30 | self._metrics_collector = metrics_collector
31 |
32 | def _run(self, query: str) -> str:
33 | """Execute the paper lookup functionality."""
34 | # Start timing
35 | start_time = time.time()
36 |
37 | paper_id = None
38 | response_text = ""
39 | success = False
40 | error_msg = None
41 |
42 | try:
43 | paper_id = self._paper_id_extractor.extract(query)
44 | if not paper_id:
45 | response_text = "No valid paper ID found in the message."
46 | return json.dumps({
47 | "ground_truth": "",
48 | "tool_answer": response_text
49 | })
50 |
51 | # Get paper info with metrics
52 | result = self._paper_service.find_paper_by_id(paper_id)
53 | if result["success"]:
54 | # The “official text” from the DB
55 | ground_truth = result["response"]
56 | success = True
57 | tool_answer = (
58 | f"Here is the paper with ID {paper_id}:\n\n"
59 | f"{ground_truth}"
60 | )
61 | else:
62 | ground_truth = ""
63 | tool_answer = f"Paper with ID {paper_id} not found."
64 | except Exception as e:
65 | ground_truth = ""
66 | error_msg = str(e)
67 | tool_answer = f"Error looking up paper: {error_msg}"
68 | success = False
69 |
70 | # End timing
71 | end_time = time.time()
72 | processing_time = end_time - start_time
73 |
74 | # Build MetricsData
75 | metrics_data = MetricsData(
76 | processing_time=processing_time,
77 | query_length=len(query),
78 | response_length=len(response_text),
79 | success=success,
80 | token_count=self._metrics_collector.count_tokens(response_text),
81 | error=error_msg
82 | )
83 |
84 | # Log to CometML
85 | if paper_id:
86 | self._experiment_tracker.log_paper_lookup(paper_id, metrics_data)
87 |
88 | return json.dumps({
89 | "ground_truth": ground_truth,
90 | "tool_answer": tool_answer
91 | })
92 |
--------------------------------------------------------------------------------
/src/agents/research_assistant.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any, List
2 | import time
3 | import json
4 |
5 | from langchain_community.chat_models import ChatOpenAI
6 | from langchain.schema import AIMessage
7 |
8 | from src.components.evaluation.experiment_tracker import ExperimentTracker
9 | from src.core.state import ConversationState
10 | from langchain.memory import ConversationBufferMemory
11 | from dotenv import load_dotenv
12 | from langchain.agents import initialize_agent, AgentType
13 | from langchain.tools.base import BaseTool
14 | from src.agents.base import BaseAgent
15 |
16 | load_dotenv()
17 |
18 | class ResearchAssistant(BaseAgent):
19 | def __init__(
20 | self,
21 | experiment_tracker: ExperimentTracker,
22 | tools: List[BaseTool],
23 | llm: ChatOpenAI,
24 | ):
25 | self.experiment_tracker = experiment_tracker
26 | self.llm = llm
27 | self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
28 | # Initialize the agent with the tools and LLM
29 | self.agent_executor = initialize_agent(
30 | tools,
31 | self.llm,
32 | agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
33 | verbose=True,
34 | memory=self.memory,
35 | handle_parsing_errors=True
36 | )
37 |
38 | def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
39 | last_message = state["messages"][-1]
40 | query_content = last_message.content
41 |
42 | start_time = time.time()
43 | self.experiment_tracker.experiment.log_parameter("input_query", query_content)
44 |
45 | # Add user message to memory
46 | self.memory.chat_memory.add_user_message(query_content)
47 |
48 | # 1) Run the agent
49 | response_text = self.agent_executor.run(input=query_content)
50 |
51 | processing_time = time.time() - start_time
52 | self.experiment_tracker.experiment.log_metrics({
53 | "processing_time": processing_time,
54 | "response_length": len(response_text),
55 | "query_length": len(query_content)
56 | })
57 |
58 | if "metrics" in state:
59 | self.experiment_tracker.experiment.log_metrics(state["metrics"])
60 |
61 | ground_truth = ""
62 | tool_answer = response_text # fallback is the entire text
63 |
64 | try:
65 | parsed = json.loads(response_text)
66 | if isinstance(parsed, dict) and "ground_truth" in parsed:
67 | ground_truth = parsed["ground_truth"]
68 | tool_answer = parsed["tool_answer"]
69 | except:
70 | pass
71 | final_response = tool_answer
72 |
73 | # Add final response to memory and state
74 | state["messages"].append(AIMessage(content=final_response))
75 | self.memory.chat_memory.add_ai_message(final_response)
76 |
77 | return {
78 | "messages": [AIMessage(content=final_response)],
79 | # Put the ground truth somewhere so we can pick it up in coordinator
80 | "tool_output": {"paper_ground_truth": ground_truth}
81 | }
82 |
83 | def process_message(self, state: ConversationState) -> Dict[str, Any]:
84 | """Process a message. This method is required by BaseAgent but is not used."""
85 | # Since we're handling message processing in __call__, we can leave this empty just to not get an error.
86 | pass
87 |
88 | def handle_error(self, error: Exception) -> Dict[str, Any]:
89 | """Generic error handler which is required by BaseAgent but is not used."""
90 | error_message = f"An error occurred: {str(error)}"
91 | return {"messages": [AIMessage(content=error_message)]}
92 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Uplifted RAG Systems with CometML Opik
2 | ## Building an Agentic Knowledge Graph for Enhanced Information Retrieval
3 |
4 | This repository demonstrates an advanced implementation of a Retrieval-Augmented Generation (RAG) system that combines graph-based knowledge representation with agentic capabilities, monitored and optimized through CometML Opik. The system transforms traditional RAG approaches by enabling context-aware, multi-hop reasoning while maintaining comprehensive observability.
5 |
6 | ## Table of Contents
7 | - [Overview](#overview)
8 | - [Key Features](#key-features)
9 | - [System Architecture](#system-architecture)
10 | - [Prerequisites](#prerequisites)
11 | - [Installation](#installation)
12 | - [Usage](#usage)
13 |
14 | ## Overview
15 |
16 | Traditional RAG systems often struggle with linear document retrieval, context blindness, and limited reasoning capabilities. This implementation addresses these challenges by:
17 | - Implementing a graph-based knowledge structure for complex relationship modeling
18 | - Integrating autonomous agents for dynamic query processing
19 | - Providing comprehensive monitoring through CometML Opik
20 | - Enabling multi-hop reasoning across interconnected documents
21 |
22 | ## Key Features
23 |
24 | ### Advanced RAG Capabilities
25 | - Graph-based knowledge representation using Neo4j
26 | - Semantic search with vector embeddings
27 | - Multi-hop reasoning across document relationships
28 | - Context-aware query processing
29 |
30 | ### Agent Integration
31 | - Dynamic tool selection based on query context
32 | - Adaptive exploration strategies
33 | - Comprehensive conversation memory
34 | - Fallback mechanisms for robust operation
35 |
36 | ### Monitoring & Optimization
37 | - Real-time performance tracking with CometML Opik
38 | - Detailed metrics collection and visualization
39 | - Hyperparameter optimization
40 | - Model versioning and experiment tracking
41 |
42 | ### User Interface
43 | - Streamlit-based interactive interface
44 | - Component-based architecture
45 | - Session state management
46 | - Metric visualization
47 |
48 | ## System Architecture
49 |
50 |
51 | The system consists of several key components:
52 |
53 | 1. **Data Pipeline**
54 | - Raw data preprocessing
55 | - Neo4j graph database integration
56 | - Parallel data ingestion system
57 | - Vector store bridging
58 |
59 | 2. **Core RAG Components**
60 | - Question-answering pipeline
61 | - Paper lookup service
62 | - Agent-based tool orchestration
63 | - Coordinator pattern implementation
64 |
65 | 3. **Monitoring Infrastructure**
66 | - Real-time metric collection
67 | - Performance visualization
68 | - Error tracking
69 | - Session analytics
70 |
71 | ## Prerequisites
72 |
73 | - Neo4j Database
74 | - CometML Account
75 | - OpenAI API Key
76 | - [arXiv Dataset from Kaggle](https://www.kaggle.com/datasets/Cornell-University/arxiv)
77 |
78 | ## Installation
79 |
80 | 1. **Clone the Repository**
81 | ```bash
82 | git clone https://github.com/mlvanguards/agentic-graph-rag-evaluation-cometml.git
83 | ```
84 |
85 | 2. **Install Dependencies**
86 | ```bash
87 | pip install -r requirements.txt
88 | ```
89 |
90 | 3. **Environment Setup**
91 | Create a `.env` file with the following credentials:
92 | ```
93 | NEO4J_URI=your_neo4j_uri
94 | NEO4J_USER=your_username
95 | NEO4J_PASSWORD=your_password
96 | COMET_API_KEY=your_comet_key
97 | OPENAI_API_KEY=your_openai_key
98 | ```
99 |
100 | ## Usage
101 |
102 | 1. **Data Preprocessing**
103 | ```python
104 | python -m scripts/preprocess_data.py --input arxiv_data.json --output processed_data.json
105 | ```
106 |
107 | 2. **Database Ingestion**
108 | ```python
109 | cd src/
110 | python -m components/database/ingest.py
111 | ```
112 |
113 | 3. **Start the Application**
114 | ```python
115 | python -m streamlit run main.py
116 | ```
117 |
--------------------------------------------------------------------------------
/src/components/rag/tool.py:
--------------------------------------------------------------------------------
1 | from src.components.database.vector_store import VectorStore
2 | from langchain_core.prompts import PromptTemplate
3 | from langchain_openai import OpenAI
4 | from langchain.chains.llm import LLMChain
5 | from typing import Optional, Dict
6 | from src.components.evaluation.experiment_tracker import MetricsCollector
7 | import time
8 |
9 | class RAG:
10 | def __init__(
11 | self,
12 | vector_store: VectorStore,
13 | openai_api_key: str,
14 | prompt_template: Optional[str] = None
15 | ):
16 | if not openai_api_key:
17 | raise ValueError("OpenAI API key must be provided")
18 | self.vector_store = vector_store
19 | self.llm = OpenAI(openai_api_key=openai_api_key)
20 | self.metrics_collector = MetricsCollector()
21 |
22 | self.prompt_template = PromptTemplate(
23 | input_variables=['context', 'question'],
24 | template=prompt_template or self._default_prompt_template()
25 | )
26 | self.llm_chain = LLMChain(llm=self.llm, prompt=self.prompt_template)
27 |
28 | def _default_prompt_template(self) -> str:
29 | return """
30 | You are an AI assistant knowledgeable about computer science research.
31 |
32 | Context:
33 | {context}
34 |
35 | Question:
36 | {question}
37 |
38 | Provide a detailed and accurate answer based on the context provided.
39 | """
40 |
41 | def answer_question(self, question: str, k: int = 3) -> Dict[str, any]:
42 | start_time = time.time()
43 | try:
44 | # Get context with metrics
45 | context_result = self.get_context(question, k)
46 | context = context_result["context"]
47 | metrics = context_result["metrics"]
48 |
49 | # Generate response
50 | response = self.llm_chain.run(context=context, question=question)
51 |
52 | # Add response metrics
53 | response_stats = self.metrics_collector.get_text_stats(response)
54 | metrics.update({
55 | "response_length": len(response),
56 | "response_tokens": response_stats["token_count"],
57 | "total_processing_time": time.time() - start_time,
58 | "success": True
59 | })
60 |
61 | return {
62 | "response": response,
63 | "metrics": metrics
64 | }
65 | except Exception as e:
66 | return {
67 | "response": f"Error generating answer: {str(e)}",
68 | "metrics": {
69 | "error": str(e),
70 | "success": False,
71 | "total_processing_time": time.time() - start_time
72 | }
73 | }
74 |
75 | def get_context(self, question: str, k: int = 3) -> Dict[str, any]:
76 | start_time = time.time()
77 | try:
78 | relevant_docs = self.vector_store.similarity_search(question, k=k)
79 | context = "\n\n".join([doc for doc, _ in relevant_docs])
80 |
81 | # Collect metrics
82 | context_stats = self.metrics_collector.get_text_stats(context)
83 | question_stats = self.metrics_collector.get_text_stats(question)
84 |
85 | metrics = {
86 | "context_length": len(context),
87 | "context_tokens": context_stats["token_count"],
88 | "context_chunks": len(relevant_docs),
89 | "question_length": len(question),
90 | "question_tokens": question_stats["token_count"],
91 | "retrieval_time": time.time() - start_time,
92 | "success": True
93 | }
94 |
95 | return {
96 | "context": context,
97 | "metrics": metrics
98 | }
99 | except Exception as e:
100 | return {
101 | "context": "",
102 | "metrics": {
103 | "error": str(e),
104 | "success": False,
105 | "retrieval_time": time.time() - start_time
106 | }
107 | }
108 |
109 |
--------------------------------------------------------------------------------
/src/components/evaluation/experiment_tracker.py:
--------------------------------------------------------------------------------
1 | from comet_ml import Experiment
2 | from typing import Dict, Any, Optional
3 | import time
4 | from dataclasses import dataclass
5 | import tiktoken
6 | from opik.evaluation.metrics import Hallucination
7 | @dataclass
8 | class MetricsData:
9 | processing_time: float
10 | query_length: int
11 | response_length: int
12 | success: bool
13 | context_length: int = 0 # Added for RAG context size
14 | token_count: int = 0 # Added for response tokens
15 | error: Optional[str] = None
16 |
17 |
18 | class MetricsCollector:
19 | def __init__(self):
20 | self.tokenizer = tiktoken.encoding_for_model("gpt-4o")
21 |
22 | def count_tokens(self, text: str) -> int:
23 | return len(self.tokenizer.encode(text))
24 |
25 | def get_text_stats(self, text: str) -> Dict[str, int]:
26 | return {
27 | "char_length": len(text),
28 | "token_count": self.count_tokens(text),
29 | "word_count": len(text.split()),
30 | "line_count": len(text.splitlines())
31 | }
32 |
33 | class ExperimentTracker:
34 | def __init__(self, api_key: str, project_name: str):
35 | self.experiment = Experiment(
36 | api_key=api_key,
37 | project_name=project_name
38 | )
39 | self.start_time = time.time()
40 | self.query_count = 0
41 | self.error_count = 0
42 |
43 | def log_paper_lookup(self, paper_id: str, metrics: MetricsData):
44 | """Log metrics for paper lookups."""
45 | self.query_count += 1
46 | if not metrics.success:
47 | self.error_count += 1
48 |
49 | self.experiment.log_metrics({
50 | "paper_lookup_latency": metrics.processing_time,
51 | "paper_id_length": len(paper_id),
52 | "paper_lookup_year": int(paper_id.split('.')[0]) if '.' in paper_id else None,
53 | "paper_lookup_success": int(metrics.success),
54 | "paper_response_length": metrics.response_length,
55 | "paper_response_tokens": metrics.token_count,
56 | "cumulative_queries": self.query_count,
57 | "error_rate": self.error_count / self.query_count if self.query_count > 0 else 0
58 | })
59 | if metrics.error:
60 | self.experiment.log_parameter("error", metrics.error)
61 |
62 | def log_rag_query(self, metrics: MetricsData):
63 | """Log metrics for RAG queries."""
64 | self.query_count += 1
65 | if not metrics.success:
66 | self.error_count += 1
67 |
68 | self.experiment.log_metrics({
69 | "rag_query_length": metrics.query_length,
70 | "rag_response_length": metrics.response_length,
71 | "rag_processing_time": metrics.processing_time,
72 | "rag_success": int(metrics.success),
73 | "rag_context_length": metrics.context_length,
74 | "rag_response_tokens": metrics.token_count,
75 | "cumulative_queries": self.query_count,
76 | "error_rate": self.error_count / self.query_count if self.query_count > 0 else 0
77 | })
78 |
79 | # Log hourly aggregates
80 | hour = time.strftime("%Y-%m-%d-%H")
81 | self.experiment.log_metrics({
82 | f"hourly_queries_{hour}": 1,
83 | f"hourly_errors_{hour}": 0 if metrics.success else 1,
84 | f"hourly_avg_latency_{hour}": metrics.processing_time
85 | })
86 |
87 | def log_session_metrics(self):
88 | """Log overall session metrics."""
89 | session_duration = time.time() - self.start_time
90 | self.experiment.log_metrics({
91 | "session_duration": session_duration,
92 | "total_queries": self.query_count,
93 | "total_errors": self.error_count,
94 | "session_error_rate": self.error_count / self.query_count if self.query_count > 0 else 0,
95 | "queries_per_minute": (self.query_count * 60) / session_duration if session_duration > 0 else 0
96 | })
97 |
98 | def end_session(self, session_metrics: Dict[str, Any]):
99 | """Log final session metrics and end experiment."""
100 | self.log_session_metrics()
101 | for key, value in session_metrics.items():
102 | self.experiment.log_metric(f"session_{key}", value)
103 | self.experiment.end()
--------------------------------------------------------------------------------
/src/components/paper/tool.py:
--------------------------------------------------------------------------------
1 | from src.components.database.neo4j_client import Neo4jClient
2 | from src.components.paper.models import Paper
3 | from typing import Dict, Any
4 | from src.components.evaluation.experiment_tracker import MetricsCollector
5 | import logging
6 | from neo4j.exceptions import AuthError, ServiceUnavailable
7 | import time
8 |
9 | class PaperTool:
10 | def __init__(self, db_client: Neo4jClient):
11 | self.db_client = db_client
12 | self.metrics_collector = MetricsCollector()
13 | self.logger = logging.getLogger(__name__)
14 |
15 | def find_paper_by_id(self, paper_id: str) -> Dict[str, Any]:
16 | start_time = time.time()
17 |
18 | try:
19 | # Test connection before query
20 | self.logger.info(f"Testing connection before paper lookup for {paper_id}")
21 | with self.db_client.session() as test_session:
22 | test_session.run("RETURN 1").single()
23 |
24 | # Perform actual query
25 | self.logger.info(f"Executing paper lookup query for {paper_id}")
26 | with self.db_client.session() as session:
27 | result = session.run(self._get_paper_query(), paper_id=paper_id)
28 | record = result.single()
29 |
30 | if not record:
31 | self.logger.info(f"No paper found with ID {paper_id}")
32 | return {
33 | "response": f"Paper with ID {paper_id} not found.",
34 | "success": False,
35 | "metrics": {
36 | "error": "Paper not found",
37 | "success": False,
38 | "processing_time": time.time() - start_time,
39 | "question_length": len(paper_id)
40 | }
41 | }
42 |
43 | # Create paper object
44 | paper = Paper.from_db_record(record)
45 | paper_text = paper.to_string()
46 |
47 | # Collect metrics
48 | paper_stats = self.metrics_collector.get_text_stats(paper_text)
49 | metrics = {
50 | "success": True,
51 | "processing_time": time.time() - start_time,
52 | "response_length": len(paper_text),
53 | "response_tokens": paper_stats["token_count"],
54 | "word_count": paper_stats["word_count"],
55 | "authors_count": len(record["authors"]),
56 | "categories_count": len(record["categories"]),
57 | "question_length": len(paper_id)
58 | }
59 |
60 | return {
61 | "response": paper_text,
62 | "success": True,
63 | "metrics": metrics
64 | }
65 |
66 | except AuthError as e:
67 | error_msg = f"Authentication failed: {str(e)}"
68 | self.logger.error(f"Authentication error during paper lookup: {str(e)}")
69 | return self._create_error_response(error_msg, start_time, paper_id)
70 | except ServiceUnavailable as e:
71 | error_msg = f"Database service unavailable: {str(e)}"
72 | self.logger.error(f"Service unavailable during paper lookup: {str(e)}")
73 | return self._create_error_response(error_msg, start_time, paper_id)
74 | except Exception as e:
75 | error_msg = f"Error retrieving paper: {str(e)}"
76 | self.logger.error(f"Error during paper lookup: {str(e)}")
77 | return self._create_error_response(error_msg, start_time, paper_id)
78 |
79 | def _create_error_response(self, error_msg: str, start_time: float, paper_id: str) -> Dict[str, Any]:
80 | return {
81 | "response": error_msg,
82 | "success": False,
83 | "metrics": {
84 | "error": error_msg,
85 | "success": False,
86 | "processing_time": time.time() - start_time,
87 | "question_length": len(paper_id)
88 | }
89 | }
90 |
91 | def _get_paper_query(self) -> str:
92 | """Return the paper lookup query."""
93 | return """
94 | MATCH (p:Paper {id: $paper_id})
95 | OPTIONAL MATCH (p)-[:AUTHORED_BY]->(a:Author)
96 | OPTIONAL MATCH (p)-[:BELONGS_TO]->(c:Category)
97 | RETURN p.id as id, p.title AS title, p.abstract AS abstract,
98 | p.submit_date AS submit_date, p.update_date AS update_date,
99 | collect(DISTINCT a.name) AS authors,
100 | collect(DISTINCT c.name) AS categories
101 | """
--------------------------------------------------------------------------------
/src/components/evaluation/custom_metric.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Any, List, Optional, Union
3 | import pydantic
4 | from opik.evaluation.metrics import base_metric, score_result
5 | from opik.evaluation.models import litellm_chat_model
6 | from opik.evaluation.models import base_model
7 | import logging
8 | LOGGER = logging.getLogger(__name__)
9 |
10 |
11 | class AnswerCompletenessResponseFormat(pydantic.BaseModel):
12 | answer_completeness_score: float
13 | reason: str
14 |
15 |
16 | class AnswerCompleteness(base_metric.BaseMetric):
17 | """
18 | A metric that evaluates the completeness of an answer relative to the user's input.
19 |
20 | This metric uses an LLM to assess whether the generated answer fully addresses all aspects of the user's query.
21 |
22 | Args:
23 | model: The language model to use for evaluation. Defaults to "gpt-4o".
24 | name: The name of the metric. Defaults to "answer_completeness_metric".
25 | few_shot_examples: Optional list of few-shot examples to guide the LLM's evaluation.
26 | track: Whether to track the metric. Defaults to True.
27 | """
28 |
29 | def __init__(
30 | self,
31 | model: Optional[Union[str, base_model.OpikBaseModel]] = None,
32 | name: str = "answer_completeness_metric",
33 | few_shot_examples: Optional[List[Any]] = None, # Replace `Any` with actual type if defined
34 | track: bool = True,
35 | ):
36 | super().__init__(
37 | name=name,
38 | track=track,
39 | )
40 | self._init_model(model)
41 | if few_shot_examples is None:
42 | self._few_shot_examples = [] # Define default few-shot examples if available
43 | else:
44 | self._few_shot_examples = few_shot_examples
45 |
46 | def _init_model(
47 | self, model: Optional[Union[str, base_model.OpikBaseModel]]
48 | ) -> None:
49 | if isinstance(model, base_model.OpikBaseModel):
50 | self._model = model
51 | else:
52 | self._model = litellm_chat_model.LiteLLMChatModel(model_name=model or "gpt-4o")
53 |
54 | def score(
55 | self, input: str, output: str, context: List[str], **ignored_kwargs: Any
56 | ) -> score_result.ScoreResult:
57 | """
58 | Calculate the answer completeness score for the given input-output pair.
59 |
60 | Args:
61 | input: The user's question or prompt.
62 | output: The LLM-generated answer.
63 | context: A list of context strings relevant to the input.
64 |
65 | Returns:
66 | score_result.ScoreResult: Contains the completeness score and reason.
67 | """
68 | llm_query = self._generate_prompt(input, output, context)
69 | model_output = self._model.generate_string(
70 | input=llm_query, response_format=AnswerCompletenessResponseFormat
71 | )
72 | return self._parse_model_output(model_output)
73 |
74 | async def ascore(
75 | self, input: str, output: str, context: List[str], **ignored_kwargs: Any
76 | ) -> score_result.ScoreResult:
77 | """
78 | Asynchronously calculate the answer completeness score for the given input-output pair.
79 |
80 | Args:
81 | input: The user's question or prompt.
82 | output: The LLM-generated answer.
83 | context: A list of context strings relevant to the input.
84 |
85 | Returns:
86 | score_result.ScoreResult: Contains the completeness score and reason.
87 | """
88 | llm_query = self._generate_prompt(input, output, context)
89 | model_output = await self._model.agenerate_string(
90 | input=llm_query, response_format=AnswerCompletenessResponseFormat
91 | )
92 | return self._parse_model_output(model_output)
93 |
94 | def _generate_prompt(self, input_text: str, output_text: str, context: List[str]) -> str:
95 | """
96 | Generate the prompt to send to the LLM for evaluation.
97 |
98 | Args:
99 | input_text: The user's question or prompt.
100 | output_text: The LLM-generated answer.
101 | context: Relevant context information.
102 |
103 | Returns:
104 | str: The complete prompt for the LLM.
105 | """
106 | context_combined = " ".join(context)
107 | prompt = f"""
108 | YOU ARE AN EXPERT IN NLP EVALUATION METRICS, TRAINED TO ASSESS THE COMPLETENESS OF ANSWERS PROVIDED BY LANGUAGE MODELS.
109 |
110 | ###INSTRUCTIONS###
111 | - ANALYZE THE GIVEN USER INPUT AND THE GENERATED ANSWER.
112 | - DETERMINE IF THE ANSWER FULLY ADDRESSES ALL ASPECTS OF THE USER'S QUERY.
113 | - ASSIGN A COMPLETENESS SCORE BETWEEN 0.0 (COMPLETELY INCOMPLETE) AND 1.0 (FULLY COMPLETE).
114 | - PROVIDE A BRIEF REASON FOR THE SCORE, HIGHLIGHTING WHICH PARTS WERE ADDRESSED OR MISSING.
115 |
116 | ###EXAMPLE###
117 | Input: "Explain the concept of transformer models in NLP and provide examples."
118 | Answer: "Transformer models are neural network architectures used in NLP that employ self-attention mechanisms..."
119 | Context: "Transformer models are neural network architectures used in NLP that employ self-attention mechanisms to process entire sentences simultaneously, capturing long-range dependencies and context."
120 | ---
121 | {
122 | "answer_completeness_score": 0.9,
123 | "reason": "The answer thoroughly explains transformer models and provides examples, fully addressing the user's request."
124 | }
125 |
126 | ###INPUTS:###
127 | ***
128 | User input:
129 | {input_text}
130 | Answer:
131 | {output_text}
132 | Contexts:
133 | {context_combined}
134 | ***
135 | """
136 | return prompt
137 |
138 | def _parse_model_output(self, content: str) -> score_result.ScoreResult:
139 | """
140 | Parse the LLM's JSON response into a ScoreResult object.
141 |
142 | Args:
143 | content: The JSON string returned by the LLM.
144 |
145 | Returns:
146 | score_result.ScoreResult: The parsed score and reason.
147 | """
148 | dict_content = json.loads(content)
149 | score: float = dict_content.get("answer_completeness_score", 0.5)
150 | reason: str = dict_content.get("reason", "No reason provided.")
151 |
152 | # Validate score range
153 | if not (0.0 <= score <= 1.0):
154 | LOGGER.warning("Received score out of bounds. Defaulting to 0.5.")
155 | score = 0.5
156 |
157 | return score_result.ScoreResult(
158 | name=self.name, value=score, reason=reason
159 | )
160 |
161 |
--------------------------------------------------------------------------------
/src/components/evaluation/opik_evaluator.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import logging
3 |
4 | from opik.evaluation.metrics import (
5 | Contains,
6 | Equals,
7 | LevenshteinRatio,
8 | Hallucination,
9 | Moderation,
10 | AnswerRelevance,
11 | GEval
12 | )
13 | from opik.evaluation.metrics.score_result import ScoreResult
14 |
15 | from src.components.evaluation.custom_metric import AnswerCompleteness
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | TASK_INTRODUCTION = (
20 | "You are an expert judge tasked with evaluating the faithfulness of an AI-generated answer to the given context."
21 | )
22 |
23 | EVALUATION_CRITERIA = """
24 | - The OUTPUT must not introduce new information beyond what's provided in the CONTEXT.
25 | - The OUTPUT must not contradict any information given in the CONTEXT.
26 | - The OUTPUT should be logically consistent and coherent.
27 | - The OUTPUT should comprehensively address all aspects of the user's query.
28 | """
29 |
30 | class LlmEvaluator:
31 | def __init__(self):
32 | """
33 | Hard-code a set of references and context for your 0704.0001 paper.
34 | Now we rely on the default opik classes directly.
35 | """
36 | self.metrics = {
37 | "contains_diphoton": Contains(name="contains_diphoton", case_sensitive=False),
38 | "contains_berger": Contains(name="contains_berger", case_sensitive=False),
39 | "equals_title": Equals(name="equals_title"),
40 | "lev_ratio_abstract": LevenshteinRatio(name="lev_ratio_abstract"),
41 | }
42 |
43 | self.hallucination_metric = Hallucination()
44 | self.moderation_metric = Moderation()
45 | self.answer_relevance_metric = AnswerRelevance()
46 | self.g_eval_metric = GEval(task_introduction=TASK_INTRODUCTION, evaluation_criteria=EVALUATION_CRITERIA)
47 |
48 | # Custom metric
49 | self.answer_completeness_metric = AnswerCompleteness()
50 |
51 | self.abstract_0704_0001 = (
52 | "A fully differential calculation in perturbative quantum chromodynamics is\n"
53 | "presented for the production of massive photon pairs at hadron colliders. All\n"
54 | "next-to-leading order perturbative contributions from quark-antiquark,\n"
55 | "gluon-(anti)quark, and gluon-gluon subprocesses are included, as well as\n"
56 | "all-orders resummation of initial-state gluon radiation valid at\n"
57 | "next-to-next-to-leading logarithmic accuracy. The region of phase space is\n"
58 | "specified in which the calculation is most reliable. Good agreement is\n"
59 | "demonstrated with data from the Fermilab Tevatron, and predictions are made for\n"
60 | "more detailed tests with CDF and DO data. Predictions are shown for\n"
61 | "distributions of diphoton pairs produced at the energy of the Large Hadron\n"
62 | "Collider (LHC). Distributions of the diphoton pairs from the decay of a Higgs\n"
63 | "boson are contrasted with those produced from QCD processes at the LHC, showing\n"
64 | "that enhanced sensitivity to the signal can be obtained with judicious\n"
65 | "selection of events.\n"
66 | )
67 | self.context_0704_0001 = [
68 | "Title: Calculation of prompt diphoton production cross sections at Tevatron and\n LHC energies",
69 | f"Abstract: {self.abstract_0704_0001}",
70 | "Authors: Balázs C., Berger E. L., Nadolsky P. M., Yuan C. -P.",
71 | "Submitted on: Mon, 2 Apr 2007 19:18:42 GMT",
72 | "Updated on: 2008-11-26"
73 | ]
74 |
75 | self.static_references = {
76 | "contains_diphoton": "diphoton",
77 | "contains_berger": "Berger",
78 | "equals_title": "Calculation of prompt diphoton production cross sections at Tevatron and\n LHC energies",
79 | "lev_ratio_abstract": self.abstract_0704_0001,
80 | }
81 |
82 | self.answer_context = (
83 | "Transformer models are neural network architectures used in NLP "
84 | "that employ self-attention mechanisms to process entire sentences "
85 | "simultaneously, capturing long-range dependencies and context. They "
86 | "use positional encoding to differentiate between words in different "
87 | "positions within a sentence."
88 | )
89 |
90 |
91 | def evaluate(self, output: str) -> Dict[str, ScoreResult]:
92 | """
93 | Evaluate an LLM output with your *static* references for 0704.0001.
94 | Return a dict of {metric_name -> ScoreResult} objects.
95 | """
96 | results = {}
97 | for metric_name, metric_obj in self.metrics.items():
98 | # e.g. "contains_diphoton" => reference "diphoton"
99 | ref = self.static_references[metric_name]
100 | score_res = metric_obj.score(output=output, reference=ref)
101 | results[metric_name] = score_res
102 | return results
103 |
104 | def check_hallucination(self, input_text: str, output_text: str) -> ScoreResult:
105 | """
106 | 0 = no hallucination, 1 = hallucination found.
107 | """
108 | return self.hallucination_metric.score(
109 | input=input_text,
110 | output=output_text,
111 | context=self.context_0704_0001
112 | )
113 |
114 | def check_moderation(self, output_text: str) -> ScoreResult:
115 | """
116 | 0.0 => safe, up to 1.0 => extremely unsafe
117 | """
118 | return self.moderation_metric.score(output=output_text)
119 |
120 | def check_answer_relevance(
121 | self,
122 | input_text: str,
123 | output_text: str,
124 | ) -> float:
125 | """
126 | Return a float in [0..1], measuring how relevant `output_text` is
127 | to `input_text` given `context_snippet`.
128 | """
129 | score_result = self.answer_relevance_metric.score(
130 | input=input_text,
131 | output=output_text,
132 | context=[self.answer_context]
133 | )
134 | return score_result.value
135 |
136 | def check_g_eval(
137 | self, output_text: str
138 | ) -> float:
139 | """
140 | Evaluate the LLM's output using GEval metric.
141 | """
142 | score_result = self.g_eval_metric.score(
143 | output=output_text
144 | )
145 | return score_result.value
146 |
147 | def check_answer_completeness(
148 | self, input_text: str, output_text: str
149 | ) -> float:
150 | """
151 | Check how complete the LLM's answer is relative to the user's input.
152 |
153 | Args:
154 | input_text: The user's question or prompt.
155 | output_text: The LLM-generated answer.
156 |
157 | Returns:
158 | float: Completeness score between 0.0 and 1.0.
159 | """
160 | score_result = self.answer_completeness_metric.score(
161 | input=input_text,
162 | output=output_text,
163 | context=[self.answer_context]
164 | )
165 | return score_result.value
166 |
167 |
--------------------------------------------------------------------------------
/src/orchestrator/coordinator.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | import time
3 | import logging
4 | from typing import Dict, Any
5 |
6 | from langchain_community.chat_models import ChatOpenAI
7 | from langchain.schema import HumanMessage, AIMessage
8 |
9 | from src.config.settings import Settings
10 | from src.core.graph import create_research_graph
11 | from src.components.database.neo4j_client import Neo4jClient
12 | from src.components.paper.tool import PaperTool
13 | from src.components.rag.tool import RAG
14 | from src.components.rag.embeddings import Embedding
15 | from src.components.database.vector_store import VectorStore
16 | from src.components.evaluation.experiment_tracker import ExperimentTracker, MetricsCollector
17 | from src.components.evaluation.opik_evaluator import LlmEvaluator
18 | from src.tools.paper_lookup import PaperLookupTool
19 | from src.tools.rag import RAGTool
20 | from src.agents.research_assistant import ResearchAssistant
21 | from dotenv import load_dotenv
22 |
23 | load_dotenv()
24 |
25 | # Configure logging
26 | logging.basicConfig(
27 | level=logging.INFO,
28 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
29 | )
30 | logging.getLogger("httpx").setLevel(logging.WARNING)
31 | logging.getLogger("urllib3").setLevel(logging.WARNING)
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | class Coordinator:
37 | def __init__(self):
38 | self.settings = Settings()
39 | self.experiment_tracker = self.setup_experiment_tracker()
40 | self.metrics_collector = MetricsCollector()
41 | self.services = self.initialize_services()
42 | self.tools = self.initialize_tools()
43 | self.assistant = self.initialize_assistant()
44 | self.graph = self.setup_graph()
45 |
46 | # Instantiate our new evaluator
47 | self.llm_evaluator = LlmEvaluator()
48 |
49 | def setup_experiment_tracker(self) -> ExperimentTracker:
50 | tracker = ExperimentTracker(
51 | api_key=self.settings.cometml_api_key,
52 | project_name=self.settings.project_name
53 | )
54 | tracker.experiment.add_tags(['v1', 'graph-rag', 'research-papers'])
55 | tracker.experiment.log_parameter("session_id", str(uuid.uuid4()))
56 | tracker.experiment.log_parameter("session_start", time.strftime("%Y-%m-%d %H:%M:%S"))
57 | return tracker
58 |
59 | def initialize_services(self) -> Dict[str, Any]:
60 | db_client = Neo4jClient(
61 | uri=self.settings.neo4j_uri,
62 | user=self.settings.neo4j_user,
63 | password=self.settings.neo4j_password
64 | )
65 | with db_client.session() as session:
66 | result = session.run("RETURN 1 as num").single()
67 | print(f"Initial connection test result: {result['num']}")
68 | embedding_service = Embedding(api_key=self.settings.openai_api_key)
69 | vector_store = VectorStore(
70 | neo4j_client=db_client,
71 | embedding_model=embedding_service.model,
72 | index_name="paper_vector_index"
73 | )
74 | paper_service = PaperTool(db_client=db_client)
75 | rag_service = RAG(
76 | vector_store=vector_store,
77 | openai_api_key=self.settings.openai_api_key
78 | )
79 |
80 | return {
81 | "db_client": db_client,
82 | "embedding_service": embedding_service,
83 | "vector_store": vector_store,
84 | "paper_service": paper_service,
85 | "rag_service": rag_service
86 | }
87 |
88 | def initialize_tools(self) -> Dict[str, Any]:
89 | paper_lookup_tool = PaperLookupTool(
90 | paper_service=self.services["paper_service"],
91 | experiment_tracker=self.experiment_tracker,
92 | metrics_collector=self.metrics_collector
93 | )
94 | rag_tool = RAGTool(
95 | rag_service=self.services["rag_service"],
96 | experiment_tracker=self.experiment_tracker,
97 | metrics_collector=self.metrics_collector
98 | )
99 | return {
100 | "paper_lookup": paper_lookup_tool,
101 | "rag": rag_tool
102 | }
103 |
104 | def initialize_assistant(self) -> ResearchAssistant:
105 | llm = ChatOpenAI(
106 | temperature=0,
107 | openai_api_key=self.settings.openai_api_key
108 | )
109 | tools = [self.tools["paper_lookup"], self.tools["rag"]]
110 | return ResearchAssistant(
111 | experiment_tracker=self.experiment_tracker,
112 | tools=tools,
113 | llm=llm
114 | )
115 |
116 | def setup_graph(self) -> Any:
117 | return create_research_graph(
118 | assistant=self.assistant,
119 | rag_tool=self.tools["rag"],
120 | paper_lookup_tool=self.tools["paper_lookup"]
121 | )
122 |
123 | def process_message(self, message: str, state: Dict[str, Any]) -> None:
124 | try:
125 | # Store user message
126 | state["messages"].append(HumanMessage(content=message))
127 |
128 | # Log conversation metrics
129 | self.experiment_tracker.experiment.log_metrics({
130 | "conversation_turn": len(state["messages"]),
131 | "message_length": len(message)
132 | })
133 |
134 | # Assistant response (LangChain chain)
135 | response = self.assistant(state)
136 | ai_messages = response["messages"]
137 | state["messages"].extend(ai_messages)
138 |
139 | # Grab the final user-facing output
140 | ai_text = ai_messages[-1].content if ai_messages else ""
141 |
142 | # Hallucination score
143 | hallucination_score = self.llm_evaluator.check_hallucination(message, ai_text)
144 | self.experiment_tracker.experiment.log_metric("hallucination_score", hallucination_score)
145 |
146 | # Moderation score
147 | moderation_score = self.llm_evaluator.check_moderation(ai_text)
148 | self.experiment_tracker.experiment.log_metric("moderation_score", moderation_score)
149 |
150 | # Evaluate references (Contains, Equals, LevenshteinRatio)
151 | metric_scores = self.llm_evaluator.evaluate(ai_text)
152 | for metric_name, score_result in metric_scores.items():
153 | self.experiment_tracker.experiment.log_metric(metric_name, score_result.value)
154 |
155 | # Answer relevance
156 | relevance_score = self.llm_evaluator.check_answer_relevance(
157 | input_text=message,
158 | output_text=ai_text,
159 | )
160 |
161 | # GEval metric
162 | g_eval_score = self.llm_evaluator.check_g_eval(
163 | output_text=ai_text
164 | )
165 | self.experiment_tracker.experiment.log_metric("g_eval_score", g_eval_score)
166 |
167 | self.experiment_tracker.experiment.log_metric("answer_relevance_score", relevance_score)
168 | logger.info(f"Answer Relevance score: {relevance_score}")
169 |
170 | # Print final answer
171 | for msg in ai_messages:
172 | if isinstance(msg, AIMessage):
173 | self.experiment_tracker.experiment.log_metrics({"response_length": len(msg.content)})
174 | print(f"Assistant: {msg.content}")
175 |
176 | except Exception as e:
177 | self.experiment_tracker.experiment.log_metric("errors", 1)
178 | print(f"Error in process_message: {str(e)}")
179 |
180 | def run(self):
181 | try:
182 | print("Research Paper Assistant initialized. Type 'exit' to quit.")
183 | state = {
184 | "messages": [],
185 | "metrics": {},
186 | "conversation_history": []
187 | }
188 | while True:
189 | user_input = input("\nYour question: ").strip()
190 | if user_input.lower() in ['exit', 'quit', 'bye']:
191 | self.cleanup()
192 | break
193 | self.process_message(user_input, state)
194 | except KeyboardInterrupt:
195 | print("\n\nSession interrupted by user.")
196 | self.cleanup()
197 | except Exception as e:
198 | logger.error(f"Unexpected error: {str(e)}")
199 | print(f"\n\nAn error occurred: {str(e)}")
200 | self.cleanup()
201 |
202 | def cleanup(self):
203 | try:
204 | if hasattr(self, 'experiment_tracker'):
205 | session_end_time = time.strftime("%Y-%m-%d %H:%M:%S")
206 | self.experiment_tracker.experiment.log_parameter("session_end", session_end_time)
207 | final_metrics = {
208 | "total_messages": len(self.assistant.memory.chat_memory.messages),
209 | "total_user_messages": len(
210 | [m for m in self.assistant.memory.chat_memory.messages if isinstance(m, HumanMessage)]
211 | ),
212 | "total_ai_messages": len(
213 | [m for m in self.assistant.memory.chat_memory.messages if isinstance(m, AIMessage)]
214 | )
215 | }
216 | self.experiment_tracker.experiment.log_metrics(final_metrics)
217 | self.experiment_tracker.experiment.end()
218 | self.services["db_client"].close()
219 | print("\nSession ended. Thank you for using the Research Paper Assistant!")
220 | except Exception as e:
221 | logger.error(f"Error during cleanup: {str(e)}")
222 |
--------------------------------------------------------------------------------