├── .python-version ├── coderetrx ├── utils │ ├── __init__.py │ ├── _extras.py │ ├── git.py │ ├── path.py │ ├── logger.py │ ├── concurrency.py │ ├── cost_tracking.py │ └── stats.py ├── static │ ├── codebase │ │ ├── queries │ │ │ ├── treesitter │ │ │ │ ├── fine_imports │ │ │ │ │ ├── c.scm │ │ │ │ │ ├── README.md │ │ │ │ │ ├── cpp.scm │ │ │ │ │ ├── java.scm │ │ │ │ │ ├── csharp.scm │ │ │ │ │ ├── go.scm │ │ │ │ │ ├── python.scm │ │ │ │ │ ├── rust.scm │ │ │ │ │ ├── javascript.scm │ │ │ │ │ ├── typescript.scm │ │ │ │ │ ├── php.scm │ │ │ │ │ └── elixir.scm │ │ │ │ ├── tags │ │ │ │ │ ├── elisp.scm │ │ │ │ │ ├── kotlin.scm │ │ │ │ │ ├── ql.scm │ │ │ │ │ ├── python.scm │ │ │ │ │ ├── elm.scm │ │ │ │ │ ├── java.scm │ │ │ │ │ ├── c.scm │ │ │ │ │ ├── c_sharp.scm │ │ │ │ │ ├── go.scm │ │ │ │ │ ├── ruby.scm │ │ │ │ │ ├── php.scm │ │ │ │ │ ├── csharp.scm │ │ │ │ │ ├── typescript.scm │ │ │ │ │ ├── cpp.scm │ │ │ │ │ ├── rust.scm │ │ │ │ │ ├── elixir.scm │ │ │ │ │ ├── hcl.scm │ │ │ │ │ ├── dart.scm │ │ │ │ │ ├── javascript.scm │ │ │ │ │ └── ocaml.scm │ │ │ │ └── tests │ │ │ │ │ ├── php.scm │ │ │ │ │ ├── elixir.scm │ │ │ │ │ ├── javascript.scm │ │ │ │ │ ├── typescript.scm │ │ │ │ │ ├── go.scm │ │ │ │ │ ├── rust.scm │ │ │ │ │ └── python.scm │ │ │ └── codeql │ │ │ │ └── python │ │ │ │ ├── qlpack.yml │ │ │ │ ├── imports.ql │ │ │ │ ├── classes.ql │ │ │ │ └── functions.ql │ │ ├── parsers │ │ │ ├── treesitter │ │ │ │ ├── __init__.py │ │ │ │ └── queries.py │ │ │ ├── __init__.py │ │ │ ├── codeql │ │ │ │ ├── __init__.py │ │ │ │ └── queries.py │ │ │ └── factory.py │ │ ├── __init__.py │ │ └── languages.py │ ├── ripgrep │ │ ├── __init__.py │ │ └── installer.py │ ├── __init__.py │ └── codeql │ │ └── installer.py ├── __init__.py ├── impl │ └── default │ │ ├── factory.py │ │ ├── topic_extractor.py │ │ ├── smart_codebase.py │ │ ├── __init__.py │ │ └── prompt.py ├── retrieval │ ├── strategy │ │ ├── filter_symbol_name_by_llm.py │ │ ├── filter_symbol_content_by_vector.py │ │ ├── filter_keyword_by_vector.py │ │ ├── filter_filename_by_llm.py │ │ ├── filter_dependency_by_llm.py │ │ ├── __init__.py │ │ ├── factory.py │ │ ├── adaptive_filter_keyword_by_vector_and_llm.py │ │ ├── filter_keyword_by_vector_and_llm.py │ │ ├── filter_symbol_content_by_vector_and_llm.py │ │ └── adaptive_filter_symbol_content_by_vector_and_llm.py │ ├── __init__.py │ └── topic_extractor.py └── tools │ ├── __init__.py │ ├── find_file_by_name.py │ ├── base.py │ ├── view_file.py │ └── get_references.py ├── bench ├── recall_rate_comparison.png ├── effectiveness_efficiency_comparison.png ├── EXPERIMENTS.md ├── repos.txt ├── repos.lock └── queries.json ├── justfile ├── .env.example ├── .gitignore ├── scripts ├── try_codelines.py ├── view_chunks.py ├── benchmark.py └── example.py ├── pyproject.toml ├── test ├── impl │ └── default │ │ ├── test_determine_strategy.py │ │ ├── test_smart_codebase.py │ │ └── test_code_recall.py └── tools │ └── test_tools.py ├── STRATEGIES.md └── USAGE.md /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /coderetrx/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bench/recall_rate_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuanwuAI/CodeRetrX/HEAD/bench/recall_rate_comparison.png -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/c.scm: -------------------------------------------------------------------------------- 1 | ;; All includes 2 | (preproc_include path: (_) @module) -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/README.md: -------------------------------------------------------------------------------- 1 | Import module extraction queries. 2 | source: Atum 3 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/cpp.scm: -------------------------------------------------------------------------------- 1 | ;; All includes 2 | (preproc_include path: (_) @module) -------------------------------------------------------------------------------- /justfile: -------------------------------------------------------------------------------- 1 | local_qdrant: 2 | docker run -d -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/java.scm: -------------------------------------------------------------------------------- 1 | ;; All import types 2 | (import_declaration (identifier) @module) -------------------------------------------------------------------------------- /coderetrx/static/codebase/parsers/treesitter/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import TreeSitterParser 2 | 3 | __all__ = ["TreeSitterParser"] 4 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/csharp.scm: -------------------------------------------------------------------------------- 1 | ;; All using types 2 | (using_directive 3 | (identifier) @dependency) -------------------------------------------------------------------------------- /bench/effectiveness_efficiency_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuanwuAI/CodeRetrX/HEAD/bench/effectiveness_efficiency_comparison.png -------------------------------------------------------------------------------- /coderetrx/__init__.py: -------------------------------------------------------------------------------- 1 | from . import retrieval, static 2 | from .static import Codebase 3 | 4 | __all__ = ["Codebase", "retrieval", "static"] 5 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/go.scm: -------------------------------------------------------------------------------- 1 | ;; All import specs 2 | (import_spec path: (interpreted_string_literal) @module) -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/codeql/python/qlpack.yml: -------------------------------------------------------------------------------- 1 | name: python-queries 2 | version: 1.0.0 3 | dependencies: 4 | codeql/python-all: "*" 5 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/python.scm: -------------------------------------------------------------------------------- 1 | ;; Standard imports 2 | (import_statement (dotted_name) @module) 3 | (import_from_statement module_name: (_) @module) -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/rust.scm: -------------------------------------------------------------------------------- 1 | ;; All use patterns 2 | ( 3 | (use_declaration 4 | argument: (_ 5 | path: (identifier) @module 6 | ) 7 | ) 8 | ) 9 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import CodebaseParser, ExtractionType 2 | from .factory import ParserFactory 3 | from .treesitter import TreeSitterParser 4 | 5 | __all__ = ["CodebaseParser", "ExtractionType", "ParserFactory", "TreeSitterParser"] 6 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | LLM_TOPIC_EXTRACTION_MODEL_ID=openai/gpt-4.1-mini 2 | MAX_CHUNKS_ONE_FILE=500 3 | EMBEDDING_MODEL_ID=text-embedding-3-large 4 | EMBEDDING_BASE_URL= 5 | EMBEDDING_API_KEY= 6 | OPENAI_API_KEY= 7 | OPENAI_BASE_URL= 8 | GITHUB_TOKEN= 9 | KEYWORD_SENTENCE_EXTRACTION=false 10 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/parsers/codeql/__init__.py: -------------------------------------------------------------------------------- 1 | from ....codeql.codeql import CodeQLWrapper, CodeQLDatabase 2 | from .parser import CodeQLParser 3 | from .queries import CodeQLQueryTemplates 4 | 5 | __all__ = ["CodeQLWrapper", "CodeQLDatabase", "CodeQLParser", "CodeQLQueryTemplates"] 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python-generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | wheels/ 7 | *.egg-info 8 | .vscode 9 | .idea 10 | .pytest_cache 11 | 12 | # Virtual environments 13 | .venv 14 | .env 15 | 16 | # Ripgrep binary 17 | rg 18 | rg.exe 19 | .tmp 20 | .cache 21 | .data 22 | 23 | .DS_Store 24 | code_reports/ 25 | logs/ 26 | qdrant_storage/ -------------------------------------------------------------------------------- /coderetrx/static/ripgrep/__init__.py: -------------------------------------------------------------------------------- 1 | from .ripgrep import ( 2 | ripgrep_glob, 3 | ripgrep_search, 4 | ripgrep_search_symbols, 5 | ripgrep_raw, 6 | GrepMatchResult, 7 | ) 8 | 9 | __all__ = [ 10 | "ripgrep_glob", 11 | "ripgrep_search", 12 | "ripgrep_search_symbols", 13 | "ripgrep_raw", 14 | "GrepMatchResult", 15 | ] 16 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/javascript.scm: -------------------------------------------------------------------------------- 1 | ;; Static imports 2 | (import_statement source: (string) @module) 3 | 4 | ;; Require 5 | (call_expression 6 | function: (identifier) @func 7 | arguments: (arguments (string) @module) 8 | (#eq? @func "require")) 9 | 10 | ;; Re-exports 11 | (export_statement source: (string) @module) -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/typescript.scm: -------------------------------------------------------------------------------- 1 | ;; Static imports 2 | (import_statement source: (string) @module) 3 | 4 | ;; Require 5 | (call_expression 6 | function: (identifier) @func 7 | arguments: (arguments (string) @module) 8 | (#eq? @func "require")) 9 | 10 | ;; Re-exports 11 | (export_statement source: (string) @module) -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/elisp.scm: -------------------------------------------------------------------------------- 1 | ;; defun/defsubst 2 | (function_definition name: (symbol) @name.definition.function) @definition.function 3 | 4 | ;; Treat macros as function definitions for the sake of TAGS. 5 | (macro_definition name: (symbol) @name.definition.function) @definition.function 6 | 7 | ;; Match function calls 8 | (list (symbol) @name.reference.function) @reference.function 9 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tests/php.scm: -------------------------------------------------------------------------------- 1 | ;; PHPUnit test methods (methods starting with test) 2 | ;; Only match test methods, not test classes, to allow non-test methods in test classes 3 | ( 4 | (method_declaration 5 | name: (name) @run @_test_method_name 6 | (#match? @_test_method_name "^test") 7 | ) @_php-test-method 8 | (#set! tag php-test-method) 9 | ) 10 | -------------------------------------------------------------------------------- /coderetrx/utils/_extras.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def require_extra(pkg_name: str, extra: str): 5 | """ 6 | Raises an ImportError if pkg_name isn't importable, pointing users 7 | to install the given extra. 8 | """ 9 | try: 10 | importlib.import_module(pkg_name) 11 | except ImportError as e: 12 | raise ImportError( 13 | f"This feature requires the '{extra}' extra. " 14 | f"Install with: uv add coderetrx[{extra}]" 15 | ) from e 16 | -------------------------------------------------------------------------------- /coderetrx/impl/default/factory.py: -------------------------------------------------------------------------------- 1 | # Compatibility import stub for backwards compatibility 2 | # CodebaseFactory has been moved to coderetrx.retrieval.factory 3 | 4 | # Will be removed in future versions 5 | from coderetrx.retrieval.factory import CodebaseFactory 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | logger.warning("The 'coderetrx.impl.default.factory' module is deprecated and will be removed in future versions, use coderetrx.retrieval.factory instead.") 9 | 10 | __all__ = ["CodebaseFactory"] -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/php.scm: -------------------------------------------------------------------------------- 1 | ;; PHP import statements 2 | ;; Basic namespace imports (use statements) 3 | (namespace_use_declaration 4 | (namespace_use_clause 5 | (qualified_name) @module)) 6 | 7 | ;; require/include statements with string literals 8 | (require_expression 9 | (string) @module) 10 | 11 | (require_once_expression 12 | (string) @module) 13 | 14 | (include_expression 15 | (string) @module) 16 | 17 | (include_once_expression 18 | (string) @module) 19 | -------------------------------------------------------------------------------- /coderetrx/impl/default/topic_extractor.py: -------------------------------------------------------------------------------- 1 | # Compatibility import stub for backwards compatibility 2 | # TopicExtractor has been moved to coderetrx.retrieval.topic_extractor 3 | # Will be removed in future versions 4 | 5 | from coderetrx.retrieval.topic_extractor import TopicExtractor 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | logger.warning("The 'coderetrx.impl.default.topic_extractor' module is deprecated and will be removed in future versions, use coderetrx.retrieval.topic_extractor instead.") 9 | 10 | 11 | __all__ = ["TopicExtractor"] -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/codeql/python/imports.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Python Imports 3 | * @description Find all import statements in Python code 4 | * @kind table 5 | * @id codeql/python-imports 6 | */ 7 | 8 | import python 9 | 10 | from Import imp 11 | where exists(imp.getLocation()) and exists(imp.getAnImportedModuleName()) 12 | select imp.getLocation().getFile().getRelativePath(), imp.getAnImportedModuleName(), 13 | imp.getLocation().getStartLine(), imp.getLocation().getEndLine(), 14 | imp.getLocation().getStartColumn(), imp.getLocation().getEndColumn() -------------------------------------------------------------------------------- /coderetrx/impl/default/smart_codebase.py: -------------------------------------------------------------------------------- 1 | # Compatibility import stub for backwards compatibility 2 | # SmartCodebase has been moved to coderetrx.retrieval.smart_codebase 3 | # Will be removed in future versions 4 | 5 | from coderetrx.retrieval.smart_codebase import SmartCodebase, SmartCodebaseSettings 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | logger.warning("The 'coderetrx.impl.default.smart_codebase' module is deprecated and will be removed in future versions, use coderetrx.retrieval.smart_codebase instead.") 9 | 10 | __all__ = ["SmartCodebase", "SmartCodebaseSettings"] -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/fine_imports/elixir.scm: -------------------------------------------------------------------------------- 1 | ; tree-sitter query to extract Elixir dependencies 2 | 3 | ; Match alias statements 4 | (call 5 | target: (identifier) @directive (#eq? @directive "alias") 6 | (arguments (alias) @module) 7 | ) @alias_stmt 8 | 9 | ; Match import statements 10 | (call 11 | target: (identifier) @directive (#eq? @directive "import") 12 | (arguments (alias) @module) 13 | ) @import_stmt 14 | 15 | ; Match require statements 16 | (call 17 | target: (identifier) @directive (#eq? @directive "require") 18 | (arguments (alias) @module) 19 | ) @require_stmt -------------------------------------------------------------------------------- /scripts/try_codelines.py: -------------------------------------------------------------------------------- 1 | from rich import print 2 | from coderetrx.static.codebase import Codebase, File 3 | 4 | code_file = File.jit_for_testing('tst.py', """ 5 | import os 6 | 7 | def x(): 8 | def y(): 9 | return x 10 | print("hello world!") 11 | for i in range(1000): 12 | print(i) 13 | 14 | print("hello world! Again!!") 15 | g = 20 16 | do(sth) 17 | print("What on earth are we doing") 18 | y = 10 19 | return y 20 | 21 | x() 22 | """ 23 | ) 24 | 25 | code_file.init_all() 26 | 27 | for line in code_file.get_lines(max_chars=100): 28 | print(line) 29 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/codeql/python/classes.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Python Classes with last statement 3 | * @description Find all class definitions and their last statement 4 | * @kind table 5 | * @id codeql/python-classes-laststmt 6 | */ 7 | 8 | import python 9 | 10 | from Class cls, Stmt stmt 11 | where 12 | stmt = cls.getBody().getLastItem() 13 | select 14 | cls.getLocation().getFile().getRelativePath(), 15 | cls.getQualifiedName(), 16 | cls.getLocation().getStartLine(), 17 | stmt.getLocation().getEndLine(), 18 | cls.getLocation().getStartColumn(), 19 | cls.getLocation().getEndColumn() -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/codeql/python/functions.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Python Functions with last statement 3 | * @description Find all functions and their last statement 4 | * @kind table 5 | * @id codeql/python-functions-laststmt 6 | */ 7 | 8 | import python 9 | 10 | from Function func, Stmt stmt 11 | where 12 | stmt = func.getBody().getLastItem() 13 | select 14 | func.getLocation().getFile().getRelativePath(), 15 | func.getQualifiedName(), 16 | func.getLocation().getStartLine(), 17 | stmt.getLocation().getEndLine(), 18 | func.getLocation().getStartColumn(), 19 | func.getLocation().getEndColumn() -------------------------------------------------------------------------------- /coderetrx/impl/default/__init__.py: -------------------------------------------------------------------------------- 1 | # Compatibility import stubs for backwards compatibility 2 | # All functionality has been moved to coderetrx.retrieval 3 | # Will be removed in future versions 4 | 5 | from coderetrx.retrieval import ( 6 | CodebaseFactory, 7 | SmartCodebase, 8 | TopicExtractor, 9 | ) 10 | import logging 11 | logger = logging.getLogger(__name__) 12 | logger.warning("The 'coderetrx.impl.default' module is deprecated and will be removed in future versions. use coderetrx.retrieval instead.") 13 | 14 | # For backwards compatibility, re-export the main classes 15 | __all__ = [ 16 | "CodebaseFactory", 17 | "SmartCodebase", 18 | "TopicExtractor", 19 | ] 20 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tests/elixir.scm: -------------------------------------------------------------------------------- 1 | ; Macros `describe`, `test` and `property`. 2 | ; This matches the ExUnit test style. 3 | ( 4 | (call 5 | target: (identifier) @run (#any-of? @run "describe" "test" "property") 6 | ) @_elixir-test 7 | (#set! tag elixir-test) 8 | ) 9 | 10 | ; Modules containing at least one `describe`, `test` and `property`. 11 | ; This matches the ExUnit test style. 12 | ( 13 | (call 14 | target: (identifier) @run (#eq? @run "defmodule") 15 | (do_block 16 | (call target: (identifier) @_keyword (#any-of? @_keyword "describe" "test" "property")) 17 | ) 18 | ) @_elixir-module-test 19 | (#set! tag elixir-module-test) 20 | ) -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tests/javascript.scm: -------------------------------------------------------------------------------- 1 | ; Add support for (node:test, bun:test and Jest) runnable 2 | ; Function expression that has `it`, `test` or `describe` as the function name 3 | ( 4 | (call_expression 5 | function: [ 6 | (identifier) @_name 7 | (member_expression 8 | object: [ 9 | (identifier) @_name 10 | (member_expression object: (identifier) @_name) 11 | ] 12 | ) 13 | ] 14 | (#any-of? @_name "it" "test" "describe" "context" "suite") 15 | arguments: ( 16 | arguments . (string (string_fragment) @run) 17 | ) 18 | ) @_js-test 19 | 20 | (#set! tag js-test) 21 | ) 22 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tests/typescript.scm: -------------------------------------------------------------------------------- 1 | ; Add support for (node:test, bun:test and Jest) runnable 2 | ; Function expression that has `it`, `test` or `describe` as the function name 3 | ( 4 | (call_expression 5 | function: [ 6 | (identifier) @_name 7 | (member_expression 8 | object: [ 9 | (identifier) @_name 10 | (member_expression object: (identifier) @_name) 11 | ] 12 | ) 13 | ] 14 | (#any-of? @_name "it" "test" "describe" "context" "suite") 15 | arguments: ( 16 | arguments . (string (string_fragment) @run) 17 | ) 18 | ) @_js-test 19 | 20 | (#set! tag js-test) 21 | ) 22 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/kotlin.scm: -------------------------------------------------------------------------------- 1 | ; Definitions 2 | 3 | (class_declaration 4 | (type_identifier) @name.definition.class) @definition.class 5 | 6 | (function_declaration 7 | (simple_identifier) @name.definition.function) @definition.function 8 | 9 | (object_declaration 10 | (type_identifier) @name.definition.object) @definition.object 11 | 12 | ; References 13 | 14 | (call_expression 15 | [ 16 | (simple_identifier) @name.reference.call 17 | (navigation_expression 18 | (navigation_suffix 19 | (simple_identifier) @name.reference.call)) 20 | ]) @reference.call 21 | 22 | (delegation_specifier 23 | [ 24 | (user_type) @name.reference.type 25 | (constructor_invocation 26 | (user_type) @name.reference.type) 27 | ]) @reference.type 28 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/ql.scm: -------------------------------------------------------------------------------- 1 | (classlessPredicate 2 | name: (predicateName) @name.definition.function) @definition.function 3 | 4 | (memberPredicate 5 | name: (predicateName) @name.definition.method) @definition.method 6 | 7 | (aritylessPredicateExpr 8 | name: (literalId) @name.reference.call) @reference.call 9 | 10 | (module 11 | name: (moduleName) @name.definition.module) @definition.module 12 | 13 | (dataclass 14 | name: (className) @name.definition.class) @definition.class 15 | 16 | (datatype 17 | name: (className) @name.definition.class) @definition.class 18 | 19 | (datatypeBranch 20 | name: (className) @name.definition.class) @definition.class 21 | 22 | (qualifiedRhs 23 | name: (predicateName) @name.reference.call) @reference.call 24 | 25 | (typeExpr 26 | name: (className) @name.reference.type) @reference.type 27 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/python.scm: -------------------------------------------------------------------------------- 1 | (class_definition 2 | name: (identifier) @name.definition.class) @definition.class 3 | 4 | (function_definition 5 | name: (identifier) @name.definition.function) @definition.function 6 | 7 | (call 8 | function: [ 9 | (identifier) @name.reference.call 10 | (attribute 11 | attribute: (identifier) @name.reference.call) 12 | ]) @reference.call 13 | 14 | ; Handle imports 15 | (import_statement) @import 16 | (import_from_statement) @import 17 | 18 | ; Variable definitions - all assignment statements 19 | (assignment 20 | left: (identifier) @name.definition.variable) @definition.variable 21 | 22 | (assignment 23 | left: (pattern_list 24 | (identifier) @name.definition.variable)) @definition.variable 25 | 26 | (assignment 27 | left: (attribute) @name.definition.variable) @definition.variable 28 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/filter_symbol_name_by_llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Strategy for filtering symbols using LLM. 3 | """ 4 | 5 | from typing import List, override 6 | from .base import FilterByLLMStrategy 7 | from ..smart_codebase import SmartCodebase as Codebase, LLMMapFilterTargetType 8 | from coderetrx.static import Symbol 9 | 10 | 11 | class FilterSymbolNameByLLMStrategy(FilterByLLMStrategy[Symbol]): 12 | """Strategy to filter symbols using LLM.""" 13 | 14 | name: str = "FILTER_SYMBOL_NAME_BY_LLM" 15 | 16 | @override 17 | def get_strategy_name(self) -> str: 18 | return self.name 19 | 20 | @override 21 | def get_target_type(self) -> LLMMapFilterTargetType: 22 | return "symbol_name" 23 | 24 | @override 25 | def extract_file_paths( 26 | self, elements: List[Symbol], codebase: Codebase 27 | ) -> List[str]: 28 | return [str(symbol.file.path) for symbol in elements] 29 | -------------------------------------------------------------------------------- /coderetrx/retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | from .code_recall import ( 2 | coderetrx_filter, 3 | coderetrx_mapping, 4 | llm_traversal_filter, 5 | llm_traversal_mapping, 6 | ) 7 | from .strategy import ( 8 | RecallStrategy, 9 | ) 10 | from .smart_codebase import ( 11 | SmartCodebase, 12 | SmartCodebaseSettings, 13 | LLMMapFilterTargetType, 14 | SimilaritySearchTargetType, 15 | LLMCallMode, 16 | CodeMapFilterResult, 17 | ) 18 | from .topic_extractor import TopicExtractor 19 | from .factory import CodebaseFactory 20 | 21 | __all__ = [ 22 | "coderetrx_filter", 23 | "coderetrx_mapping", 24 | "llm_traversal_filter", 25 | "llm_traversal_mapping", 26 | "RecallStrategy", 27 | "SmartCodebase", 28 | "SmartCodebaseSettings", 29 | "LLMMapFilterTargetType", 30 | "SimilaritySearchTargetType", 31 | "LLMCallMode", 32 | "CodeMapFilterResult", 33 | "TopicExtractor", 34 | "CodebaseFactory", 35 | ] 36 | -------------------------------------------------------------------------------- /coderetrx/impl/default/prompt.py: -------------------------------------------------------------------------------- 1 | # Compatibility import stub for backwards compatibility 2 | # Prompt utilities have been moved to coderetrx.retrieval.prompt 3 | # Will be removed in future versions 4 | 5 | from coderetrx.retrieval.prompt import * 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | logger.warning("The 'coderetrx.impl.default.prompt' module is deprecated and will be removed in future versions, use coderetrx.retrieval.prompt instead.") 9 | 10 | __all__ = [ 11 | "llm_filter_prompt_template", 12 | "llm_mapping_prompt_template", 13 | "topic_extraction_prompt_template", 14 | "KeywordExtractorResult", 15 | "llm_filter_function_call_system_prompt", 16 | "llm_mapping_function_call_system_prompt", 17 | "topic_extraction_function_call_system_prompt", 18 | "filter_and_mapping_function_call_user_prompt_template", 19 | "topic_extraction_function_call_user_prompt_template", 20 | "get_filter_function_definition", 21 | "get_mapping_function_definition", 22 | "get_topic_extraction_function_definition", 23 | ] -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/elm.scm: -------------------------------------------------------------------------------- 1 | (value_declaration (function_declaration_left (lower_case_identifier) @name.definition.function)) @definition.function 2 | 3 | (function_call_expr (value_expr (value_qid) @name.reference.function)) @reference.function 4 | (exposed_value (lower_case_identifier) @name.reference.function) @reference.function 5 | (type_annotation ((lower_case_identifier) @name.reference.function) (colon)) @reference.function 6 | 7 | (type_declaration ((upper_case_identifier) @name.definition.type)) @definition.type 8 | 9 | (type_ref (upper_case_qid (upper_case_identifier) @name.reference.type)) @reference.type 10 | (exposed_type (upper_case_identifier) @name.reference.type) @reference.type 11 | 12 | (type_declaration (union_variant (upper_case_identifier) @name.definition.union)) @definition.union 13 | 14 | (value_expr (upper_case_qid (upper_case_identifier) @name.reference.union)) @reference.union 15 | 16 | 17 | (module_declaration 18 | (upper_case_qid (upper_case_identifier)) @name.definition.module 19 | ) @definition.module 20 | -------------------------------------------------------------------------------- /coderetrx/static/__init__.py: -------------------------------------------------------------------------------- 1 | from .codebase import ( 2 | Codebase, 3 | File, 4 | FileType, 5 | Symbol, 6 | Keyword, 7 | Dependency, 8 | CallGraphEdge, 9 | CodebaseModel, 10 | FileModel, 11 | SymbolModel, 12 | KeywordModel, 13 | DependencyModel, 14 | CallGraphEdgeModel, 15 | CodeElement, 16 | CodeElementTypeVar, 17 | ) 18 | 19 | from .ripgrep import ( 20 | ripgrep_glob, 21 | ripgrep_search, 22 | ripgrep_search_symbols, 23 | ripgrep_raw, 24 | GrepMatchResult, 25 | ) 26 | 27 | __all__ = [ 28 | # Codebase exports 29 | "Codebase", 30 | "File", 31 | "FileType", 32 | "Symbol", 33 | "Keyword", 34 | "Dependency", 35 | "CallGraphEdge", 36 | "CodebaseModel", 37 | "FileModel", 38 | "SymbolModel", 39 | "KeywordModel", 40 | "DependencyModel", 41 | "CallGraphEdgeModel", 42 | "CodeElementTypeVar", 43 | # Ripgrep exports 44 | "ripgrep_glob", 45 | "ripgrep_search", 46 | "ripgrep_search_symbols", 47 | "ripgrep_raw", 48 | "GrepMatchResult", 49 | "CodeElement" 50 | ] 51 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/__init__.py: -------------------------------------------------------------------------------- 1 | from .codebase import ( 2 | Codebase, 3 | File, 4 | FileType, 5 | Symbol, 6 | Keyword, 7 | Dependency, 8 | CallGraphEdge, 9 | CodeElementTypeVar, 10 | CodeElement, 11 | CodeChunk, 12 | ChunkType, 13 | CodeHunk, 14 | CodeLine, 15 | ) 16 | from .models import ( 17 | CodebaseModel, 18 | FileModel, 19 | SymbolModel, 20 | KeywordModel, 21 | DependencyModel, 22 | CallGraphEdgeModel, 23 | ) 24 | from .languages import IDXSupportedLanguage, IDXSupportedTag, BUILTIN_CRYPTO_LIBS 25 | 26 | __all__ = [ 27 | "Codebase", 28 | "File", 29 | "FileType", 30 | "Symbol", 31 | "Keyword", 32 | "Dependency", 33 | "CallGraphEdge", 34 | "CodebaseModel", 35 | "FileModel", 36 | "SymbolModel", 37 | "KeywordModel", 38 | "DependencyModel", 39 | "CallGraphEdgeModel", 40 | "CodeElementTypeVar", 41 | "CodeElement", 42 | "CodeChunk", 43 | "ChunkType", 44 | "CodeHunk", 45 | "CodeLine", 46 | "IDXSupportedLanguage", 47 | "IDXSupportedTag", 48 | "BUILTIN_CRYPTO_LIBS", 49 | ] 50 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/java.scm: -------------------------------------------------------------------------------- 1 | (class_declaration 2 | name: (identifier) @name.definition.class) @definition.class 3 | 4 | (method_declaration 5 | name: (identifier) @name.definition.method) @definition.method 6 | 7 | (method_invocation 8 | name: (identifier) @name.reference.call 9 | arguments: (argument_list) @reference.call) 10 | 11 | (interface_declaration 12 | name: (identifier) @name.definition.interface) @definition.interface 13 | 14 | (type_list 15 | (type_identifier) @name.reference.implementation) @reference.implementation 16 | 17 | (object_creation_expression 18 | type: (type_identifier) @name.reference.class) @reference.class 19 | 20 | (superclass (type_identifier) @name.reference.class) @reference.class 21 | 22 | (import_declaration) @import 23 | 24 | ; Variable definitions - field declarations 25 | (field_declaration 26 | declarator: (variable_declarator 27 | name: (identifier) @name.definition.variable)) @definition.variable 28 | 29 | ; Local variable declarations 30 | (local_variable_declaration 31 | declarator: (variable_declarator 32 | name: (identifier) @name.definition.variable)) @definition.variable 33 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/c.scm: -------------------------------------------------------------------------------- 1 | (struct_specifier name: (type_identifier) @name.definition.class body:(_)) @definition.class 2 | 3 | (declaration type: (union_specifier name: (type_identifier) @name.definition.class)) @definition.class 4 | 5 | (function_definition 6 | declarator: (function_declarator 7 | declarator: (identifier) @name.definition.function)) @definition.function 8 | 9 | (function_definition 10 | declarator: (pointer_declarator 11 | declarator: (function_declarator 12 | declarator: (identifier) @name.definition.function))) @definition.function 13 | 14 | (type_definition declarator: (type_identifier) @name.definition.type) @definition.type 15 | 16 | (enum_specifier name: (type_identifier) @name.definition.type) @definition.type 17 | 18 | (preproc_include) @import 19 | 20 | ; Variable definitions - all declarations 21 | (declaration 22 | declarator: (identifier) @name.definition.variable) @definition.variable 23 | 24 | (declaration 25 | declarator: (init_declarator 26 | declarator: (identifier) @name.definition.variable)) @definition.variable 27 | 28 | (declaration 29 | declarator: (pointer_declarator 30 | declarator: (identifier) @name.definition.variable)) @definition.variable 31 | -------------------------------------------------------------------------------- /bench/EXPERIMENTS.md: -------------------------------------------------------------------------------- 1 | # Experiments 2 | 3 | We conducted comprehensive experiments on large-scale benchmarks across multiple programming languages and repository sizes to validate the effectiveness of our code retrieval strategies. The analysis demonstrates how **`coderetrx_filter`** performs across various bug types and complexity levels. 4 | 5 | ## Experiment Setup 6 | 7 | All methods share the same indexing and parsing pipeline (repository snapshot, language parser, and symbol table extraction). The Symbol Vector Embedding baseline encodes identifier-level semantics only, while Line-per-Symbol indexes each line within a function / class's body, enabling precise structure-aware retrieval. We used **gpt-oss-120b** for dataset construction and evaluation. 8 | 9 | ## Performance Benchmarks 10 | 11 | The following results demonstrate the performance across different strategies. Figure 1 shows the Recall Rate Comparisons across languages and repository sizes and Table 1 compares the Effectiveness and Efficiency between CodeRetrX and baselines. 12 | 13 | ![Figure 1: Recall Rate Comparisons across languages and repository sizes](recall_rate_comparison.png) 14 | ![Table 1: Effectiveness and Efficiency Comparison](effectiveness_efficiency_comparison.png) 15 | 16 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/c_sharp.scm: -------------------------------------------------------------------------------- 1 | (class_declaration 2 | name: (identifier) @name.definition.class 3 | ) @definition.class 4 | 5 | (class_declaration 6 | bases: (base_list (_) @name.reference.class) 7 | ) @reference.class 8 | 9 | (interface_declaration 10 | name: (identifier) @name.definition.interface 11 | ) @definition.interface 12 | 13 | (interface_declaration 14 | bases: (base_list (_) @name.reference.interface) 15 | ) @reference.interface 16 | 17 | (method_declaration 18 | name: (identifier) @name.definition.method 19 | ) @definition.method 20 | 21 | (object_creation_expression 22 | type: (identifier) @name.reference.class 23 | ) @reference.class 24 | 25 | (type_parameter_constraints_clause 26 | target: (identifier) @name.reference.class 27 | ) @reference.class 28 | 29 | (type_constraint 30 | type: (identifier) @name.reference.class 31 | ) @reference.class 32 | 33 | (variable_declaration 34 | type: (identifier) @name.reference.class 35 | ) @reference.class 36 | 37 | (invocation_expression 38 | function: 39 | (member_access_expression 40 | name: (identifier) @name.reference.send 41 | ) 42 | ) @reference.send 43 | 44 | (namespace_declaration 45 | name: (identifier) @name.definition.module 46 | ) @definition.module 47 | -------------------------------------------------------------------------------- /bench/repos.txt: -------------------------------------------------------------------------------- 1 | // Boringtun, rust, Small, 7937 2 | https://github.com/cloudflare/boringtun 3 | 4 | // Rosenpass, rust, Mid, 30063 5 | https://github.com/rosenpass/rosenpass 6 | 7 | // Neon, rust, Large, 375060 8 | https://github.com/neondatabase/neon 9 | 10 | // Python p2p, python, Small, 1189 11 | https://github.com/GianisTsol/python-p2p 12 | 13 | // Magic Wormhole, python, Mid, 27538 14 | https://github.com/magic-wormhole/magic-wormhole 15 | 16 | // Zulip, python, Large, 499528 17 | https://github.com/zulip/zulip 18 | 19 | // Anubis, golang, Small, 5163 20 | https://github.com/TecharoHQ/anubis 21 | 22 | // FSCrypt, golang, Mid, 18815 23 | https://github.com/google/fscrypt 24 | 25 | // Ethereum, golang, Large, 272576 26 | https://github.com/ethereum/go-ethereum 27 | 28 | // Mastercard Client Encryption, java, Small, 8654 29 | https://github.com/Mastercard/client-encryption-java 30 | 31 | // Keycloak, java, Large, 955351 32 | https://github.com/keycloak/keycloak 33 | 34 | // Cryptomator, java, Mid, 38883 35 | https://github.com/cryptomator/cryptomator 36 | 37 | // Padloc, js/ts, Large, 136233 38 | https://github.com/padloc/padloc 39 | 40 | // Vaultwarden, js/ts, Mid, 62609 41 | https://github.com/dani-garcia/vaultwarden 42 | 43 | // Swifty, js/ts, Small, 8295 44 | https://github.com/swiftyapp/swifty 45 | -------------------------------------------------------------------------------- /coderetrx/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains tools exposed by mcp server. 3 | """ 4 | 5 | from .base import BaseTool 6 | from .find_file_by_name import FindFileByNameTool 7 | from .get_references import GetReferenceTool 8 | from .keyword_search import KeywordSearchTool 9 | from .list_dir import ListDirTool 10 | from .view_file import ViewFileTool 11 | from typing import Type 12 | import sys 13 | 14 | # Export the tools as the default 15 | __all__ = [ 16 | "FindFileByNameTool", 17 | "GetReferenceTool", 18 | "KeywordSearchTool", 19 | "ListDirTool", 20 | "ViewFileTool", 21 | ] 22 | tool_classes = [ 23 | FindFileByNameTool, 24 | GetReferenceTool, 25 | KeywordSearchTool, 26 | ListDirTool, 27 | ViewFileTool, 28 | ] 29 | tool_map: dict[str, dict[str, BaseTool]] = {} 30 | 31 | 32 | def get_tool_class(name: str) -> Type[BaseTool]: 33 | for cls in tool_classes: 34 | if getattr(cls, "name", None) == name: 35 | return cls 36 | raise ValueError(f"{name} is not a valid tool") 37 | 38 | 39 | def get_tool(repo_url: str, name: str) -> BaseTool: 40 | if tool_map.get(repo_url) is None: 41 | tool_map[repo_url] = {} 42 | if tool_map[repo_url].get(name) is None: 43 | tool_map[repo_url][name] = get_tool_class(name)(repo_url) 44 | return tool_map[repo_url][name] 45 | 46 | 47 | def list_tool_class() -> list[BaseTool]: 48 | return tool_classes 49 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/filter_symbol_content_by_vector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Strategy for filtering symbols using vector similarity search. 3 | """ 4 | 5 | from typing import List, override 6 | from .base import FilterByVectorStrategy 7 | from ..smart_codebase import SmartCodebase as Codebase, SimilaritySearchTargetType 8 | from coderetrx.static import Symbol 9 | 10 | 11 | class FilterSymbolContentByVectorStrategy(FilterByVectorStrategy[Symbol]): 12 | """Strategy to filter symbols using vector similarity search.""" 13 | 14 | name: str = "FILTER_SYMBOL_CONTENT_BY_VECTOR" 15 | 16 | @override 17 | def get_strategy_name(self) -> str: 18 | return self.name 19 | 20 | @override 21 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]: 22 | return ["symbol_content"] 23 | 24 | @override 25 | def get_collection_size(self, codebase: Codebase) -> int: 26 | return len(codebase.symbols) 27 | 28 | @override 29 | def extract_file_paths( 30 | self, elements: List[Symbol], codebase: Codebase, subdirs_or_files: List[str] 31 | ) -> List[str]: 32 | file_paths = [] 33 | for symbol in elements: 34 | if isinstance(symbol, Symbol): 35 | file_path = str(symbol.file.path) 36 | if file_path.startswith(tuple(subdirs_or_files)): 37 | file_paths.append(file_path) 38 | return list(set(file_paths)) 39 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/go.scm: -------------------------------------------------------------------------------- 1 | ( 2 | (comment)* @doc 3 | . 4 | (function_declaration 5 | name: (identifier) @name.definition.function) @definition.function 6 | (#strip! @doc "^//\\s*") 7 | (#set-adjacent! @doc @definition.function) 8 | ) 9 | 10 | ( 11 | (comment)* @doc 12 | . 13 | (method_declaration 14 | name: (field_identifier) @name.definition.method) @definition.method 15 | (#strip! @doc "^//\\s*") 16 | (#set-adjacent! @doc @definition.method) 17 | ) 18 | 19 | (call_expression 20 | function: [ 21 | (identifier) @name.reference.call 22 | (parenthesized_expression (identifier) @name.reference.call) 23 | (selector_expression field: (field_identifier) @name.reference.call) 24 | (parenthesized_expression (selector_expression field: (field_identifier) @name.reference.call)) 25 | ]) @reference.call 26 | 27 | (type_spec 28 | name: (type_identifier) @name.definition.type) @definition.type 29 | 30 | (type_identifier) @name.reference.type @reference.type 31 | 32 | (import_spec) @import 33 | 34 | ; Variable definitions 35 | (var_declaration 36 | (var_spec 37 | name: (identifier) @name.definition.variable)) @definition.variable 38 | 39 | (const_declaration 40 | (const_spec 41 | name: (identifier) @name.definition.variable)) @definition.variable 42 | 43 | (short_var_declaration 44 | left: (expression_list 45 | (identifier) @name.definition.variable)) @definition.variable 46 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/filter_keyword_by_vector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Strategy for filtering keywords using vector similarity search. 3 | """ 4 | 5 | from typing import List, override, Any 6 | from .base import FilterByVectorStrategy 7 | from ..smart_codebase import SmartCodebase as Codebase, SimilaritySearchTargetType 8 | from coderetrx.static import Keyword 9 | 10 | 11 | class FilterKeywordByVectorStrategy(FilterByVectorStrategy[Keyword]): 12 | """Strategy to filter keywords using vector similarity search.""" 13 | 14 | name: str = "FILTER_KEYWORD_BY_VECTOR" 15 | 16 | @override 17 | def get_strategy_name(self) -> str: 18 | return self.name 19 | 20 | @override 21 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]: 22 | return ["keyword"] 23 | 24 | @override 25 | def get_collection_size(self, codebase: Codebase) -> int: 26 | return len(codebase.keywords) 27 | 28 | @override 29 | def extract_file_paths( 30 | self, elements: List[Keyword], codebase: Codebase, subdirs_or_files: List[str] 31 | ) -> List[str]: 32 | referenced_paths = set() 33 | for item in elements: 34 | if isinstance(item, Keyword) and item.referenced_by: 35 | for ref_file in item.referenced_by: 36 | if str(ref_file.path).startswith(tuple(subdirs_or_files)): 37 | referenced_paths.add(str(ref_file.path)) 38 | return list(referenced_paths) 39 | -------------------------------------------------------------------------------- /coderetrx/utils/git.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | from pathlib import Path 5 | from typing import Tuple 6 | from git import Repo, GitCommandError 7 | 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def clone_repo_if_not_exists(repo_url: str, target_path: str) -> str: 14 | logger.info(f"clone repo: {repo_url}") 15 | repo_path = Path(target_path) 16 | if not repo_path.exists(): 17 | try: 18 | logger.info(f"Cloning repository {repo_url} into {repo_path}") 19 | # todo: remove ssh clone 20 | Repo.clone_from( 21 | repo_url, 22 | repo_path, 23 | env=( 24 | {"GIT_SSH_COMMAND": f'ssh -i {Path(__file__).parent/"sg_rsa"}'} 25 | if repo_url.endswith(".git") 26 | else None 27 | ), 28 | depth=1, 29 | ) 30 | except GitCommandError as e: 31 | repo_path.unlink(missing_ok=True) 32 | raise Exception(f"Clone failed: {e}") 33 | return repo_path.as_posix() 34 | 35 | 36 | def get_repo_id(repo_url: str) -> str: 37 | if repo_url.startswith("http"): 38 | repo_id = "_".join(repo_url.split("/")[-2:]) 39 | else: 40 | # todo: only for backward compatibility, we need to remove this in the future 41 | repo_id = repo_url.split("/")[-1] 42 | repo_id = repo_id.replace(".git", "") 43 | return repo_id 44 | 45 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/ruby.scm: -------------------------------------------------------------------------------- 1 | ; Method definitions 2 | 3 | ( 4 | (comment)* @doc 5 | . 6 | [ 7 | (method 8 | name: (_) @name.definition.method) @definition.method 9 | (singleton_method 10 | name: (_) @name.definition.method) @definition.method 11 | ] 12 | (#strip! @doc "^#\\s*") 13 | (#select-adjacent! @doc @definition.method) 14 | ) 15 | 16 | (alias 17 | name: (_) @name.definition.method) @definition.method 18 | 19 | (setter 20 | (identifier) @ignore) 21 | 22 | ; Class definitions 23 | 24 | ( 25 | (comment)* @doc 26 | . 27 | [ 28 | (class 29 | name: [ 30 | (constant) @name.definition.class 31 | (scope_resolution 32 | name: (_) @name.definition.class) 33 | ]) @definition.class 34 | (singleton_class 35 | value: [ 36 | (constant) @name.definition.class 37 | (scope_resolution 38 | name: (_) @name.definition.class) 39 | ]) @definition.class 40 | ] 41 | (#strip! @doc "^#\\s*") 42 | (#select-adjacent! @doc @definition.class) 43 | ) 44 | 45 | ; Module definitions 46 | 47 | ( 48 | (module 49 | name: [ 50 | (constant) @name.definition.module 51 | (scope_resolution 52 | name: (_) @name.definition.module) 53 | ]) @definition.module 54 | ) 55 | 56 | ; Calls 57 | 58 | (call method: (identifier) @name.reference.call) @reference.call 59 | 60 | ( 61 | [(identifier) (constant)] @name.reference.call @reference.call 62 | (#is-not? local) 63 | (#not-match? @name.reference.call "^(lambda|load|require|require_relative|__FILE__|__LINE__)$") 64 | ) 65 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/php.scm: -------------------------------------------------------------------------------- 1 | ;; Class definitions 2 | (class_declaration 3 | name: (name) @name.definition.class) @definition.class 4 | 5 | ;; Interface definitions 6 | (interface_declaration 7 | name: (name) @name.definition.interface) @definition.interface 8 | 9 | ;; Trait definitions 10 | (trait_declaration 11 | name: (name) @name.definition.class) @definition.class 12 | 13 | ;; Function definitions 14 | (function_definition 15 | name: (name) @name.definition.function) @definition.function 16 | 17 | ;; Method definitions 18 | (method_declaration 19 | name: (name) @name.definition.method) @definition.method 20 | 21 | ;; Namespace imports 22 | (namespace_use_declaration 23 | (namespace_use_clause 24 | (qualified_name) @name.import)) @import 25 | 26 | ;; Object creation (class references) 27 | (object_creation_expression 28 | [ 29 | (qualified_name (name) @name.reference.class) 30 | (variable_name (name) @name.reference.class) 31 | ]) @reference.class 32 | 33 | ;; Function calls 34 | (function_call_expression 35 | function: [ 36 | (qualified_name (name) @name.reference.call) 37 | (variable_name (name)) @name.reference.call 38 | ]) @reference.call 39 | 40 | ;; Static method calls 41 | (scoped_call_expression 42 | name: (name) @name.reference.call) @reference.call 43 | 44 | ;; Member method calls 45 | (member_call_expression 46 | name: (name) @name.reference.call) @reference.call 47 | 48 | ;; Property declarations (class variables) 49 | (property_declaration 50 | (property_element 51 | (variable_name) @name.definition.variable)) @definition.variable 52 | -------------------------------------------------------------------------------- /scripts/view_chunks.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import sys 3 | 4 | from coderetrx import Codebase 5 | from coderetrx.static.codebase import CodeChunk 6 | from coderetrx.retrieval.factory import CodebaseFactory 7 | from coderetrx.retrieval.smart_codebase import SmartCodebaseSettings 8 | 9 | 10 | def load_coderetrx_codebase(codebase_path: str) -> Codebase: 11 | abs_repo_path = Path(codebase_path).resolve() 12 | 13 | smart_settings = SmartCodebaseSettings() 14 | smart_settings.symbol_codeline_embedding = False 15 | smart_settings.keyword_embedding = False 16 | smart_settings.symbol_content_embedding = False 17 | smart_settings.symbol_name_embedding = False 18 | 19 | codebase = CodebaseFactory.new( 20 | str(abs_repo_path), 21 | abs_repo_path, 22 | smart_settings, 23 | ) 24 | return codebase 25 | 26 | 27 | def show_code_chunks(codebase: Codebase, file_name: str): 28 | for chunk in codebase.get_splited_distinct_chunks(100): 29 | # for chunk in codebase.all_chunks: 30 | if str(chunk.src.path) != file_name: 31 | continue 32 | print("Type:", chunk.type) 33 | print() 34 | print( 35 | chunk.ast_codeblock( 36 | show_line_numbers=True, 37 | zero_based_line_numbers=False, 38 | show_imports=True, 39 | ) 40 | ) 41 | print("=" * 40) 42 | 43 | 44 | def main(): 45 | codebase_path = sys.argv[1] 46 | codebase = load_coderetrx_codebase(codebase_path) 47 | 48 | file_name = sys.argv[2] 49 | show_code_chunks(codebase, file_name) 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tests/go.scm: -------------------------------------------------------------------------------- 1 | ; Functions names start with `Test` 2 | ( 3 | ( 4 | (function_declaration name: (_) @run 5 | (#match? @run "^Test.*")) 6 | ) @_ 7 | (#set! tag go-test) 8 | ) 9 | 10 | ; `go:generate` comments 11 | ( 12 | ((comment) @_comment @run 13 | (#match? @_comment "^//go:generate")) 14 | (#set! tag go-generate) 15 | ) 16 | 17 | ; `t.Run` 18 | ( 19 | ( 20 | (call_expression 21 | function: ( 22 | selector_expression 23 | field: _ @run @_name 24 | (#eq? @_name "Run") 25 | ) 26 | arguments: ( 27 | argument_list 28 | . 29 | (interpreted_string_literal) @_subtest_name 30 | . 31 | (func_literal 32 | parameters: ( 33 | parameter_list 34 | (parameter_declaration 35 | name: (identifier) @_param_name 36 | type: (pointer_type 37 | (qualified_type 38 | package: (package_identifier) @_pkg 39 | name: (type_identifier) @_type 40 | (#eq? @_pkg "testing") 41 | (#eq? @_type "T") 42 | ) 43 | ) 44 | ) 45 | ) 46 | ) @_second_argument 47 | ) 48 | ) 49 | ) @_ 50 | (#set! tag go-subtest) 51 | ) 52 | 53 | ; Functions names start with `Benchmark` 54 | ( 55 | ( 56 | (function_declaration name: (_) @run @_name 57 | (#match? @_name "^Benchmark.+")) 58 | ) @_ 59 | (#set! tag go-benchmark) 60 | ) 61 | 62 | ; Functions names start with `Fuzz` 63 | ( 64 | ( 65 | (function_declaration name: (_) @run @_name 66 | (#match? @_name "^Fuzz")) 67 | ) @_ 68 | (#set! tag go-fuzz) 69 | ) 70 | 71 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "coderetrx" 3 | version = "0.2.0beta" 4 | description = "Library for Code Analysis via Static Analysis and LLMs" 5 | readme = "README.md" 6 | license = { file = "LICENSE" } 7 | requires-python = ">=3.12" 8 | dependencies = [ 9 | "attrs>=25.3.0", 10 | "git-python>=1.0.3", 11 | "httpx>=0.27.2", 12 | "ignore-python>=0.2.0", 13 | "json-repair>=0.46.0", 14 | "nanoid>=2.0.0", 15 | "pydantic>=2.11.4", 16 | "pydantic-settings>=2.9.1", 17 | "python-dotenv>=1.1.0", 18 | "python-git>=2018.2.1", 19 | "tomli-w>=1.2.0", 20 | "tree-sitter>=0.25.2", 21 | "tree-sitter-language-pack>=0.7.3", 22 | "openai>=1.0.0", 23 | "qdrant-client>=1.12.0", 24 | "tenacity>=9.1.2", 25 | "aiofiles>=24.1.0", 26 | "tiktoken>=0.9.0", 27 | "filetype>=1.2.0", 28 | "mcp>=1.13.1", 29 | "starlette>=0.47.3", 30 | "anyio>=4.9.0", 31 | "pytest-asyncio>=1.2.0", 32 | ] 33 | 34 | [tool.setuptools.packages.find] 35 | where = ["."] 36 | include = ["coderetrx*"] 37 | exclude = [] 38 | namespaces = false 39 | 40 | [dependency-groups] 41 | dev = [ 42 | "httpx[socks]>=0.28.1", 43 | "ipython>=9.2.0", 44 | "pytest>=8.3.5", 45 | "rich>=14.0.0", 46 | "black", 47 | ] 48 | 49 | [project.optional-dependencies] 50 | stats = ["tiktoken>=0.9.0"] 51 | cli = ["typer>=0.16.0"] 52 | redis = ["redis>=5.0.0"] 53 | chromadb = ["chromadb>=0.4.0"] 54 | 55 | 56 | [tool.pytest.ini_options] 57 | asyncio_default_fixture_loop_scope = "function" 58 | 59 | [build-system] 60 | requires = ["setuptools>=61.0"] 61 | build-backend = "setuptools.build_meta" 62 | 63 | [tool.setuptools] 64 | include-package-data = true 65 | 66 | [tool.setuptools.package-data] 67 | "coderetrx" = ["**/*.scm", "**/*.py"] 68 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/filter_filename_by_llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Strategy for filtering filenames using LLM. 3 | """ 4 | 5 | from typing import List, override 6 | from .base import FilterByLLMStrategy, StrategyExecuteResult 7 | from ..smart_codebase import SmartCodebase as Codebase, LLMMapFilterTargetType 8 | from coderetrx.static import File 9 | 10 | 11 | class FilterFilenameByLLMStrategy(FilterByLLMStrategy[File]): 12 | """Strategy to filter filenames using LLM.""" 13 | 14 | @override 15 | def get_strategy_name(self) -> str: 16 | return "FILTER_FILENAME_BY_LLM" 17 | 18 | @override 19 | def get_target_type(self) -> LLMMapFilterTargetType: 20 | return "file_name" 21 | 22 | @override 23 | def extract_file_paths(self, elements: List[File], codebase: Codebase) -> List[str]: 24 | return [str(file.path) for file in elements] 25 | 26 | @override 27 | async def execute( 28 | self, 29 | codebase: Codebase, 30 | prompt: str, 31 | subdirs_or_files: List[str], 32 | target_type: str = "symbol_content", 33 | ) -> StrategyExecuteResult: 34 | prompt = f""" 35 | A file with this path is highly likely to contain content that matches the following criteria: 36 | 37 | {prompt} 38 | 39 | 40 | The objective of this requirement is to preliminarily identify files based on their paths that are likely to meet specific content criteria. 41 | Files with matching paths will proceed to a deeper analysis in the content filter (content_criterias) at a later stage (not in this run). 42 | 43 | """ 44 | return await super().execute(codebase, prompt, subdirs_or_files, target_type) 45 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/csharp.scm: -------------------------------------------------------------------------------- 1 | ; Based on https://github.com/tree-sitter/tree-sitter-c-sharp/blob/master/queries/tags.scm 2 | ; MIT License. 3 | 4 | (class_declaration name: (identifier) @name.definition.class) @definition.class 5 | 6 | (class_declaration (base_list (_) @name.reference.class)) @reference.class 7 | 8 | (interface_declaration name: (identifier) @name.definition.interface) @definition.interface 9 | 10 | (interface_declaration (base_list (_) @name.reference.interface)) @reference.interface 11 | 12 | (method_declaration name: (identifier) @name.definition.method) @definition.method 13 | 14 | (object_creation_expression type: (identifier) @name.reference.class) @reference.class 15 | 16 | (type_parameter_constraints_clause (identifier) @name.reference.class) @reference.class 17 | 18 | (type_parameter_constraint (type type: (identifier) @name.reference.class)) @reference.class 19 | 20 | (variable_declaration type: (identifier) @name.reference.class) @reference.class 21 | 22 | (invocation_expression function: (member_access_expression name: (identifier) @name.reference.send)) @reference.send 23 | 24 | (namespace_declaration name: (identifier) @name.definition.module) @definition.module 25 | 26 | (namespace_declaration name: (identifier) @name.definition.module) @module 27 | 28 | (using_directive) @import 29 | 30 | ; Variable definitions - field and variable declarations 31 | (field_declaration 32 | (variable_declaration 33 | (variable_declarator 34 | (identifier) @name.definition.variable))) @definition.variable 35 | 36 | (local_declaration_statement 37 | (variable_declaration 38 | (variable_declarator 39 | (identifier) @name.definition.variable))) @definition.variable 40 | 41 | (property_declaration 42 | name: (identifier) @name.definition.variable) @definition.variable 43 | -------------------------------------------------------------------------------- /coderetrx/utils/path.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from coderetrx.utils.git import get_repo_id 3 | import os 4 | 5 | 6 | def get_data_dir(): 7 | """Returns the path to the data directory.""" 8 | data_dir = os.getenv("DATA_DIR") 9 | if data_dir: 10 | data_dir = Path(data_dir) 11 | else: 12 | data_dir = Path(__file__).parent.parent.parent / ".data" 13 | data_dir.mkdir(parents=True, exist_ok=True) 14 | return data_dir 15 | 16 | def get_repo_path(repo_url): 17 | repo_path = get_data_dir() / "repos" / get_repo_id(repo_url) 18 | return repo_path 19 | 20 | def get_cache_dir(): 21 | cache_dir = os.getenv("CACHE_DIR") 22 | if cache_dir: 23 | cache_dir = Path(cache_dir) 24 | else: 25 | cache_dir = Path(__file__).parent.parent.parent / ".cache" 26 | (cache_dir / "llm").mkdir(parents=True, exist_ok=True) 27 | (cache_dir / "embedding").mkdir(parents=True, exist_ok=True) 28 | cache_dir.mkdir(parents=True, exist_ok=True) 29 | return cache_dir 30 | 31 | def safe_join(path1: str | Path, path2: str | Path) -> Path: 32 | if isinstance(path1, str): 33 | path1 = Path(path1) 34 | if isinstance(path2, str): 35 | path2 = Path(path2) 36 | result = path1 / path2 37 | # Basic check: disallow absolute path2 that would ignore path1 38 | if not result.is_relative_to(path1): 39 | raise ValueError(f"Path {path2} is not relative to the base directory {path1}") 40 | # Symlink-aware check: resolve both sides (non-strict to allow non-existing leaf) 41 | resolved_base = path1.resolve(strict=False) 42 | resolved_result = (path1 / path2).resolve(strict=False) 43 | if not resolved_result.is_relative_to(resolved_base): 44 | raise ValueError( 45 | f"Path {path2} escapes base directory {path1} via symlink resolution" 46 | ) 47 | return result 48 | -------------------------------------------------------------------------------- /test/impl/default/test_determine_strategy.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import asyncio 3 | from coderetrx.retrieval.code_recall import _determine_strategy_by_llm 4 | import json 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | logging.basicConfig(level=logging.INFO) 9 | 10 | def generate_prompts(limit=10): 11 | all_prompts = [] 12 | bug_name = [] 13 | feature_file = "features/test.json" 14 | try: 15 | with open(feature_file, 'r', encoding='utf-8') as f: 16 | features = json.load(f) 17 | for feature in features: 18 | if len(all_prompts) >= limit: 19 | break 20 | # Extract filter prompts from resources 21 | for resource in feature.get('resources', []): 22 | if resource.get('type') == 'ToolCallingResource': 23 | filter_prompt = resource.get('tool_input_kwargs', {}).get('filter_prompt') 24 | if filter_prompt: 25 | all_prompts.append(filter_prompt) 26 | bug_name.append(feature.get('name')) 27 | if len(all_prompts) >= limit: 28 | break 29 | except Exception as e: 30 | logger.warning(f"Failed to load feature file {feature_file}: {e}") 31 | return all_prompts[:limit], bug_name[:limit] 32 | 33 | class TestDetermineStrategy(unittest.TestCase): 34 | def test_determine_strategy(self): 35 | async def run_test(): 36 | res = generate_prompts(25) 37 | for idx in range(len(res[0])): 38 | bug_name = res[1][idx] 39 | prompt = res[0][idx] 40 | print(f"Bug Name: {bug_name}") 41 | strategy = await _determine_strategy_by_llm(prompt) 42 | print(f"Strategy: {strategy}") 43 | asyncio.run(run_test()) 44 | 45 | if __name__ == "__main__": 46 | unittest.main() -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/typescript.scm: -------------------------------------------------------------------------------- 1 | (function_signature 2 | name: (identifier) @name.definition.function) @definition.function 3 | 4 | (method_signature 5 | name: (property_identifier) @name.definition.method) @definition.method 6 | 7 | (abstract_method_signature 8 | name: (property_identifier) @name.definition.method) @definition.method 9 | 10 | (abstract_class_declaration 11 | name: (type_identifier) @name.definition.class) @definition.class 12 | 13 | (module 14 | name: (identifier) @name.definition.module) @definition.module 15 | 16 | (interface_declaration 17 | name: (type_identifier) @name.definition.interface) @definition.interface 18 | 19 | (type_annotation 20 | (type_identifier) @name.reference.type) @reference.type 21 | 22 | (new_expression 23 | constructor: (identifier) @name.reference.class) @reference.class 24 | 25 | (function_declaration 26 | name: (identifier) @name.definition.function) @definition.function 27 | 28 | (arrow_function) @definition.function 29 | 30 | (method_definition 31 | name: (property_identifier) @name.definition.method) @definition.method 32 | 33 | (class_declaration 34 | name: (type_identifier) @name.definition.class) @definition.class 35 | 36 | (interface_declaration 37 | name: (type_identifier) @name.definition.class) @definition.class 38 | 39 | (type_alias_declaration 40 | name: (type_identifier) @name.definition.type) @definition.type 41 | 42 | (enum_declaration 43 | name: (identifier) @name.definition.enum) @definition.enum 44 | 45 | (import_statement) @import 46 | 47 | ; Variable definitions - const/let/var declarations 48 | ; Note: We capture all variable declarations, function assignments are already captured above 49 | (lexical_declaration 50 | (variable_declarator 51 | name: (identifier) @name.definition.variable)) @definition.variable 52 | 53 | (variable_declaration 54 | (variable_declarator 55 | name: (identifier) @name.definition.variable)) @definition.variable 56 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tests/rust.scm: -------------------------------------------------------------------------------- 1 | ; Rust mod test 2 | ( 3 | (mod_item 4 | name: (_) @run 5 | (#eq? @run "tests") 6 | ) @test 7 | (#set! tag rust-mod-test) 8 | ) 9 | 10 | ; Rust test 11 | ( 12 | ( 13 | (attribute_item (attribute 14 | [((identifier) @_attribute) 15 | (scoped_identifier (identifier) @_attribute) 16 | ]) 17 | (#match? @_attribute "test") 18 | ) @_start 19 | . 20 | (attribute_item) * 21 | . 22 | [(line_comment) (block_comment)] * 23 | . 24 | (function_item 25 | name: (_) @run @_test_name 26 | body: _ 27 | ) @_end 28 | ) @test 29 | (#set! tag rust-test) 30 | ) 31 | 32 | ; Rust doc test 33 | ; ( 34 | ; ( 35 | ; (line_comment) * 36 | ; (line_comment 37 | ; doc: (_) @_comment_content 38 | ; ) @_start @run 39 | ; (#match? @_comment_content "```") 40 | ; . 41 | ; (line_comment) * 42 | ; . 43 | ; (line_comment 44 | ; doc: (_) @_end_comment_content 45 | ; ) @_end_code_block 46 | ; (#match? @_end_comment_content "```") 47 | ; . 48 | ; (line_comment) * 49 | ; (attribute_item) * 50 | ; . 51 | ; [(function_item 52 | ; name: (_) @_doc_test_name 53 | ; body: _ 54 | ; ) (function_signature_item 55 | ; name: (_) @_doc_test_name 56 | ; ) (struct_item 57 | ; name: (_) @_doc_test_name 58 | ; ) (enum_item 59 | ; name: (_) @_doc_test_name 60 | ; body: _ 61 | ; ) ( 62 | ; (attribute_item) ? 63 | ; (macro_definition 64 | ; name: (_) @_doc_test_name) 65 | ; ) (mod_item 66 | ; name: (_) @_doc_test_name 67 | ; )] @_end 68 | ; ) @test 69 | ; (#set! tag rust-doc-test) 70 | ; ) 71 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/cpp.scm: -------------------------------------------------------------------------------- 1 | (struct_specifier name: (type_identifier) @name.definition.class body:(_)) @definition.class 2 | 3 | (declaration type: (union_specifier name: (type_identifier) @name.definition.class)) @definition.class 4 | 5 | (function_definition 6 | declarator: (function_declarator 7 | declarator: (identifier) @name.definition.function)) @definition.function 8 | 9 | (function_definition 10 | declarator: (pointer_declarator 11 | declarator: (function_declarator 12 | declarator: (identifier) @name.definition.function))) @definition.function 13 | 14 | (function_definition 15 | declarator: (function_declarator 16 | declarator: (field_identifier) @name.definition.function)) @definition.function 17 | 18 | (function_definition 19 | declarator: (function_declarator 20 | declarator: (qualified_identifier scope: (namespace_identifier) @scope name: (identifier) @name.definition.method))) @definition.method 21 | 22 | (function_definition 23 | declarator: (reference_declarator 24 | (function_declarator 25 | declarator: (qualified_identifier scope: (namespace_identifier) @scope name: (identifier) @name.definition.method)))) @definition.method 26 | 27 | (type_definition declarator: (type_identifier) @name.definition.type) @definition.type 28 | 29 | (enum_specifier name: (type_identifier) @name.definition.type) @definition.type 30 | 31 | (class_specifier name: (type_identifier) @name.definition.class) @definition.class 32 | 33 | (preproc_include) @import 34 | 35 | ; Variable definitions - all declarations 36 | (declaration 37 | declarator: (identifier) @name.definition.variable) @definition.variable 38 | 39 | (declaration 40 | declarator: (init_declarator 41 | declarator: (identifier) @name.definition.variable)) @definition.variable 42 | 43 | (declaration 44 | declarator: (pointer_declarator 45 | declarator: (identifier) @name.definition.variable)) @definition.variable 46 | 47 | (declaration 48 | declarator: (reference_declarator 49 | (identifier) @name.definition.variable)) @definition.variable 50 | -------------------------------------------------------------------------------- /scripts/benchmark.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import re 3 | import subprocess 4 | import time 5 | from pathlib import Path 6 | 7 | async def run_command(cmd): 8 | process = await asyncio.create_subprocess_exec(*cmd, cwd=Path(__file__).parent.parent) 9 | return await process.wait() 10 | 11 | def parse_repos(): 12 | repos = [] 13 | with open("bench/repos.txt", 'r') as f: 14 | content = f.read() 15 | 16 | blocks = [block.strip() for block in content.split('\n\n') if block.strip()] 17 | for block in blocks: 18 | lines = block.strip().split('\n') 19 | if len(lines) >= 2 and lines[1].startswith('https://'): 20 | repos.append(lines[1].strip()) 21 | 22 | return repos 23 | 24 | async def main(): 25 | repos = parse_repos() 26 | modes = ["file_name", "symbol_name", "line_per_symbol", "precise", "auto"] 27 | 28 | print(f"Running benchmark on {len(repos)} repositories with {len(modes)} modes") 29 | 30 | repo_url = "https://github.com/ollama/ollama" 31 | for mode in modes: 32 | print(f" Running {mode} mode...") 33 | cmd = ["uv", "run", "scripts/code_retriever.py", "-l", "9", "-f", "--mode", mode, "--repo", repo_url] 34 | exit_code = await run_command(cmd) 35 | if exit_code != 0: 36 | print(f" Failed with exit code {exit_code}") 37 | 38 | # for i, repo_url in enumerate(repos, 1): 39 | # print(f"\n[{i}/{len(repos)}] Processing {repo_url}") 40 | # 41 | # for mode in modes: 42 | # print(f" Running {mode} mode...") 43 | # cmd = ["uv", "run", "scripts/code_retriever.py", "-l", "9", "-f", "--mode", mode, "--repo", repo_url, "-t"] 44 | # exit_code = await run_command(cmd) 45 | # if exit_code != 0: 46 | # print(f" Failed with exit code {exit_code}") 47 | # 48 | # break 49 | 50 | print("\nRunning analyze_code_reports...") 51 | await run_command(["uv", "run", "scripts/analyze_code_reports.py"]) 52 | 53 | print("Benchmark completed") 54 | 55 | if __name__ == "__main__": 56 | asyncio.run(main()) -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/rust.scm: -------------------------------------------------------------------------------- 1 | ; ADT definitions 2 | 3 | (struct_item 4 | name: (type_identifier) @name.definition.class) @definition.class 5 | 6 | (enum_item 7 | name: (type_identifier) @name.definition.class) @definition.class 8 | 9 | (union_item 10 | name: (type_identifier) @name.definition.class) @definition.class 11 | 12 | ; type aliases 13 | 14 | (type_item 15 | name: (type_identifier) @name.definition.class) @definition.class 16 | 17 | ; method definitions 18 | 19 | ; (declaration_list 20 | ; (function_item 21 | ; name: (identifier) @name.definition.method) @definition.method) 22 | 23 | ; function definitions 24 | 25 | (function_item 26 | name: (identifier) @name.definition.function) @definition.function 27 | 28 | ; trait definitions 29 | (trait_item 30 | name: (type_identifier) @name.definition.interface) @definition.interface 31 | 32 | ; module definitions 33 | (mod_item 34 | name: (identifier) @name.definition.module) @definition.module 35 | 36 | ; macro definitions 37 | 38 | (macro_definition 39 | name: (identifier) @name.definition.macro) @definition.macro 40 | 41 | ; references 42 | 43 | (call_expression 44 | function: (identifier) @name.reference.call) @reference.call 45 | 46 | (call_expression 47 | function: (field_expression 48 | field: (field_identifier) @name.reference.call)) @reference.call 49 | 50 | (macro_invocation 51 | macro: (identifier) @name.reference.call) @reference.call 52 | 53 | ; implementations 54 | 55 | (impl_item 56 | trait: (type_identifier) @name.reference.implementation) @reference.implementation 57 | 58 | (impl_item 59 | type: (type_identifier) @name.reference.implementation 60 | !trait) @reference.implementation 61 | 62 | ; handle imports 63 | 64 | ( use_declaration ) @import 65 | 66 | ; Variable definitions 67 | (let_declaration 68 | pattern: (identifier) @name.definition.variable) @definition.variable 69 | 70 | (const_item 71 | name: (identifier) @name.definition.variable) @definition.variable 72 | 73 | (static_item 74 | name: (identifier) @name.definition.variable) @definition.variable 75 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/elixir.scm: -------------------------------------------------------------------------------- 1 | ; Definitions 2 | 3 | ; * modules and protocols 4 | (call 5 | target: (identifier) @ignore 6 | (arguments (alias) @name.definition.module) 7 | (#match? @ignore "^(defmodule|defprotocol)$")) @definition.module 8 | 9 | ; * functions/macros 10 | (call 11 | target: (identifier) @ignore 12 | (arguments 13 | [ 14 | ; zero-arity functions with no parentheses 15 | (identifier) @name.definition.function 16 | ; regular function clause 17 | (call target: (identifier) @name.definition.function) 18 | ; function clause with a guard clause 19 | (binary_operator 20 | left: (call target: (identifier) @name.definition.function) 21 | operator: "when") 22 | ]) 23 | (#match? @ignore "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @definition.function 24 | 25 | ; References 26 | 27 | ; ignore calls to kernel/special-forms keywords 28 | (call 29 | target: (identifier) @ignore 30 | (#match? @ignore "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp|defmodule|defprotocol|defimpl|defstruct|defexception|defoverridable|alias|case|cond|else|for|if|import|quote|raise|receive|require|reraise|super|throw|try|unless|unquote|unquote_splicing|use|with)$")) 31 | 32 | ; ignore module attributes 33 | (unary_operator 34 | operator: "@" 35 | operand: (call 36 | target: (identifier) @ignore)) 37 | 38 | ; * function call 39 | (call 40 | target: [ 41 | ; local 42 | (identifier) @name.reference.call 43 | ; remote 44 | (dot 45 | right: (identifier) @name.reference.call) 46 | ]) @reference.call 47 | 48 | ; * pipe into function call 49 | (binary_operator 50 | operator: "|>" 51 | right: (identifier) @name.reference.call) @reference.call 52 | 53 | ; * modules 54 | (alias) @name.reference.module @reference.module 55 | 56 | (call 57 | target: (identifier)@_cap 58 | (#any-of? @_cap "import" "use" "require"))@import 59 | 60 | ; Variable definitions - module attributes 61 | (unary_operator 62 | operator: "@" 63 | operand: (call 64 | target: (identifier) @name.definition.variable)) @definition.variable 65 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/filter_dependency_by_llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Strategy for filtering dependencies using LLM. 3 | """ 4 | 5 | from typing import List, override 6 | from .base import FilterByLLMStrategy, StrategyExecuteResult 7 | from ..smart_codebase import SmartCodebase as Codebase, LLMMapFilterTargetType 8 | from coderetrx.static import Dependency, Symbol 9 | 10 | 11 | class FilterDependencyByLLMStrategy(FilterByLLMStrategy[Dependency]): 12 | """Strategy to filter dependencies by LLM and retrieve code chunks that use these dependencies.""" 13 | 14 | @override 15 | def get_strategy_name(self) -> str: 16 | return "FILTER_DEPENDENCY_BY_LLM" 17 | 18 | @override 19 | def get_target_type(self) -> LLMMapFilterTargetType: 20 | return "dependency_name" 21 | 22 | @override 23 | def extract_file_paths( 24 | self, elements: List[Dependency], codebase: Codebase 25 | ) -> List[str]: 26 | """Extract file paths from dependencies by getting files that import these dependencies.""" 27 | file_paths = [] 28 | for dependency in elements: 29 | if isinstance(dependency, Dependency): 30 | file_paths.extend([str(f.path) for f in dependency.imported_by]) 31 | elif isinstance(dependency, Symbol) and dependency.type == "dependency": 32 | file_paths.append(str(dependency.file.path)) 33 | return list(set(file_paths)) 34 | 35 | @override 36 | async def execute( 37 | self, 38 | codebase: Codebase, 39 | prompt: str, 40 | subdirs_or_files: List[str], 41 | target_type: str = "symbol_content", 42 | ) -> StrategyExecuteResult: 43 | enhanced_prompt = f""" 44 | A dependency that matches the following criteria is highly likely to be relevant: 45 | 46 | {prompt} 47 | 48 | 49 | The objective is to identify dependencies based on their names that match the specified criteria. 50 | Files that import these matching dependencies will be retrieved for further analysis. 51 | Focus on dependency names, package names, module names, and library names that are relevant to the criteria. 52 | 53 | """ 54 | return await super().execute( 55 | codebase, enhanced_prompt, subdirs_or_files, target_type 56 | ) 57 | -------------------------------------------------------------------------------- /STRATEGIES.md: -------------------------------------------------------------------------------- 1 | # Coarse Recall Strategies 2 | 3 | In the coarse recall stage, multiple strategies are used to efficiently retrieve potentially relevant code snippets at low cost. The retrieved results will then be further processed in the refined filter stage to improve precision and accuracy. 4 | 5 | ## Available Strategies 6 | 7 | - **`file_name`**: The fastest filtering strategy, ideal for coarse recall by directly determining relevance based on filenames. This approach is highly effective for cases where the query can be matched through filenames alone, such as "retrieve all configuration files." Its simplicity makes it extremely efficient for structural queries and file discovery. 8 | 9 | - **`symbol_name`**: A filtering approach that focuses on symbol names (e.g., function or class names) during coarse recall. This method excels in cases like "retrieve functions implementing cryptographic algorithms," where relevance can be inferred directly from the symbol name. It balances recall and precision for targeted queries at the symbol level. 10 | 11 | - **`line_per_symbol`**: A high-accuracy filtering strategy that leverages line-level vector search combined with LLM analysis. It identifies relevance by focusing on the most relevant code lines (`top-k`) within a function body. This method is particularly effective for complex cases where understanding the function's implementation is necessary, such as "retrieve code implementing authentication logic." While powerful, it's also versatile enough to handle most queries effectively. Our benchmarking shows that the algorithm achieves over 90% recall with approximately 25% of the computational cost, making it both precise and efficient. 12 | 13 | - **`auto`**: Automatically selects the optimal filtering strategy based on query complexity, routing requests to the most appropriate method. For the majority of cases, `line_per_symbol` is a well-balanced and reliable choice, but `file_name` or `symbol_name` can also be explicitly used for scenarios where they excel. 14 | 15 | ## Performance Characteristics 16 | 17 | - **Speed**: `file_name` > `symbol_name` > `auto` > `line_per_symbol` > `precise` 18 | - **Accuracy**: `precise` > `line_per_symbol` > `auto` > `symbol_name` > `file_name` 19 | - **Cost**: `file_name` < `line_per_symbol` < `auto` < `symbol_name` < `precise` 20 | 21 | Use `file_name` for structural queries, `symbol_name` for API search, `line_per_symbol` for specific code related analysis, `auto` for general purpose, and `precise` for ground truth. 22 | 23 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/hcl.scm: -------------------------------------------------------------------------------- 1 | ;; Based on https://github.com/tree-sitter-grammars/tree-sitter-hcl/blob/main/make_grammar.js 2 | ;; Which has Apache 2.0 License 3 | ;; tags.scm for Terraform (tree-sitter-hcl) 4 | 5 | ; === Definitions: Terraform Blocks === 6 | (block 7 | (identifier) @block_type 8 | (string_lit (template_literal) @resource_type) 9 | (string_lit (template_literal) @name.definition.resource) 10 | (body) @definition.resource 11 | ) (#eq? @block_type "resource") 12 | 13 | (block 14 | (identifier) @block_type 15 | (string_lit (template_literal) @name.definition.module) 16 | (body) @definition.module 17 | ) (#eq? @block_type "module") 18 | 19 | (block 20 | (identifier) @block_type 21 | (string_lit (template_literal) @name.definition.variable) 22 | (body) @definition.variable 23 | ) (#eq? @block_type "variable") 24 | 25 | (block 26 | (identifier) @block_type 27 | (string_lit (template_literal) @name.definition.output) 28 | (body) @definition.output 29 | ) (#eq? @block_type "output") 30 | 31 | (block 32 | (identifier) @block_type 33 | (string_lit (template_literal) @name.definition.provider) 34 | (body) @definition.provider 35 | ) (#eq? @block_type "provider") 36 | 37 | (block 38 | (identifier) @block_type 39 | (body 40 | (attribute 41 | (identifier) @name.definition.local 42 | (expression) @definition.local 43 | )+ 44 | ) 45 | ) (#eq? @block_type "locals") 46 | 47 | ; === References: Variables, Locals, Modules, Data, Resources === 48 | ((variable_expr) @ref_type 49 | (get_attr (identifier) @name.reference.variable) 50 | ) @reference.variable 51 | (#eq? @ref_type "var") 52 | 53 | ((variable_expr) @ref_type 54 | (get_attr (identifier) @name.reference.local) 55 | ) @reference.local 56 | (#eq? @ref_type "local") 57 | 58 | ((variable_expr) @ref_type 59 | (get_attr (identifier) @name.reference.module) 60 | ) @reference.module 61 | (#eq? @ref_type "module") 62 | 63 | ((variable_expr) @ref_type 64 | (get_attr (identifier) @data_source_type) 65 | (get_attr (identifier) @name.reference.data) 66 | ) @reference.data 67 | (#eq? @ref_type "data") 68 | 69 | ((variable_expr) @resource_type 70 | (get_attr (identifier) @name.reference.resource) 71 | ) @reference.resource 72 | (#not-eq? @resource_type "var") 73 | (#not-eq? @resource_type "local") 74 | (#not-eq? @resource_type "module") 75 | (#not-eq? @resource_type "data") 76 | (#not-eq? @resource_type "provider") 77 | (#not-eq? @resource_type "output") 78 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Strategy package for code retrieval strategies. 3 | 4 | This package contains individual strategy implementations split from the original strategies.py file. 5 | Each strategy is now in its own module for better organization and maintainability. 6 | """ 7 | 8 | from .base import ( 9 | RecallStrategy, 10 | StrategyExecuteResult, 11 | RecallStrategyExecutor, 12 | FilterByLLMStrategy, 13 | FilterByVectorStrategy, 14 | FilterByVectorAndLLMStrategy, 15 | AdaptiveFilterByVectorAndLLMStrategy, 16 | deduplicate_elements, 17 | rel_path, 18 | ) 19 | from .factory import StrategyFactory 20 | 21 | from .filter_filename_by_llm import FilterFilenameByLLMStrategy 22 | from .filter_symbol_name_by_llm import FilterSymbolNameByLLMStrategy 23 | from .filter_dependency_by_llm import FilterDependencyByLLMStrategy 24 | from .filter_keyword_by_vector import FilterKeywordByVectorStrategy 25 | from .filter_symbol_content_by_vector import FilterSymbolContentByVectorStrategy 26 | from .filter_keyword_by_vector_and_llm import FilterKeywordByVectorAndLLMStrategy 27 | from .filter_symbol_content_by_vector_and_llm import ( 28 | FilterSymbolContentByVectorAndLLMStrategy, 29 | ) 30 | from .adaptive_filter_keyword_by_vector_and_llm import ( 31 | AdaptiveFilterKeywordByVectorAndLLMStrategy, 32 | ) 33 | from .adaptive_filter_symbol_content_by_vector_and_llm import ( 34 | AdaptiveFilterSymbolContentByVectorAndLLMStrategy, 35 | ) 36 | from .filter_line_per_symbol_by_vector_and_llm import ( 37 | FilterLinePerSymbolByVectorAndLLMStrategy, 38 | ) 39 | 40 | __all__ = [ 41 | # Enums and Models 42 | "RecallStrategy", 43 | "StrategyExecuteResult", 44 | # Base Classes 45 | "RecallStrategyExecutor", 46 | "FilterByLLMStrategy", 47 | "FilterByVectorStrategy", 48 | "FilterByVectorAndLLMStrategy", 49 | "AdaptiveFilterByVectorAndLLMStrategy", 50 | # Concrete Strategy Implementations 51 | "FilterFilenameByLLMStrategy", 52 | "FilterSymbolNameByLLMStrategy", 53 | "FilterDependencyByLLMStrategy", 54 | "FilterKeywordByVectorStrategy", 55 | "FilterSymbolContentByVectorStrategy", 56 | "FilterKeywordByVectorAndLLMStrategy", 57 | "FilterSymbolContentByVectorAndLLMStrategy", 58 | "AdaptiveFilterKeywordByVectorAndLLMStrategy", 59 | "AdaptiveFilterSymbolContentByVectorAndLLMStrategy", 60 | "FilterLinePerSymbolByVectorAndLLMStrategy", 61 | # Factory and Utilities 62 | "StrategyFactory", 63 | "deduplicate_elements", 64 | "rel_path", 65 | ] 66 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/dart.scm: -------------------------------------------------------------------------------- 1 | (class_definition 2 | name: (identifier) @name.definition.class) @definition.class 3 | 4 | (method_signature 5 | (function_signature)) @definition.method 6 | 7 | (type_alias 8 | (type_identifier) @name.definition.type) @definition.type 9 | 10 | (method_signature 11 | (getter_signature 12 | name: (identifier) @name.definition.method)) @definition.method 13 | 14 | (method_signature 15 | (setter_signature 16 | name: (identifier) @name.definition.method)) @definition.method 17 | 18 | (method_signature 19 | (function_signature 20 | name: (identifier) @name.definition.method)) @definition.method 21 | 22 | (method_signature 23 | (factory_constructor_signature 24 | (identifier) @name.definition.method)) @definition.method 25 | 26 | (method_signature 27 | (constructor_signature 28 | name: (identifier) @name.definition.method)) @definition.method 29 | 30 | (method_signature 31 | (operator_signature)) @definition.method 32 | 33 | (method_signature) @definition.method 34 | 35 | (mixin_declaration 36 | (mixin) 37 | (identifier) @name.definition.mixin) @definition.mixin 38 | 39 | (extension_declaration 40 | name: (identifier) @name.definition.extension) @definition.extension 41 | 42 | (enum_declaration 43 | name: (identifier) @name.definition.enum) @definition.enum 44 | 45 | (function_signature 46 | name: (identifier) @name.definition.function) @definition.function 47 | 48 | (new_expression 49 | (type_identifier) @name.reference.class) @reference.class 50 | 51 | (initialized_variable_definition 52 | name: (identifier) 53 | value: (identifier) @name.reference.class 54 | value: (selector 55 | "!"? 56 | (argument_part 57 | (arguments 58 | (argument)*))?)?) @reference.class 59 | 60 | (assignment_expression 61 | left: (assignable_expression 62 | (identifier) 63 | (unconditional_assignable_selector 64 | "." 65 | (identifier) @name.reference.call))) @reference.call 66 | 67 | (assignment_expression 68 | left: (assignable_expression 69 | (identifier) 70 | (conditional_assignable_selector 71 | "?." 72 | (identifier) @name.reference.call))) @reference.call 73 | 74 | ((identifier) @name 75 | (selector 76 | "!"? 77 | (conditional_assignable_selector 78 | "?." (identifier) @name.reference.call)? 79 | (unconditional_assignable_selector 80 | "."? (identifier) @name.reference.call)? 81 | (argument_part 82 | (arguments 83 | (argument)*))?)* 84 | (cascade_section 85 | (cascade_selector 86 | (identifier)) @name.reference.call 87 | (argument_part 88 | (arguments 89 | (argument)*))?)?) @reference.call 90 | 91 | 92 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tests/python.scm: -------------------------------------------------------------------------------- 1 | ; subclasses of unittest.TestCase or TestCase 2 | ( 3 | (class_definition 4 | name: (identifier) @run @_unittest_class_name 5 | superclasses: (argument_list 6 | [(identifier) @_superclass 7 | (attribute (identifier) @_superclass)] 8 | ) 9 | (#eq? @_superclass "TestCase") 10 | ) @_python-unittest-class 11 | (#set! tag python-unittest-class) 12 | ) 13 | 14 | ; test methods whose names start with `test` in a TestCase 15 | ( 16 | (class_definition 17 | name: (identifier) @_unittest_class_name 18 | superclasses: (argument_list 19 | [(identifier) @_superclass 20 | (attribute (identifier) @_superclass)] 21 | ) 22 | (#eq? @_superclass "TestCase") 23 | body: (block 24 | (function_definition 25 | name: (identifier) @run @_unittest_method_name 26 | (#match? @_unittest_method_name "^test.*") 27 | ) @_python-unittest-method 28 | (#set! tag python-unittest-method) 29 | ) 30 | ) 31 | ) 32 | 33 | ; pytest functions 34 | ( 35 | (module 36 | (function_definition 37 | name: (identifier) @run @_pytest_method_name 38 | (#match? @_pytest_method_name "^test_") 39 | ) @_python-pytest-method 40 | ) 41 | (#set! tag python-pytest-method) 42 | ) 43 | 44 | ; decorated pytest functions 45 | ( 46 | (module 47 | (decorated_definition 48 | (decorator)+ @_decorator 49 | definition: (function_definition 50 | name: (identifier) @run @_pytest_method_name 51 | (#match? @_pytest_method_name "^test_") 52 | ) 53 | ) @_python-pytest-method 54 | ) 55 | (#set! tag python-pytest-method) 56 | ) 57 | 58 | ; pytest classes 59 | ( 60 | (module 61 | (class_definition 62 | name: (identifier) @run @_pytest_class_name 63 | (#match? @_pytest_class_name "^Test") 64 | ) 65 | (#set! tag python-pytest-class) 66 | ) 67 | ) 68 | 69 | ; pytest class methods 70 | ( 71 | (module 72 | (class_definition 73 | name: (identifier) @_pytest_class_name 74 | (#match? @_pytest_class_name "^Test") 75 | body: (block 76 | (function_definition 77 | name: (identifier) @run @_pytest_method_name 78 | (#match? @_pytest_method_name "^test") 79 | ) @_python-pytest-method 80 | (#set! tag python-pytest-method) 81 | ) 82 | ) 83 | ) 84 | ) 85 | -------------------------------------------------------------------------------- /coderetrx/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Pydantic based structured logging into a JSONL file. 2 | # The "logging" here should specifically be for dataset creation. 3 | from datetime import datetime 4 | from os import PathLike 5 | from pathlib import Path 6 | from typing import Literal, Optional, List, TYPE_CHECKING 7 | 8 | from pydantic import BaseModel, RootModel 9 | 10 | from coderetrx.static.codebase.models import IsolatedCodeChunkModel 11 | from coderetrx.static.codebase.languages import IDXSupportedTag 12 | 13 | 14 | class FilteringLog(BaseModel): 15 | type: Literal["filtering"] = "filtering" 16 | query: str 17 | total_chunks: int 18 | strategy: Literal["static", "dynamic"] 19 | limit: int 20 | filter_tags: Optional[List["IDXSupportedTag"]] = None 21 | 22 | 23 | class CodeChunkClassificationLog(BaseModel): 24 | type: Literal["code_chunk_classification"] = "code_chunk_classification" 25 | code_chunk: IsolatedCodeChunkModel 26 | classification: str 27 | rationale: str 28 | 29 | 30 | class VecSearchLog(BaseModel): 31 | type: Literal["vec_search"] = "vec_search" 32 | query: str 33 | total_retrieved: int 34 | matched_count: int 35 | strategy: Literal["static", "dynamic"] 36 | initial_limit: int 37 | final_limit: int 38 | filter_tags: Optional[List["IDXSupportedTag"]] = None 39 | success_ratio: float 40 | llm_model: str 41 | 42 | 43 | class LLMCallLog(BaseModel): 44 | type: Literal["llm_call"] = "llm_call" 45 | completion_id: str 46 | model: str 47 | completion_tokens: int 48 | prompt_tokens: int 49 | total_tokens: int 50 | call_url: str 51 | cached: bool = False 52 | 53 | 54 | class ErrLog(BaseModel): 55 | type: Literal["err"] = "err" 56 | error_type: str 57 | error: str 58 | 59 | 60 | type LogData = ( 61 | FilteringLog | CodeChunkClassificationLog | VecSearchLog | LLMCallLog | ErrLog 62 | ) 63 | 64 | 65 | class LogEntry(BaseModel): 66 | timestamp: str 67 | data: LogData 68 | 69 | 70 | def write_log(entry: LogEntry, file: PathLike | str): 71 | with open(file, "a") as f: 72 | f.write(entry.model_dump_json() + "\n") 73 | 74 | 75 | def read_logs(file: PathLike | str): 76 | with open(file, "r") as f: 77 | for line in f: 78 | yield LogEntry.model_validate_json(line) 79 | 80 | class JsonLogger: 81 | def __init__(self, file: PathLike | str): 82 | self.file = file 83 | Path(file).touch(exist_ok=True) 84 | 85 | def log(self, data: LogData): 86 | entry = LogEntry(timestamp=datetime.now().isoformat(), data=data) 87 | write_log(entry, self.file) 88 | 89 | if __name__ == "__main__": 90 | from rich import print 91 | 92 | print(LogEntry.model_json_schema()) 93 | -------------------------------------------------------------------------------- /coderetrx/tools/find_file_by_name.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from typing import ClassVar, List, Type 4 | from pathlib import Path 5 | from pydantic import BaseModel, Field 6 | from coderetrx.static.ripgrep import ripgrep_glob # type: ignore 7 | from coderetrx.utils.path import safe_join 8 | from coderetrx.tools.base import BaseTool 9 | import filetype 10 | 11 | 12 | class FindFileByNameArgs(BaseModel): 13 | dir_path: str = Field(description="The directory to search within") 14 | pattern: str = Field( 15 | description="pattern to search for. Based on keyword matching, wildcards are not supported and are not required." 16 | ) 17 | 18 | 19 | class FindFileByNameResult(BaseModel): 20 | path: str = Field(description="The path of matched file name") 21 | type: str = Field(description="The type of the file") 22 | 23 | @classmethod 24 | def repr(cls, entries: list["FindFileByNameResult"]): 25 | """ 26 | convert a list of FindFileByNameResult to a readable string 27 | """ 28 | if not entries: 29 | return "No result Found." 30 | if not entries[0].path: 31 | return "The SearchDirectory does not exist. Please make sure that the SearchDirectory you are investigating exists." 32 | tool_result = "" 33 | for i, file in enumerate(entries, 1): 34 | tool_result += f"# **{i}. {file.path}**\n" 35 | 36 | return tool_result 37 | 38 | 39 | class FindFileByNameTool(BaseTool): 40 | name = "find_file_by_name" 41 | description = ( 42 | "This tool searches for files and directories within a specified directory, similar to the Linux `find` command. " 43 | "The returned result paths are relative to the root path." 44 | ) 45 | args_schema: ClassVar[Type[FindFileByNameArgs]] = FindFileByNameArgs 46 | 47 | def _get_file_type(self, path: str) -> str: 48 | type = filetype.guess(path) 49 | if type: 50 | return type.mime 51 | else: 52 | return "Unknown" 53 | 54 | def forward(self, dir_path: str, pattern: str) -> str: 55 | """Synchronous wrapper for async _run method.""" 56 | return self.run_sync(dir_path=dir_path, pattern=pattern) 57 | 58 | async def _run(self, dir_path: str, pattern: str) -> list[FindFileByNameResult]: 59 | """Search for files matching the pattern in the specified directory""" 60 | full_dir_path = safe_join(self.repo_path, dir_path.lstrip("/")) 61 | 62 | if not os.path.exists(full_dir_path): 63 | return [FindFileByNameResult(path="", type="Directory Not Exists")] 64 | 65 | matched_files = await ripgrep_glob( 66 | full_dir_path, pattern, extra_argv=["-g", "!.git"] 67 | ) 68 | 69 | results = [] 70 | for file_path in matched_files: 71 | full_file_path = full_dir_path / file_path 72 | file_type = self._get_file_type(str(full_file_path)) 73 | results.append(FindFileByNameResult(path=file_path, type=file_type)) 74 | 75 | return results 76 | -------------------------------------------------------------------------------- /coderetrx/tools/base.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from pathlib import Path 4 | from typing import ClassVar, Optional 5 | 6 | from coderetrx.utils.path import get_data_dir, get_repo_path 7 | from coderetrx.utils.git import clone_repo_if_not_exists, get_repo_id 8 | from typing import Any, Type 9 | from pydantic import BaseModel 10 | from abc import abstractmethod 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class BaseTool: 16 | """Base class for tools that work with repositories.""" 17 | 18 | name: str 19 | description: str 20 | args_schema: ClassVar[Type[BaseModel]] 21 | 22 | def __init__(self, repo_url: str, uuid: Optional[str] = None): 23 | super().__init__() 24 | logger.info(f"Init base repo tool {self.name} with uuid: {uuid} ...") 25 | 26 | self.repo_url = repo_url 27 | self.repo_id = get_repo_id(repo_url) 28 | self.uuid = uuid 29 | self.repo_path = get_repo_path(repo_url) 30 | 31 | clone_repo_if_not_exists(repo_url, str(self.repo_path)) 32 | 33 | def run_sync(self, *args, **kwargs): 34 | """Synchronous wrapper for async _run method.""" 35 | try: 36 | loop = asyncio.get_running_loop() 37 | # If we're already in a running loop, we need to use a different approach 38 | import concurrent.futures 39 | import threading 40 | 41 | def run_in_thread(): 42 | new_loop = asyncio.new_event_loop() 43 | asyncio.set_event_loop(new_loop) 44 | try: 45 | return new_loop.run_until_complete(self.run(*args, **kwargs)) 46 | finally: 47 | new_loop.close() 48 | 49 | with concurrent.futures.ThreadPoolExecutor() as executor: 50 | future = executor.submit(run_in_thread) 51 | return future.result() 52 | 53 | except RuntimeError: 54 | # No event loop is running, we can create one 55 | loop = asyncio.new_event_loop() 56 | asyncio.set_event_loop(loop) 57 | try: 58 | return loop.run_until_complete(self.run(*args, **kwargs)) 59 | finally: 60 | loop.close() 61 | 62 | async def _run_repr(self, *args, **kwargs: dict[str, str]): 63 | result = await self._run(*args, **kwargs) 64 | if not result: 65 | return "No result Found." 66 | if type(result) == str: 67 | return result 68 | return type(result[0]).repr(result) 69 | 70 | @abstractmethod 71 | async def _run( 72 | self, 73 | *args: Any, 74 | **kwargs: Any, 75 | ) -> Any: 76 | """Async implementation to be overridden by subclasses.""" 77 | raise NotImplementedError("Subclasses must implement _run method") 78 | 79 | async def run( 80 | self, 81 | *args: Any, 82 | **kwargs: Any, 83 | ) -> Any: 84 | repr_output = kwargs.pop("repr_output", True) 85 | if repr_output: 86 | return await self._run_repr(*args, **kwargs) 87 | else: 88 | return await self._run(*args, **kwargs) 89 | -------------------------------------------------------------------------------- /scripts/example.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from coderetrx.retrieval import coderetrx_filter, CodebaseFactory, TopicExtractor, SmartCodebaseSettings 3 | from coderetrx.retrieval.code_recall import CodeRecallSettings 4 | from coderetrx.utils.git import clone_repo_if_not_exists, get_repo_id 5 | from coderetrx.utils.path import get_data_dir 6 | 7 | async def main(): 8 | # Set up the repository URL and path 9 | repo_url = "https://github.com/ollama/ollama.git" 10 | repo_path = get_data_dir() / "repos" / get_repo_id(repo_url) 11 | 12 | # Clone the repository if it does not exist 13 | clone_repo_if_not_exists(repo_url, str(repo_path)) 14 | 15 | # Create codebase settings with symbol_codeline_embedding enabled 16 | codebase_settings = SmartCodebaseSettings() 17 | codebase_settings.symbol_codeline_embedding = True 18 | 19 | # Create a codebase instance 20 | codebase = CodebaseFactory.new(get_repo_id(repo_url), repo_path, settings=codebase_settings) 21 | 22 | # Create a topic extractor instance 23 | topic_extractor = TopicExtractor() 24 | 25 | # Initialize code recall settings 26 | settings = CodeRecallSettings() 27 | 28 | # Set the target_type and coarse recall strategy 29 | result, llm_output = await coderetrx_filter( 30 | codebase=codebase, 31 | subdirs_or_files=["/"], 32 | prompt="The code snippet contains a function call that dynamically executes code or system commands. Examples include Python's `eval()`, `exec()`, or functions like `os.system()`, `subprocess.run()` (especially with `shell=True`), `subprocess.call()` (with `shell=True`), or `popen()`. The critical feature is that the string representing the code or command to be executed is not a hardcoded literal; instead, it's derived from a variable, function argument, string concatenation/formatting, or an external source such as user input, network request, or LLM output.", 33 | target_type="symbol_content", 34 | coarse_recall_strategy="line_per_symbol", 35 | topic_extractor=topic_extractor, 36 | settings=settings 37 | ) 38 | 39 | ''' 40 | result, llm_output = await coderetrx_filter( 41 | codebase=codebase, 42 | subdirs_or_files=["/"], 43 | prompt="The code snippet contains a function call that dynamically executes code or system commands. Examples include Python's `eval()`, `exec()`, or functions like `os.system()`, `subprocess.run()` (especially with `shell=True`), `subprocess.call()` (with `shell=True`), or `popen()`. The critical feature is that the string representing the code or command to be executed is not a hardcoded literal; instead, it's derived from a variable, function argument, string concatenation/formatting, or an external source such as user input, network request, or LLM output.", 44 | target_type="symbol_content", 45 | coarse_recall_strategy="line_per_symbol", 46 | topic_extractor=topic_extractor, 47 | settings=settings 48 | ) 49 | ''' 50 | 51 | print(f"Find {len(result)} results, first {min(5, len(result))} are:") 52 | for i, location in enumerate(result[:5], 1): 53 | print(f" {i}. {location}") 54 | 55 | if __name__ == "__main__": 56 | asyncio.run(main()) -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/javascript.scm: -------------------------------------------------------------------------------- 1 | ( 2 | (comment)* @doc 3 | . 4 | (method_definition 5 | name: (property_identifier) @name.definition.method) @definition.method 6 | (#not-eq? @name.definition.method "constructor") 7 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") 8 | (#select-adjacent! @doc @definition.method) 9 | ) 10 | 11 | ( 12 | (comment)* @doc 13 | . 14 | [ 15 | (class 16 | name: (_) @name.definition.class) 17 | (class_declaration 18 | name: (_) @name.definition.class) 19 | ] @definition.class 20 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") 21 | (#select-adjacent! @doc @definition.class) 22 | ) 23 | 24 | ( 25 | (comment)* @doc 26 | . 27 | [ 28 | (function_expression 29 | name: (identifier) @name.definition.function) 30 | (function_declaration 31 | name: (identifier) @name.definition.function) 32 | (generator_function 33 | name: (identifier) @name.definition.function) 34 | (generator_function_declaration 35 | name: (identifier) @name.definition.function) 36 | ] @definition.function 37 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") 38 | (#select-adjacent! @doc @definition.function) 39 | ) 40 | 41 | ( 42 | (comment)* @doc 43 | . 44 | (lexical_declaration 45 | (variable_declarator 46 | name: (identifier) @name.definition.function 47 | value: [(arrow_function) (function_expression)]) @definition.function) 48 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") 49 | (#select-adjacent! @doc @definition.function) 50 | ) 51 | 52 | ( 53 | (comment)* @doc 54 | . 55 | (variable_declaration 56 | (variable_declarator 57 | name: (identifier) @name.definition.function 58 | value: [(arrow_function) (function_expression)]) @definition.function) 59 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$") 60 | (#select-adjacent! @doc @definition.function) 61 | ) 62 | 63 | (assignment_expression 64 | left: [ 65 | (identifier) @name.definition.function 66 | (member_expression 67 | property: (property_identifier) @name.definition.function) 68 | ] 69 | right: [(arrow_function) (function_expression)] 70 | ) @definition.function 71 | 72 | (pair 73 | key: (property_identifier) @name.definition.function 74 | value: [(arrow_function) (function_expression)]) @definition.function 75 | 76 | ( 77 | (call_expression 78 | function: (identifier) @name.reference.call) @reference.call 79 | (#not-match? @name.reference.call "^(require)$") 80 | ) 81 | 82 | (call_expression 83 | function: (member_expression 84 | property: (property_identifier) @name.reference.call) 85 | arguments: (_) @reference.call) 86 | 87 | (new_expression 88 | constructor: (_) @name.reference.class) @reference.class 89 | 90 | (import_statement) @import 91 | 92 | ; Variable definitions - const/let/var declarations (non-function values) 93 | ; Note: We capture all variable declarations, function assignments are already captured above 94 | (lexical_declaration 95 | (variable_declarator 96 | name: (identifier) @name.definition.variable)) @definition.variable 97 | 98 | (variable_declaration 99 | (variable_declarator 100 | name: (identifier) @name.definition.variable)) @definition.variable 101 | -------------------------------------------------------------------------------- /test/tools/test_tools.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pytest 3 | from coderetrx.tools.get_references import GetReferenceTool 4 | from coderetrx.tools.view_file import ViewFileTool 5 | from coderetrx.tools.find_file_by_name import FindFileByNameTool 6 | from coderetrx.tools.keyword_search import KeywordSearchTool 7 | from coderetrx.tools.list_dir import ListDirTool 8 | import logging 9 | logger =logging.getLogger(__name__) 10 | logging.basicConfig(level=logging.INFO) 11 | TEST_REPO = "https://github.com/apache/flink.git" 12 | 13 | class TestGetReferenceTool: 14 | def test(self): 15 | """Test finding references to a symbol""" 16 | logger.info("Testing GetReferenceTool...") 17 | tool = GetReferenceTool(TEST_REPO) 18 | result = asyncio.run(tool._run(symbol_name="upload")) 19 | logger.info(f"GetReferenceTool result: {result}") 20 | assert isinstance(result, (list, dict)), "Result should be a list or dict" 21 | 22 | 23 | class TestViewFileTool: 24 | def test(self): 25 | """Test viewing file content""" 26 | logger.info("Testing ViewFileTool...") 27 | tool = ViewFileTool(TEST_REPO) 28 | result = asyncio.run(tool._run(file_path="./README.md", start_line=0, end_line=10)) 29 | logger.info(f"ViewFileTool result: {result}") 30 | assert isinstance(result, str), "Result should be a string (file content)" 31 | assert len(result) > 0, "File content should not be empty" 32 | 33 | 34 | class TestFindFileByNameTool: 35 | def test(self): 36 | """Test finding files by name pattern""" 37 | logger.info("Testing FindFileByNameTool...") 38 | tool = FindFileByNameTool(TEST_REPO) 39 | results = asyncio.run(tool._run(dir_path="/", pattern="*.md")) 40 | logger.info(f"FindFileByNameTool result: {results}") 41 | assert isinstance(results, list), "Result should be a list of file paths" 42 | 43 | 44 | class TestKeywordSearchTool: 45 | def test(self): 46 | """Test keyword search functionality""" 47 | logger.info("Testing KeywordSearchTool...") 48 | tool = KeywordSearchTool(TEST_REPO) 49 | result = asyncio.run( 50 | tool._run( 51 | query="README", 52 | dir_path="/", 53 | case_insensitive=False, 54 | include_content=False, 55 | ) 56 | ) 57 | logger.info(f"KeywordSearchTool result: {result}") 58 | assert isinstance(result, list), "Result should be a list of matches" 59 | 60 | 61 | class TestListDirTool: 62 | def test(self): 63 | """Test listing directory contents""" 64 | logger.info("Testing ListDirTool...") 65 | tool = ListDirTool(TEST_REPO) 66 | result = asyncio.run(tool._run(directory_path=".")) 67 | logger.info(f"ListDirTool result: {result}") 68 | assert isinstance(result, list), "Result should be a list of directory entries" 69 | 70 | def test_all_tools(): 71 | """Test all tools""" 72 | tool_testers = [ 73 | TestGetReferenceTool(), 74 | TestViewFileTool(), 75 | TestFindFileByNameTool(), 76 | TestKeywordSearchTool(), 77 | TestListDirTool(), 78 | ] 79 | for tester in tool_testers: 80 | tester.test() 81 | 82 | if __name__ == "__main__": 83 | test_all_tools() -------------------------------------------------------------------------------- /coderetrx/static/codebase/queries/treesitter/tags/ocaml.scm: -------------------------------------------------------------------------------- 1 | ; Modules 2 | ;-------- 3 | 4 | ( 5 | (comment)? @doc . 6 | (module_definition (module_binding (module_name) @name.definition.module) @definition.module) 7 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 8 | ) 9 | 10 | (module_path (module_name) @name.reference.module) @reference.module 11 | 12 | ; Module types 13 | ;-------------- 14 | 15 | ( 16 | (comment)? @doc . 17 | (module_type_definition (module_type_name) @name.definition.interface) @definition.interface 18 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 19 | ) 20 | 21 | (module_type_path (module_type_name) @name.reference.implementation) @reference.implementation 22 | 23 | ; Functions 24 | ;---------- 25 | 26 | ( 27 | (comment)? @doc . 28 | (value_definition 29 | [ 30 | (let_binding 31 | pattern: (value_name) @name.definition.function 32 | (parameter)) 33 | (let_binding 34 | pattern: (value_name) @name.definition.function 35 | body: [(fun_expression) (function_expression)]) 36 | ] @definition.function 37 | ) 38 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 39 | ) 40 | 41 | ( 42 | (comment)? @doc . 43 | (external (value_name) @name.definition.function) @definition.function 44 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 45 | ) 46 | 47 | (application_expression 48 | function: (value_path (value_name) @name.reference.call)) @reference.call 49 | 50 | (infix_expression 51 | left: (value_path (value_name) @name.reference.call) 52 | operator: (concat_operator) @reference.call 53 | (#eq? @reference.call "@@")) 54 | 55 | (infix_expression 56 | operator: (rel_operator) @reference.call 57 | right: (value_path (value_name) @name.reference.call) 58 | (#eq? @reference.call "|>")) 59 | 60 | ; Operator 61 | ;--------- 62 | 63 | ( 64 | (comment)? @doc . 65 | (value_definition 66 | (let_binding 67 | pattern: (parenthesized_operator (_) @name.definition.function)) @definition.function) 68 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 69 | ) 70 | 71 | [ 72 | (prefix_operator) 73 | (sign_operator) 74 | (pow_operator) 75 | (mult_operator) 76 | (add_operator) 77 | (concat_operator) 78 | (rel_operator) 79 | (and_operator) 80 | (or_operator) 81 | (assign_operator) 82 | (hash_operator) 83 | (indexing_operator) 84 | (let_operator) 85 | (let_and_operator) 86 | (match_operator) 87 | ] @name.reference.call @reference.call 88 | 89 | ; Classes 90 | ;-------- 91 | 92 | ( 93 | (comment)? @doc . 94 | [ 95 | (class_definition (class_binding (class_name) @name.definition.class) @definition.class) 96 | (class_type_definition (class_type_binding (class_type_name) @name.definition.class) @definition.class) 97 | ] 98 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 99 | ) 100 | 101 | [ 102 | (class_path (class_name) @name.reference.class) 103 | (class_type_path (class_type_name) @name.reference.class) 104 | ] @reference.class 105 | 106 | ; Methods 107 | ;-------- 108 | 109 | ( 110 | (comment)? @doc . 111 | (method_definition (method_name) @name.definition.method) @definition.method 112 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$") 113 | ) 114 | 115 | (method_invocation (method_name) @name.reference.call) @reference.call 116 | -------------------------------------------------------------------------------- /coderetrx/utils/concurrency.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Callable, Any, TypeVar, Coroutine, Optional 3 | from concurrent.futures import ThreadPoolExecutor 4 | import asyncio 5 | import threading 6 | 7 | T = TypeVar("T") 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | async def abatch_func_call( 13 | max_concurrency: int, func: Callable[..., Any], kwargs_list: list[dict] 14 | ) -> list[Any]: 15 | """Execute a batch of function calls with controlled concurrency. 16 | 17 | Handles both async and sync functions automatically: 18 | - Async functions run directly in the event loop 19 | - Sync functions run in thread pool executor to avoid blocking 20 | 21 | Features: 22 | - Semaphore-based concurrency control 23 | - Automatic error logging with function context 24 | - Preserves original exception stack traces 25 | 26 | Args: 27 | max_concurrency: Maximum parallel executions allowed 28 | func: Callable to execute (async or sync) 29 | kwargs_list: List of keyword arguments dictionaries for each call 30 | 31 | Returns: 32 | list: Results in the same order as kwargs_list, exceptions will propagate 33 | 34 | Raises: 35 | Exception: Re-raises the first encountered exception from any task 36 | """ 37 | semaphore = asyncio.Semaphore(max_concurrency) 38 | 39 | async def a_func_call(func: Callable, kwargs: dict) -> Any: 40 | """Execute a single function call with concurrency control. 41 | 42 | Args: 43 | func: Target function to execute 44 | kwargs: Keyword arguments for this call 45 | 46 | Returns: 47 | Any: Result of the function call 48 | 49 | Raises: 50 | Exception: Propagates any exceptions from the function call 51 | """ 52 | async with semaphore: # Acquire semaphore slot 53 | try: 54 | if asyncio.iscoroutinefunction(func): 55 | # Directly await async functions 56 | result = await func(**kwargs) 57 | else: 58 | # Run sync functions in thread pool to prevent blocking 59 | result = await asyncio.to_thread(func, **kwargs) 60 | return result 61 | except Exception as e: 62 | logger.error(f"Error in {func.__name__} with args {kwargs}: {e}") 63 | raise # Re-raise to maintain stack trace 64 | 65 | # Create and schedule all tasks 66 | tasks = [a_func_call(func, kwargs) for kwargs in kwargs_list] 67 | 68 | # Execute all tasks concurrently with controlled concurrency 69 | results = await asyncio.gather(*tasks, return_exceptions=False) 70 | 71 | return results 72 | 73 | 74 | def run_coroutine_sync( 75 | coroutine: Coroutine[Any, Any, T], timeout: Optional[float] = None 76 | ) -> T: 77 | def run_in_new_loop(): 78 | new_loop = asyncio.new_event_loop() 79 | asyncio.set_event_loop(new_loop) 80 | try: 81 | return new_loop.run_until_complete(coroutine) 82 | finally: 83 | new_loop.close() 84 | 85 | try: 86 | loop = asyncio.get_running_loop() 87 | except RuntimeError: 88 | return asyncio.run(coroutine) 89 | 90 | if threading.current_thread() is threading.main_thread(): 91 | if not loop.is_running(): 92 | return loop.run_until_complete(coroutine) 93 | else: 94 | with ThreadPoolExecutor() as pool: 95 | future = pool.submit(run_in_new_loop) 96 | return future.result(timeout=timeout) 97 | else: 98 | return asyncio.run_coroutine_threadsafe(coroutine, loop).result() 99 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/factory.py: -------------------------------------------------------------------------------- 1 | from coderetrx.retrieval.strategy.base import RecallStrategy 2 | from typing import Optional, Union, List 3 | 4 | from coderetrx.retrieval.smart_codebase import LLMCallMode 5 | 6 | from coderetrx.retrieval.topic_extractor import TopicExtractor 7 | from coderetrx.retrieval.strategy.base import RecallStrategyExecutor 8 | from coderetrx.retrieval.strategy.filter_filename_by_llm import ( 9 | FilterFilenameByLLMStrategy, 10 | ) 11 | from coderetrx.retrieval.strategy.filter_keyword_by_vector import ( 12 | FilterKeywordByVectorStrategy, 13 | ) 14 | from coderetrx.retrieval.strategy.filter_symbol_content_by_vector import ( 15 | FilterSymbolContentByVectorStrategy, 16 | ) 17 | from coderetrx.retrieval.strategy.filter_symbol_name_by_llm import ( 18 | FilterSymbolNameByLLMStrategy, 19 | ) 20 | from coderetrx.retrieval.strategy.filter_dependency_by_llm import ( 21 | FilterDependencyByLLMStrategy, 22 | ) 23 | from coderetrx.retrieval.strategy.filter_keyword_by_vector_and_llm import ( 24 | FilterKeywordByVectorAndLLMStrategy, 25 | ) 26 | from coderetrx.retrieval.strategy.filter_symbol_content_by_vector_and_llm import ( 27 | FilterSymbolContentByVectorAndLLMStrategy, 28 | ) 29 | from coderetrx.retrieval.strategy.adaptive_filter_symbol_content_by_vector_and_llm import ( 30 | AdaptiveFilterSymbolContentByVectorAndLLMStrategy, 31 | ) 32 | from coderetrx.retrieval.strategy.adaptive_filter_keyword_by_vector_and_llm import ( 33 | AdaptiveFilterKeywordByVectorAndLLMStrategy, 34 | ) 35 | from coderetrx.retrieval.strategy.filter_line_per_symbol_by_vector_and_llm import ( 36 | FilterLinePerSymbolByVectorAndLLMStrategy, 37 | ) 38 | from coderetrx.retrieval.strategy.filter_line_per_file_by_vector_and_llm import ( 39 | FilterLinePerFileByVectorAndLLMStrategy, 40 | ) 41 | 42 | 43 | class StrategyFactory: 44 | """Factory for creating strategy executors.""" 45 | 46 | def __init__( 47 | self, 48 | topic_extractor: Optional[TopicExtractor] = None, 49 | llm_call_mode: LLMCallMode = "traditional", 50 | ): 51 | self.topic_extractor = topic_extractor 52 | self.llm_call_mode = llm_call_mode 53 | 54 | def create_strategy(self, strategy: RecallStrategy) -> RecallStrategyExecutor: 55 | """Create a strategy executor based on the strategy enum.""" 56 | strategy_map = { 57 | RecallStrategy.FILTER_FILENAME_BY_LLM: FilterFilenameByLLMStrategy, 58 | RecallStrategy.FILTER_KEYWORD_BY_VECTOR: FilterKeywordByVectorStrategy, 59 | RecallStrategy.FILTER_SYMBOL_CONTENT_BY_VECTOR: FilterSymbolContentByVectorStrategy, 60 | RecallStrategy.FILTER_SYMBOL_NAME_BY_LLM: FilterSymbolNameByLLMStrategy, 61 | RecallStrategy.FILTER_DEPENDENCY_BY_LLM: FilterDependencyByLLMStrategy, 62 | RecallStrategy.FILTER_KEYWORD_BY_VECTOR_AND_LLM: FilterKeywordByVectorAndLLMStrategy, 63 | RecallStrategy.FILTER_SYMBOL_CONTENT_BY_VECTOR_AND_LLM: FilterSymbolContentByVectorAndLLMStrategy, 64 | RecallStrategy.ADAPTIVE_FILTER_KEYWORD_BY_VECTOR_AND_LLM: AdaptiveFilterKeywordByVectorAndLLMStrategy, 65 | RecallStrategy.ADAPTIVE_FILTER_SYMBOL_CONTENT_BY_VECTOR_AND_LLM: AdaptiveFilterSymbolContentByVectorAndLLMStrategy, 66 | RecallStrategy.FILTER_LINE_PER_SYMBOL_BY_VECTOR_AND_LLM: FilterLinePerSymbolByVectorAndLLMStrategy, 67 | RecallStrategy.FILTER_LINE_PER_FILE_BY_VECTOR_AND_LLM: FilterLinePerFileByVectorAndLLMStrategy, 68 | } 69 | 70 | if strategy not in strategy_map: 71 | raise ValueError(f"Unknown strategy: {strategy}") 72 | 73 | return strategy_map[strategy]( 74 | topic_extractor=self.topic_extractor, llm_call_mode=self.llm_call_mode 75 | ) 76 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/adaptive_filter_keyword_by_vector_and_llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Strategy for adaptive filtering of keywords using vector similarity search followed by LLM refinement. 3 | """ 4 | 5 | from typing import List, Union, Optional, override, Any, Literal 6 | from .base import AdaptiveFilterByVectorAndLLMStrategy, StrategyExecuteResult 7 | from ..smart_codebase import ( 8 | SmartCodebase as Codebase, 9 | LLMMapFilterTargetType, 10 | SimilaritySearchTargetType, 11 | ) 12 | from coderetrx.static import Keyword, Symbol, File 13 | 14 | 15 | class AdaptiveFilterKeywordByVectorAndLLMStrategy(AdaptiveFilterByVectorAndLLMStrategy): 16 | """Strategy to filter keywords using adaptive vector similarity search followed by LLM refinement.""" 17 | 18 | name: str = "ADAPTIVE_FILTER_KEYWORD_BY_VECTOR_AND_LLM" 19 | 20 | @override 21 | def get_strategy_name(self) -> str: 22 | return self.name 23 | 24 | @override 25 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]: 26 | return ["keyword"] 27 | 28 | @override 29 | def get_target_type_for_llm(self) -> LLMMapFilterTargetType: 30 | return "keyword" 31 | 32 | @override 33 | def get_collection_size(self, codebase: Codebase) -> int: 34 | return len(codebase.keywords) 35 | 36 | @override 37 | def filter_elements( 38 | self, 39 | codebase: Codebase, 40 | elements: List[Any], 41 | target_type: LLMMapFilterTargetType = "symbol_content", 42 | subdirs_or_files: List[str] = [], 43 | ) -> List[Union[Keyword, Symbol, File]]: 44 | keyword_elements: List[Union[Keyword, Symbol, File]] = [] 45 | for element in elements: 46 | if isinstance(element, Keyword): 47 | if subdirs_or_files and codebase: 48 | for ref_file in element.referenced_by: 49 | if any( 50 | str(ref_file.path).startswith(subdir) 51 | for subdir in subdirs_or_files 52 | ): 53 | keyword_elements.append(element) 54 | break 55 | else: 56 | keyword_elements.append(element) 57 | return keyword_elements 58 | 59 | @override 60 | def collect_file_paths( 61 | self, 62 | filtered_elements: List[Any], 63 | codebase: Codebase, 64 | subdirs_or_files: List[str], 65 | ) -> List[str]: 66 | referenced_paths = set() 67 | for keyword in filtered_elements: 68 | if isinstance(keyword, Keyword) and keyword.referenced_by: 69 | for ref_file in keyword.referenced_by: 70 | if str(ref_file.path).startswith(tuple(subdirs_or_files)): 71 | referenced_paths.add(str(ref_file.path)) 72 | return list(referenced_paths) 73 | 74 | @override 75 | async def execute( 76 | self, 77 | codebase: Codebase, 78 | prompt: str, 79 | subdirs_or_files: List[str], 80 | target_type: LLMMapFilterTargetType = "symbol_content", 81 | ) -> StrategyExecuteResult: 82 | prompt = f""" 83 | A code chunk containing the specified keywords is highly likely to meet the following criteria: 84 | 85 | {prompt} 86 | 87 | 88 | The objective of this requirement is to preliminarily filter files that are likely to meet specific content criteria based on the keywords they contain. 89 | Files with matching keywords will proceed to deeper analysis in the content filter (content_criteria) at a later stage (not in this run). 90 | 91 | """ 92 | return await super().execute(codebase, prompt, subdirs_or_files, target_type) 93 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/filter_keyword_by_vector_and_llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Strategy for filtering keywords using vector similarity search followed by LLM refinement. 3 | """ 4 | 5 | from typing import Any, List, Union, Optional, override 6 | 7 | from coderetrx.retrieval.strategy.base import StrategyExecuteResult 8 | from coderetrx.retrieval.strategy.base import FilterByVectorAndLLMStrategy 9 | from coderetrx.retrieval.smart_codebase import ( 10 | SmartCodebase as Codebase, 11 | LLMMapFilterTargetType, 12 | SimilaritySearchTargetType, 13 | ) 14 | from coderetrx.static import Keyword, Symbol, File 15 | 16 | 17 | class FilterKeywordByVectorAndLLMStrategy(FilterByVectorAndLLMStrategy): 18 | """Strategy to filter keywords using vector similarity search followed by LLM refinement.""" 19 | 20 | name: str = "FILTER_KEYWORD_BY_VECTOR_AND_LLM" 21 | 22 | @override 23 | def get_strategy_name(self) -> str: 24 | return self.name 25 | 26 | @override 27 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]: 28 | return ["keyword"] 29 | 30 | @override 31 | def get_target_type_for_llm(self) -> LLMMapFilterTargetType: 32 | return "keyword" 33 | 34 | @override 35 | def get_collection_size(self, codebase: Codebase) -> int: 36 | return len(codebase.keywords) 37 | 38 | @override 39 | def filter_elements( 40 | self, 41 | codebase: Codebase, 42 | elements: List[Any], 43 | target_type: LLMMapFilterTargetType = "symbol_content", 44 | subdirs_or_files: List[str] = [], 45 | ) -> List[Union[Keyword, Symbol, File]]: 46 | keyword_elements: List[Union[Keyword, Symbol, File]] = [] 47 | for element in elements: 48 | if isinstance(element, Keyword): 49 | # If subdirs_or_files is provided and codebase is available, filter by subdirs 50 | if subdirs_or_files and codebase: 51 | for ref_file in element.referenced_by: 52 | if any( 53 | str(ref_file.path).startswith(subdir) 54 | for subdir in subdirs_or_files 55 | ): 56 | keyword_elements.append(element) 57 | break 58 | else: 59 | keyword_elements.append(element) 60 | return keyword_elements 61 | 62 | @override 63 | def collect_file_paths( 64 | self, 65 | filtered_elements: List[Any], 66 | codebase: Codebase, 67 | subdirs_or_files: List[str], 68 | ) -> List[str]: 69 | referenced_paths = set() 70 | for keyword in filtered_elements: 71 | if isinstance(keyword, Keyword) and keyword.referenced_by: 72 | for ref_file in keyword.referenced_by: 73 | if str(ref_file.path).startswith(tuple(subdirs_or_files)): 74 | referenced_paths.add(str(ref_file.path)) 75 | return list(referenced_paths) 76 | 77 | @override 78 | async def execute( 79 | self, 80 | codebase: Codebase, 81 | prompt: str, 82 | subdirs_or_files: List[str], 83 | target_type: LLMMapFilterTargetType = "symbol_content", 84 | ) -> StrategyExecuteResult: 85 | prompt = f""" 86 | A code chunk containing the specified keywords is highly likely to meet the following criteria: 87 | 88 | {prompt} 89 | 90 | 91 | The objective of this requirement is to preliminarily filter files that are likely to meet specific content criteria based on the keywords they contain. 92 | Files with matching keywords will proceed to deeper analysis in the content filter (content_criteria) at a later stage (not in this run). 93 | 94 | """ 95 | return await super().execute(codebase, prompt, subdirs_or_files, target_type) 96 | -------------------------------------------------------------------------------- /coderetrx/tools/view_file.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from typing import Optional, Type, ClassVar 4 | from pathlib import Path 5 | from pydantic import BaseModel, Field 6 | from coderetrx.tools.base import BaseTool 7 | import aiofiles 8 | from coderetrx.utils.path import safe_join 9 | 10 | 11 | class ViewFileArgs(BaseModel): 12 | """Parameters for file viewing operation""" 13 | 14 | file_path: str = Field( 15 | ..., 16 | description="Absolute path of file to view", 17 | ) 18 | start_line: Optional[int] = Field( 19 | None, description="Starting line number. Optional, default to be 0.", ge=0 20 | ) 21 | end_line: Optional[int] = Field( 22 | None, 23 | description="Ending line number. Optional, default to be the last line.", 24 | ge=0, 25 | ) 26 | 27 | 28 | class ViewFileTool(BaseTool): 29 | name = "view_file" 30 | description = ( 31 | "View the contents of a file. The lines of the file are 0-indexed, and the output of this tool call will be the file contents from StartLine to EndLine. The line range should be less than or equal 1000.\n\n" 32 | "When using this tool to gather information, it's your responsibility to ensure you have the COMPLETE context. Specifically, each time you call this command you should:\n" 33 | "1) Assess if the file contents you viewed are sufficient to proceed with your task.\n" 34 | "2) Take note of where there are lines not shown. These are represented by <... XX more lines from [code item] not shown ...> in the tool response.\n" 35 | "3) If the file contents you have viewed are insufficient, and you suspect they may be in lines not shown, proactively call the tool again to view those lines.\n" 36 | "4) When in doubt, call this tool again to gather more information. Remember that partial file views may miss critical dependencies, imports, or functionality." 37 | ) 38 | args_schema: ClassVar[Type[ViewFileArgs]] = ViewFileArgs 39 | 40 | def forward(self, file_path: str, start_line: int, end_line: int) -> str: 41 | """Synchronous wrapper for async _run method.""" 42 | return self.run_sync( 43 | file_path=file_path, start_line=start_line, end_line=end_line 44 | ) 45 | 46 | async def _run( 47 | self, 48 | file_path: str, 49 | start_line: Optional[int] = None, 50 | end_line: Optional[int] = None, 51 | ) -> str: 52 | """View file content with optional line range.""" 53 | full_path = safe_join(self.repo_path, file_path.lstrip("/")) 54 | 55 | if not full_path.exists(): 56 | return "File Not Exists.\n" 57 | 58 | if full_path.is_dir(): 59 | return "Path is a directory, not a file.\n" 60 | 61 | try: 62 | async with aiofiles.open(full_path, "r", encoding="utf-8") as f: 63 | content = await f.read() 64 | except UnicodeDecodeError: 65 | return "File is not a text file or uses unsupported encoding.\n" 66 | 67 | lines = content.split("\n") 68 | total_lines = len(lines) 69 | 70 | # Handle line range 71 | if start_line is not None or end_line is not None: 72 | start = start_line if start_line is not None else 0 73 | end = end_line if end_line is not None else total_lines 74 | 75 | # Validate line numbers 76 | if start < 0 or end > total_lines or start > end: 77 | return f"Invalid line range (0-{total_lines}).\n" 78 | 79 | threshold_line = 1000 80 | if end - start >= threshold_line: 81 | return f"File is too large ({total_lines} lines), please specify a line range less than or equal {threshold_line}, or search keyword in the file.\n" 82 | 83 | selected_lines = lines[start:end] 84 | result = "\n".join(selected_lines) 85 | else: 86 | result = content 87 | 88 | # Add line count info unless we're showing the whole file 89 | if start_line is not None or end_line is not None: 90 | result += f"\n\n(This file has total {total_lines} lines.)" 91 | 92 | return result 93 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/languages.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from os import PathLike 3 | from typing import List, Literal, Optional 4 | from pathlib import Path 5 | import fnmatch 6 | 7 | 8 | IDXSupportedLanguage = Literal[ 9 | "javascript", 10 | "typescript", 11 | "python", 12 | "rust", 13 | "c", 14 | "cpp", 15 | "csharp", 16 | "go", 17 | "elixir", 18 | "java", 19 | "php", 20 | ] 21 | 22 | IDXSupportedTag = Literal[ 23 | "definition.function", 24 | "definition.type", 25 | "definition.method", 26 | "definition.class", 27 | "definition.interface", 28 | "definition.module", 29 | "definition.reexport", 30 | "definition.variable", 31 | "reference.implementation", 32 | "reference.call", 33 | "reference.class", 34 | "import", 35 | ] 36 | 37 | EXTENSION_MAP: dict[str, IDXSupportedLanguage] = { 38 | "js": "javascript", 39 | "ts": "typescript", 40 | "py": "python", 41 | "rs": "rust", 42 | "c": "c", 43 | "cpp": "cpp", 44 | "h": "c", 45 | "hpp": "cpp", 46 | "cs": "csharp", 47 | "go": "go", 48 | "ex": "elixir", 49 | "exs": "elixir", 50 | "java": "java", 51 | "php": "php", 52 | } 53 | 54 | BLOCKED_PATTERNS = ["*.min.js", "*_test.go"] 55 | 56 | DEP_FILES: List[str] = [ 57 | # JS / TS 58 | "package.json", 59 | # Python 60 | "requirements.txt", 61 | "setup.py", 62 | "Pipfile", 63 | "pyproject.toml", 64 | # Rust 65 | "Cargo.toml", 66 | # C/CPP 67 | "Makefile", 68 | # Golang 69 | "go.mod", 70 | # Elixir 71 | "mix.exs", 72 | "pom.xml", 73 | # Java 74 | "build.gradle", 75 | "build.gradle.kts", 76 | "build.sbt", 77 | "build.gradle", 78 | "build.gradle.kts", 79 | "build.sbt", 80 | ] 81 | 82 | BUILTIN_CRYPTO_LIBS: dict[IDXSupportedLanguage, List[str]] = { 83 | "javascript": ["crypto", "node:crypto", "webcrypto"], 84 | "typescript": ["crypto", "node:crypto", "webcrypto"], 85 | "python": ["hashlib", "hmac", "secrets", "ssl"], 86 | "rust": [], 87 | "c": [], 88 | "cpp": [], 89 | "csharp": ["System.Security.Cryptography"], 90 | "go": ["crypto"], 91 | "elixir": [":crypto"], 92 | "java": ["java.security", "javax.crypto"], 93 | "php": ["openssl", "hash", "sodium"], 94 | } 95 | 96 | FUNCLIKE_TAGS: List[IDXSupportedTag] = [ 97 | "definition.function", 98 | "definition.method", 99 | ] 100 | 101 | OBJLIKE_TAGS: List[IDXSupportedTag] = [ 102 | "definition.class", 103 | "definition.interface", 104 | "definition.module", 105 | "definition.reexport", 106 | "reference.implementation", 107 | ] 108 | 109 | PRIMARY_TAGS: List[IDXSupportedTag] = [ 110 | "definition.class", 111 | "definition.type", 112 | "definition.function", 113 | "definition.interface", 114 | "definition.method", 115 | "definition.module", 116 | # Special case: imports something and introduces a new symbol to the global pool 117 | "definition.reexport", 118 | "reference.implementation", 119 | ] 120 | 121 | REFERENCE_TAGS: List[IDXSupportedTag] = [ 122 | "reference.call", 123 | "reference.class", 124 | ] 125 | 126 | IMPORT_TAGS: List[IDXSupportedTag] = [ 127 | "import", 128 | ] 129 | 130 | VARIABLE_TAGS: List[IDXSupportedTag] = [ 131 | "definition.variable", 132 | ] 133 | 134 | 135 | def get_extension(filepath: PathLike | str) -> str: 136 | return str(filepath).split(".")[-1].lower() 137 | 138 | 139 | def is_blocked_file(filepath: PathLike | str) -> bool: 140 | return any(fnmatch.fnmatch(str(filepath), pattern) for pattern in BLOCKED_PATTERNS) 141 | 142 | 143 | def is_sourcecode(filepath: PathLike | str) -> bool: 144 | if is_blocked_file(filepath): 145 | return False 146 | extension = get_extension(filepath) 147 | return extension in EXTENSION_MAP 148 | 149 | 150 | def is_dependency(filepath: PathLike | str) -> bool: 151 | return str(filepath).split("/")[-1] in DEP_FILES 152 | 153 | 154 | def get_language(filepath: PathLike | str) -> Optional[IDXSupportedLanguage]: 155 | extension = get_extension(filepath) 156 | return EXTENSION_MAP.get(extension) 157 | -------------------------------------------------------------------------------- /coderetrx/utils/cost_tracking.py: -------------------------------------------------------------------------------- 1 | # Cost tracking for LLM calls; currently only for OpenRouter. 2 | 3 | import httpx 4 | from .logger import JsonLogger, LLMCallLog, ErrLog 5 | from pydantic import BaseModel, Field 6 | from typing import Dict, Optional 7 | from attrs import define 8 | from httpx import AsyncClient 9 | from .logger import read_logs 10 | from os import PathLike 11 | 12 | a_client = AsyncClient() 13 | 14 | 15 | @define 16 | class ModelCost: 17 | prompt: float 18 | completion: float 19 | 20 | 21 | type ModelCosts = Dict[str, ModelCost] 22 | 23 | def get_cost_hook(json_logger: JsonLogger, base_url: str = "https://openrouter.ai/api/v1"): 24 | async def on_response(response: httpx.Response): 25 | if not str(response.request.url).startswith(base_url): 26 | return 27 | 28 | try: 29 | await response.aread() 30 | response_json = response.json() 31 | model = response_json["model"] 32 | usage_data = response_json["usage"] 33 | json_logger.log( 34 | LLMCallLog( 35 | model=model, 36 | completion_id=response_json["id"], 37 | completion_tokens=usage_data["completion_tokens"], 38 | prompt_tokens=usage_data["prompt_tokens"], 39 | total_tokens=usage_data["total_tokens"], 40 | call_url=str(response.request.url), 41 | ) 42 | ) 43 | except Exception as e: 44 | json_logger.log( 45 | ErrLog( 46 | error_type="LLM_CALL_ERROR", 47 | error=str(e), 48 | ) 49 | ) 50 | return on_response 51 | 52 | async def load_model_costs() -> ModelCosts: 53 | all_model_costs_rsp = await a_client.get("https://openrouter.ai/api/v1/models") 54 | all_model_costs = all_model_costs_rsp.json() 55 | model_costs_parsed = {} 56 | for model in all_model_costs["data"]: 57 | try: 58 | model_id = model["id"] 59 | model_slug = model["canonical_slug"] 60 | model_pricing = model["pricing"] 61 | model_cost = ModelCost( 62 | prompt=float(model_pricing["prompt"]), 63 | completion=float(model_pricing["completion"]), 64 | ) 65 | model_costs_parsed[model_id] = model_cost 66 | model_costs_parsed[model_slug] = model_cost 67 | except Exception as e: 68 | print(f"Error parsing model {model_slug}: {e}") 69 | print(f"Loaded {len(model_costs_parsed)} models") 70 | return model_costs_parsed 71 | 72 | async def calc_llm_costs(log_path: PathLike | str, model_costs: Optional[ModelCosts] = None): 73 | from pathlib import Path 74 | 75 | # Check if log file exists 76 | if not Path(log_path).exists(): 77 | return 0.0 78 | 79 | model_costs_parsed = model_costs or await load_model_costs() 80 | total_cost = 0 81 | for log_item in read_logs(log_path): 82 | if log_item.data.type == "llm_call": 83 | log = log_item.data 84 | if log.model not in model_costs_parsed: 85 | print(f"Model {log.model} not found in model costs") 86 | continue 87 | cost_info = model_costs_parsed[log.model] 88 | total_cost += ( 89 | log.prompt_tokens * cost_info.prompt 90 | + log.completion_tokens * cost_info.completion 91 | ) 92 | return total_cost 93 | 94 | def calc_input_tokens(log_path: PathLike | str): 95 | from pathlib import Path 96 | 97 | # Check if log file exists 98 | if not Path(log_path).exists(): 99 | return 0 100 | 101 | total_input_tokens = 0 102 | for log_item in read_logs(log_path): 103 | if log_item.data.type == "llm_call": 104 | log = log_item.data 105 | total_input_tokens += log.prompt_tokens 106 | return total_input_tokens 107 | 108 | def calc_output_tokens(log_path: PathLike | str): 109 | from pathlib import Path 110 | 111 | # Check if log file exists 112 | if not Path(log_path).exists(): 113 | return 0 114 | 115 | total_output_tokens = 0 116 | for log_item in read_logs(log_path): 117 | if log_item.data.type == "llm_call": 118 | log = log_item.data 119 | total_output_tokens += log.completion_tokens 120 | return total_output_tokens 121 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/filter_symbol_content_by_vector_and_llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Strategy for filtering symbols using vector similarity search followed by LLM refinement. 3 | """ 4 | 5 | from collections import defaultdict 6 | from typing import List, Union, Optional, override 7 | from .base import FilterByVectorAndLLMStrategy 8 | from ..smart_codebase import ( 9 | SmartCodebase as Codebase, 10 | LLMMapFilterTargetType, 11 | SimilaritySearchTargetType, 12 | ) 13 | from coderetrx.static import Keyword, Symbol, File 14 | 15 | 16 | class FilterSymbolContentByVectorAndLLMStrategy(FilterByVectorAndLLMStrategy): 17 | """Strategy to filter symbols using vector similarity search followed by LLM refinement.""" 18 | 19 | @override 20 | def get_strategy_name(self) -> str: 21 | return "FILTER_SYMBOL_CONTENT_BY_VECTOR_AND_LLM" 22 | 23 | @override 24 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]: 25 | return ["symbol_content"] 26 | 27 | @override 28 | def get_target_type_for_llm(self) -> LLMMapFilterTargetType: 29 | return "symbol_content" 30 | 31 | @override 32 | def get_collection_size(self, codebase: Codebase) -> int: 33 | return len(codebase.symbols) 34 | 35 | @override 36 | def filter_elements( 37 | self, 38 | codebase: Codebase, 39 | elements: List[Symbol], 40 | target_type: LLMMapFilterTargetType = "symbol_content", 41 | subdirs_or_files: List[str] = [], 42 | ) -> List[Union[Keyword, Symbol, File]]: 43 | filtered_symbols: List[Symbol] = [] 44 | for element in elements: 45 | if not isinstance(element, Symbol): 46 | continue 47 | # If subdirs_or_files is provided and codebase is available, filter by subdirs 48 | if subdirs_or_files and codebase: 49 | # Get the relative path from the codebase directory 50 | rpath = str(element.file.path) 51 | if any(rpath.startswith(subdir) for subdir in subdirs_or_files): 52 | filtered_symbols.append(element) 53 | else: 54 | filtered_symbols.append(element) 55 | if target_type == "class_content": 56 | # If the target type is class_content, filter symbols that are classes 57 | filtered_symbols = [ 58 | elem for elem in filtered_symbols if elem.type == "class" 59 | ] 60 | elif target_type == "function_content": 61 | # If the target type is function_content, filter symbols that are functions 62 | filtered_symbols = [ 63 | elem for elem in filtered_symbols if elem.type == "function" 64 | ] 65 | elif target_type == "leaf_symbol_content": 66 | # If the target type is leaf_symbol_content, filter symbols that are leaves 67 | 68 | parent_of_symbol = { 69 | symbol.id: symbol.chunk.parent.id 70 | for symbol in codebase.symbols 71 | if symbol.chunk.parent 72 | } 73 | childs_of_symbol = defaultdict(list) 74 | for child, parent in parent_of_symbol.items(): 75 | childs_of_symbol[parent].append(child) 76 | filtered_symbols = [ 77 | elem for elem in filtered_symbols if not childs_of_symbol[elem.id] 78 | ] 79 | elif target_type == "root_symbol_content": 80 | parent_of_symbol = { 81 | symbol.id: symbol.chunk.parent.id 82 | for symbol in codebase.symbols 83 | if symbol.chunk.parent 84 | } 85 | filtered_symbols = [ 86 | elem for elem in filtered_symbols if not parent_of_symbol[elem.id] 87 | ] 88 | return filtered_symbols 89 | 90 | @override 91 | def collect_file_paths( 92 | self, 93 | filtered_elements: List[Symbol], 94 | codebase: Codebase, 95 | subdirs_or_files: List[str], 96 | ) -> List[str]: 97 | """Collect file paths from the filtered symbols.""" 98 | file_paths = set() 99 | for symbol in filtered_elements: 100 | if isinstance(symbol, Symbol): 101 | file_path = str(symbol.file.path) 102 | if not subdirs_or_files or any( 103 | file_path.startswith(subdir) for subdir in subdirs_or_files 104 | ): 105 | file_paths.add(file_path) 106 | return list(file_paths) 107 | -------------------------------------------------------------------------------- /bench/repos.lock: -------------------------------------------------------------------------------- 1 | [repositories.cloudflare_boringtun] 2 | url = "https://github.com/cloudflare/boringtun" 3 | commit = "2f3c85f5c4a601018c10b464b1ca890d9504bf6e" 4 | 5 | [repositories.cloudflare_boringtun.stats] 6 | num_files = 30 7 | num_lines = 7937 8 | num_tokens = 73507 9 | primary_language = "rust" 10 | 11 | [repositories.rosenpass_rosenpass] 12 | url = "https://github.com/rosenpass/rosenpass" 13 | commit = "d98815fa7f8dbe6fd2ea2e024e17447bf587d134" 14 | 15 | [repositories.rosenpass_rosenpass.stats] 16 | num_files = 147 17 | num_lines = 30063 18 | num_tokens = 237990 19 | primary_language = "rust" 20 | 21 | [repositories.neondatabase_neon] 22 | url = "https://github.com/neondatabase/neon" 23 | commit = "1b935b1958a7f508807f1bd241e715f33cdc386e" 24 | 25 | [repositories.neondatabase_neon.stats] 26 | num_files = 934 27 | num_lines = 375060 28 | num_tokens = 3021451 29 | primary_language = "rust" 30 | 31 | [repositories.GianisTsol_python-p2p] 32 | url = "https://github.com/GianisTsol/python-p2p" 33 | commit = "69056e1634ec108c7863d36b837f93043ee89ac6" 34 | 35 | [repositories.GianisTsol_python-p2p.stats] 36 | num_files = 8 37 | num_lines = 1189 38 | num_tokens = 11875 39 | primary_language = "python" 40 | 41 | [repositories.magic-wormhole_magic-wormhole] 42 | url = "https://github.com/magic-wormhole/magic-wormhole" 43 | commit = "e5f2ba2c77fe041ffd83db2045591a4d5a65a8de" 44 | 45 | [repositories.magic-wormhole_magic-wormhole.stats] 46 | num_files = 99 47 | num_lines = 27538 48 | num_tokens = 175095 49 | primary_language = "python" 50 | 51 | [repositories.zulip_zulip] 52 | url = "https://github.com/zulip/zulip" 53 | commit = "e243fc67fac9f228dffbbbc5a0ba30abb99e298e" 54 | 55 | [repositories.zulip_zulip.stats] 56 | num_files = 2300 57 | num_lines = 499528 58 | num_tokens = 5711007 59 | primary_language = "python" 60 | 61 | [repositories.TecharoHQ_anubis] 62 | url = "https://github.com/TecharoHQ/anubis" 63 | commit = "6e2eeb9e6562e9330024aa3bcb393d223713e40e" 64 | 65 | [repositories.TecharoHQ_anubis.stats] 66 | num_files = 48 67 | num_lines = 5163 68 | num_tokens = 34501 69 | primary_language = "go" 70 | 71 | [repositories.google_fscrypt] 72 | url = "https://github.com/google/fscrypt" 73 | commit = "827c13689b39814552a3a18449f922b123725b49" 74 | 75 | [repositories.google_fscrypt.stats] 76 | num_files = 42 77 | num_lines = 11698 78 | num_tokens = 18502 79 | primary_language = "go" 80 | 81 | [repositories.ethereum_go-ethereum] 82 | url = "https://github.com/ethereum/go-ethereum" 83 | commit = "0983cd789ee1905aedaed96f72793e5af8466f34" 84 | 85 | [repositories.ethereum_go-ethereum.stats] 86 | num_files = 920 87 | num_lines = 272576 88 | num_tokens = 1104733 89 | primary_language = "go" 90 | 91 | [repositories.Mastercard_client-encryption-java] 92 | url = "https://github.com/Mastercard/client-encryption-java" 93 | commit = "20a5b880b39a590d10785e4b6dbd957112548e31" 94 | 95 | [repositories.Mastercard_client-encryption-java.stats] 96 | num_files = 86 97 | num_lines = 8654 98 | num_tokens = 188192 99 | primary_language = "java" 100 | 101 | [repositories.keycloak_keycloak] 102 | url = "https://github.com/keycloak/keycloak" 103 | commit = "c37e597117a0999960106e6a9a420f514249ccd4" 104 | 105 | [repositories.keycloak_keycloak.stats] 106 | num_files = 7031 107 | num_lines = 955351 108 | num_tokens = 12068541 109 | primary_language = "java" 110 | 111 | [repositories.cryptomator_cryptomator] 112 | url = "https://github.com/cryptomator/cryptomator" 113 | commit = "0178ca4080b9adb774006af53e3517a104476077" 114 | 115 | [repositories.cryptomator_cryptomator.stats] 116 | num_files = 387 117 | num_lines = 29637 118 | num_tokens = 363437 119 | primary_language = "java" 120 | 121 | [repositories.padloc_padloc] 122 | url = "https://github.com/padloc/padloc" 123 | commit = "1d2e9129d65afad72cf7fefc80fd5eb20b5671d9" 124 | 125 | [repositories.padloc_padloc.stats] 126 | num_files = 304 127 | num_lines = 136352 128 | num_tokens = 1363775 129 | primary_language = "javascript" 130 | 131 | [repositories.dani-garcia_vaultwarden] 132 | url = "https://github.com/dani-garcia/vaultwarden" 133 | commit = "0d3f283c3720b82e97c1aa383a66ebc8b4951dfb" 134 | 135 | [repositories.dani-garcia_vaultwarden.stats] 136 | num_files = 67 137 | num_lines = 62609 138 | num_tokens = 401294 139 | primary_language = "javascript" 140 | 141 | [repositories.swiftyapp_swifty] 142 | url = "https://github.com/swiftyapp/swifty" 143 | commit = "858209dc2b67183065de40e298f3fa1e60aad2c3" 144 | 145 | [repositories.swiftyapp_swifty.stats] 146 | num_files = 190 147 | num_lines = 8295 148 | num_tokens = 43671 149 | primary_language = "javascript" 150 | -------------------------------------------------------------------------------- /test/impl/default/test_smart_codebase.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | load_dotenv() 3 | import json 4 | from pathlib import Path 5 | from coderetrx.impl.default import CodebaseFactory 6 | from coderetrx.impl.default import SmartCodebase 7 | import os 8 | import asyncio 9 | import unittest 10 | from coderetrx.utils.embedding import create_documents_embedding 11 | from coderetrx.utils.git import clone_repo_if_not_exists, get_repo_id 12 | from coderetrx.utils.path import get_data_dir 13 | import logging 14 | 15 | logger = logging.getLogger(__name__) 16 | logging.basicConfig(level=logging.INFO) 17 | 18 | TEST_REPOS = ["https://github.com/apache/dubbo-admin.git"] 19 | 20 | 21 | def prepare_codebase(repo_url: str, repo_path: Path): 22 | """Helper function to prepare codebase for testing""" 23 | database_path = get_data_dir() / "databases" / f"{get_repo_id(repo_url)}.json" 24 | # Create a test codebase 25 | clone_repo_if_not_exists(repo_url, str(repo_path)) 26 | 27 | if database_path.exists(): 28 | codebase = CodebaseFactory.from_json( 29 | json.load(open(database_path, "r", encoding="utf-8")) 30 | ) 31 | else: 32 | codebase = CodebaseFactory.new(get_repo_id(repo_url), repo_path) 33 | with open(f"{repo_path}.json", "w") as f: 34 | json.dump(codebase.to_json(), f, indent=4) 35 | return codebase 36 | 37 | 38 | class TestSmartCodebase(unittest.TestCase): 39 | """Test SmartCodebase functionality""" 40 | 41 | def setUp(self): 42 | """Set up test environment""" 43 | self.repo_url = TEST_REPOS[0] 44 | self.repo_path = get_data_dir() / "repos" / get_repo_id(self.repo_url) 45 | 46 | def test_codebase_initialization(self): 47 | """Test codebase initialization""" 48 | codebase = prepare_codebase(self.repo_url, self.repo_path) 49 | self.assertIsInstance(codebase, SmartCodebase) 50 | self.assertEqual(codebase.id, get_repo_id(self.repo_url)) 51 | 52 | def test_keyword_extraction(self): 53 | """Test keyword extraction functionality""" 54 | os.environ["KEYWORD_EMBEDDING"] = "True" 55 | codebase = prepare_codebase(self.repo_url, self.repo_path) 56 | 57 | # Verify keywords were extracted 58 | self.assertGreater(len(codebase.keywords), 0) 59 | logger.info(f"Extracted {len(codebase.keywords)} keywords") 60 | logger.info(f"Sample keywords: {[k.content for k in codebase.keywords[:10]]}") 61 | 62 | def test_keyword_search(self): 63 | """Test keyword search functionality""" 64 | os.environ["KEYWORD_EMBEDDING"] = "True" 65 | codebase = prepare_codebase(self.repo_url, self.repo_path) 66 | 67 | try: 68 | # Generate embeddings 69 | test_query = "Is the code snippet used for user authentication?" 70 | logger.info(f"\nTesting keyword search with query: '{test_query}'") 71 | 72 | # Create a searcher 73 | results = asyncio.run(codebase.similarity_search( 74 | target_types=["keyword"], query=test_query 75 | )) 76 | 77 | logger.info("Search results:") 78 | logger.info(results) 79 | 80 | # Verify results 81 | self.assertIsNotNone(results) 82 | 83 | except Exception as e: 84 | self.fail(f"Error during keyword search test: {e}") 85 | 86 | def test_symbol_extraction(self): 87 | """Test symbol extraction functionality""" 88 | os.environ["SYMBOL_NAME_EMBEDDING"] = "True" 89 | codebase = prepare_codebase(self.repo_url, self.repo_path) 90 | 91 | # Verify symbols were extracted 92 | self.assertGreater(len(codebase.symbols), 0) 93 | logger.info(f"Sample symbols: {[k.name for k in codebase.symbols[:10]]}") 94 | 95 | def test_symbol_search(self): 96 | """Test symbol search functionality""" 97 | os.environ["SYMBOL_NAME_EMBEDDING"] = "True" 98 | codebase = prepare_codebase(self.repo_url, self.repo_path) 99 | 100 | try: 101 | # Test search with a sample query 102 | test_query = "Is the code snippet used for user authentication?" 103 | logger.info(f"\nTesting symbol search with query: '{test_query}'") 104 | 105 | # Create a searcher 106 | results = asyncio.run(codebase.similarity_search( 107 | target_types=["symbol_name"], query=test_query 108 | )) 109 | 110 | logger.info("Search results:") 111 | logger.info(results) 112 | 113 | # Verify results 114 | self.assertIsNotNone(results) 115 | 116 | except Exception as e: 117 | self.fail(f"Error during symbol search test: {e}") 118 | 119 | 120 | # Run tests if specified 121 | if __name__ == "__main__": 122 | # Run unittest tests 123 | unittest.main(argv=['first-arg-is-ignored'], exit=False) -------------------------------------------------------------------------------- /coderetrx/tools/get_references.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from typing import List, ClassVar, Type 4 | from pathlib import Path 5 | from pydantic import BaseModel, Field 6 | from coderetrx.static.ripgrep import ripgrep_search # type: ignore 7 | from coderetrx.tools.base import BaseTool 8 | from coderetrx.utils.llm import count_tokens_openai 9 | from coderetrx.tools.keyword_search import KeywordSearchResult 10 | 11 | 12 | class GetReferenceArgs(BaseModel): 13 | symbol_name: str = Field( 14 | description="The symbolic name whose reference is to be retrieved." 15 | ) 16 | 17 | 18 | class GetReferenceResult(KeywordSearchResult): 19 | symbol_idx: int = Field( 20 | default=-1, 21 | description="The index of the symbol in the list of symbols. if -1, it means not applicable.", 22 | ) 23 | symbol_name: str = Field( 24 | default="", description="The symbolic name whose reference is to be retrieved." 25 | ) 26 | symbol_location: str = Field(default="", description="The location of the symbol.") 27 | 28 | @classmethod 29 | def repr(cls, entries: list["GetReferenceResult"]): 30 | """ 31 | convert a list of GetReferenceResult to a readable string 32 | """ 33 | if not entries: 34 | return "No entries found." 35 | 36 | result = "" 37 | current_symbol_idx = -1 38 | entries_count = {} 39 | entries = sorted( 40 | entries, key=lambda x: (x.symbol_idx, x.start_line, x.end_line) 41 | ) 42 | # symbol_idx is -1 means the multi-symbol cross-reference is not applicable 43 | if entries[0].symbol_idx == -1: 44 | for idx, ref in enumerate(entries): 45 | result += f"## **{idx}. {ref.path}:L{ref.start_line}:{ref.end_line}**\n" 46 | result += f"{ref.content}\n" 47 | return result 48 | 49 | threshold_tokens = 5000 # reference: 6262 tokens = 830 lines python 50 | cur_tokens = 0 51 | for idx, ref in enumerate(entries): 52 | each_result = "" 53 | # When processing a new symbol, add a symbol title. 54 | if ref.symbol_idx != current_symbol_idx: 55 | current_symbol_idx = ref.symbol_idx 56 | entries_count[current_symbol_idx] = 0 57 | each_result += f"# **{current_symbol_idx}. {ref.symbol_name}** location:{ref.symbol_location}\n" 58 | 59 | if not ref.path: 60 | each_result += "No entries Found.\n" 61 | continue 62 | 63 | entries_count[current_symbol_idx] += 1 64 | 65 | if entries_count[current_symbol_idx] == 1: 66 | # calulate the total reference count of the current symbol 67 | total_refs = sum( 68 | 1 for r in entries if r.symbol_idx == current_symbol_idx and r.path 69 | ) 70 | each_result += f"{total_refs} entries found:\n" 71 | 72 | each_result += f"## **{entries_count[current_symbol_idx]}. {ref.path}:L{ref.start_line}:{ref.end_line}**\n" 73 | each_result += f"{ref.content}\n" 74 | 75 | cur_tokens += count_tokens_openai(each_result) 76 | if cur_tokens > threshold_tokens: 77 | result += f"Omitted...\nReturn contents are too long. Please refine your query.\n" 78 | break 79 | result += each_result 80 | 81 | return result 82 | 83 | 84 | class GetReferenceTool(BaseTool): 85 | name = "get_reference" 86 | description = ( 87 | "Used to find symbol direct references in the codebase, should be used when tracking code usage. " 88 | "Finding multiple levels of references requires multiple calls." 89 | ) 90 | args_schema: ClassVar[Type[GetReferenceArgs]] = GetReferenceArgs 91 | 92 | def forward(self, symbol_name: str) -> str: 93 | """Synchronous wrapper for async _run method.""" 94 | return self.run_sync(symbol_name=symbol_name) 95 | 96 | async def _run(self, symbol_name: str) -> list[GetReferenceResult]: 97 | """Find references to a symbol in the codebase""" 98 | # Convert the query to a list of regexes 99 | regexes = [f"\\b{symbol_name}\\b"] 100 | 101 | # Call ripgrep_search with the appropriate parameters 102 | rg_results = await ripgrep_search( 103 | search_dir=Path(self.repo_path), 104 | regexes=regexes, 105 | case_sensitive=True, 106 | exclude_file_pattern=".git", 107 | ) 108 | 109 | # Convert GrepMatchResult to KeywordSearchResult 110 | results = [] 111 | for result in rg_results: 112 | search_result = GetReferenceResult( 113 | path=str(result.file_path), 114 | start_line=result.line_number, 115 | end_line=result.line_number, 116 | content=result.line_text, 117 | ) 118 | results.append(search_result) 119 | 120 | return results 121 | -------------------------------------------------------------------------------- /coderetrx/utils/stats.py: -------------------------------------------------------------------------------- 1 | from ._extras import require_extra 2 | 3 | require_extra("tiktoken", "stats") 4 | 5 | from pydantic import BaseModel 6 | from collections import Counter 7 | from typing import Dict, List, ClassVar 8 | from pathlib import Path 9 | import tiktoken 10 | from concurrent.futures import ThreadPoolExecutor, as_completed 11 | 12 | from coderetrx.static.codebase import Codebase, File, CodeChunk, ChunkType 13 | from coderetrx.static.codebase.languages import get_language 14 | 15 | 16 | class ChunkStats(BaseModel): 17 | chunk_id: str 18 | file_path: str 19 | num_lines: int 20 | num_tokens: int 21 | 22 | # Class variable for the tokenizer 23 | _tokenizer: ClassVar = tiktoken.encoding_for_model("gpt-4o") 24 | 25 | @classmethod 26 | def from_chunk(cls, chunk: CodeChunk, strict: bool = False) -> "ChunkStats": 27 | lines = chunk.lines() 28 | text = "\n".join(lines) 29 | 30 | SAMPLE_SIZE = 100_000 31 | if not strict and len(text) > SAMPLE_SIZE: 32 | tokens = len(cls._tokenizer.encode(text[:SAMPLE_SIZE], disallowed_special=())) * ( 33 | len(text) / SAMPLE_SIZE 34 | ) 35 | else: 36 | tokens = len(cls._tokenizer.encode(text, disallowed_special=())) 37 | 38 | return cls( 39 | chunk_id=chunk.id, 40 | file_path=str(chunk.src.path), 41 | num_lines=len(lines), 42 | num_tokens=int(tokens), 43 | ) 44 | 45 | 46 | class FileStats(BaseModel): 47 | file_path: str 48 | num_chunks: int 49 | num_lines: int 50 | num_tokens: int 51 | 52 | @classmethod 53 | def from_file(cls, file: File) -> "FileStats": 54 | with ThreadPoolExecutor() as executor: 55 | # Process chunks in parallel 56 | future_to_chunk = { 57 | executor.submit(ChunkStats.from_chunk, chunk): chunk 58 | for chunk in file.chunks 59 | if chunk.type in [ChunkType.PRIMARY, ChunkType.IMPORT] 60 | } 61 | chunk_stats = [future.result() for future in as_completed(future_to_chunk)] 62 | 63 | return cls( 64 | file_path=str(file.path), 65 | num_chunks=len(file.chunks), 66 | num_lines=len(file.lines), 67 | num_tokens=sum(stat.num_tokens for stat in chunk_stats), 68 | ) 69 | 70 | 71 | class CodebaseStats(BaseModel): 72 | codebase_id: str 73 | num_chunks: int 74 | num_files: int 75 | num_lines: int 76 | num_tokens: int 77 | language_distribution: Dict[str, int] # Maps language name to line count 78 | primary_language: str # The language with the most lines of code 79 | file_stats: List[FileStats] 80 | chunk_stats: List[ChunkStats] 81 | 82 | @classmethod 83 | def from_codebase(cls, codebase: Codebase) -> "CodebaseStats": 84 | with ThreadPoolExecutor() as executor: 85 | # Process files in parallel 86 | future_to_file = { 87 | executor.submit(FileStats.from_file, file): file_path 88 | for file_path, file in codebase.source_files.items() 89 | } 90 | file_stats = [future.result() for future in as_completed(future_to_file)] 91 | 92 | # Calculate language distribution by line count 93 | language_counts = Counter() 94 | for file_path, file in codebase.source_files.items(): 95 | lang = get_language(Path(file_path)) 96 | if lang: 97 | language_counts[lang] += len(file.lines) 98 | 99 | # Determine primary language (language with most lines) 100 | primary_language = max(language_counts, key=language_counts.get) if language_counts else "Unknown" 101 | 102 | # Collect all chunk stats 103 | chunk_stats = [] 104 | for file in codebase.source_files.values(): 105 | for chunk in file.chunks: 106 | if chunk.type in [ChunkType.PRIMARY, ChunkType.IMPORT]: 107 | chunk_stats.append(ChunkStats.from_chunk(chunk)) 108 | 109 | return cls( 110 | codebase_id=codebase.id, 111 | num_chunks=len(codebase.all_chunks), 112 | num_files=len(codebase.source_files), 113 | num_lines=sum(stat.num_lines for stat in file_stats), 114 | num_tokens=sum(stat.num_tokens for stat in file_stats), 115 | language_distribution=dict(language_counts), 116 | primary_language=primary_language, 117 | file_stats=file_stats, 118 | chunk_stats=chunk_stats, 119 | ) 120 | 121 | 122 | if __name__ == "__main__": 123 | import argparse 124 | from textwrap import dedent 125 | 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument("target", type=str, help="Path to codebase") 128 | args = parser.parse_args() 129 | 130 | codebase = Codebase.new("tmp", Path(args.target)) 131 | codebase.init_chunks() 132 | stats = CodebaseStats.from_codebase(codebase) 133 | print( 134 | dedent(f""" 135 | Codebase Stats for {args.target}: 136 | - Number of chunks: {stats.num_chunks} 137 | - Number of files: {stats.num_files} 138 | - Number of lines: {stats.num_lines} 139 | - Number of tokens: {stats.num_tokens} 140 | - Primary language: {stats.primary_language} 141 | - Language distribution: {stats.language_distribution} 142 | """) 143 | ) 144 | -------------------------------------------------------------------------------- /coderetrx/retrieval/strategy/adaptive_filter_symbol_content_by_vector_and_llm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Strategy for adaptive filtering of symbols using vector similarity search followed by LLM refinement. 3 | """ 4 | 5 | from collections import defaultdict 6 | from typing import List, Union, Optional, override, Literal, Any 7 | from .base import ( 8 | AdaptiveFilterByVectorAndLLMStrategy, 9 | FilterByVectorAndLLMStrategy, 10 | StrategyExecuteResult, 11 | ) 12 | from ..smart_codebase import ( 13 | SmartCodebase as Codebase, 14 | LLMMapFilterTargetType, 15 | SimilaritySearchTargetType, 16 | ) 17 | from coderetrx.static import Keyword, Symbol, File 18 | 19 | 20 | class AdaptiveFilterSymbolContentByVectorAndLLMStrategy( 21 | AdaptiveFilterByVectorAndLLMStrategy 22 | ): 23 | """Strategy to filter symbols using adaptive vector similarity search followed by LLM refinement.""" 24 | 25 | name: str = "ADAPTIVE_FILTER_SYMBOL_CONTENT_BY_VECTOR_AND_LLM" 26 | 27 | @override 28 | def get_strategy_name(self) -> str: 29 | return self.name 30 | 31 | @override 32 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]: 33 | return ["symbol_content"] 34 | 35 | @override 36 | def get_target_type_for_llm(self) -> LLMMapFilterTargetType: 37 | return "symbol_content" 38 | 39 | @override 40 | def get_collection_size(self, codebase: Codebase) -> int: 41 | return len(codebase.symbols) 42 | 43 | @override 44 | def filter_elements( 45 | self, 46 | codebase: Codebase, 47 | elements: List[Any], 48 | target_type: LLMMapFilterTargetType = "symbol_content", 49 | subdirs_or_files: List[str] = [], 50 | ) -> List[Union[Keyword, Symbol, File]]: 51 | filtered_symbols: List[Symbol] = [] 52 | for element in elements: 53 | if not isinstance(element, Symbol): 54 | continue 55 | # If subdirs_or_files is provided and codebase is available, filter by subdirs 56 | if subdirs_or_files and codebase: 57 | # Get the relative path from the codebase directory 58 | rpath = str(element.file.path) 59 | if any(rpath.startswith(subdir) for subdir in subdirs_or_files): 60 | filtered_symbols.append(element) 61 | else: 62 | filtered_symbols.append(element) 63 | if target_type == "class_content": 64 | # If the target type is class_content, filter symbols that are classes 65 | filtered_symbols = [ 66 | elem for elem in filtered_symbols if elem.type == "class" 67 | ] 68 | elif target_type == "function_content": 69 | # If the target type is function_content, filter symbols that are functions 70 | filtered_symbols = [ 71 | elem for elem in filtered_symbols if elem.type == "function" 72 | ] 73 | elif target_type == "leaf_symbol_content": 74 | # If the target type is leaf_symbol_content, filter symbols that are leaves 75 | 76 | parent_of_symbol = { 77 | symbol.id: symbol.chunk.parent.id 78 | for symbol in codebase.symbols 79 | if symbol.chunk.parent 80 | } 81 | childs_of_symbol = defaultdict(list) 82 | for child, parent in parent_of_symbol.items(): 83 | childs_of_symbol[parent].append(child) 84 | filtered_symbols = [ 85 | elem for elem in filtered_symbols if not childs_of_symbol[elem.id] 86 | ] 87 | elif target_type == "root_symbol_content": 88 | parent_of_symbol = { 89 | symbol.id: symbol.chunk.parent.id 90 | for symbol in codebase.symbols 91 | if symbol.chunk.parent 92 | } 93 | filtered_symbols = [ 94 | elem for elem in filtered_symbols if not parent_of_symbol[elem.id] 95 | ] 96 | return filtered_symbols # type: ignore 97 | 98 | @override 99 | def collect_file_paths( 100 | self, 101 | filtered_elements: List[Any], 102 | codebase: Codebase, 103 | subdirs_or_files: List[str], 104 | ) -> List[str]: 105 | file_paths = [] 106 | for symbol in filtered_elements: 107 | if isinstance(symbol, Symbol): 108 | file_path = str(symbol.file.path) 109 | if file_path.startswith(tuple(subdirs_or_files)): 110 | file_paths.append(file_path) 111 | return file_paths 112 | 113 | @override 114 | async def execute( 115 | self, 116 | codebase: Codebase, 117 | prompt: str, 118 | subdirs_or_files: List[str], 119 | target_type: LLMMapFilterTargetType = "symbol_content", 120 | ) -> StrategyExecuteResult: 121 | prompt = f""" 122 | requirement: A code chunk with this name is highly likely to meet the following criteria: 123 | 124 | {prompt} 125 | 126 | 127 | The objective of this requirement is to preliminarily identify code chunks that are likely to meet specific content criteria based on their names. 128 | Code chunks with matching names will proceed to deeper analysis in the content filter (content_criteria) at a later stage (not in this run). 129 | 130 | """ 131 | return await super().execute(codebase, prompt, subdirs_or_files, target_type) 132 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/parsers/codeql/queries.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | from pathlib import Path 3 | import logging 4 | from ...languages import IDXSupportedLanguage 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class CodeQLQueryTemplates: 10 | """ 11 | Provides CodeQL query templates for different languages and symbol types. 12 | 13 | Each template is loaded from .ql files in the queries/codeql directory, 14 | ensuring consistent results between parsers and maintainable query files. 15 | """ 16 | 17 | # Supported query types for each language 18 | LANGUAGE_QUERY_TYPES = { 19 | "python": ["functions", "classes", "imports"], 20 | "javascript": ["functions", "classes", "imports"], 21 | "typescript": ["functions", "classes", "imports"], 22 | "java": ["functions", "classes", "imports"], 23 | "cpp": ["functions", "classes", "includes"], 24 | "c": ["functions", "classes", "includes"], 25 | "go": ["functions", "types", "imports"], 26 | "rust": ["functions", "structs"], 27 | "csharp": ["functions", "classes", "imports"], 28 | } 29 | 30 | @classmethod 31 | def get_supported_languages(cls) -> List[IDXSupportedLanguage]: 32 | """ 33 | Get list of languages that have CodeQL query templates. 34 | 35 | Returns: 36 | List of supported languages 37 | """ 38 | from typing import cast 39 | 40 | supported = [] 41 | for language_str in cls.LANGUAGE_QUERY_TYPES.keys(): 42 | language = cast(IDXSupportedLanguage, language_str) 43 | # Check if at least one query file exists for this language 44 | query_types = cls.LANGUAGE_QUERY_TYPES[language_str] 45 | for query_type in query_types: 46 | if cls._get_query_file_path(language, query_type).exists(): 47 | supported.append(language) 48 | break 49 | 50 | return supported 51 | 52 | @classmethod 53 | def get_query(cls, language: IDXSupportedLanguage, query_type: str) -> Path: 54 | """ 55 | Get a specific query for a language. 56 | 57 | Args: 58 | language: The language 59 | query_type: The type of query (e.g., 'functions', 'classes') 60 | 61 | Returns: 62 | Query text 63 | 64 | Raises: 65 | KeyError: If language or query type is not supported 66 | FileNotFoundError: If query file is not found 67 | """ 68 | if language not in cls.LANGUAGE_QUERY_TYPES: 69 | raise KeyError(f"Language not supported: {language}") 70 | 71 | if query_type not in cls.LANGUAGE_QUERY_TYPES[language]: 72 | raise KeyError( 73 | f"Query type '{query_type}' not available for language: {language}" 74 | ) 75 | 76 | query_file = cls._get_query_file_path(language, query_type) 77 | if not query_file.exists(): 78 | raise FileNotFoundError(f"CodeQL query file not found: {query_file}") 79 | 80 | return query_file 81 | 82 | @classmethod 83 | def has_query(cls, language: IDXSupportedLanguage, query_type: str) -> bool: 84 | """ 85 | Check if a specific query is available for a language. 86 | 87 | Args: 88 | language: The language 89 | query_type: The type of query 90 | 91 | Returns: 92 | True if query is available, False otherwise 93 | """ 94 | if language not in cls.LANGUAGE_QUERY_TYPES: 95 | return False 96 | 97 | if query_type not in cls.LANGUAGE_QUERY_TYPES[language]: 98 | return False 99 | 100 | query_file = cls._get_query_file_path(language, query_type) 101 | return query_file.exists() 102 | 103 | @classmethod 104 | def get_available_queries(cls) -> Dict[str, List[str]]: 105 | """ 106 | Get all available queries organized by language. 107 | 108 | Returns: 109 | Dictionary mapping language -> list of available query types 110 | """ 111 | from typing import cast 112 | 113 | available = {} 114 | for language_str in cls.LANGUAGE_QUERY_TYPES: 115 | language = cast(IDXSupportedLanguage, language_str) 116 | query_types = [] 117 | for query_type in cls.LANGUAGE_QUERY_TYPES[language_str]: 118 | if cls.has_query(language, query_type): 119 | query_types.append(query_type) 120 | if query_types: 121 | available[language_str] = query_types 122 | 123 | return available 124 | 125 | @classmethod 126 | def _get_query_file_path( 127 | cls, language: IDXSupportedLanguage, query_type: str 128 | ) -> Path: 129 | """ 130 | Get the file path for a specific query. 131 | 132 | Args: 133 | language: The language 134 | query_type: The type of query (e.g., 'functions', 'classes') 135 | 136 | Returns: 137 | Path to the query file 138 | """ 139 | # Get the directory where this file is located 140 | current_file = Path(__file__) 141 | # Navigate to the queries directory: ../../../queries/codeql/ 142 | queries_dir = current_file.parent.parent.parent / "queries" / "codeql" 143 | 144 | # Construct the path: queries/codeql/{language}/{query_type}.ql 145 | query_file_path = queries_dir / language / f"{query_type}.ql" 146 | 147 | return query_file_path 148 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/parsers/treesitter/queries.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | from pathlib import Path 3 | import logging 4 | from ...languages import IDXSupportedLanguage 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class TreeSitterQueryTemplates: 10 | """ 11 | Provides Tree-sitter query templates for different languages and symbol types. 12 | 13 | Each template is loaded from .scm files in the queries/treesitter directory, 14 | ensuring consistent results between parsers and maintainable query files. 15 | """ 16 | 17 | # Supported query types for each language 18 | LANGUAGE_QUERY_TYPES = { 19 | "javascript": ["tags", "tests", "fine_imports"], 20 | "typescript": ["tags", "tests", "fine_imports"], 21 | "python": ["tags", "tests", "fine_imports"], 22 | "rust": ["tags", "tests", "fine_imports"], 23 | "c": ["tags", "fine_imports"], 24 | "cpp": ["tags", "fine_imports"], 25 | "csharp": ["tags", "fine_imports"], 26 | "go": ["tags", "tests", "fine_imports"], 27 | "elixir": ["tags", "tests", "fine_imports"], 28 | "java": ["tags", "fine_imports"], 29 | "php": ["tags", "tests", "fine_imports"], 30 | } 31 | 32 | @classmethod 33 | def get_supported_languages(cls) -> List[IDXSupportedLanguage]: 34 | """ 35 | Get list of languages that have Tree-sitter query templates. 36 | 37 | Returns: 38 | List of supported languages 39 | """ 40 | from typing import cast 41 | 42 | supported = [] 43 | for language_str in cls.LANGUAGE_QUERY_TYPES.keys(): 44 | language = cast(IDXSupportedLanguage, language_str) 45 | # Check if at least one query file exists for this language 46 | query_types = cls.LANGUAGE_QUERY_TYPES[language_str] 47 | for query_type in query_types: 48 | if cls._get_query_file_path(language, query_type).exists(): 49 | supported.append(language) 50 | break 51 | 52 | return supported 53 | 54 | @classmethod 55 | def get_query(cls, language: IDXSupportedLanguage, query_type: str = "tags") -> str: 56 | """ 57 | Get a specific query for a language. 58 | 59 | Args: 60 | language: The language 61 | query_type: The type of query (e.g., 'tags', 'tests', 'fine_imports') 62 | 63 | Returns: 64 | Query text 65 | 66 | Raises: 67 | KeyError: If language or query type is not supported 68 | FileNotFoundError: If query file is not found 69 | """ 70 | if language not in cls.LANGUAGE_QUERY_TYPES: 71 | raise KeyError(f"Language not supported: {language}") 72 | 73 | if query_type not in cls.LANGUAGE_QUERY_TYPES[language]: 74 | raise KeyError( 75 | f"Query type '{query_type}' not available for language: {language}" 76 | ) 77 | 78 | query_file = cls._get_query_file_path(language, query_type) 79 | if not query_file.exists(): 80 | raise FileNotFoundError(f"Tree-sitter query file not found: {query_file}") 81 | 82 | with open(query_file, "r") as f: 83 | return f.read() 84 | 85 | @classmethod 86 | def has_query(cls, language: IDXSupportedLanguage, query_type: str) -> bool: 87 | """ 88 | Check if a specific query is available for a language. 89 | 90 | Args: 91 | language: The language 92 | query_type: The type of query 93 | 94 | Returns: 95 | True if query is available, False otherwise 96 | """ 97 | if language not in cls.LANGUAGE_QUERY_TYPES: 98 | return False 99 | 100 | if query_type not in cls.LANGUAGE_QUERY_TYPES[language]: 101 | return False 102 | 103 | query_file = cls._get_query_file_path(language, query_type) 104 | return query_file.exists() 105 | 106 | @classmethod 107 | def get_available_queries(cls) -> Dict[str, List[str]]: 108 | """ 109 | Get all available queries organized by language. 110 | 111 | Returns: 112 | Dictionary mapping language -> list of available query types 113 | """ 114 | from typing import cast 115 | 116 | available = {} 117 | for language_str in cls.LANGUAGE_QUERY_TYPES: 118 | language = cast(IDXSupportedLanguage, language_str) 119 | query_types = [] 120 | for query_type in cls.LANGUAGE_QUERY_TYPES[language_str]: 121 | if cls.has_query(language, query_type): 122 | query_types.append(query_type) 123 | if query_types: 124 | available[language_str] = query_types 125 | 126 | return available 127 | 128 | @classmethod 129 | def _get_query_file_path( 130 | cls, language: IDXSupportedLanguage, query_type: str 131 | ) -> Path: 132 | """ 133 | Get the file path for a specific query. 134 | 135 | Args: 136 | language: The language 137 | query_type: The type of query (e.g., 'tags', 'tests', 'fine_imports') 138 | 139 | Returns: 140 | Path to the query file 141 | """ 142 | # Get the directory where this file is located 143 | current_file = Path(__file__) 144 | # Navigate to the queries directory: ../../../queries/treesitter/ 145 | queries_dir = current_file.parent.parent.parent / "queries" / "treesitter" 146 | 147 | # Construct the path: queries/treesitter/{query_type}/{language}.scm 148 | query_file_path = queries_dir / query_type / f"{language}.scm" 149 | 150 | return query_file_path 151 | -------------------------------------------------------------------------------- /USAGE.md: -------------------------------------------------------------------------------- 1 | # Usage Guide 2 | 3 | ## Programmatic API 4 | 5 | CodeRetrX provides a powerful programmatic interface through the `coderetrx_filter` and `llm_traversal_filter` API, which enables flexible code analysis and retrieval across different search strategies and filtering modes. 6 | 7 | The coderetrx_filter implements stragties designed by us. It offers a cost-effective semantic recall approach for large-scale repositories, achieving approximately 90% recall with only about 25% of the resource consumption in practical tests(line_per_symbol strategy) — the larger the repository, the greater the savings. 8 | 9 | The llm_traversal_filter provides the most comprehensive and accurate analysis, ideal for establishing ground truth. For small-scale repositories, this strategy can also be a good choise. 10 | 11 | ### Using coderetrx_filter with symbol_name strategy 12 | 13 | ```python 14 | from pathlib import Path 15 | from coderetrx.retrieval import coderetrx_filter, CodebaseFactory 16 | 17 | # Initialize codebase 18 | codebase = CodebaseFactory.new("repo_name", Path("/path/to/your/repo")) 19 | 20 | # Basic symbol search 21 | elements, llm_results = await coderetrx_filter( 22 | codebase=codebase, 23 | prompt="your_filter_prompt", 24 | subdirs_or_files=["src/"], 25 | target_type="symbol_content", 26 | coarse_recall_strategy="symbol_name" 27 | ) 28 | 29 | # Process results 30 | for element in elements: 31 | print(f"Found: {element.name} in {element.file.path}") 32 | ``` 33 | 34 | ### Using coderetrx_filter with line_per_symbol strategy 35 | 36 | ```python 37 | from coderetrx.retrieval import coderetrx_filter 38 | from coderetrx.retrieval.code_recall import CodeRecallSettings 39 | 40 | # Configure advanced settings 41 | settings = CodeRecallSettings( 42 | llm_primary_recall_model_id="google/gemini-2.5-flash-lite-preview-06-17", 43 | llm_secondary_recall_model_id="openai/gpt-4o-mini" 44 | ) 45 | 46 | # Cost-efficient and complete search with line_per_symbol mode and secondary recall 47 | elements, llm_results = await coderetrx_filter( 48 | codebase=codebase, 49 | prompt="your_filter_prompt", 50 | subdirs_or_files=["src/", "lib/"], 51 | target_type="symbol_content", 52 | coarse_recall_strategy="line_per_symbol", 53 | settings=settings, 54 | enable_secondary_recall=True 55 | ) 56 | ``` 57 | 58 | ### Using llm_traversal_filter (Ground Truth & Maximum Accuracy) 59 | 60 | ```python 61 | from coderetrx.retrieval import llm_traversal_filter 62 | 63 | # Ground truth search - most comprehensive and accurate 64 | elements, llm_results = await llm_traversal_filter( 65 | codebase=codebase, 66 | prompt="your_filter_prompt", 67 | subdirs_or_files=["src/", "lib/"], 68 | target_type="symbol_content", 69 | settings=settings 70 | ) 71 | ``` 72 | 73 | ## Coarse Recall Strategy 74 | 75 | **coderetrx_filter** supports `file_name`, `symbol_name`, `line_per_symbol`, and `auto` strategies for different speed/accuracy tradeoffs. **llm_traversal_filter** uses full LLM processing for maximum accuracy. 76 | 77 | [See detailed strategy comparison in README.md](#-coarse-recall-strategies) 78 | 79 | ## Target type Options 80 | 81 | Target type defines the retrieval target, determining the type of code object to be recalled and returned. For example, if the target_type is set to class_content, the result will include the relevant classes whose content match the query. Below are the available target_type options: 82 | 83 | - **`symbol_name`**:Matches symbols (e.g., functions, classes) whose **name** satisfies the query. 84 | - **`symbol_content`**: Matches symbols whose **entire code content** satisfies the query. 85 | - **`leaf_symbol_name`**: Matches **leaf symbols** (symbols without child elements, such as methods) whose **name** satisfy the query. 86 | - **`leaf_symbol_content`**: Matches **leaf symbols** whose **code content** satisfies the query. 87 | - **`root_symbol_name`**: Matches **root symbols** (symbols without parent elements, such as top-level classes, functions) whose **name** satisfy the query. 88 | - **`root_symbol_content`**: Matches **root symbols** whose **entire code content** satisfies the query. 89 | - **`class_name`**: Matches **classes** whose **name** that satisfy the query. 90 | - **`class_content`**: Matches **classes** whose **entire code content** satisfies the query. 91 | - **`function_name`**: Matches **functions** whose **name** satisfy the query. 92 | - **`function_content`**: Matches **functions** whose **entire code content** satisfies the query. 93 | - **`dependency_name`**: Matches **dependency names** (e.g., imported libraries or modules) that satisfy the query. 94 | 95 | Note: The coderetrx_filter only supports the xxx_content series of target_type, while the llm_traversal_filter supports all target_type options. 96 | 97 | ## Settings Configuration 98 | 99 | The `CodeRecallSettings` class allows fine-tuning of the search behavior: 100 | 101 | ```python 102 | settings = CodeRecallSettings( 103 | llm_primary_recall_model_id="...", # Model used for coarse recall and the primary recall in the refined stage. 104 | llm_secondary_recall_model_id="...", # Model used for secondary recall in the refined stage. If set (not None), secondary recall will be enabled. 105 | llm_selector_strategy_model_id="...", # Model used for strategy selection in "auto" mode during the coarse recall stage. 106 | llm_call_mode="function_call" # LLM call mode. If set to "function_call", the LLM will return results as a function call (recommended for models supporting this feature). 107 | # If set to "traditional", the LLM will return results in plain text format. 108 | ) 109 | ``` 110 | 111 | ## Working with Results 112 | 113 | ```python 114 | # Process different types of results 115 | for element in elements: 116 | if hasattr(element, 'name'): # Symbol 117 | print(f"Symbol: {element.name} in {element.file.path}") 118 | print(f"Content: {element.chunk.code()}") 119 | elif hasattr(element, 'path'): # File 120 | print(f"File: {element.path}") 121 | elif hasattr(element, 'content'): # Keyword 122 | print(f"Keyword: {element.content}") 123 | 124 | # Access LLM analysis results 125 | for result in llm_results: 126 | print(f"Index: {result.index}") 127 | print(f"Reason: {result.reason}") 128 | print(f"Result: {result.result}") 129 | ``` 130 | 131 | -------------------------------------------------------------------------------- /bench/queries.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "3b866a1d-8de9-4999-bbf3-651c1eb6ec5e", 4 | "name": "Unsafe Evaluation Functions", 5 | "filter_prompt": "The code snippet contains a function call that dynamically executes code or system commands. Examples include Python's `eval()`, `exec()`, or functions like `os.system()`, `subprocess.run()` (especially with `shell=True`), `subprocess.call()` (with `shell=True`), or `popen()`. The critical feature is that the string representing the code or command to be executed is not a hardcoded literal; instead, it's derived from a variable, function argument, string concatenation/formatting, or an external source such as user input, network request, or LLM output.", 6 | "subdirs_or_files": [ 7 | "/" 8 | ] 9 | }, 10 | { 11 | "id": "b2565e7b-dacc-44e5-9817-a31898326bc1", 12 | "name": "Pickle/Cloudpickle Deserialization", 13 | "filter_prompt": "This code snippet deserializes data using functions from Python's `pickle` or `cloudpickle` libraries, such as `load()` or `loads()`. The input data for the deserialization operation is not a hardcoded literal.", 14 | "subdirs_or_files": [ 15 | "/" 16 | ] 17 | }, 18 | { 19 | "id": "a8a337fc-4a32-4bf9-b98c-4b30a666352a", 20 | "name": "Magic Bytes File Type Manipulation", 21 | "filter_prompt": "This code snippet implements logic to determine or validate a file's type by reading and analyzing its initial bytes (e.g., magic bytes, file signature, or header). This is often part of a file upload handling mechanism or file processing pipeline where verifying the actual content type based on its leading bytes is critical.", 22 | "subdirs_or_files": [ 23 | "/" 24 | ] 25 | }, 26 | { 27 | "id": "9011537d-3f3a-4587-9500-068c453194f9", 28 | "name": "Shell Command Execution", 29 | "filter_prompt": "This code snippet executes a shell command, system command, an external program, or evaluates a string as code. This is often done using functions like `os.system`, `subprocess.call`, `subprocess.run` (especially with `shell=True`), `subprocess.Popen` (especially with `shell=True`), `commands.getoutput`, `Runtime.getRuntime().exec`, `ProcessBuilder`, `php.system`, `php.exec`, `php.shell_exec`, `php.passthru`, `php.popen`, PHP backticks (` `), `Node.child_process.exec`, `Node.child_process.execSync`, `eval`, `exec`, `ScriptEngine.eval()`, `execCommand`, `Perl.system`, `Ruby.system`, `Ruby.exec`, Ruby backticks (``), `Go.os/exec.Command`, etc. The command string, arguments to the command, or the string being evaluated as code, are derived from variables, function parameters, or other dynamic sources, rather than being solely hardcoded string literals.", 30 | "subdirs_or_files": [ 31 | "/" 32 | ] 33 | }, 34 | { 35 | "id": "ab624bd9-92e7-4f34-a2b4-02a914a78040", 36 | "name": "Command Injection in CLI Applications", 37 | "filter_prompt": "This code snippet executes operating system commands using functions like `os.system`, `subprocess.run`, `subprocess.Popen`, `subprocess.call`, `subprocess.check_output`, `commands.getoutput`, or `pty.spawn`. The command being executed is dynamically constructed using string operations (e.g., concatenation, f-strings, `.format()`) with variables that could hold data from external sources like command-line arguments or file content. Prioritize instances where `subprocess` functions are used with `shell=True` or where command components are assembled from non-literal string variables.", 38 | "subdirs_or_files": [ 39 | "/" 40 | ] 41 | }, 42 | { 43 | "id": "1f7cc885-ad42-42a7-a051-0bb5cd97374f", 44 | "name": "Other Deserialization Mechanisms", 45 | "filter_prompt": "This code snippet performs deserialization of data using PyTorch's `torch.load()` (or similar model loading functions in AI/ML frameworks), Python's `shelve` module (e.g., `shelve.open()`, `shelf[key]`), or JDBC connection mechanisms (e.g., constructing connection URLs or using drivers). The deserialization is flagged if the input data (such as a model file path or content, data from a shelve file, or components of a JDBC URL) is not a hardcoded literal and could originate from an untrusted external source.", 46 | "subdirs_or_files": [ 47 | "/" 48 | ] 49 | }, 50 | { 51 | "id": "dbfd2d19-9e1a-4c1a-b1f1-79318deebb06", 52 | "name": "Path Traversal and File Operations", 53 | "filter_prompt": "Locate code snippets that perform file system operations (such as reading, writing, deleting, moving files or directories, extracting archives, or including files) or use file paths or names within system commands. Focus on cases where these file paths or names are derived from, or can be influenced by, external sources (e.g., user input, network data, API parameters, environment variables, or function arguments traceable to such sources) and where there is a potential lack of, or insufficient, sanitization or validation against path traversal techniques (e.g., sequences like '..', absolute paths, symbolic links, null bytes, or encoding tricks).", 54 | "subdirs_or_files": [ 55 | "/" 56 | ] 57 | }, 58 | { 59 | "id": "bd1ec518-2837-40a2-ab2d-fc38099b3686", 60 | "name": "Arbitrary File Write", 61 | "filter_prompt": "This code snippet involves a file system write operation (such as creating, writing to, or moving a file). The destination path, filename, or the content of the file appears to be constructed or influenced by data originating from an external source (e.g., user input, API request parameters, network data, configuration files, environment variables) and there is no clear evidence of robust sanitization, validation, or restriction of the path to a predefined safe directory.", 62 | "subdirs_or_files": [ 63 | "/" 64 | ] 65 | }, 66 | { 67 | "id": "50794dd5-45fd-43e4-afd3-37359ef5b076", 68 | "name": "Bypass of File Extension Restrictions", 69 | "filter_prompt": "This code snippet is involved in processing files uploaded by users. This includes operations such as retrieving the original filename or file extension, determining the destination path or filename for storage, moving or saving the uploaded file to the server's filesystem, and/or implementing validation rules to restrict allowed file types. These validation rules might be based on the file's extension (e.g., checking against a list of permitted or forbidden extensions like '.php', '.jsp', '.asp', '.exe', '.gif', '.jpg'), its MIME type, or its initial bytes (magic bytes/file signatures).", 70 | "subdirs_or_files": [ 71 | "/" 72 | ] 73 | } 74 | ] -------------------------------------------------------------------------------- /coderetrx/static/ripgrep/installer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import sys 4 | import asyncio 5 | import shutil 6 | import tarfile 7 | from tempfile import TemporaryDirectory 8 | import zipfile 9 | from pathlib import Path 10 | from typing import Optional, Tuple, Dict 11 | 12 | import httpx 13 | 14 | # Ripgrep release information 15 | RG_VERSION = "14.1.1" 16 | RG_REPO = "BurntSushi/ripgrep" 17 | RG_BASE_URL = f"https://github.com/{RG_REPO}/releases/download/{RG_VERSION}" 18 | 19 | 20 | def get_platform_info() -> Optional[str]: 21 | """Get platform information and return the corresponding download filename 22 | 23 | Returns: 24 | Optional[str]: Matching platform filename or None if unsupported 25 | """ 26 | system = platform.system().lower() 27 | machine = platform.machine().lower() 28 | 29 | # Map platform to ripgrep release filename 30 | platform_map: Dict[Tuple[str, str], str] = { 31 | ("darwin", "x86_64"): f"ripgrep-{RG_VERSION}-x86_64-apple-darwin.tar.gz", 32 | ("darwin", "arm64"): f"ripgrep-{RG_VERSION}-aarch64-apple-darwin.tar.gz", 33 | ("linux", "x86_64"): f"ripgrep-{RG_VERSION}-x86_64-unknown-linux-musl.tar.gz", 34 | ("linux", "i686"): f"ripgrep-{RG_VERSION}-i686-unknown-linux-gnu.tar.gz", 35 | ("linux", "aarch64"): f"ripgrep-{RG_VERSION}-aarch64-unknown-linux-gnu.tar.gz", 36 | ( 37 | "linux", 38 | "armv7l", 39 | ): f"ripgrep-{RG_VERSION}-armv7-unknown-linux-gnueabihf.tar.gz", 40 | ("windows", "amd64"): f"ripgrep-{RG_VERSION}-x86_64-pc-windows-msvc.zip", 41 | ("windows", "x86"): f"ripgrep-{RG_VERSION}-i686-pc-windows-msvc.zip", 42 | } 43 | 44 | return platform_map.get((system, machine)) 45 | 46 | 47 | async def download_file(url: str, dest_path: Path) -> None: 48 | """Download a file from a URL to a destination path.""" 49 | async with httpx.AsyncClient(follow_redirects=True) as client: 50 | async with client.stream("GET", url) as response: 51 | response.raise_for_status() 52 | with open(dest_path, "wb") as f: 53 | async for chunk in response.aiter_bytes(): 54 | f.write(chunk) 55 | 56 | 57 | def extract_single_file(archive_path: Path, filename: str, dest: Path) -> bool: 58 | """Extract a single file from archive without extracting everything 59 | 60 | Args: 61 | archive_path: Path to the archive file 62 | filename: Name of the file to extract 63 | dest: Destination path for the extracted file 64 | 65 | Returns: 66 | bool: True if extraction succeeded, False otherwise 67 | """ 68 | try: 69 | if ( 70 | archive_path.suffixes[-2:] == [".tar", ".gz"] 71 | or archive_path.suffix[-1] == ".tgz" 72 | ): 73 | with tarfile.open(archive_path, "r:gz") as tar: 74 | member = next( 75 | (m for m in tar.getmembers() if m.name.endswith(f"/{filename}")), 76 | None, 77 | ) 78 | if member: 79 | extracted_file = tar.extractfile(member) 80 | if extracted_file is None: 81 | return False 82 | with extracted_file as src, open(dest, "wb") as dst: 83 | shutil.copyfileobj(src, dst) 84 | return True 85 | elif archive_path.suffix[-1] == ".zip": 86 | with zipfile.ZipFile(archive_path, "r") as zip_ref: 87 | member = next( 88 | (m for m in zip_ref.infolist() if m.filename.endswith(filename)), 89 | None, 90 | ) 91 | if member: 92 | with zip_ref.open(member) as src, open(dest, "wb") as dst: 93 | shutil.copyfileobj(src, dst) 94 | return True 95 | return False 96 | except Exception as e: 97 | print(f"Error extracting {filename}: {e}") 98 | return False 99 | 100 | 101 | async def install_rg(install_path: Path) -> Optional[Path]: 102 | """Main installation function 103 | 104 | Returns: 105 | Optional[Path]: Path to the installed rg binary or None if failed 106 | """ 107 | # Get appropriate ripgrep filename for current platform 108 | rg_file = get_platform_info() 109 | if not rg_file: 110 | print("Error: Unsupported platform") 111 | print(f"System: {platform.system()}, Machine: {platform.machine()}") 112 | return None 113 | 114 | install_path.mkdir(parents=True, exist_ok=True) 115 | with TemporaryDirectory() as temp_dir: # ty: ignore[no-matching-overload] 116 | try: 117 | # Download the file 118 | download_url = f"{RG_BASE_URL}/{rg_file}" 119 | archive_path = Path(temp_dir) / rg_file 120 | await download_file(download_url, archive_path) 121 | 122 | # Determine binary filename based on platform 123 | binary_name = "rg.exe" if os.name == "nt" else "rg" 124 | temp_binary = Path(temp_dir) / binary_name 125 | 126 | # Extract just the binary file 127 | if not extract_single_file(archive_path, binary_name, temp_binary): 128 | print(f"Error: Could not find {binary_name} in the downloaded archive") 129 | return None 130 | 131 | # Set executable permissions (Unix-like systems) 132 | if os.name != "nt": 133 | os.chmod(temp_binary, 0o755) 134 | 135 | # Install to destination 136 | dest_path = install_path / binary_name 137 | 138 | # Move the binary to final location 139 | shutil.move(temp_binary, dest_path) 140 | 141 | print("Installation complete!") 142 | print(f"ripgrep has been installed to: {dest_path}") 143 | print(f"{install_path}") 144 | 145 | return dest_path 146 | 147 | except httpx.HTTPError as e: 148 | print(f"Download failed: {e}") 149 | except Exception as e: 150 | print(f"An unexpected error occurred: {e}") 151 | 152 | return None 153 | 154 | 155 | if __name__ == "__main__": 156 | from pathlib import Path 157 | 158 | if len(sys.argv) < 2: 159 | path = Path(".") 160 | else: 161 | path = Path(sys.argv[1]) 162 | installed_path = asyncio.run(install_rg(path)) 163 | if installed_path: 164 | print(f"Successfully installed rg at: {installed_path}") 165 | else: 166 | print("Failed to install ripgrep") 167 | sys.exit(1) 168 | -------------------------------------------------------------------------------- /coderetrx/retrieval/topic_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List, Optional, Literal 2 | import logging 3 | import json 4 | from pydantic import BaseModel 5 | 6 | from .prompt import ( 7 | KeywordExtractorResult, 8 | topic_extraction_prompt_template, 9 | topic_extraction_function_call_system_prompt, 10 | get_topic_extraction_function_definition, topic_extraction_function_call_user_prompt_template, 11 | ) 12 | from .smart_codebase import SmartCodebaseSettings 13 | from coderetrx.utils.llm import call_llm_with_fallback, call_llm_with_function_call 14 | from .smart_codebase import ( 15 | SmartCodebase, 16 | SimilaritySearchTargetType, 17 | LLMCallMode, 18 | ) 19 | import os 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class TopicExtractor: 25 | """ 26 | A class to extract topics from input text using LLM before performing vector similarity searching. 27 | """ 28 | async def extract_topic(self, input_text: str, llm_call_mode: Literal["traditional", "function_call"] = "traditional") -> Optional[str]: 29 | """ 30 | Extract the core topic from the input text using LLM. 31 | 32 | Args: 33 | input_text: The input text to extract topic from 34 | llm_call_mode: Whether to use traditional prompt-based extraction or function call mode. 35 | 36 | Returns: 37 | The extracted topic as a string, or None if extraction fails 38 | """ 39 | try: 40 | if llm_call_mode == "function_call": 41 | return await self._extract_topic_with_function_call(input_text) 42 | else: 43 | return await self._extract_topic_traditional(input_text) 44 | except Exception as e: 45 | logger.error(f"Error extracting topic: {str(e)}", exc_info=True) 46 | return None 47 | 48 | async def _extract_topic_traditional(self, input_text: str) -> Optional[str]: 49 | """Extract topic using traditional prompt-based approach.""" 50 | try: 51 | # Prepare input data for the prompt template 52 | input_data = {"input": input_text} 53 | 54 | # Call LLM with the topic extraction prompt template 55 | # Use topic extraction model_id from settings 56 | settings = SmartCodebaseSettings() 57 | topic_model_id = settings.llm_topic_extraction_model_id or settings.default_model_id 58 | model_ids = [topic_model_id] 59 | result = await call_llm_with_fallback( 60 | response_model=KeywordExtractorResult, 61 | input_data=input_data, 62 | prompt_template=topic_extraction_prompt_template, 63 | model_ids=model_ids, 64 | ) 65 | 66 | assert isinstance(result, KeywordExtractorResult) 67 | # Extract the topic from the result 68 | extracted_topic = result.result 69 | logger.info(f"Successfully extracted topic: '{extracted_topic}' from input using traditional mode") 70 | return extracted_topic 71 | 72 | except Exception as e: 73 | logger.error(f"Error in traditional topic extraction: {str(e)}", exc_info=True) 74 | return None 75 | 76 | async def _extract_topic_with_function_call(self, input_text: str) -> Optional[str]: 77 | """Extract topic using function call approach.""" 78 | try: 79 | # Prepare prompts for function call 80 | system_prompt = topic_extraction_function_call_system_prompt 81 | user_prompt = topic_extraction_function_call_user_prompt_template.format( 82 | input=input_text, 83 | ) 84 | function_definition = get_topic_extraction_function_definition() 85 | 86 | # Call LLM with function call 87 | # Use topic extraction model_id from settings 88 | settings = SmartCodebaseSettings() 89 | topic_model_id = settings.llm_topic_extraction_model_id or settings.default_model_id 90 | model_ids = [topic_model_id] 91 | function_args = await call_llm_with_function_call( 92 | system_prompt=system_prompt, 93 | user_prompt=user_prompt, 94 | function_definition=function_definition, 95 | model_ids=model_ids, 96 | ) 97 | 98 | # Extract topic from function call result 99 | extracted_topic = function_args.get("result") 100 | reason = function_args.get("reason", "") 101 | 102 | if extracted_topic: 103 | logger.info(f"Successfully extracted topic: '{extracted_topic}' from input using function call mode. Reason: {reason}") 104 | return extracted_topic 105 | else: 106 | logger.warning("Function call returned empty result for topic extraction") 107 | return None 108 | 109 | except Exception as e: 110 | logger.error(f"Error in function call topic extraction: {str(e)}", exc_info=True) 111 | return None 112 | 113 | async def extract_and_search( 114 | self, 115 | codebase: SmartCodebase, 116 | input_text: str, 117 | target_types: List[SimilaritySearchTargetType], 118 | threshold: float = 0.1, 119 | top_k: int = 100, 120 | llm_call_mode: Literal["traditional", "function_call"] = "traditional", 121 | ) -> List[Any]: 122 | """ 123 | Extract topic from input text and use it for vector similarity search. 124 | 125 | Args: 126 | codebase: The codebase to search in 127 | input_text: The input text to extract topic from 128 | target_types: List of target types for similarity search 129 | threshold: Similarity threshold 130 | top_k: Number of top results to return 131 | llm_call_mode: Whether to use traditional prompt-based extraction or function call mode. 132 | 133 | Returns: 134 | List of search results 135 | """ 136 | # Extract topic from input text 137 | topic = await self.extract_topic(input_text, llm_call_mode) 138 | 139 | if not topic: 140 | logger.warning( 141 | "Using original input text for search as topic extraction failed" 142 | ) 143 | topic = input_text 144 | 145 | logger.info(f"Performing similarity search with topic: '{topic}'") 146 | 147 | # Perform similarity search using the extracted topic 148 | results = await codebase.similarity_search( 149 | target_types=target_types, query=topic, threshold=threshold, top_k=top_k 150 | ) 151 | 152 | logger.info(f"Found {len(results)} results for topic '{topic}'") 153 | return results 154 | -------------------------------------------------------------------------------- /test/impl/default/test_code_recall.py: -------------------------------------------------------------------------------- 1 | from dotenv import load_dotenv 2 | load_dotenv() 3 | import json 4 | from pathlib import Path 5 | from coderetrx.impl.default import CodebaseFactory 6 | from coderetrx.impl.default import TopicExtractor 7 | from coderetrx.retrieval.code_recall import multi_strategy_code_filter, multi_strategy_code_mapping 8 | import os 9 | import asyncio 10 | import unittest 11 | from typing import Literal 12 | from unittest.mock import patch, MagicMock 13 | from coderetrx.utils.embedding import create_documents_embedding 14 | from coderetrx.utils.git import clone_repo_if_not_exists, get_repo_id 15 | from coderetrx.utils.path import get_data_dir 16 | import logging 17 | 18 | logger = logging.getLogger(__name__) 19 | logging.basicConfig(level=logging.INFO) 20 | 21 | TEST_REPOS = ["https://github.com/apache/dubbo-admin.git"] 22 | 23 | 24 | def prepare_codebase(repo_url: str, repo_path: Path): 25 | """Helper function to prepare codebase for testing""" 26 | database_path = get_data_dir() / "databases" / f"{get_repo_id(repo_url)}.json" 27 | # Create a test codebase 28 | clone_repo_if_not_exists(repo_url, str(repo_path)) 29 | 30 | if database_path.exists(): 31 | codebase = CodebaseFactory.from_json( 32 | json.load(open(database_path, "r", encoding="utf-8")) 33 | ) 34 | else: 35 | codebase = CodebaseFactory.new(get_repo_id(repo_url), repo_path) 36 | with open(f"{repo_path}.json", "w") as f: 37 | json.dump(codebase.to_json(), f, indent=4) 38 | return codebase 39 | 40 | 41 | class TestLLMCodeFilterTool(unittest.TestCase): 42 | """Test LLMCodeFilterTool functionality""" 43 | 44 | def setUp(self): 45 | """Set up test environment""" 46 | self.repo_url = TEST_REPOS[0] 47 | self.repo_path = get_data_dir() / "repos" / get_repo_id(self.repo_url) 48 | self.codebase = prepare_codebase(self.repo_url, self.repo_path) 49 | self.test_dir = "/" 50 | self.test_prompt = "Is the code snippet used for user authentication?" 51 | self.topic_extractor = TopicExtractor() 52 | 53 | async def recall_with_mode(self, mode: Literal["fast", "balance", "precise", "custom"]): 54 | """Helper method to run the tool with a specific mode""" 55 | result, llm_output = await multi_strategy_code_filter( 56 | codebase=self.codebase, 57 | subdirs_or_files=[self.test_dir], 58 | prompt=self.test_prompt, 59 | target_type="symbol_content", 60 | mode=mode, 61 | topic_extractor=self.topic_extractor, 62 | ) 63 | return result 64 | 65 | def test_initialization(self): 66 | """Test initialization of test environment""" 67 | self.assertIsNotNone(self.codebase) 68 | self.assertIsNotNone(self.test_dir) 69 | 70 | def test_fast_mode(self): 71 | """Test LLMCodeFilterTool in fast mode""" 72 | result = asyncio.run(self.recall_with_mode("fast")) 73 | 74 | # Verify results 75 | self.assertIsNotNone(result) 76 | self.assertIsInstance(result, list) 77 | logger.info(f"Fast mode results count: {len(result)}") 78 | if result: 79 | logger.info(f"Sample result: {result[0]}") 80 | 81 | def test_balance_mode(self): 82 | """Test LLMCodeFilterTool in balance mode""" 83 | result = asyncio.run(self.recall_with_mode("balance")) 84 | 85 | # Verify results 86 | self.assertIsNotNone(result) 87 | self.assertIsInstance(result, list) 88 | logger.info(f"Balance mode results count: {len(result)}") 89 | if result: 90 | logger.info(f"Sample result: {result[0]}") 91 | 92 | def test_precise_mode(self): 93 | """Test LLMCodeFilterTool in precise mode""" 94 | result = asyncio.run(self.recall_with_mode("precise")) 95 | 96 | # Verify results 97 | self.assertIsNotNone(result) 98 | self.assertIsInstance(result, list) 99 | logger.info(f"Precise mode results count: {len(result)}") 100 | if result: 101 | logger.info(f"Sample result: {result[0]}") 102 | 103 | 104 | class TestLLMCodeMappingTool(unittest.TestCase): 105 | """Test multi_strategy_code_mapping functionality""" 106 | 107 | def setUp(self): 108 | """Set up test environment""" 109 | self.repo_url = TEST_REPOS[0] 110 | self.repo_path = get_data_dir() / "repos" / get_repo_id(self.repo_url) 111 | self.codebase = prepare_codebase(self.repo_url, self.repo_path) 112 | self.test_dir = "/" 113 | self.test_prompt = "Extract the function call sink that may result in arbitrary code execution" 114 | 115 | async def recall_with_mode(self, mode: Literal["fast", "balance", "precise", "custom"]): 116 | """Helper method to run the tool with a specific mode""" 117 | result, llm_output = await multi_strategy_code_mapping( 118 | codebase=self.codebase, 119 | subdirs_or_files=[self.test_dir], 120 | prompt=self.test_prompt, 121 | target_type="symbol_content", 122 | mode=mode, 123 | ) 124 | return result 125 | 126 | def test_initialization(self): 127 | """Test initialization of test environment""" 128 | self.assertIsNotNone(self.codebase) 129 | self.assertIsNotNone(self.test_dir) 130 | 131 | def test_fast_mode(self): 132 | """Test LLMCodeMappingTool in fast mode""" 133 | result = asyncio.run(self.recall_with_mode("fast")) 134 | 135 | # Verify results 136 | self.assertIsNotNone(result) 137 | self.assertIsInstance(result, list) 138 | logger.info(f"Fast mode results count: {len(result)}") 139 | if result: 140 | logger.info(f"Sample result: {result[0]}") 141 | 142 | def test_balance_mode(self): 143 | """Test LLMCodeMappingTool in balance mode""" 144 | result = asyncio.run(self.recall_with_mode("balance")) 145 | 146 | # Verify results 147 | self.assertIsNotNone(result) 148 | self.assertIsInstance(result, list) 149 | logger.info(f"Balance mode results count: {len(result)}") 150 | if result: 151 | logger.info(f"Sample result: {result[0]}") 152 | 153 | def test_precise_mode(self): 154 | """Test LLMCodeMappingTool in precise mode""" 155 | result = asyncio.run(self.recall_with_mode("precise")) 156 | 157 | # Verify results 158 | self.assertIsNotNone(result) 159 | self.assertIsInstance(result, list) 160 | logger.info(f"Precise mode results count: {len(result)}") 161 | if result: 162 | logger.info(f"Sample result: {result[0]}") 163 | 164 | 165 | # Run tests if specified 166 | if __name__ == "__main__": 167 | # Run unittest tests 168 | unittest.main(argv=['first-arg-is-ignored'], exit=False) 169 | -------------------------------------------------------------------------------- /coderetrx/static/codeql/installer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import sys 4 | import asyncio 5 | import shutil 6 | import tarfile 7 | from tempfile import TemporaryDirectory 8 | from pathlib import Path 9 | from typing import Optional, Tuple, Dict 10 | 11 | import httpx 12 | 13 | # CodeQL release information 14 | CODEQL_VERSION = "v2.23.1" 15 | CODEQL_REPO = "github/codeql-action" 16 | CODEQL_BASE_URL = ( 17 | f"https://github.com/{CODEQL_REPO}/releases/download/codeql-bundle-{CODEQL_VERSION}" 18 | ) 19 | 20 | 21 | def get_platform_info() -> Optional[str]: 22 | """Get platform information and return the corresponding download filename 23 | 24 | Returns: 25 | Optional[str]: Matching platform filename or None if unsupported 26 | """ 27 | system = platform.system().lower() 28 | 29 | # Map platform to CodeQL release filename 30 | platform_map: Dict[str, str] = { 31 | "darwin": f"codeql-bundle-osx64.tar.gz", 32 | "linux": f"codeql-bundle-linux64.tar.gz", 33 | } 34 | 35 | return platform_map.get(system) 36 | 37 | 38 | async def download_file(url: str, dest_path: Path) -> None: 39 | """Download a file from a URL to a destination path.""" 40 | async with httpx.AsyncClient(follow_redirects=True) as client: 41 | async with client.stream("GET", url) as response: 42 | response.raise_for_status() 43 | with open(dest_path, "wb") as f: 44 | async for chunk in response.aiter_bytes(): 45 | f.write(chunk) 46 | 47 | 48 | def extract_codeql_bundle(archive_path: Path, dest_dir: Path) -> bool: 49 | """Extract the entire CodeQL bundle to destination directory 50 | 51 | Args: 52 | archive_path: Path to the archive file 53 | dest_dir: Destination directory for extraction 54 | 55 | Returns: 56 | bool: True if extraction succeeded, False otherwise 57 | """ 58 | try: 59 | with tarfile.open(archive_path, "r:gz") as tar: 60 | # Extract all contents 61 | tar.extractall(path=dest_dir) 62 | return True 63 | except Exception as e: 64 | print(f"Error extracting CodeQL bundle: {e}") 65 | return False 66 | 67 | 68 | async def install_codeql(install_path: Path) -> Optional[Path]: 69 | """Main installation function 70 | 71 | Args: 72 | install_path: Path where CodeQL should be installed (e.g., /opt/codeql) 73 | 74 | Returns: 75 | Optional[Path]: Path to the installed CodeQL CLI binary or None if failed 76 | """ 77 | # Get appropriate CodeQL filename for current platform 78 | codeql_file = get_platform_info() 79 | if not codeql_file: 80 | print("Error: Unsupported platform") 81 | print(f"System: {platform.system()}") 82 | print("CodeQL installer only supports Linux and macOS") 83 | return None 84 | 85 | # Create install directory if it doesn't exist 86 | install_path.mkdir(parents=True, exist_ok=True) 87 | 88 | with TemporaryDirectory() as temp_dir: 89 | try: 90 | # Download the file 91 | download_url = f"{CODEQL_BASE_URL}/{codeql_file}" 92 | archive_path = Path(temp_dir) / codeql_file 93 | 94 | print(f"Downloading CodeQL from: {download_url}") 95 | await download_file(download_url, archive_path) 96 | print("Download completed") 97 | 98 | # Extract the bundle to temp directory 99 | temp_extract_dir = Path(temp_dir) / "extracted" 100 | temp_extract_dir.mkdir() 101 | 102 | print("Extracting CodeQL bundle...") 103 | if not extract_codeql_bundle(archive_path, temp_extract_dir): 104 | print("Error: Could not extract CodeQL bundle") 105 | return None 106 | 107 | # Find the codeql directory in extracted contents 108 | # The bundle typically extracts to a 'codeql' directory 109 | codeql_dir = temp_extract_dir / "codeql" 110 | if not codeql_dir.exists(): 111 | # Look for any directory that might contain CodeQL 112 | extracted_dirs = [d for d in temp_extract_dir.iterdir() if d.is_dir()] 113 | if extracted_dirs: 114 | codeql_dir = extracted_dirs[0] 115 | else: 116 | print("Error: Could not find CodeQL directory in extracted bundle") 117 | return None 118 | 119 | # Verify that the codeql binary exists 120 | binary_name = "codeql.exe" if os.name == "nt" else "codeql" 121 | codeql_binary = codeql_dir / binary_name 122 | if not codeql_binary.exists(): 123 | print(f"Error: Could not find {binary_name} in the extracted bundle") 124 | return None 125 | 126 | # Remove existing installation if it exists 127 | if install_path.exists(): 128 | print(f"Removing existing CodeQL installation at {install_path}") 129 | shutil.rmtree(install_path) 130 | 131 | # Move the entire codeql directory to the install location 132 | shutil.move(str(codeql_dir), str(install_path)) 133 | 134 | # Set executable permissions for the binary (Unix-like systems) 135 | final_binary = install_path / binary_name 136 | if os.name != "nt": 137 | os.chmod(final_binary, 0o755) 138 | 139 | print("Installation complete!") 140 | print(f"CodeQL has been installed to: {install_path}") 141 | print(f"CodeQL CLI binary: {final_binary}") 142 | print( 143 | f"To use CodeQL, add {install_path} to your PATH or use the full path: {final_binary}" 144 | ) 145 | 146 | return final_binary 147 | 148 | except httpx.HTTPError as e: 149 | print(f"Download failed: {e}") 150 | except Exception as e: 151 | print(f"An unexpected error occurred: {e}") 152 | 153 | return None 154 | 155 | 156 | if __name__ == "__main__": 157 | import argparse 158 | 159 | parser = argparse.ArgumentParser(description="Install CodeQL CLI") 160 | parser.add_argument( 161 | "--install-path", 162 | type=Path, 163 | default=Path("/opt/codeql"), 164 | help="Installation path for CodeQL (default: /opt/codeql)", 165 | ) 166 | 167 | args = parser.parse_args() 168 | 169 | # Check if we have write permissions to the install path 170 | try: 171 | args.install_path.parent.mkdir(parents=True, exist_ok=True) 172 | if not os.access(args.install_path.parent, os.W_OK): 173 | print(f"Error: No write permission to {args.install_path.parent}") 174 | print( 175 | "You may need to run this script with sudo or choose a different install path" 176 | ) 177 | sys.exit(1) 178 | except Exception as e: 179 | print(f"Error checking install path: {e}") 180 | sys.exit(1) 181 | 182 | installed_path = asyncio.run(install_codeql(args.install_path)) 183 | if installed_path: 184 | print(f"Successfully installed CodeQL CLI at: {installed_path}") 185 | else: 186 | print("Failed to install CodeQL") 187 | sys.exit(1) 188 | -------------------------------------------------------------------------------- /coderetrx/static/codebase/parsers/factory.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Dict, Any, List 2 | import logging 3 | from pathlib import Path 4 | 5 | from .base import CodebaseParser, UnsupportedLanguageError 6 | from .treesitter import TreeSitterParser 7 | from .codeql import CodeQLParser 8 | from ..languages import IDXSupportedLanguage 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class ParserFactory: 14 | """ 15 | Factory for creating and managing codebase parsers. 16 | 17 | Supports auto-selection of parsers based on availability and preferences, 18 | with fallback mechanisms for robustness. 19 | """ 20 | 21 | # Parser priority order for auto-selection 22 | PARSER_PRIORITY = ["treesitter", "codeql"] 23 | 24 | @classmethod 25 | def get_parser(cls, parser_type: str = "auto", **kwargs) -> CodebaseParser: 26 | """ 27 | Get a parser instance. 28 | 29 | Args: 30 | parser_type: Type of parser to create: 31 | - "auto": Auto-select best available parser 32 | - "codeql": CodeQL parser 33 | - "treesitter": Tree-sitter parser 34 | - "hybrid": Use CodeQL where available, fallback to tree-sitter 35 | **kwargs: Parser-specific configuration options 36 | 37 | Returns: 38 | CodebaseParser instance 39 | 40 | Raises: 41 | ValueError: If parser_type is invalid 42 | RuntimeError: If no suitable parser is available 43 | """ 44 | if parser_type == "auto": 45 | return cls._auto_select_parser(**kwargs) 46 | elif parser_type == "codeql": 47 | return cls._create_codeql_parser(**kwargs) 48 | elif parser_type == "treesitter": 49 | return cls._create_treesitter_parser(**kwargs) 50 | else: 51 | raise ValueError(f"Unknown parser type: {parser_type}") 52 | 53 | @classmethod 54 | def _auto_select_parser(cls, **kwargs) -> CodebaseParser: 55 | """ 56 | Auto-select the best available parser. 57 | 58 | Tries parsers in priority order and returns the first working one. 59 | """ 60 | for parser_name in cls.PARSER_PRIORITY: 61 | try: 62 | if parser_name == "codeql": 63 | parser = cls._create_codeql_parser(**kwargs) 64 | # Test if CodeQL is actually available 65 | if cls._test_codeql_availability(parser): 66 | logger.info("Auto-selected CodeQL parser") 67 | return parser 68 | else: 69 | logger.info("CodeQL CLI not available, trying next parser") 70 | continue 71 | elif parser_name == "treesitter": 72 | parser = cls._create_treesitter_parser(**kwargs) 73 | logger.info("Auto-selected Tree-sitter parser") 74 | return parser 75 | except Exception as e: 76 | logger.debug(f"Failed to create {parser_name} parser: {e}") 77 | continue 78 | 79 | raise RuntimeError("No suitable parser available") 80 | 81 | @classmethod 82 | def _create_codeql_parser(cls, **kwargs) -> CodeQLParser: 83 | """Create a CodeQL parser instance.""" 84 | return CodeQLParser(**kwargs) 85 | 86 | @classmethod 87 | def _create_treesitter_parser(cls, **kwargs) -> TreeSitterParser: 88 | """Create a Tree-sitter parser instance.""" 89 | return TreeSitterParser(**kwargs) 90 | 91 | @classmethod 92 | def _test_codeql_availability(cls, parser: CodeQLParser) -> bool: 93 | """ 94 | Test if CodeQL is actually available and working. 95 | 96 | Args: 97 | parser: CodeQL parser instance to test 98 | 99 | Returns: 100 | True if CodeQL is available, False otherwise 101 | """ 102 | try: 103 | # Test basic CodeQL functionality 104 | supported_langs = parser.get_supported_languages() 105 | return len(supported_langs) > 0 106 | except Exception as e: 107 | logger.debug(f"CodeQL availability test failed: {e}") 108 | return False 109 | 110 | @classmethod 111 | def get_available_parsers(cls) -> Dict[str, bool]: 112 | """ 113 | Get status of all available parsers. 114 | 115 | Returns: 116 | Dictionary mapping parser names to availability status 117 | """ 118 | status = {} 119 | 120 | # Test Tree-sitter 121 | try: 122 | parser = cls._create_treesitter_parser() 123 | status["treesitter"] = len(parser.get_supported_languages()) > 0 124 | except Exception: 125 | status["treesitter"] = False 126 | 127 | # Test CodeQL 128 | try: 129 | parser = cls._create_codeql_parser() 130 | status["codeql"] = cls._test_codeql_availability(parser) 131 | except Exception: 132 | status["codeql"] = False 133 | 134 | return status 135 | 136 | @classmethod 137 | def recommend_parser( 138 | cls, languages: Optional[List[IDXSupportedLanguage]] = None 139 | ) -> str: 140 | """ 141 | Recommend the best parser for given languages. 142 | 143 | Args: 144 | languages: List of languages to support (None for all) 145 | 146 | Returns: 147 | Recommended parser name 148 | """ 149 | available = cls.get_available_parsers() 150 | 151 | if not languages: 152 | # If no specific languages, prefer CodeQL if available 153 | if available.get("codeql", False): 154 | return "codeql" 155 | elif available.get("treesitter", False): 156 | return "treesitter" 157 | else: 158 | return "auto" # Let auto-selection handle the error 159 | 160 | # Check language support for each parser 161 | codeql_coverage = 0 162 | treesitter_coverage = 0 163 | 164 | if available.get("codeql", False): 165 | try: 166 | codeql_parser = cls._create_codeql_parser() 167 | codeql_supported = set(codeql_parser.get_supported_languages()) 168 | codeql_coverage = len(set(languages).intersection(codeql_supported)) 169 | except Exception: 170 | pass 171 | 172 | if available.get("treesitter", False): 173 | try: 174 | treesitter_parser = cls._create_treesitter_parser() 175 | treesitter_supported = set(treesitter_parser.get_supported_languages()) 176 | treesitter_coverage = len( 177 | set(languages).intersection(treesitter_supported) 178 | ) 179 | except Exception: 180 | pass 181 | 182 | # Prefer CodeQL if it covers more languages, otherwise tree-sitter 183 | if codeql_coverage > treesitter_coverage: 184 | return "codeql" 185 | elif treesitter_coverage > 0: 186 | return "treesitter" 187 | elif codeql_coverage > 0: 188 | return "codeql" 189 | else: 190 | return "hybrid" # Use hybrid for maximum coverage 191 | --------------------------------------------------------------------------------