├── cora ├── __init__.py ├── agents │ ├── __init__.py │ ├── rewrite │ │ ├── __init__.py │ │ ├── dont.py │ │ ├── base.py │ │ ├── issue.py │ │ └── summ.py │ ├── snippets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── split_file.py │ │ ├── judge_snip.py │ │ ├── factory.py │ │ └── score_snip.py │ ├── simple_agent.py │ ├── find_entities.py │ ├── choose_files.py │ ├── score_preview.py │ ├── explore_tree.py │ ├── reason_agent.py │ └── base.py ├── base │ ├── __init__.py │ ├── paths.py │ ├── rag.py │ └── console.py ├── kwe │ ├── __init__.py │ ├── index.py │ ├── engine.py │ └── tokens.py ├── llms │ ├── __init__.py │ ├── easydeploy_.py │ ├── huggingface_.py │ ├── openai_.py │ ├── ollama_.py │ ├── anthropic_.py │ ├── factory.py │ └── base.py ├── repair │ ├── __init__.py │ ├── _gen_ev.py │ ├── events.py │ ├── patch.py │ └── refine.py ├── repo │ ├── __init__.py │ ├── repo.py │ ├── kwe.py │ └── find.py ├── retrv │ ├── __init__.py │ ├── result.py │ ├── _gen_ev.py │ └── events.py ├── splits │ ├── __init__.py │ ├── splitter.py │ ├── text_.py │ ├── factory.py │ ├── code_.py │ └── ftypes.py ├── utils │ ├── __init__.py │ ├── generic.py │ ├── parallel.py │ ├── pattern.py │ ├── sanitize.py │ ├── interval.py │ ├── tree.py │ ├── misc.py │ ├── cmdline.py │ └── event.py ├── preview │ ├── internal │ │ ├── __init__.py │ │ └── xml_element.py │ ├── __init__.py │ ├── python_.py │ ├── base.py │ ├── text.py │ ├── xml_.py │ └── code.py ├── results.py ├── agent.py ├── cfar.py ├── repoqa.py ├── config.py ├── fixit.py └── options.py ├── .github └── assets │ ├── codefuse.jpg │ └── overview.png ├── docs ├── ruff_config_pycharm.png ├── ruff.md └── README_zh.md ├── LEGAL.md ├── .pre-commit-config.yaml ├── env.template ├── environment.yaml ├── roff.toml └── .gitignore /cora/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/kwe/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/llms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/repair/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/repo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/retrv/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/retrv/result.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/splits/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/agents/rewrite/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/agents/snippets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cora/preview/internal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/assets/codefuse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/codefuse-repoagent/master/.github/assets/codefuse.jpg -------------------------------------------------------------------------------- /.github/assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/codefuse-repoagent/master/.github/assets/overview.png -------------------------------------------------------------------------------- /docs/ruff_config_pycharm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/codefuse-repoagent/master/docs/ruff_config_pycharm.png -------------------------------------------------------------------------------- /cora/agents/rewrite/dont.py: -------------------------------------------------------------------------------- 1 | from cora.agents.rewrite.base import RewriterBase 2 | 3 | 4 | class DontRewrite(RewriterBase): 5 | def rewrite(self, query: str) -> str: 6 | return query 7 | -------------------------------------------------------------------------------- /cora/utils/generic.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar, Generic, cast 2 | 3 | ThisClass = TypeVar("ThisClass") 4 | 5 | 6 | class CastSelfToThis(Generic[ThisClass]): 7 | @property 8 | def this(self) -> ThisClass: 9 | return cast(ThisClass, self) 10 | -------------------------------------------------------------------------------- /cora/agents/rewrite/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | 3 | from cora.repo.repo import Repository 4 | 5 | 6 | class RewriterBase(ABC): 7 | def __init__(self, repo: Repository): 8 | self.repo = repo 9 | 10 | @abstractmethod 11 | def rewrite(self, query: str) -> str: ... 12 | -------------------------------------------------------------------------------- /cora/utils/parallel.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from joblib import Parallel, delayed 4 | 5 | 6 | def parallel(fn_and_args: List[Tuple[any, tuple]], n_jobs: int, backend="locky") -> any: 7 | return Parallel(n_jobs=n_jobs, backend=backend)( 8 | delayed(x[0])(*x[1]) for x in fn_and_args 9 | ) 10 | -------------------------------------------------------------------------------- /cora/preview/__init__.py: -------------------------------------------------------------------------------- 1 | from cora.preview.base import FilePreview 2 | from cora.preview.code import CodePreview 3 | from cora.preview.python_ import PythonPreview 4 | from cora.preview.text import TextPreview 5 | from cora.preview.xml_ import XMLPreview 6 | 7 | __all__ = [FilePreview, CodePreview, TextPreview, XMLPreview, PythonPreview] 8 | -------------------------------------------------------------------------------- /LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | 3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 4 | 5 | 法律免责声明 6 | 7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /cora/utils/pattern.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | from typing import List 3 | 4 | 5 | def match_all_patterns(s: str, patterns: List[str]): 6 | return all(fnmatch.fnmatch(s, p) for p in patterns) 7 | 8 | 9 | def match_any_pattern(s: str, patterns: List[str]): 10 | return any(fnmatch.fnmatch(s, p) for p in patterns) 11 | 12 | 13 | def match_no_patterns(s: str, patterns: List[str]): 14 | return all(not fnmatch.fnmatch(s, p) for p in patterns) 15 | -------------------------------------------------------------------------------- /cora/repo/repo.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from cora.base.repos import RepoBase, RepoTup 4 | from cora.repo.find import FindMixin 5 | from cora.repo.kwe import KwEngMixin 6 | 7 | 8 | class Repository(RepoBase, KwEngMixin, FindMixin): 9 | def __init__(self, repo: RepoTup, *, excludes: Optional[List[str]] = None): 10 | RepoBase.__init__(self, repo, excludes=excludes) 11 | KwEngMixin.__init__(self) 12 | FindMixin.__init__(self) 13 | -------------------------------------------------------------------------------- /cora/repair/_gen_ev.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from cora.utils.event import gen_event_and_callback_classes 4 | 5 | if __name__ == "__main__": 6 | gen_event_and_callback_classes( 7 | emitter="IssueRepa", 8 | events=[ 9 | "start", 10 | "finish", 11 | "gen_patch_start", 12 | "gen_patch_finish", 13 | "eval_patch_start", 14 | "eval_patch_finish", 15 | "next_round", 16 | ], 17 | to_file=Path(__file__).parent / "events.py", 18 | ) 19 | -------------------------------------------------------------------------------- /cora/splits/splitter.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import cached_property 3 | from typing import List 4 | 5 | from cora.base.paths import FilePath, SnippetPath 6 | 7 | 8 | class Splitter: 9 | def __init__(self, file: FilePath): 10 | self.file = file 11 | 12 | @cached_property 13 | def content(self): 14 | return self.file.read_text(encoding="utf-8", errors="replace") 15 | 16 | def split(self) -> List[SnippetPath]: 17 | return self._do_split() 18 | 19 | @abstractmethod 20 | def _do_split(self) -> List[SnippetPath]: ... 21 | -------------------------------------------------------------------------------- /cora/splits/text_.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from cora.base.paths import SnippetPath, FilePath 4 | from cora.splits.splitter import Splitter 5 | 6 | 7 | class LineSpl(Splitter): 8 | """A line-based text splitter without overlapping""" 9 | 10 | def __init__(self, file: FilePath, snippet_size: int = 15): 11 | super().__init__(file) 12 | self._snippet_size = snippet_size 13 | 14 | def _do_split(self) -> List[SnippetPath]: 15 | t = len(self.content.splitlines()) 16 | s = self._snippet_size 17 | return [SnippetPath(self.file, i, min(i + s, t)) for i in range(0, t, s)] 18 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.7.2 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | fail_fast: true 9 | # Run the formatter. 10 | - id: ruff-format 11 | - repo: https://github.com/compilerla/conventional-pre-commit 12 | rev: v3.6.0 13 | hooks: 14 | - id: conventional-pre-commit 15 | stages: [commit-msg] 16 | args: ["--verbose"] 17 | - repo: https://github.com/pre-commit/pre-commit-hooks 18 | rev: v4.6.0 19 | hooks: 20 | - id: end-of-file-fixer 21 | - id: trailing-whitespace 22 | - id: check-yaml 23 | -------------------------------------------------------------------------------- /cora/utils/sanitize.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | _PATTERN_PHONE_NUMBER = re.compile( 4 | r"(?:\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}" 5 | ) 6 | _PATTERN_EMAIL_ADDRESS = re.compile(r"\b[\w\-.+]+@(?:[\w-]+\.)+[\w-]{2,4}\b") 7 | _PATTERN_PASSWORD = re.compile( 8 | r'["\']?password["\']?\s*[=:]\s*["\']?[\w_]+["\']?', flags=re.IGNORECASE 9 | ) 10 | 11 | 12 | def sanitize_content(content): 13 | content = _PATTERN_EMAIL_ADDRESS.sub("", content) 14 | content = _PATTERN_PHONE_NUMBER.sub("", content) 15 | content = _PATTERN_PASSWORD.sub("", content) 16 | return content 17 | -------------------------------------------------------------------------------- /cora/llms/easydeploy_.py: -------------------------------------------------------------------------------- 1 | from cora.llms.base import LLMBase 2 | 3 | 4 | def call_easydeploy(model, messages, *, temperature, top_p, max_tokens) -> str: 5 | raise NotImplementedError("EasyDeploy") 6 | 7 | 8 | class EasyDeploy(LLMBase): 9 | def __init__(self, model, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self.model = model 12 | 13 | def do_query(self) -> str: 14 | return call_easydeploy( 15 | self.model, 16 | [m.to_json() for m in self.history], 17 | temperature=self.temperature, 18 | top_p=self.top_p, 19 | max_tokens=self.max_tokens, 20 | ) 21 | -------------------------------------------------------------------------------- /cora/retrv/_gen_ev.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from cora.utils.event import gen_event_and_callback_classes 4 | 5 | if __name__ == "__main__": 6 | gen_event_and_callback_classes( 7 | emitter="Retriever", 8 | events=[ 9 | "start", 10 | "finish", 11 | "qrw_start", 12 | "qrw_finish", 13 | "edl_start", 14 | "edl_finish", 15 | "kws_start", 16 | "kws_finish", 17 | "fte_start", 18 | "fte_finish", 19 | "fps_start", 20 | "fps_finish", 21 | "scr_start", 22 | "scr_finish", 23 | ], 24 | to_file=Path(__file__).parent / "events.py", 25 | ) 26 | -------------------------------------------------------------------------------- /cora/utils/interval.py: -------------------------------------------------------------------------------- 1 | from intervaltree import IntervalTree 2 | 3 | 4 | def merge_overlapping_intervals(intervals, merge_continuous=False): 5 | iv_tree = IntervalTree.from_tuples(intervals) 6 | iv_tree.merge_overlaps() 7 | intervals = [(iv.begin, iv.end) for iv in iv_tree] 8 | intervals.sort() 9 | # [30, 50) and [50, 60) won't be merged as they do not overlap 10 | if not merge_continuous: 11 | return intervals 12 | # Merge [30, 50) and [50, 60) into [30, 60) 13 | refined = [] 14 | for iv in intervals: 15 | if len(refined) > 0 and refined[-1][1] == iv[0]: 16 | refined[-1] = (refined[-1][0], iv[1]) 17 | else: 18 | refined.append(iv) 19 | return refined 20 | -------------------------------------------------------------------------------- /cora/llms/huggingface_.py: -------------------------------------------------------------------------------- 1 | from cora.llms.base import LLMBase 2 | 3 | # TODO: make thread-safe 4 | _CACHED_MODELS = {} 5 | 6 | 7 | def call_huggingface(model_id, messages, *, temperature, top_p, max_tokens): 8 | # TODO: directly use pipeline 9 | raise NotImplementedError("HuggingFace") 10 | 11 | 12 | class HuggingFace(LLMBase): 13 | def __init__(self, model, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.model = model 16 | 17 | def do_query(self) -> str: 18 | return call_huggingface( 19 | self.model, 20 | [m.to_json() for m in self.history], 21 | temperature=self.temperature, 22 | top_p=self.top_p, 23 | max_tokens=self.max_tokens, 24 | ) 25 | -------------------------------------------------------------------------------- /cora/llms/openai_.py: -------------------------------------------------------------------------------- 1 | import openai 2 | 3 | from cora.llms.base import LLMBase 4 | 5 | _client = openai.OpenAI() 6 | 7 | 8 | def call_openai(model_name, messages, *, temperature, top_p, max_tokens): 9 | resp = _client.chat.completions.create( 10 | model=model_name, 11 | messages=messages, 12 | temperature=temperature, 13 | top_p=top_p, 14 | max_completion_tokens=max_tokens, 15 | ) 16 | return resp.choices[0].message.content 17 | 18 | 19 | class OpenAI(LLMBase): 20 | def __init__(self, model, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | self.model = model 23 | 24 | def do_query(self) -> str: 25 | return call_openai( 26 | self.model, 27 | [m.to_json() for m in self.history], 28 | temperature=self.temperature, 29 | top_p=self.top_p, 30 | max_tokens=self.max_tokens, 31 | ) 32 | -------------------------------------------------------------------------------- /cora/preview/internal/xml_element.py: -------------------------------------------------------------------------------- 1 | from xml.etree.ElementTree import Element 2 | 3 | 4 | class Elements: 5 | @staticmethod 6 | def start_point(elem: Element) -> int: 7 | # noinspection PyTypeChecker 8 | return elem.attrib["start_point"] 9 | 10 | @staticmethod 11 | def start_line_number(elem: Element) -> int: 12 | # noinspection PyTypeChecker 13 | return elem.attrib["start_point"][0] 14 | 15 | @staticmethod 16 | def start_column_number(elem: Element) -> int: 17 | # noinspection PyTypeChecker 18 | return elem.attrib["start_point"][1] 19 | 20 | @staticmethod 21 | def end_line_number(elem: Element) -> int: 22 | # noinspection PyTypeChecker 23 | return elem.attrib["end_point"][0] 24 | 25 | @staticmethod 26 | def end_column_number(elem: Element) -> int: 27 | # noinspection PyTypeChecker 28 | return elem.attrib["end_point"][1] 29 | -------------------------------------------------------------------------------- /cora/preview/python_.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from cora.preview import FilePreview 4 | 5 | 6 | class _PreviewVisitor(ast.NodeVisitor): 7 | def __init__(self): 8 | self.lines = [] 9 | 10 | def visit_ClassDef(self, node): 11 | self.lines.append(f"{self.get_indent(node)}class {node.name}:") 12 | self.generic_visit(node) 13 | 14 | def visit_FunctionDef(self, node): 15 | args = ", ".join(arg.arg for arg in node.args.args) 16 | self.lines.append(f"{self.get_indent(node)}def {node.name}({args}):") 17 | 18 | @staticmethod 19 | def get_indent(node): 20 | return " " * node.col_offset 21 | 22 | 23 | @FilePreview.register(["python"]) 24 | class PythonPreview(FilePreview): 25 | def get_preview(self) -> str: 26 | tree = ast.parse(self.file_content, filename=self.file_name) 27 | visitor = _PreviewVisitor() 28 | visitor.visit(tree) 29 | 30 | return "\n".join(visitor.lines) 31 | -------------------------------------------------------------------------------- /cora/llms/ollama_.py: -------------------------------------------------------------------------------- 1 | import ollama 2 | 3 | from cora.llms.base import LLMBase 4 | 5 | 6 | def call_ollama(model_name, messages, *, temperature, top_p, max_tokens): 7 | resp = ollama.chat( 8 | model_name, 9 | messages=messages, 10 | options={"temperature": temperature, "top_p": top_p, "max_tokens": max_tokens}, 11 | ) 12 | return resp["message"]["content"] 13 | 14 | 15 | class Ollama(LLMBase): 16 | def __init__(self, model, *args, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | self.model = model 19 | 20 | def do_query(self) -> str: 21 | return call_ollama( 22 | self.model, 23 | [m.to_json() for m in self.history], 24 | temperature=self.temperature, 25 | top_p=self.top_p, 26 | max_tokens=self.max_tokens, 27 | ) 28 | 29 | 30 | if __name__ == "__main__": 31 | model_ = Ollama("qwen2:0.5b-instruct", temperature=0.8, debug_mode=True) 32 | model_.append_user_message("Hi, I'm Tony! What's your name?") 33 | model_.query() 34 | -------------------------------------------------------------------------------- /cora/agents/snippets/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple, Generator 3 | 4 | from cora.repo.repo import Repository 5 | 6 | 7 | class SnipRelDetmBase(ABC): 8 | @abstractmethod 9 | def is_debugging(self) -> bool: ... 10 | 11 | @abstractmethod 12 | def disable_debugging(self): ... 13 | 14 | @abstractmethod 15 | def enable_debugging(self): ... 16 | 17 | @abstractmethod 18 | def determine( 19 | self, query: str, snippet_path: str, snippet_content: str, *args, **kwargs 20 | ) -> Tuple[bool, str]: ... 21 | 22 | 23 | class SnipFinderBase(ABC): 24 | def __init__(self, repo: Repository, determ: SnipRelDetmBase): 25 | self.repo = repo 26 | self.determ = determ 27 | 28 | def disable_debugging(self): 29 | self.determ.disable_debugging() 30 | 31 | def enable_debugging(self): 32 | self.determ.enable_debugging() 33 | 34 | @abstractmethod 35 | def find( 36 | self, query: str, file_path: str, *args, **kwargs 37 | ) -> Generator[Tuple[str, str], None, None]: ... 38 | -------------------------------------------------------------------------------- /env.template: -------------------------------------------------------------------------------- 1 | ## 2 | ## Global settings 3 | ## 4 | CACHE_DIRECTORY_PATH=cora_cache # Directory saving caches (like indices) of CodeFuse RepoAgent (CoRA) 5 | SANITIZE_CONTENT_IN_REPOSITORY=0 # Set this to "1" to sanitize sensitive information (e.g., email addresses, password) 6 | 7 | ## 8 | ## OpenAI Settings: Set these if you prefer to using OpenAI 9 | ## 10 | OPENAI_API_KEY=sk-xxx # Your API key in OpenAI 11 | OPENAI_ORG_ID=bk-xxx # Your organization ID in OpenAI 12 | 13 | ## 14 | ## Anthropic Settings: Set these if you prefer to using Anthropic 15 | ## 16 | ANTHROPIC_API_KEY=xx-xxx # Your API key in Anthropic 17 | 18 | ## 19 | ## HuggingFace Settings: Set these if you prefer to using HuggingFace 20 | ## 21 | HF_DATASETS_OFFLINE=1 # Disable HuggingFace's online accessing to datasets 22 | TRANSFORMERS_OFFLINE=1 # Disable HuggingFace's online accessing to models 23 | TOKENIZERS_PARALLELISM=false # Disable tokenizer's parallelism 24 | 25 | ## 26 | ## EasyDeploy Settings: Set these if you prefer to using EasyDeploy 27 | ## 28 | EASYDEPLOY_ENDPOINT=https://xxx # Endpoint 29 | -------------------------------------------------------------------------------- /cora/repair/events.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, unique 2 | 3 | from cora.utils.event import EventEmitter 4 | 5 | 6 | @unique 7 | class IssueRepaEvents(Enum): 8 | EVENT_START = "start" 9 | EVENT_FINISH = "finish" 10 | EVENT_GEN_PATCH_START = "gen_patch_start" 11 | EVENT_GEN_PATCH_FINISH = "gen_patch_finish" 12 | EVENT_EVAL_PATCH_START = "eval_patch_start" 13 | EVENT_EVAL_PATCH_FINISH = "eval_patch_finish" 14 | EVENT_NEXT_ROUND = "next_round" 15 | 16 | 17 | class IssueRepaCallbacks: 18 | def on_start(self, **kwargs): 19 | pass 20 | 21 | def on_finish(self, **kwargs): 22 | pass 23 | 24 | def on_gen_patch_start(self, **kwargs): 25 | pass 26 | 27 | def on_gen_patch_finish(self, **kwargs): 28 | pass 29 | 30 | def on_eval_patch_start(self, **kwargs): 31 | pass 32 | 33 | def on_eval_patch_finish(self, **kwargs): 34 | pass 35 | 36 | def on_next_round(self, **kwargs): 37 | pass 38 | 39 | def register_to(self, x: EventEmitter): 40 | for k, v in IssueRepaEvents.__members__.items(): 41 | x.on(v.value, getattr(self, f"on_{v.value}")) 42 | -------------------------------------------------------------------------------- /cora/utils/tree.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Protocol, TypeVar, Generic, Optional, List 3 | 4 | 5 | class TreeNodeVisitor(Protocol): 6 | @abstractmethod 7 | def __call__(self, node: "TreeNode"): ... 8 | 9 | 10 | TreeData = TypeVar("TreeData") 11 | 12 | 13 | class TreeNode(Generic[TreeData]): 14 | def __init__(self, data: TreeData, parent: Optional["TreeNode"] = None): 15 | self.data: TreeData = data 16 | self.parent: Optional["TreeNode"] = parent 17 | self.children: List["TreeNode"] = [] 18 | if parent: 19 | parent.children.append(self) 20 | 21 | def leaves(self) -> List["TreeNode"]: 22 | leaves = [] 23 | 24 | def _visit(node): 25 | if len(node.children) == 0: 26 | leaves.append(node) 27 | 28 | self.accept(_visit) 29 | 30 | return leaves 31 | 32 | def detach(self): 33 | if not self.parent: 34 | return 35 | self.parent.children.remove(self) 36 | self.parent = None 37 | 38 | def accept(self, visitor: TreeNodeVisitor): 39 | visitor(self) 40 | for child in self.children: 41 | child.accept(visitor) 42 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: cora_venv 2 | channels: 3 | - defaults 4 | dependencies: 5 | - bzip2=1.0.8=h80987f9_6 6 | - ca-certificates=2024.9.24=hca03da5_0 7 | - libffi=3.4.4=hca03da5_1 8 | - ncurses=6.4=h313beb8_0 9 | - openssl=3.0.15=h80987f9_0 10 | - pip=24.2=py310hca03da5_0 11 | - python=3.10.11=hb885b13_3 12 | - readline=8.2=h1a28f6b_0 13 | - setuptools=75.1.0=py310hca03da5_0 14 | - sqlite=3.45.3=h80987f9_0 15 | - tk=8.6.14=h6ba3021_0 16 | - tzdata=2024b=h04d1e81_0 17 | - wheel=0.44.0=py310hca03da5_0 18 | - xz=5.4.6=h80987f9_1 19 | - zlib=1.2.13=h18a0788_1 20 | - pip: 21 | - cfgv==3.4.0 22 | - distlib==0.3.9 23 | - filelock==3.16.1 24 | - identify==2.6.1 25 | - nodeenv==1.9.1 26 | - platformdirs==4.3.6 27 | - pre-commit==4.0.1 28 | - pyyaml==6.0.2 29 | - virtualenv==20.27.1 30 | - pyjson5==1.6.7 31 | - rich==13.9.4 32 | - python-dotenv==1.0.1 33 | - RapidFuzz==3.10.1 34 | - joblib==1.4.2 35 | - intervaltree==3.1.0 36 | - tree-sitter==0.21.3 37 | - tree-sitter-languages==1.10.2 38 | - datasets==3.1.0 39 | - ollama==0.3.3 40 | - openai==1.56.0 41 | - anthropic==0.42.0 42 | -------------------------------------------------------------------------------- /cora/base/paths.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | FilePath = Path 4 | 5 | 6 | class SnippetPath: 7 | def __init__(self, file: FilePath, start: int, end: int): 8 | self._file_path = file 9 | self._start_line = start 10 | self._end_line = end 11 | 12 | @staticmethod 13 | def from_str(snp_path) -> "SnippetPath": 14 | f, t = snp_path.split(":") 15 | s, e = t.split("-") 16 | return SnippetPath(FilePath(f), int(s), int(e)) 17 | 18 | @property 19 | def file_path(self) -> FilePath: 20 | return self._file_path 21 | 22 | @property 23 | def start_line(self) -> int: 24 | return self._start_line 25 | 26 | @property 27 | def end_line(self) -> int: 28 | return self._end_line 29 | 30 | def size(self): 31 | return self._end_line - self._start_line 32 | 33 | def as_tuple(self): 34 | return str(self.file_path), self._start_line, self._end_line 35 | 36 | def __eq__(self, other): 37 | if not isinstance(other, SnippetPath): 38 | return False 39 | return str(self) == str(other) 40 | 41 | def __str__(self): 42 | return f"{self.file_path}:{self.start_line}-{self.end_line}" 43 | -------------------------------------------------------------------------------- /docs/ruff.md: -------------------------------------------------------------------------------- 1 | # Configure Ruff's Linter and Formatter 2 | 3 | This documents ensures Ruff to lint and format the file on each save action. 4 | 5 | ## Configure In PyCharm 6 | 7 | 1. Install CoRA and start CoRA's environment following instructions in [README.md](../README.md). 8 | 2. Install Ruff with the version (with a comment `# Ruff version.`) shown in [pre-commit configs](../.pre-commit-config.yaml). For example: 9 | 10 | ```shell 11 | pip install ruff==v0.7.2 # Change the version accordingly 12 | ``` 13 | 14 | 3. Import CoRA into PyCharm; please ensure CoRA's environment has been imported as a common conda environment 15 | 4. Open Settings -> Tools -> File Watchers, and add a new file watcher named with the following configurations with the Arguments filled with 16 | 17 | ```shell 18 | -c "$PyInterpreterDirectory$/ruff check $FilePathRelativeToProjectRoot$ && $PyInterpreterDirectory$/ruff format $FilePathRelativeToProjectRoot$" 19 | ``` 20 | 21 | ![Ruff in PyCharm](./ruff_config_pycharm.png) 22 | 23 | 5. Open Settings -> Tools -> Actions On Save to ensure "File Watcher: Ruff Lint & Format" is toggled. 24 | 6. Save to apply all the above settings. 25 | 26 | ## Configure in VSCode 27 | 28 | > TODO: Add instructions to configure Ruff in VSCode 29 | -------------------------------------------------------------------------------- /cora/llms/anthropic_.py: -------------------------------------------------------------------------------- 1 | import anthropic 2 | 3 | from cora.llms.base import LLMBase 4 | 5 | _client = anthropic.Anthropic() 6 | 7 | 8 | def call_anthropic(model_name, messages, *, temperature, top_p, max_tokens, system): 9 | resp = _client.messages.create( 10 | model=model_name, 11 | messages=messages, 12 | system=system, 13 | temperature=temperature, 14 | top_p=top_p, 15 | max_tokens=max_tokens, 16 | ) 17 | return resp.content[0].text 18 | 19 | 20 | class Anthropic(LLMBase): 21 | def __init__(self, model, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.model = model 24 | 25 | def do_query(self) -> str: 26 | messages = [m.to_json() for m in self.history] 27 | if messages[0]["role"] == "system": 28 | system_prompt = messages[0]["content"] 29 | messages = messages[1:] 30 | else: 31 | system_prompt = anthropic.NOT_GIVEN 32 | return call_anthropic( 33 | self.model, 34 | messages=messages, 35 | temperature=self.temperature, 36 | top_p=self.top_p, 37 | max_tokens=self.max_tokens, 38 | system=system_prompt, 39 | ) 40 | -------------------------------------------------------------------------------- /cora/agents/simple_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Type, List 2 | 3 | from cora.agents.base import AgentBase 4 | from cora.llms.base import LLMBase 5 | 6 | JSON_SCHEMA = """\ 7 | {{ 8 | {props} 9 | }}\ 10 | """ 11 | 12 | 13 | class SimpleAgent(AgentBase): 14 | def __init__( 15 | self, llm: LLMBase, returns: List[Tuple[str, Type, str]], *args, **kwargs 16 | ): 17 | super().__init__( 18 | llm=llm, 19 | json_schema=JSON_SCHEMA.format( 20 | props="\n".join([f' "{prop}": {desc}' for prop, _, desc in returns]) 21 | ), 22 | *args, 23 | **kwargs, 24 | ) 25 | self.returns = returns 26 | 27 | def _check_response_format(self, response: dict, *args, **kwargs): 28 | for ret in self.returns: 29 | if ret[0] not in response: 30 | return False, f"'{ret[0]}' is missing in the JSON object" 31 | return True, None 32 | 33 | def _check_response_semantics(self, response: dict, *args, **kwargs): 34 | for r in self.returns: 35 | if isinstance(type(response[r[0]]), r[1]): 36 | return False, f"{r[0]} should be with type {r[1]}" 37 | return True, None 38 | 39 | def _parse_response(self, response: dict, *args, **kwargs): 40 | return {r[0]: response[r[0]] for r in self.returns} 41 | -------------------------------------------------------------------------------- /cora/splits/factory.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Type, List 2 | 3 | from cora.base.paths import FilePath 4 | from cora.splits.code_ import ASTSpl 5 | from cora.splits.ftypes import parse_ftype 6 | from cora.splits.splitter import Splitter 7 | from cora.splits.text_ import LineSpl 8 | 9 | 10 | class SplFactory: 11 | _additional_splitters: Dict[str, Type[Splitter]] = {} 12 | 13 | @classmethod 14 | def register(cls, file_types: List[str]): 15 | def register_inner(spl_cls: Type[Splitter]): 16 | for ft in file_types: 17 | assert ( 18 | ft not in cls._additional_splitters 19 | ), f"Cannot register {spl_cls} for {ft}; it was already registered by {cls._additional_splitters[ft]}" 20 | cls._additional_splitters[ft] = spl_cls 21 | return spl_cls 22 | 23 | return register_inner 24 | 25 | @classmethod 26 | def create(cls, file: FilePath) -> Splitter: 27 | file_type = parse_ftype(file.name) 28 | # We are a registered type 29 | if file_type in cls._additional_splitters: 30 | return cls._additional_splitters[file_type](file) 31 | try: 32 | # We might be a code file; we split by AST 33 | return ASTSpl(file) 34 | except Exception: 35 | # We fall back to a text splitter; we split by lines 36 | return LineSpl(file) 37 | 38 | @staticmethod 39 | def get_splitter(): 40 | pass 41 | -------------------------------------------------------------------------------- /cora/utils/misc.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | import time 4 | import traceback 5 | from collections import OrderedDict 6 | 7 | import joblib 8 | 9 | 10 | class CannotReachHereError(RuntimeError): 11 | def __init__(self, msg): 12 | super().__init__(msg) 13 | 14 | 15 | # TODO: Consider using jd/tenacity 16 | def robust_call(retry=2, sleep=10): 17 | def _robust_call(fn): 18 | def _function(*args, **kwargs): 19 | exception = None 20 | for _ in range(retry): 21 | try: 22 | return fn(*args, **kwargs) 23 | except Exception as e: 24 | exception = e 25 | traceback.print_exc() 26 | time.sleep(random.randint(1, sleep)) 27 | raise exception 28 | 29 | return _function 30 | 31 | return _robust_call 32 | 33 | 34 | def to_bool(s): 35 | if isinstance(s, bool): 36 | return s 37 | elif isinstance(s, str): 38 | return s.lower() in ["true", "y", "yes", "1"] 39 | else: 40 | return bool(s) 41 | 42 | 43 | def save_object(obj, path): 44 | joblib.dump(obj, path) 45 | with open(path, "wb") as fou: 46 | joblib.dump(obj, fou) 47 | 48 | 49 | def load_object(path): 50 | joblib.load(path) 51 | with open(path, "rb") as fin: 52 | return pickle.load(fin) 53 | 54 | 55 | def ordered_set(array: list) -> set: 56 | return OrderedDict.fromkeys([x for x in array]) 57 | -------------------------------------------------------------------------------- /roff.toml: -------------------------------------------------------------------------------- 1 | # Exclude a variety of commonly ignored directories. 2 | exclude = [ 3 | ".bzr", 4 | ".direnv", 5 | ".eggs", 6 | ".git", 7 | ".git-rewrite", 8 | ".hg", 9 | ".ipynb_checkpoints", 10 | ".mypy_cache", 11 | ".nox", 12 | ".pants.d", 13 | ".pyenv", 14 | ".pytest_cache", 15 | ".pytype", 16 | ".ruff_cache", 17 | ".svn", 18 | ".tox", 19 | ".venv", 20 | ".vscode", 21 | "__pypackages__", 22 | "_build", 23 | "buck-out", 24 | "build", 25 | "dist", 26 | "node_modules", 27 | "site-packages", 28 | "venv", 29 | ] 30 | 31 | # Same as Black. 32 | line-length = 88 33 | indent-width = 4 34 | 35 | # Assume Python 3.8 36 | target-version = "py38" 37 | 38 | [lint] 39 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 40 | select = ["E4", "E7", "E9", "F"] 41 | ignore = [] 42 | 43 | # Allow fix for all enabled rules (when `--fix`) is provided. 44 | fixable = ["ALL"] 45 | unfixable = [] 46 | 47 | # Allow unused variables when underscore-prefixed. 48 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 49 | 50 | [format] 51 | # Like Black, use double quotes for strings. 52 | quote-style = "double" 53 | 54 | # Like Black, indent with spaces, rather than tabs. 55 | indent-style = "space" 56 | 57 | # Like Black, respect magic trailing commas. 58 | skip-magic-trailing-comma = false 59 | 60 | # Like Black, automatically detect the appropriate line ending. 61 | line-ending = "auto" 62 | -------------------------------------------------------------------------------- /cora/utils/cmdline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shlex 3 | import signal 4 | import subprocess 5 | 6 | 7 | def safe_killpg(pid, sig): 8 | try: 9 | os.killpg(pid, sig) 10 | except ProcessLookupError: 11 | pass # Ignore if there is no such process 12 | 13 | 14 | def spawn_process(cmd, stdout, stderr, timeout) -> subprocess.CompletedProcess: 15 | # Fix: subprocess.run(cmd) series methods, when timed out, only send a SIGTERM 16 | # signal to cmd while does not kill cmd's subprocess. We let each command run 17 | # in a new process group by adding start_new_session flag, and kill the whole 18 | # process group such that all cmd's subprocesses are also killed when timed out. 19 | with subprocess.Popen( 20 | cmd, stdout=stdout, stderr=stderr, start_new_session=True 21 | ) as proc: 22 | try: 23 | output, err_msg = proc.communicate(timeout=timeout) 24 | except: # Including TimeoutExpired, KeyboardInterrupt, communicate handled that. 25 | safe_killpg(os.getpgid(proc.pid), signal.SIGKILL) 26 | # We don't call proc.wait() as .__exit__ does that for us. 27 | raise 28 | ecode = proc.poll() 29 | return subprocess.CompletedProcess(proc.args, ecode, output, err_msg) 30 | 31 | 32 | def check_call(cmd: str, timeout: int = 60): 33 | proc = spawn_process( 34 | shlex.split(cmd), 35 | stdout=subprocess.PIPE, 36 | stderr=subprocess.PIPE, 37 | timeout=timeout, 38 | ) 39 | proc.check_returncode() 40 | -------------------------------------------------------------------------------- /cora/llms/factory.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Literal 3 | 4 | from cora.llms.anthropic_ import Anthropic 5 | from cora.llms.base import LLMBase 6 | from cora.llms.easydeploy_ import EasyDeploy 7 | from cora.llms.huggingface_ import HuggingFace 8 | from cora.llms.ollama_ import Ollama 9 | from cora.llms.openai_ import OpenAI 10 | 11 | 12 | @dataclass 13 | class LLMConfig: 14 | provider: Literal["openai", "ollama", "huggingface"] 15 | llm_name: str 16 | debug_mode: bool = field(default=False) 17 | temperature: float = field(default=0) 18 | top_k: int = field(default=50) 19 | top_p: float = field(default=0.95) 20 | max_tokens: int = field(default=1024) 21 | 22 | @property 23 | def max_completion_tokens(self) -> int: 24 | return self.max_tokens 25 | 26 | 27 | class LLMFactory: 28 | @classmethod 29 | def create(cls, config: LLMConfig) -> LLMBase: 30 | return { 31 | "ollama": Ollama, 32 | "openai": OpenAI, 33 | "anthropic": Anthropic, 34 | "huggingface": HuggingFace, 35 | "easydeploy": EasyDeploy, 36 | }[config.provider]( 37 | config.llm_name, 38 | debug_mode=config.debug_mode, 39 | temperature=config.temperature, 40 | top_k=config.top_k, 41 | top_p=config.top_p, 42 | max_tokens=config.max_tokens, 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | llm = LLMFactory.create( 48 | LLMConfig(provider="ollama", llm_name="qwen2:0.5b-instruct", debug_mode=True) 49 | ) 50 | llm.append_user_message("Hi, I'm Simon!") 51 | llm.query() 52 | -------------------------------------------------------------------------------- /cora/retrv/events.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, unique 2 | 3 | from cora.utils.event import EventEmitter 4 | 5 | 6 | @unique 7 | class RetrieverEvents(Enum): 8 | EVENT_START = "start" 9 | EVENT_FINISH = "finish" 10 | EVENT_QRW_START = "qrw_start" 11 | EVENT_QRW_FINISH = "qrw_finish" 12 | EVENT_EDL_START = "edl_start" 13 | EVENT_EDL_FINISH = "edl_finish" 14 | EVENT_KWS_START = "kws_start" 15 | EVENT_KWS_FINISH = "kws_finish" 16 | EVENT_FTE_START = "fte_start" 17 | EVENT_FTE_FINISH = "fte_finish" 18 | EVENT_FPS_START = "fps_start" 19 | EVENT_FPS_FINISH = "fps_finish" 20 | EVENT_SCR_START = "scr_start" 21 | EVENT_SCR_FINISH = "scr_finish" 22 | 23 | 24 | class RetrieverCallbacks: 25 | def on_start(self, **kwargs): 26 | pass 27 | 28 | def on_finish(self, **kwargs): 29 | pass 30 | 31 | def on_qrw_start(self, **kwargs): 32 | pass 33 | 34 | def on_qrw_finish(self, **kwargs): 35 | pass 36 | 37 | def on_edl_start(self, **kwargs): 38 | pass 39 | 40 | def on_edl_finish(self, **kwargs): 41 | pass 42 | 43 | def on_kws_start(self, **kwargs): 44 | pass 45 | 46 | def on_kws_finish(self, **kwargs): 47 | pass 48 | 49 | def on_fte_start(self, **kwargs): 50 | pass 51 | 52 | def on_fte_finish(self, **kwargs): 53 | pass 54 | 55 | def on_fps_start(self, **kwargs): 56 | pass 57 | 58 | def on_fps_finish(self, **kwargs): 59 | pass 60 | 61 | def on_scr_start(self, **kwargs): 62 | pass 63 | 64 | def on_scr_finish(self, **kwargs): 65 | pass 66 | 67 | def register_to(self, x: EventEmitter): 68 | for k, v in RetrieverEvents.__members__.items(): 69 | x.on(v.value, getattr(self, f"on_{v.value}")) 70 | -------------------------------------------------------------------------------- /cora/agents/snippets/split_file.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Generator 2 | 3 | from cora.agents.snippets.base import SnipFinderBase, SnipRelDetmBase 4 | from cora.repo.repo import Repository 5 | from cora.utils.parallel import parallel 6 | 7 | 8 | class EnumSnipFinder(SnipFinderBase): 9 | def __init__(self, repo: Repository, determ: SnipRelDetmBase): 10 | super().__init__(repo, determ) 11 | 12 | def find( 13 | self, 14 | query: str, 15 | file_path: str, 16 | num_threads: int = 4, 17 | snippet_size: int = 100, 18 | *args, 19 | **kwargs, 20 | ) -> Generator[Tuple[str, str], None, None]: 21 | need_disable_enable_debugging = num_threads > 1 and self.determ.is_debugging() 22 | 23 | if need_disable_enable_debugging: 24 | self.determ.disable_debugging() 25 | 26 | snippets = self.repo.get_all_snippets_of_file_with_size( 27 | file_path, snippet_size=snippet_size 28 | ) 29 | 30 | results = parallel( 31 | [(self._determ_relevance, (query, snip_path)) for snip_path in snippets], 32 | n_jobs=num_threads, 33 | backend="threading", 34 | ) 35 | 36 | if need_disable_enable_debugging: 37 | self.determ.enable_debugging() 38 | 39 | yield from [ 40 | (snip_path, rel_reason) 41 | for snip_path, (is_relevant, rel_reason) in results 42 | if is_relevant 43 | ] 44 | 45 | def _determ_relevance(self, query: str, snip_path: str): 46 | return snip_path, self.determ.determine( 47 | query, 48 | snip_path, 49 | self.repo.get_snippet_content( 50 | snip_path, surroundings=0, add_lines=True, add_separators=True 51 | ), 52 | ) 53 | -------------------------------------------------------------------------------- /cora/kwe/index.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict, Counter 3 | from typing import Dict 4 | 5 | from cora.kwe.tokens import TokenizerBase 6 | 7 | 8 | class InvertedIndex: 9 | def __init__(self, tokenizer: TokenizerBase, bm25_k1=1.2, bm25_b=0.75): 10 | self.tokenizer = tokenizer 11 | self._intern = defaultdict(list) # token -> [(snippet_path, token_count)] 12 | self._length = {} # snippet_path -> num_tokens 13 | self._ave_len = 0.0 14 | self.bm25_k1 = bm25_k1 15 | self.bm25_b = bm25_b 16 | 17 | def bm25_all(self, query: str) -> Dict[str, float]: 18 | bm25 = defaultdict(float) 19 | 20 | # Calculate a BM25 score for each snippet toward the query 21 | for token in [tok.text for tok in self.tokenizer.tokenize(query)]: 22 | for sp, tok_cnt in self._intern[token]: 23 | tf = ((self.bm25_k1 + 1) * tok_cnt) / ( 24 | tok_cnt 25 | + self.bm25_k1 26 | * ( 27 | 1 28 | - self.bm25_b 29 | + self.bm25_b * (self._length[sp] / self._ave_len) 30 | ) 31 | ) 32 | idf = math.log10( 33 | ((len(self._length) - len(self._intern[token])) + 0.5) 34 | / (len(self._intern[token]) + 0.5) 35 | + 1.0 36 | ) 37 | bm25[sp] += idf * tf 38 | 39 | # Normalize all bm25 scores for each snippet toward the query 40 | mmax, mmin = 1, 0 41 | if bm25: 42 | mmax, mmin = max(bm25.values()), min(bm25.values()) 43 | if mmin < mmax: 44 | mmin = 0 45 | for snip in bm25.keys(): 46 | bm25[snip] = (bm25[snip] - mmin) / (mmax - mmin) 47 | 48 | return bm25 49 | 50 | def index_snippet(self, snippet: str, content: str): 51 | tokens = self.tokenizer.tokenize(content) 52 | # Updating inverted index 53 | counts = Counter([tok.text for tok in tokens]) 54 | for tok, num in counts.items(): 55 | self._intern[tok].append((snippet, num)) 56 | # Caching length 57 | self._length[snippet] = len(tokens) 58 | self._ave_len = sum(self._length.values()) / len(self._length) 59 | -------------------------------------------------------------------------------- /cora/base/rag.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | from typing import List, Optional 3 | 4 | 5 | class RetrieverBase(ABC): 6 | def __init__(self): 7 | self.agent: Optional["RAGBase"] = None 8 | 9 | def inject_agent(self, agent: "RAGBase"): 10 | self.agent = agent 11 | 12 | @abstractmethod 13 | def retrieve(self, query: str, **kwargs) -> List[str]: ... 14 | 15 | 16 | class GeneratorBase(ABC): 17 | def __init__(self): 18 | self.agent: Optional["RAGBase"] = None 19 | 20 | def inject_agent(self, agent: "RAGBase"): 21 | self.agent = agent 22 | 23 | @abstractmethod 24 | def generate(self, query: str, context: List[str], **kwargs) -> any: ... 25 | 26 | 27 | class RAGBase: 28 | def __init__( 29 | self, 30 | name: str, 31 | *, 32 | retriever: RetrieverBase, 33 | generator: GeneratorBase, 34 | ): 35 | self.name = name 36 | self.retriever = retriever 37 | self.generator = generator 38 | self.retriever.inject_agent(self) 39 | self.generator.inject_agent(self) 40 | 41 | def run( 42 | self, 43 | query: str, 44 | retrieving_args: Optional[dict] = None, 45 | generation_args: Optional[dict] = None, 46 | ) -> any: 47 | return self.generate( 48 | query, 49 | self.retrieve(query, **(retrieving_args or {})), 50 | **(generation_args or {}), 51 | ) 52 | 53 | def before_retrieving(self, query: str, **kwargs): 54 | pass 55 | 56 | def retrieve(self, query: str, **kwargs) -> List[str]: 57 | self.before_retrieving(query, **kwargs) 58 | context = self.retriever.retrieve(query, **kwargs) 59 | self.after_retrieving(query, context, **kwargs) 60 | return context 61 | 62 | def after_retrieving(self, query: str, context: List[str], **kwargs): 63 | pass 64 | 65 | def before_generate(self, query: str, context: List[str], **kwargs): 66 | pass 67 | 68 | def generate(self, query: str, context: List[str], **kwargs) -> any: 69 | self.before_generate(query, context, **kwargs) 70 | response = self.generator.generate(query, context, **kwargs) 71 | self.after_generate(query, context, response, **kwargs) 72 | return response 73 | 74 | def after_generate(self, query: str, context: List[str], response: any, **kwargs): 75 | pass 76 | -------------------------------------------------------------------------------- /cora/kwe/engine.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from cora.base.paths import FilePath, SnippetPath 4 | from cora.base.repos import RepoBase 5 | from cora.kwe.index import InvertedIndex 6 | from cora.kwe.tokens import TokenizerBase 7 | from cora.utils import misc as utils 8 | 9 | 10 | class KwEng: 11 | def __init__(self, repo: RepoBase, index: InvertedIndex): 12 | self._repo = repo 13 | self._index = index 14 | 15 | def search_snippets(self, query: str, limit: Optional[int] = None) -> List[str]: 16 | snippet_scores = self._index.bm25_all(query) 17 | # Update snippet scores according to file scores 18 | snippet_list = self._repo.get_all_snippets() 19 | for snippet in snippet_list: 20 | if snippet not in snippet_scores: 21 | snippet_scores[snippet] = 0 22 | # Rank snippets according to each snippet's score 23 | ranked_snippets = sorted( 24 | snippet_list, key=lambda sp: snippet_scores[sp], reverse=True 25 | ) 26 | return ranked_snippets[:limit] 27 | 28 | @classmethod 29 | def from_repo(cls, repo: RepoBase, tokenizer: TokenizerBase): 30 | index = InvertedIndex(tokenizer) 31 | for file in repo.get_all_files(): 32 | file_path = FilePath(repo.repo_path) / file 33 | file_cont = file_path.read_text(encoding="utf-8", errors="replace") 34 | file_lines = file_cont.splitlines() 35 | for snippet in repo.get_all_snippets_of_file(file): 36 | snp_path = SnippetPath.from_str(snippet) 37 | snp_cont = "\n".join( 38 | file_lines[snp_path.start_line : snp_path.end_line] 39 | ) 40 | index.index_snippet(str(snp_path), snp_cont) 41 | return cls(repo, index) 42 | 43 | def save_to_disk(self, file_path): 44 | return utils.save_object( 45 | { 46 | "repo": self._repo.full_name, 47 | "index": self._index, 48 | }, 49 | file_path, 50 | ) 51 | 52 | @classmethod 53 | def load_from_disk(cls, file_path, repo: RepoBase): 54 | obj = utils.load_object(file_path) 55 | assert ( 56 | "repo" in obj and obj["repo"] == repo.full_name 57 | ), f"Repository is not match, expecting {obj['repo']}, got {repo.full_name}" 58 | assert "index" in obj, "No `index` field in the disk file" 59 | return cls(repo, obj["index"]) 60 | -------------------------------------------------------------------------------- /cora/repo/kwe.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List 3 | 4 | from cora.base.paths import SnippetPath 5 | from cora.base.repos import RepoBase 6 | from cora.config import CoraConfig 7 | from cora.kwe.engine import KwEng 8 | from cora.kwe.tokens import NGramTokenizer 9 | from cora.utils.generic import CastSelfToThis 10 | from cora.utils.pattern import match_any_pattern 11 | 12 | 13 | class KwEngMixin(CastSelfToThis[RepoBase]): 14 | def __init__(self): 15 | self._kw_engine = None 16 | 17 | def search_snippets( 18 | self, 19 | query: str, 20 | limit: Optional[int] = 10, 21 | includes: Optional[List[str]] = None, 22 | ) -> List[str]: 23 | self.ensure_keyword_engine_loaded() 24 | snippets = [] 25 | for s in self._kw_engine.search_snippets(query, limit=None): 26 | if (not includes) or match_any_pattern(s, includes): 27 | snippets.append(s) 28 | if limit and len(snippets) == limit: 29 | break 30 | return snippets 31 | 32 | def search_files( 33 | self, 34 | query, 35 | limit: Optional[int] = 10, 36 | includes: Optional[List[str]] = None, 37 | ) -> List[str]: 38 | # Let's assume the top 32*limit snippets must contain top limit files 39 | files = [] 40 | for s in self.search_snippets( 41 | query, 42 | limit=None, # Let's search for all snippets 43 | includes=None, # Let's filter files ourselves 44 | ): 45 | f = str(SnippetPath.from_str(s).file_path) 46 | if (f not in files) and (not includes or match_any_pattern(f, includes)): 47 | files.append(f) 48 | if limit and len(files) == limit: 49 | break 50 | return files 51 | 52 | def ensure_keyword_engine_loaded(self): 53 | if self._kw_engine: 54 | return 55 | kwe_cache_file = self._kwe_cache_file 56 | if kwe_cache_file.exists(): 57 | self._kw_engine = KwEng.load_from_disk(kwe_cache_file, self.this) 58 | else: 59 | self.this.ensure_repository_chunked() 60 | self._kw_engine = KwEng.from_repo(self.this, NGramTokenizer()) 61 | self._kw_engine.save_to_disk(kwe_cache_file) 62 | 63 | @property 64 | def _kwe_cache_file(self) -> Path: 65 | return CoraConfig.keyword_index_cache_directory() / ( 66 | Path(self.this.repo_path).name + ".kwe" 67 | ) 68 | -------------------------------------------------------------------------------- /cora/repo/find.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | import rapidfuzz 5 | 6 | from cora.base.paths import FilePath 7 | from cora.base.repos import RepoBase 8 | from cora.utils.generic import CastSelfToThis 9 | from cora.utils.pattern import match_any_pattern 10 | 11 | 12 | class FindMixin(CastSelfToThis[RepoBase]): 13 | def find_similar_files( 14 | self, 15 | file_path: str, 16 | limit: int = 10, 17 | absolute: bool = False, 18 | includes: Optional[List[str]] = None, 19 | ) -> List[str]: 20 | return self._find_similar_paths( 21 | file_path, 22 | self.this.get_all_files(), 23 | limit=limit, 24 | absolute=absolute, 25 | includes=includes, 26 | ) 27 | 28 | def find_similar_directories( 29 | self, 30 | directory_path: str, 31 | limit: int = 10, 32 | absolute: bool = False, 33 | includes: List[str] = None, 34 | ) -> List[str]: 35 | return self._find_similar_paths( 36 | directory_path, 37 | self.this.get_all_directories(), 38 | limit=limit, 39 | absolute=absolute, 40 | includes=includes, 41 | ) 42 | 43 | def _find_similar_paths( 44 | self, 45 | to_path: str, 46 | from_path_list: List[str], 47 | limit: int = 10, 48 | absolute: bool = False, 49 | includes: Optional[List[str]] = None, 50 | ) -> List[str]: 51 | path_name = FilePath(to_path).name 52 | if includes: 53 | from_path_list = [ 54 | path for path in from_path_list if match_any_pattern(path, includes) 55 | ] 56 | similar_paths = self._find_similar_names(path_name, from_path_list, limit=limit) 57 | if absolute: 58 | return [os.path.join(self.this.repo_path, path) for path in similar_paths] 59 | else: 60 | return similar_paths 61 | 62 | @staticmethod 63 | def _find_similar_names(name: str, name_list: List[str], limit: int = 10): 64 | sorted_full_scores = sorted( 65 | name_list, 66 | key=lambda n: rapidfuzz.fuzz.ratio(name, n) + (101 if name in n else 0), 67 | reverse=True, 68 | ) 69 | sorted_part_scores = sorted( 70 | name_list, key=lambda n: rapidfuzz.fuzz.partial_ratio(name, n), reverse=True 71 | ) 72 | sorted_final_scores = sorted_full_scores[:limit] 73 | for s in sorted_part_scores[:limit]: 74 | if s not in sorted_final_scores: 75 | sorted_final_scores.append(s) 76 | return sorted_final_scores 77 | -------------------------------------------------------------------------------- /cora/preview/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import cached_property 3 | from typing import Type, Dict, List 4 | 5 | from cora.splits.ftypes import parse_ftype 6 | 7 | 8 | class FilePreview: 9 | _PREVIEW_SPLITTER = " | " 10 | _PREVIEW_DICT: Dict[str, Type["FilePreview"]] = {} 11 | 12 | def __init__(self, file_type: str, file_name: str, file_content: str): 13 | self.file_type = file_type 14 | self.file_name = file_name 15 | self.file_content = file_content 16 | 17 | @cached_property 18 | def file_lines(self): 19 | return self.file_content.splitlines() 20 | 21 | @classmethod 22 | def register(cls, file_types: List[str]): 23 | def register_inner(preview_cls: Type["FilePreview"]): 24 | for ft in file_types: 25 | assert ( 26 | ft not in cls._PREVIEW_DICT 27 | ), f"Conflicted previewers: {preview_cls} and {cls._PREVIEW_DICT[ft]} are both requesting for {ft}." 28 | cls._PREVIEW_DICT[ft] = preview_cls 29 | return preview_cls 30 | 31 | return register_inner 32 | 33 | @classmethod 34 | def of(cls, file_name: str, file_content: str) -> str: 35 | file_type = parse_ftype(file_name) 36 | return cls._PREVIEW_DICT.get(file_type, _FileContent)( 37 | file_type, file_name, file_content 38 | ).get_preview() 39 | 40 | @abstractmethod 41 | def get_preview(self) -> str: ... 42 | 43 | @classmethod 44 | def preview_line(cls, line_number, line_content): 45 | return str(line_number) + cls._PREVIEW_SPLITTER + line_content 46 | 47 | @classmethod 48 | def parse_preview_line(cls, preview_line): 49 | index = preview_line.find(cls._PREVIEW_SPLITTER) 50 | if index != -1: 51 | try: 52 | line_number = int(preview_line[:index]) 53 | line_content = preview_line[index + len(cls._PREVIEW_SPLITTER) :] 54 | except ValueError: 55 | line_number = None 56 | line_content = preview_line 57 | else: 58 | line_number = None 59 | line_content = preview_line 60 | return line_number, line_content 61 | 62 | @classmethod 63 | def spacing_for_line_number(cls, line_number): 64 | return " " * (len(str(line_number)) + len(cls._PREVIEW_SPLITTER)) 65 | 66 | @staticmethod 67 | def indentation_of_line(line): 68 | return " " * (len(line) - len(line.lstrip())) 69 | 70 | 71 | class _FileContent(FilePreview): 72 | def get_preview(self, min_line: int = 5, max_line: int = -1) -> str: 73 | return "\n".join( 74 | [self.preview_line(i, line) for i, line in enumerate(self.file_lines)] 75 | ) 76 | -------------------------------------------------------------------------------- /cora/preview/text.py: -------------------------------------------------------------------------------- 1 | from cora.preview.base import FilePreview 2 | 3 | 4 | # TODO: Support adoc in a separate like AdocPreview 5 | @FilePreview.register(["txt", "adoc"]) 6 | class TextPreview(FilePreview): 7 | def __init__(self, file_type: str, file_name: str, file_content: str): 8 | super().__init__( 9 | file_type=file_type, file_name=file_name, file_content=file_content 10 | ) 11 | 12 | def get_preview(self) -> str: 13 | preview = [] 14 | 15 | file_lines = self.file_lines 16 | num_lines = len(self.file_lines) 17 | 18 | # Let's assume paragraphs are divided by "\n\n" 19 | start_number = 0 20 | while start_number < num_lines: 21 | end_number = start_number + 1 22 | while end_number < num_lines and file_lines[end_number].strip(): 23 | end_number += 1 24 | 25 | para_lines = file_lines[start_number:end_number] 26 | num_para_lines = end_number - start_number 27 | 28 | if num_para_lines <= 3: 29 | preview.extend( 30 | [ 31 | self.preview_line_ex(start_number + i, para_lines[i]) 32 | for i in range(num_para_lines) 33 | ] 34 | ) 35 | else: 36 | preview.append(self.preview_line_ex(start_number, para_lines[0])) 37 | spacing = self.spacing_for_line_number(start_number) 38 | indentation = self.indentation_of_line(para_lines[0]) 39 | preview.append(spacing + indentation + "...") 40 | preview.append( 41 | spacing 42 | + indentation 43 | + f"(lines {start_number + 1}-{end_number - 2} are hidden in preview)" 44 | ) 45 | preview.append(spacing + indentation + "...\n") 46 | preview.append(self.preview_line_ex(end_number - 1, para_lines[-1])) 47 | 48 | if end_number != num_lines: 49 | preview.append(self.preview_line(end_number, file_lines[end_number])) 50 | 51 | start_number = end_number + 1 52 | 53 | return "\n".join(preview) 54 | 55 | @classmethod 56 | def preview_line_ex(cls, line_number, line): 57 | if not line: 58 | return cls.preview_line(line_number, line) 59 | sentences = line.split(".") # TODO Use NLTK's sent_tokenize() ?? 60 | if sentences[-1] == "": 61 | sentences = sentences[:-1] 62 | if len(sentences) == 1 or len(sentences) == 2: 63 | return cls.preview_line(line_number, line) 64 | else: 65 | return cls.preview_line( 66 | line_number, 67 | sentences[0] 68 | + " ... (in-between sentences are hidden in preview) ... " 69 | + sentences[-1], 70 | ) 71 | -------------------------------------------------------------------------------- /cora/results.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pyjson5 as json5 4 | 5 | from cora.repair.events import IssueRepaCallbacks 6 | from cora.retrv.events import RetrieverCallbacks 7 | 8 | 9 | class CfarResult(RetrieverCallbacks): 10 | def __init__(self, res_file: Path): 11 | self.res_file = res_file 12 | self.result = {} 13 | 14 | def on_qrw_finish(self, **kwargs): 15 | self.add_interm_res("qrw", kwargs) 16 | 17 | def on_edl_finish(self, **kwargs): 18 | self.add_interm_res("edl", kwargs) 19 | 20 | def on_kws_finish(self, **kwargs): 21 | self.add_interm_res("kws", kwargs) 22 | 23 | def on_fte_finish(self, **kwargs): 24 | self.add_interm_res("fte", kwargs) 25 | 26 | def on_fps_finish(self, **kwargs): 27 | self.add_interm_res("fps", kwargs) 28 | 29 | def on_scr_finish(self, **kwargs): 30 | self.add_interm_res("scr", kwargs) 31 | 32 | def on_finish(self, **kwargs): 33 | self.add_interm_res("all", kwargs) 34 | 35 | def add_interm_res(self, phase: str, phase_res: dict): 36 | try: 37 | # We use JSON5 as some like sets are not serializable 38 | json5.dumps(phase_res) 39 | new_phase_res = phase_res 40 | except json5.Json5Exception | TypeError: 41 | # Let's fall back to string for failed cases 42 | new_phase_res = str(phase_res) 43 | self.result[phase] = new_phase_res 44 | with self.res_file.open("w") as fou: 45 | fou.write(json5.dumps(self.result)) 46 | 47 | 48 | class IssueRepaResult(IssueRepaCallbacks): 49 | def __init__(self, res_file: Path): 50 | self.res_file = res_file 51 | self.result = {"rounds": [], "num_rounds": None} 52 | self.curr_round = -1 53 | 54 | def on_finish(self, **kwargs): 55 | self.curr_round = -1 56 | self.add_interm_res("result", kwargs) 57 | 58 | def on_next_round(self, **kwargs): 59 | if not self.result["num_rounds"]: 60 | self.result["num_rounds"] = kwargs["num_rounds"] 61 | self.curr_round += 1 62 | self.result["rounds"].append({}) 63 | 64 | def on_gen_patch_finish(self, **kwargs): 65 | self.add_interm_res("gen_patch", kwargs) 66 | 67 | def on_eval_patch_finish(self, **kwargs): 68 | self.add_interm_res("eval_patch", kwargs) 69 | 70 | def add_interm_res(self, phase: str, phase_res: dict): 71 | try: 72 | # We use JSON5 as some like sets are not serializable 73 | json5.dumps(phase_res) 74 | new_phase_res = phase_res 75 | except json5.Json5Exception | TypeError: 76 | # Let's fall back to string for failed cases 77 | new_phase_res = str(phase_res) 78 | self.result["rounds"][self.curr_round][phase] = new_phase_res 79 | with self.res_file.open("w") as fou: 80 | fou.write(json5.dumps(self.result)) 81 | -------------------------------------------------------------------------------- /cora/agent.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | from typing import Optional, List, cast 3 | 4 | from cora.agents.rewrite.base import RewriterBase 5 | from cora.base.console import get_boxed_console 6 | from cora.base.rag import RAGBase, GeneratorBase 7 | from cora.config import CoraConfig 8 | from cora.llms.factory import LLMConfig 9 | from cora.repo.repo import Repository 10 | from cora.retrv import retrv 11 | 12 | 13 | class RepoAgent(RAGBase): 14 | def __init__( 15 | self, 16 | repo: Repository, 17 | *, 18 | use_llm: LLMConfig, 19 | rewriter: RewriterBase, 20 | generator: GeneratorBase, 21 | includes: Optional[List[str]] = None, 22 | num_proc: int = 1, 23 | num_thread: int = 1, 24 | name: str = "RepoAgent", 25 | files_as_context: bool = False, 26 | debug_mode: bool = False, 27 | ): 28 | super().__init__( 29 | name=name, 30 | retriever=retrv.Retriever( 31 | repo, 32 | use_llm=LLMConfig( 33 | **{ 34 | **asdict(use_llm), 35 | "temperature": 0, # Enable greedy decoding 36 | } 37 | ), 38 | includes=includes, 39 | rewriter=rewriter, 40 | debug_mode=debug_mode, 41 | ), 42 | generator=generator, 43 | ) 44 | self.repo = repo 45 | self.includes = includes 46 | self.use_llm = use_llm 47 | self.num_proc = num_proc 48 | self.num_thread = num_thread 49 | self.files_as_context = files_as_context 50 | self.debug_mode = debug_mode 51 | self.console = get_boxed_console( 52 | box_title=self.name, 53 | box_bg_color="grey50", 54 | debug_mode=debug_mode, 55 | ) 56 | self.console.printb( 57 | f"Loaded repository {self.repo.repo_org}/{self.repo.repo_name} from {self.repo.repo_path} ..." 58 | ) 59 | 60 | @property 61 | def cfar(self) -> retrv.Retriever: 62 | return cast(retrv.Retriever, self.retriever) 63 | 64 | def before_retrieving(self, query: str, **kwargs): 65 | self.console.printb( 66 | f"Retrieving relevant context for the user query:\n```\n{query}\n```" 67 | ) 68 | # Setup required configs 69 | CoraConfig.SCR_ENUM_FNDR_NUM_THREADS = self.num_thread 70 | # CoraConfig.FTE_STRATEGY = ... 71 | # CoraConfig.QSM_STRATEGY = ... 72 | 73 | def after_retrieving(self, query: str, context: List[str], **kwargs): 74 | self.console.printb( 75 | "The retrieved context is:\n" + ("\n".join(["- " + s for s in context])) 76 | ) 77 | 78 | def run( 79 | self, 80 | query: str, 81 | retrieving_args: Optional[dict] = None, 82 | generation_args: Optional[dict] = None, 83 | ) -> any: 84 | return super().run( 85 | query, 86 | retrieving_args={ 87 | **(retrieving_args or {}), 88 | "files_only": self.files_as_context, 89 | "num_proc": self.num_proc, 90 | }, 91 | generation_args=generation_args, 92 | ) 93 | -------------------------------------------------------------------------------- /cora/kwe/tokens.py: -------------------------------------------------------------------------------- 1 | import re 2 | from abc import abstractmethod 3 | from dataclasses import dataclass 4 | from typing import List 5 | 6 | _PATTERN_TOKEN = re.compile(r"\b\w+\b") 7 | _PATTERN_VARIABLE = re.compile(r"([A-Z][a-z]+|[a-z]+|[A-Z]+(?=[A-Z]|$))") 8 | 9 | 10 | @dataclass 11 | class Token: 12 | text: str 13 | start: int 14 | end: int 15 | 16 | def __str__(self): 17 | return self.text 18 | 19 | def __repr__(self): 20 | return self.text 21 | 22 | 23 | class TokenizerBase: 24 | def tokenize(self, text: str) -> List[Token]: 25 | return self._do_tokenize(text) 26 | 27 | @abstractmethod 28 | def _do_tokenize(self, text: str) -> List[Token]: ... 29 | 30 | 31 | class NGramTokenizer(TokenizerBase): 32 | def __init__(self, ensure_ascii=True, num_gram=3): 33 | self.ensure_ascii = ensure_ascii 34 | self.num_gram = num_gram 35 | 36 | def _do_tokenize(self, text: str) -> List[Token]: 37 | tokens = [] 38 | if self.ensure_ascii: 39 | unigram = self._tokenize_ascii(text) 40 | else: 41 | unigram = self._tokenize_ascii(text) 42 | tokens.extend(unigram) 43 | for n in range(2, self.num_gram + 1): 44 | tokens.extend(self._create_ngram(n, unigram)) 45 | return tokens 46 | 47 | def _tokenize_ascii(self, text: str) -> List[Token]: 48 | tokens = [] 49 | 50 | def append_token_if_valid(text_, start_): 51 | if text_ and len(text_) > 1: 52 | tokens.append( 53 | Token( 54 | text=text_.lower(), 55 | start=start_, 56 | end=start_ + len(text_), 57 | ) 58 | ) 59 | 60 | for m in re.finditer(_PATTERN_TOKEN, text): 61 | tok_text, tok_start = m.group(), m.start() 62 | 63 | # Snakecase token 64 | if "_" in tok_text: 65 | offset = 0 66 | for tok_part in tok_text.split("_"): 67 | append_token_if_valid(tok_part.lower(), tok_start + offset) 68 | offset += len(tok_part) + 1 # "1" for "_" 69 | # Camelcase token 70 | elif tok_parts := _PATTERN_VARIABLE.findall(tok_text): 71 | offset = 0 72 | for tok_part in tok_parts: 73 | append_token_if_valid(tok_part.lower(), tok_start + offset) 74 | offset += len(tok_part) 75 | # Others 76 | else: 77 | append_token_if_valid(tok_text.lower(), tok_start) 78 | return tokens 79 | 80 | def _tokenize_unicode(self, text: str) -> List[Token]: 81 | raise NotImplementedError() 82 | 83 | @staticmethod 84 | def _create_ngram(n: int, tokens: List[Token]) -> List[Token]: 85 | ngram_toks = [] 86 | pre_toks = [] 87 | for token in tokens: 88 | if len(pre_toks) == n - 1: 89 | new_tok = Token( 90 | text="_".join([t.text for t in pre_toks] + [token.text]), 91 | start=pre_toks[0].start, 92 | end=token.end, 93 | ) 94 | ngram_toks.append(new_tok) 95 | pre_toks.pop(0) 96 | pre_toks.append(token) 97 | return ngram_toks 98 | -------------------------------------------------------------------------------- /cora/cfar.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | from typing import List, Optional 4 | 5 | from cora import options, results 6 | from cora.agents.rewrite.base import RewriterBase 7 | from cora.agents.rewrite.dont import DontRewrite 8 | from cora.agents.rewrite.issue import IssueSummarizer 9 | from cora.base.console import get_boxed_console 10 | from cora.config import CoraConfig 11 | from cora.llms.factory import LLMConfig 12 | from cora.repo.repo import Repository 13 | from cora.retrv import retrv 14 | 15 | 16 | def retrieve( 17 | query: str, 18 | repo: Repository, 19 | *, 20 | use_llm: LLMConfig, 21 | rewriter: Optional[RewriterBase] = None, 22 | files_only: bool = False, 23 | num_proc: int = 1, 24 | includes: Optional[List[str]] = None, 25 | debug_mode: bool = False, 26 | log_dir: Optional[Path] = None, 27 | ) -> List[str]: 28 | console = get_boxed_console( 29 | box_title="CFAR", 30 | box_bg_color=retrv.DEBUG_OUTPUT_LOGGING_COLOR, 31 | debug_mode=debug_mode, 32 | ) 33 | 34 | console.printb( 35 | f"Loaded repository {repo.repo_org}/{repo.repo_name} from {repo.repo_path} ..." 36 | ) 37 | 38 | console.printb(f"Retrieving relevant context for query:\n```\n{query}\n```") 39 | retriever = retrv.Retriever( 40 | repo, 41 | use_llm=use_llm, 42 | includes=includes, 43 | debug_mode=debug_mode, 44 | rewriter=rewriter, 45 | ) 46 | 47 | if log_dir: 48 | retriever.add_callback(results.CfarResult(log_dir / "cfar_res.json")) 49 | 50 | snip_ctx = retriever.retrieve( 51 | query, 52 | files_only=files_only, 53 | num_proc=num_proc, 54 | ) 55 | console.printb( 56 | "The retrieved context is:\n" + ("\n".join(["- " + sp for sp in snip_ctx])) 57 | ) 58 | 59 | return snip_ctx 60 | 61 | 62 | def parse_args(): 63 | parser = ArgumentParser() 64 | options.make_common_options(parser) 65 | parser.add_argument( 66 | "--files-only", 67 | "-F", 68 | action="store_true", 69 | help="Only retrieve relevant files other than snippets", 70 | ) 71 | parser.add_argument( 72 | "--query-as-issue", 73 | action="store_true", 74 | help="Treat the user query as a GitHub issue to resolve", 75 | ) 76 | return parser.parse_args() 77 | 78 | 79 | def main(): 80 | args = parse_args() 81 | 82 | repo = options.parse_repo(args) 83 | query, incl = options.parse_query(args) 84 | llm = options.parse_llms(args) 85 | procs, threads = options.parse_perf(args) 86 | CoraConfig.SCR_ENUM_FNDR_NUM_THREADS = threads 87 | log_dir, verbose = options.parse_logging(args) 88 | 89 | if args.query_as_issue: 90 | rewriter = IssueSummarizer(repo, use_llm=llm) 91 | else: 92 | # TODO: Add a common rewriter to summarize a query or let LLM decide if a query is an issue 93 | rewriter = DontRewrite(repo) 94 | 95 | retrieve( 96 | query, 97 | repo, 98 | use_llm=llm, 99 | num_proc=procs, 100 | includes=incl, 101 | files_only=args.files_only, 102 | rewriter=rewriter, 103 | log_dir=log_dir, 104 | debug_mode=verbose, 105 | ) 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /cora/agents/find_entities.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional 2 | 3 | from cora.agents.base import AgentBase 4 | from cora.llms.base import LLMBase 5 | from cora.repo.repo import Repository 6 | 7 | # TODO: Extract entities first, then try guessing their definition files 8 | SYSTEM_PROMPT = """\ 9 | You are a File Name Extractor, tasked to extract all possible **file names** embedded in a "User Query". \ 10 | Since your extracted file names will be used by others to solve the user query, it is very important that you do not miss any possible file names. 11 | 12 | For this task, I will present you the "User Query", \ 13 | which is a question or a request proposed by a user for the repository: {repo_name}. \ 14 | To solve the user query, one may need to comprehend the user query and then access or even update some mentioned or not mentioned files in the repository. \ 15 | So it is important that we should find all the specific files. \ 16 | Your task is to comprehend analyze the user query in detail, and try finding all possible file names embedded in the user query. 17 | 18 | Note, 19 | 1. There might be multiple file names embedded in the query, you should find all of them. 20 | 2. If you cannot find all any file names in the user query, leave the field "files" as an empty array ([]) and give a clear "reason" to let me know. 21 | 22 | ## User Query ## 23 | 24 | ``` 25 | {user_query} 26 | ``` 27 | 28 | """ 29 | 30 | JSON_SCHEMA = """\ 31 | { 32 | "thoughts": "your comprehension to the user query", // your comprehension to the user query; this should be in very detail and should help you find all possible file names. 33 | "files": ["", "", ...], // all possible file names; or leave this field as an empty array if you cannot find any file names from the user query. 34 | "reason": "the reason why you think the names you found are file names embedded in the query" // explain step by step and one file name by one file name. 35 | }\ 36 | """ 37 | 38 | 39 | class EntDefnFinder(AgentBase): 40 | def __init__( 41 | self, 42 | query: str, 43 | repo: Repository, 44 | llm: LLMBase, 45 | *args, 46 | **kwargs, 47 | ): 48 | super().__init__(llm=llm, json_schema=JSON_SCHEMA, *args, **kwargs) 49 | self.query = query 50 | self.repo = repo 51 | 52 | def find(self): 53 | return self.run( 54 | SYSTEM_PROMPT.format( 55 | repo_name=self.repo.full_name, 56 | user_query=self.query, 57 | ) 58 | ) 59 | 60 | def _check_response_format( 61 | self, response: dict, *args, **kwargs 62 | ) -> Tuple[bool, Optional[str]]: 63 | for field in ["thoughts", "files", "reason"]: 64 | if field not in response: 65 | return False, f"'{field}' is missing in the JSON object" 66 | return True, None 67 | 68 | def _check_response_semantics( 69 | self, response: dict, *args, **kwargs 70 | ) -> Tuple[bool, Optional[str]]: 71 | return True, None 72 | 73 | def _parse_response(self, response: dict, *args, **kwargs) -> any: 74 | names = response["files"] 75 | reason = response["thoughts"] + "\n" + response["reason"] 76 | 77 | if not names: 78 | return [], reason 79 | 80 | return names, reason 81 | 82 | def _default_result_when_reaching_max_chat_round(self): 83 | return ( 84 | [], 85 | "The model have reached the max number of chat round and is unable to find any files in the query.", 86 | ) 87 | -------------------------------------------------------------------------------- /cora/utils/event.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from abc import abstractmethod 3 | from pathlib import Path 4 | from typing import Protocol, Dict, List, Optional 5 | 6 | _EVENTS_CLASS_TEMPLATE = """\ 7 | from enum import Enum, unique 8 | 9 | from cora.utils.event import EventEmitter 10 | 11 | 12 | @unique 13 | class {emitter_class}Events(Enum): 14 | {events_code} 15 | 16 | 17 | class {emitter_class}Callbacks: 18 | {callbacks_code} 19 | 20 | def register_to(self, x: EventEmitter): 21 | for k, v in {emitter_class}Events.__members__.items(): 22 | x.on(v.value, getattr(self, f"on_{{v.value}}")) 23 | 24 | """ 25 | 26 | 27 | def gen_event_and_callback_classes(emitter: str, events: List[str], *, to_file: Path): 28 | code_events = [] 29 | for ev in events: 30 | code_events.append(f'EVENT_{ev.upper()} = "{ev}"') 31 | 32 | code_callbacks = [] 33 | for ev in events: 34 | code_callbacks.append(f"def on_{ev.lower()}(self, **kwargs):\n pass") 35 | 36 | with to_file.open("w") as fou: 37 | fou.write( 38 | _EVENTS_CLASS_TEMPLATE.format( 39 | emitter_class=emitter, 40 | events_code="\n".join(f" {s}" for s in code_events), 41 | callbacks_code="\n\n".join(f" {s}" for s in code_callbacks), 42 | ) 43 | ) 44 | 45 | 46 | def _inspect_args_in_dict(fn, *, is_method: bool, args: tuple, kwargs: dict): 47 | fn_args = {**kwargs} 48 | if not args: 49 | return fn_args 50 | # Let's obtain the names of all positional parameters 51 | sig = inspect.signature(fn) 52 | positional_params = [ 53 | p.name 54 | for p in sig.parameters.values() 55 | if p.kind 56 | in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) 57 | ][1 if is_method else 0 :] 58 | # Trunk according to the given arguments 59 | params = positional_params[: len(args)] 60 | # Insert positional arguments into the dict 61 | fn_args.update({p: a for p, a in zip(params, args)}) 62 | return fn_args 63 | 64 | 65 | def hook_method_to_emit_events( 66 | before_event: Optional[str] = None, after_event: Optional[str] = None 67 | ): 68 | def wrap_method(method): 69 | def _wrapper(self, *args, **kwargs): 70 | if before_event or after_event: 71 | fn_args = _inspect_args_in_dict( 72 | method, is_method=True, args=args, kwargs=kwargs 73 | ) 74 | else: 75 | fn_args = {} 76 | if before_event: 77 | self.emit(before_event, **fn_args) 78 | res = method(self, *args, **kwargs) 79 | if after_event: 80 | self.emit(after_event, **fn_args, result=res) 81 | return res 82 | 83 | return _wrapper 84 | 85 | return wrap_method 86 | 87 | 88 | class EventReceiver(Protocol): 89 | @abstractmethod 90 | def __call__(self, **kwargs): 91 | pass 92 | 93 | 94 | class EventEmitter: 95 | def __init__(self): 96 | self.event_receivers: Dict[str, List[EventReceiver]] = {} 97 | 98 | def on(self, event: str, receiver: EventReceiver): 99 | if event not in self.event_receivers: 100 | self.event_receivers[event] = [] 101 | self.event_receivers[event].append(receiver) 102 | 103 | def emit(self, event, **kwargs): 104 | for recv in self.event_receivers.get(event, []): 105 | recv(**kwargs) 106 | -------------------------------------------------------------------------------- /cora/agents/snippets/judge_snip.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from cora.agents.base import AgentBase 4 | from cora.agents.snippets.base import SnipRelDetmBase 5 | from cora.base.paths import SnippetPath 6 | from cora.llms.base import LLMBase 7 | from cora.utils.misc import to_bool 8 | 9 | SYSTEM_PROMPT = """\ 10 | You are a Relevance Determiner, tasked to analyze if a given "File Snippet" is relevant to a given "User Query". 11 | 12 | For this task, I will provide you with a "File Snippet" which is a portion of the file **{file_name}**, and a "User Query". \ 13 | The file snippet lists the content of the snippet; \ 14 | it is a portion of a file, wrapped by "===START OF SNIPPET===" and "===END OF SNIPPET===". \ 15 | We also provide some lines of surrounding content of the snippet in its file for your reference. 16 | 17 | A file snippet is relevant to a user query if the file snippet can be an important part to address the user query \ 18 | (though it might not address the user query directly). \ 19 | You should determine their relevance by the following steps: 20 | 1. Think carefully what files are required to address the user query; 21 | 2. Analyze if the file snippet is part of those files and what the snippet provides or what the file snippet does; 22 | 3. Check if the file snippet can provide some useful information to address the user query directly or indirectly; 23 | 4. Conclude if the file snippet are relevant to the user query. 24 | 25 | ## User Query ## 26 | 27 | ``` 28 | {user_query} 29 | ``` 30 | 31 | ## File Snippet ## 32 | 33 | ``` 34 | //// Snippet: {snippet_path} 35 | {file_snippet} 36 | ``` 37 | 38 | """ 39 | 40 | JSON_SCHEMA = """\ 41 | { 42 | "relevant": "", // if the file snippet is relevant to the user query; this field is a boolean field, either true or false 43 | "reason": "the reason why you think the file snippet is relevant to the user query" // explain in detail why you think this snippet is relevant to the user query; explain it step by step following the above steps 44 | }\ 45 | """ 46 | 47 | 48 | class SnipJudge(SnipRelDetmBase, AgentBase): 49 | def __init__(self, llm: LLMBase, *args, **kwargs): 50 | AgentBase.__init__(self, llm=llm, json_schema=JSON_SCHEMA, *args, **kwargs) 51 | 52 | def is_debugging(self) -> bool: 53 | return AgentBase.is_debugging(self) 54 | 55 | def enable_debugging(self): 56 | AgentBase.enable_debugging(self) 57 | 58 | def disable_debugging(self): 59 | AgentBase.disable_debugging(self) 60 | 61 | def determine( 62 | self, query: str, snippet_path: str, snippet_content: str, *args, **kwargs 63 | ): 64 | return self.run( 65 | SYSTEM_PROMPT.format( 66 | file_name=SnippetPath.from_str(snippet_path).file_path.name, 67 | user_query=query, 68 | snippet_path=snippet_path, 69 | file_snippet=snippet_content, 70 | ) 71 | ) 72 | 73 | def _check_response_format( 74 | self, response: dict, *args, **kwargs 75 | ) -> Tuple[bool, Optional[str]]: 76 | for field in ["relevant", "reason"]: 77 | if field not in response: 78 | return False, f"'{field}' is missing in the JSON object" 79 | return True, None 80 | 81 | def _check_response_semantics( 82 | self, response: dict, *args, **kwargs 83 | ) -> Tuple[bool, Optional[str]]: 84 | return True, None 85 | 86 | def _parse_response(self, response: dict, *args, **kwargs) -> any: 87 | return to_bool(response["relevant"]), response["reason"] 88 | 89 | def _default_result_when_reaching_max_chat_round(self): 90 | return ( 91 | False, 92 | "The model have reached the max number of chat round and is unable to determine their relevance.", 93 | ) 94 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | .idea/ -------------------------------------------------------------------------------- /cora/base/console.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from abc import abstractmethod 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | from rich.console import Console 8 | from rich.panel import Panel 9 | 10 | 11 | @dataclass 12 | class BoxedConsoleConfigs: 13 | box_width: Optional[int] = None # Width of the console 14 | out_dir: Optional[str] = None # If set, the console will print to a file 15 | print_to_console: bool = False # If print to console when out_dir is enabled 16 | 17 | 18 | class BoxedConsoleBase: 19 | @abstractmethod 20 | def printb(self, *args, **kwargs): ... 21 | 22 | @abstractmethod 23 | def print(self, *args, **kwargs): ... 24 | 25 | @classmethod 26 | def _make_box_title(cls, title): 27 | return f"{title} [{cls._thread_id()}]" 28 | 29 | @staticmethod 30 | def _thread_id(): 31 | curr_thr = threading.current_thread() 32 | return f"{curr_thr.name}@{curr_thr.native_id}" 33 | 34 | 35 | class MockConsole(BoxedConsoleBase): 36 | def printb(self, *args, **kwargs): 37 | pass 38 | 39 | def print(self, *args, **kwargs): 40 | pass 41 | 42 | 43 | class FileConsole(BoxedConsoleBase): 44 | def __init__( 45 | self, *, out_file: str, title: Optional[str], print_to_console: bool = False 46 | ): 47 | self.title = title 48 | self.out_file = out_file 49 | self.print_to_console = print_to_console 50 | 51 | def printb(self, message, title=None, *args, **kwargs): 52 | title = self._make_box_title(title or self.title) 53 | long_msg = "" 54 | if title: 55 | long_msg += f"--- {title} --------\n" 56 | long_msg += message 57 | long_msg += "\n" 58 | with open(self.out_file, "a") as fou: 59 | fou.write(long_msg) 60 | if self.print_to_console: 61 | print(long_msg) 62 | 63 | def print(self, message): 64 | long_msg = message + "\n" 65 | with open(self.out_file, "a") as fou: 66 | fou.write(long_msg) 67 | if self.print_to_console: 68 | print(long_msg) 69 | 70 | 71 | class BoxedConsole(BoxedConsoleBase): 72 | def __init__(self, *, box_width, box_title, box_bg_color="black"): 73 | self.console = Console() 74 | self.box_width = box_width 75 | self.box_title = box_title 76 | self.box_bg_color = box_bg_color 77 | 78 | def printb(self, message, title=None, background=None): 79 | title = self._make_box_title(title or self.box_title) 80 | background = background or self.box_bg_color 81 | self.console.print( 82 | Panel( 83 | f"{message}", 84 | title=title, 85 | title_align="left", 86 | width=self.box_width, 87 | style=f"on {background}", 88 | ) 89 | ) 90 | 91 | def print(self, message): 92 | self.console.print( 93 | message, width=self.box_width, style=f"on {self.box_bg_color}" 94 | ) 95 | 96 | 97 | def get_boxed_console( 98 | box_title=None, box_bg_color="black", console_name="cora", debug_mode=False 99 | ) -> BoxedConsoleBase: 100 | if debug_mode: 101 | if BoxedConsoleConfigs.out_dir: 102 | return FileConsole( 103 | out_file=str( 104 | ( 105 | Path(BoxedConsoleConfigs.out_dir) / (console_name + ".traj.log") 106 | ).resolve() 107 | ), 108 | title=box_title, 109 | print_to_console=BoxedConsoleConfigs.print_to_console, 110 | ) 111 | else: 112 | return BoxedConsole( 113 | box_width=BoxedConsoleConfigs.box_width, 114 | box_title=box_title, 115 | box_bg_color=box_bg_color, 116 | ) 117 | else: 118 | return MockConsole() 119 | -------------------------------------------------------------------------------- /cora/llms/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Literal, Optional, List 4 | 5 | from cora.base.console import get_boxed_console 6 | 7 | 8 | @dataclass 9 | class FunctionCall: 10 | name: Optional[str] = None 11 | arguments: Optional[str] = None 12 | reasoning: Optional[str] = None 13 | pycode: Optional[str] = None 14 | 15 | def to_json(self): 16 | obj = {} 17 | 18 | if self.name is not None: 19 | obj["name"] = self.name 20 | if self.arguments is not None: 21 | obj["arguments"] = self.arguments 22 | if self.reasoning is not None: 23 | obj["reasoning"] = self.reasoning 24 | if self.pycode is not None: 25 | obj["pycode"] = self.pycode 26 | 27 | return obj if len(obj) != 0 else None 28 | 29 | 30 | @dataclass 31 | class ChatMessage: 32 | role: Literal["user", "assistant", "system", "function"] 33 | content: str = None 34 | name: Optional[str] = None 35 | function_call: Optional[FunctionCall] = None 36 | 37 | def to_json(self): 38 | obj = {"role": self.role, "content": ""} 39 | if self.content is not None: 40 | obj["content"] = self.content 41 | if self.name is not None: 42 | obj["name"] = self.name 43 | if self.function_call is not None: 44 | obj["function_call"] = self.function_call.to_json() 45 | return obj 46 | 47 | 48 | class LLMBase: 49 | DEBUG_OUTPUT_SYSTEM_COLOR = "bright_red" 50 | DEBUG_OUTPUT_ASSISTANT_COLOR = "bright_yellow" 51 | DEBUG_OUTPUT_USER_COLOR = "light_cyan1" 52 | DEBUG_OUTPUT_FUNCTION_COLOR = "light_cyan1" 53 | 54 | def __init__( 55 | self, *, temperature=0, top_k=50, top_p=0.95, max_tokens=4096, debug_mode=False 56 | ): 57 | self.history = [] 58 | self.temperature = temperature 59 | self.top_k = top_k 60 | self.top_p = top_p 61 | self.max_tokens = max_tokens 62 | self.debug_mode = debug_mode 63 | self.console = get_boxed_console(debug_mode=debug_mode) 64 | 65 | def is_debug_mode(self): 66 | return self.debug_mode 67 | 68 | def enable_debug_mode(self): 69 | self.debug_mode = True 70 | self.console = get_boxed_console(debug_mode=True) 71 | 72 | def disable_debug_mode(self): 73 | self.debug_mode = False 74 | self.console = get_boxed_console(debug_mode=False) 75 | 76 | def query(self) -> str: 77 | r = self.do_query() 78 | self.append_assistant_message(r) 79 | return r 80 | 81 | def get_history(self) -> List[ChatMessage]: 82 | return self.history 83 | 84 | def clear_history(self): 85 | self.history = [] 86 | 87 | def append_system_message(self, content: str): 88 | self.append_message(ChatMessage(role="system", content=content)) 89 | 90 | def append_user_message(self, content: str): 91 | self.append_message(ChatMessage(role="user", content=content)) 92 | 93 | def append_assistant_message(self, content: str): 94 | self.append_message(ChatMessage(role="assistant", content=content)) 95 | 96 | def append_message(self, message: ChatMessage): 97 | color = { 98 | "system": LLMBase.DEBUG_OUTPUT_SYSTEM_COLOR, 99 | "user": LLMBase.DEBUG_OUTPUT_USER_COLOR, 100 | "function": LLMBase.DEBUG_OUTPUT_FUNCTION_COLOR, 101 | "assistant": LLMBase.DEBUG_OUTPUT_ASSISTANT_COLOR, 102 | }[message.role] 103 | if message.role == "assistant" and message.function_call is not None: 104 | fn_reason = message.function_call.reasoning 105 | fn_name = message.function_call.name 106 | fn_args = message.function_call.arguments 107 | formatted_message = f"{fn_reason}\n\nCall Function: {fn_name}(**{fn_args})" 108 | else: 109 | formatted_message = message.content 110 | self.console.printb( 111 | formatted_message, title=message.role.capitalize(), background=color 112 | ) 113 | self.history.append(message) 114 | 115 | @abstractmethod 116 | def do_query(self) -> str: 117 | pass 118 | -------------------------------------------------------------------------------- /cora/agents/snippets/factory.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict, Type 2 | 3 | from cora.agents.snippets.base import SnipFinderBase, SnipRelDetmBase 4 | from cora.agents.snippets.judge_snip import SnipJudge 5 | from cora.agents.snippets.prev_file import PrevSnipFinder 6 | from cora.agents.snippets.score_snip import SnipScorer, SCORE_WEAK_RELEVANCE 7 | from cora.agents.snippets.split_file import EnumSnipFinder 8 | from cora.base.console import BoxedConsoleBase 9 | from cora.base.paths import SnippetPath, FilePath 10 | from cora.config import CoraConfig 11 | from cora.llms.factory import LLMConfig, LLMFactory 12 | from cora.repo.repo import Repository 13 | from cora.utils.interval import merge_overlapping_intervals 14 | from cora.utils.misc import CannotReachHereError 15 | 16 | 17 | class _SFW: 18 | def __init__(self, finder: SnipFinderBase, console=None): 19 | self.finder = finder 20 | self.console = console 21 | 22 | def enable_debugging(self): 23 | self.finder.enable_debugging() 24 | 25 | def disable_debugging(self): 26 | self.finder.disable_debugging() 27 | 28 | def find(self, query: str, file_path: str, *args, **kwargs) -> List[str]: 29 | snippets = [] 30 | for snippet, reason in self.finder.find(query, file_path, *args, **kwargs): 31 | if not snippet: 32 | continue 33 | self._printb(f"Found {snippet}: {reason}") 34 | snippets.append(snippet) 35 | snippet_tuples = [] 36 | for snippet in snippets: 37 | snippet_path = SnippetPath.from_str(snippet) 38 | snippet_tuples.append((snippet_path.start_line, snippet_path.end_line)) 39 | merged_tuples = merge_overlapping_intervals( 40 | snippet_tuples, merge_continuous=True 41 | ) 42 | return [str(SnippetPath(FilePath(file_path), a, b)) for a, b in merged_tuples] 43 | 44 | def _printb(self, *args, **kwargs): 45 | if self.console: 46 | self.console.printb(*args, **kwargs) 47 | 48 | 49 | class SnipFinderFactory: 50 | @classmethod 51 | def create( 52 | cls, 53 | name: str, 54 | repo: Repository, 55 | *, 56 | use_llm: LLMConfig, 57 | use_determ: str, 58 | determ_args: Optional[Dict] = None, 59 | console: Optional[BoxedConsoleBase] = None, 60 | ): 61 | try: 62 | ctor = { 63 | CoraConfig.SCR_SNIPPET_FINDER_NAME_ENUM_FNDR: EnumSnipFinder, 64 | CoraConfig.SCR_SNIPPET_FINDER_NAME_PREV_FNDR: PrevSnipFinder, 65 | }[name] 66 | except KeyError: 67 | raise CannotReachHereError(f"Unsupported snippet finder: {name}") 68 | return cls._create_sfw( 69 | ctor, 70 | repo, 71 | use_llm=use_llm, 72 | use_determ=use_determ, 73 | determ_args=determ_args, 74 | console=console, 75 | ) 76 | 77 | @classmethod 78 | def _create_sfw( 79 | cls, 80 | ctor: Type["SnipFinderBase"], 81 | repo: Repository, 82 | *, 83 | use_llm: LLMConfig, 84 | use_determ: str, 85 | determ_args: Optional[Dict] = None, 86 | console: Optional[BoxedConsoleBase] = None, 87 | ) -> _SFW: 88 | determ = cls._create_determ(use_determ, use_llm=use_llm, **(determ_args or {})) 89 | if ctor == PrevSnipFinder: 90 | finder = PrevSnipFinder( 91 | llm=LLMFactory.create(use_llm), repo=repo, determ=determ 92 | ) 93 | elif ctor == EnumSnipFinder: 94 | finder = EnumSnipFinder(repo=repo, determ=determ) 95 | else: 96 | raise CannotReachHereError(f"Unsupported snippet finder: {ctor}") 97 | return _SFW(finder=finder, console=console) 98 | 99 | @staticmethod 100 | def _create_determ(determ: str, use_llm: LLMConfig, **kwargs) -> SnipRelDetmBase: 101 | try: 102 | return { 103 | CoraConfig.SCR_SNIPPET_DETERM_NAME_SNIP_SCORER: SnipScorer( 104 | LLMFactory.create(use_llm), 105 | threshold=kwargs.get("threshold", SCORE_WEAK_RELEVANCE), 106 | ), 107 | CoraConfig.SCR_SNIPPET_DETERM_NAME_SNIP_JUDGE: SnipJudge( 108 | LLMFactory.create(use_llm), 109 | ), 110 | }[determ] 111 | except KeyError: 112 | raise CannotReachHereError(f"Unsupported snippet determiner: {determ}") 113 | -------------------------------------------------------------------------------- /cora/repoqa.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import List, cast 3 | 4 | from cora import options, results 5 | from cora.agent import RepoAgent 6 | from cora.agents.base import AgentBase 7 | from cora.agents.reason_agent import R1 8 | from cora.agents.rewrite.dont import DontRewrite 9 | from cora.agents.rewrite.issue import IssueSummarizer 10 | from cora.base.rag import GeneratorBase 11 | from cora.llms.factory import LLMConfig, LLMFactory 12 | 13 | RESP_GEN_NO_REASONING_PROMPT = """\ 14 | ## Context ## 15 | 16 | ``` 17 | {context} 18 | ``` 19 | 20 | ## User Query ## 21 | 22 | \"\"\" 23 | {query} 24 | \"\"\" 25 | 26 | ## Response Plan ## 27 | 28 | 1. Read and understand the context provided from the codebase: {repo}. 29 | 2. Analyze the user's query to determine the specific information they are seeking. 30 | 3. Formulate a response that directly addresses the user's question by referencing the relevant parts of the provided context. 31 | 4. Outline the steps or information needed to answer the query in a clear and logical sequence. 32 | 33 | ## Answer ## 34 | 35 | """ 36 | 37 | 38 | RESP_GEN_R1_USER_QUERY_PROMPT = """\ 39 | Against the codebase {repo}, I have the following question: 40 | 41 | \"\"\" 42 | {query} 43 | \"\"\" 44 | 45 | Below are some code I obtained from the codebase that you may find helpful to answer my question: 46 | 47 | ``` 48 | {context} 49 | ``` 50 | 51 | """ 52 | 53 | 54 | class RespGen(AgentBase): 55 | def __init__(self, use_llm: LLMConfig): 56 | super().__init__(LLMFactory.create(use_llm), json_schema=None) 57 | 58 | def respond(self, query: str, *, context: str, repo: str) -> str: 59 | return self.run( 60 | system_prompt=RESP_GEN_NO_REASONING_PROMPT.format( 61 | query=query, context=context, repo=repo 62 | ) 63 | ) 64 | 65 | 66 | class RespGenR1: 67 | def __init__(self, use_llm: LLMConfig): 68 | self.r1 = R1(llm=LLMFactory.create(use_llm), max_chat_round=25) 69 | 70 | def respond(self, query: str, *, context: str, repo: str) -> str: 71 | return self.r1.run( 72 | RESP_GEN_R1_USER_QUERY_PROMPT.format( 73 | query=query, context=context, repo=repo 74 | ), 75 | with_internal_thoughts=True, 76 | ) 77 | 78 | 79 | class _Generator(GeneratorBase): 80 | def __init__(self, use_r1: bool = False): 81 | super().__init__() 82 | self.use_r1 = use_r1 83 | 84 | def generate(self, query: str, context: List[str], **kwargs) -> str: 85 | assert self.agent, "RepoAgent hasn't been injected. Please invoke inject_agent() before calling this method" 86 | agent = cast(RepoAgent, self.agent) 87 | gen_cls = RespGenR1 if self.use_r1 else RespGen 88 | resp = gen_cls(agent.use_llm).respond( 89 | query, 90 | context="\n\n".join( 91 | [ 92 | f"/// {sp}\n" 93 | f"{agent.repo.get_snippet_content(sp, add_lines=True, add_separators=False)}" 94 | for sp in context 95 | ] 96 | ), 97 | repo=agent.repo.full_name, 98 | ) 99 | agent.console.printb(resp) 100 | return resp 101 | 102 | 103 | def parse_args(): 104 | parser = ArgumentParser() 105 | options.make_common_options(parser) 106 | parser.add_argument( 107 | "--query-as-issue", 108 | action="store_true", 109 | help="Treat the user query as a GitHub issue to resolve", 110 | ) 111 | parser.add_argument( 112 | "--enable-r1", 113 | "-r1", 114 | action="store_true", 115 | help="Leverage R1 to answer the user query step by step with a reasoning chain", 116 | ) 117 | return parser.parse_args() 118 | 119 | 120 | def main(): 121 | args = parse_args() 122 | 123 | repo = options.parse_repo(args) 124 | query, incl = options.parse_query(args) 125 | llm = options.parse_llms(args) 126 | procs, threads = options.parse_perf(args) 127 | log_dir, verbose = options.parse_logging(args) 128 | 129 | if args.query_as_issue: 130 | rewriter = IssueSummarizer(repo, use_llm=llm) 131 | else: 132 | rewriter = DontRewrite(repo) 133 | 134 | agent = RepoAgent( 135 | repo=repo, 136 | use_llm=llm, 137 | rewriter=rewriter, 138 | generator=_Generator(use_r1=args.enable_r1), 139 | includes=incl, 140 | num_proc=procs, 141 | num_thread=threads, 142 | name="RepoQA", 143 | files_as_context=False, 144 | debug_mode=verbose, 145 | ) 146 | if log_dir: 147 | agent.cfar.add_callback(results.CfarResult(log_dir / "cfar_res.json")) 148 | agent.run(query=query) 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /cora/splits/code_.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | from tree_sitter import Node, Range 5 | from tree_sitter_languages import get_parser 6 | 7 | from cora.base.paths import FilePath, SnippetPath 8 | from cora.splits.ftypes import parse_ftype 9 | from cora.splits.splitter import Splitter 10 | 11 | 12 | class ASTSpl(Splitter): 13 | """ 14 | An AST- and char-based code splitter based on Kevin Lu's blog: 15 | - Chunking 2M+ files a day for Code Search using Syntax Trees 16 | - https://docs.sweep.dev/blogs/chunking-2m-files 17 | """ 18 | 19 | def __init__(self, file: FilePath, snippet_size: int = 1500, min_size: int = 100): 20 | super().__init__(file) 21 | self._parser = get_parser(parse_ftype(file.name)) 22 | self._snippet_size = snippet_size 23 | self._min_size = min_size 24 | 25 | def _do_split(self): 26 | # Split code along the AST, each splits saving their ranges 27 | ranges = self._split_ast() 28 | if len(ranges) == 0: 29 | return [] 30 | elif len(ranges) == 1: 31 | return [SnippetPath(self.file, 0, ranges[0].end_point[0] + 1)] 32 | 33 | # Merge overly small ranges into one of their adjacent ranges 34 | merged_ranges = [] 35 | cur_ran = Range((0, 0), (0, 0), 0, 0) 36 | for ran in ranges: 37 | cur_ran = Range( 38 | cur_ran.start_point, ran.end_point, cur_ran.start_byte, ran.end_byte 39 | ) 40 | cur_cont = self.content[cur_ran.start_byte : cur_ran.end_byte] 41 | if len(re.sub(r"\s", "", cur_cont)) > self._min_size and "\n" in cur_cont: 42 | merged_ranges.append(cur_ran) 43 | cur_ran = Range( 44 | ran.end_point, ran.end_point, ran.end_byte, ran.end_byte 45 | ) 46 | if cur_ran.end_byte - cur_ran.start_byte > 0: 47 | merged_ranges.append(cur_ran) 48 | 49 | # Converting from their ranges to their snippets (by line numbers) 50 | snippets = [ 51 | SnippetPath(self.file, spl.start_point[0], spl.end_point[0]) 52 | for spl in merged_ranges 53 | ] 54 | snippets[-1] = SnippetPath( 55 | self.file, snippets[-1].start_line, snippets[-1].end_line + 1 56 | ) # Let the last snippet to include the very last line 57 | 58 | return snippets 59 | 60 | def _split_ast(self) -> List[Range]: 61 | ast = self._parser.parse(self.content.encode("utf-8")) 62 | 63 | # Split recursively, each splits saving their starting and ending point in a range 64 | ranges = self._split_node(ast.root_node) 65 | 66 | # AST nodes eliminates spaces, we add them back to avoid miss 67 | ranges[0] = Range((0, 0), ranges[0].end_point, 0, ranges[0].end_byte) 68 | for i in range(len(ranges) - 1): 69 | ranges[i] = Range( 70 | ranges[i].start_point, 71 | ranges[i + 1].start_point, 72 | ranges[i].start_byte, 73 | ranges[i + 1].start_byte, 74 | ) 75 | ranges[-1] = Range( 76 | ranges[-1].start_point, 77 | ast.root_node.end_point, 78 | ranges[-1].start_byte, 79 | ast.root_node.end_byte, 80 | ) 81 | 82 | return ranges 83 | 84 | def _split_node(self, node: Node) -> List[Range]: 85 | ranges: List[Range] = [] 86 | cur_ran = Range( 87 | node.start_point, node.start_point, node.start_byte, node.start_byte 88 | ) 89 | for child in node.children: 90 | # If the current snippet is too big, we add that to our list of splits and empty the bundle 91 | if child.end_byte - child.start_byte > self._snippet_size: 92 | ranges.append(cur_ran) 93 | cur_ran = Range( 94 | child.end_point, child.end_point, child.end_byte, child.end_byte 95 | ) 96 | ranges.extend(self._split_node(child)) 97 | # If the next child node is too big, we recursively chunk the child node and add it to the list of splits 98 | elif ( 99 | child.end_byte 100 | - child.start_byte 101 | + (cur_ran.end_byte - cur_ran.start_byte) 102 | > self._snippet_size 103 | ): 104 | ranges.append(cur_ran) 105 | cur_ran = Range( 106 | child.start_point, child.end_point, child.start_byte, child.end_byte 107 | ) 108 | # Otherwise, concatenate the current chunk with the child node 109 | else: 110 | cur_ran = Range( 111 | cur_ran.start_point, 112 | child.end_point, 113 | cur_ran.start_byte, 114 | child.end_byte, 115 | ) 116 | ranges.append(cur_ran) 117 | return ranges 118 | -------------------------------------------------------------------------------- /cora/agents/rewrite/issue.py: -------------------------------------------------------------------------------- 1 | from cora.agents.rewrite.summ import SummaryGenBuilder 2 | 3 | SUMMARIZE_PROMPT = """\ 4 | ## YOUR TASK ## 5 | 6 | You are a powerful developer familiar with GitHub issues tasked to summarize issues concisely. 7 | I will give you a real Github issue about the repository: {repo}, \ 8 | I need you to generate a query according to the content of the issue, \ 9 | which developers can use this query to identify the problem he needs to solve without check the issue. 10 | A good query should:\ 11 | 1) accurately summarize the "ISSUE"; \ 12 | 2) be expressed naturally and concisely, without burdening the user with reading; \ 13 | 3) help me better locate the target files in the repo which need to be modified to solve the issue; \ 14 | 4) be not more than 40 words, ideally one or two sentences, that captures the essence of the issue's requirements; 15 | 16 | ## ISSUE ## 17 | 18 | {query} 19 | 20 | """ 21 | 22 | SUMMARIZE_SUMMARY_KEY = "new_query" 23 | 24 | SUMMARIZE_RETURNS = [ 25 | ( 26 | "reason", 27 | str, 28 | '"the reason why you generate the query that way" // explain in detail, step-by-step', 29 | ), 30 | (SUMMARIZE_SUMMARY_KEY, str, ""), 31 | ] 32 | 33 | 34 | EVALUATE_PROMPT = """\ 35 | ## YOUR TASK ## 36 | 37 | You are a distinguished code developer familiar with Github Issues and bug repair. \ 38 | I will give you a "ISSUE" and a "QUERY" which is created by others to summarize the ISSUE, \ 39 | and your task is to determine if the given "QUERY" summarize "ISSUE" well, \ 40 | and give a score according to your determination. 41 | 42 | A query summarize the issue well if developers can use this query to identify the problem he needs to solve without check the issue.\ 43 | You should determine your score by the following steps: 44 | 1. Think carefully about what problems the issue raises; 45 | 2. Analyze if the query include the most important elements of the issue; 46 | 3. Check Whether query can be used instead of issue to express the same requirement and whether it is clear to the developer to understand; 47 | 4. Conclude if the query summarize the issue well and give a score. 48 | 49 | The score (an integer chosen from [0, 1, 2, 3]) represent the extent to which the query summarizes the issue, where 50 | - Score 0: The query is totally irrelevant to the issue; 51 | - Score 1: The query can only cover some important elements which is not enough to help the developers to solve the problem 52 | - Score 2: The query can cover most important elements of issue, and can give developers a basic understanding of the issue 53 | - Score 3: The query can cover all important elements of issue, and achieve the same effect as the issue. 54 | 55 | ## ISSUE ## 56 | 57 | {query} 58 | 59 | ## QUERY ## 60 | 61 | {summary} 62 | 63 | """ 64 | 65 | EVALUATE_SCORE_KEY = "score" 66 | 67 | EVALUATE_RETURNS = [ 68 | ( 69 | "reason", 70 | str, 71 | '"the reason why you give the score" // explain in detail, step-by-step', 72 | ), 73 | ( 74 | EVALUATE_SCORE_KEY, 75 | int, 76 | " // the summarization score; it should be an integer chosen from [0, 1, 2, 3]", 77 | ), 78 | ] 79 | 80 | 81 | UPDATE_PROMPT = """\ 82 | ## YOUR TASK ## 83 | 84 | You are an experienced developer familiar with GitHub issues tasked to summarize issues concisely. 85 | I will give you a Github issue about the repository: {repo}. 86 | I already have a preliminary version of this query, but I feel it could be improved for clarity and completeness. 87 | Please help me refine or rewrite my existing query to make it more effective for searching and understanding the issue at hand. 88 | A good query should:\ 89 | 1) accurately summarize the "ISSUE"; \ 90 | 2) be expressed naturally and concisely, without burdening the user with reading; \ 91 | 3) help me better locate the target files in the repo which need to be modified to solve the issue; \ 92 | 4) be not more than 40 words, ideally one or two sentences, that captures the essence of the issue's requirements; 93 | 94 | ## ISSUE ## 95 | 96 | {query} 97 | 98 | ## EXISTED QUERY ## 99 | 100 | {summary} 101 | 102 | """ 103 | 104 | UPDATE_SUMMARY_KEY = "new_query" 105 | 106 | UPDATE_RETURNS = [ 107 | ( 108 | "reason", 109 | str, 110 | '"the reason why you give the query that way" // explain in detail, step-by-step', 111 | ), 112 | (UPDATE_SUMMARY_KEY, str, ""), 113 | ] 114 | 115 | 116 | IssueSummarizer = ( 117 | SummaryGenBuilder() 118 | .with_summarize_prompt(SUMMARIZE_PROMPT) 119 | .with_summarize_summary_key(SUMMARIZE_SUMMARY_KEY) 120 | .with_summarize_returns(SUMMARIZE_RETURNS) 121 | .with_evaluate_prompt(EVALUATE_PROMPT) 122 | .with_evaluate_score_key(EVALUATE_SCORE_KEY) 123 | .with_evaluate_returns(EVALUATE_RETURNS) 124 | .with_update_prompt(UPDATE_PROMPT) 125 | .with_update_summary_key(UPDATE_SUMMARY_KEY) 126 | .with_update_returns(UPDATE_RETURNS) 127 | .build() 128 | ) 129 | -------------------------------------------------------------------------------- /cora/agents/choose_files.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, List 2 | 3 | from cora.agents.base import AgentBase 4 | from cora.llms.base import LLMBase 5 | from cora.repo.repo import Repository 6 | 7 | SYSTEM_PROMPT = """\ 8 | ## YOUR TASK ## 9 | 10 | You are a powerful File Chooser with the capability to analyze and evaluate the relevance of files to the "User Query". 11 | For this task, I will give you the "User Query" and a "File List". 12 | - The "User Query" represents an authentic request from the customer. Please take the time to comprehend and analyze the query thoroughly. 13 | - The "File List" lists all the similar files in a list structure which you need to choose from. All the files are from the repository: {repo_name}. You do not have the access to the actual files. 14 | Your Task is to: 15 | - Based on your understanding of "User Query", identify and select the CERTAINLY RELEVANT files from "File List" that need modification to effectively address the user's needs. 16 | - Return a JSON contains a List object of the files you choose from "File List" and you think are CERTAINLY RELEVANT to the "User Query" \ 17 | and the reason why you choose them. 18 | Note, 19 | 1. You should ONLY INCLUDE FILES IN "File List". 20 | 2. If you are NOT SURE whether some files are relevant to "User Query", DO NOT put them into the result list. 21 | 3. Please ensure that your response should be a JSON Object which includes the chosen file list along with your reason. 22 | 23 | ## User Query ## 24 | 25 | ``` 26 | {user_query} 27 | ``` 28 | 29 | ## File List ## 30 | 31 | ``` 32 | {file_list} 33 | ``` 34 | 35 | """ 36 | 37 | JSON_SCHEMA = """\ 38 | { 39 | "choose_list": ["","", ...] # choose the file from "File List" certainly relevant to the "User Query" 40 | "reason": "the reason why you choose them " // explain in detail,step-by-step 41 | }\ 42 | """ 43 | 44 | NOT_FILE_LIST_MESSAGE = """\ 45 | **FAILURE**: The chosen file list you gave is NOT a list. 46 | 47 | You response should ONLY CHOOSE FILES IN THE "File List" i give you : 48 | 49 | ``` 50 | {file_list} 51 | ``` 52 | 53 | ## Your Format ## 54 | 55 | {json_schema} 56 | 57 | Please fix the above shown issues (shown above) and respond again. 58 | 59 | ## Your Response (NOTE : Please ues JSON format) ## 60 | 61 | """ 62 | 63 | INVALID_FILE_MESSAGE = """\ 64 | **FAILURE**: Your response has files should not appear : {error_message}. 65 | 66 | You response should ONLY CHOOSE FILES IN THE "File List" i give you : 67 | 68 | ``` 69 | {file_list} 70 | ``` 71 | 72 | Please fix the above shown issues (shown above) and respond again. 73 | 74 | ## Your Format ## 75 | 76 | {json_schema} 77 | 78 | ## Your Response (NOTE : Please ues JSON format) ## 79 | 80 | """ 81 | 82 | 83 | class FileChooser(AgentBase): 84 | def __init__( 85 | self, 86 | query: str, 87 | repo: Repository, 88 | llm: LLMBase, 89 | *args, 90 | **kwargs, 91 | ): 92 | super().__init__(llm=llm, json_schema=JSON_SCHEMA, *args, **kwargs) 93 | self.query = query 94 | self.repo = repo 95 | 96 | def choose(self, from_files: List[str]): 97 | assert len(from_files) != 0, "Cannot choose any files if no files are given" 98 | return self.run( 99 | SYSTEM_PROMPT.format( 100 | repo_name=self.repo.full_name, 101 | user_query=self.query, 102 | file_list="\n".join(from_files), 103 | ), 104 | from_files=from_files, 105 | ) 106 | 107 | def _check_response_format( 108 | self, response: dict, *args, **kwargs 109 | ) -> Tuple[bool, Optional[str]]: 110 | for field in ["choose_list", "reason"]: 111 | if field not in response: 112 | return False, f"'{field}' is missing in the JSON object" 113 | return True, None 114 | 115 | def _check_response_semantics( 116 | self, response: dict, *args, **kwargs 117 | ) -> Tuple[bool, Optional[str]]: 118 | from_files = kwargs["from_files"] 119 | choose_list, _ = response["choose_list"], response["reason"] 120 | 121 | if type(choose_list) is not list: 122 | return False, NOT_FILE_LIST_MESSAGE.format( 123 | file_list=from_files, json_schema=JSON_SCHEMA 124 | ) 125 | 126 | for file in choose_list: 127 | if file not in from_files: 128 | return False, INVALID_FILE_MESSAGE.format( 129 | error_message=file, 130 | file_list=from_files, 131 | file_count=len(from_files), 132 | json_schema=JSON_SCHEMA, 133 | ) 134 | 135 | return True, None 136 | 137 | def _parse_response(self, response: dict, *args, **kwargs) -> any: 138 | return response["choose_list"], response["reason"] 139 | 140 | def _default_result_when_reaching_max_chat_round(self): 141 | return ( 142 | None, 143 | "The model have reached the max number of chat round and is unable to find any further files.", 144 | ) 145 | -------------------------------------------------------------------------------- /cora/preview/xml_.py: -------------------------------------------------------------------------------- 1 | from xml.etree import ElementTree 2 | 3 | from cora.preview.base import FilePreview 4 | from cora.preview.internal.xml_element import Elements 5 | from cora.preview.internal.xml_parser import SlowXMLParser 6 | 7 | 8 | @FilePreview.register(["xml"]) 9 | class XMLPreview(FilePreview): 10 | def __init__(self, file_type: str, file_name: str, file_content: str): 11 | super().__init__( 12 | file_type=file_type, file_name=file_name, file_content=file_content 13 | ) 14 | # TODO: Comments are all lost 15 | self.xml_tree = ElementTree.fromstring(file_content, parser=SlowXMLParser()) 16 | self.max_kept_depth = 5 17 | 18 | def get_preview(self): 19 | file_lines = self.file_lines 20 | last_line_number = -1 21 | 22 | def preview_lines_update_last_number(start_number, end_number, preview): 23 | nonlocal last_line_number 24 | for line_number in range(start_number, end_number, 1): 25 | preview.append(self.preview_line(line_number, file_lines[line_number])) 26 | last_line_number = line_number 27 | 28 | def traverse_element(element: ElementTree.Element, *, depth): 29 | nonlocal last_line_number 30 | preview = [] 31 | 32 | # We have reached max kept depth; let's hide all our children. 33 | if depth + 1 == self.max_kept_depth and len(element) > 0: 34 | start_number = Elements.start_line_number(element) - 1 35 | child_start_number = Elements.start_line_number(element[0]) - 1 36 | 37 | if child_start_number > start_number: 38 | # Continue from last previewed line until our first child 39 | preview_lines_update_last_number( 40 | last_line_number + 1, child_start_number, preview 41 | ) 42 | else: 43 | # Our child and us are in the same line, them we should continue to us 44 | preview_lines_update_last_number( 45 | last_line_number + 1, start_number + 1, preview 46 | ) 47 | end_number = Elements.end_line_number(element) - 1 48 | child_end_number = Elements.end_line_number(element[-1]) - 1 49 | 50 | if child_end_number > last_line_number: 51 | spacing = self.spacing_for_line_number(child_start_number) 52 | indentation = self.indentation_of_line( 53 | file_lines[child_start_number] 54 | ) 55 | # TODO: Collect children's tags, texts, etc. as terms 56 | preview.extend( 57 | [ 58 | spacing + indentation + "...", 59 | spacing 60 | + indentation 61 | + f"(lines {last_line_number + 1}-{child_end_number} are hidden in preview)", 62 | spacing + indentation + "...\n", 63 | ] 64 | ) 65 | preview_lines_update_last_number( 66 | child_end_number + 1, end_number + 1, preview 67 | ) 68 | else: 69 | preview_lines_update_last_number( 70 | last_line_number + 1, end_number + 1, preview 71 | ) 72 | return preview 73 | 74 | # Let's preview our children one by one 75 | for child in element: 76 | child_start_number = Elements.start_line_number(child) - 1 77 | child_end_number = Elements.end_line_number(child) - 1 78 | # This means [start_number, last_line_number] were already processed 79 | if child_start_number < last_line_number: 80 | child_start_number = last_line_number + 1 81 | # This means the child node was already fully processed 82 | if child_start_number > child_end_number: 83 | continue 84 | # Continue from last previewed line 85 | preview_lines_update_last_number( 86 | last_line_number + 1, child_start_number, preview 87 | ) 88 | # The child do not have any further children; let's put its content. 89 | if len(child) == 0: 90 | preview.extend( 91 | [ 92 | self.preview_line(line_number, file_lines[line_number]) 93 | for line_number in range( 94 | child_start_number, child_end_number + 1, 1 95 | ) 96 | ] 97 | ) 98 | # Otherwise, let's traverse child's children 99 | else: 100 | preview.extend(traverse_element(child, depth=depth + 1)) 101 | last_line_number = child_end_number 102 | 103 | preview_lines_update_last_number( 104 | last_line_number + 1, 105 | (Elements.end_line_number(element) - 1) + 1, 106 | preview, 107 | ) 108 | 109 | return preview 110 | 111 | return "\n".join(traverse_element(self.xml_tree, depth=0)) 112 | -------------------------------------------------------------------------------- /cora/splits/ftypes.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | _TYPE_TO_EXTS = { 4 | "actionscript": ["as"], 5 | "ada": ["adb", "ada", "ads"], 6 | "adoc": ["adoc"], 7 | "apache": ["apacheconf"], 8 | "apex": ["cls", "apex", "trigger"], 9 | "applescript": ["applescript", "scpt"], 10 | "bash": ["sh", "bash"], 11 | "beancount": ["beancount"], 12 | "bibtex": ["bib"], 13 | "c": ["c", "h"], 14 | "clojure": ["clj", "cljs", "cljx"], 15 | "cmake": ["cmake", "cmake.in", "in"], 16 | "coffescript": ["coffee", "cake", "cson", "cjsx", "iced"], 17 | "cpp": ["cc", "cpp", "cxx", "c++", "h++", "hh", "hpp", "pc", "pcc"], 18 | "commonlisp": ["lisp", "cl"], 19 | "css": ["css"], 20 | "cuda": ["cu", "cuh"], 21 | "c_sharp": ["cs", "csx"], 22 | "d": ["d"], 23 | "dart": ["dart"], 24 | "diff": ["diff", "patch"], 25 | "django": ["jinja"], 26 | "dockerfile": ["Dockerfile"], 27 | "dos": ["bat", "cmd"], 28 | "dot": ["dot"], 29 | "elisp": ["el"], 30 | "elixir": ["ex", "exs"], 31 | "elm": ["elm"], 32 | "erlang": ["erl"], 33 | "fortran": ["f", "for", "frt", "fr", "forth", "4th", "fth"], 34 | "f_sharp": ["fs"], 35 | "go": ["go"], 36 | "gomod": ["mod"], 37 | "graphql": ["graphql"], 38 | "groovy": ["groovy", "gradle"], 39 | "haml": ["haml"], 40 | "handlebars": ["hbs", "handlebars"], 41 | "haskell": ["hs", "hsc"], 42 | "hcl": ["hcl", "nomad", "tf", "tfvars", "workflow"], 43 | "html": ["htm", "html", "xhtml"], 44 | "ini": ["ini", "cfg", "prefs", "pro", "properties"], 45 | "java": ["java"], 46 | "javascript": ["js", "es", "es6", "jss", "jsm"], 47 | "javascriptreact": ["jsx"], 48 | "json": ["json"], 49 | "json5": ["json5"], 50 | "jsonnet": ["jsonnet", "libsonnet"], 51 | "julia": ["jl"], 52 | "kotlin": ["kt", "ktm", "kts"], 53 | "less": ["less"], 54 | "lisp": ["lisp", "asd", "cl", "lsp", "l", "ny", "podsl", "sexp"], 55 | "lua": ["lua", "fcgi", "nse", "pd_lua", "rbxs", "wlua"], 56 | "make": ["mk", "mak", "makefile"], 57 | "markdown": ["md", "mkdown", "mkd"], 58 | "nginx": ["nginxconf"], 59 | "objc": ["m", "mm"], 60 | "ocaml": ["ml", "eliom", "eliomi", "ml4", "mli", "mll", "mly", "re"], 61 | "pascal": ["p", "pas", "pp"], 62 | "perl": ["pl", "al", "cgi", "perl", "ph", "plx", "pm", "pod", "psgi", "t"], 63 | "php": ["php", "phtml", "php3", "php4", "php5", "php6", "php7", "phps"], 64 | "powershell": ["ps1", "psd1", "psm1"], 65 | "protobuf": ["proto"], 66 | "python": ["py", "pyc", "pyd", "pyo", "pyw", "pyz"], 67 | "r": ["r", "rd", "rsx"], 68 | "repro": ["repro"], 69 | "rst": ["rst"], 70 | "ruby": [ 71 | "rb", 72 | "rbi", 73 | "builder", 74 | "eye", 75 | "gemspec", 76 | "god", 77 | "jbuilder", 78 | "mspec", 79 | "pluginspec", 80 | "podspec", 81 | "rabl", 82 | "rake", 83 | "rbuild", 84 | "rbw", 85 | "rbx", 86 | "ru", 87 | "ruby", 88 | "spec", 89 | "thor", 90 | "watchr", 91 | ], 92 | "rust": ["rs", "rs.in"], 93 | "scala": ["sbt", "sc", "scala"], 94 | "scheme": ["scm", "sch", "sls", "sps", "ss"], 95 | "scss": ["sass", "scss"], 96 | "smalltalk": ["st"], 97 | "sql": ["sql"], 98 | "sqlite": ["sqlite"], 99 | "starlark": ["bzl", "bazel", "BUILD", "WORKSPACE"], 100 | "stylus": ["styl"], 101 | "svelte": ["svelte"], 102 | "swift": ["swift"], 103 | "thrift": ["thrift"], 104 | "toml": ["toml"], 105 | "twig": ["twig"], 106 | "txt": ["txt"], 107 | "typescript": ["ts"], 108 | "typescriptreact": ["tsx"], 109 | "vbnet": ["vb"], 110 | "vbscrip": ["vbs"], 111 | "verilog": ["v", "veo", "sv", "svh", "svi"], 112 | "vhdl": ["vhd", "vhdl"], 113 | "vim": ["vim"], 114 | "xml": [ 115 | "xml", 116 | "adml", 117 | "admx", 118 | "ant", 119 | "axml", 120 | "builds", 121 | "ccxml", 122 | "clixml", 123 | "cproject", 124 | "csl", 125 | "csproj", 126 | "ct", 127 | "dita", 128 | "ditamap", 129 | "ditaval", 130 | "dll.config", 131 | "dotsettings", 132 | "filters", 133 | "fsproj", 134 | "fxml", 135 | "glade", 136 | "gml", 137 | "grxml", 138 | "iml", 139 | "ivy", 140 | "jelly", 141 | "jsproj", 142 | "kml", 143 | "launch", 144 | "mdpolicy", 145 | "mjml", 146 | "mod", 147 | "mxml", 148 | "nproj", 149 | "nuspec", 150 | "odd", 151 | "osm", 152 | "pkgproj", 153 | "plist", 154 | "props", 155 | "ps1xml", 156 | "psc1", 157 | "pt", 158 | "rdf", 159 | "resx", 160 | "rss", 161 | "scxml", 162 | "sfproj", 163 | "srdf", 164 | "storyboard", 165 | "stTheme", 166 | "sublime-snippet", 167 | "targets", 168 | "tmCommand", 169 | "tml", 170 | "tmLanguage", 171 | "tmPreferences", 172 | "tmSnippet", 173 | "tmTheme", 174 | "ui", 175 | "urdf", 176 | "ux", 177 | "vbproj", 178 | "vcxproj", 179 | "vsixmanifest", 180 | "vssettings", 181 | "vstemplate", 182 | "vxml", 183 | "wixproj", 184 | "wsdl", 185 | "wsf", 186 | "wxi", 187 | "wxl", 188 | "wxs", 189 | "x3d", 190 | "xacro", 191 | "xaml", 192 | "xib", 193 | "xlf", 194 | "xliff", 195 | "xmi", 196 | "xml.dist", 197 | "xproj", 198 | "xsd", 199 | "xspec", 200 | "xul", 201 | "zcml", 202 | ], 203 | "yaml": ["yml", "yaml"], 204 | "zsh": ["zsh"], 205 | } 206 | _EXT_TO_TYPE = {ext: kv[0] for kv in _TYPE_TO_EXTS.items() for ext in kv[1]} 207 | 208 | 209 | def parse_ftype(file_name: str): 210 | ext = Path(file_name).suffix or Path(file_name).name 211 | if ext.startswith("."): 212 | ext = ext[1:] 213 | return _EXT_TO_TYPE.get(ext) 214 | -------------------------------------------------------------------------------- /cora/preview/code.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import tree_sitter_languages 4 | from tree_sitter import Node, Parser 5 | 6 | from cora.preview.base import FilePreview 7 | 8 | 9 | def _extract_words(string): 10 | # extract the most common words from a code snippet 11 | words = re.findall(r"\w+", string) 12 | return list(dict.fromkeys(words)) 13 | 14 | 15 | @FilePreview.register( 16 | [ 17 | # See https://github.com/grantjenks/py-tree-sitter-languages/blob/main/tests/test_tree_sitter_languages.py 18 | "bash", 19 | "c", 20 | "c_sharp", 21 | "commonlisp", 22 | "cpp", 23 | "css", 24 | "dockerfile", 25 | "dot", 26 | "elisp", 27 | "elixir", 28 | "elm", 29 | "erlang", 30 | "fortran", 31 | "go", 32 | "gomod", 33 | "haskell", 34 | "hcl", 35 | "html", 36 | "java", 37 | "javascript", 38 | "json", 39 | "julia", 40 | "kotlin", 41 | "lua", 42 | "make", 43 | "markdown", 44 | "objc", 45 | "ocaml", 46 | "perl", 47 | "php", 48 | "r", 49 | "rst", 50 | "ruby", 51 | "rust", 52 | "scala", 53 | "sql", 54 | "sqlite", 55 | "toml", 56 | "typescript", 57 | "yaml", 58 | ] 59 | ) 60 | class CodePreview(FilePreview): 61 | def __init__(self, file_type: str, file_name: str, file_content: str): 62 | super().__init__( 63 | file_type=file_type, file_name=file_name, file_content=file_content 64 | ) 65 | parser = Parser() 66 | parser.set_language(tree_sitter_languages.get_language(file_type)) 67 | self.file_tree = parser.parse(bytes(file_content, "utf8")) 68 | self.min_line = 5 69 | self.max_line = 50 70 | self.num_kept_lines = 2 71 | self.num_kept_terms = 5 72 | 73 | def get_preview(self): 74 | file_lines = self.file_lines 75 | 76 | last_line_number = -1 77 | 78 | def traverse_node(node: Node): 79 | nonlocal last_line_number 80 | preview_lines = [] 81 | for child in node.children: 82 | start_number, _ = child.start_point 83 | end_number, _ = child.end_point 84 | # This means [start_number, last_line_number] were already processed 85 | if start_number <= last_line_number: 86 | start_number = last_line_number + 1 87 | # This means the child node was already fully processed 88 | if start_number > end_number: 89 | continue 90 | # Continue from last previewed line 91 | for line_number in range(last_line_number + 1, start_number): 92 | line = file_lines[line_number] 93 | preview_lines.append(self.preview_line(line_number, line)) 94 | last_line_number = line_number 95 | # The child node is too large; let's get into its children. 96 | if end_number - start_number > self.max_line: 97 | preview_lines.extend(traverse_node(child)) 98 | # The child node is small enough; let's present all its content. 99 | elif end_number - start_number < self.min_line: 100 | node_lines = file_lines[start_number : end_number + 1] 101 | text = "\n".join( 102 | [ 103 | self.preview_line(start_number + i, line) 104 | for i, line in enumerate(node_lines) 105 | ] 106 | ) 107 | preview_lines.append(text) 108 | # The child line is within min_line--max_line; let's present its head/tail and hide its body 109 | else: 110 | node_lines = file_lines[start_number : end_number + 1] 111 | num_kept_lines = self.num_kept_lines 112 | # Keep the starting num_kept_lines 113 | preview_lines.extend( 114 | [ 115 | self.preview_line(start_number + i, line) 116 | for i, line in enumerate(node_lines[:num_kept_lines]) 117 | ] 118 | ) 119 | # Hide the middle lines and leave a short message 120 | num_extracted_terms = self.num_kept_terms 121 | first_n_terms = ", ".join( 122 | _extract_words( 123 | "\n".join(node_lines[num_kept_lines:-num_kept_lines]) 124 | )[:num_extracted_terms] 125 | ) 126 | spacing = self.spacing_for_line_number( 127 | start_number + num_kept_lines - 1 128 | ) 129 | indentation = self.indentation_of_line( 130 | node_lines[num_kept_lines - 1] 131 | ) 132 | preview_lines.extend( 133 | [ 134 | spacing + indentation + "...", 135 | spacing 136 | + indentation 137 | + f"(lines {start_number + num_kept_lines}-{end_number - num_kept_lines} contains terms: {first_n_terms}", 138 | spacing + indentation + "...\n", 139 | ] 140 | ) 141 | # Keep the ending num_kept_lines 142 | preview_lines.extend( 143 | [ 144 | self.preview_line( 145 | start_number + (end_number - start_number - 1) + i, line 146 | ) 147 | for i, line in enumerate(node_lines[-num_kept_lines:]) 148 | ] 149 | ) 150 | last_line_number = end_number 151 | return preview_lines 152 | 153 | return "\n".join(traverse_node(self.file_tree.root_node)) 154 | -------------------------------------------------------------------------------- /cora/agents/score_preview.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, List 2 | 3 | from cora.agents.base import AgentBase 4 | from cora.llms.base import LLMBase 5 | from cora.preview import FilePreview 6 | from cora.repo.repo import Repository 7 | 8 | SYSTEM_PROMPT = """\ 9 | ## YOUR TASK ## 10 | 11 | You are a powerful File Relevance Decider with the capability to analyze and evaluate the relevance of files to the "User Query". \ 12 | Your task is to determine the relevance of the file to the "User Query" and give a relevance score according to your determination. 13 | 14 | For this task, I will give you the "User Query" and a preview of a file which contains function names and class names. \ 15 | The file {file_name} is from the repository: {repo_name}. 16 | Additionally, I will give you a file list which contains some filenames which maybe related to the "User Query". The filelist maybe help you decide the relevance score. 17 | 18 | A file is relevant to a user query if the file can be an important part to address the user query \ 19 | (though it might not address the user query directly). \ 20 | You should determine their relevance by the following steps: 21 | 1. Think carefully what files are required to address the user query; 22 | 2. Analyze if the file is part of those files and what the provides or what the file does; 23 | 3. Check if the file can provide some useful information to address the user query directly or indirectly; 24 | 4. Conclude if the file are relevant to the user query and give a relevance score. 25 | 26 | The relevance score (an integer chosen from [0, 1, 2, 3]) represents the relevance of the file and the user query, where 27 | - Score 0: The file is totally irrelevant to the user query; it does not help anything to address the user query. 28 | - Score 1: The file is weakly relevant to the user query, but the user query can be addressed even without it. 29 | - Score 2: The file is relevant to the user query; the user query can only be partially addressed without it. 30 | - Score 3: The file is strongly relevant to the user query, and the user query relies on it can never be addressed without it. 31 | 32 | ## User Query ## 33 | 34 | ``` 35 | {user_query} 36 | ``` 37 | 38 | ## File Preview ## 39 | 40 | ``` 41 | {file_name} 42 | 43 | {file_preview} 44 | ``` 45 | 46 | ## File List ## 47 | ``` 48 | {file_list} 49 | ``` 50 | 51 | """ 52 | 53 | JSON_SCHEMA = """\ 54 | { 55 | "score": , // the relevance score; it should be an integer chosen from [0, 1, 2, 3] 56 | "reason": "the reason why you give the relevance score" // explain in detail, step-by-step, the relevance between the file and the user query 57 | }\ 58 | """ 59 | 60 | NON_INTEGER_SCORE_MESSAGE = """\ 61 | **FAILURE**: The relevance score ({score}) you gave is NOT an integer. 62 | 63 | The relevance score must be an integer chosen from [0, 1, 2, 3], where: 64 | - Score 0: The file is totally irrelevant to the user query; it does not help anything to address the user query. 65 | - Score 1: The file is weakly relevant to the user query, but the user query can be addressed even without it. 66 | - Score 2: The file is relevant to the user query; the user query can only be partially addressed without it. 67 | - Score 3: The file is strongly relevant to the user query, and the user query relies on it can never be addressed without it. 68 | 69 | ## Your Response (JSON format) ## 70 | 71 | """ 72 | 73 | INVALID_SCORE_VALUE_MESSAGE = """\ 74 | **FAILURE**: The relevance score ({score}) you gave is NOT chosen from [0, 1, 2, 3]. 75 | 76 | The relevance score must be an integer chosen from [0, 1, 2, 3], where: 77 | - Score 0: The file is totally irrelevant to the user query; it does not help anything to address the user query. 78 | - Score 1: The file is weakly relevant to the user query, but the user query can be addressed even without it. 79 | - Score 2: The file is fairly relevant to the user query; the user query can only be partially addressed without it. 80 | - Score 3: The file is strongly relevant to the user query, and the user query relies on it can never be addressed without it. 81 | 82 | ## Your Response (JSON format) ## 83 | 84 | """ 85 | 86 | SCORE_IRRELEVANCE = 0 87 | SCORE_WEAK_RELEVANCE = 1 88 | SCORE_FAIR_RELEVANCE = 2 89 | SCORE_STRONG_RELEVANCE = 3 90 | 91 | 92 | class PreviewScorer(AgentBase): 93 | def __init__( 94 | self, 95 | query: str, 96 | repo: Repository, 97 | llm: LLMBase, 98 | *args, 99 | **kwargs, 100 | ): 101 | super().__init__(llm=llm, json_schema=JSON_SCHEMA, *args, **kwargs) 102 | self.query = query 103 | self.repo = repo 104 | 105 | def score(self, file: str, other_files: List[str]) -> Tuple[int, str]: 106 | return self.run( 107 | SYSTEM_PROMPT.format( 108 | repo_name=self.repo.repo_name, 109 | user_query=self.query, 110 | file_name=file, 111 | file_preview=FilePreview.of(file, self.repo.get_file_content(file)), 112 | file_list="" 113 | if len(other_files) == 0 114 | else "\n".join(other_files), 115 | ) 116 | ) 117 | 118 | def _check_response_format( 119 | self, response: dict, *args, **kwargs 120 | ) -> Tuple[bool, Optional[str]]: 121 | for field in ["score", "reason"]: 122 | if field not in response: 123 | return False, f"'{field}' is missing in the JSON object" 124 | return True, None 125 | 126 | def _check_response_semantics( 127 | self, response: dict, *args, **kwargs 128 | ) -> Tuple[bool, Optional[str]]: 129 | try: 130 | score = int(response["score"]) 131 | except ValueError: 132 | return False, NON_INTEGER_SCORE_MESSAGE.format(score=response["score"]) 133 | if score not in [0, 1, 2, 3]: 134 | return False, INVALID_SCORE_VALUE_MESSAGE.format(score=score) 135 | if type(response["reason"]) is not str: 136 | return False, None 137 | return True, None 138 | 139 | def _parse_response(self, response: dict, *args, **kwargs) -> any: 140 | return int(response["score"]), response["reason"] 141 | 142 | def _default_result_when_reaching_max_chat_round(self): 143 | return ( 144 | SCORE_IRRELEVANCE, 145 | "The model have reached the max number of chat round and is unable to find any further files.", 146 | ) 147 | -------------------------------------------------------------------------------- /cora/agents/explore_tree.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, List 2 | 3 | from cora.agents.base import AgentBase 4 | from cora.base.ftree import FileTree 5 | from cora.llms.base import LLMBase 6 | from cora.repo.repo import Repository 7 | 8 | SYSTEM_PROMPT = """\ 9 | You are a File Finder, tasked to collect a list of files from the repository: {repo_name}. \ 10 | Your found files must be relevant to a "User Query" that I give you. \ 11 | Since the found files will be used by others to solve the user query, it is very important that you do not miss any relevant files. 12 | 13 | For this task, I will give you the "User Query" and a "Repository Tree". \ 14 | The repository tree lists all the repository's files in a tree structure. \ 15 | I will also provide you with a "File List" listing all files that you are CERTAIN relevant in the past and stored. \ 16 | You initially start with files in the file list (it might also be an empty list) and will explore the tree to find further files that you are CERTAIN relevant to the user query. \ 17 | If a file you are CERTAIN relevant is already in the file list, do not add them into the list again; find another relevant one instead. \ 18 | You determine if a file is relevant to the user query merely by analyzing the relevance between the user query and: 19 | 1. the repository's directory structure (because the structure may imply the repository's ) 20 | 2. each file's name and its position in the repository tree (because the file name somewhat indicates what the file's functionality) 21 | 3. your prior knowledge about the functionality of each special file (pom.xml, build.gradle, requirements.txt, package.json, etc.) 22 | You do not have any access to any file's content. 23 | 24 | Note, 25 | 1. There might be multiple relevant files, you respond only one file each time I query you; so please respond the file that is MOST relevant. 26 | 2. If you find all files you are CERTAIN relevant to the user query are all in the file list, set the field "file" to null and give a clear "reason" to let me know. 27 | 28 | ## User Query ## 29 | 30 | ``` 31 | {user_query} 32 | ``` 33 | 34 | ## Repository Tree ## 35 | 36 | ``` 37 | {repository_tree} 38 | ``` 39 | 40 | ## File List ## 41 | 42 | ``` 43 | {file_list} 44 | ``` 45 | 46 | """ 47 | 48 | JSON_SCHEMA = """\ 49 | {{ 50 | "file": "", // must be the path of the file like "{example_file}" or set this field to null if you find all files you are CERTAIN relevant to the user query are in the File List 51 | "reason": "the reason why you think this file is relevant to the user query" // explain in detail why you deem this file is relevant to the user query this might due to its name is reflected in the query, some of its special functionality can solve the query, or any other reasonable explanation. 52 | }}\ 53 | """ 54 | 55 | FILE_NOT_EXISTS_MESSAGE = """\ 56 | **FAILURE**: File {file_path} does not exist in the repository. 57 | 58 | Did you mean one of the following files? 59 | 60 | {similar_file_paths} 61 | 62 | ## Your Response (JSON format) ## 63 | 64 | """ 65 | 66 | FILE_ALREADY_EXISTS_MESSAGE = """\ 67 | **FAILURE**: File {file_path} ALREADY exists in the file list. 68 | 69 | Please find another file that you are CERTAIN relevant to the user query. \ 70 | If you find all files you are CERTAIN relevant to the user query are all in the file list, set the field "file" to null and give a clear "reason" to let me know. 71 | 72 | ## Your Response (JSON format) ## 73 | 74 | """ 75 | 76 | 77 | class FileFinder(AgentBase): 78 | def __init__( 79 | self, 80 | query: str, 81 | repo: Repository, 82 | tree: FileTree, 83 | llm: LLMBase, 84 | includes: Optional[List[str]] = None, 85 | *args, 86 | **kwargs, 87 | ): 88 | super().__init__( 89 | llm=llm, 90 | json_schema=JSON_SCHEMA.format(example_file=repo.get_rand_file()), 91 | *args, 92 | **kwargs, 93 | ) 94 | self.query = query 95 | self.repo = repo 96 | self.file_list = [] 97 | self.tree = tree 98 | self.includes = includes 99 | 100 | def next_file(self): 101 | return self.run( 102 | SYSTEM_PROMPT.format( 103 | repo_name=self.repo.full_name, 104 | user_query=self.query, 105 | repository_tree=str(self.tree), 106 | file_list="" 107 | if len(self.file_list) == 0 108 | else "\n".join(self.file_list), 109 | ) 110 | ) 111 | 112 | def _check_response_format( 113 | self, response: dict, *args, **kwargs 114 | ) -> Tuple[bool, Optional[str]]: 115 | for field in ["file", "reason"]: 116 | if field not in response: 117 | return False, f"'{field}' is missing in the JSON object" 118 | return True, None 119 | 120 | def _check_response_semantics( 121 | self, response: dict, *args, **kwargs 122 | ) -> Tuple[bool, Optional[str]]: 123 | file_path = response["file"] 124 | 125 | # None is a valid response 126 | if file_path is None: 127 | return True, None 128 | 129 | # Check the existence of the file in the repository 130 | if not self.repo.has_file(file_path): 131 | return False, FILE_NOT_EXISTS_MESSAGE.format( 132 | file_path=file_path, 133 | similar_file_paths="\n".join( 134 | [ 135 | f"- {path}" 136 | for path in self.repo.find_similar_files( 137 | file_path, includes=self.includes 138 | ) 139 | ] 140 | ), 141 | ) 142 | 143 | # Check the existence of the file in file list 144 | if file_path in self.file_list: 145 | return False, FILE_ALREADY_EXISTS_MESSAGE.format(file_path=file_path) 146 | 147 | return True, None 148 | 149 | def _parse_response(self, response: dict, *args, **kwargs) -> any: 150 | file_path, reason = response["file"], response["reason"] 151 | if file_path is None: 152 | return None, reason 153 | self.file_list.append(file_path) 154 | return file_path, reason 155 | 156 | def _default_result_when_reaching_max_chat_round(self): 157 | return ( 158 | None, 159 | "The model have reached the max number of chat round and is unable to find any further files.", 160 | ) 161 | -------------------------------------------------------------------------------- /docs/README_zh.md: -------------------------------------------------------------------------------- 1 |
2 | CodeFuse 3 |
4 | 5 |
6 |

CodeFuse RepoAgent

7 |
8 | 9 |
10 | 11 | | [中文](./README_zh.md) | [English](../README.md) | 12 | 13 |
14 | 15 | 16 | 17 | CodeFuse RepoAgent (CoRA) 是一个仓库级检索增强生成(r-RAG)智能体,旨在通过提升 r-RAG 的上游检索(R)能力,进而增强下游大模型的生成(G)能力,并据此回答用户对仓库的查询和提问。CoRA 利用其上游的语义检索智能体 CodeFuse Agentic Retriever (CFAR) 来获取于用户提问紧密相关的上下文信息,随后,它将这些上下文信息转换为特定任务的下游任务提示指令,以有效地生成对用户查询的响应。 18 | 19 | --- 20 | 21 | ## 🔥 新闻 22 | 23 | - **[2024/12/01]** 我们发布了 CoRA 最核心的上下文检索工具 CodeFuse Agentic Retriever (CFAR)。 24 | - **[2024/12/03]** 我们发布了 CoRA 的缺陷修复工具 FixIt! 和 SWE-kit,其中,SWE-kit 集成 SWE-bench 进行补丁正确性验证。 25 | 26 | --- 27 | 28 | ## 👏🏻 简介 29 | 30 | CoRA 利用 CFAR 获取与用户提问紧密相关的上下文信息,随后,它将这些上下文信息转换为特定任务的下游任务提示指令,以有效地生成对用户查询的响应。 31 | 32 | 仓库级检索增强生成(r-RAG)所面临的主要挑战是提取与用户提问紧密相关的上下文信息——即称为“片段上下文”或简称“上下文”的特定代码行或文档行——以有效响应用户的仓库相关查询。虽然目前已经有一些现有方法,但它们存在准确性不足、结果不确定性以及设计复杂且难以实施等局限。为了解决这些问题,CFAR 提出了一种新颖且简单易行的复合方法,将传统的基于关键字的搜索与大语言模型(LLM)相结合,以提升 r-RAG 的语义检索能力。CFAR 的设计基于一个观察:尽管关键字搜索的准确性可能不高,但它所检索出的上下文往往接近于解决用户查询所需的真实上下文(GT, ground truths)。这一观察体现在两个方面:首先,虽然结果并不总是特别好,但关键字搜索仍能够召回相当一部分的 GT 文件;其次,关键字引擎所找到的文件在仓库结构上与 GT 文件距离较为接近,LLM 可以帮助在这些文件的附近进行更有效的查找。CFAR 的工作流程如下: 33 | 1. **识别可行文件**:CFAR 首先利用关键字引擎找到一系列可能与用户查询相关的可行文件。这些可行文件为后续基于LLM的优化提供了基础,具体包括:(1)分析用户提问中的代码实体(例如函数和类)和仓库的组织结构,以补充可能在引擎搜索过程中遗漏的可行文件;(2)让 LLM 对每个可行文件进行概览分析,以过滤掉在引擎搜索中错误提及的不相关文件。 34 | 2. **检索片段上下文**:接下来,CFAR 利用 LLM 深入检查每个可行文件的具体内容,查找与用户查询相关的片段,并在此阶段对相关片段进行排序和优化。 35 | 36 | ![CoRA's Overview](../.github/assets/overview.png) 37 | 38 | CoRA 的用途之一是仓库级的缺陷修复。基于 CFAR 获取的上下文,CoRA 通过一条简单直接的提示词,便成功促使 GPT-4o 解决了 SWE-bench Lite 中的 95 个问题(31.67%)。在 2024 年 10 月 23 日前,这个结果曾一度在 SWE-bench Lite 开源排行榜上排名第一。我们将 CoRA 的技术报告发表在了 [arXiv](#) 上。 39 | 40 | ## 📦 安装 41 | 42 | 首先,创建一个 conda 环境: 43 | 44 | ```shell 45 | conda env create -f environment.yaml 46 | conda activate cora_venv 47 | mv env.template .env # This saves some environment variables 48 | ``` 49 | 50 | 其次,根据下表进行配置,确保欲使用的库/框架及大模型处于可用状态: 51 | - `√` 代表对应库/框架已经支持。 52 | - `.` 代表对应库/框架正在支持中或未来即将支持。 53 | - `x` 代表暂不考虑支持对应库/框架。 54 | 55 | | 库/框架 | 状态 | 配置方式 | 56 | |:----------------:|:---:|:------------------------------------------| 57 | | [OpenAI](#) | `√` | 在 `.env` 中配置 API key 等环境变量 | 58 | | [Anthropic](#) | `√` | 在 `.env` 中配置 API key 等环境变量 | 59 | | [Ollama](#) | `√` | 在使用 CoRA 前通过 `ollama pull` 下载所需模型并启动 | 60 | | [HuggingFace](#) | `.` | 要么提前下载好模型,要么在 `.env` 中配置允许 HuggingFace 联网 | 61 | | [EasyDeploy](#) | `.` | 未来将支持 | 62 | 63 | 64 | ## 🔍 上下文检索 65 | 66 | CFAR 可独立用于上下文检索: 67 | 68 | ```shell 69 | python -m cora.cfar \ 70 | -q \ 71 | -m \ 72 | 73 | ``` 74 | 75 | ## 🚀 缺陷修复(WIP) 76 | 77 | > [!WARNING] 78 | > This section is still working in progress. 79 | 80 | CoRA 的 FixIt! 可以为仓库中的某个缺陷生成修复该缺陷的补丁(patch): 81 | 82 | ```shell 83 | python -m cora.fixit \ 84 | -q \ 85 | -i \ 86 | -m \ 87 | -e \ 88 | --eval-args \ 89 | -M \ 90 | 91 | ``` 92 | 93 | 如果提供了评估脚本(即 `-e`),FixIt! 将使用该脚本评估所生成的补丁是否能够通过该脚本的测试。若未通过,FixIt! 将重试,直到生成能够通过该脚本的补丁或达到被允许尝试的最大尝试次数(即 `-M`)。若没有提供评估脚本,,FixIt! 仅生成一个看似合理的补丁,而不评估其正确性。 94 | 95 | 下面展示了一个评估脚本的简单样例,这个评估脚本首先将 FixIt! 传递给它的所有参数存储到 `/tmp/test.json`,然后检查补丁中是否包含 "Hello World" 子串。所有包含该子串的补丁将被视为通过测试。 96 | 97 | ```python 98 | #! /usr/local/bin/python3 99 | 100 | import json 101 | import sys 102 | 103 | if __name__ == "__main__": 104 | issue_id = sys.argv[1] 105 | patch_str = sys.argv[2] 106 | buggy_repo = sys.argv[3] 107 | patched_repo = sys.argv[4] 108 | 109 | # Save all arguments passed from FixIt! into /tmp/test.json 110 | with open("/tmp/test.json", "w") as fou: 111 | json.dump( 112 | { 113 | "issue_id": issue_id, 114 | "patch": patch_str, 115 | "buggy_repo": buggy_repo, 116 | "patched_repo": patched_repo, 117 | }, 118 | fou, 119 | ensure_ascii=False, 120 | indent=2, 121 | ) 122 | 123 | # We accept the patch if there is an "Hello World" 124 | if "Hello World" in patch_str: 125 | exit(0) # Exiting with 0 indicates an acceptance 126 | else: 127 | exit(1) # All other exit status imply a rejection 128 | ``` 129 | 130 | 下面展示了 `/tmp/test.json` 的内容: 131 | 132 | ```json 133 | { 134 | "issue_id": "django__django-11848", 135 | "patch": "diff --git a/django/utils/http.py b/django/utils/http.py\n--- a/django/utils/http.py\n+++ b/django/utils/http.py\n@@ -176,7 +176,7 @@\n try:\n year = int(m.group('year'))\n if year < 100:\n- if year < 70:\n+ if year < 50:\n year += 2000\n else:\n year += 1900\n", 136 | "buggy_repo": "/tmp/fixit/django_f0adf3b9", 137 | "patched_repo": "/tmp/fixit/patched_django_f0adf3b9" 138 | } 139 | ``` 140 | 141 | ## 🐑 SWE-bench (WIP) 142 | 143 | > [!WARNING] 144 | > This section is still working in progress 145 | 146 | CoRA 的 SWE-kit 可以修复 SWE-bench 数据集中的缺陷。为了验证 SWE-kit 生成的补丁是否可以通过 SWE-bench 中收录的测试,首先需要安装 SWE-bench: 147 | 148 | ```shell 149 | git clone git@github.com:princeton-nlp/SWE-bench.git 150 | cd SWE-bench 151 | pip install -e . 152 | ``` 153 | 154 | 然后,便可以用 SWE-kit 尝试为某个缺陷生成补丁: 155 | 156 | ```shell 157 | python -m cora.swekit \ 158 | -d \ 159 | -m \ 160 | -M \ 161 | 162 | ``` 163 | 164 | ## 🤖 问题回答 165 | 166 | CoRA's RepoQA 可以回答用户针对仓库的提问: 167 | 168 | ```shell 169 | python -m cora.repoqa \ 170 | -q \ 171 | -m \ 172 | 173 | ``` 174 | 175 | ## 👨‍💻‍ 写给开发者 176 | 177 | 为更好地维护仓库,CoRA 设置了一系列预提交检查,比如代码风格检查、提交信息检查等。因此,在为该项目提交代码之前,开发者需要安装 CoRA 所需的检查工具: 178 | 179 | ```shell 180 | pre-commit install # install pre-commit itself 181 | pre-commit install-hooks # install our pre-commit checkers 182 | ``` 183 | 184 | 下面列举了 CoRA 设置的部分检查器: 185 | + Python:CoRA 使用 ruff 进行代码检查和格式化,[这里](https://docs.astral.sh/ruff/rules/)列举了所有 Ruff 支持的规则。若开发者使用 PyCharm,可根据 [ruff.md](./docs/ruff.md) 进行 ruff 配置。 186 | + Commit Messages:CoRA 遵循 Conventional Commits 来规范所有提交信息,见[这里](https://www.conventionalcommits.org/)。 187 | + 其他:CoRA 还应用了一些其他检查工具,例如 YAML 检查。 188 | 189 | 190 | -------------------------------------------------------------------------------- /cora/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Optional, List 4 | 5 | from dotenv import load_dotenv 6 | 7 | from cora.utils import misc 8 | 9 | 10 | class _RetrieverConfigMixin: 11 | # Query Rewrite 12 | QRW_WORD_SIZE = 40 13 | 14 | # Entity Definition Lookup 15 | EDL_FILE_LIMIT = 10 16 | 17 | # Keyword Engine Search 18 | KWS_FILE_LIMIT = 5 19 | 20 | # File Tree Exploration 21 | FTE_STRATEGY_NAME_NO_FTE = "disable-fte" 22 | FTE_STRATEGY_NAME_FTD_GU = "files-then-dirs__give-up" 23 | FTE_STRATEGY_NAME_FTD_TS = "files-then-dirs__try-shrinking" 24 | FTE_STRATEGY = FTE_STRATEGY_NAME_FTD_GU 25 | FTE_FTD_GOING_UPWARD = 2 26 | FTE_FILE_LIMIT = 2 27 | FTE_MAX_FILE_TREE_SIZE = 1500 28 | 29 | # File Preview Scoring 30 | FPS_PREVIEW_SCORE_THRESHOLD = 2 31 | 32 | # Snippet Context Retrieval 33 | SCR_SNIPPET_FINDER_NAME_ENUM_FNDR = "enumerative-finder" 34 | SCR_SNIPPET_FINDER_NAME_PREV_FNDR = "preview-finder" 35 | SCR_SNIPPET_FINDER = SCR_SNIPPET_FINDER_NAME_ENUM_FNDR 36 | SCR_ENUM_FNDR_SNIPPET_SIZE = 100 37 | SCR_ENUM_FNDR_NUM_THREADS = 1 38 | SCR_SNIPPET_DETERM_NAME_SNIP_SCORER = "snippet-scorer" 39 | SCR_SNIPPET_DETERM_NAME_SNIP_JUDGE = "snippet-judge" 40 | SCR_SNIPPET_DETERM = SCR_SNIPPET_DETERM_NAME_SNIP_SCORER 41 | SCR_SNIP_SCORER_THRESHOLD = 1 42 | 43 | # Overall Results 44 | FINAL_FILE_LIMIT = 5 45 | 46 | 47 | class _FileConfigMixin: 48 | MAX_BYTES_PER_FILE = 240000 # Do not consider files exceeding this size 49 | MAX_FILES_PER_DIRECTORY = ( 50 | 100 # Do not consider directories containing files more than this number 51 | ) 52 | EXCLUDE_HIDDEN_FILES = True # Whether to consider hidden files 53 | 54 | EXCLUDED_FILE_NAMES: List[str] = ["gradle-wrapper.properties", "local.properties"] 55 | EXCLUDED_DIRECTORY_NAMES: List[str] = [ 56 | ".git", 57 | ".github", 58 | ".gitlab", 59 | "venv", 60 | "__pycache__", 61 | "node_modules", 62 | ".gradle", 63 | ".maven", 64 | ".mvn", 65 | ".idea", 66 | ".vscode", 67 | ".eclipse", 68 | ] 69 | EXCLUDED_SUFFIXES: List[str] = [ 70 | ".min.js", 71 | ".min.js.map", 72 | ".min.css", 73 | ".min.css.map", 74 | ".tfstate", 75 | ".tfstate.backup", 76 | ".jar", 77 | ".ipynb", 78 | ".png", 79 | ".jpg", 80 | ".jpeg", 81 | ".download", 82 | ".gif", 83 | ".bmp", 84 | ".tiff", 85 | ".ico", 86 | ".mp3", 87 | ".wav", 88 | ".wma", 89 | ".ogg", 90 | ".flac", 91 | ".mp4", 92 | ".avi", 93 | ".mkv", 94 | ".mov", 95 | ".patch", 96 | ".patch.disabled", 97 | ".wmv", 98 | ".m4a", 99 | ".m4v", 100 | ".3gp", 101 | ".3g2", 102 | ".rm", 103 | ".swf", 104 | ".flv", 105 | ".iso", 106 | ".bin", 107 | ".tar", 108 | ".zip", 109 | ".7z", 110 | ".gz", 111 | ".bz", 112 | ".bz2", 113 | ".rar", 114 | ".pdf", 115 | ".doc", 116 | ".docx", 117 | ".xls", 118 | ".xlsx", 119 | ".ppt", 120 | ".pptx", 121 | ".svg", 122 | ".parquet", 123 | ".pyc", 124 | ".pub", 125 | ".pem", 126 | ".ttf", 127 | ".log", 128 | ] 129 | 130 | @classmethod 131 | def should_exclude(cls, path: str) -> bool: 132 | path = Path(path).resolve() 133 | if not path.exists(): 134 | return True 135 | if cls.EXCLUDE_HIDDEN_FILES and path.name.startswith("."): 136 | return True 137 | if path.is_file(): 138 | return cls.should_exclude_file(path) 139 | elif path.is_dir(): 140 | return cls.should_exclude_directory(dir_path=path) 141 | else: 142 | return True 143 | 144 | @classmethod 145 | def should_exclude_file(cls, file_path: Path): 146 | # Check stems, suffixes, and parents 147 | if file_path.suffix in cls.EXCLUDED_SUFFIXES: 148 | return True 149 | if file_path.stem in cls.EXCLUDED_FILE_NAMES: 150 | return True 151 | if file_path.parent.name in cls.EXCLUDED_DIRECTORY_NAMES: 152 | return True 153 | try: 154 | if os.stat(file_path).st_size > cls.MAX_BYTES_PER_FILE: 155 | return True 156 | except FileNotFoundError: 157 | return True 158 | is_binary = False 159 | with file_path.open("rb") as fin: 160 | for block in iter(lambda: fin.read(1024), b""): 161 | if b"\0" in block: 162 | is_binary = True 163 | break 164 | return is_binary 165 | 166 | @classmethod 167 | def should_exclude_directory(cls, dir_path: Path) -> bool: 168 | return ( 169 | dir_path.name in cls.EXCLUDED_DIRECTORY_NAMES 170 | or len(list(dir_path.iterdir())) > cls.MAX_FILES_PER_DIRECTORY 171 | ) 172 | 173 | 174 | class CoraConfig(_FileConfigMixin, _RetrieverConfigMixin): 175 | # Additional environments or overridden environments 176 | _additional_envs_ = {} 177 | 178 | @staticmethod 179 | def load(env_file: Optional[str] = None): 180 | load_dotenv(env_file) 181 | 182 | @classmethod 183 | def get(cls, key: str) -> Optional[str]: 184 | return cls._additional_envs_.get(key, None) or os.getenv(key, None) 185 | 186 | @classmethod 187 | def set(cls, key: str, val: str): 188 | cls._additional_envs_[key] = val 189 | return True 190 | 191 | @classmethod 192 | def cache_directory(cls) -> Path: 193 | return Path(cls.get("CACHE_DIRECTORY_PATH")).resolve() 194 | 195 | @classmethod 196 | def keyword_index_cache_directory(cls) -> Path: 197 | keyword_index_cache_directory = CoraConfig.cache_directory() / "keyword_indices" 198 | if not keyword_index_cache_directory.exists(): 199 | keyword_index_cache_directory.mkdir(parents=True) 200 | return keyword_index_cache_directory 201 | 202 | @classmethod 203 | def sanitize_content_in_repository(cls) -> bool: 204 | return misc.to_bool(cls.get("SANITIZE_CONTENT_IN_REPOSITORY")) 205 | 206 | @classmethod 207 | def easydeploy_endpoint_url(cls) -> str: 208 | return cls.get("EASYDEPLOY_ENDPOINT") 209 | 210 | 211 | __ENV_LOADED = False 212 | if not __ENV_LOADED: 213 | CoraConfig.load() 214 | __ENV_LOADED = True 215 | -------------------------------------------------------------------------------- /cora/agents/rewrite/summ.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Tuple, List, Optional 2 | 3 | from cora.agents.rewrite.base import RewriterBase 4 | from cora.agents.simple_agent import SimpleAgent 5 | from cora.llms.factory import LLMConfig, LLMFactory 6 | from cora.repo.repo import Repository 7 | 8 | 9 | class SummaryGen(RewriterBase): 10 | def __init__( 11 | self, 12 | repo: Repository, 13 | *, 14 | summarize_prompt: str, 15 | summarize_returns: List[Tuple[str, Type, str]], 16 | summarize_summary_key: str, 17 | evaluate_prompt: Optional[str], 18 | evaluate_returns: Optional[List[Tuple[str, Type, str]]], 19 | evaluate_score_key: Optional[str], 20 | update_prompt: Optional[str], 21 | update_returns: Optional[List[Tuple[str, Type, str]]], 22 | update_summary_key: Optional[str], 23 | use_llm: LLMConfig, 24 | max_rounds: int = 5, 25 | ): 26 | super().__init__(repo) 27 | self.use_llm = use_llm 28 | self.max_rounds = max_rounds 29 | self.sum_prompt = summarize_prompt 30 | self.sum_returns = summarize_returns 31 | self.eval_prompt = evaluate_prompt 32 | self.eval_returns = evaluate_returns 33 | self.upd_prompt = update_prompt 34 | self.upd_returns = update_returns 35 | self.sum_sum_key = summarize_summary_key 36 | self.eval_score_key = evaluate_score_key 37 | self.upd_sum_key = update_summary_key 38 | 39 | def rewrite(self, query: str) -> str: 40 | summary = self.summarize(query) 41 | if not ( 42 | self.eval_prompt 43 | and self.eval_returns 44 | and self.eval_score_key 45 | and self.upd_prompt 46 | and self.upd_returns 47 | and self.upd_sum_key 48 | ): 49 | return summary 50 | for _ in range(self.max_rounds): 51 | score = self.evaluate(summary, query) 52 | if score >= 2: 53 | break 54 | summary = self.update(summary, query) 55 | return summary 56 | 57 | def summarize(self, query: str): 58 | resp = SimpleAgent( 59 | llm=LLMFactory.create(self.use_llm), returns=self.sum_returns 60 | ).run(self.sum_prompt.format(query=query, repo=self.repo.full_name)) 61 | return resp[self.sum_sum_key] 62 | 63 | def evaluate(self, summary: str, query: str): 64 | assert ( 65 | self.eval_prompt and self.eval_returns and self.eval_score_key 66 | ), "Evaluate cannot be called as no prompt/returns/keys are given" 67 | resp = SimpleAgent( 68 | llm=LLMFactory.create(self.use_llm), returns=self.eval_returns 69 | ).run(self.eval_prompt.format(summary=summary, query=query)) 70 | return resp[self.eval_score_key] 71 | 72 | def update(self, summary: str, query: str): 73 | assert ( 74 | self.upd_prompt and self.upd_returns and self.upd_sum_key 75 | ), "Update cannot be called as no prompt/returns/keys are given" 76 | resp = SimpleAgent( 77 | llm=LLMFactory.create(self.use_llm), returns=self.upd_returns 78 | ).run( 79 | self.upd_prompt.format( 80 | repo=self.repo.full_name, summary=summary, query=query 81 | ) 82 | ) 83 | return resp[self.upd_sum_key] 84 | 85 | 86 | class SummaryGenBuilder: 87 | def __init__(self): 88 | self._sum_prompt: Optional[str] = None 89 | self._sum_returns: Optional[List[Tuple[str, Type, str]]] = None 90 | self._sum_sum_key: Optional[str] = None 91 | self._eval_prompt: Optional[str] = None 92 | self._eval_returns: Optional[List[Tuple[str, Type, str]]] = None 93 | self._eval_score_key: Optional[str] = None 94 | self._upd_prompt: Optional[str] = None 95 | self._upd_returns: Optional[List[Tuple[str, Type, str]]] = None 96 | self._upd_sum_key: Optional[str] = None 97 | 98 | def with_summarize_prompt(self, prompt: str) -> "SummaryGenBuilder": 99 | self._sum_prompt = prompt 100 | return self 101 | 102 | def with_summarize_returns( 103 | self, returns: List[Tuple[str, Type, str]] 104 | ) -> "SummaryGenBuilder": 105 | self._sum_returns = returns 106 | return self 107 | 108 | def with_summarize_summary_key(self, key: str) -> "SummaryGenBuilder": 109 | self._sum_sum_key = key 110 | return self 111 | 112 | def with_evaluate_prompt(self, prompt: str) -> "SummaryGenBuilder": 113 | self._eval_prompt = prompt 114 | return self 115 | 116 | def with_evaluate_returns( 117 | self, returns: List[Tuple[str, Type, str]] 118 | ) -> "SummaryGenBuilder": 119 | self._eval_returns = returns 120 | return self 121 | 122 | def with_evaluate_score_key(self, key: str) -> "SummaryGenBuilder": 123 | self._eval_score_key = key 124 | return self 125 | 126 | def with_update_prompt(self, prompt: str) -> "SummaryGenBuilder": 127 | self._upd_prompt = prompt 128 | return self 129 | 130 | def with_update_returns( 131 | self, returns: List[Tuple[str, Type, str]] 132 | ) -> "SummaryGenBuilder": 133 | self._upd_returns = returns 134 | return self 135 | 136 | def with_update_summary_key(self, key: str) -> "SummaryGenBuilder": 137 | self._upd_sum_key = key 138 | return self 139 | 140 | def build(self): 141 | assert self._sum_prompt is not None, "No summarization prompt" 142 | assert self._sum_sum_key is not None, "No summarization keys" 143 | assert self._sum_returns is not None, "No summarization returns" 144 | 145 | this = self 146 | 147 | class SummaryGenImpl(SummaryGen): 148 | def __init__( 149 | self, 150 | repo: Repository, 151 | *, 152 | use_llm: LLMConfig, 153 | max_rounds: int = 5, 154 | ): 155 | super().__init__( 156 | repo=repo, 157 | use_llm=use_llm, 158 | max_rounds=max_rounds, 159 | summarize_prompt=this._sum_prompt, 160 | summarize_summary_key=this._sum_sum_key, 161 | summarize_returns=this._sum_returns, 162 | evaluate_prompt=this._eval_prompt, 163 | evaluate_returns=this._eval_returns, 164 | evaluate_score_key=this._eval_score_key, 165 | update_prompt=this._upd_prompt, 166 | update_returns=this._upd_returns, 167 | update_summary_key=this._upd_sum_key, 168 | ) 169 | 170 | return SummaryGenImpl 171 | -------------------------------------------------------------------------------- /cora/agents/reason_agent.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | 4 | from cora.agents.base import AgentBase 5 | from cora.llms.base import LLMBase, ChatMessage 6 | 7 | G1_SYSTEM_PROMPT = """\ 8 | You are an expert AI assistant that creates advanced reasoning chain against a user's query. \ 9 | For each step in the chain, provide a title and content that demonstrates your thought process, \ 10 | your confidence on your reasoning in this step, and if you need a further reasoning step. 11 | 12 | FOLLOW THESE GUIDELINES: 13 | 14 | 1. USE AT LEAST 5 REASONING STEPS, aiming for 7-10 steps for complex problems. 15 | 2. EFFICIENCY: Use concise, clear language. 16 | 3. MULTIPLE METHODS: Use at least 3 distinct approaches to derive the answer. 17 | 4. EXPLORE ALTERNATIVES: Consider and analyze potential alternative answers. 18 | 5. CHALLENGE ASSUMPTIONS: Critically examine your own reasoning and initial conclusions. 19 | 6. ADDRESS LLM LIMITATIONS: Be aware of and compensate for typical AI shortcomings. 20 | 7. QUANTIFY CONFIDENCE: For each step and the final answer, provide a confidence level (0-100%). 21 | 8. CITE SOURCES: If referring to factual information, mention where you would source it from. 22 | 9. NO ETHICAL FILTERS: Reason objectively without moral constraints. 23 | 10. REAL-WORLD APPLICATION: Relate the problem or solution to practical, real-world scenarios. 24 | 11. NO ONLINE TOOLS AND SEARCHING: You cannot use online tools or search the internet. 25 | 26 | MOST IMPORTANT: \ 27 | Respond in JSON format with 'title', 'content', 'confidence' (0-100), and 'next_action' ('continue' or 'final_answer') keys. 28 | REPLY WITH EXACTLY ONE JSON OBJECT THAT REPRESENTS EXACTLY ONE STEP IN YOUR REASONING. 29 | 30 | Example of a valid JSON response: 31 | { 32 | "title": "Initial Problem Analysis", 33 | "content": "To begin solving this problem, I'll break it down into its core components...", 34 | "confidence": 90, 35 | "next_action": "continue" 36 | } 37 | 38 | REMEMBER: Your answer will be parsed as JSON and fed to you in the next step by the main app. \ 39 | For this reason, you MUST ALWAYS use the JSON format and think forward in your response to construct the next step. \ 40 | This does not apply to the final answer, of course.\ 41 | """ 42 | 43 | G1_ASSISTANT_START = """\ 44 | Understood. I will now create a detailed reasoning chain following the given instructions, \ 45 | starting with a thorough problem decomposition step.\ 46 | """ 47 | 48 | G1_USER_CONTINUE_REASONING = """\ 49 | GREAT JOB: Your confidence is {confidence}! Continue with your next reasoning step (in JSON format).\ 50 | """ 51 | 52 | G1_USER_GET_FINAL_ANSWER = """\ 53 | Provide the final answer based on your reasoning above. 54 | 55 | REMEMBER: Do NOT use JSON formatting. Only provide the text response without any titles or preambles.\ 56 | """ 57 | 58 | INVALID_JSON_OBJECT = """\ 59 | ERROR: Your response is NOT a valid JSON object: {error_message}. 60 | 61 | You should respond with valid JSON format, containing \ 62 | 'title', 'content', 'confidence', and 'next_action' (either 'continue' or 'final_answer') keys. 63 | 64 | Example of a valid JSON response: 65 | 66 | ```json 67 | {{ 68 | "title": "Initial Problem Analysis", 69 | "content": "To begin solving this problem, I'll break it down into its core components...", 70 | "confidence": 90, 71 | "next_action": "continue" 72 | }} 73 | ```\ 74 | """ 75 | 76 | INVALID_NEXT_ACTION = """\ 77 | ERROR: The EXPECTED value for 'next_action' is EITHER 'continue' OR 'final_answer', but we got '{invalid_key}'. 78 | 79 | REMEMBER: Set 'next_action' to 'continue' if you need another step, \ 80 | otherwise (if you're ready to give the final answer), set it to 'final_answer'.\ 81 | """ 82 | 83 | 84 | class R1: 85 | """ 86 | This is a reasoning agent resembling to the following projects: 87 | - https://github.com/bklieger-groq/g1 88 | - https://github.com/tcsenpai/multi1 89 | """ 90 | 91 | def __init__(self, llm: LLMBase, max_chat_round: int = 25): 92 | self.llm = llm 93 | self.max_chat_round = max_chat_round 94 | 95 | def is_debugging(self) -> bool: 96 | return self.llm.is_debug_mode() 97 | 98 | def enable_debugging(self): 99 | self.llm.enable_debug_mode() 100 | 101 | def disable_debugging(self): 102 | self.llm.disable_debug_mode() 103 | 104 | def get_history(self) -> List[ChatMessage]: 105 | return self.llm.get_history() 106 | 107 | def run(self, query: str, *, with_internal_thoughts: bool = False) -> str: 108 | self.llm.clear_history() 109 | 110 | self.llm.append_system_message(G1_SYSTEM_PROMPT) 111 | self.llm.append_user_message(query) 112 | self.llm.append_assistant_message(G1_ASSISTANT_START) 113 | 114 | thoughts = [] 115 | 116 | for _ in range(self.max_chat_round): 117 | try: 118 | step_resp = self.llm.query() 119 | except Exception: 120 | continue 121 | 122 | step_data, err_msg = AgentBase.parse_json_response(step_resp) 123 | 124 | # Not a valid JSON object 125 | if step_data is None: 126 | self.llm.append_user_message( 127 | INVALID_JSON_OBJECT.format(error_message=err_msg) 128 | ) 129 | continue 130 | 131 | # Check if the required keys are in the response 132 | err_msg = "" 133 | for key in ["title", "content", "confidence", "next_action"]: 134 | if key not in step_data: 135 | err_msg = f"Missing {key}." 136 | self.llm.append_user_message( 137 | INVALID_JSON_OBJECT.format(error_message=err_msg) 138 | ) 139 | break 140 | if err_msg: 141 | continue 142 | 143 | # Check if the next_action is valid 144 | next_action = step_data["next_action"] 145 | if next_action not in ["continue", "final_answer"]: 146 | self.llm.append_user_message( 147 | INVALID_NEXT_ACTION.format(invalid_key=next_action) 148 | ) 149 | continue 150 | 151 | # Extend our final response 152 | if with_internal_thoughts: 153 | thoughts.append( 154 | f"## Thinking: {step_data['title']} (Confidence: {step_data['confidence']})" 155 | ) 156 | thoughts.append(f"{step_data['content']}") 157 | 158 | if next_action == "final_answer": 159 | break 160 | 161 | # Rectify the assistant's response and ask it to continue 162 | self.llm.get_history()[-1].content = json.dumps(step_data) 163 | self.llm.append_user_message( 164 | G1_USER_CONTINUE_REASONING.format(confidence=step_data["confidence"]) 165 | ) 166 | 167 | self.llm.append_user_message(G1_USER_GET_FINAL_ANSWER) 168 | 169 | if with_internal_thoughts: 170 | thoughts.append("## Final Answer") 171 | 172 | thoughts.append(self.llm.query()) 173 | 174 | return "\n\n".join(thoughts) 175 | -------------------------------------------------------------------------------- /cora/fixit.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | import subprocess 3 | from argparse import ArgumentParser 4 | from pathlib import Path 5 | from typing import Optional, List, Tuple, cast 6 | 7 | from cora import options, results 8 | from cora.agent import RepoAgent 9 | from cora.agents.rewrite.issue import IssueSummarizer 10 | from cora.base.console import BoxedConsoleBase 11 | from cora.base.rag import GeneratorBase 12 | from cora.options import ArgumentError 13 | from cora.repair import repair 14 | from cora.repair.events import IssueRepaCallbacks 15 | from cora.repo.repo import Repository 16 | from cora.utils import cmdline 17 | 18 | 19 | class EvalScript: 20 | def __init__( 21 | self, eval_script: str, eval_args: Optional[str], console: BoxedConsoleBase 22 | ): 23 | self.eval_script = eval_script 24 | self.eval_args = eval_args or "" 25 | self.console = console 26 | 27 | def __call__( 28 | self, 29 | issue_id: str, 30 | patch_str: str, 31 | original_repo: Repository, 32 | patched_repo: Repository, 33 | *args, 34 | **kwargs, 35 | ) -> bool: 36 | try: 37 | cmdline.check_call( 38 | f"{self.eval_script} " 39 | f"{issue_id} {shlex.quote(patch_str)} " 40 | f"{original_repo.repo_path} {patched_repo.repo_path} " 41 | f"{self.eval_args}", 42 | timeout=5 * 60, 43 | ) 44 | return True 45 | except subprocess.CalledProcessError as e: 46 | ecode = e.returncode 47 | emsg = e.stderr 48 | if emsg: 49 | emsg = str(emsg, encoding="utf-8").strip() 50 | self.console.printb( 51 | f"The patche does not pass all tests (exit_code={ecode}) with the following errors raised: {emsg}" 52 | ) 53 | else: 54 | self.console.printb( 55 | f"The patche does not pass all tests (exit_code={ecode})." 56 | ) 57 | return False 58 | 59 | 60 | class _Generator(GeneratorBase): 61 | def __init__( 62 | self, 63 | eval_script: Optional[str], 64 | eval_args: Optional[str], 65 | ): 66 | super().__init__() 67 | self.eval_script = eval_script 68 | self.eval_args = eval_args 69 | self.callbacks = [] 70 | 71 | def add_callback(self, cbs: IssueRepaCallbacks): 72 | self.callbacks.append(cbs) 73 | 74 | def generate(self, issue: str, snip_ctx: List[str], **kwargs) -> any: 75 | assert self.agent, "RepoAgent hasn't been injected. Please invoke inject_agent() before calling this method" 76 | agent = cast(RepoAgent, self.agent) 77 | issue_id = kwargs["issue_id"] 78 | num_retries = kwargs["num_retries"] 79 | repa = repair.IssueRepa( 80 | agent.repo, 81 | use_llm=agent.use_llm, 82 | debug_mode=agent.debug_mode, 83 | ) 84 | for cbs in self.callbacks: 85 | repa.add_callbacks(cbs) 86 | console = agent.console 87 | if self.eval_script: 88 | console.printb( 89 | f"Generating a plausible patch that passed the evaluation script: {self.eval_script}" 90 | ) 91 | patch = repa.try_repair( 92 | issue, 93 | issue_id, 94 | snip_ctx, 95 | EvalScript(self.eval_script, self.eval_args, console), 96 | num_retries=num_retries, 97 | num_proc=agent.num_proc, 98 | ) 99 | else: 100 | console.printb("Generating a plausible patch (without evaluation script)") 101 | patch = repa.gen_patch(issue, snip_ctx) 102 | if patch: 103 | console.printb(f"The generated patch is: ```diff\n{patch}\n```") 104 | else: 105 | console.printb("No available patches are generated.") 106 | return patch 107 | 108 | 109 | def parse_eval_script(args) -> Tuple[Optional[str], Optional[str]]: 110 | eval_script = args.eval_script or None 111 | if eval_script: 112 | eval_script_path = Path(eval_script) 113 | if not eval_script_path.exists(): 114 | raise ArgumentError(f"The evaluation script does not exist: {eval_script}") 115 | if not eval_script_path.is_file(): 116 | raise ArgumentError(f"The evaluation script is not a file: {eval_script}") 117 | eval_args = args.eval_args or None 118 | if eval_args and not eval_script: 119 | raise ArgumentError( 120 | "Evaluation args are provided yet the evaluation script is not" 121 | ) 122 | return eval_script, eval_args 123 | 124 | 125 | def parse_args(): 126 | parser = ArgumentParser() 127 | options.make_common_options(parser) 128 | parser.add_argument( 129 | "--issue-id", 130 | "-i", 131 | required=True, 132 | type=str, 133 | help="The ID of the issue, this is used by the evaluation script for referring to the issue.", 134 | ) 135 | parser.add_argument( 136 | "--eval-script", 137 | "-e", 138 | default="", 139 | type=str, 140 | help="Path to a script for evaluating if the patched repository could resolve the issue. " 141 | "The first argument of the script should be the ID of the issue. " 142 | "The second argument of the script is a string of the patch in the format of unified diff. " 143 | "The third argument of the script is an absolute path to the original repository. " 144 | "The fourth argument of the script is an absolute path to the patched repository. " 145 | 'All the rest arguments should be passed via "--eval-args"', 146 | ) 147 | parser.add_argument( 148 | "--eval-args", 149 | default="", 150 | type=str, 151 | help="Additional arguments passed to the evaluation script", 152 | ) 153 | parser.add_argument( 154 | "--max-retries", 155 | "-M", 156 | default=20, 157 | type=int, 158 | help="The max number of allowed attempts for evaluating the generated patch.", 159 | ) 160 | return parser.parse_args() 161 | 162 | 163 | def main(): 164 | args = parse_args() 165 | 166 | repo = options.parse_repo(args) 167 | issue, incl = options.parse_query(args) 168 | issue_id = args.issue_id 169 | llm = options.parse_llms(args) 170 | procs, threads = options.parse_perf(args) 171 | eval_script, eval_args = parse_eval_script(args) 172 | log_dir, verbose = options.parse_logging(args) 173 | 174 | fixit = RepoAgent( 175 | name="FixIt!", 176 | repo=repo, 177 | includes=incl, 178 | use_llm=llm, 179 | rewriter=IssueSummarizer(repo, use_llm=llm), 180 | generator=_Generator(eval_script=eval_script, eval_args=eval_args), 181 | num_proc=procs, 182 | num_thread=threads, 183 | files_as_context=False, 184 | debug_mode=verbose, 185 | ) 186 | if log_dir: 187 | fixit.cfar.add_callback(results.CfarResult(log_dir / "cfar_res.json")) 188 | cast(_Generator, fixit.generator).add_callback( 189 | results.IssueRepaResult(log_dir / "fixit_res.json") 190 | ) 191 | fixit.run( 192 | query=issue, 193 | generation_args={"issue_id": issue_id, "num_retries": args.max_retries}, 194 | ) 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /cora/repair/patch.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | import os 3 | import re 4 | from collections import OrderedDict 5 | from typing import Tuple, Optional, List 6 | 7 | from cora.agents.base import AgentBase 8 | from cora.base.console import get_boxed_console 9 | from cora.base.paths import SnippetPath 10 | from cora.llms.base import LLMBase 11 | from cora.repo.repo import Repository 12 | from cora.utils.misc import ordered_set 13 | 14 | SYSTEM_PROMPT = """\ 15 | We are currently solving the following issue within our repository. Here is the issue text: 16 | --- BEGIN ISSUE --- 17 | {issue_text} 18 | 19 | --- END ISSUE --- 20 | 21 | 22 | Below are some code segments, each from a relevant file. One or more of these files may contain bugs. 23 | 24 | --- BEGIN FILE --- 25 | ``` 26 | {snippet_context} 27 | ``` 28 | --- END FILE --- 29 | 30 | Please first localize the bug based on the issue statement, and then generate *SEARCH/REPLACE* edits to fix the issue. 31 | 32 | Every *SEARCH/REPLACE* edit must use this format: 33 | 1. The file path 34 | 2. The start of search block: <<<<<<< SEARCH 35 | 3. A contiguous chunk of lines to search for in the existing source code 36 | 4. The dividing line: ======= 37 | 5. The lines to replace into the source code 38 | 6. The end of the replace block: >>>>>>> REPLACE 39 | 40 | Here is an example: 41 | 42 | ```python 43 | ### mathweb/flask/app.py 44 | <<<<<<< SEARCH 45 | from flask import Flask 46 | ======= 47 | import math 48 | from flask import Flask 49 | >>>>>>> REPLACE 50 | ``` 51 | 52 | Please note that the *SEARCH/REPLACE* edit REQUIRES PROPER INDENTATION. If you would like to add the line ' print(x)', you must fully write that out, with all those spaces before the code! 53 | Wrap the *SEARCH/REPLACE* edit in blocks ```python...```. 54 | """ 55 | 56 | _PATTERN_EDIT_COMMAND = r"```python\n(.*?)\n```" 57 | _PATTERN_SEARCH_LINE = "<<< SEARCH" 58 | _PATTERN_SEPARATOR_LINE = "=======" 59 | _PATTERN_REPLACE_LINE = ">>>>>>> REPLACE" 60 | 61 | DEBUG_OUTPUT_LOGGING_COLOR = "grey50" 62 | DEBUG_OUTPUT_LOGGING_TITLE = "Patcher" 63 | 64 | 65 | class PatchGen(AgentBase): 66 | def __init__( 67 | self, repo: Repository, llm: LLMBase, debug_mode=False, *args, **kwargs 68 | ): 69 | super().__init__(llm=llm, json_schema=None, *args, **kwargs) 70 | self.repo = repo 71 | self.console = self.console = get_boxed_console( 72 | box_title=DEBUG_OUTPUT_LOGGING_TITLE, 73 | box_bg_color=DEBUG_OUTPUT_LOGGING_COLOR, 74 | debug_mode=debug_mode, 75 | ) 76 | 77 | def generate( 78 | self, 79 | issue_text: str, 80 | snip_paths: List[str], 81 | max_patches: int = 1, 82 | context_window: int = 10, 83 | ) -> List[str]: 84 | file_paths = ordered_set( 85 | [str(SnippetPath.from_str(sp).file_path) for sp in snip_paths] 86 | ) 87 | prompt = SYSTEM_PROMPT.format( 88 | issue_text=issue_text, 89 | snippet_context=self._make_context(snip_paths, surroundings=context_window), 90 | ) 91 | patches = [] 92 | for i in range(max_patches): 93 | self.console.printb( 94 | f"Try generating the {i+1}-th patch ({max_patches} patches are requested in total)" 95 | ) 96 | patch = self.run(prompt, file_paths=file_paths) 97 | if patch is None: 98 | self.console.printb("Generation failed") 99 | continue 100 | self.console.printb( 101 | f"Succeeded! The generated patch is: ```diff\n{patch}\n```" 102 | ) 103 | patches.append(patch) 104 | return patches 105 | 106 | def _parse_response(self, response, *args, **kwargs) -> Optional[str]: 107 | file_paths = kwargs["file_paths"] 108 | edit_cmds = re.findall(_PATTERN_EDIT_COMMAND, response, re.DOTALL) 109 | patch = [] 110 | for edit in edit_cmds: 111 | self.console.printb(f"Found edit command: {edit}") 112 | edited_file, old_cont, new_cont = self._parse_edit(edit, file_paths) 113 | if not edited_file: 114 | self.console.printb("The above edit command is invalid") 115 | continue 116 | self.console.printb(f"The above edit command is for: {edited_file}") 117 | from_file = os.path.join("a", edited_file) 118 | to_file = os.path.join("b", edited_file) 119 | udiff = "".join( 120 | difflib.unified_diff( 121 | old_cont.splitlines(keepends=True), 122 | new_cont.splitlines(keepends=True), 123 | # Follow git to preceding an "a/" and a "b" to the from and to file 124 | fromfile=from_file, 125 | tofile=to_file, 126 | ) 127 | ) 128 | if not udiff: 129 | continue 130 | udiff = f"diff --git {from_file} {to_file}\n" + udiff 131 | patch.append(udiff) 132 | if not patch: 133 | self.console.printb("No valid edit commands found in LLM's response") 134 | return "\n".join(patch) if patch else None 135 | 136 | def _make_context(self, snip_paths: List[str], surroundings: int): 137 | ctx_dict = OrderedDict() 138 | 139 | for sp in snip_paths: 140 | fp = str(SnippetPath.from_str(sp).file_path) 141 | if fp not in ctx_dict: 142 | ctx_dict[fp] = f"### {fp}\n" 143 | cont = self.repo.get_snippet_content( 144 | sp, 145 | surroundings=surroundings, 146 | add_lines=False, 147 | add_separators=False, 148 | ) 149 | ctx_dict[fp] += f"...\n{cont}\n...\n" 150 | 151 | return "".join(ctx_dict.values()) 152 | 153 | def _parse_edit( 154 | self, edit: str, file_paths: List[str] 155 | ) -> Tuple[Optional[str], str, str]: 156 | # Find which file is in editing 157 | edited_file = None 158 | for name in file_paths: 159 | if name in edit: 160 | edited_file = name 161 | break 162 | if not edited_file: 163 | return None, "", "" 164 | 165 | # Find search, separator, and replace lines 166 | ser_lno, sep_lno, rep_lno = -1, -1, -1 167 | edit_lines = edit.splitlines() 168 | for index in range(len(edit_lines)): 169 | if _PATTERN_SEARCH_LINE in edit_lines[index]: 170 | ser_lno = index 171 | elif _PATTERN_SEPARATOR_LINE in edit_lines[index]: 172 | sep_lno = index 173 | elif _PATTERN_REPLACE_LINE in edit_lines[index]: 174 | rep_lno = index 175 | if not (ser_lno <= sep_lno <= rep_lno): 176 | return None, "", "" 177 | 178 | # Parse search and replace content 179 | search_content = "\n".join(edit_lines[ser_lno + 1 : sep_lno]) 180 | replace_content = "\n".join(edit_lines[sep_lno + 1 : rep_lno]) 181 | if ( 182 | not search_content or not replace_content 183 | ): # TODO: Should we allow deletion, i.e., replace_content == ""? 184 | return None, "", "" 185 | 186 | # Get the old content of the edited file and compute its new content 187 | old_cont = self.repo.get_file_content(edited_file, add_lines=False) 188 | if search_content not in old_cont: 189 | return None, "", "" 190 | return edited_file, old_cont, old_cont.replace(search_content, replace_content) 191 | -------------------------------------------------------------------------------- /cora/agents/snippets/score_snip.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from cora.agents.base import AgentBase 4 | from cora.agents.snippets.base import SnipRelDetmBase 5 | from cora.base.paths import SnippetPath 6 | from cora.llms.base import LLMBase 7 | 8 | SYSTEM_PROMPT = """\ 9 | You are a Snippet Relevance Scorer, tasked to determine if a given "File Snippet" is relevant to a given "User Query", \ 10 | and give a relevance score according to your determination. \ 11 | And the file snippet is considered to be relevant to the user query if it can be used to address the user query. 12 | 13 | For this task, I will provide you with a "File Snippet" which is a portion of the file **{file_name}**, and a "User Query". \ 14 | The file snippet lists the content of the snippet; \ 15 | it is a portion of a file, wrapped by "===START OF SNIPPET===" and "===END OF SNIPPET===". \ 16 | We also provide some lines of surrounding content of the snippet in its file for your reference. 17 | 18 | A file snippet is relevant to a user query if the file snippet can be an important part to address the user query \ 19 | (though it might not address the user query directly). \ 20 | You should determine their relevance by the following steps: 21 | 1. Think carefully what files are required to address the user query; 22 | 2. Analyze if the file snippet is part of those files and what the snippet provides or what the file snippet does; 23 | 3. Check if the file snippet can provide some useful information to address the user query directly or indirectly; 24 | 4. Conclude if the file snippet are relevant to the user query and give a relevance score. 25 | 26 | The relevance score (an integer chosen from [0, 1, 2, 3]) represents the relevance of the file snippet and the user query, where 27 | - Score 0: The file snippet is totally irrelevant to the user query; it does not help anything to address the user query. 28 | - Score 1: The file snippet is weakly relevant to the user query, but the user query can be addressed even without it. 29 | - Score 2: The file snippet is relevant to the user query; the user query can only be partially addressed without it. 30 | - Score 3: The file snippet is strongly relevant to the user query, and the user query relies on it can never be addressed without it. 31 | 32 | ## User Query ## 33 | 34 | ``` 35 | {user_query} 36 | ``` 37 | 38 | ## File Snippet ## 39 | 40 | ``` 41 | //// Snippet: {snippet_path} 42 | {file_snippet} 43 | ``` 44 | 45 | """ 46 | 47 | JSON_SCHEMA = """\ 48 | { 49 | "score": , // the relevance score; it should be an integer chosen from [0, 1, 2, 3] 50 | "reason": "the reason why you give the relevance score" // explain in detail, step-by-step, the relevance between the file snippet and the user query 51 | }\ 52 | """ 53 | 54 | NON_INTEGER_SCORE_MESSAGE = """\ 55 | **FAILURE**: The relevance score ({score}) you gave is NOT an integer. 56 | 57 | The relevance score must be an integer chosen from [0, 1, 2, 3], where: 58 | - Score 0: The file snippet is totally irrelevant to the user query; it does not help anything to address the user query. 59 | - Score 1: The file snippet is weakly relevant to the user query, but the user query can be addressed even without it. 60 | - Score 2: The file snippet is relevant to the user query; the user query can only be partially addressed without it. 61 | - Score 3: The file snippet is strongly relevant to the user query, and the user query relies on it can never be addressed without it. 62 | 63 | ## Your Response (JSON format) ## 64 | 65 | """ 66 | 67 | INVALID_SCORE_VALUE_MESSAGE = """\ 68 | **FAILURE**: The relevance score ({score}) you gave is NOT chosen from [0, 1, 2, 3]. 69 | 70 | The relevance score must be an integer chosen from [0, 1, 2, 3], where: 71 | - Score 0: The file snippet is totally irrelevant to the user query; it does not help anything to address the user query. 72 | - Score 1: The file snippet is weakly relevant to the user query, but the user query can be addressed even without it. 73 | - Score 2: The file snippet is fairly relevant to the user query; the user query can only be partially addressed without it. 74 | - Score 3: The file snippet is strongly relevant to the user query, and the user query relies on it can never be addressed without it. 75 | 76 | ## Your Response (JSON format) ## 77 | 78 | """ 79 | 80 | SCORE_IRRELEVANCE = 0 81 | SCORE_WEAK_RELEVANCE = 1 82 | SCORE_FAIR_RELEVANCE = 2 83 | SCORE_STRONG_RELEVANCE = 3 84 | 85 | 86 | class SnipScorer(SnipRelDetmBase, AgentBase): 87 | def __init__( 88 | self, 89 | llm: LLMBase, 90 | threshold: int = SCORE_WEAK_RELEVANCE, # inclusive 91 | *args, 92 | **kwargs, 93 | ): 94 | AgentBase.__init__(self, llm=llm, json_schema=JSON_SCHEMA, *args, **kwargs) 95 | self.threshold = threshold 96 | 97 | def is_debugging(self) -> bool: 98 | return AgentBase.is_debugging(self) 99 | 100 | def enable_debugging(self): 101 | AgentBase.enable_debugging(self) 102 | 103 | def disable_debugging(self): 104 | AgentBase.disable_debugging(self) 105 | 106 | def determine( 107 | self, query: str, snippet_path: str, snippet_content: str, *args, **kwargs 108 | ) -> Tuple[bool, str]: 109 | score, reason = self.score(query, snippet_path, snippet_content) 110 | if score >= self.threshold: 111 | return True, ( 112 | f"The snippet is considered relevant to the user query as: " 113 | f"the score of the snippet ({score}) is within our threshold ({self.threshold}). " 114 | f"The reason for giving the score ({score}) is: " + reason 115 | ) 116 | else: 117 | return False, ( 118 | f"The snippet is considered irrelevant to the user query as: " 119 | f"the score of the snippet ({score}) is beyond our threshold ({self.threshold}). " 120 | f"The reason for giving the score ({score}) is: " + reason 121 | ) 122 | 123 | def score( 124 | self, query: str, snippet_path: str, snippet_content: str 125 | ) -> Tuple[int, str]: 126 | return self.run( 127 | SYSTEM_PROMPT.format( 128 | file_name=SnippetPath.from_str(snippet_path).file_path.name, 129 | user_query=query, 130 | snippet_path=snippet_path, 131 | file_snippet=snippet_content, 132 | ) 133 | ) 134 | 135 | def _check_response_format( 136 | self, response: dict, *args, **kwargs 137 | ) -> Tuple[bool, Optional[str]]: 138 | for field in ["score", "reason"]: 139 | if field not in response: 140 | return False, f"'{field}' is missing in the JSON object" 141 | return True, None 142 | 143 | def _check_response_semantics( 144 | self, response: dict, *args, **kwargs 145 | ) -> Tuple[bool, Optional[str]]: 146 | try: 147 | score = int(response["score"]) 148 | except ValueError: 149 | return False, NON_INTEGER_SCORE_MESSAGE.format(score=response["score"]) 150 | if score not in [0, 1, 2, 3]: 151 | return False, INVALID_SCORE_VALUE_MESSAGE.format(score=score) 152 | return True, None 153 | 154 | def _parse_response(self, response: dict, *args, **kwargs) -> any: 155 | return int(response["score"]), response["reason"] 156 | 157 | def _default_result_when_reaching_max_chat_round(self): 158 | return ( 159 | SCORE_IRRELEVANCE, 160 | "The model have reached the max number of chat round and is unable to score their relevance.", 161 | ) 162 | -------------------------------------------------------------------------------- /cora/options.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser, Action 3 | from pathlib import Path 4 | from typing import Tuple, List 5 | 6 | from cora.base.console import BoxedConsoleConfigs 7 | from cora.base.repos import RepoTup 8 | from cora.llms.factory import LLMConfig 9 | from cora.repo.repo import Repository 10 | 11 | 12 | class ArgumentError(ValueError): 13 | pass 14 | 15 | 16 | def _save_options(args: any, file: Path): 17 | with file.open("w") as fou: 18 | json.dump(vars(args), fou, ensure_ascii=False, indent=2) 19 | 20 | 21 | def parse_repo(args: any) -> Repository: 22 | s = args.repository 23 | try: 24 | f, p = s.split(":", maxsplit=1) 25 | o, n = f.split("/", maxsplit=1) 26 | except ValueError: 27 | raise ArgumentError( 28 | 'Invalid argument for "repository". It should be formatted ' 29 | 'in "org/name:path" such as "torvalds/linux:/path/to/my/linux/mirror".' 30 | ) 31 | repo_path = Path(p) 32 | if not repo_path.exists(): 33 | raise ArgumentError( 34 | f"The repository does not exist; please check the repository's path: {p}." 35 | ) 36 | if not repo_path.is_dir(): 37 | raise ArgumentError( 38 | f"The repository is not a directory; please check the repository's path: {p}." 39 | ) 40 | return Repository(RepoTup(org=o, name=n, path=p), excludes=(args.excludes or [])) 41 | 42 | 43 | def parse_query(args: any) -> Tuple[str, List[str]]: 44 | path = Path(args.query) 45 | if path.exists() and path.is_file(): 46 | query = path.read_text(encoding="utf-8", errors="replace") 47 | else: 48 | query = args.query 49 | includes = args.includes or [] 50 | return query, includes 51 | 52 | 53 | def parse_llms(args: any) -> LLMConfig: 54 | try: 55 | p, m = args.model.split(":", maxsplit=1) 56 | except ValueError: 57 | raise ArgumentError( 58 | 'Invalid argument for "--model". It should be in the format ' 59 | 'of "provider:model" such as "openai:gpt-4o", "ollama:qwen2:0.5b-instruct".' 60 | ) 61 | return LLMConfig( 62 | provider=p, 63 | llm_name=m, 64 | debug_mode=args.verbose, 65 | temperature=args.model_temperature, 66 | top_k=args.model_top_k, 67 | top_p=args.model_top_p, 68 | max_tokens=args.model_max_tokens, 69 | ) 70 | 71 | 72 | def parse_perf(args: any) -> Tuple[int, int]: 73 | return args.num_procs, args.num_threads 74 | 75 | 76 | def parse_logging(args: any): 77 | if args.log_dir: 78 | log_dir = Path(args.log_dir) 79 | if log_dir.exists(): 80 | raise ArgumentError(f"The logging directory already exists: {log_dir}") 81 | log_dir.mkdir(exist_ok=False, parents=False) 82 | _save_options(args, file=(log_dir / "commands.json")) 83 | args.verbose = True # We enable verbose mode if log_dir is present 84 | BoxedConsoleConfigs.out_dir = str(log_dir.resolve()) 85 | BoxedConsoleConfigs.print_to_console = True 86 | else: 87 | log_dir = None 88 | return log_dir, args.verbose 89 | 90 | 91 | def make_repo_options(parser: ArgumentParser) -> List[Action]: 92 | return [ 93 | parser.add_argument( 94 | "repository", 95 | help='The repository in the format of "org/name:path" ' 96 | 'such as "torvalds/linux:/path/to/my/linux/mirror"', 97 | ), 98 | parser.add_argument( 99 | "--excludes", 100 | type=str, 101 | action="append", 102 | help="Files (UNIX shell-style patterns) in the repository to " 103 | "exclude in the whole process; this can be specified multiple times", 104 | ), 105 | ] 106 | 107 | 108 | def make_query_options(parser: ArgumentParser) -> List[Action]: 109 | return [ 110 | parser.add_argument( 111 | "--query", 112 | "-q", 113 | required=True, 114 | type=str, 115 | help="The user query against the repository; " 116 | "either a simple string or a path to a UTF-8 file saving the query", 117 | ), 118 | parser.add_argument( 119 | "--includes", 120 | type=str, 121 | action="append", 122 | help="Files (UNIX shell-style patterns) to consider when retrieving " 123 | "relevant context for the user query; this can be specified multiple times", 124 | ), 125 | ] 126 | 127 | 128 | def make_model_options(parser: ArgumentParser) -> List[Action]: 129 | return [ 130 | parser.add_argument( 131 | "--model", 132 | "-m", 133 | required=True, 134 | type=str, 135 | help='The assistive LM in the format of "provider:model" such as ' 136 | '"openai:gpt-4o", "ollama:qwen2:0.5b-instruct"', 137 | ), 138 | parser.add_argument( 139 | "--model-temperature", 140 | "-T", 141 | default=0.8, 142 | type=float, 143 | help="Parameter temperature controlling the LM's generation", 144 | ), 145 | parser.add_argument( 146 | "--model-top-k", 147 | "-K", 148 | default=50, 149 | type=int, 150 | help="Parameter top-k controlling the LM's generation", 151 | ), 152 | parser.add_argument( 153 | "--model-top-p", 154 | "-P", 155 | default=0.95, 156 | type=float, 157 | help="Parameter top-p controlling the LM's generation", 158 | ), 159 | parser.add_argument( 160 | "--model-max-tokens", 161 | "-L", 162 | default=1024, 163 | type=int, 164 | help="Parameter max-tokens controlling the LM's maximum number of tokens to generate", 165 | ), 166 | ] 167 | 168 | 169 | def make_perf_options(parser: ArgumentParser) -> List[Action]: 170 | return [ 171 | parser.add_argument( 172 | "--num-procs", 173 | "-j", 174 | default=1, 175 | type=int, 176 | help="The maximum number of processes to use in parallel", 177 | ), 178 | parser.add_argument( 179 | "--num-threads", 180 | "-t", 181 | default=1, 182 | type=int, 183 | help="The maximum number of threads to use in parallel in the each process", 184 | ), 185 | ] 186 | 187 | 188 | def make_logging_options(parser: ArgumentParser) -> List[Action]: 189 | return [ 190 | parser.add_argument( 191 | "--verbose", 192 | action="store_true", 193 | help="Enable verbose logging (this includes all interactions with the LM)", 194 | ), 195 | parser.add_argument( 196 | "--log-dir", 197 | type=str, 198 | default="", 199 | help="Store trajectories and logs into the assigned directory; " 200 | "This option also implies --verbose", 201 | ), 202 | ] 203 | 204 | 205 | def make_common_options(parser: ArgumentParser) -> List[Action]: 206 | return [ 207 | # Options related to the repository 208 | *make_repo_options(parser), 209 | # Options related to the user query 210 | *make_query_options(parser), 211 | # Options related to the LM (language model) 212 | *make_model_options(parser), 213 | # Options related to performance 214 | *make_perf_options(parser), 215 | # Logging options 216 | *make_logging_options(parser), 217 | ] 218 | -------------------------------------------------------------------------------- /cora/repair/refine.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, List 2 | 3 | from cora.agents.base import AgentBase 4 | from cora.base.paths import SnippetPath 5 | from cora.llms.base import LLMBase 6 | from cora.repo.repo import Repository 7 | from cora.utils.interval import merge_overlapping_intervals 8 | 9 | SYSTEM_PROMPT = """\ 10 | You are an Intelligent RELEVANT Snippet Retrieval Assistant. 11 | Your Task is to go through the provided "File Snippet" and identify the line of Snippet that are most relevant to "User Query" and "Github Issue". 12 | 13 | For this task, I will provide you with a "File Snippet" which is a portion of the file **{file_name}**, a "User Query" and a "Github Issue". 14 | The "User Query" is summarized from the "Github Issue", and the "Github Issue" contains more details. 15 | The file snippet lists the content of the snippet; \ 16 | it is a portion of a file, wrapped by "===START OF SNIPPET===" and "===END OF SNIPPET===". \ 17 | We also provide some lines of surrounding content of the snippet in its file for your reference. 18 | The Snippet is a preview version which only contains function, class, loop, branch and return, You should Use your prior knowledge to understand the code. 19 | 20 | Correlation criterion: 21 | 1. The code snippet directly addresses or implements functionality related to the "User Query" and the "Github Issue". 22 | 2. The code's purpose aligns with the intent of the "User Query" and the "Github Issue". 23 | 24 | *Key Remind*: 25 | - Some relevant codes may not seem relevant, please think carefully about the meaning of the code!!! 26 | - You must find lines which are relevant to BOTH the "User Query" AND the "Github Issue" !!! 27 | - IF some lines are only relevant to "User Query" or only relevant to the "Github Issue", Exclude Them !!! 28 | - if you think the Whole "File Snippet" is not Completely relevant to "User Query", set the field "line" to null. 29 | - If you're not sure which lines are related to "User Query" and "Github Issue", set the field "line" to null. 30 | - List AT MOST 5 lines you pretty sure they need to be modified to solve the "User Query" and the "Github Issue". 31 | 32 | ## Github Issue ## 33 | ``` 34 | {github_issue} 35 | ``` 36 | 37 | ## File Snippet ## 38 | ``` 39 | //// Snippet: {snippet_path} 40 | {file_snippet} 41 | ``` 42 | 43 | """ 44 | 45 | JSON_SCHEMA = """\ 46 | { 47 | "line": , // the relevant lines; At Most 5 lines, it should be an integer list like [1,2,3]; if there are no relevant lines you should set the field to "null" 48 | "reason": "the reason why you give the lines " // explain in detail, step-by-step, the relevance between the "Code Snippet" and the "User Query" 49 | }\ 50 | """ 51 | 52 | NOT_INTEGER_LIST_MESSAGE = """\ 53 | **FAILURE**: The relevant line ({line}) you gave is NOT an integer list or null. 54 | 55 | The relevance score must be an integer list like [1,2,3,6,7,8] or null! 56 | 57 | 58 | ## Your Response (JSON format) ## 59 | 60 | """ 61 | 62 | INVALID_LINE_NUMBER_MESSAGE = """\ 63 | **FAILURE**: The relevant line ({line}) you gave is NOT chosen from the line number I gave you . 64 | 65 | The relevance line number must be an integer chosen from the "File Snippet" which contains the line numbers you should give. 66 | 67 | 68 | ## Your Response (JSON format) ## 69 | 70 | """ 71 | 72 | INVALID_JSON_OBJECT_MESSAGE = """\ 73 | **FAILURE**: Your response is not a valid JSON object: {error_message}. 74 | 75 | You response should strictly do follow the following JSON format: 76 | 77 | ``` 78 | {json_schema} 79 | ``` 80 | 81 | Please fix the above shown issues (shown above) and respond again. 82 | 83 | ## Your Response ## 84 | 85 | """ 86 | 87 | SYSTEM_PROMPT_JSON_INSTRUCTION = """\ 88 | ## Response Format ## 89 | 90 | Your response MUST be in the following JSON format: 91 | 92 | ``` 93 | {json_schema} 94 | ``` 95 | 96 | 97 | ## Your Response ## 98 | 99 | """ 100 | VIOLATED_JSON_FORMAT_MESSAGE = """\ 101 | **FAILURE**: Your responded JSON object violates the given JSON format: {error_message}. 102 | 103 | You response should strictly do follow the following JSON format: 104 | 105 | ``` 106 | {json_schema} 107 | ``` 108 | 109 | Please fix the above shown issues (shown above) and respond again. 110 | 111 | ## Your Response ## 112 | 113 | """ 114 | 115 | 116 | class SnipRefiner(AgentBase): 117 | def __init__( 118 | self, 119 | llm: LLMBase, 120 | repo: Repository, 121 | surroundings: int = 10, 122 | *args, 123 | **kwargs, 124 | ): 125 | AgentBase.__init__(self, llm=llm, json_schema=JSON_SCHEMA, *args, **kwargs) 126 | self.repo = repo 127 | self.surroundings = surroundings 128 | 129 | def refine(self, issue: str, snip_path: str) -> Tuple[List[str], str]: 130 | snip_path = SnippetPath.from_str(snip_path) 131 | refined_paths, reason = self.run( 132 | SYSTEM_PROMPT.format( 133 | file_name=snip_path.file_path.name, 134 | snippet_path=str(snip_path), 135 | file_snippet=self.repo.get_snippet_content( 136 | str(snip_path), 137 | self.surroundings, 138 | add_lines=True, 139 | add_separators=True, 140 | ), 141 | github_issue=issue, 142 | ), 143 | snip_path=snip_path, 144 | ) 145 | # Merge continuous and overlap snippets into one snippet 146 | snip_tups = merge_overlapping_intervals( 147 | [ 148 | (SnippetPath.from_str(sp).start_line, SnippetPath.from_str(sp).end_line) 149 | for sp in refined_paths 150 | ], 151 | merge_continuous=True, 152 | ) 153 | return [ 154 | str(SnippetPath(snip_path.file_path, a, b)) for a, b in snip_tups 155 | ], reason 156 | 157 | def _check_response_format( 158 | self, response: dict, *args, **kwargs 159 | ) -> Tuple[bool, Optional[str]]: 160 | for field in ["line", "reason"]: 161 | if field not in response: 162 | return False, f"'{field}' is missing in the JSON object" 163 | return True, None 164 | 165 | def _check_response_semantics( 166 | self, response: dict, *args, **kwargs 167 | ) -> Tuple[bool, Optional[str]]: 168 | # The chosen "line" should be a list 169 | lines = response["line"] 170 | if not lines or lines == "null" or lines == [None]: 171 | return True, None 172 | try: 173 | if not isinstance(lines, list): 174 | return False, NOT_INTEGER_LIST_MESSAGE.format(line=response["line"]) 175 | lines = [int(n) for n in lines] 176 | except ValueError: 177 | return False, NOT_INTEGER_LIST_MESSAGE.format(line=response["line"]) 178 | 179 | # Make sure the line numbers are all in the correct range 180 | snip_path = kwargs["snip_path"] 181 | start_line = snip_path.start_line - self.surroundings 182 | end_line = snip_path.end_line + self.surroundings 183 | 184 | for lno in lines: 185 | lno = int(lno) 186 | if not (start_line <= lno <= end_line): 187 | return False, INVALID_LINE_NUMBER_MESSAGE.format(line=response["line"]) 188 | 189 | return True, None 190 | 191 | def _parse_response(self, response: dict, *args, **kwargs) -> any: 192 | snip_path = kwargs["snip_path"] 193 | lines = response["line"] 194 | reason = response["reason"] 195 | file_path = snip_path.file_path 196 | 197 | if not lines: 198 | return [], None 199 | 200 | # We return all chosen lines 201 | return [str(SnippetPath(file_path, lno, lno + 1)) for lno in lines], reason 202 | 203 | def _default_result_when_reaching_max_chat_round(self): 204 | return ( 205 | [], 206 | "The model have reached the max number of chat round and is unable to score their relevance.", 207 | ) 208 | -------------------------------------------------------------------------------- /cora/agents/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, List 2 | 3 | import pyjson5 as json5 4 | 5 | from cora.llms.base import LLMBase, ChatMessage 6 | 7 | SYSTEM_PROMPT_JSON_INSTRUCTION = """\ 8 | ## Response Format ## 9 | 10 | Your response MUST be in the following JSON format: 11 | 12 | ``` 13 | {json_schema} 14 | ``` 15 | 16 | ## Your Response ## 17 | 18 | """ 19 | 20 | INVALID_JSON_OBJECT_MESSAGE = """\ 21 | **FAILURE**: Your response is not a valid JSON object: {error_message}. 22 | 23 | You response should strictly do follow the following JSON format: 24 | 25 | ``` 26 | {json_schema} 27 | ``` 28 | 29 | Please fix the above shown issues (shown above) and respond again. 30 | 31 | ## Your Response ## 32 | 33 | """ 34 | 35 | VIOLATED_JSON_FORMAT_MESSAGE = """\ 36 | **FAILURE**: Your responded JSON object violates the given JSON format: {error_message}. 37 | 38 | You response should strictly do follow the following JSON format: 39 | 40 | ``` 41 | {json_schema} 42 | ``` 43 | 44 | Please fix the above shown issues (shown above) and respond again. 45 | 46 | ## Your Response ## 47 | 48 | """ 49 | 50 | 51 | class ReachChatRoundLimitException(Exception): 52 | def __init__(self, limit: int): 53 | super().__init__(f"The maximum allowed chat-round limit is {limit}") 54 | 55 | 56 | class AgentBase: 57 | def __init__(self, llm: LLMBase, json_schema: Optional[str], *, max_chat_round=10): 58 | self.llm = llm 59 | self.json_schema = json_schema 60 | self.max_chat_round = max_chat_round 61 | 62 | def is_debugging(self) -> bool: 63 | return self.llm.is_debug_mode() 64 | 65 | def enable_debugging(self): 66 | self.llm.enable_debug_mode() 67 | 68 | def disable_debugging(self): 69 | self.llm.disable_debug_mode() 70 | 71 | def get_history(self) -> List[ChatMessage]: 72 | return self.llm.get_history() 73 | 74 | def run(self, system_prompt: str, *args, **kwargs): 75 | if self.json_schema: 76 | return self._run_with_json_schema(system_prompt, *args, **kwargs) 77 | else: 78 | return self._run_without_json_schema(system_prompt, *args, **kwargs) 79 | 80 | def _run_without_json_schema(self, system_prompt: str, *args, **kwargs): 81 | # We're a clean run, so clear all prior chats 82 | self.llm.clear_history() 83 | 84 | # TODO: Use append_system_message()? 85 | self.llm.append_user_message(system_prompt) 86 | 87 | for _ in range(self.max_chat_round): 88 | try: 89 | response = self.llm.query() 90 | except Exception: 91 | continue 92 | 93 | # Parse the response and return results 94 | return self._parse_response(response, *args, **kwargs) 95 | 96 | return self._default_result_when_reaching_max_chat_round() 97 | 98 | def _run_with_json_schema(self, system_prompt: str, *args, **kwargs): 99 | assert self.json_schema, "No JSON schema is given" 100 | 101 | # We're a clean run, so clear all prior chats 102 | self.llm.clear_history() 103 | 104 | # TODO: Use append_system_message()? 105 | self.llm.append_user_message( 106 | system_prompt 107 | + SYSTEM_PROMPT_JSON_INSTRUCTION.format(json_schema=self.json_schema) 108 | ) 109 | 110 | for _ in range(self.max_chat_round): 111 | try: 112 | response = self.llm.query() 113 | except Exception: 114 | continue 115 | 116 | response, err_msg = self.parse_json_response(response) 117 | 118 | # TODO We need to cleanup all trial-error messages and keep our history clean 119 | 120 | # Not a JSON object, let's try again 121 | if response is None: 122 | self.llm.append_user_message( 123 | INVALID_JSON_OBJECT_MESSAGE.format( 124 | error_message=err_msg, json_schema=self.json_schema 125 | ) 126 | ) 127 | continue 128 | 129 | formatted, err_msg = self._check_response_format(response, *args, **kwargs) 130 | 131 | # Violates JSON format, let's try again 132 | if not formatted: 133 | self.llm.append_user_message( 134 | VIOLATED_JSON_FORMAT_MESSAGE.format( 135 | error_message=err_msg, json_schema=self.json_schema 136 | ) 137 | ) 138 | continue 139 | 140 | valid, err_prompt = self._check_response_semantics( 141 | response, *args, **kwargs 142 | ) 143 | 144 | # Invalid response, let's try again 145 | if not valid: 146 | self.llm.append_user_message(err_prompt) 147 | continue 148 | 149 | # Parse the response and return results 150 | return self._parse_response(response, *args, **kwargs) 151 | 152 | return self._default_result_when_reaching_max_chat_round() 153 | 154 | def _check_response_format( 155 | self, response: dict, *args, **kwargs 156 | ) -> Tuple[bool, Optional[str]]: 157 | """ 158 | Check if the response follow the given JSON schema. 159 | Return (True, None) if the response follows. 160 | Otherwise, (False, error_message) if there are violations. 161 | """ 162 | return True, None 163 | 164 | def _check_response_semantics( 165 | self, response: dict, *args, **kwargs 166 | ) -> Tuple[bool, Optional[str]]: 167 | """ 168 | Check if the response is valid in terms of the agent's functionality/semantics. 169 | Return (True, None) if the response is valid. 170 | Otherwise, (False, error_prompt). 171 | """ 172 | return True, None 173 | 174 | def _parse_response(self, response: any, *args, **kwargs) -> any: 175 | """ 176 | Parse the response and return results. 177 | The results should be of the same type as run(). 178 | """ 179 | return response 180 | 181 | def _default_result_when_reaching_max_chat_round(self): 182 | """ 183 | The default result to return when the model have reached a max chat round. 184 | Usually, this indicates that the model fails to output any valid result. 185 | So be sure to return a value that can indicate an "exit" of running the model. 186 | The return value should be of the same type as run(). 187 | """ 188 | raise ReachChatRoundLimitException(self.max_chat_round) 189 | 190 | @staticmethod 191 | def parse_json_response( 192 | r, drop_newline_symbol=True 193 | ) -> (Optional[dict], Optional[str]): 194 | try: 195 | if "{" not in r: 196 | raise Exception("Missing the left, matching curly brace ({)") 197 | if "}" not in r: 198 | raise Exception("Missing the right, matching curly brace (})") 199 | r = r[ 200 | r.find("{") : r.rfind("}") + 1 201 | ] # Skip all preceding and succeeding contents 202 | # Since we are a JSON object, "\n" takes no effects unless it is within some key's value. However, 203 | # it can make our JSON parsing fail once it is within a value. For example: 204 | # `{\n"a": "value of a", "b": "value of \n b"}` 205 | # The first "\n" preceding "\"a\"" is valid, but the second "\n" preceding " b" makes the JSON invalid. 206 | # Indeed, the second one should be "\\n". Since we do not have an approach to distinguish them, 207 | # we conservatively assume that "\n" do not make a major contribution for the result and discard them. 208 | if drop_newline_symbol: 209 | r = r.replace("\n", " ") 210 | # We used JSON5 as LLM may generate some JS-style jsons like comments 211 | return json5.loads(r), None 212 | except Exception as e: 213 | return None, getattr(e, "message", str(e)) 214 | --------------------------------------------------------------------------------