├── src └── langgraph_engineer │ ├── __init__.py │ ├── verify_branch.py │ ├── configuration.py │ ├── check.py │ ├── model.py │ ├── loader.py │ ├── summarize.py │ ├── summarize_node.py │ ├── test_run.py │ ├── post_critique_router.py │ ├── setup_node.py │ ├── gather_requirements.py │ ├── route_message.py │ ├── git_push_node.py │ ├── critique.py │ ├── diff_node.py │ ├── state.py │ ├── aider_node.py │ ├── agent.py │ ├── interactive_aider.py │ └── tools.py ├── pyproject.toml ├── langgraph.json ├── Dockerfile ├── lmsystems.py ├── TODO.md └── README.md /src/langgraph_engineer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "langgraph_eng_package" 3 | version = "0.0.1" 4 | dependencies = [ 5 | "langgraph", 6 | "langchain_anthropic", 7 | "langchain_core", 8 | "langchain_openai", 9 | "gitpython", 10 | "python-dotenv", 11 | "pydantic", 12 | "aider-chat[browser]", 13 | "semantic-router", 14 | "libtmux" 15 | ] 16 | 17 | [build-system] 18 | requires = ["setuptools >= 61.0"] 19 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /langgraph.json: -------------------------------------------------------------------------------- 1 | { 2 | "dockerfile_lines": [ 3 | "RUN apt-get update && apt-get install -y git bsdutils", 4 | "RUN mkdir -p /repos", 5 | "RUN pip install aider-chat[browser]", 6 | "ENV TERM=xterm-256color", 7 | "ENV COLUMNS=80", 8 | "ENV LINES=24", 9 | "RUN chmod 777 /repos", 10 | "WORKDIR /repos" 11 | ], 12 | "graphs": { 13 | "engineer": "./src/langgraph_engineer/agent.py:graph" 14 | }, 15 | "env": ".env", 16 | "python_version": "3.11", 17 | "dependencies": [ 18 | "." 19 | ] 20 | } -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM langchain/langgraph-api:3.11 2 | 3 | # Install git and other necessary tools including script utility 4 | RUN apt-get update && apt-get install -y git bsdutils 5 | 6 | RUN mkdir -p /repos 7 | 8 | ADD . /deps/LM-Systems 9 | 10 | RUN PYTHONDONTWRITEBYTECODE=1 pip install --no-cache-dir -c /api/constraints.txt -e /deps/* 11 | 12 | ENV LANGSERVE_GRAPHS='{"engineer": "/deps/LM-Systems/src/langgraph_engineer/agent.py:graph"}' 13 | 14 | # Set terminal environment variables 15 | ENV TERM=xterm-256color 16 | ENV COLUMNS=80 17 | ENV LINES=24 18 | 19 | WORKDIR /repos 20 | 21 | RUN chmod 777 /repos -------------------------------------------------------------------------------- /src/langgraph_engineer/verify_branch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from git import Repo 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | def verify_branch(repo_path: str, expected_branch: str) -> bool: 7 | """Verify we're on the correct branch before operations.""" 8 | try: 9 | repo = Repo(repo_path) 10 | current_branch = repo.active_branch.name 11 | if current_branch != expected_branch: 12 | logger.error(f"Branch mismatch: expected {expected_branch}, got {current_branch}") 13 | return False 14 | return True 15 | except Exception as e: 16 | logger.error(f"Branch verification failed: {str(e)}") 17 | return False 18 | -------------------------------------------------------------------------------- /src/langgraph_engineer/configuration.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, fields 2 | from typing import Any, Optional 3 | from langchain_core.runnables import RunnableConfig 4 | import os 5 | 6 | @dataclass(kw_only=True) 7 | class Configuration: 8 | """The configurable fields for the LangGraph Engineer.""" 9 | openai_api_key: Optional[str] = None 10 | anthropic_api_key: Optional[str] = None 11 | 12 | @classmethod 13 | def from_runnable_config( 14 | cls, config: Optional[RunnableConfig] = None 15 | ) -> "Configuration": 16 | """Create a Configuration instance from a RunnableConfig.""" 17 | configurable = ( 18 | config["configurable"] if config and "configurable" in config else {} 19 | ) 20 | 21 | # First check config, then environment variables 22 | values: dict[str, Any] = { 23 | f.name: configurable.get(f.name, os.environ.get(f.name.upper())) 24 | for f in fields(cls) 25 | if f.init 26 | } 27 | 28 | return cls(**{k: v for k, v in values.items() if v is not None}) 29 | -------------------------------------------------------------------------------- /src/langgraph_engineer/check.py: -------------------------------------------------------------------------------- 1 | import re 2 | from langgraph_engineer.state import AgentState 3 | 4 | 5 | def extract_python_code(text): 6 | pattern = r'```python\s*(.*?)\s*(```|$)' 7 | matches = re.findall(pattern, text, re.DOTALL) 8 | return matches 9 | 10 | 11 | error_parsing = """Make sure your response contains a code block in the following format: 12 | 13 | ```python 14 | ... 15 | ``` 16 | 17 | When trying to parse out that code block, got this error: {error}""" 18 | 19 | 20 | def check(state: AgentState): 21 | last_answer = state['messages'][-1] 22 | try: 23 | code_blocks = extract_python_code(last_answer.content) 24 | except Exception as e: 25 | return {"messages": [{"role": "user", "content": error_parsing.format(error=str(e))}]} 26 | if len(code_blocks) == 0: 27 | return {"messages": [{"role": "user", "content": error_parsing.format(error="Did not find a code block!")}]} 28 | if len(code_blocks) > 1: 29 | return {"messages": [{"role": "user", "content": error_parsing.format(error="Found multiple code blocks!")}]} 30 | return {"code": f"```python\n{code_blocks[0][0]}\n```"} 31 | -------------------------------------------------------------------------------- /src/langgraph_engineer/model.py: -------------------------------------------------------------------------------- 1 | from langchain_openai import ChatOpenAI 2 | from langchain_anthropic import ChatAnthropic 3 | from .configuration import Configuration 4 | from langchain_core.runnables import RunnableConfig 5 | from typing import Optional 6 | 7 | 8 | def _get_model(config: Optional[RunnableConfig], default: str, key: str): 9 | # Get configuration instance 10 | config_instance = Configuration.from_runnable_config(config) 11 | 12 | # Get model type from config, defaulting to 'anthropic' 13 | model_type = config['configurable'].get('model', 'anthropic') if config and 'configurable' in config else 'anthropic' 14 | 15 | if model_type == "openai": 16 | return ChatOpenAI( 17 | temperature=0, 18 | model_name="gpt-4o", 19 | api_key=config_instance.openai_api_key 20 | ) 21 | elif model_type == "anthropic": 22 | return ChatAnthropic( 23 | temperature=0, 24 | model_name="claude-3-haiku-20240307", 25 | api_key=config_instance.anthropic_api_key 26 | ) 27 | else: 28 | raise ValueError(f"Unknown model type: {model_type}") 29 | -------------------------------------------------------------------------------- /lmsystems.py: -------------------------------------------------------------------------------- 1 | from lmsystems.client import LmsystemsClient 2 | from dotenv import load_dotenv 3 | import os 4 | import asyncio 5 | 6 | # Load environment variables 7 | load_dotenv() 8 | 9 | # Async usage 10 | async def main(): 11 | # Simple initialization with just graph name and API key 12 | client = await LmsystemsClient.create( 13 | graph_name="github-agent-48", 14 | api_key=os.environ["LMSYSTEMS_API_KEY"] 15 | ) 16 | 17 | # Create thread and run with error handling 18 | try: 19 | thread = await client.create_thread() 20 | 21 | run = await client.create_run( 22 | thread, 23 | input={"messages": [{"role": "user", "content": "What's this repo about?"}], 24 | "repo_url": "", 25 | "repo_path": "", 26 | "branch_name": "", 27 | "github_token": "", 28 | "accepted": False, 29 | "model_name": "sonnet" 30 | } 31 | ) 32 | 33 | # Stream response 34 | async for chunk in client.stream_run(thread, run): 35 | print(chunk) 36 | 37 | except Exception as e: 38 | print(f"Error: {str(e)}") 39 | 40 | if __name__ == "__main__": 41 | asyncio.run(main()) 42 | -------------------------------------------------------------------------------- /src/langgraph_engineer/loader.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from functools import lru_cache 3 | import time 4 | import requests 5 | 6 | 7 | CACHE_DURATION = 24 * 60 * 60 8 | 9 | 10 | def time_based_cache(seconds): 11 | def wrapper_cache(func): 12 | func = lru_cache(maxsize=None)(func) 13 | func.lifetime = seconds 14 | func.expiration = time.time() + func.lifetime 15 | 16 | @functools.wraps(func) 17 | def wrapped_func(*args, **kwargs): 18 | if time.time() >= func.expiration: 19 | func.cache_clear() 20 | func.expiration = time.time() + func.lifetime 21 | return func(*args, **kwargs) 22 | 23 | return wrapped_func 24 | 25 | return wrapper_cache 26 | 27 | 28 | @time_based_cache(CACHE_DURATION) 29 | def load_github_file(url): 30 | # Convert GitHub URL to raw content URL 31 | raw_url = url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/") 32 | 33 | # Send a GET request to the raw URL 34 | response = requests.get(raw_url) 35 | 36 | # Check if the request was successful 37 | if response.status_code == 200: 38 | return response.text 39 | else: 40 | return f"Failed to load file. Status code: {response.status_code}" 41 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # Bugs that need to be fixed: 2 | 3 | 1. The aider node's output is being streamed to the terminal, but not to the langgraph api endpoint in real time. instead, the entire response is sent when the aider node is finished. 4 | 5 | **The goal:** the aider node's output should be streamed to the langgraph api endpoint in real time. 6 | **Replicate Bug:** Run this graph [locally](https://langchain-ai.github.io/langgraph/tutorials/langgraph-platform/local-server/#create-a-env-file), then get the API endpoint from the local graph ```http://127.0.0.1:2024``` and input it into the ```connecting.ipynb``` file. Then run the langgraph client sdk in this notebook. You'll be able to see the output streamed in the terminal, but not in the langgraph api endpoint. 7 | 8 | This aider node is my attempt at replicating this [streamlit app from aider](https://github.com/Aider-AI/aider/blob/main/aider/gui.py), so it might be beneficial to look at that code. 9 | 10 | 11 | # Features I'd like to see: 12 | 13 | #### More Agentic Approaches 14 | I'd like to see more agentic approaches where we go to collect more information from the repo by first asking a duplicate aider node a question about the repo, then re-prompting the original prompt with the new information. 15 | 16 | To take it a step further, we can run multiple of these 'ask' nodes in [parallel](https://langchain-ai.github.io/langgraph/how-tos/branching/) where they ask the same or other repos (likely open source ones), or use a web search tool to (exa, perplexity, google deep research, etc.) to search about a particular package or something. 17 | 18 | 19 | #### Integrations with other tools to close the feedback loop 20 | 21 | Allowing a user to connect to vercel in some way would be largely beneficial for users building nextjs sites. Integrations like this would be great! 22 | 23 | 24 | -------------------------------------------------------------------------------- /src/langgraph_engineer/summarize.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from langgraph_engineer.model import _get_model 3 | from langgraph_engineer.state import AgentState 4 | from langchain_core.messages import SystemMessage, AIMessage 5 | 6 | summarize_prompt = """You are a helpful AI assistant that summarizes technical information clearly and concisely. 7 | 8 | Review the provided context and generate a clear, helpful response to the user's original question. 9 | 10 | Focus on: 11 | - Directly answering the user's question 12 | - Being concise but thorough 13 | - Using clear, simple language 14 | - Including relevant technical details when necessary 15 | 16 | IMPORTANT: DO NOT MAKE ANYTHING UP. ONLY USE THE INFORMATION PROVIDED IN THE CONTEXT. 17 | 18 | Context from previous analysis: 19 | {context} 20 | 21 | Original question: 22 | {requirements} 23 | 24 | Provide your summarized response from the context above:""" 25 | 26 | def summarize_response(state: AgentState, config: Dict[str, Any]) -> AgentState: 27 | """Summarize the aider output and generate a clear response""" 28 | # Get the model 29 | model = _get_model(config, "openai", "summarize_model") 30 | 31 | # Get requirements from state following router_agent.py pattern 32 | requirements = state.get('requirements', '') 33 | if not requirements and state.get('messages'): 34 | # If no requirements but we have messages, use the last human message 35 | for msg in reversed(state['messages']): 36 | if msg.type == 'human': 37 | requirements = msg.content 38 | break 39 | 40 | # Get the aider output from the last step result 41 | aider_output = "" 42 | if state.get("step_results"): 43 | last_result = list(state["step_results"].values())[-1] 44 | aider_output = last_result.get("output", "") 45 | 46 | # Format the prompt 47 | formatted_prompt = summarize_prompt.format( 48 | context=aider_output, 49 | requirements=requirements # Use requirements here instead of original_question 50 | ) 51 | 52 | # Create message sequence 53 | messages = [ 54 | SystemMessage(content=formatted_prompt) 55 | ] 56 | 57 | # Get response 58 | response = model.invoke(messages) 59 | 60 | # Update state with summary 61 | state["summary"] = response.content 62 | 63 | return state -------------------------------------------------------------------------------- /src/langgraph_engineer/summarize_node.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from langgraph_engineer.model import _get_model 3 | from langgraph_engineer.state import AgentState 4 | from langchain_core.messages import SystemMessage, AIMessage 5 | import logging 6 | import os 7 | logger = logging.getLogger(__name__) 8 | 9 | summarize_prompt = """Please Communicate the latest message from the coding AI agent (as if you were the coding AI Agent) to the user given their prompt. Do not make anything up that was not presented in the coding AI Agent response. 10 | 11 | User's message: 12 | {user_message} 13 | 14 | Coding AI Agent response: 15 | {last_response} 16 | 17 | Provide a concise summary to the user on behalf of the coding AI Agent:""" 18 | 19 | def summarize_response(state: AgentState) -> AgentState: 20 | """Summarize the aider output and store it in state""" 21 | try: 22 | logger.debug("=== Summarize Node Starting ===") 23 | 24 | # Get the last AI message 25 | messages = state.get('messages', []) 26 | if not messages: 27 | logger.debug("No messages found in state") 28 | return state 29 | 30 | last_ai_message = None 31 | for msg in reversed(messages): 32 | if isinstance(msg, AIMessage): 33 | last_ai_message = msg 34 | break 35 | 36 | if not last_ai_message: 37 | logger.debug("No AI message found to summarize") 38 | return state 39 | 40 | # Create proper config structure 41 | config = { 42 | "configurable": { 43 | "model": "openai", # Specify model type 44 | "openai_api_key": os.environ.get("OPENAI_API_KEY"), 45 | "model_name": state.get("model_name", "4o") 46 | } 47 | } 48 | 49 | # Get the model with proper config 50 | model = _get_model(config, "anthropic", "summarize_model") 51 | 52 | # Get the last user message 53 | last_user_message = "No user message found" 54 | for msg in reversed(messages): 55 | if msg.type == "human": 56 | last_user_message = msg.content 57 | break 58 | 59 | # Format the prompt 60 | formatted_prompt = summarize_prompt.format( 61 | user_message=last_user_message, 62 | last_response=last_ai_message.content 63 | ) 64 | 65 | # Create message sequence 66 | messages = [SystemMessage(content=formatted_prompt)] 67 | 68 | # Get response 69 | response = model.invoke(messages) 70 | 71 | # Store summary in state 72 | state["aider_summary"] = response.content 73 | 74 | logger.debug("=== Summarize Node Completed ===") 75 | return state 76 | 77 | except Exception as e: 78 | logger.error(f"Error in summarize_node: {str(e)}") 79 | logger.error("Full traceback:", exc_info=True) 80 | raise -------------------------------------------------------------------------------- /src/langgraph_engineer/test_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import asyncio 4 | import sys 5 | from pathlib import Path 6 | from dotenv import load_dotenv 7 | from langchain_core.messages import HumanMessage 8 | 9 | # Add the project root directory to Python path 10 | project_root = str(Path(__file__).parent.parent.parent) 11 | if project_root not in sys.path: 12 | sys.path.insert(0, project_root) 13 | 14 | from langgraph_engineer.agent import Engineer 15 | from langgraph_engineer.model import _get_model 16 | 17 | # Set up logging with more detailed configuration 18 | logging.basicConfig( 19 | level=logging.INFO, 20 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 21 | handlers=[ 22 | logging.StreamHandler() 23 | ] 24 | ) 25 | logger = logging.getLogger(__name__) 26 | 27 | async def main(): 28 | # Load environment variables 29 | load_dotenv() 30 | 31 | # Check for required environment variables 32 | required_env_vars = ['GITHUB_TOKEN', 'ANTHROPIC_API_KEY'] 33 | missing_vars = [var for var in required_env_vars if not os.getenv(var)] 34 | if missing_vars: 35 | logger.error(f"Missing required environment variables: {', '.join(missing_vars)}") 36 | return 37 | 38 | logger.info(f"Python path: {sys.path}") 39 | logger.info("Starting the application...") 40 | 41 | # Initialize the Engineer class 42 | engineer = Engineer() 43 | 44 | # Test the model configuration 45 | config = { 46 | "configurable": { 47 | "gather_model": "anthropic", 48 | "draft_model": "anthropic", 49 | "critique_model": "anthropic" 50 | } 51 | } 52 | 53 | try: 54 | # Test model initialization 55 | model = _get_model(config, "anthropic", "draft_model") 56 | logger.info("Successfully initialized Anthropic model") 57 | except Exception as e: 58 | logger.error(f"Failed to initialize model: {str(e)}") 59 | return 60 | 61 | # Create test input state 62 | input_state = { 63 | "query": "enhance the readme file given the content that's already there", 64 | "repo_url": "https://github.com/RVCA212/LM-Systems", 65 | "github_token": os.getenv("GITHUB_TOKEN"), 66 | "repo_path": os.path.join(os.path.expanduser("~"), "clone"), 67 | "configurable": config["configurable"] 68 | } 69 | 70 | # Log environment check 71 | logger.info(f"GitHub Token present: {bool(input_state['github_token'])}") 72 | logger.info(f"Repository path: {input_state['repo_path']}") 73 | 74 | # Process the request 75 | try: 76 | logger.info("Processing request...") 77 | result = await engineer.process_request(input_state) 78 | logger.info(f"Process completed successfully: {result}") 79 | except Exception as e: 80 | logger.error(f"An error occurred: {str(e)}") 81 | raise 82 | 83 | if __name__ == "__main__": 84 | # Run the async main function 85 | asyncio.run(main()) -------------------------------------------------------------------------------- /src/langgraph_engineer/post_critique_router.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from langgraph_engineer.model import _get_model 3 | from langgraph_engineer.state import AgentState 4 | from langchain_core.messages import SystemMessage, AIMessage 5 | from langchain_core.pydantic_v1 import BaseModel 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class PostCritiqueResponse(BaseModel): 11 | """Schema for post-critique router response""" 12 | reasoning: str 13 | decision: str # Must be either 'Y' or 'N' 14 | 15 | post_critique_prompt = """You are a specialized validation agent that analyzes critique results to determine if code changes are complete and acceptable. 16 | 17 | RULES: 18 | 1. Analyze the critique results carefully 19 | 2. Output 'Y' if: 20 | - All requested changes were successfully implemented 21 | - Requirements were fully met 22 | - Implementation is correct and of good quality 23 | - Changes stayed within the required scope 24 | - No unauthorized or unexpected changes were made 25 | 26 | 3. Output 'N' if: 27 | - Changes are missing or incomplete 28 | - Requirements were not fully met 29 | - Implementation has quality issues 30 | - Changes exceeded scope or contain unauthorized modifications 31 | - Additional work or revisions are needed 32 | 33 | EXAMPLES: 34 | 35 | Critique that should return Y: 36 | "The changes were successfully implemented. All requirements were met, and the code quality is good. The implementation follows the specifications exactly." 37 | 38 | Critique that should return N: 39 | "Some required changes are missing. The implementation is incomplete and needs additional work. There are quality issues that need to be addressed." 40 | 41 | Analyze the following critique results and determine if the changes are complete and acceptable: 42 | {critique_logic}""" 43 | 44 | def post_critique_route(state: AgentState, config: Dict[str, Any]) -> AgentState: 45 | """Determine if the critique results warrant proceeding or require more work""" 46 | try: 47 | # Get the model with structured output 48 | model = _get_model(config, "openai", "post_critique_model").with_structured_output(PostCritiqueResponse) 49 | 50 | # Get critique logic from state 51 | critique_results = state.get('step_results', {}).get('critique', {}) 52 | critique_logic = critique_results.get('args', {}).get('logic', '') 53 | 54 | # Format the prompt 55 | formatted_prompt = post_critique_prompt.format(critique_logic=critique_logic) 56 | 57 | # Create message sequence 58 | messages = [ 59 | SystemMessage(content=formatted_prompt), 60 | AIMessage(content="Please analyze the critique results and determine if changes are complete.") 61 | ] 62 | 63 | # Get response 64 | response = model.invoke(messages) 65 | 66 | # Update state with the routing decision 67 | state['step_results']['post_critique_router'] = { 68 | 'args': { 69 | 'decision': response.decision, 70 | 'reasoning': response.reasoning 71 | } 72 | } 73 | 74 | return state 75 | 76 | except Exception as e: 77 | logger.error(f"Error in post_critique_route: {str(e)}") 78 | # Default to N on error for safety 79 | state['step_results']['post_critique_router'] = { 80 | 'args': { 81 | 'decision': 'N', 82 | 'reasoning': f'Error occurred: {str(e)}' 83 | } 84 | } 85 | return state -------------------------------------------------------------------------------- /src/langgraph_engineer/setup_node.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Dict 2 | from langchain_core.messages import AIMessage 3 | from langgraph.graph import END 4 | import os 5 | import time 6 | import logging 7 | 8 | from langgraph_engineer.state import AgentState, initialize_state 9 | from langgraph_engineer.tools import setup_node, force_clone, force_branch, ForceCloneInput, ForceBranchInput 10 | from git import Repo 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | def validate_branch_creation(state: dict) -> bool: 15 | """Validate that we're on the correct branch.""" 16 | try: 17 | repo_path = state.get('repo_path') 18 | repo = Repo(repo_path) 19 | current_branch = repo.active_branch.name 20 | expected_branch = state.get('branch_name') 21 | 22 | if current_branch != expected_branch: 23 | logger.error(f"Branch mismatch: expected {expected_branch}, got {current_branch}") 24 | return False 25 | 26 | return True 27 | except Exception as e: 28 | logger.error(f"Branch validation failed: {str(e)}") 29 | return False 30 | 31 | async def setup_repository(state: AgentState) -> AgentState: 32 | """Handle repository setup including cloning and branch creation.""" 33 | try: 34 | # Validate required values 35 | required_keys = ['repo_url', 'github_token', 'repo_path'] 36 | missing_values = [key for key in required_keys if not state.get(key)] 37 | if missing_values: 38 | raise ValueError(f"Missing required state values: {', '.join(missing_values)}") 39 | 40 | # Log state for debugging 41 | logger.debug(f"Setup node received state keys: {list(state.keys())}") 42 | logger.debug(f"Setup node repo_path: {state.get('repo_path')}") 43 | 44 | # Clone repository 45 | clone_result = await force_clone.coroutine( 46 | ForceCloneInput( 47 | url=state['repo_url'], 48 | path=state['repo_path'], 49 | state=state, 50 | config=None 51 | ) 52 | ) 53 | 54 | if not isinstance(clone_result, dict) or clone_result.get("status") != "success": 55 | raise ValueError(f"Failed to clone repository: {clone_result}") 56 | 57 | # Ensure repo_path is preserved in state 58 | state['repo_path'] = clone_result['repo_path'] 59 | 60 | # Create and checkout new branch 61 | branch_result = await force_branch.coroutine( 62 | ForceBranchInput( 63 | branch_name=state['branch_name'], 64 | state=state, 65 | config={"callbacks": None} 66 | ) 67 | ) 68 | 69 | if not isinstance(branch_result, dict) or branch_result.get("status") != "success": 70 | raise ValueError(f"Failed to create branch: {branch_result}") 71 | 72 | return state 73 | 74 | except Exception as e: 75 | logger.error(f"Error in setup_repository: {str(e)}", exc_info=True) 76 | raise 77 | # async def route_setup(state: AgentState) -> Literal["setup_node", "router_agent"]: 78 | # """Route to setup node if repository is not initialized, otherwise to router_agent.""" 79 | # if not state.get('repo_path') or not os.path.exists(state.get('repo_path', '')): 80 | # return "setup_node" 81 | # return "router_agent" 82 | 83 | async def validate_setup(state: AgentState) -> bool: 84 | """Validate that repository setup was successful.""" 85 | repo_path = state.get('repo_path', '') 86 | branch_name = state.get('branch_name', '') 87 | 88 | is_valid = ( 89 | os.path.exists(repo_path) and 90 | os.path.isdir(os.path.join(repo_path, '.git')) and 91 | branch_name != '' 92 | ) 93 | 94 | if not is_valid: 95 | logger.error(f"Setup validation failed: repo_path={repo_path}, branch_name={branch_name}") 96 | 97 | return is_valid 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Github Agent 2 | 3 | Github Agent clones a given Github Repository, then modifies it with a human-in-the-loop before pushing any code. 4 | 5 | Deployed API: [LMSystems.ai](https://www.lmsystems.ai/graphs/github-agent-48/test) 6 | 7 | IOS app: [Ship App](https://apps.apple.com/us/app/ship/id6738367546) 8 | 9 | # Quick Start 10 | 11 | ```pip install -e .``` 12 | 13 | ```pip install --upgrade "langgraph-cli[inmem]"``` 14 | 15 | ```langgraph dev``` 16 | 17 | Then use langgraph studio or connecting.ipynb to interact with the graph. 18 | 19 | Use langgraph's [quikstart guide]([https://langchain-ai.github.io/langgraph/tutorials/langgraph-platform/local-server/]) for more in depth quickstart instructions. 20 | 21 | 22 | ### Why we made this project: 23 | 24 | Coding seems to be the killer app for LLMs, but we see a much brighter future in "cloud coding" than coding with traditional IDEs. Cloud coding refers to having your coding enviornemnt running in a server rather than directly on your computer. We hope to make an AI application which you can give a "task" to and it can go off and complete that given task, no matter how long it takes. This only becomes possible in cloud coding enviornments for these reasons: 25 | - You can create a scalable enviornment which closes the feedback loop with coding (e.g. you can run the code or use computer use to give feedback on the locally running app). 26 | - Cloud enviornments allow for colaboration on many different AI apps and APIs 27 | - Lastly, a big bonus is that it's always on and can scale meaning you could in theory run 100 Coding agents in parralel which opens a world of possibilities. 28 | 29 | This project is the first step towards that world. 30 | 31 | ## Project Overview 32 | 33 | Our graph currently works like this: 34 | 35 | User Query => Setup Node => Aider Node <==> Human Interaction => Git Push Changes 36 | 37 | Here's each node and their corresponding purpose + files. 38 | 39 | [Main Graph File](src/langgraph_engineer/agent.py) 40 | 41 | - **[Setup Node](src/langgraph_engineer/setup_node.py)** this node clones the given repo with the repo url, github access token, and selected branch name 42 | - **[Aider Node](src/langgraph_engineer/aider_node.py)** this node uses [Aider]([https://aider.chat/]) to do the heavy lifting for making code changes to the repo. Aider is a cli tool operating on the cloned repo. We've tried to emulate what they've done with [Aider in the browser]([https://aider.chat/docs/usage/browser.html]) in order to try and capture the llm stream of tokens but we have yet to capture it. See [TODO.md](TODO.md) for more issues. 43 | [secondary file](src/langgraph_engineer/interactive_aider.py) 44 | - **Human Interaction** we've added a Human-in-the-Loop here which allows the user and aider node to have a back and forth conversation for as many times as they'd like before the human decides to push the local changes to github. To push the changes to github, you must set the 'Accepted' state values to 'True'. *located in the [main file](src/langgraph_engineer/agent.py)* 45 | - **[Git Push Node](src/langgraph_engineer/git_push_node.py)** Pushes the local changes to the selected branch. 46 | 47 | 48 | ## Contributing 49 | 50 | We welcome contributions from the community! Here's how you can help: 51 | 52 | 1. **Report Issues**: Submit bugs and feature requests through our issue tracker 53 | 2. **Submit Pull Requests**: Improve documentation, fix bugs, or add new features 54 | 3. **Follow Standards**: 55 | - Write clear commit messages 56 | - Follow PEP 8 style guide for Python code 57 | - Include tests for new features 58 | - Update documentation as needed 59 | 4. **Accomplish TODOs**: refer to the [TODO.md](TODO.md) file for a list of features that need to be implemented. 60 | 61 | 62 | ## License 63 | 64 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 65 | 66 | ## Support 67 | 68 | Contact me with any questions at sean@lmsystems.ai 69 | 70 | --- 71 | Made by LM Systems - Building the future of Shareable Graphs 72 | -------------------------------------------------------------------------------- /src/langgraph_engineer/gather_requirements.py: -------------------------------------------------------------------------------- 1 | from langgraph_engineer.model import _get_model 2 | from langgraph_engineer.state import AgentState 3 | from typing import TypedDict 4 | from langchain_core.messages import RemoveMessage, HumanMessage 5 | import logging 6 | import os 7 | from pathlib import Path 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | def get_directory_tree(start_path: str, indent: str = "", prefix: str = "") -> str: 12 | """Generate a directory tree string starting from the given path.""" 13 | if not os.path.exists(start_path): 14 | logger.error(f"Repository path does not exist: {start_path}") 15 | return "Error: Path does not exist" 16 | 17 | try: 18 | # Convert to Path object for better path handling 19 | start_path = Path(start_path).resolve() 20 | base_name = start_path.name 21 | tree = [f"{prefix}{base_name}/"] 22 | 23 | # List and sort directory contents 24 | items = sorted(item for item in start_path.iterdir() 25 | if item.name not in {'.git', '__pycache__', '.DS_Store'}) 26 | 27 | for i, item in enumerate(items): 28 | is_last = i == len(items) - 1 29 | next_indent = indent + (" " if is_last else "│ ") 30 | next_prefix = indent + ("└── " if is_last else "├── ") 31 | 32 | if item.is_dir(): 33 | # Recursively process directories 34 | subtree = get_directory_tree(str(item), next_indent, next_prefix) 35 | tree.append(subtree) 36 | else: 37 | # Add files with relative path 38 | rel_path = item.relative_to(start_path) 39 | tree.append(f"{next_prefix}{item.name} ({rel_path})") 40 | 41 | return "\n".join(tree) 42 | except Exception as e: 43 | logger.error(f"Error generating directory tree: {str(e)}") 44 | return f"Error reading directory: {str(e)}" 45 | 46 | gather_prompt = """You are the requirements gathering component of an AI software developer system. \ 47 | Your role is to either clarify user requests or pass them directly to implementation when clear enough. 48 | 49 | Current Repository Structure: 50 | {directory_structure} 51 | 52 | Your task is to: 53 | 1. Quickly assess if the user's request is clear enough to implement 54 | 2. If clear: Call the `Build` tool with the requirements 55 | 3. If unclear: Ask a clarifying follow-up question 56 | 57 | please error on the side of passing the task off to the react agent who can read files and execute code. 58 | 59 | Only ask for clarification when: 60 | - The request is fundamentally ambiguous 61 | - Critical technical details are missing 62 | - Security implications need clarification 63 | 64 | Most requests should pass directly to implementation. When in doubt, proceed rather than ask.""" 65 | 66 | 67 | class Build(TypedDict): 68 | requirements: str 69 | 70 | 71 | def gather_requirements(state: AgentState, config): 72 | # Get directory tree 73 | repo_path = state.get('repo_path') 74 | directory_structure = get_directory_tree(repo_path) if repo_path else "Repository not yet cloned" 75 | 76 | messages = [ 77 | {"role": "system", "content": gather_prompt.format(directory_structure=directory_structure)} 78 | ] + state['messages'] 79 | model = _get_model(config, "openai", "gather_model").bind_tools([Build]) 80 | response = model.invoke(messages) 81 | if len(response.tool_calls) == 0: 82 | return {"messages": [response]} 83 | else: 84 | requirements = response.tool_calls[0]['args']['requirements'] 85 | delete_messages = [RemoveMessage(id=m.id) for m in state['messages']] 86 | requirements_message = HumanMessage(content=f"Here are the gathered requirements:\n\n{requirements}\n\nPlease proceed with implementing these requirements.") 87 | return { 88 | "requirements": requirements, 89 | "messages": delete_messages + [requirements_message] 90 | } 91 | -------------------------------------------------------------------------------- /src/langgraph_engineer/route_message.py: -------------------------------------------------------------------------------- 1 | from semantic_router import Route, RouteLayer 2 | from semantic_router.encoders import OpenAIEncoder 3 | from typing import Dict, Any 4 | from langgraph_engineer.state import AgentState, PlanStep 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | # Define routes for different types of queries 10 | chat_route = Route( 11 | name="chat", 12 | utterances=[ 13 | "give me an overview of this repository", 14 | "describe the project structure", 15 | "list the main contributors", 16 | "what technologies are used here", 17 | "show me the project dependencies", 18 | "explain the system architecture", 19 | "what's the development status", 20 | "describe the main components", 21 | "what are the key features", 22 | "show documentation for this project", 23 | ], 24 | ) 25 | 26 | easy_route = Route( 27 | name="easy", 28 | utterances=[ 29 | "change variable name from X to Y", 30 | "add type hints to this function", 31 | "fix the indentation in this file", 32 | "add a docstring to explain this code", 33 | "remove this unused import statement", 34 | "add error handling here", 35 | "update this log message", 36 | "fix this spelling mistake", 37 | "add these missing parameters", 38 | "update this function's return type", 39 | "move this code block to a new function", 40 | ], 41 | ) 42 | 43 | hard_route = Route( 44 | name="hard", 45 | utterances=[ 46 | "create a new feature that...", 47 | "refactor this module to use...", 48 | "implement caching for...", 49 | "add authentication to...", 50 | "optimize the performance of...", 51 | "create unit tests for...", 52 | "integrate this new library...", 53 | "implement a new API endpoint...", 54 | "fix these security issues...", 55 | "redesign this component to...", 56 | "add support for async operations", 57 | ], 58 | ) 59 | 60 | # Initialize the route layer 61 | encoder = OpenAIEncoder() 62 | route_layer = RouteLayer(encoder=encoder, routes=[chat_route, easy_route, hard_route]) 63 | 64 | def route_message(state: AgentState) -> AgentState: 65 | """Routes the user's message to either chat, easy, or hard paths""" 66 | try: 67 | # Get requirements from state 68 | requirements = state.get('requirements', '') 69 | if not requirements and state.get('messages'): 70 | # If no requirements but we have messages, use the last human message 71 | for msg in reversed(state['messages']): 72 | if msg.type == 'human': 73 | requirements = msg.content 74 | break 75 | 76 | # Use semantic router to determine the route 77 | route_choice = route_layer(requirements) 78 | route_type = route_choice.name if route_choice else "hard" # Default to hard if no match 79 | 80 | # Update state with routing decision 81 | new_state = { 82 | **state, 83 | "router_analysis": { 84 | "route_type": route_type, 85 | "changes_req": route_type != "chat" 86 | } 87 | } 88 | 89 | # Prepare step information for both chat and easy routes 90 | if route_type in ["chat", "easy"]: 91 | new_state["steps"] = [ 92 | PlanStep( 93 | reasoning="Direct execution of request", 94 | step_id="S1", 95 | tool_name="aider_shell", 96 | tool_args={ 97 | "message": requirements, 98 | "files": "." 99 | } 100 | ) 101 | ] 102 | new_state["current_step"] = 0 103 | new_state["execution_status"] = "executing" 104 | 105 | return new_state 106 | 107 | except Exception as e: 108 | logger.error(f"Error in route_message: {str(e)}") 109 | # Default to hard on error for safety 110 | return { 111 | **state, 112 | "router_analysis": { 113 | "route_type": "hard", 114 | "changes_req": True 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/langgraph_engineer/git_push_node.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import subprocess 4 | from langgraph_engineer.state import AgentState 5 | from langgraph_engineer.tools import git_status, git_add, git_commit, git_push 6 | from langchain_core.messages import AIMessage 7 | from langgraph_engineer.verify_branch import verify_branch 8 | from git import Repo 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def git_push_changes(state: AgentState, config) -> dict: 14 | """Handle git operations and push changes after critique acceptance.""" 15 | try: 16 | # Enhanced validation 17 | if not isinstance(state, dict): 18 | raise ValueError(f"Invalid state type: {type(state)}") 19 | 20 | # Add github_token to required keys 21 | required_keys = ['repo_path', 'branch_name', 'repo_url', 'github_token'] 22 | missing_keys = [key for key in required_keys if not state.get(key)] 23 | if missing_keys: 24 | raise ValueError(f"Missing or empty required state keys: {missing_keys}") 25 | 26 | # Validate GitHub token 27 | github_token = state.get('github_token') 28 | if not github_token: 29 | raise ValueError("Missing GitHub Personal Access Token") 30 | 31 | # Validate paths and URLs 32 | repo_path = state['repo_path'] 33 | if not os.path.exists(repo_path): 34 | raise ValueError(f"Repository path does not exist: {repo_path}") 35 | 36 | repo_url = state['repo_url'] 37 | if not repo_url.startswith('https://github.com/'): 38 | raise ValueError(f"Invalid GitHub URL format: {repo_url}") 39 | 40 | # Verify branch before operations 41 | if not verify_branch(state['repo_path'], state['branch_name']): 42 | raise ValueError(f"Not on the expected branch: {state['branch_name']}") 43 | 44 | # Optional dry-run mode for safer operations 45 | dry_run = state.get('dry_run', False) 46 | logger.info(f"Git push operation {'(DRY RUN)' if dry_run else ''}") 47 | 48 | # Log current state with sensitive info masked 49 | masked_state = {k: ('*****' if k == 'github_token' else v) for k, v in state.items()} 50 | logger.info(f"Starting git operations with state: {masked_state}") 51 | 52 | # Check git status 53 | status_result = git_status.invoke({ 54 | "config": {"tool_choice": "git_status"}, 55 | "state": state 56 | }) 57 | logger.info(f"Git status: {status_result}") 58 | 59 | # Check if there are changes to commit 60 | if "nothing to commit" in status_result.lower(): 61 | logger.info("No changes to commit") 62 | return { 63 | "messages": [AIMessage(content="No changes to push - repository is up to date")], 64 | "accepted": True 65 | } 66 | 67 | # Stage all changes 68 | add_result = git_add.invoke({ 69 | "file_path": ".", 70 | "config": {"tool_choice": "git_add"}, 71 | "state": state 72 | }) 73 | logger.info(f"Git add result: {add_result}") 74 | 75 | # Create commit 76 | requirements = state.get('requirements', 'No requirements specified') 77 | commit_message = f"Implemented changes for:\n{requirements}" 78 | commit_result = git_commit.invoke({ 79 | "message": commit_message, 80 | "config": {"tool_choice": "git_commit"}, 81 | "state": state 82 | }) 83 | logger.info(f"Git commit result: {commit_result}") 84 | 85 | # Push changes using the tool 86 | if not dry_run: 87 | push_result = git_push.invoke({ 88 | "config": {"tool_choice": "git_push"}, 89 | "state": state 90 | }) 91 | logger.info(f"Git push result: {push_result}") 92 | 93 | if any(error_term in push_result.lower() for error_term in ["error", "failed", "fatal"]): 94 | raise ValueError(push_result) 95 | else: 96 | push_result = "Dry run - no changes pushed" 97 | logger.warning("Dry run mode: No changes were pushed to the repository") 98 | 99 | return { 100 | "messages": [AIMessage( 101 | content=f"{'Dry run: ' if dry_run else ''}Changes {'would be ' if dry_run else ''}pushed successfully to branch: {state['branch_name']}\n" 102 | f"Repository: {state['repo_url']}\n" 103 | f"Requirements implemented:\n{requirements}" 104 | )], 105 | "accepted": True 106 | } 107 | 108 | except Exception as e: 109 | logger.error(f"Error in git_push_changes: {str(e)}", exc_info=True) 110 | return { 111 | "messages": [AIMessage( 112 | content=f"Error during git operations: {str(e)}" 113 | )], 114 | "accepted": False 115 | } 116 | -------------------------------------------------------------------------------- /src/langgraph_engineer/critique.py: -------------------------------------------------------------------------------- 1 | from langgraph_engineer.model import _get_model 2 | from langgraph_engineer.state import AgentState 3 | from langchain_core.messages import AIMessage, SystemMessage 4 | from langchain_core.pydantic_v1 import BaseModel 5 | from git import Repo 6 | import logging 7 | import json 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | # Simplified critique prompt focused on natural language response 12 | critique_prompt = """Review if the Task was fully completed given the user's requirements and the steps/changes made. 13 | 14 | Original User Requirements: 15 | {requirements} 16 | 17 | Here was our agent's original plan: 18 | Plan: 19 | {plan} 20 | 21 | and here's the results from that plan: 22 | Step Results: 23 | {step_results} 24 | 25 | Please provide a detailed analysis that covers: 26 | 1. What specific changes were found and implemented 27 | 2. How well each requirement was addressed 28 | 3. The quality of the implementation 29 | 4. Whether changes stayed within the expected scope 30 | 31 | Use clear, standardized phrases like: 32 | - "changes were successfully implemented" or "changes are missing" 33 | - "requirements were fully met" or "requirements were not fully met" 34 | - "implementation was correct" or "implementation needs revision" 35 | - "changes stayed within scope" or "unauthorized changes detected" 36 | """ 37 | 38 | 39 | class Accept(BaseModel): 40 | """Schema for critique response""" 41 | logic: str 42 | accept: bool 43 | completion_status: str 44 | 45 | 46 | def _swap_messages(messages): 47 | new_messages = [] 48 | for m in messages: 49 | if isinstance(m, AIMessage): 50 | new_messages.append({"role": "user", "content": m.content}) 51 | else: 52 | new_messages.append({"role": "assistant", "content": m.content}) 53 | return new_messages 54 | 55 | 56 | def get_git_diff(state: AgentState) -> str: 57 | """Get git diff directly without using the tool interface""" 58 | try: 59 | repo = Repo(state['repo_path']) 60 | diff_output = [] 61 | 62 | # Check for staged changes 63 | staged_diff = repo.git.diff('--cached') 64 | if staged_diff: 65 | diff_output.append("=== Staged Changes ===") 66 | diff_output.append(staged_diff) 67 | 68 | # Check for unstaged changes 69 | unstaged_diff = repo.git.diff() 70 | if unstaged_diff: 71 | diff_output.append("\n=== Unstaged Changes ===") 72 | diff_output.append(unstaged_diff) 73 | 74 | # If no current changes, show the last commit diff 75 | if not diff_output: 76 | if len(repo.heads) > 0: 77 | last_commit = repo.head.commit 78 | if last_commit.parents: 79 | diff_output.append("=== Last Commit Diff ===") 80 | diff_output.append(repo.git.diff(f'{last_commit.parents[0].hexsha}..{last_commit.hexsha}')) 81 | else: 82 | diff_output.append("=== Initial Commit Diff ===") 83 | diff_output.append(repo.git.diff(last_commit.hexsha)) 84 | else: 85 | return "No commits in the repository yet." 86 | 87 | return "\n".join(diff_output) if diff_output else "No changes detected" 88 | 89 | except Exception as e: 90 | logger.error(f"Error getting diff: {str(e)}") 91 | return f"Error getting diff: {str(e)}" 92 | 93 | 94 | def critique(state: AgentState, config) -> AgentState: 95 | """Modified critique to provide natural language analysis""" 96 | try: 97 | # Format the prompt with required information 98 | formatted_prompt = critique_prompt.format( 99 | requirements=state.get('requirements', ''), 100 | plan=state.get('plan_string', ''), 101 | step_results=json.dumps(state.get('step_results', {}), indent=2) 102 | ) 103 | 104 | model = _get_model(config, "openai", "critique_model") 105 | 106 | # Create message sequence 107 | message_sequence = [ 108 | SystemMessage(content=formatted_prompt), 109 | AIMessage(content="Please analyze the implementation.") 110 | ] 111 | 112 | # Get unstructured response from the model 113 | response = model.invoke(message_sequence) 114 | critique_logic = response.content 115 | 116 | # Update state with critique results 117 | new_state = { 118 | **state, 119 | "step_results": { 120 | **(state.get("step_results", {})), 121 | "critique": { 122 | "args": { 123 | "logic": critique_logic 124 | } 125 | } 126 | } 127 | } 128 | 129 | return new_state 130 | 131 | except Exception as e: 132 | logger.error(f"Error in critique: {str(e)}") 133 | return { 134 | **state, 135 | "step_results": { 136 | **(state.get("step_results", {})), 137 | "critique": { 138 | "args": { 139 | "logic": f"Error occurred during critique: {str(e)}" 140 | } 141 | } 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/langgraph_engineer/diff_node.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import os 3 | import difflib 4 | import logging 5 | from git import Repo, GitCommandError 6 | from langchain_core.messages import AIMessage 7 | from langgraph_engineer.state import AgentState 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | def show_file_diffs(state: AgentState) -> AgentState: 12 | """Show diffs between current state and last commit for all changed files.""" 13 | try: 14 | # Validate required state 15 | repo_path = state.get('repo_path') 16 | branch_name = state.get('branch_name') 17 | if not repo_path or not branch_name: 18 | raise ValueError("Repository path and branch name are required in state") 19 | 20 | repo = Repo(repo_path) 21 | 22 | # Add debug logging 23 | logger.debug(f"Current branch: {repo.active_branch.name}") 24 | logger.debug(f"Expected branch: {branch_name}") 25 | logger.debug(f"Git status:\n{repo.git.status()}") 26 | logger.debug(f"Git diff:\n{repo.git.diff()}") 27 | 28 | # Ensure we're on the correct branch 29 | current_branch = repo.active_branch.name 30 | if current_branch != branch_name: 31 | logger.warning(f"Not on expected branch. Expected: {branch_name}, Current: {current_branch}") 32 | try: 33 | # Try to switch to correct branch 34 | repo.git.checkout(branch_name) 35 | logger.info(f"Switched to branch: {branch_name}") 36 | except GitCommandError as e: 37 | raise ValueError(f"Failed to switch to branch {branch_name}: {str(e)}") 38 | 39 | # Get all changes using multiple git methods 40 | changed_files = [] 41 | 42 | # Get unstaged changes 43 | unstaged = [item.a_path for item in repo.index.diff(None)] 44 | changed_files.extend(unstaged) 45 | 46 | # Get staged changes 47 | staged = [item.a_path for item in repo.index.diff('HEAD')] 48 | changed_files.extend(staged) 49 | 50 | # Get untracked files 51 | untracked = repo.untracked_files 52 | changed_files.extend(untracked) 53 | 54 | # Remove duplicates 55 | changed_files = list(set(changed_files)) 56 | 57 | logger.debug(f"Detected changed files: {changed_files}") 58 | 59 | all_diffs = [] 60 | 61 | for file_path in changed_files: 62 | try: 63 | file_full_path = os.path.abspath(os.path.join(repo_path, file_path)) 64 | logger.debug(f"Processing file: {file_path}") 65 | logger.debug(f"Full path: {file_full_path}") 66 | logger.debug(f"File exists: {os.path.exists(file_full_path)}") 67 | 68 | # Get old content 69 | try: 70 | old_content = repo.git.show(f'HEAD:{file_path}') 71 | except Exception: 72 | old_content = '' 73 | 74 | # Get new content 75 | try: 76 | with open(file_full_path, 'r') as f: 77 | new_content = f.read() 78 | except Exception: 79 | new_content = '' 80 | 81 | # Generate diff 82 | diff = difflib.unified_diff( 83 | old_content.splitlines(keepends=True), 84 | new_content.splitlines(keepends=True), 85 | fromfile=f'a/{file_path}', 86 | tofile=f'b/{file_path}' 87 | ) 88 | all_diffs.extend(diff) 89 | 90 | except Exception as e: 91 | logger.error(f"Error processing file {file_path}: {str(e)}") 92 | continue 93 | 94 | # Handle untracked files 95 | untracked = repo.untracked_files 96 | for file_path in untracked: 97 | try: 98 | with open(os.path.join(repo_path, file_path), 'r') as f: 99 | new_content = f.read() 100 | 101 | # Show new file content as addition 102 | diff = difflib.unified_diff( 103 | [], 104 | new_content.splitlines(keepends=True), 105 | fromfile=f'/dev/null', 106 | tofile=f'b/{file_path}' 107 | ) 108 | all_diffs.extend(diff) 109 | except Exception as e: 110 | logger.error(f"Error processing untracked file {file_path}: {str(e)}") 111 | continue 112 | 113 | # Add diff results to state messages 114 | diff_text = ''.join(all_diffs) 115 | if diff_text: 116 | state['messages'].append( 117 | AIMessage(content=f"Here are the changes made on branch '{branch_name}':\n```diff\n{diff_text}\n```") 118 | ) 119 | else: 120 | state['messages'].append( 121 | AIMessage(content=f"No changes detected in the repository on branch '{branch_name}'.") 122 | ) 123 | 124 | return state 125 | 126 | except Exception as e: 127 | logger.error(f"Error showing diffs: {str(e)}") 128 | state['messages'].append( 129 | AIMessage(content=f"Error showing diffs: {str(e)}") 130 | ) 131 | return state -------------------------------------------------------------------------------- /src/langgraph_engineer/state.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, List, Dict, Optional, Literal, Union, Any 2 | from typing_extensions import TypedDict 3 | from langchain_core.messages import BaseMessage, HumanMessage, AIMessage 4 | from langgraph.graph.message import add_messages 5 | from pydantic import BaseModel, Field, ConfigDict 6 | import json 7 | import logging 8 | import os 9 | from pathlib import Path 10 | from datetime import datetime, timezone 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | # Input/Output schemas 15 | class InputState(TypedDict, total=False): 16 | repo_url: str 17 | query: str 18 | user_id: str 19 | github_token: str 20 | branch_name: Optional[str] 21 | anthropic_api_key: Optional[str] 22 | openai_api_key: Optional[str] 23 | chat_mode: Optional[str] 24 | model_name: Optional[str] 25 | 26 | class OutputState(TypedDict): 27 | code: str 28 | 29 | # Aider-specific schemas 30 | class AiderState(BaseModel): 31 | """Track the state of the aider chat session""" 32 | initialized: bool = False 33 | last_prompt: Optional[str] = None 34 | waiting_for_input: bool = False 35 | input_type: Optional[str] = None 36 | setup_complete: bool = False 37 | model_name: str = Field(default='openai') 38 | last_files: List[str] = Field(default_factory=list) 39 | conversation_history: List[dict] = Field(default_factory=list) 40 | 41 | # Update to use Pydantic v2 config 42 | model_config = ConfigDict( 43 | arbitrary_types_allowed=True, 44 | json_encoders={ 45 | BaseModel: lambda v: v.model_dump() 46 | } 47 | ) 48 | 49 | # Simplified main agent state 50 | class AgentState(TypedDict, total=False): 51 | """Used for Aider Chat""" 52 | messages: List[Union[HumanMessage, AIMessage]] 53 | code: str 54 | repo_url: str 55 | repo_path: str 56 | branch_name: str 57 | github_token: str 58 | anthropic_api_key: Optional[str] 59 | openai_api_key: Optional[str] 60 | aider_state: AiderState 61 | chat_mode: Optional[str] 62 | model_name: Optional[str] 63 | accepted: bool 64 | show_diff: bool 65 | aider_summary: Optional[str] 66 | 67 | def initialize_state( 68 | repo_url: str, 69 | github_token: str, 70 | repo_path: str, 71 | anthropic_api_key: Optional[str] = None, 72 | openai_api_key: Optional[str] = None, 73 | branch_name: Optional[str] = None, 74 | chat_mode: Optional[str] = None, 75 | model_name: Optional[str] = None, 76 | query: Optional[str] = None, 77 | ) -> AgentState: 78 | """Initialize state with all required fields""" 79 | 80 | # Prioritize input API keys over environment variables 81 | anthropic_api_key = anthropic_api_key if anthropic_api_key is not None else os.getenv('ANTHROPIC_API_KEY') 82 | openai_api_key = openai_api_key if openai_api_key is not None else os.getenv('OPENAI_API_KEY') 83 | 84 | # Log which source we're using for each API key 85 | logger.debug(f"Using Anthropic API key from: {'parameter' if anthropic_api_key != os.getenv('ANTHROPIC_API_KEY') else 'environment'}") 86 | logger.debug(f"Using OpenAI API key from: {'parameter' if openai_api_key != os.getenv('OPENAI_API_KEY') else 'environment'}") 87 | 88 | repo_path = str(Path(repo_path)) 89 | 90 | if not branch_name: 91 | timestamp = int(datetime.now(timezone.utc).timestamp()) 92 | branch_name = f"feature/ai-changes-{timestamp}" 93 | 94 | valid_models = ['haiku', 'sonnet', '4o', 'o1', 'gpt-4o-mini'] 95 | if model_name and model_name not in valid_models: 96 | raise ValueError(f"model_name must be one of {valid_models}") 97 | 98 | # Validate API key based on model 99 | if model_name in ['haiku', 'sonnet'] and not anthropic_api_key: 100 | raise ValueError(f"Anthropic API key is required for model {model_name}") 101 | elif model_name in ['4o', 'o1', 'gpt-4o-mini'] and not openai_api_key: 102 | raise ValueError(f"OpenAI API key is required for model {model_name}") 103 | 104 | state = { 105 | "messages": [], 106 | "code": "", 107 | "repo_url": repo_url, 108 | "repo_path": repo_path, 109 | "branch_name": branch_name, 110 | "github_token": github_token, 111 | "anthropic_api_key": anthropic_api_key, 112 | "openai_api_key": openai_api_key, 113 | "chat_mode": chat_mode, 114 | "accepted": False, 115 | "show_diff": False, 116 | "aider_state": AiderState( 117 | initialized=True, 118 | model_name=model_name or '4o', 119 | conversation_history=[], 120 | last_files=[], 121 | waiting_for_input=False, 122 | setup_complete=True 123 | ), 124 | "model_name": model_name or '4o' 125 | } 126 | 127 | if query: 128 | state["messages"] = [ 129 | HumanMessage( 130 | content=query, 131 | additional_kwargs={"role": "user"} 132 | ) 133 | ] 134 | 135 | return state 136 | 137 | # Keep the serialization helpers 138 | def serialize_state_for_llm(state: AgentState) -> dict: 139 | """Prepare state for LLM by serializing all complex objects.""" 140 | serialized = dict(state) 141 | 142 | if 'messages' in serialized: 143 | serialized['messages'] = [ 144 | msg.dict() if hasattr(msg, 'dict') else 145 | {'type': msg.__class__.__name__, 'content': msg.content} 146 | for msg in serialized['messages'] 147 | ] 148 | 149 | if 'aider_state' in serialized: 150 | if isinstance(serialized['aider_state'], AiderState): 151 | serialized['aider_state'] = serialized['aider_state'].model_dump() 152 | elif isinstance(serialized['aider_state'], dict): 153 | serialized['aider_state'] = serialized['aider_state'] 154 | 155 | for key, value in serialized.items(): 156 | if hasattr(value, 'model_dump'): 157 | serialized[key] = value.model_dump() 158 | elif isinstance(value, (set, tuple)): 159 | serialized[key] = list(value) 160 | 161 | return serialized 162 | 163 | def validate_state(state: Dict) -> None: 164 | """Validate that all state values are serializable.""" 165 | try: 166 | json.dumps(serialize_state_for_llm(state)) 167 | except (TypeError, ValueError) as e: 168 | raise ValueError(f"State contains non-serializable data: {str(e)}") 169 | 170 | # Add this with the other TypedDict definitions 171 | class GraphConfig(TypedDict, total=False): 172 | """Configuration for the graph""" 173 | anthropic_api_key: Optional[str] 174 | openai_api_key: Optional[str] 175 | github_token: str 176 | repo_path: str 177 | branch_name: str 178 | chat_mode: Optional[str] 179 | model_name: Optional[str] 180 | 181 | -------------------------------------------------------------------------------- /src/langgraph_engineer/aider_node.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from langchain_core.messages import AIMessage, HumanMessage, AIMessageChunk, BaseMessage 4 | from langgraph_engineer.state import AgentState 5 | from langgraph_engineer.interactive_aider import InteractiveAider 6 | from langgraph.graph.message import add_messages 7 | from datetime import datetime, timezone 8 | from pathlib import Path 9 | import os 10 | logger = logging.getLogger(__name__) 11 | 12 | def create_aider_node(): 13 | """ 14 | Create a node function that handles interactive Aider execution, streaming the 15 | partial response back in real time. 16 | """ 17 | 18 | async def aider_node_fn(state: AgentState) -> AgentState: 19 | """ 20 | Node function that processes messages using InteractiveAider. 21 | Streams chunked output as AIMessageChunks and a final AIMessage. 22 | """ 23 | try: 24 | logger.debug("=== Interactive Aider Node Starting ===") 25 | 26 | # Ensure state is properly initialized 27 | if not isinstance(state, dict): 28 | state = dict(state) 29 | 30 | # Initialize messages if not present 31 | if "messages" not in state: 32 | state["messages"] = [] 33 | 34 | # Get last message 35 | messages = state.get("messages", []) 36 | if not messages: 37 | raise ValueError("No messages in state") 38 | 39 | last_message = messages[-1] 40 | if not isinstance(last_message, (HumanMessage, AIMessage)): 41 | # Convert to proper message type if needed 42 | last_message = HumanMessage( 43 | content=str(last_message), 44 | additional_kwargs={"role": "user"} 45 | ) 46 | 47 | user_text = last_message.content 48 | 49 | # Initialize aider 50 | repo_path = state.get("repo_path") 51 | if not repo_path: 52 | raise ValueError("Repository path not found in state") 53 | 54 | # Determine which API key to use based on model 55 | model_name = state.get("model_name", "4o") 56 | anthropic_models = ["sonnet", "haiku"] 57 | if model_name in anthropic_models: 58 | # Prioritize state API key over environment variable 59 | api_key = state.get("anthropic_api_key") 60 | if api_key is None: # Only use env var if state key is None 61 | api_key = os.getenv("ANTHROPIC_API_KEY") 62 | source = "state" if state.get("anthropic_api_key") is not None else "environment" 63 | logger.info(f"Using Anthropic API key from: {source}") 64 | else: # Default to OpenAI 65 | # Prioritize state API key over environment variable 66 | api_key = state.get("openai_api_key") 67 | if api_key is None: # Only use env var if state key is None 68 | api_key = os.getenv("OPENAI_API_KEY") 69 | source = "state" if state.get("openai_api_key") is not None else "environment" 70 | logger.info(f"Using OpenAI API key from: {source}") 71 | logger.info(f"OpenAI key starts with: {api_key[:10]}..." if api_key else "No OpenAI key found!") 72 | 73 | if not api_key: 74 | raise ValueError(f"No API key found for model {model_name}. Please provide a valid API key.") 75 | 76 | logger.info(f"Using API key for model {model_name}: {api_key[:50]}...") 77 | 78 | aider = InteractiveAider( 79 | repo_path=repo_path, 80 | model=model_name, 81 | api_key=api_key, 82 | api_base=state.get("api_base") 83 | ) 84 | 85 | # Initialize a content accumulator 86 | accumulated_content = "" 87 | 88 | # Track edits for this session 89 | if "edits" not in state: 90 | state["edits"] = [] 91 | 92 | try: 93 | # Process chunks 94 | async for chunk in aider.execute_command(user_text): 95 | if chunk["type"] == "edit": 96 | # Store edit information in state 97 | edit = { 98 | "type": "edit", 99 | "commit_hash": chunk.get("commit_hash"), 100 | "commit_message": chunk.get("commit_message"), 101 | "files": chunk.get("files", []), 102 | "diff": chunk.get("content") 103 | } 104 | state["edits"].append(edit) 105 | 106 | # Create a message for the diff 107 | diff_message = AIMessage( 108 | content=f"Made changes to: {', '.join(edit['files'])}\n\n```diff\n{edit['diff']}\n```", 109 | additional_kwargs={ 110 | "role": "assistant", 111 | "type": "edit", 112 | "edit_info": edit 113 | } 114 | ) 115 | state["messages"].append(diff_message) 116 | else: 117 | # Handle regular message chunks as before 118 | accumulated_content += chunk.get("content", "") 119 | 120 | # After processing all chunks, create a single AIMessage 121 | final_message = AIMessage( 122 | content=accumulated_content, 123 | additional_kwargs={"role": "assistant", "type": "message"} 124 | ) 125 | 126 | # Update state with the final AIMessage 127 | state["messages"].append(final_message) 128 | 129 | except Exception as e: 130 | error_msg = str(e) 131 | # Check for authentication-specific errors 132 | if any(auth_err in error_msg.lower() for auth_err in [ 133 | "authentication_error", 134 | "invalid x-api-key", 135 | "invalid api key", 136 | "unable to authenticate", 137 | "auth" 138 | ]): 139 | auth_error_msg = ( 140 | f"API Authentication Error: The provided API key for {state.get('model_name', 'unknown')} " 141 | f"model is invalid or has expired. Please check your API key and try again.\n\n" 142 | f"Error details: {error_msg}" 143 | ) 144 | error_message = AIMessage( 145 | content=auth_error_msg, 146 | additional_kwargs={ 147 | "role": "assistant", 148 | "type": "error", 149 | "error_type": "authentication" 150 | } 151 | ) 152 | else: 153 | # Handle other types of errors 154 | error_message = AIMessage( 155 | content=f"Error: {error_msg}", 156 | additional_kwargs={ 157 | "role": "assistant", 158 | "type": "error", 159 | "error_type": "general" 160 | } 161 | ) 162 | state["messages"].append(error_message) 163 | raise # Re-raise to be caught by outer try/except 164 | 165 | return state 166 | 167 | except Exception as e: 168 | logger.error(f"Error in aider_node_fn: {str(e)}") 169 | logger.error("Full traceback:", exc_info=True) 170 | 171 | # If no error message has been added yet, add one 172 | if not any(msg.additional_kwargs.get("type") == "error" 173 | for msg in state.get("messages", [])): 174 | error_message = AIMessage( 175 | content=f"Error in Aider node: {str(e)}", 176 | additional_kwargs={"role": "assistant", "type": "error"} 177 | ) 178 | if "messages" not in state: 179 | state["messages"] = [] 180 | state["messages"].append(error_message) 181 | 182 | return state 183 | 184 | return aider_node_fn 185 | -------------------------------------------------------------------------------- /src/langgraph_engineer/agent.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Dict, Any, Annotated 2 | import os 3 | import shutil 4 | import time 5 | import logging 6 | import json 7 | from git import Repo 8 | import difflib 9 | 10 | from langgraph.graph import StateGraph, END, MessagesState 11 | from langchain_core.messages import AIMessage, HumanMessage, FunctionMessage 12 | from langgraph_engineer.tools import ( 13 | AiderShellTool, 14 | AiderToReactOutput, 15 | aider_command 16 | ) 17 | # Update imports 18 | from langgraph.prebuilt import ToolNode 19 | from langgraph_engineer.check import check 20 | from langgraph_engineer.state import ( 21 | AgentState, 22 | OutputState, 23 | GraphConfig, 24 | InputState, 25 | initialize_state, 26 | AiderState 27 | ) 28 | 29 | from langgraph_engineer.setup_node import setup_repository 30 | from langgraph_engineer.git_push_node import git_push_changes 31 | from langchain_core.tools import BaseTool 32 | from langgraph.types import interrupt, Command 33 | 34 | from langgraph_engineer.aider_node import create_aider_node 35 | from langgraph_engineer.diff_node import show_file_diffs 36 | from langgraph_engineer.summarize_node import summarize_response 37 | 38 | anthropic_api_key = os.getenv('ANTHROPIC_API_KEY', '') 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | def route_critique(state: AgentState) -> Literal["react_agent", "git_push_changes", END]: 43 | """Route after critique based on validation results.""" 44 | if state.get("step_results"): 45 | latest_result = list(state["step_results"].values())[-1] 46 | completion_status = latest_result.get("args", {}).get("completion_status") 47 | if completion_status in ["complete_with_changes", "complete_no_changes"]: 48 | state["execution_status"] = "complete" 49 | 50 | if not state.get("accepted"): 51 | state["execution_status"] = "planning" 52 | return "react_agent" 53 | elif state["execution_status"] == "complete": 54 | return "git_push_changes" 55 | else: 56 | return "react_agent" 57 | 58 | def route_git_push(state: AgentState) -> Literal[END]: 59 | return END 60 | 61 | def route_start(state: AgentState) -> Literal["react_agent", "gather_requirements"]: 62 | if state.get('requirements'): 63 | return "react_agent" 64 | else: 65 | return "gather_requirements" 66 | 67 | def aider_config_mapper(state: AgentState) -> Dict[str, Any]: 68 | # Handle both API keys 69 | anthropic_api_key = state.get("anthropic_api_key") or os.getenv("ANTHROPIC_API_KEY") 70 | openai_api_key = state.get("openai_api_key") or os.getenv("OPENAI_API_KEY") 71 | 72 | if not (anthropic_api_key or openai_api_key): 73 | raise ValueError("Either ANTHROPIC_API_KEY or OPENAI_API_KEY must be set") 74 | logger.info(f"Using Anthropic API key: {anthropic_api_key[:50]}...") 75 | logger.info(f"Using OpenAI API key: {openai_api_key[:50]}...") 76 | return { 77 | "configurable": { 78 | "repo_path": state.get("repo_path"), 79 | "anthropic_api_key": anthropic_api_key, 80 | "openai_api_key": openai_api_key, # Add OpenAI key 81 | "aider_state": state.get("aider_state"), 82 | "chat_mode": state.get("chat_mode", ""), 83 | "model_name": state.get("model_name", "4o"), 84 | } 85 | } 86 | 87 | def route_gather(state: AgentState) -> Literal["react_agent", END]: 88 | if state["messages"] and isinstance(state["messages"][-1], AIMessage): 89 | return END 90 | return "react_agent" 91 | 92 | def route_react(state: AgentState) -> Literal["aider_node", "git_push_changes"]: 93 | if state["execution_status"] in ["planning", "executing"]: 94 | return "aider_node" 95 | elif state["execution_status"] == "complete" and state.get("accepted"): 96 | return "git_push_changes" 97 | return "aider_node" 98 | 99 | def route_tool(state: AgentState) -> Literal["react_agent"]: 100 | return "react_agent" 101 | 102 | def route_aider(state: AgentState) -> Literal["git_push_changes", "summarize"]: 103 | changes_req = state.get("router_analysis", {}).get("changes_req", True) 104 | if not changes_req: 105 | return "summarize" 106 | return "git_push_changes" 107 | 108 | def route_summarize(state: AgentState) -> Literal[END]: 109 | return END 110 | 111 | def route_message_node(state: AgentState) -> Literal["react_agent", "aider_node"]: 112 | route_type = state.get("router_analysis", {}).get("route_type", "hard") 113 | if route_type in ["chat", "easy"]: 114 | return "aider_node" 115 | else: 116 | return "react_agent" 117 | 118 | def route_post_critique(state: AgentState) -> Literal["git_push_changes", "react_agent"]: 119 | try: 120 | post_critique_result = state.get('step_results', {}).get('post_critique_router', {}) 121 | decision = post_critique_result.get('args', {}).get('decision', 'N') 122 | return "git_push_changes" if decision == 'Y' else "react_agent" 123 | except Exception as e: 124 | logger.error(f"Error in route_post_critique: {str(e)}") 125 | return "react_agent" 126 | 127 | def route_human(state: AgentState) -> Literal["aider_node", "git_push_changes", "diff_changes"]: 128 | """Route after human interaction based on state flags.""" 129 | if state.get("accepted"): 130 | return "git_push_changes" 131 | elif state.get("show_diff", False): 132 | state["show_diff"] = False 133 | return "diff_changes" 134 | return "aider_node" 135 | 136 | def route_diff(state: AgentState) -> Literal["human_interaction"]: 137 | """Route after showing diffs - always return to human interaction.""" 138 | return "human_interaction" 139 | 140 | def human_interaction(state: AgentState) -> AgentState: 141 | """Node for handling human interaction with the agent.""" 142 | try: 143 | if state.get('messages'): 144 | human_input = interrupt("Please provide your message (or enter {'accept': true} to finish, or {'show_diff': true} to see changes):") 145 | 146 | logger.debug(f"Received human input: {human_input}") 147 | 148 | # Check if input is a dictionary 149 | if isinstance(human_input, dict): 150 | if human_input.get("accept"): 151 | state["accepted"] = True 152 | return state 153 | elif human_input.get("show_diff"): 154 | state["show_diff"] = True 155 | return state 156 | 157 | # Add human message to state 158 | state.setdefault("messages", []).append(HumanMessage(content=human_input)) 159 | 160 | return state 161 | 162 | except Exception as e: 163 | logger.error(f"Error in human_interaction: {str(e)}") 164 | raise 165 | 166 | class Engineer: 167 | def __init__(self): 168 | self.base_repos_dir = "/repos" 169 | os.makedirs(self.base_repos_dir, exist_ok=True) 170 | self.test_repo_url = "https://github.com/RVCA212/portfolio-starter-kit" 171 | self.test_user_id = "test_user_123" 172 | 173 | async def process_request(self, input_state: InputState) -> dict: 174 | repo_url = input_state.get('repo_url', self.test_repo_url) 175 | user_id = input_state.get('user_id', self.test_user_id) 176 | query = input_state.get('query', '') 177 | github_token = input_state.get('github_token', '') 178 | anthropic_api_key = input_state.get('anthropic_api_key', '') 179 | openai_api_key = input_state.get('openai_api_key', '') # Add OpenAI key 180 | chat_mode = input_state.get('chat_mode', '') 181 | model_name = input_state.get('model_name', '4o') # Default to OpenAI model 182 | logger.info(f"Received chat_mode in input_state: '{chat_mode}'") 183 | logger.info(f"Using model: {model_name}") 184 | 185 | if not github_token: 186 | raise ValueError("GitHub token is required in input_state") 187 | if not repo_url: 188 | raise ValueError("Repository URL is required") 189 | 190 | # Validate API keys based on model 191 | if model_name in ['haiku', 'sonnet'] and not anthropic_api_key: 192 | raise ValueError(f"Anthropic API key is required for model {model_name}") 193 | elif model_name in ['4o', 'o1', 'gpt-4o-mini'] and not openai_api_key: 194 | raise ValueError(f"OpenAI API key is required for model {model_name}") 195 | 196 | repo_path = os.path.join("/repos", user_id.lstrip('/')) 197 | os.makedirs(repo_path, exist_ok=True) 198 | 199 | initial_state = initialize_state( 200 | repo_url=repo_url, 201 | github_token=github_token, 202 | repo_path=repo_path, 203 | anthropic_api_key=anthropic_api_key, 204 | openai_api_key=openai_api_key, 205 | chat_mode=chat_mode, 206 | model_name=model_name, 207 | query=query 208 | ) 209 | 210 | # Set up memory checkpointer for interrupts 211 | memory = MemorySaver() 212 | graph_with_memory = workflow.compile(checkpointer=memory) 213 | 214 | # Run the workflow 215 | final_state = await run_workflow(graph_with_memory, initial_state) 216 | return final_state 217 | 218 | 219 | # Register nodes 220 | aider_node = create_aider_node() 221 | workflow = StateGraph(AgentState) 222 | 223 | # Add nodes 224 | workflow.add_node("setup_node", setup_repository) 225 | workflow.add_node("human_interaction", human_interaction) 226 | workflow.add_node("aider_node", aider_node) 227 | workflow.add_node("git_push_changes", git_push_changes) 228 | workflow.add_node("diff_changes", show_file_diffs) 229 | 230 | # Set entry point and edges with routing 231 | workflow.set_entry_point("setup_node") 232 | workflow.add_edge("setup_node", "aider_node") 233 | workflow.add_edge("aider_node", "human_interaction") 234 | workflow.add_conditional_edges( 235 | "human_interaction", 236 | route_human 237 | ) 238 | workflow.add_edge("diff_changes", "human_interaction") 239 | workflow.add_edge("git_push_changes", END) 240 | 241 | graph = workflow.compile() 242 | 243 | async def run_workflow(graph, initial_state): 244 | """Run workflow with properly configured state.""" 245 | try: 246 | logger.debug("=== Workflow Starting ===") 247 | logger.debug(f"Initial state chat_mode: '{initial_state.get('chat_mode')}'") 248 | logger.debug(f"Initial state model_name: '{initial_state.get('model_name')}'") 249 | logger.debug(f"Initial state keys: {list(initial_state.keys())}") 250 | logger.debug(f"Initial state repo_path: {initial_state.get('repo_path')}") 251 | logger.debug(f"Initial state anthropic_api_key present: {'yes' if initial_state.get('anthropic_api_key') else 'no'}") 252 | 253 | # Validate required state values 254 | required_keys = ["repo_path", "anthropic_api_key", "github_token"] 255 | missing_keys = [key for key in required_keys if not initial_state.get(key)] 256 | 257 | logger.debug("=== State Validation ===") 258 | logger.debug(f"Required keys: {required_keys}") 259 | logger.debug(f"Missing keys: {missing_keys}") 260 | 261 | if missing_keys: 262 | logger.error(f"Missing required state values: {', '.join(missing_keys)}") 263 | raise ValueError(f"Missing required state values: {', '.join(missing_keys)}") 264 | 265 | # Update config to include OpenAI API key 266 | config = { 267 | "configurable": { 268 | "repo_path": initial_state["repo_path"], 269 | "anthropic_api_key": initial_state["anthropic_api_key"], 270 | "openai_api_key": initial_state.get("openai_api_key"), # Add OpenAI key 271 | "aider_state": initial_state.get("aider_state"), 272 | "github_token": initial_state["github_token"], 273 | "chat_mode": initial_state.get("chat_mode", ""), 274 | "model_name": initial_state.get("model_name", "4o"), 275 | } 276 | } 277 | logger.debug("=== Config Creation ===") 278 | logger.debug(f"Created config: {json.dumps(config, default=str)}") 279 | 280 | # Run graph 281 | logger.debug("=== Running Graph ===") 282 | state = await graph.arun( 283 | inputs=initial_state, 284 | config=config 285 | ) 286 | 287 | logger.debug("=== Workflow Completed ===") 288 | return state 289 | 290 | except Exception as e: 291 | logger.error(f"Error running workflow: {str(e)}") 292 | logger.error("Full traceback:", exc_info=True) 293 | raise -------------------------------------------------------------------------------- /src/langgraph_engineer/interactive_aider.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from pathlib import Path 4 | from typing import AsyncGenerator, Tuple, Optional, List, Dict, Any 5 | from aider.io import InputOutput 6 | from aider.coders import Coder 7 | from aider.main import main as cli_main 8 | from langchain_core.messages import AIMessageChunk 9 | import os 10 | import sys 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class StreamingAiderIO(InputOutput): 15 | """Custom IO class that captures streaming output from Aider.""" 16 | 17 | def __init__(self, *args, **kwargs): 18 | # Match GUI's non-interactive settings 19 | kwargs.update({ 20 | 'yes': True, 21 | 'pretty': False, 22 | 'fancy_input': False, 23 | 'dry_run': False 24 | }) 25 | super().__init__(*args, **kwargs) 26 | self._buffer = [] 27 | self.stream = True 28 | self.yield_stream = True 29 | self.callbacks = [] # Add this line to store streaming callbacks 30 | self.pretty = False # Important for diff formatting 31 | self._last_commit_hash = None 32 | self._last_commit_message = None 33 | self._edited_files = set() 34 | 35 | async def stream_callback(self, chunk): 36 | """Async method to handle streaming chunks""" 37 | for callback in self.callbacks: 38 | await callback(chunk) 39 | 40 | def tool_output(self, msg, log_only=False): 41 | """Capture tool output for streaming.""" 42 | chunk = AIMessageChunk( 43 | content=msg, 44 | additional_kwargs={ 45 | "type": "message", 46 | "files": [] 47 | } 48 | ) 49 | 50 | # Create task to run stream callback 51 | if hasattr(asyncio, 'get_running_loop'): 52 | loop = asyncio.get_running_loop() 53 | loop.create_task(self.stream_callback(chunk)) 54 | 55 | self._buffer.append(chunk) 56 | super().tool_output(msg, log_only=log_only) 57 | 58 | def tool_error(self, msg): 59 | """Capture error messages for streaming.""" 60 | self._buffer.append(AIMessageChunk( 61 | content=msg, 62 | additional_kwargs={ 63 | "type": "error", 64 | "files": [] 65 | } 66 | )) 67 | super().tool_error(msg) 68 | 69 | def tool_warning(self, msg): 70 | """Capture warning messages for streaming.""" 71 | self._buffer.append(AIMessageChunk( 72 | content=msg, 73 | additional_kwargs={ 74 | "type": "warning", 75 | "files": [] 76 | } 77 | )) 78 | super().tool_warning(msg) 79 | 80 | def assistant_output(self, msg, pretty=False): 81 | """Capture assistant output for streaming.""" 82 | self._buffer.append(AIMessageChunk( 83 | content=msg, 84 | additional_kwargs={ 85 | "type": "message", 86 | "files": [] 87 | } 88 | )) 89 | super().assistant_output(msg, pretty=pretty) 90 | 91 | def get_buffer(self): 92 | """Get and clear the buffer.""" 93 | buffer = self._buffer 94 | self._buffer = [] 95 | return buffer 96 | 97 | def capture_edit_info(self, commit_hash=None, commit_message=None, files=None, diff=None): 98 | """Capture edit information for streaming""" 99 | edit_chunk = { 100 | "type": "edit", 101 | "content": diff or "", 102 | "commit_hash": commit_hash, 103 | "commit_message": commit_message, 104 | "files": list(files) if files else [] 105 | } 106 | 107 | chunk = AIMessageChunk( 108 | content=diff or "", 109 | additional_kwargs=edit_chunk 110 | ) 111 | 112 | self._buffer.append(chunk) 113 | return chunk 114 | 115 | class InteractiveAider: 116 | """A wrapper class for Aider to handle interactive code editing.""" 117 | 118 | def __init__( 119 | self, 120 | repo_path: str, 121 | model: str, 122 | api_base: Optional[str] = None, 123 | api_key: Optional[str] = None, 124 | encoding: str = "utf-8", 125 | ): 126 | self.repo_path = Path(repo_path).resolve() 127 | self.model = model 128 | self.api_base = api_base 129 | self.api_key = api_key 130 | self.encoding = encoding 131 | 132 | # Store original environment variables 133 | self._original_anthropic_key = os.environ.get("ANTHROPIC_API_KEY") 134 | self._original_openai_key = os.environ.get("OPENAI_API_KEY") 135 | 136 | # Set API key in environment based on model 137 | if api_key: 138 | # Check if model is an Anthropic model 139 | anthropic_models = ["sonnet", "haiku"] # Add all Anthropic models here 140 | if self.model in anthropic_models: 141 | logger.info(f"Using provided Anthropic API key from state for model {self.model}") 142 | os.environ["ANTHROPIC_API_KEY"] = api_key 143 | # Temporarily unset OpenAI key to prevent any confusion 144 | if "OPENAI_API_KEY" in os.environ: 145 | del os.environ["OPENAI_API_KEY"] 146 | else: # Default to OpenAI 147 | logger.info(f"Using provided OpenAI API key from state for model {self.model}") 148 | os.environ["OPENAI_API_KEY"] = api_key 149 | # Temporarily unset Anthropic key to prevent any confusion 150 | if "ANTHROPIC_API_KEY" in os.environ: 151 | del os.environ["ANTHROPIC_API_KEY"] 152 | 153 | # Get coder instance 154 | try: 155 | self.coder = self._get_coder() 156 | except Exception as e: 157 | # Restore original environment variables in case of error 158 | self._restore_env_vars() 159 | raise e 160 | 161 | # Initialize and attach our custom IO 162 | self.io = StreamingAiderIO( 163 | encoding=self.encoding, 164 | yes=True, 165 | pretty=False, 166 | fancy_input=False, 167 | dry_run=False 168 | ) 169 | self.io.yes = True 170 | self.coder.commands.io = self.io 171 | 172 | # Add reflection tracking 173 | self.num_reflections = 0 174 | self.max_reflections = 3 175 | 176 | def _restore_env_vars(self): 177 | """Restore original environment variables""" 178 | if hasattr(self, '_original_anthropic_key'): 179 | if self._original_anthropic_key is not None: 180 | os.environ["ANTHROPIC_API_KEY"] = self._original_anthropic_key 181 | elif "ANTHROPIC_API_KEY" in os.environ: 182 | del os.environ["ANTHROPIC_API_KEY"] 183 | 184 | if hasattr(self, '_original_openai_key'): 185 | if self._original_openai_key is not None: 186 | os.environ["OPENAI_API_KEY"] = self._original_openai_key 187 | elif "OPENAI_API_KEY" in os.environ: 188 | del os.environ["OPENAI_API_KEY"] 189 | 190 | def __del__(self): 191 | """Cleanup when object is destroyed""" 192 | self._restore_env_vars() 193 | 194 | def _get_coder(self): 195 | """Initialize and validate coder instance similar to Streamlit app.""" 196 | # Validate API key is set for selected model 197 | anthropic_models = ["sonnet", "haiku"] 198 | if self.model in anthropic_models: 199 | if not self.api_key: 200 | raise ValueError(f"Anthropic API key not found for {self.model} model") 201 | else: 202 | if not self.api_key: 203 | raise ValueError(f"OpenAI API key not found for {self.model} model") 204 | 205 | # Add API key to command arguments 206 | api_flag = "--anthropic-api-key" if self.model in anthropic_models else "--openai-api-key" 207 | argv = [ 208 | "--model", self.model, 209 | "--yes-always", 210 | "--stream", 211 | "--map-refresh", "auto", 212 | api_flag, self.api_key, # Explicitly pass API key 213 | str(self.repo_path) 214 | ] 215 | argv = [arg for arg in argv if arg] 216 | 217 | coder = cli_main( 218 | argv=argv, 219 | input=None, 220 | output=None, 221 | force_git_root=str(self.repo_path), 222 | return_coder=True 223 | ) 224 | 225 | # Validate coder instance 226 | if not isinstance(coder, Coder): 227 | raise ValueError("Failed to create valid Coder instance") 228 | 229 | # Validate repo 230 | if not coder.repo: 231 | raise ValueError("Aider can currently only be used inside a git repo") 232 | 233 | # Ensure chat mode settings 234 | coder.yield_stream = True 235 | coder.stream = True 236 | 237 | coder.pretty = False # Important for diff formatting 238 | return coder 239 | 240 | def auto_add_files(self, file_paths: List[str]): 241 | """ 242 | Automatically add the specified file paths to the Aider session, 243 | just as if the user had manually added them. 244 | """ 245 | for path in file_paths: 246 | # Use add_rel_fname to add files to the chat 247 | self.coder.add_rel_fname(path) 248 | self.io.yes = True 249 | # Optionally log or confirm the added files 250 | logger.debug(f"Auto-added files to chat: {file_paths}") 251 | 252 | async def get_diff_info(self): 253 | """Get diff information from the latest changes""" 254 | if not self.coder.last_aider_commit_hash: 255 | return None 256 | 257 | if (self.io._last_commit_hash != self.coder.last_aider_commit_hash): 258 | commits = f"{self.coder.last_aider_commit_hash}~1" 259 | diff = self.coder.repo.diff_commits( 260 | pretty=False, 261 | from_rev=commits, 262 | to_rev=self.coder.last_aider_commit_hash 263 | ) 264 | 265 | edit_info = { 266 | "commit_hash": self.coder.last_aider_commit_hash, 267 | "commit_message": self.coder.last_aider_commit_message, 268 | "files": list(self.coder.aider_edited_files), 269 | "diff": diff 270 | } 271 | 272 | self.io._last_commit_hash = self.coder.last_aider_commit_hash 273 | return edit_info 274 | return None 275 | 276 | async def execute_command(self, message: str) -> AsyncGenerator[Dict[str, Any], None]: 277 | """Execute command with enhanced streaming support.""" 278 | try: 279 | accumulated_content = "" 280 | prompt = message 281 | 282 | # Set up streaming callback 283 | async def stream_handler(chunk): 284 | if isinstance(chunk, AIMessageChunk): 285 | yield { 286 | "type": chunk.additional_kwargs.get("type", "message"), 287 | "content": chunk.content, 288 | "files": chunk.additional_kwargs.get("files", []) 289 | } 290 | 291 | # Add our stream handler to IO callbacks 292 | self.io.callbacks.append(stream_handler) 293 | 294 | while prompt: 295 | stream = self.coder.run_stream(prompt) 296 | for chunk in stream: 297 | # Process the current chunk in real-time 298 | if isinstance(chunk, str): 299 | chunk_dict = {"type": "message", "content": chunk, "files": []} 300 | accumulated_content += chunk 301 | yield chunk_dict 302 | elif hasattr(chunk, 'content'): 303 | chunk_dict = {"type": "message", "content": chunk.content, "files": []} 304 | accumulated_content += chunk.content 305 | yield chunk_dict 306 | elif hasattr(chunk, 'diff'): 307 | chunk_dict = { 308 | "type": "edit", 309 | "content": chunk.diff, 310 | "files": getattr(chunk, 'files', []) 311 | } 312 | yield chunk_dict 313 | 314 | # NEW: Flush the buffer immediately after each chunk 315 | buffered_chunks = self.io.get_buffer() 316 | for buffered_chunk in buffered_chunks: 317 | yield { 318 | "type": buffered_chunk.additional_kwargs.get("type", "message"), 319 | "content": buffered_chunk.content, 320 | "files": buffered_chunk.additional_kwargs.get("files", []) 321 | } 322 | 323 | # Check for reflections 324 | prompt = None 325 | if hasattr(self.coder, 'reflected_message') and self.coder.reflected_message: 326 | if self.num_reflections < self.max_reflections: 327 | self.num_reflections += 1 328 | prompt = self.coder.reflected_message 329 | yield { 330 | "type": "reflection", 331 | "content": f"Reflection {self.num_reflections}: {prompt}", 332 | "files": [] 333 | } 334 | 335 | # Final message completion 336 | yield { 337 | "type": "complete", 338 | "content": accumulated_content, 339 | "files": [] 340 | } 341 | 342 | # Clean up 343 | self.io.callbacks.remove(stream_handler) 344 | 345 | except Exception as e: 346 | logger.error(f"Error in execute_command: {str(e)}") 347 | logger.error("Full traceback:", exc_info=True) 348 | yield {"type": "error", "content": str(e), "files": []} 349 | raise -------------------------------------------------------------------------------- /src/langgraph_engineer/tools.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Annotated, List, Optional, TypedDict, Literal, Union, ClassVar, Type, Any, AsyncGenerator 2 | from typing_extensions import TypedDict 3 | from langchain_core.tools import tool, BaseTool 4 | from langchain_core.runnables import RunnableConfig 5 | from langgraph.prebuilt import ToolNode, InjectedState 6 | from git import Repo, RemoteProgress, GitCommandError 7 | from git.exc import GitCommandError, InvalidGitRepositoryError 8 | import os 9 | import logging 10 | import time 11 | import shutil 12 | from pathlib import Path 13 | from urllib.parse import urlparse, urlunparse 14 | from dotenv import load_dotenv 15 | import subprocess 16 | import asyncio 17 | from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk 18 | from langgraph.graph.message import add_messages 19 | from langgraph.store.base import BaseStore 20 | import glob 21 | import shlex 22 | from pydantic import BaseModel, Field, ConfigDict 23 | import json 24 | from langchain_core.callbacks import CallbackManagerForToolRun 25 | from pydantic import ConfigDict 26 | from logging import Logger 27 | import traceback 28 | import pty 29 | import fcntl 30 | import termios 31 | import struct 32 | from langgraph_engineer.state import ( 33 | AiderState, 34 | AgentState 35 | ) 36 | from typing import TypedDict, List 37 | from langchain_core.messages import BaseMessage 38 | from langgraph.graph.message import add_messages 39 | import sys 40 | 41 | 42 | 43 | load_dotenv() 44 | logger = logging.getLogger(__name__) 45 | 46 | openai_api_key = os.getenv('OPENAI_API_KEY') 47 | 48 | # Add these constants at the top of the file 49 | OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') 50 | ANTHROPIC_API_KEY = os.getenv('ANTHROPIC_API_KEY') 51 | 52 | if not OPENAI_API_KEY: 53 | logger.warning("OPENAI_API_KEY not found in environment variables") 54 | if not ANTHROPIC_API_KEY: 55 | logger.warning("ANTHROPIC_API_KEY not found in environment variables") 56 | 57 | # Alias for backward compatibility 58 | RepoState = AgentState 59 | 60 | 61 | 62 | # class WriteFileArgs(BaseModel): 63 | # """Arguments for write_file tool""" 64 | # file_path: str = Field(description="Path to the file relative to repository root") 65 | # changes: Optional[List[FunctionChange]] = Field( 66 | # None, 67 | # description="List of function-level changes to apply" 68 | # ) 69 | # content: Optional[str] = Field( 70 | # None, 71 | # description="Full content if replacing entire file" 72 | # ) 73 | 74 | # @validator('changes', 'content', allow_reuse=True) 75 | # def validate_changes_or_content(cls, v: Optional[str], values: dict, field: str) -> Optional[str]: 76 | # if field == 'changes' and not v and 'content' not in values: 77 | # raise ValueError("Either changes or content must be provided") 78 | # if field == 'content' and not v and 'changes' not in values: 79 | # raise ValueError("Either changes or content must be provided") 80 | # return v 81 | 82 | def validate_file_path(repo_path: str, file_path: str) -> Path: 83 | """Validate and normalize file path.""" 84 | try: 85 | repo_path = Path(repo_path).resolve() 86 | full_path = (repo_path / file_path).resolve() 87 | 88 | # Security check: ensure path is within repo 89 | if not str(full_path).startswith(str(repo_path)): 90 | raise ValueError(f"File path {file_path} is outside repository") 91 | 92 | return full_path 93 | except Exception as e: 94 | raise ValueError(f"Invalid file path: {str(e)}") 95 | 96 | 97 | def validate_repo_path(state: dict) -> str: 98 | """Validate repository path exists in state.""" 99 | repo_path = state.get('repo_path') 100 | if not repo_path: 101 | raise ValueError("Repository path not found in state") 102 | 103 | # Add retry logic for filesystem sync 104 | max_retries = 3 105 | retry_delay = 0.5 # seconds 106 | 107 | for attempt in range(max_retries): 108 | if os.path.exists(repo_path): 109 | return repo_path 110 | if attempt < max_retries - 1: 111 | time.sleep(retry_delay) 112 | 113 | raise ValueError(f"Repository path does not exist: {repo_path}") 114 | 115 | 116 | 117 | 118 | class StepResult(BaseModel): 119 | """Result of a single execution step.""" 120 | step_id: str 121 | output: str 122 | success: bool 123 | 124 | class AiderToReactOutput(BaseModel): 125 | """Output format for aider commands.""" 126 | response: str 127 | success: bool = True 128 | error: Optional[str] = None 129 | 130 | class AiderShellInput(BaseModel): 131 | """Input schema for aider shell commands.""" 132 | message: str = Field(..., description="The message/instruction for aider") 133 | files: Union[str, List[str]] = Field( 134 | default=".", 135 | description="Files to process. Can be a single file or a list of files." 136 | ) 137 | 138 | class AiderCommand(BaseModel): 139 | """Model for aider command input""" 140 | command_type: Literal['ask', 'code'] = Field(description="Type of aider command") 141 | prompt: str = Field(description="The prompt to send to aider") 142 | 143 | def to_cli_format(self) -> str: 144 | """Convert to CLI command format""" 145 | return f"/{self.command_type} {self.prompt}" 146 | 147 | 148 | 149 | class ForceCloneInput(BaseModel): 150 | """Input schema for force_clone tool""" 151 | url: str = Field(description="Repository URL to clone") 152 | path: str = Field(description="Local path to clone to") 153 | state: Optional[Dict] = Field(default=None, description="State object") 154 | config: Optional[Dict] = Field(default=None, description="Config object") 155 | 156 | @tool(args_schema=ForceCloneInput) 157 | async def force_clone( 158 | input: ForceCloneInput, 159 | ) -> dict: 160 | """Force clone a repository to the specified path with retry logic.""" 161 | try: 162 | # Extract values from input 163 | url = input.url 164 | path = input.path 165 | state = input.state 166 | config = input.config 167 | 168 | logger.info(f"Starting clone operation for URL: {url.replace(state.get('github_token', ''), '*****')}") 169 | logger.info(f"Target path: {path}") 170 | logger.debug(f"State keys available: {list(state.keys() if state else [])}") 171 | 172 | # Validate state 173 | if not state: 174 | logger.error("State object is missing") 175 | return { 176 | "status": "error", 177 | "error": "State object is missing" 178 | } 179 | 180 | # Validate github token exists in state 181 | github_token = state.get('github_token') 182 | if not github_token: 183 | logger.error("GitHub token not found in state") 184 | return { 185 | "status": "error", 186 | "error": "GitHub token not found in state" 187 | } 188 | 189 | # Log directory state before cleanup 190 | if os.path.exists(path): 191 | logger.info(f"Existing directory found at {path}, will be removed") 192 | try: 193 | shutil.rmtree(path) 194 | logger.debug("Successfully cleaned up existing directory") 195 | except Exception as e: 196 | logger.error(f"Error cleaning up directory: {str(e)}") 197 | raise 198 | 199 | # Create parent directory 200 | try: 201 | os.makedirs(os.path.dirname(path), exist_ok=True) 202 | logger.debug(f"Created parent directory: {os.path.dirname(path)}") 203 | except Exception as e: 204 | logger.error(f"Failed to create parent directory: {str(e)}") 205 | raise 206 | 207 | # Parse and modify the URL to include the token 208 | try: 209 | parsed = urlparse(url) 210 | auth_url = urlunparse(parsed._replace( 211 | netloc=f"{github_token}@{parsed.netloc}" 212 | )) 213 | logger.debug(f"Authenticated URL created (token hidden): {auth_url.replace(github_token, '*****')}") 214 | except Exception as e: 215 | logger.error(f"Failed to create authenticated URL: {str(e)}") 216 | raise 217 | 218 | # Retry logic for clone operation 219 | max_retries = 3 220 | retry_delay = 2 # seconds 221 | last_exception = None 222 | 223 | for attempt in range(max_retries): 224 | logger.info(f"Starting clone attempt {attempt + 1}/{max_retries}") 225 | try: 226 | # Configure git with longer timeout 227 | git_config = { 228 | 'http.postBuffer': '524288000', # 500MB buffer 229 | 'http.lowSpeedLimit': '1000', # 1KB/s minimum speed 230 | 'http.lowSpeedTime': '60', # 60 seconds timeout 231 | 'core.compression': '0', # Disable compression 232 | 'http.version': 'HTTP/1.1', # Force HTTP/1.1 233 | 'git.protocol.version': '1' # Force Git protocol version 1 234 | } 235 | 236 | logger.debug("Initializing repository") 237 | repo = Repo.init(path) 238 | 239 | logger.debug("Applying Git configurations:") 240 | for key, value in git_config.items(): 241 | try: 242 | with repo.config_writer() as git_config_writer: 243 | git_config_writer.set_value(key.split('.')[0], key.split('.')[1], value) 244 | logger.debug(f"Set {key}={value}") 245 | except Exception as config_error: 246 | logger.warning(f"Failed to set config {key}: {str(config_error)}") 247 | 248 | # Perform the clone steps 249 | logger.debug("Creating remote 'origin'") 250 | repo.create_remote('origin', auth_url) 251 | 252 | logger.debug("Starting fetch operation") 253 | fetch_info = repo.remote('origin').fetch(progress=GitProgressHandler()) 254 | logger.debug(f"Fetch completed: {fetch_info}") 255 | 256 | logger.debug("Starting pull operation") 257 | pull_info = repo.remote('origin').pull('main', progress=GitProgressHandler()) 258 | logger.debug(f"Pull completed: {pull_info}") 259 | 260 | # Verify the clone 261 | if not os.path.exists(path): 262 | raise ValueError(f"Repository was cloned but path {path} does not exist") 263 | 264 | # Verify git repository integrity 265 | try: 266 | repo.git.status() 267 | logger.debug("Repository integrity verified") 268 | except Exception as integrity_error: 269 | logger.error(f"Repository integrity check failed: {str(integrity_error)}") 270 | raise 271 | 272 | logger.info(f"Successfully cloned repository to {path} on attempt {attempt + 1}") 273 | 274 | # Update state with repo path 275 | if state is not None: 276 | state["repo_path"] = path 277 | logger.debug("Updated state with repo path") 278 | 279 | return { 280 | "status": "success", 281 | "repo_path": path, 282 | "message": f"Repository cloned successfully to {path}" 283 | } 284 | 285 | except Exception as e: 286 | last_exception = e 287 | logger.error(f"Clone attempt {attempt + 1} failed with error: {str(e)}") 288 | logger.error("Full error traceback:", exc_info=True) 289 | 290 | # Log system state 291 | try: 292 | logger.debug(f"Directory exists: {os.path.exists(path)}") 293 | if os.path.exists(path): 294 | logger.debug(f"Directory contents: {os.listdir(path)}") 295 | except Exception as debug_error: 296 | logger.error(f"Error during debug logging: {str(debug_error)}") 297 | 298 | # Cleanup failed attempt 299 | if os.path.exists(path): 300 | try: 301 | shutil.rmtree(path) 302 | logger.debug("Cleaned up failed attempt directory") 303 | except Exception as cleanup_error: 304 | logger.error(f"Error cleaning up failed attempt: {str(cleanup_error)}") 305 | 306 | if attempt < max_retries - 1: 307 | delay = retry_delay * (attempt + 1) 308 | logger.info(f"Waiting {delay} seconds before retry...") 309 | await asyncio.sleep(delay) 310 | continue 311 | break 312 | 313 | # If all retries failed, log final error state 314 | logger.error(f"All {max_retries} clone attempts failed. Last error: {str(last_exception)}") 315 | return { 316 | "status": "error", 317 | "error": str(last_exception), 318 | "details": { 319 | "attempts": attempt + 1, 320 | "last_error": str(last_exception), 321 | "path": path, 322 | "url": url.replace(github_token, '*****') 323 | } 324 | } 325 | 326 | except Exception as e: 327 | logger.error(f"Unexpected error in force_clone: {str(e)}") 328 | logger.error("Full error traceback:", exc_info=True) 329 | return { 330 | "status": "error", 331 | "error": str(e), 332 | "details": { 333 | "type": "unexpected_error", 334 | "path": path if 'path' in locals() else None, 335 | "url": url.replace(github_token, '*****') if 'url' in locals() else None 336 | } 337 | } 338 | 339 | class ForceBranchInput(BaseModel): 340 | """Input schema for force_branch tool""" 341 | branch_name: str = Field(description="Name of branch to create") 342 | config: Optional[Dict] = Field(default=None, description="Config object") 343 | state: Dict = Field(description="State object") 344 | 345 | @tool(args_schema=ForceBranchInput) 346 | async def force_branch( 347 | input: ForceBranchInput, 348 | ) -> dict: 349 | """Force create and checkout a new branch.""" 350 | try: 351 | repo_path = validate_repo_path(input.state) 352 | repo = Repo(repo_path) 353 | 354 | # Check if branch already exists 355 | branch_name = input.branch_name 356 | if branch_name in repo.heads: 357 | # If branch exists, just checkout 358 | current = repo.heads[branch_name] 359 | else: 360 | # Create and checkout new branch 361 | current = repo.create_head(branch_name) 362 | 363 | # Checkout the branch 364 | current.checkout() 365 | 366 | logger.info(f"Created and checked out branch: {branch_name}") 367 | return { 368 | "status": "success", 369 | "branch_name": branch_name, 370 | "message": f"Created and checked out branch: {branch_name}" 371 | } 372 | except Exception as e: 373 | logger.error(f"Error creating branch: {str(e)}") 374 | return { 375 | "status": "error", 376 | "error": str(e) 377 | } 378 | 379 | # Add a proper input schema for git_status 380 | class GitStatusInput(BaseModel): 381 | """Input schema for git_status tool""" 382 | config: Optional[Dict] = Field(default=None, description="Tool configuration") 383 | state: Dict = Field(description="Current state") 384 | 385 | @tool(args_schema=GitStatusInput) 386 | def git_status( 387 | config: RunnableConfig, 388 | state: Annotated[RepoState, InjectedState] 389 | ) -> str: 390 | """Get the current git status of the repository.""" 391 | try: 392 | repo = Repo(state['repo_path']) 393 | status = [] 394 | 395 | # Get branch info 396 | try: 397 | branch = repo.active_branch 398 | status.append(f"On branch {branch.name}") 399 | except TypeError: 400 | status.append("Not currently on any branch") 401 | 402 | # Get tracking info 403 | if not repo.head.is_detached: 404 | tracking_branch = repo.active_branch.tracking_branch() 405 | if tracking_branch: 406 | status.append(f"Tracking {tracking_branch.name}") 407 | 408 | # Get changed files using GitPython's native methods 409 | changed_files = [item.a_path for item in repo.index.diff(None)] 410 | staged_files = [item.a_path for item in repo.index.diff('HEAD')] 411 | untracked = repo.untracked_files 412 | 413 | if staged_files: 414 | status.append("\nChanges to be committed:") 415 | status.extend(f" modified: {file}" for file in staged_files) 416 | 417 | if changed_files: 418 | status.append("\nChanges not staged for commit:") 419 | status.extend(f" modified: {file}" for file in changed_files) 420 | 421 | if untracked: 422 | status.append("\nUntracked files:") 423 | status.extend(f" {file}" for file in untracked) 424 | 425 | return '\n'.join(status) 426 | except InvalidGitRepositoryError: 427 | return "Error: Not a valid git repository" 428 | except Exception as e: 429 | return f"Error getting git status: {str(e)}" 430 | 431 | # Add new input schemas before the tool definitions 432 | class GitAddInput(BaseModel): 433 | """Input schema for git_add tool""" 434 | file_path: str = Field(description="Path to file to stage") 435 | config: Optional[Dict] = Field(default=None, description="Tool configuration") 436 | state: Dict = Field(description="Current state") 437 | 438 | class GitCommitInput(BaseModel): 439 | """Input schema for git_commit tool""" 440 | message: str = Field(description="Commit message") 441 | config: Optional[Dict] = Field(default=None, description="Tool configuration") 442 | state: Dict = Field(description="Current state") 443 | 444 | class GitPushInput(BaseModel): 445 | """Input schema for git_push tool""" 446 | config: Optional[Dict] = Field(default=None, description="Tool configuration") 447 | state: Dict = Field(description="Current state") 448 | 449 | # Update tool decorators and keep existing implementations 450 | @tool(args_schema=GitAddInput) 451 | def git_add( 452 | file_path: str, 453 | config: RunnableConfig, 454 | state: Annotated[RepoState, InjectedState] 455 | ) -> str: 456 | """Stage a file for commit.""" 457 | try: 458 | repo = Repo(state['repo_path']) 459 | 460 | # Handle wildcards and multiple files 461 | if file_path == '.': 462 | # Stage all changes including untracked files 463 | repo.git.add(A=True) 464 | return "Successfully staged all changes" 465 | 466 | # Validate file exists 467 | full_path = Path(repo.working_dir) / file_path 468 | if not full_path.exists(): 469 | return f"Error: File {file_path} does not exist" 470 | 471 | # Stage specific file 472 | repo.index.add([file_path]) 473 | 474 | # Verify file was staged 475 | staged_files = [item.a_path for item in repo.index.diff('HEAD')] 476 | if file_path in staged_files: 477 | return f"Successfully staged {file_path}" 478 | else: 479 | return f"File {file_path} was not staged (no changes detected)" 480 | 481 | except Exception as e: 482 | return f"Error staging file: {str(e)}" 483 | 484 | @tool(args_schema=GitCommitInput) 485 | def git_commit( 486 | message: str, 487 | config: RunnableConfig, 488 | state: Annotated[RepoState, InjectedState] 489 | ) -> str: 490 | """Commit staged changes.""" 491 | try: 492 | repo = Repo(state['repo_path']) 493 | 494 | # Check if there are staged changes 495 | if not repo.index.diff('HEAD'): 496 | return "No changes staged for commit" 497 | 498 | # Configure author/committer if available in state 499 | author = None 500 | if 'git_author_name' in state and 'git_author_email' in state: 501 | author = f"{state['git_author_name']} <{state['git_author_email']}>" 502 | 503 | # Commit with optional author 504 | if author: 505 | commit = repo.index.commit(message, author=author) 506 | else: 507 | commit = repo.index.commit(message) 508 | 509 | # Return detailed commit info 510 | return (f"Successfully committed with hash: {commit.hexsha[:8]}\n" 511 | f"Author: {commit.author}\n" 512 | f"Message: {commit.message}") 513 | 514 | except Exception as e: 515 | return f"Error committing changes: {str(e)}" 516 | 517 | @tool(args_schema=GitPushInput) 518 | def git_push( 519 | config: RunnableConfig, 520 | state: Annotated[RepoState, InjectedState] 521 | ) -> str: 522 | """Push commits to remote repository with enhanced error handling and retry logic.""" 523 | try: 524 | repo_path = validate_repo_path(state) 525 | branch_name = state.get('branch_name') 526 | repo_url = state.get('repo_url') 527 | github_token = state.get('github_token') # Get token from state 528 | 529 | if not github_token: 530 | raise ValueError("GitHub token not found in state") 531 | 532 | logger.info(f"Push attempt with: path={repo_path}, branch={branch_name}") 533 | 534 | # Initialize repo 535 | repo = Repo(repo_path) 536 | 537 | # Configure the remote with authentication 538 | remote_url = configure_remote_with_auth(repo, repo_url, github_token) 539 | 540 | # Initialize progress handler 541 | progress = GitProgressHandler() 542 | 543 | # deploy 544 | 545 | 546 | # Add retry logic for push operation 547 | max_retries = 3 548 | retry_delay = 1 549 | last_error = None 550 | 551 | for attempt in range(max_retries): 552 | try: 553 | # Push with progress monitoring 554 | push_info = repo.remote('origin').push( 555 | refspec=f"refs/heads/{branch_name}:refs/heads/{branch_name}", 556 | force=True, 557 | progress=progress 558 | ) 559 | 560 | # Detailed push result checking 561 | for info in push_info: 562 | if info.flags & info.ERROR: 563 | raise GitCommandError(f"Push failed: {info.summary}") 564 | if info.flags & info.FAST_FORWARD: 565 | logger.info("Fast-forward push successful") 566 | if info.flags & info.FORCED_UPDATE: 567 | logger.info("Forced update successful") 568 | 569 | return f"Successfully pushed to branch: {branch_name}" 570 | 571 | except Exception as e: 572 | last_error = e 573 | logger.error(f"Push attempt {attempt + 1} failed: {e}") 574 | if attempt < max_retries - 1: 575 | time.sleep(retry_delay) 576 | continue 577 | break 578 | 579 | raise last_error if last_error else ValueError("Push failed with unknown error") 580 | 581 | except Exception as e: 582 | error_msg = f"Error pushing changes: {str(e)}" 583 | logger.error(error_msg, exc_info=True) 584 | return error_msg 585 | 586 | finally: 587 | # Cleanup sensitive information 588 | cleanup_remote(repo) 589 | 590 | @tool 591 | def git_diff( 592 | config: RunnableConfig, 593 | state: Annotated[RepoState, InjectedState] 594 | ) -> str: 595 | """Show the diff of the latest commit or working directory changes.""" 596 | try: 597 | repo = Repo(state['repo_path']) 598 | 599 | # Get the diff of staged and unstaged changes 600 | diff_output = [] 601 | 602 | # Check for staged changes (diff between HEAD and index) 603 | staged_diff = repo.git.diff('--cached') 604 | if staged_diff: 605 | diff_output.append("=== Staged Changes ===") 606 | diff_output.append(staged_diff) 607 | 608 | # Check for unstaged changes (diff between index and working tree) 609 | unstaged_diff = repo.git.diff() 610 | if unstaged_diff: 611 | diff_output.append("\n=== Unstaged Changes ===") 612 | diff_output.append(unstaged_diff) 613 | 614 | # If no current changes, show the last commit diff 615 | if not diff_output: 616 | if len(repo.heads) > 0: # Check if there are any commits 617 | last_commit = repo.head.commit 618 | if last_commit.parents: # If commit has a parent 619 | diff_output.append("=== Last Commit Diff ===") 620 | diff_output.append(repo.git.diff(f'{last_commit.parents[0].hexsha}..{last_commit.hexsha}')) 621 | else: # First commit 622 | diff_output.append("=== Initial Commit Diff ===") 623 | diff_output.append(repo.git.diff(last_commit.hexsha)) 624 | else: 625 | return "No commits in the repository yet." 626 | 627 | return "\n".join(diff_output) if diff_output else "No changes detected" 628 | 629 | except Exception as e: 630 | return f"Error getting diff: {str(e)}" 631 | 632 | 633 | # class AiderShellInput(BaseModel): 634 | # """Input schema for aider_shell tool.""" 635 | # message: str = Field(..., description="The message/instruction for aider.") 636 | # files: Union[str, List[str]] = Field( 637 | # default=".", 638 | # description="Files to process. Can be a single file or a list of files." 639 | # ) 640 | 641 | class AiderShellTool(BaseTool): 642 | """Tool to run aider shell commands.""" 643 | 644 | # Model configuration to ignore logger 645 | model_config = ConfigDict(arbitrary_types_allowed=True) 646 | 647 | # Class variables with proper type annotations 648 | name: ClassVar[str] = "aider_shell" 649 | description: ClassVar[str] = "Run aider commands for code modifications. Use this to interact with the codebase using aider." 650 | args_schema: ClassVar[Type[BaseModel]] = AiderShellInput 651 | # Add logger as ClassVar to indicate it's not a model field 652 | logger: ClassVar[Logger] = logging.getLogger(__name__) 653 | 654 | # Add state management and new fields for message/files 655 | state: Optional[AgentState] = Field(None, exclude=True) 656 | message: Optional[str] = Field(None, exclude=True) 657 | files: Optional[Union[str, List[str]]] = Field(None, exclude=True) 658 | 659 | def __init__(self, state: Optional[AgentState] = None): 660 | super().__init__() 661 | self.state = state 662 | 663 | def _run( 664 | self, 665 | message: str, 666 | files: Union[str, List[str]] = ".", 667 | run_manager: Optional[CallbackManagerForToolRun] = None, 668 | **kwargs 669 | ) -> str: 670 | """Run the tool synchronously.""" 671 | return asyncio.run(self._arun(message, files, run_manager, **kwargs)) 672 | 673 | async def _arun( 674 | self, 675 | message: Optional[str] = None, 676 | files: Union[str, List[str]] = ".", 677 | run_manager: Optional[CallbackManagerForToolRun] = None, 678 | **kwargs 679 | ) -> str: 680 | """Run the tool asynchronously.""" 681 | try: 682 | if not self.state: 683 | raise ValueError("State is required - ensure it's being passed to AiderShellTool.") 684 | 685 | # Ensure we have the API key 686 | anthropic_api_key = self.state.get('anthropic_api_key') or os.getenv('ANTHROPIC_API_KEY') 687 | if not anthropic_api_key: 688 | raise ValueError("ANTHROPIC_API_KEY not found in state or environment") 689 | openai_api_key = os.getenv('OPENAI_API_KEY') 690 | # Use pre-configured values if not provided in call 691 | message = message or self.message 692 | files = files or self.files 693 | 694 | if not message: 695 | raise ValueError("No message provided for aider command") 696 | 697 | self.logger.info(f"Aider command executing with message: {message}") 698 | self.logger.info(f"Aider command executing with files: {files}") 699 | 700 | # Initialize aider_state if None 701 | if self.state.get('aider_state') is None: 702 | self.state['aider_state'] = AiderState( 703 | initialized=True, 704 | model_name='gpt-4o-mini', 705 | conversation_history=[], 706 | last_files=[], 707 | waiting_for_input=False, 708 | setup_complete=True 709 | ) 710 | elif isinstance(self.state['aider_state'], dict): 711 | # Convert dict to AiderState if necessary 712 | self.state['aider_state'] = AiderState(**self.state['aider_state']) 713 | 714 | aider_state = self.state['aider_state'] 715 | 716 | repo_path = self.state.get('repo_path') 717 | if not repo_path: 718 | raise ValueError("Repository path not found in state") 719 | 720 | # Process files argument 721 | if isinstance(files, list): 722 | files = [str(Path(repo_path) / Path(f).name) for f in files] 723 | else: 724 | files = str(Path(repo_path) / Path(files).name) 725 | 726 | # Escape message for shell 727 | escaped_message = shlex.quote(message) 728 | 729 | # Construct aider command with --yes flag to avoid prompts 730 | aider_cmd = f"aider --yes --no-stream --message {escaped_message} {files}" 731 | 732 | self.logger.info(f"Executing aider command: {aider_cmd} in {repo_path}") 733 | 734 | # Create process 735 | process = await asyncio.create_subprocess_shell( 736 | aider_cmd, 737 | stdout=asyncio.subprocess.PIPE, 738 | stderr=asyncio.subprocess.PIPE, 739 | cwd=repo_path 740 | ) 741 | 742 | try: 743 | stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=300) 744 | stdout_decoded = stdout.decode() 745 | stderr_decoded = stderr.decode() 746 | 747 | self.logger.info(f"Aider stdout: {stdout_decoded}") 748 | if stderr_decoded: 749 | self.logger.error(f"Aider stderr: {stderr_decoded}") 750 | 751 | if process.returncode != 0: 752 | raise RuntimeError(f"Aider command failed with exit code {process.returncode}") 753 | 754 | response = self._process_aider_output(stdout_decoded, stderr_decoded) 755 | 756 | # Update conversation history 757 | conversation_entry = { 758 | "message": message, 759 | "files": files, 760 | "response": response, 761 | "working_directory": repo_path 762 | } 763 | aider_state.conversation_history.append(conversation_entry) 764 | aider_state.last_prompt = message 765 | aider_state.last_files = files if isinstance(files, list) else [files] 766 | 767 | # After processing aider output 768 | # Update state messages 769 | if self.state: 770 | if 'messages' not in self.state: 771 | self.state['messages'] = [] 772 | self.state['messages'].append(AIMessage(content=response)) 773 | 774 | return response 775 | 776 | except asyncio.TimeoutError: 777 | process.kill() 778 | raise TimeoutError("Aider command timed out after 300 seconds") 779 | 780 | except Exception as e: 781 | error_msg = f"Error in aider_shell: {str(e)}" 782 | self.logger.error(error_msg, exc_info=True) 783 | # Log the traceback for detailed debugging 784 | self.logger.error(traceback.format_exc()) 785 | 786 | # Return error with current state 787 | return str(e) 788 | 789 | def _process_aider_output(self, stdout: str, stderr: str) -> str: 790 | """Process aider output to extract meaningful response.""" 791 | # As the aider shell no longer prompts, we can focus on capturing the output 792 | # Split output into lines 793 | lines = stdout.strip().split('\n') 794 | 795 | # Remove any empty lines and irrelevant output (headers, footers) 796 | filtered_lines = [ 797 | line for line in lines 798 | if not line.strip().startswith('Aider v') 799 | and not line.strip().startswith('Main model:') 800 | and not line.strip().startswith('Git repo:') 801 | and not line.strip().startswith('Repo-map:') 802 | and not line.strip().startswith('VSCode') 803 | and not line.strip().startswith('Use /help') 804 | and not line.strip().startswith('Tokens:') 805 | and line.strip() 806 | ] 807 | 808 | response = '\n'.join(filtered_lines).strip() 809 | 810 | # Add any error messages if present 811 | if stderr.strip(): 812 | response += f"\nErrors:\n{stderr.strip()}" 813 | 814 | return response 815 | 816 | class SingleAiderShellTool(BaseTool): 817 | """Modified tool to handle single step execution""" 818 | 819 | # Model configuration to ignore logger 820 | model_config = ConfigDict(arbitrary_types_allowed=True) 821 | 822 | # Class variables with proper type annotations 823 | name: ClassVar[str] = "single_aider_shell" 824 | description: ClassVar[str] = "Execute a single aider command for code modifications" 825 | args_schema: ClassVar[Type[BaseModel]] = AiderShellInput 826 | state: Optional[AgentState] = Field(None, exclude=True) 827 | logger: ClassVar[Logger] = logging.getLogger(__name__) 828 | 829 | def _run( 830 | self, 831 | message: str = None, 832 | files: Union[str, List[str]] = ".", 833 | step_id: Optional[str] = None, 834 | run_manager: Optional[CallbackManagerForToolRun] = None, 835 | **kwargs 836 | ) -> StepResult: 837 | """Execute synchronously by running async method""" 838 | return asyncio.run(self._arun(message, files, step_id, run_manager, **kwargs)) 839 | 840 | async def _arun( 841 | self, 842 | message: str = None, 843 | files: Union[str, List[str]] = ".", 844 | step_id: Optional[str] = None, 845 | run_manager: Optional[CallbackManagerForToolRun] = None, 846 | **kwargs 847 | ) -> StepResult: 848 | """Execute single step and return structured result""" 849 | try: 850 | # Get message from state if not provided 851 | if not message and self.state: 852 | if self.state.get('router_analysis', {}).get('route_type') == 'single-step': 853 | # For single-step, use the original requirements 854 | message = self.state.get('requirements', '') 855 | step_id = "S1" # Single-step ID 856 | else: 857 | # For multi-step, get from current step 858 | current_step = self.state['steps'][self.state['current_step']] 859 | message = current_step.tool_args['message'] 860 | files = current_step.tool_args['files'] 861 | step_id = current_step.step_id 862 | 863 | if not message: 864 | raise ValueError("No message content provided for aider tool") 865 | 866 | # Create instance of parent class for this call 867 | parent_tool = AiderShellTool(state=self.state) 868 | 869 | # Execute aider command using parent's implementation 870 | result = await parent_tool._arun( 871 | message=message, 872 | files=".", # Always use "." for single-step 873 | run_manager=run_manager, 874 | **kwargs 875 | ) 876 | 877 | # Update state with the response 878 | if self.state and isinstance(result, AiderToReactOutput): 879 | # Append the response to messages 880 | self.state['messages'].append(AIMessage(content=result.response)) 881 | 882 | # Format as StepResult 883 | return StepResult( 884 | step_id=step_id or "", 885 | output=result.response if isinstance(result, AiderToReactOutput) else str(result), 886 | success=True 887 | ) 888 | except Exception as e: 889 | self.logger.error(f"Error in SingleAiderShellTool: {str(e)}", exc_info=True) 890 | return StepResult( 891 | step_id=step_id or "", 892 | output=f"Error: {str(e)}", 893 | success=False 894 | ) 895 | 896 | # # Instead of initializing with no state: 897 | # aider_tools = [SingleAiderShellTool()] # Remove this line 898 | 899 | 900 | # Add after existing tool input classes (around line 128) 901 | class InteractiveAiderInput(BaseModel): 902 | """Input schema for interactive aider command.""" 903 | message: str = Field(description="The message/instruction for aider") 904 | state: Dict[str, Any] = Field(description="Current state object") 905 | files: Optional[Union[str, List[str]]] = Field( 906 | default=None, 907 | description="Files to process" 908 | ) 909 | 910 | model_config = ConfigDict(arbitrary_types_allowed=True) 911 | 912 | @tool(args_schema=InteractiveAiderInput) 913 | async def interactive_aider_command( 914 | tool_input: InteractiveAiderInput, 915 | ) -> AsyncGenerator[AIMessageChunk, None]: 916 | """Execute aider command and stream responses.""" 917 | try: 918 | repo_path = tool_input.state.get('repo_path') 919 | if not repo_path: 920 | raise ValueError("No repo_path in state") 921 | 922 | aider = InteractiveAider(repo_path=repo_path) 923 | async for raw_chunk in aider.execute_command(tool_input.message): 924 | # Ensure we have a valid chunk structure 925 | if isinstance(raw_chunk, dict): 926 | chunk_type = raw_chunk.get("type", "message") 927 | chunk_content = raw_chunk.get("content", "") 928 | chunk_files = raw_chunk.get("files", []) 929 | else: 930 | chunk_type = "message" 931 | chunk_content = str(raw_chunk) 932 | chunk_files = [] 933 | 934 | yield AIMessageChunk( 935 | content=chunk_content, 936 | additional_kwargs={ 937 | "type": chunk_type, 938 | "files": chunk_files 939 | } 940 | ) 941 | 942 | except Exception as e: 943 | yield AIMessageChunk( 944 | content=str(e), 945 | additional_kwargs={"type": "error"} 946 | ) 947 | 948 | 949 | # Create tools when needed with state: 950 | def create_aider_tools(state: Optional[AgentState] = None) -> List[BaseTool]: 951 | """Create aider tools with state.""" 952 | return [SingleAiderShellTool(state=state)] 953 | 954 | # Update the aider_node creation 955 | aider_node = lambda state: ToolNode(tools=create_aider_tools(state)) 956 | 957 | # Keep only these tools 958 | setup_tools = [force_clone, force_branch] 959 | 960 | regular_tools = [ 961 | git_status, 962 | git_add, 963 | git_commit, 964 | git_push, 965 | git_diff, 966 | interactive_aider_command # Add the new interactive tool 967 | ] 968 | 969 | git_tools = [git_status, git_add, git_commit, git_push, git_diff] 970 | 971 | # aider_tools = [SingleAiderShellTool()] 972 | 973 | # Combined tools list for react agent 974 | tools = setup_tools + regular_tools 975 | 976 | # Create separate tool nodes 977 | setup_node = ToolNode(tools=setup_tools) 978 | tool_node = ToolNode(tools=regular_tools) 979 | 980 | # At the bottom of tools.py, update the exports 981 | __all__ = [ 982 | 'SingleAiderShellTool', 983 | 'AiderShellTool', 984 | 'create_aider_tools', 985 | 'interactive_aider_command', # Add the new tool 986 | 'force_clone', 987 | 'force_branch', 988 | 'git_status', 989 | 'git_add', 990 | 'git_commit', 991 | 'git_push', 992 | 'git_diff' 993 | ] 994 | 995 | class AiderCommandInput(BaseModel): 996 | message: str = Field(..., description="The instruction for aider") 997 | files: Union[str, List[str]] = Field(..., description="Files to process") 998 | 999 | @tool(args_schema=AiderCommandInput) 1000 | async def aider_command( 1001 | message: str, 1002 | files: Union[str, List[str]], 1003 | config: RunnableConfig, 1004 | ) -> str: 1005 | """Execute aider command on specified files in a container environment.""" 1006 | try: 1007 | repo_path = config.get("configurable", {}).get("repo_path") 1008 | chat_mode = config.get("configurable", {}).get("chat_mode", "") 1009 | 1010 | # Add logging for debugging chat_mode 1011 | logger.info(f"Aider command configuration - chat_mode: '{chat_mode}', repo_path: '{repo_path}'") 1012 | 1013 | if not repo_path: 1014 | raise ValueError("repo_path not found in config") 1015 | 1016 | repo_path = str(Path(repo_path)) 1017 | os.makedirs(repo_path, exist_ok=True) 1018 | 1019 | # Handle file paths 1020 | if isinstance(files, list): 1021 | files = [ 1022 | str(Path(f)) if f == "." else str(Path(repo_path) / Path(f).name) 1023 | for f in files 1024 | ] 1025 | else: 1026 | files = str(Path(files)) if files == "." else str(Path(repo_path) / Path(files).name) 1027 | 1028 | # Build the command with proper path handling 1029 | files_str = " ".join(f'"{f}"' for f in files) if isinstance(files, list) else f'"{files}"' 1030 | escaped_message = shlex.quote(message) 1031 | 1032 | # Construct chat mode argument with logging 1033 | chat_mode_arg = "" 1034 | if chat_mode.strip(): 1035 | # Replace 'code' with 'architect' if specified 1036 | mode = 'architect' if chat_mode.strip().lower() == 'code' else chat_mode.strip() 1037 | chat_mode_arg = f"--chat-mode {mode}" 1038 | logger.info(f"Constructed chat_mode_arg: '{chat_mode_arg}'") 1039 | 1040 | # Get model name from config 1041 | model_name = config.get("configurable", {}).get("model_name", "4o") 1042 | 1043 | # Determine API key and flag based on model 1044 | if model_name in ['haiku', 'sonnet']: 1045 | api_key = ANTHROPIC_API_KEY 1046 | api_flag = "--anthropic-api-key" 1047 | else: # '4o' or 'o1' 1048 | api_key = OPENAI_API_KEY 1049 | api_flag = "--openai-api-key" 1050 | 1051 | # Construct command with proper model and API key 1052 | aider_cmd = f"aider {chat_mode_arg} --yes-always --{model_name} {api_flag} {api_key} --no-stream --message {escaped_message} {files_str}" 1053 | logger.info(f"Final aider command: {aider_cmd}") 1054 | 1055 | # Create process with proper environment and stdin configuration 1056 | process = await asyncio.create_subprocess_shell( 1057 | aider_cmd, 1058 | stdout=asyncio.subprocess.PIPE, 1059 | stderr=asyncio.subprocess.PIPE, 1060 | stdin=asyncio.subprocess.DEVNULL, # Explicitly set stdin to DEVNULL 1061 | cwd=repo_path, 1062 | env={ 1063 | **os.environ, 1064 | 'TERM': 'xterm-256color', 1065 | 'COLUMNS': '80', 1066 | 'LINES': '24', 1067 | 'PATH': f"{os.environ.get('PATH', '')}:/usr/local/bin", 1068 | 'AIDER_NO_INTERACTIVE': '1' # Add environment variable to disable interactive mode 1069 | } 1070 | ) 1071 | 1072 | try: 1073 | stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=300) 1074 | stdout_decoded = stdout.decode() if stdout else "" 1075 | stderr_decoded = stderr.decode() if stderr else "" 1076 | 1077 | logger.info(f"Command output: {stdout_decoded}") 1078 | if stderr_decoded: 1079 | logger.error(f"Command stderr: {stderr_decoded}") 1080 | 1081 | if process.returncode != 0: 1082 | raise RuntimeError(f"Command failed (exit code {process.returncode}): {stderr_decoded}") 1083 | 1084 | return stdout_decoded if stdout_decoded else "Command completed successfully" 1085 | 1086 | except asyncio.TimeoutError: 1087 | process.kill() 1088 | raise TimeoutError("Command timed out after 300 seconds") 1089 | 1090 | except Exception as e: 1091 | error_msg = f"Error in aider_command: {str(e)}" 1092 | logger.error(error_msg, exc_info=True) 1093 | raise RuntimeError(error_msg) 1094 | 1095 | def configure_remote_with_auth(repo: Repo, repo_url: str, github_token: str) -> str: 1096 | """Configure git remote with authentication.""" 1097 | try: 1098 | # Parse the URL 1099 | parsed = urlparse(repo_url) 1100 | 1101 | # Construct authenticated URL 1102 | auth_url = urlunparse(parsed._replace( 1103 | netloc=f"oauth2:{github_token}@{parsed.netloc}" 1104 | )) 1105 | 1106 | # Set or update the remote 1107 | if 'origin' in repo.remotes: 1108 | repo.delete_remote('origin') 1109 | repo.create_remote('origin', auth_url) 1110 | 1111 | return auth_url 1112 | except Exception as e: 1113 | raise ValueError(f"Failed to configure remote: {str(e)}") 1114 | 1115 | async def _execute_aider_command( 1116 | message: str, 1117 | repo_path: str, 1118 | anthropic_api_key: str, 1119 | aider_state: Any, 1120 | files: Optional[str] = None 1121 | ) -> str: 1122 | """Execute an aider command and return the response.""" 1123 | try: 1124 | # Construct aider command 1125 | cmd = f"aider --yes --no-stream --message {shlex.quote(message)}" 1126 | if files: 1127 | cmd += f" {files}" 1128 | 1129 | # Create process 1130 | process = await asyncio.create_subprocess_shell( 1131 | cmd, 1132 | stdout=asyncio.subprocess.PIPE, 1133 | stderr=asyncio.subprocess.PIPE, 1134 | cwd=repo_path, 1135 | env={"OPENAI_API_KEY": openai_api_key} 1136 | # env={"ANTHROPIC_API_KEY": anthropic_api_key} 1137 | ) 1138 | 1139 | # Wait for completion with timeout 1140 | try: 1141 | stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=300) 1142 | return stdout.decode() 1143 | except asyncio.TimeoutError: 1144 | process.kill() 1145 | raise TimeoutError("Aider command timed out after 300 seconds") 1146 | 1147 | except Exception as e: 1148 | raise RuntimeError(f"Failed to execute aider command: {str(e)}") 1149 | 1150 | def cleanup_remote(repo: Repo) -> None: 1151 | """Remove sensitive information from git remote.""" 1152 | try: 1153 | if 'origin' in repo.remotes: 1154 | repo.delete_remote('origin') 1155 | # Recreate with clean URL if needed 1156 | if hasattr(repo, '_original_url'): 1157 | repo.create_remote('origin', repo._original_url) 1158 | except Exception as e: 1159 | logger.warning(f"Failed to cleanup remote: {str(e)}") 1160 | 1161 | # Git Progress Handler 1162 | class GitProgressHandler(RemoteProgress): 1163 | """Handle git operation progress.""" 1164 | def __init__(self): 1165 | super().__init__() 1166 | self.logger = logging.getLogger(__name__) 1167 | 1168 | def update(self, op_code, cur_count, max_count=None, message=''): 1169 | """Called whenever the progress changes.""" 1170 | self.logger.debug(f'Progress: {op_code}, {cur_count}/{max_count}, {message}') 1171 | 1172 | class AiderChunk(BaseModel): 1173 | type: str 1174 | content: str 1175 | files: Optional[List[str]] = None 1176 | 1177 | --------------------------------------------------------------------------------