161 |
162 | """
163 | + self.cssb
164 | )
165 | code = gr.HTML(
166 | self.cssa
167 | + """
168 |
170 |
171 |
172 | """
173 | + self.cssb
174 | )
175 | with gr.Row():
176 | with gr.Column():
177 | output3 = gr.Textbox(label="key words", lines=2)
178 | output4 = gr.Textbox(label="key words code", lines=14)
179 |
180 | btn.click(
181 | self.wrapper_respond,
182 | inputs=[msg, system],
183 | outputs=[msg, output1, output2, output3, code, output4],
184 | )
185 | btnc.click(
186 | self.clean, outputs=[msg, output1, output2, output3, code, output4]
187 | )
188 | msg.submit(
189 | self.wrapper_respond,
190 | inputs=[msg, system],
191 | outputs=[msg, output1, output2, output3, code, output4],
192 | ) # Press enter to submit
193 |
194 | gr.close_all()
195 | demo.queue().launch(share=False, height=800)
196 |
197 |
198 | # 使用方法
199 | if __name__ == "__main__":
200 |
201 | def respond_function(msg, system):
202 | RAG = """
203 |
204 |
205 | """
206 | return msg, RAG, "Embedding_recall_output", "Key_words_output", "Code_output"
207 |
208 | gradio_interface = GradioInterface(respond_function)
209 |
--------------------------------------------------------------------------------
/repo_agent/chat_with_repo/json_handler.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 |
4 | from repo_agent.log import logger
5 |
6 |
7 | class JsonFileProcessor:
8 | def __init__(self, file_path):
9 | self.file_path = file_path
10 |
11 | def read_json_file(self):
12 | try:
13 | with open(self.file_path, "r", encoding="utf-8") as file:
14 | data = json.load(file)
15 | return data
16 | except FileNotFoundError:
17 | logger.exception(f"File not found: {self.file_path}")
18 | sys.exit(1)
19 |
20 | def extract_data(self):
21 | # Load JSON data from a file
22 | json_data = self.read_json_file()
23 | md_contents = []
24 | extracted_contents = []
25 | # Iterate through each file in the JSON data
26 | for file, items in json_data.items():
27 | # Check if the value is a list (new format)
28 | if isinstance(items, list):
29 | # Iterate through each item in the list
30 | for item in items:
31 | # Check if 'md_content' exists and is not empty
32 | if "md_content" in item and item["md_content"]:
33 | # Append the first element of 'md_content' to the result list
34 | md_contents.append(item["md_content"][0])
35 | # Build a dictionary containing the required information
36 | item_dict = {
37 | "type": item.get("type", "UnknownType"),
38 | "name": item.get("name", "Unnamed"),
39 | "code_start_line": item.get("code_start_line", -1),
40 | "code_end_line": item.get("code_end_line", -1),
41 | "have_return": item.get("have_return", False),
42 | "code_content": item.get("code_content", "NoContent"),
43 | "name_column": item.get("name_column", 0),
44 | "item_status": item.get("item_status", "UnknownStatus"),
45 | # Adapt or remove fields based on new structure requirements
46 | }
47 | extracted_contents.append(item_dict)
48 | return md_contents, extracted_contents
49 |
50 | def recursive_search(self, data_item, search_text, code_results, md_results):
51 | if isinstance(data_item, dict):
52 | # Direct comparison is removed as there's no direct key==search_text in the new format
53 | for key, value in data_item.items():
54 | # Recursively search through dictionary values and lists
55 | if isinstance(value, (dict, list)):
56 | self.recursive_search(value, search_text, code_results, md_results)
57 | elif isinstance(data_item, list):
58 | for item in data_item:
59 | # Now we check for the 'name' key in each item of the list
60 | if isinstance(item, dict) and item.get("name") == search_text:
61 | # If 'code_content' exists, append it to results
62 | if "code_content" in item:
63 | code_results.append(item["code_content"])
64 | md_results.append(item["md_content"])
65 | # Recursive call in case of nested lists or dicts
66 | self.recursive_search(item, search_text, code_results, md_results)
67 |
68 | def search_code_contents_by_name(self, file_path, search_text):
69 | # Attempt to retrieve code from the JSON file
70 | try:
71 | with open(file_path, "r", encoding="utf-8") as file:
72 | data = json.load(file)
73 | code_results = []
74 | md_results = [] # List to store matching items' code_content and md_content
75 | self.recursive_search(data, search_text, code_results, md_results)
76 | # 确保无论结果如何都返回两个值
77 | if code_results or md_results:
78 | return code_results, md_results
79 | else:
80 | return ["No matching item found."], ["No matching item found."]
81 | except FileNotFoundError:
82 | return "File not found."
83 | except json.JSONDecodeError:
84 | return "Invalid JSON file."
85 | except Exception as e:
86 | return f"An error occurred: {e}"
87 |
88 |
89 | if __name__ == "__main__":
90 | processor = JsonFileProcessor("database.json")
91 | md_contents, extracted_contents = processor.extract_data()
92 |
--------------------------------------------------------------------------------
/repo_agent/chat_with_repo/main.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from repo_agent.chat_with_repo.gradio_interface import GradioInterface
4 | from repo_agent.chat_with_repo.rag import RepoAssistant
5 | from repo_agent.log import logger
6 | from repo_agent.settings import SettingsManager
7 |
8 |
9 | def main():
10 | logger.info("Initializing the RepoAgent chat with doc module.")
11 |
12 | # Load settings
13 | setting = SettingsManager.get_setting()
14 |
15 | api_key = setting.chat_completion.openai_api_key.get_secret_value()
16 | api_base = str(setting.chat_completion.openai_base_url)
17 | db_path = (
18 | setting.project.target_repo
19 | / setting.project.hierarchy_name
20 | / "project_hierarchy.json"
21 | )
22 |
23 | # Initialize RepoAssistant
24 | assistant = RepoAssistant(api_key, api_base, db_path)
25 |
26 | # Extract data
27 | md_contents, meta_data = assistant.json_data.extract_data()
28 |
29 | # Create vector store and measure runtime
30 | logger.info("Starting vector store creation...")
31 | start_time = time.time()
32 | assistant.vector_store_manager.create_vector_store(
33 | md_contents, meta_data, api_key, api_base
34 | )
35 | elapsed_time = time.time() - start_time
36 | logger.info(f"Vector store created successfully in {elapsed_time:.2f} seconds.")
37 |
38 | # Launch Gradio interface
39 | GradioInterface(assistant.respond)
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/repo_agent/chat_with_repo/prompt.py:
--------------------------------------------------------------------------------
1 | from llama_index.core import ChatPromptTemplate, PromptTemplate
2 | from llama_index.core.llms import ChatMessage, MessageRole
3 |
4 | # Query Generation Prompt
5 | query_generation_prompt_str = (
6 | "You are a helpful assistant that generates multiple search queries based on a "
7 | "single input query. Generate {num_queries} search queries, one on each line, "
8 | "related to the following input query:\n"
9 | "Query: {query}\n"
10 | "Queries:\n"
11 | )
12 | query_generation_template = PromptTemplate(query_generation_prompt_str)
13 |
14 | # Relevance Ranking Prompt
15 | relevance_ranking_instruction = (
16 | "You are an expert relevance ranker. Given a list of documents and a query, your job is to determine how relevant each document is for answering the query. "
17 | "Your output is JSON, which is a list of documents. Each document has two fields, content and relevance_score. relevance_score is from 0.0 to 100.0. "
18 | "Higher relevance means higher score."
19 | )
20 | relevance_ranking_guideline = "Query: {query} Docs: {docs}"
21 |
22 | relevance_ranking_message_template = [
23 | ChatMessage(content=relevance_ranking_instruction, role=MessageRole.SYSTEM),
24 | ChatMessage(
25 | content=relevance_ranking_guideline,
26 | role=MessageRole.USER,
27 | ),
28 | ]
29 | relevance_ranking_chat_template = ChatPromptTemplate(
30 | message_templates=relevance_ranking_message_template
31 | )
32 |
33 | # RAG (Retrieve and Generate) Prompt
34 | rag_prompt_str = (
35 | "You are a helpful assistant in repository Q&A. Users will ask questions about something contained in a repository. "
36 | "You will be shown the user's question, and the relevant information from the repository. Answer the user's question only with information given.\n\n"
37 | "Question: {query}.\n\n"
38 | "Information: {information}"
39 | )
40 | rag_template = PromptTemplate(rag_prompt_str)
41 |
42 | # RAG_AR (Advanced RAG) Prompt
43 | rag_ar_prompt_str = (
44 | "You are a helpful Repository-Level Software Q&A assistant. Your task is to answer users' questions based on the given information about a software repository, "
45 | "including related code and documents.\n\n"
46 | "Currently, you're in the {project_name} project. The user's question is:\n"
47 | "{query}\n\n"
48 | "Now, you are given related code and documents as follows:\n\n"
49 | "-------------------Code-------------------\n"
50 | "Some most likely related code snippets recalled by the retriever are:\n"
51 | "{related_code}\n\n"
52 | "-------------------Document-------------------\n"
53 | "Some most relevant documents recalled by the retriever are:\n"
54 | "{embedding_recall}\n\n"
55 | "Please note: \n"
56 | "1. All the provided recall results are related to the current project {project_name}. Please filter useful information according to the user's question and provide corresponding answers or solutions.\n"
57 | "2. Ensure that your responses are accurate and detailed. Present specific answers in a professional manner and tone.\n"
58 | "3. The user's question may be asked in any language. You must respond **in the same language** as the user's question, even if the input language is not English.\n"
59 | "4. If you find the user's question completely unrelated to the provided information or if you believe you cannot provide an accurate answer, kindly decline. Note: DO NOT fabricate any non-existent information.\n\n"
60 | "Now, focusing on the user's query, and incorporating the given information to offer a specific, detailed, and professional answer IN THE SAME LANGUAGE AS the user's question."
61 | )
62 |
63 |
64 | rag_ar_template = PromptTemplate(rag_ar_prompt_str)
65 |
--------------------------------------------------------------------------------
/repo_agent/chat_with_repo/rag.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from llama_index.llms.openai import OpenAI
4 |
5 | from repo_agent.chat_with_repo.json_handler import JsonFileProcessor
6 | from repo_agent.chat_with_repo.prompt import (
7 | query_generation_template,
8 | rag_ar_template,
9 | rag_template,
10 | relevance_ranking_chat_template,
11 | )
12 | from repo_agent.chat_with_repo.text_analysis_tool import TextAnalysisTool
13 | from repo_agent.chat_with_repo.vector_store_manager import VectorStoreManager
14 | from repo_agent.log import logger
15 |
16 |
17 | class RepoAssistant:
18 | def __init__(self, api_key, api_base, db_path):
19 | self.db_path = db_path
20 | self.md_contents = []
21 |
22 | self.weak_model = OpenAI(
23 | api_key=api_key,
24 | api_base=api_base,
25 | model="gpt-4o-mini",
26 | )
27 | self.strong_model = OpenAI(
28 | api_key=api_key,
29 | api_base=api_base,
30 | model="gpt-4o",
31 | )
32 | self.textanslys = TextAnalysisTool(self.weak_model, db_path)
33 | self.json_data = JsonFileProcessor(db_path)
34 | self.vector_store_manager = VectorStoreManager(top_k=5, llm=self.weak_model)
35 |
36 | def generate_queries(self, query_str: str, num_queries: int = 4):
37 | fmt_prompt = query_generation_template.format(
38 | num_queries=num_queries - 1, query=query_str
39 | )
40 | response = self.weak_model.complete(fmt_prompt)
41 | queries = response.text.split("\n")
42 | return queries
43 |
44 | def rerank(self, query, docs): # 这里要防止返回值格式上出问题
45 | response = self.weak_model.chat(
46 | response_format={"type": "json_object"},
47 | temperature=0,
48 | messages=relevance_ranking_chat_template.format_messages(
49 | query=query, docs=docs
50 | ),
51 | )
52 | scores = json.loads(response.message.content)["documents"] # type: ignore
53 | logger.debug(f"scores: {scores}")
54 | sorted_data = sorted(scores, key=lambda x: x["relevance_score"], reverse=True)
55 | top_5_contents = [doc["content"] for doc in sorted_data[:5]]
56 | return top_5_contents
57 |
58 | def rag(self, query, retrieved_documents):
59 | rag_prompt = rag_template.format(
60 | query=query, information="\n\n".join(retrieved_documents)
61 | )
62 | response = self.weak_model.complete(rag_prompt)
63 | return response.text
64 |
65 | def list_to_markdown(self, list_items):
66 | markdown_content = ""
67 |
68 | # 对于列表中的每个项目,添加一个带数字的列表项
69 | for index, item in enumerate(list_items, start=1):
70 | markdown_content += f"{index}. {item}\n"
71 |
72 | return markdown_content
73 |
74 | def rag_ar(self, query, related_code, embedding_recall, project_name):
75 | rag_ar_prompt = rag_ar_template.format_messages(
76 | query=query,
77 | related_code=related_code,
78 | embedding_recall=embedding_recall,
79 | project_name=project_name,
80 | )
81 | response = self.strong_model.chat(rag_ar_prompt)
82 | return response.message.content
83 |
84 | def respond(self, message, instruction):
85 | """
86 | Respond to a user query by processing input, querying the vector store,
87 | reranking results, and generating a final response.
88 | """
89 | logger.debug("Starting response generation.")
90 |
91 | # Step 1: Format the chat prompt
92 | prompt = self.textanslys.format_chat_prompt(message, instruction)
93 | logger.debug(f"Formatted prompt: {prompt}")
94 |
95 | questions = self.textanslys.keyword(prompt)
96 | logger.debug(f"Generated keywords from prompt: {questions}")
97 |
98 | # Step 2: Generate additional queries
99 | prompt_queries = self.generate_queries(prompt, 3)
100 | logger.debug(f"Generated queries: {prompt_queries}")
101 |
102 | all_results = []
103 | all_documents = []
104 |
105 | # Step 3: Query the VectorStoreManager for each query
106 | for query in prompt_queries:
107 | logger.debug(f"Querying vector store with: {query}")
108 | query_results = self.vector_store_manager.query_store(query)
109 | logger.debug(f"Results for query '{query}': {query_results}")
110 | all_results.extend(query_results)
111 |
112 | # Step 4: Deduplicate results by content
113 | unique_results = {result["text"]: result for result in all_results}.values()
114 | unique_documents = [result["text"] for result in unique_results]
115 | logger.debug(f"Unique documents: {unique_documents}")
116 |
117 | unique_code = [
118 | result.get("metadata", {}).get("code_content") for result in unique_results
119 | ]
120 | logger.debug(f"Unique code content: {unique_code}")
121 |
122 | # Step 5: Rerank documents based on relevance
123 | retrieved_documents = self.rerank(message, unique_documents)
124 | logger.debug(f"Reranked documents: {retrieved_documents}")
125 |
126 | # Step 6: Generate a response using RAG (Retrieve and Generate)
127 | response = self.rag(prompt, retrieved_documents)
128 | chunkrecall = self.list_to_markdown(retrieved_documents)
129 | logger.debug(f"RAG-generated response: {response}")
130 | logger.debug(f"Markdown chunk recall: {chunkrecall}")
131 |
132 | bot_message = str(response)
133 | logger.debug(f"Initial bot_message: {bot_message}")
134 |
135 | # Step 7: Perform NER and queryblock processing
136 | keyword = str(self.textanslys.nerquery(bot_message))
137 | keywords = str(self.textanslys.nerquery(str(prompt) + str(questions)))
138 | logger.debug(f"Extracted keywords: {keyword}, {keywords}")
139 |
140 | codez, mdz = self.textanslys.queryblock(keyword)
141 | codey, mdy = self.textanslys.queryblock(keywords)
142 |
143 | # Ensure all returned items are lists
144 | codez = codez if isinstance(codez, list) else [codez]
145 | mdz = mdz if isinstance(mdz, list) else [mdz]
146 | codey = codey if isinstance(codey, list) else [codey]
147 | mdy = mdy if isinstance(mdy, list) else [mdy]
148 |
149 | # Step 8: Merge and deduplicate results
150 | codex = list(dict.fromkeys(codez + codey))
151 | md = list(dict.fromkeys(mdz + mdy))
152 | unique_mdx = list(set([item for sublist in md for item in sublist]))
153 | uni_codex = list(dict.fromkeys(codex))
154 | uni_md = list(dict.fromkeys(unique_mdx))
155 |
156 | # Convert to Markdown format
157 | codex_md = self.textanslys.list_to_markdown(uni_codex)
158 | retrieved_documents = list(dict.fromkeys(retrieved_documents + uni_md))
159 |
160 | # Final rerank and response generation
161 | retrieved_documents = self.rerank(message, retrieved_documents[:6])
162 | logger.debug(f"Final retrieved documents after rerank: {retrieved_documents}")
163 |
164 | uni_code = self.rerank(
165 | message, list(dict.fromkeys(uni_codex + unique_code))[:6]
166 | )
167 | logger.debug(f"Final unique code after rerank: {uni_code}")
168 |
169 | unique_code_md = self.textanslys.list_to_markdown(unique_code)
170 | logger.debug(f"Unique code in Markdown: {unique_code_md}")
171 |
172 | # Generate final response using RAG_AR
173 | bot_message = self.rag_ar(prompt, uni_code, retrieved_documents, "test")
174 | logger.debug(f"Final bot_message after RAG_AR: {bot_message}")
175 |
176 | return message, bot_message, chunkrecall, questions, unique_code_md, codex_md
177 |
--------------------------------------------------------------------------------
/repo_agent/chat_with_repo/text_analysis_tool.py:
--------------------------------------------------------------------------------
1 | from llama_index.core.llms.function_calling import FunctionCallingLLM
2 | from llama_index.llms.openai import OpenAI
3 |
4 | from repo_agent.chat_with_repo.json_handler import JsonFileProcessor
5 |
6 |
7 | class TextAnalysisTool:
8 | def __init__(self, llm: FunctionCallingLLM, db_path):
9 | self.jsonsearch = JsonFileProcessor(db_path)
10 | self.llm = llm
11 | self.db_path = db_path
12 |
13 | def keyword(self, query):
14 | prompt = f"Please provide a list of Code keywords according to the following query, please output no more than 3 keywords, Input: {query}, Output:"
15 | response = self.llm.complete(prompt)
16 | return response
17 |
18 | def tree(self, query):
19 | prompt = f"Please analyze the following text and generate a tree structure based on its hierarchy:\n\n{query}"
20 | response = self.llm.complete(prompt)
21 | return response
22 |
23 | def format_chat_prompt(self, message, instruction):
24 | prompt = f"System:{instruction}\nUser: {message}\nAssistant:"
25 | return prompt
26 |
27 | def queryblock(self, message):
28 | search_result, md = self.jsonsearch.search_code_contents_by_name(
29 | self.db_path, message
30 | )
31 | return search_result, md
32 |
33 | def list_to_markdown(self, search_result):
34 | markdown_str = ""
35 | # 遍历列表,将每个元素转换为Markdown格式的项
36 | for index, content in enumerate(search_result, start=1):
37 | # 添加到Markdown字符串中,每个项后跟一个换行符
38 | markdown_str += f"{index}. {content}\n\n"
39 |
40 | return markdown_str
41 |
42 | def nerquery(self, message):
43 | instrcution = """
44 | Extract the most relevant class or function base on the following instrcution:
45 |
46 | The output must strictly be a pure function name or class name, without any additional characters.
47 | For example:
48 | Pure function names: calculateSum, processData
49 | Pure class names: MyClass, DataProcessor
50 | The output function name or class name should be only one.
51 | """
52 | query = f"{instrcution}\n\nThe input is shown as bellow:\n{message}\n\nAnd now directly give your Output:"
53 | response = self.llm.complete(query)
54 | # logger.debug(f"Input: {message}, Output: {response}")
55 | return response
56 |
57 |
58 | if __name__ == "__main__":
59 | api_base = "https://api.openai.com/v1"
60 | api_key = "your_api_key"
61 | log_file = "your_logfile_path"
62 | llm = OpenAI(api_key=api_key, api_base=api_base)
63 | db_path = "your_database_path"
64 | test = TextAnalysisTool(llm, db_path)
65 |
--------------------------------------------------------------------------------
/repo_agent/chat_with_repo/vector_store_manager.py:
--------------------------------------------------------------------------------
1 | import chromadb
2 | from llama_index.core import (
3 | Document,
4 | StorageContext,
5 | VectorStoreIndex,
6 | get_response_synthesizer,
7 | )
8 | from llama_index.core.node_parser import (
9 | SemanticSplitterNodeParser,
10 | SentenceSplitter,
11 | )
12 | from llama_index.core.query_engine import RetrieverQueryEngine
13 | from llama_index.core.retrievers import VectorIndexRetriever
14 | from llama_index.embeddings.openai import OpenAIEmbedding
15 | from llama_index.vector_stores.chroma import ChromaVectorStore
16 |
17 | from repo_agent.log import logger
18 |
19 |
20 | class VectorStoreManager:
21 | def __init__(self, top_k, llm):
22 | """
23 | Initialize the VectorStoreManager.
24 | """
25 | self.query_engine = None # Initialize as None
26 | self.chroma_db_path = "./chroma_db" # Path to Chroma database
27 | self.collection_name = "test" # Default collection name
28 | self.similarity_top_k = top_k
29 | self.llm = llm
30 |
31 | def create_vector_store(self, md_contents, meta_data, api_key, api_base):
32 | """
33 | Add markdown content and metadata to the index.
34 | """
35 | if not md_contents or not meta_data:
36 | logger.warning("No content or metadata provided. Skipping.")
37 | return
38 |
39 | # Ensure lengths match
40 | min_length = min(len(md_contents), len(meta_data))
41 | md_contents = md_contents[:min_length]
42 | meta_data = meta_data[:min_length]
43 |
44 | logger.debug(f"Number of markdown contents: {len(md_contents)}")
45 | logger.debug(f"Number of metadata entries: {len(meta_data)}")
46 |
47 | # Initialize Chroma client and collection
48 | db = chromadb.PersistentClient(path=self.chroma_db_path)
49 | chroma_collection = db.get_or_create_collection(self.collection_name)
50 |
51 | # Define embedding model
52 | embed_model = OpenAIEmbedding(
53 | model_name="text-embedding-3-large",
54 | api_key=api_key,
55 | api_base=api_base,
56 | )
57 |
58 | # Initialize semantic chunker (SimpleNodeParser)
59 | logger.debug("Initializing semantic chunker (SimpleNodeParser).")
60 | splitter = SemanticSplitterNodeParser(
61 | buffer_size=1, breakpoint_percentile_threshold=95, embed_model=embed_model
62 | )
63 | base_splitter = SentenceSplitter(chunk_size=1024)
64 |
65 | documents = [
66 | Document(text=content, extra_info=meta)
67 | for content, meta in zip(md_contents, meta_data)
68 | ]
69 |
70 | all_nodes = []
71 | for i, doc in enumerate(documents):
72 | logger.debug(
73 | f"Processing document {i+1}: Content length={len(doc.get_text())}"
74 | )
75 |
76 | try:
77 | # Try semantic splitting first
78 | nodes = splitter.get_nodes_from_documents([doc])
79 | logger.debug(f"Document {i+1} split into {len(nodes)} semantic chunks.")
80 |
81 | except Exception as e:
82 | # Fallback to baseline sentence splitting
83 | logger.warning(
84 | f"Semantic splitting failed for document {i+1}, falling back to SentenceSplitter. Error: {e}"
85 | )
86 | nodes = base_splitter.get_nodes_from_documents([doc])
87 | logger.debug(f"Document {i+1} split into {len(nodes)} sentence chunks.")
88 |
89 | all_nodes.extend(nodes)
90 |
91 | if not all_nodes:
92 | logger.warning("No valid nodes to add to the index after chunking.")
93 | return
94 |
95 | logger.debug(f"Number of valid chunks: {len(all_nodes)}")
96 |
97 | # Set up ChromaVectorStore and load data
98 | vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
99 | storage_context = StorageContext.from_defaults(vector_store=vector_store)
100 | index = VectorStoreIndex(
101 | all_nodes, storage_context=storage_context, embed_model=embed_model
102 | )
103 | retriever = VectorIndexRetriever(
104 | index=index, similarity_top_k=self.similarity_top_k, embed_model=embed_model
105 | )
106 |
107 | response_synthesizer = get_response_synthesizer(llm=self.llm)
108 |
109 | # Set the query engine
110 | self.query_engine = RetrieverQueryEngine(
111 | retriever=retriever,
112 | response_synthesizer=response_synthesizer,
113 | )
114 |
115 | logger.info(f"Vector store created and loaded with {len(documents)} documents.")
116 |
117 | def query_store(self, query):
118 | """
119 | Query the vector store for relevant documents.
120 | """
121 | if not self.query_engine:
122 | logger.error(
123 | "Query engine is not initialized. Please create a vector store first."
124 | )
125 | return []
126 |
127 | # Query the vector store
128 | logger.debug(f"Querying vector store with: {query}")
129 | results = self.query_engine.query(query)
130 |
131 | # Extract relevant information from results
132 | return [{"text": results.response, "metadata": results.metadata}]
133 |
--------------------------------------------------------------------------------
/repo_agent/log.py:
--------------------------------------------------------------------------------
1 | # repo_agent/log.py
2 | import inspect
3 | import logging
4 | import sys
5 |
6 | from loguru import logger
7 |
8 | logger = logger.opt(colors=True)
9 | """
10 | RepoAgent 日志记录器对象。
11 |
12 | 默认信息:
13 | - 格式: `[%(asctime)s %(name)s] %(levelname)s: %(message)s`
14 | - 等级: `INFO` ,根据 `CONFIG["log_level"]` 配置改变
15 | - 输出: 输出至 stdout
16 |
17 | 用法示例:
18 | ```python
19 | from repo_agent.log import logger
20 |
21 | # 基本消息记录
22 | logger.info("It works>!") # 使用颜色
23 |
24 | # 记录异常信息
25 | try:
26 | 1 / 0
27 | except ZeroDivisionError:
28 | # 使用 `logger.exception` 可以在记录异常消息时自动附加异常的堆栈跟踪信息。
29 | logger.exception("ZeroDivisionError occurred")
30 |
31 | # 记录调试信息
32 | logger.debug(f"Debugging info: {some_debug_variable}")
33 |
34 | # 记录警告信息
35 | logger.warning("This is a warning message")
36 |
37 | # 记录错误信息
38 | logger.error("An error occurred")
39 | ```
40 |
41 | """
42 |
43 |
44 | class InterceptHandler(logging.Handler):
45 | def emit(self, record: logging.LogRecord) -> None:
46 | # Get corresponding Loguru level if it exists.
47 | level: str | int
48 | try:
49 | level = logger.level(record.levelname).name
50 | except ValueError:
51 | level = record.levelno
52 |
53 | # Find caller from where originated the logged message.
54 | frame, depth = inspect.currentframe(), 0
55 | while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
56 | frame = frame.f_back
57 | depth += 1
58 |
59 | logger.opt(depth=depth, exception=record.exc_info).log(
60 | level, record.getMessage()
61 | )
62 |
63 |
64 | def set_logger_level_from_config(log_level):
65 | """
66 | Configures the loguru logger with specified log level and integrates it with the standard logging module.
67 |
68 | Args:
69 | log_level (str): The log level to set for loguru (e.g., "DEBUG", "INFO", "WARNING").
70 |
71 | This function:
72 | - Removes any existing loguru handlers to ensure a clean slate.
73 | - Adds a new handler to loguru, directing output to stderr with the specified level.
74 | - `enqueue=True` ensures thread-safe logging by using a queue, helpful in multi-threaded contexts.
75 | - `backtrace=False` minimizes detailed traceback to prevent overly verbose output.
76 | - `diagnose=False` suppresses additional loguru diagnostic information for more concise logs.
77 | - Redirects the standard logging output to loguru using the InterceptHandler, allowing loguru to handle
78 | all logs consistently across the application.
79 | """
80 | logger.remove()
81 | logger.add(
82 | sys.stderr, level=log_level, enqueue=True, backtrace=False, diagnose=False
83 | )
84 |
85 | # Intercept standard logging
86 | logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True)
87 |
88 | logger.success(f"Log level set to {log_level}!")
89 |
--------------------------------------------------------------------------------
/repo_agent/main.py:
--------------------------------------------------------------------------------
1 | from importlib import metadata
2 |
3 | import click
4 | from pydantic import ValidationError
5 |
6 | from repo_agent.doc_meta_info import DocItem, MetaInfo
7 | from repo_agent.log import logger, set_logger_level_from_config
8 | from repo_agent.runner import Runner, delete_fake_files
9 | from repo_agent.settings import SettingsManager, LogLevel
10 | from repo_agent.utils.meta_info_utils import delete_fake_files, make_fake_files
11 |
12 | try:
13 | version_number = metadata.version("repoagent")
14 | except metadata.PackageNotFoundError:
15 | version_number = "0.0.0"
16 |
17 |
18 | @click.group()
19 | @click.version_option(version_number)
20 | def cli():
21 | """An LLM-Powered Framework for Repository-level Code Documentation Generation."""
22 | pass
23 |
24 |
25 | def handle_setting_error(e: ValidationError):
26 | """Handle configuration errors for settings."""
27 | # 输出更详细的字段缺失信息,使用颜色区分
28 | for error in e.errors():
29 | field = error["loc"][-1]
30 | if error["type"] == "missing":
31 | message = click.style(
32 | f"Missing required field `{field}`. Please set the `{field}` environment variable.",
33 | fg="yellow",
34 | )
35 | else:
36 | message = click.style(error["msg"], fg="yellow")
37 | click.echo(message, err=True, color=True)
38 |
39 | # 使用 ClickException 优雅地退出程序
40 | raise click.ClickException(
41 | click.style(
42 | "Program terminated due to configuration errors.", fg="red", bold=True
43 | )
44 | )
45 |
46 |
47 | @cli.command()
48 | @click.option(
49 | "--model",
50 | "-m",
51 | default="gpt-4o-mini",
52 | show_default=True,
53 | help="Specifies the model to use for completion.",
54 | type=str,
55 | )
56 | @click.option(
57 | "--temperature",
58 | "-t",
59 | default=0.2,
60 | show_default=True,
61 | help="Sets the generation temperature for the model. Lower values make the model more deterministic.",
62 | type=float,
63 | )
64 | @click.option(
65 | "--request-timeout",
66 | "-r",
67 | default=60,
68 | show_default=True,
69 | help="Defines the timeout in seconds for the API request.",
70 | type=int,
71 | )
72 | @click.option(
73 | "--base-url",
74 | "-b",
75 | default="https://api.openai.com/v1",
76 | show_default=True,
77 | help="The base URL for the API calls.",
78 | type=str,
79 | )
80 | @click.option(
81 | "--target-repo-path",
82 | "-tp",
83 | default="",
84 | show_default=True,
85 | help="The file system path to the target repository. This path is used as the root for documentation generation.",
86 | type=click.Path(file_okay=False),
87 | )
88 | @click.option(
89 | "--hierarchy-path",
90 | "-hp",
91 | default=".project_doc_record",
92 | show_default=True,
93 | help="The name or path for the project hierarchy file, used to organize documentation structure.",
94 | type=str,
95 | )
96 | @click.option(
97 | "--markdown-docs-path",
98 | "-mdp",
99 | default="markdown_docs",
100 | show_default=True,
101 | help="The folder path where Markdown documentation will be stored or generated.",
102 | type=str,
103 | )
104 | @click.option(
105 | "--ignore-list",
106 | "-i",
107 | default="",
108 | help="A comma-separated list of files or directories to ignore during documentation generation.",
109 | )
110 | @click.option(
111 | "--language",
112 | "-l",
113 | default="English",
114 | show_default=True,
115 | help="The ISO 639 code or language name for the documentation. ",
116 | type=str,
117 | )
118 | @click.option(
119 | "--max-thread-count",
120 | "-mtc",
121 | default=4,
122 | show_default=True,
123 | )
124 | @click.option(
125 | "--log-level",
126 | "-ll",
127 | default="INFO",
128 | show_default=True,
129 | help="Sets the logging level (e.g., DEBUG, INFO, WARNING, ERROR, CRITICAL) for the application. Default is INFO.",
130 | type=click.Choice([level.value for level in LogLevel], case_sensitive=False),
131 | )
132 | @click.option(
133 | "--print-hierarchy",
134 | "-pr",
135 | is_flag=True,
136 | show_default=True,
137 | default=False,
138 | help="If set, prints the hierarchy of the target repository when finished running the main task.",
139 | )
140 | def run(
141 | model,
142 | temperature,
143 | request_timeout,
144 | base_url,
145 | target_repo_path,
146 | hierarchy_path,
147 | markdown_docs_path,
148 | ignore_list,
149 | language,
150 | max_thread_count,
151 | log_level,
152 | print_hierarchy,
153 | ):
154 | """Run the program with the specified parameters."""
155 | try:
156 | # Fetch and validate the settings using the SettingsManager
157 | setting = SettingsManager.initialize_with_params(
158 | target_repo=target_repo_path,
159 | hierarchy_name=hierarchy_path,
160 | markdown_docs_name=markdown_docs_path,
161 | ignore_list=[item.strip() for item in ignore_list.split(",") if item],
162 | language=language,
163 | log_level=log_level,
164 | model=model,
165 | temperature=temperature,
166 | request_timeout=request_timeout,
167 | openai_base_url=base_url,
168 | max_thread_count=max_thread_count,
169 | )
170 | set_logger_level_from_config(log_level=log_level)
171 | except ValidationError as e:
172 | handle_setting_error(e)
173 | return
174 |
175 | # 如果设置成功,则运行任务
176 | runner = Runner()
177 | runner.run()
178 | logger.success("Documentation task completed.")
179 | if print_hierarchy:
180 | runner.meta_info.target_repo_hierarchical_tree.print_recursive()
181 | logger.success("Hierarchy printed.")
182 |
183 |
184 | @cli.command()
185 | def clean():
186 | """Clean the fake files generated by the documentation process."""
187 | delete_fake_files()
188 | logger.success("Fake files have been cleaned up.")
189 |
190 |
191 | @cli.command()
192 | def diff():
193 | """Check for changes and print which documents will be updated or generated."""
194 | try:
195 | # Fetch and validate the settings using the SettingsManager
196 | setting = SettingsManager.get_setting()
197 | except ValidationError as e:
198 | handle_setting_error(e)
199 | return
200 |
201 | runner = Runner()
202 | if runner.meta_info.in_generation_process: # 如果不是在生成过程中,就开始检测变更
203 | click.echo("This command only supports pre-check")
204 | raise click.Abort()
205 |
206 | file_path_reflections, jump_files = make_fake_files()
207 | new_meta_info = MetaInfo.init_meta_info(file_path_reflections, jump_files)
208 | new_meta_info.load_doc_from_older_meta(runner.meta_info)
209 | delete_fake_files()
210 |
211 | DocItem.check_has_task(
212 | new_meta_info.target_repo_hierarchical_tree,
213 | ignore_list=setting.project.ignore_list,
214 | )
215 | if new_meta_info.target_repo_hierarchical_tree.has_task:
216 | click.echo("The following docs will be generated/updated:")
217 | new_meta_info.target_repo_hierarchical_tree.print_recursive(
218 | diff_status=True, ignore_list=setting.project.ignore_list
219 | )
220 | else:
221 | click.echo("No docs will be generated/updated, check your source-code update")
222 |
223 |
224 | @cli.command()
225 | def chat_with_repo():
226 | """
227 | Start an interactive chat session with the repository.
228 | """
229 | try:
230 | # Fetch and validate the settings using the SettingsManager
231 | setting = SettingsManager.get_setting()
232 | except ValidationError as e:
233 | # Handle configuration errors if the settings are invalid
234 | handle_setting_error(e)
235 | return
236 |
237 | from repo_agent.chat_with_repo import main
238 |
239 | main()
240 |
241 |
242 | if __name__ == "__main__":
243 | cli()
244 |
--------------------------------------------------------------------------------
/repo_agent/multi_task_dispatch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import random
4 | import threading
5 | import time
6 | from typing import Any, Callable, Dict, List
7 |
8 | from colorama import Fore, Style
9 |
10 |
11 | class Task:
12 | def __init__(self, task_id: int, dependencies: List[Task], extra_info: Any = None):
13 | self.task_id = task_id
14 | self.extra_info = extra_info
15 | self.dependencies = dependencies
16 | self.status = 0 # 任务状态:0未开始,1正在进行,2已经完成,3出错了
17 |
18 |
19 | class TaskManager:
20 | def __init__(self):
21 | """
22 | Initialize a MultiTaskDispatch object.
23 |
24 | This method initializes the MultiTaskDispatch object by setting up the necessary attributes.
25 |
26 | Attributes:
27 | - task_dict (Dict[int, Task]): A dictionary that maps task IDs to Task objects.
28 | - task_lock (threading.Lock): A lock used for thread synchronization when accessing the task_dict.
29 | - now_id (int): The current task ID.
30 | - query_id (int): The current query ID.
31 | - sync_func (None): A placeholder for a synchronization function.
32 |
33 | """
34 | self.task_dict: Dict[int, Task] = {}
35 | self.task_lock = threading.Lock()
36 | self.now_id = 0
37 | self.query_id = 0
38 |
39 | @property
40 | def all_success(self) -> bool:
41 | return len(self.task_dict) == 0
42 |
43 | def add_task(self, dependency_task_id: List[int], extra=None) -> int:
44 | """
45 | Adds a new task to the task dictionary.
46 |
47 | Args:
48 | dependency_task_id (List[int]): List of task IDs that the new task depends on.
49 | extra (Any, optional): Extra information associated with the task. Defaults to None.
50 |
51 | Returns:
52 | int: The ID of the newly added task.
53 | """
54 | with self.task_lock:
55 | depend_tasks = [self.task_dict[task_id] for task_id in dependency_task_id]
56 | self.task_dict[self.now_id] = Task(
57 | task_id=self.now_id, dependencies=depend_tasks, extra_info=extra
58 | )
59 | self.now_id += 1
60 | return self.now_id - 1
61 |
62 | def get_next_task(self, process_id: int):
63 | """
64 | Get the next task for a given process ID.
65 |
66 | Args:
67 | process_id (int): The ID of the process.
68 |
69 | Returns:
70 | tuple: A tuple containing the next task object and its ID.
71 | If there are no available tasks, returns (None, -1).
72 | """
73 | with self.task_lock:
74 | self.query_id += 1
75 | for task_id in self.task_dict.keys():
76 | ready = (
77 | len(self.task_dict[task_id].dependencies) == 0
78 | ) and self.task_dict[task_id].status == 0
79 | if ready:
80 | self.task_dict[task_id].status = 1
81 | print(
82 | f"{Fore.RED}[process {process_id}]{Style.RESET_ALL}: get task({task_id}), remain({len(self.task_dict)})"
83 | )
84 | return self.task_dict[task_id], task_id
85 | return None, -1
86 |
87 | def mark_completed(self, task_id: int):
88 | """
89 | Marks a task as completed and removes it from the task dictionary.
90 |
91 | Args:
92 | task_id (int): The ID of the task to mark as completed.
93 |
94 | """
95 | with self.task_lock:
96 | target_task = self.task_dict[task_id]
97 | for task in self.task_dict.values():
98 | if target_task in task.dependencies:
99 | task.dependencies.remove(target_task)
100 | self.task_dict.pop(task_id) # 从任务字典中移除
101 |
102 |
103 | def worker(task_manager, process_id: int, handler: Callable):
104 | """
105 | Worker function that performs tasks assigned by the task manager.
106 |
107 | Args:
108 | task_manager: The task manager object that assigns tasks to workers.
109 | process_id (int): The ID of the current worker process.
110 | handler (Callable): The function that handles the tasks.
111 |
112 | Returns:
113 | None
114 | """
115 | while True:
116 | if task_manager.all_success:
117 | return
118 | task, task_id = task_manager.get_next_task(process_id)
119 | if task is None:
120 | time.sleep(0.5)
121 | continue
122 | # print(f"will perform task: {task_id}")
123 | handler(task.extra_info)
124 | task_manager.mark_completed(task.task_id)
125 | # print(f"task complete: {task_id}")
126 |
127 |
128 | if __name__ == "__main__":
129 | task_manager = TaskManager()
130 |
131 | def some_function(): # 随机睡一会
132 | time.sleep(random.random() * 3)
133 |
134 | # 添加任务,例如:
135 | i1 = task_manager.add_task(some_function, []) # type: ignore
136 | i2 = task_manager.add_task(some_function, []) # type: ignore
137 | i3 = task_manager.add_task(some_function, [i1]) # type: ignore
138 | i4 = task_manager.add_task(some_function, [i2, i3]) # type: ignore
139 | i5 = task_manager.add_task(some_function, [i2, i3]) # type: ignore
140 | i6 = task_manager.add_task(some_function, [i1]) # type: ignore
141 |
142 | threads = [threading.Thread(target=worker, args=(task_manager,)) for _ in range(4)]
143 | for thread in threads:
144 | thread.start()
145 | for thread in threads:
146 | thread.join()
147 |
--------------------------------------------------------------------------------
/repo_agent/project_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import jedi
4 |
5 |
6 | class ProjectManager:
7 | def __init__(self, repo_path, project_hierarchy):
8 | self.repo_path = repo_path
9 | self.project = jedi.Project(self.repo_path)
10 | self.project_hierarchy = os.path.join(
11 | self.repo_path, project_hierarchy, "project_hierarchy.json"
12 | )
13 |
14 | def get_project_structure(self):
15 | """
16 | Returns the structure of the project by recursively walking through the directory tree.
17 |
18 | Returns:
19 | str: The project structure as a string.
20 | """
21 |
22 | def walk_dir(root, prefix=""):
23 | structure.append(prefix + os.path.basename(root))
24 | new_prefix = prefix + " "
25 | for name in sorted(os.listdir(root)):
26 | if name.startswith("."): # 忽略隐藏文件和目录
27 | continue
28 | path = os.path.join(root, name)
29 | if os.path.isdir(path):
30 | walk_dir(path, new_prefix)
31 | elif os.path.isfile(path) and name.endswith(".py"):
32 | structure.append(new_prefix + name)
33 |
34 | structure = []
35 | walk_dir(self.repo_path)
36 | return "\n".join(structure)
37 |
38 | def build_path_tree(self, who_reference_me, reference_who, doc_item_path):
39 | from collections import defaultdict
40 |
41 | def tree():
42 | return defaultdict(tree)
43 |
44 | path_tree = tree()
45 |
46 | # 构建 who_reference_me 和 reference_who 的树
47 | for path_list in [who_reference_me, reference_who]:
48 | for path in path_list:
49 | parts = path.split(os.sep)
50 | node = path_tree
51 | for part in parts:
52 | node = node[part]
53 |
54 | # 处理 doc_item_path
55 | parts = doc_item_path.split(os.sep)
56 | parts[-1] = "✳️" + parts[-1] # 在最后一个对象前面加上星号
57 | node = path_tree
58 | for part in parts:
59 | node = node[part]
60 |
61 | def tree_to_string(tree, indent=0):
62 | s = ""
63 | for key, value in sorted(tree.items()):
64 | s += " " * indent + key + "\n"
65 | if isinstance(value, dict):
66 | s += tree_to_string(value, indent + 1)
67 | return s
68 |
69 | return tree_to_string(path_tree)
70 |
71 |
72 | if __name__ == "__main__":
73 | project_manager = ProjectManager(repo_path="", project_hierarchy="")
74 | print(project_manager.get_project_structure())
75 |
--------------------------------------------------------------------------------
/repo_agent/prompt.py:
--------------------------------------------------------------------------------
1 | from llama_index.core import ChatPromptTemplate
2 | from llama_index.core.llms import ChatMessage, MessageRole
3 |
4 | doc_generation_instruction = (
5 | "You are an AI documentation assistant, and your task is to generate documentation based on the given code of an object. "
6 | "The purpose of the documentation is to help developers and beginners understand the function and specific usage of the code.\n\n"
7 | "Currently, you are in a project{project_structure_prefix}\n"
8 | "{project_structure}\n\n"
9 | "The path of the document you need to generate in this project is {file_path}.\n"
10 | 'Now you need to generate a document for a {code_type_tell}, whose name is "{code_name}".\n\n'
11 | "The content of the code is as follows:\n"
12 | "{code_content}\n\n"
13 | "{reference_letter}\n"
14 | "{referencer_content}\n\n"
15 | "Please generate a detailed explanation document for this object based on the code of the target object itself {combine_ref_situation}.\n\n"
16 | "Please write out the function of this {code_type_tell} in bold plain text, followed by a detailed analysis in plain text "
17 | "(including all details), in language {language} to serve as the documentation for this part of the code.\n\n"
18 | "The standard format is as follows:\n\n"
19 | "**{code_name}**: The function of {code_name} is XXX. (Only code name and one sentence function description are required)\n"
20 | "**{parameters_or_attribute}**: The {parameters_or_attribute} of this {code_type_tell}.\n"
21 | "· parameter1: XXX\n"
22 | "· parameter2: XXX\n"
23 | "· ...\n"
24 | "**Code Description**: The description of this {code_type_tell}.\n"
25 | "(Detailed and CERTAIN code analysis and description...{has_relationship})\n"
26 | "**Note**: Points to note about the use of the code\n"
27 | "{have_return_tell}\n\n"
28 | "Please note:\n"
29 | "- Any part of the content you generate SHOULD NOT CONTAIN Markdown hierarchical heading and divider syntax.\n"
30 | "- Write mainly in the desired language. If necessary, you can write with some English words in the analysis and description "
31 | "to enhance the document's readability because you do not need to translate the function name or variable name into the target language.\n"
32 | )
33 |
34 | documentation_guideline = (
35 | "Keep in mind that your audience is document readers, so use a deterministic tone to generate precise content and don't let them know "
36 | "you're provided with code snippet and documents. AVOID ANY SPECULATION and inaccurate descriptions! Now, provide the documentation "
37 | "for the target object in {language} in a professional way."
38 | )
39 |
40 |
41 | message_templates = [
42 | ChatMessage(content=doc_generation_instruction, role=MessageRole.SYSTEM),
43 | ChatMessage(
44 | content=documentation_guideline,
45 | role=MessageRole.USER,
46 | ),
47 | ]
48 |
49 | chat_template = ChatPromptTemplate(message_templates=message_templates)
50 |
--------------------------------------------------------------------------------
/repo_agent/settings.py:
--------------------------------------------------------------------------------
1 | from enum import StrEnum
2 | from typing import Optional
3 |
4 | from iso639 import Language, LanguageNotFoundError
5 | from pydantic import (
6 | DirectoryPath,
7 | Field,
8 | HttpUrl,
9 | PositiveFloat,
10 | PositiveInt,
11 | SecretStr,
12 | field_validator,
13 | )
14 | from pydantic_settings import BaseSettings
15 | from pathlib import Path
16 |
17 |
18 | class LogLevel(StrEnum):
19 | DEBUG = "DEBUG"
20 | INFO = "INFO"
21 | WARNING = "WARNING"
22 | ERROR = "ERROR"
23 | CRITICAL = "CRITICAL"
24 |
25 |
26 | class ProjectSettings(BaseSettings):
27 | target_repo: DirectoryPath = "" # type: ignore
28 | hierarchy_name: str = ".project_doc_record"
29 | markdown_docs_name: str = "markdown_docs"
30 | ignore_list: list[str] = []
31 | language: str = "English"
32 | max_thread_count: PositiveInt = 4
33 | log_level: LogLevel = LogLevel.INFO
34 |
35 | @field_validator("language")
36 | @classmethod
37 | def validate_language_code(cls, v: str) -> str:
38 | try:
39 | language_name = Language.match(v).name
40 | return language_name # Returning the resolved language name
41 | except LanguageNotFoundError:
42 | raise ValueError(
43 | "Invalid language input. Please enter a valid ISO 639 code or language name."
44 | )
45 |
46 | @field_validator("log_level", mode="before")
47 | @classmethod
48 | def set_log_level(cls, v: str) -> LogLevel:
49 | if isinstance(v, str):
50 | v = v.upper() # Convert input to uppercase
51 | if (
52 | v in LogLevel._value2member_map_
53 | ): # Check if the converted value is in enum members
54 | return LogLevel(v)
55 | raise ValueError(f"Invalid log level: {v}")
56 |
57 |
58 | class ChatCompletionSettings(BaseSettings):
59 | model: str = "gpt-4o-mini" # NOTE: No model restrictions for user flexibility, but it's recommended to use models with a larger context window.
60 | temperature: PositiveFloat = 0.2
61 | request_timeout: PositiveInt = 60
62 | openai_base_url: str = "https://api.openai.com/v1"
63 | openai_api_key: SecretStr = Field(..., exclude=True)
64 |
65 | @field_validator("openai_base_url", mode="before")
66 | @classmethod
67 | def convert_base_url_to_str(cls, openai_base_url: HttpUrl) -> str:
68 | return str(openai_base_url)
69 |
70 |
71 | class Setting(BaseSettings):
72 | project: ProjectSettings = {} # type: ignore
73 | chat_completion: ChatCompletionSettings = {} # type: ignore
74 |
75 |
76 | class SettingsManager:
77 | _setting_instance: Optional[Setting] = (
78 | None # Private class attribute, initially None
79 | )
80 |
81 | @classmethod
82 | def get_setting(cls):
83 | if cls._setting_instance is None:
84 | cls._setting_instance = Setting()
85 | return cls._setting_instance
86 |
87 | @classmethod
88 | def initialize_with_params(
89 | cls,
90 | target_repo: Path,
91 | markdown_docs_name: str,
92 | hierarchy_name: str,
93 | ignore_list: list[str],
94 | language: str,
95 | max_thread_count: int,
96 | log_level: str,
97 | model: str,
98 | temperature: float,
99 | request_timeout: int,
100 | openai_base_url: str,
101 | ):
102 | project_settings = ProjectSettings(
103 | target_repo=target_repo,
104 | hierarchy_name=hierarchy_name,
105 | markdown_docs_name=markdown_docs_name,
106 | ignore_list=ignore_list,
107 | language=language,
108 | max_thread_count=max_thread_count,
109 | log_level=LogLevel(log_level),
110 | )
111 |
112 | chat_completion_settings = ChatCompletionSettings(
113 | model=model,
114 | temperature=temperature,
115 | request_timeout=request_timeout,
116 | openai_base_url=openai_base_url,
117 | )
118 |
119 | cls._setting_instance = Setting(
120 | project=project_settings,
121 | chat_completion=chat_completion_settings,
122 | )
123 |
124 |
125 | if __name__ == "__main__":
126 | setting = SettingsManager.get_setting()
127 | print(setting.model_dump())
128 |
--------------------------------------------------------------------------------
/repo_agent/utils/gitignore_checker.py:
--------------------------------------------------------------------------------
1 | import fnmatch
2 | import os
3 |
4 |
5 | class GitignoreChecker:
6 | def __init__(self, directory: str, gitignore_path: str):
7 | """
8 | Initialize the GitignoreChecker with a specific directory and the path to a .gitignore file.
9 |
10 | Args:
11 | directory (str): The directory to be checked.
12 | gitignore_path (str): The path to the .gitignore file.
13 | """
14 | self.directory = directory
15 | self.gitignore_path = gitignore_path
16 | self.folder_patterns, self.file_patterns = self._load_gitignore_patterns()
17 |
18 | def _load_gitignore_patterns(self) -> tuple:
19 | """
20 | Load and parse the .gitignore file, then split the patterns into folder and file patterns.
21 |
22 | If the specified .gitignore file is not found, fall back to the default path.
23 |
24 | Returns:
25 | tuple: A tuple containing two lists - one for folder patterns and one for file patterns.
26 | """
27 | try:
28 | with open(self.gitignore_path, "r", encoding="utf-8") as file:
29 | gitignore_content = file.read()
30 | except FileNotFoundError:
31 | # Fallback to the default .gitignore path if the specified file is not found
32 | default_path = os.path.join(
33 | os.path.dirname(__file__), "..", "..", ".gitignore"
34 | )
35 | with open(default_path, "r", encoding="utf-8") as file:
36 | gitignore_content = file.read()
37 |
38 | patterns = self._parse_gitignore(gitignore_content)
39 | return self._split_gitignore_patterns(patterns)
40 |
41 | @staticmethod
42 | def _parse_gitignore(gitignore_content: str) -> list:
43 | """
44 | Parse the .gitignore content and return patterns as a list.
45 |
46 | Args:
47 | gitignore_content (str): The content of the .gitignore file.
48 |
49 | Returns:
50 | list: A list of patterns extracted from the .gitignore content.
51 | """
52 | patterns = []
53 | for line in gitignore_content.splitlines():
54 | line = line.strip()
55 | if line and not line.startswith("#"):
56 | patterns.append(line)
57 | return patterns
58 |
59 | @staticmethod
60 | def _split_gitignore_patterns(gitignore_patterns: list) -> tuple:
61 | """
62 | Split the .gitignore patterns into folder patterns and file patterns.
63 |
64 | Args:
65 | gitignore_patterns (list): A list of patterns from the .gitignore file.
66 |
67 | Returns:
68 | tuple: Two lists, one for folder patterns and one for file patterns.
69 | """
70 | folder_patterns = []
71 | file_patterns = []
72 | for pattern in gitignore_patterns:
73 | if pattern.endswith("/"):
74 | folder_patterns.append(pattern.rstrip("/"))
75 | else:
76 | file_patterns.append(pattern)
77 | return folder_patterns, file_patterns
78 |
79 | @staticmethod
80 | def _is_ignored(path: str, patterns: list, is_dir: bool = False) -> bool:
81 | """
82 | Check if the given path matches any of the patterns.
83 |
84 | Args:
85 | path (str): The path to check.
86 | patterns (list): A list of patterns to check against.
87 | is_dir (bool): True if the path is a directory, False otherwise.
88 |
89 | Returns:
90 | bool: True if the path matches any pattern, False otherwise.
91 | """
92 | for pattern in patterns:
93 | if fnmatch.fnmatch(path, pattern):
94 | return True
95 | if is_dir and pattern.endswith("/") and fnmatch.fnmatch(path, pattern[:-1]):
96 | return True
97 | return False
98 |
99 | def check_files_and_folders(self) -> list:
100 | """
101 | Check all files and folders in the given directory against the split gitignore patterns.
102 | Return a list of files that are not ignored and have the '.py' extension.
103 | The returned file paths are relative to the self.directory.
104 |
105 | Returns:
106 | list: A list of paths to files that are not ignored and have the '.py' extension.
107 | """
108 | not_ignored_files = []
109 | for root, dirs, files in os.walk(self.directory):
110 | dirs[:] = [
111 | d
112 | for d in dirs
113 | if not self._is_ignored(d, self.folder_patterns, is_dir=True)
114 | ]
115 |
116 | for file in files:
117 | file_path = os.path.join(root, file)
118 | relative_path = os.path.relpath(file_path, self.directory)
119 | if not self._is_ignored(
120 | file, self.file_patterns
121 | ) and file_path.endswith(".py"):
122 | not_ignored_files.append(relative_path)
123 |
124 | return not_ignored_files
125 |
126 |
127 | # Example usage:
128 | # gitignore_checker = GitignoreChecker('path_to_directory', 'path_to_gitignore_file')
129 | # not_ignored_files = gitignore_checker.check_files_and_folders()
130 | # print(not_ignored_files)
131 |
--------------------------------------------------------------------------------
/repo_agent/utils/meta_info_utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import os
3 |
4 | import git
5 | from colorama import Fore, Style
6 |
7 | from repo_agent.log import logger
8 | from repo_agent.settings import SettingsManager
9 |
10 | latest_verison_substring = "_latest_version.py"
11 |
12 |
13 | def make_fake_files():
14 | """根据git status检测暂存区信息。如果有文件:
15 | 1. 新增文件,没有add。无视
16 | 2. 修改文件内容,没有add,原始文件重命名为fake_file,新建原本的文件名内容为git status中的文件内容
17 | 3. 删除文件,没有add,原始文件重命名为fake_file,新建原本的文件名内容为git status中的文件内容
18 | 注意: 目标仓库的文件不能以latest_verison_substring结尾
19 | """
20 | delete_fake_files()
21 | setting = SettingsManager.get_setting()
22 |
23 | repo = git.Repo(setting.project.target_repo)
24 | unstaged_changes = repo.index.diff(None) # 在git status里,但是有修改没提交
25 | untracked_files = repo.untracked_files # 在文件系统里,但没在git里的文件
26 |
27 | jump_files = [] # 这里面的内容不parse、不生成文档,并且引用关系也不计算他们
28 | for file_name in untracked_files:
29 | if file_name.endswith(".py"):
30 | print(
31 | f"{Fore.LIGHTMAGENTA_EX}[SKIP untracked files]: {Style.RESET_ALL}{file_name}"
32 | )
33 | jump_files.append(file_name)
34 | for diff_file in unstaged_changes.iter_change_type(
35 | "A"
36 | ): # 新增的、没有add的文件,都不处理
37 | if diff_file.a_path.endswith(latest_verison_substring):
38 | logger.error(
39 | "FAKE_FILE_IN_GIT_STATUS detected! suggest to use `delete_fake_files` and re-generate document"
40 | )
41 | exit()
42 | jump_files.append(diff_file.a_path)
43 |
44 | file_path_reflections = {}
45 | for diff_file in itertools.chain(
46 | unstaged_changes.iter_change_type("M"), unstaged_changes.iter_change_type("D")
47 | ): # 获取修改过的文件
48 | if diff_file.a_path.endswith(latest_verison_substring):
49 | logger.error(
50 | "FAKE_FILE_IN_GIT_STATUS detected! suggest to use `delete_fake_files` and re-generate document"
51 | )
52 | exit()
53 | now_file_path = diff_file.a_path # 针对repo_path的相对路径
54 | if now_file_path.endswith(".py"):
55 | raw_file_content = diff_file.a_blob.data_stream.read().decode("utf-8")
56 | latest_file_path = now_file_path[:-3] + latest_verison_substring
57 | if os.path.exists(os.path.join(setting.project.target_repo, now_file_path)):
58 | os.rename(
59 | os.path.join(setting.project.target_repo, now_file_path),
60 | os.path.join(setting.project.target_repo, latest_file_path),
61 | )
62 |
63 | print(
64 | f"{Fore.LIGHTMAGENTA_EX}[Save Latest Version of Code]: {Style.RESET_ALL}{now_file_path} -> {latest_file_path}"
65 | )
66 | else:
67 | print(
68 | f"{Fore.LIGHTMAGENTA_EX}[Create Temp-File for Deleted(But not Staged) Files]: {Style.RESET_ALL}{now_file_path} -> {latest_file_path}"
69 | )
70 | with open(
71 | os.path.join(setting.project.target_repo, latest_file_path), "w"
72 | ) as writer:
73 | pass
74 | with open(
75 | os.path.join(setting.project.target_repo, now_file_path), "w"
76 | ) as writer:
77 | writer.write(raw_file_content)
78 | file_path_reflections[now_file_path] = latest_file_path # real指向fake
79 | return file_path_reflections, jump_files
80 |
81 |
82 | def delete_fake_files():
83 | """在任务执行完成以后,删除所有的fake_file"""
84 | setting = SettingsManager.get_setting()
85 |
86 | def gci(filepath):
87 | # 遍历filepath下所有文件,包括子目录
88 | files = os.listdir(filepath)
89 | for fi in files:
90 | fi_d = os.path.join(filepath, fi)
91 | if os.path.isdir(fi_d):
92 | gci(fi_d)
93 | elif fi_d.endswith(latest_verison_substring):
94 | origin_name = fi_d.replace(latest_verison_substring, ".py")
95 | os.remove(origin_name)
96 | if os.path.getsize(fi_d) == 0:
97 | print(
98 | f"{Fore.LIGHTRED_EX}[Deleting Temp File]: {Style.RESET_ALL}{fi_d[len(str(setting.project.target_repo)):]}, {origin_name[len(str(setting.project.target_repo)):]}"
99 | ) # type: ignore
100 | os.remove(fi_d)
101 | else:
102 | print(
103 | f"{Fore.LIGHTRED_EX}[Recovering Latest Version]: {Style.RESET_ALL}{origin_name[len(str(setting.project.target_repo)):]} <- {fi_d[len(str(setting.project.target_repo)):]}"
104 | ) # type: ignore
105 | os.rename(fi_d, origin_name)
106 |
107 | gci(setting.project.target_repo)
108 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenBMB/RepoAgent/825d988127d7bfd757237d9c4e8678d9104030f0/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_change_detector.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 |
4 | from git import Repo
5 |
6 | from repo_agent.change_detector import ChangeDetector
7 |
8 |
9 | class TestChangeDetector(unittest.TestCase):
10 | @classmethod
11 | def setUpClass(cls):
12 | # 定义测试仓库的路径
13 | cls.test_repo_path = os.path.join(os.path.dirname(__file__), 'test_repo')
14 |
15 | # 如果测试仓库文件夹不存在,则创建它
16 | if not os.path.exists(cls.test_repo_path):
17 | os.makedirs(cls.test_repo_path)
18 |
19 | # 初始化 Git 仓库
20 | cls.repo = Repo.init(cls.test_repo_path)
21 |
22 | # 配置 Git 用户信息
23 | cls.repo.git.config('user.email', 'ci@example.com')
24 | cls.repo.git.config('user.name', 'CI User')
25 |
26 | # 创建一些测试文件
27 | with open(os.path.join(cls.test_repo_path, 'test_file.py'), 'w') as f:
28 | f.write('print("Hello, Python")')
29 |
30 | with open(os.path.join(cls.test_repo_path, 'test_file.md'), 'w') as f:
31 | f.write('# Hello, Markdown')
32 |
33 | # 模拟 Git 操作:添加和提交文件
34 | cls.repo.git.add(A=True)
35 | cls.repo.git.commit('-m', 'Initial commit')
36 |
37 | def test_get_staged_pys(self):
38 | # 创建一个新的 Python 文件并暂存
39 | new_py_file = os.path.join(self.test_repo_path, 'new_test_file.py')
40 | with open(new_py_file, 'w') as f:
41 | f.write('print("New Python File")')
42 | self.repo.git.add(new_py_file)
43 |
44 | # 使用 ChangeDetector 检查暂存文件
45 | change_detector = ChangeDetector(self.test_repo_path)
46 | staged_files = change_detector.get_staged_pys()
47 |
48 | # 断言新文件在暂存文件列表中
49 | self.assertIn('new_test_file.py', [os.path.basename(path) for path in staged_files])
50 |
51 | print(f"\ntest_get_staged_pys: Staged Python files: {staged_files}")
52 |
53 |
54 | def test_get_unstaged_mds(self):
55 | # 修改一个 Markdown 文件但不暂存
56 | md_file = os.path.join(self.test_repo_path, 'test_file.md')
57 | with open(md_file, 'a') as f:
58 | f.write('\nAdditional Markdown content')
59 |
60 | # 使用 ChangeDetector 获取未暂存的 Markdown 文件
61 | change_detector = ChangeDetector(self.test_repo_path)
62 | unstaged_files = change_detector.get_to_be_staged_files()
63 |
64 | # 断言修改的文件在未暂存文件列表中
65 | self.assertIn('test_file.md', [os.path.basename(path) for path in unstaged_files])
66 |
67 | print(f"\ntest_get_unstaged_mds: Unstaged Markdown files: {unstaged_files}")
68 |
69 |
70 | def test_add_unstaged_mds(self):
71 | # 确保有一个未暂存的 Markdown 文件
72 | self.test_get_unstaged_mds()
73 |
74 | # 使用 ChangeDetector 添加未暂存的 Markdown 文件
75 | change_detector = ChangeDetector(self.test_repo_path)
76 | change_detector.add_unstaged_files()
77 |
78 | # 检查文件是否被暂存
79 | unstaged_files_after_add = change_detector.get_to_be_staged_files()
80 |
81 | # 断言暂存操作后没有未暂存的 Markdown 文件
82 | self.assertEqual(len(unstaged_files_after_add), 0)
83 |
84 | remaining_unstaged_files = len(unstaged_files_after_add)
85 | print(f"\ntest_add_unstaged_mds: Number of remaining unstaged Markdown files after add: {remaining_unstaged_files}")
86 |
87 |
88 | @classmethod
89 | def tearDownClass(cls):
90 | # 清理测试仓库
91 | cls.repo.close()
92 | os.system('rm -rf ' + cls.test_repo_path)
93 |
94 | if __name__ == '__main__':
95 | unittest.main()
96 |
--------------------------------------------------------------------------------
/tests/test_json_handler.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from unittest.mock import mock_open, patch
3 |
4 | from ..repo_agent.chat_with_repo.json_handler import (
5 | JsonFileProcessor, # Adjust the import according to your project structure
6 | )
7 |
8 |
9 | class TestJsonFileProcessor(unittest.TestCase):
10 |
11 | def setUp(self):
12 | self.processor = JsonFileProcessor("test.json")
13 |
14 | @patch("builtins.open", new_callable=mock_open, read_data='{"files": [{"objects": [{"md_content": "content1"}]}]}')
15 | def test_read_json_file(self, mock_file):
16 | # Test read_json_file method
17 | data = self.processor.read_json_file()
18 | self.assertEqual(data, {"files": [{"objects": [{"md_content": "content1"}]}]})
19 | mock_file.assert_called_with("test.json", "r", encoding="utf-8")
20 |
21 | @patch.object(JsonFileProcessor, 'read_json_file')
22 | def test_extract_md_contents(self, mock_read_json):
23 | # Test extract_md_contents method
24 | mock_read_json.return_value = {"files": [{"objects": [{"md_content": "content1"}]}]}
25 | md_contents = self.processor.extract_md_contents()
26 | self.assertIn("content1", md_contents)
27 |
28 | @patch("builtins.open", new_callable=mock_open, read_data='{"name": "test", "files": [{"name": "file1"}]}')
29 | def test_search_in_json_nested(self, mock_file):
30 | # Test search_in_json_nested method
31 | result = self.processor.search_in_json_nested("test.json", "file1")
32 | self.assertEqual(result, {"name": "file1"})
33 | mock_file.assert_called_with("test.json", "r", encoding="utf-8")
34 |
35 | # Additional tests for error handling (FileNotFoundError, JSONDecodeError, etc.) can be added here
36 |
37 | if __name__ == '__main__':
38 | unittest.main()
39 |
--------------------------------------------------------------------------------
/tests/test_structure_tree.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import defaultdict
3 |
4 |
5 | def build_path_tree(who_reference_me, reference_who, doc_item_path):
6 | def tree():
7 | return defaultdict(tree)
8 | path_tree = tree()
9 |
10 | for path_list in [who_reference_me, reference_who]:
11 | for path in path_list:
12 | parts = path.split(os.sep)
13 | node = path_tree
14 | for part in parts:
15 | node = node[part]
16 |
17 | # 处理 doc_item_path
18 | parts = doc_item_path.split(os.sep)
19 | parts[-1] = '✳️' + parts[-1] # 在最后一个对象前面加上星号
20 | node = path_tree
21 | for part in parts:
22 | node = node[part]
23 |
24 | def tree_to_string(tree, indent=0):
25 | s = ''
26 | for key, value in sorted(tree.items()):
27 | s += ' ' * indent + key + '\n'
28 | if isinstance(value, dict):
29 | s += tree_to_string(value, indent + 1)
30 | return s
31 |
32 | return tree_to_string(path_tree)
33 |
34 |
35 | if "__name__ == main":
36 | who_reference_me = [
37 | "repo_agent/file_handler.py/FileHandler/__init__",
38 | "repo_agent/runner.py/need_to_generate"
39 | ]
40 | reference_who = [
41 | "repo_agent/file_handler.py/FileHandler/__init__",
42 | "repo_agent/runner.py/need_to_generate",
43 | ]
44 |
45 | doc_item_path = 'tests/test_change_detector.py/TestChangeDetector'
46 |
47 | result = build_path_tree(who_reference_me,reference_who,doc_item_path)
48 | print(result)
49 |
--------------------------------------------------------------------------------