├── .python-version
├── coderetrx
├── utils
│ ├── __init__.py
│ ├── _extras.py
│ ├── git.py
│ ├── path.py
│ ├── logger.py
│ ├── concurrency.py
│ ├── cost_tracking.py
│ └── stats.py
├── static
│ ├── codebase
│ │ ├── queries
│ │ │ ├── treesitter
│ │ │ │ ├── fine_imports
│ │ │ │ │ ├── c.scm
│ │ │ │ │ ├── README.md
│ │ │ │ │ ├── cpp.scm
│ │ │ │ │ ├── java.scm
│ │ │ │ │ ├── csharp.scm
│ │ │ │ │ ├── go.scm
│ │ │ │ │ ├── python.scm
│ │ │ │ │ ├── rust.scm
│ │ │ │ │ ├── javascript.scm
│ │ │ │ │ ├── typescript.scm
│ │ │ │ │ ├── php.scm
│ │ │ │ │ └── elixir.scm
│ │ │ │ ├── tags
│ │ │ │ │ ├── elisp.scm
│ │ │ │ │ ├── kotlin.scm
│ │ │ │ │ ├── ql.scm
│ │ │ │ │ ├── python.scm
│ │ │ │ │ ├── elm.scm
│ │ │ │ │ ├── java.scm
│ │ │ │ │ ├── c.scm
│ │ │ │ │ ├── c_sharp.scm
│ │ │ │ │ ├── go.scm
│ │ │ │ │ ├── ruby.scm
│ │ │ │ │ ├── php.scm
│ │ │ │ │ ├── csharp.scm
│ │ │ │ │ ├── typescript.scm
│ │ │ │ │ ├── cpp.scm
│ │ │ │ │ ├── rust.scm
│ │ │ │ │ ├── elixir.scm
│ │ │ │ │ ├── hcl.scm
│ │ │ │ │ ├── dart.scm
│ │ │ │ │ ├── javascript.scm
│ │ │ │ │ └── ocaml.scm
│ │ │ │ └── tests
│ │ │ │ │ ├── php.scm
│ │ │ │ │ ├── elixir.scm
│ │ │ │ │ ├── javascript.scm
│ │ │ │ │ ├── typescript.scm
│ │ │ │ │ ├── go.scm
│ │ │ │ │ ├── rust.scm
│ │ │ │ │ └── python.scm
│ │ │ └── codeql
│ │ │ │ └── python
│ │ │ │ ├── qlpack.yml
│ │ │ │ ├── imports.ql
│ │ │ │ ├── classes.ql
│ │ │ │ └── functions.ql
│ │ ├── parsers
│ │ │ ├── treesitter
│ │ │ │ ├── __init__.py
│ │ │ │ └── queries.py
│ │ │ ├── __init__.py
│ │ │ ├── codeql
│ │ │ │ ├── __init__.py
│ │ │ │ └── queries.py
│ │ │ └── factory.py
│ │ ├── __init__.py
│ │ └── languages.py
│ ├── ripgrep
│ │ ├── __init__.py
│ │ └── installer.py
│ ├── __init__.py
│ └── codeql
│ │ └── installer.py
├── __init__.py
├── impl
│ └── default
│ │ ├── factory.py
│ │ ├── topic_extractor.py
│ │ ├── smart_codebase.py
│ │ ├── __init__.py
│ │ └── prompt.py
├── retrieval
│ ├── strategy
│ │ ├── filter_symbol_name_by_llm.py
│ │ ├── filter_symbol_content_by_vector.py
│ │ ├── filter_keyword_by_vector.py
│ │ ├── filter_filename_by_llm.py
│ │ ├── filter_dependency_by_llm.py
│ │ ├── __init__.py
│ │ ├── factory.py
│ │ ├── adaptive_filter_keyword_by_vector_and_llm.py
│ │ ├── filter_keyword_by_vector_and_llm.py
│ │ ├── filter_symbol_content_by_vector_and_llm.py
│ │ └── adaptive_filter_symbol_content_by_vector_and_llm.py
│ ├── __init__.py
│ └── topic_extractor.py
└── tools
│ ├── __init__.py
│ ├── find_file_by_name.py
│ ├── base.py
│ ├── view_file.py
│ └── get_references.py
├── bench
├── recall_rate_comparison.png
├── effectiveness_efficiency_comparison.png
├── EXPERIMENTS.md
├── repos.txt
├── repos.lock
└── queries.json
├── justfile
├── .env.example
├── .gitignore
├── scripts
├── try_codelines.py
├── view_chunks.py
├── benchmark.py
└── example.py
├── pyproject.toml
├── test
├── impl
│ └── default
│ │ ├── test_determine_strategy.py
│ │ ├── test_smart_codebase.py
│ │ └── test_code_recall.py
└── tools
│ └── test_tools.py
├── STRATEGIES.md
└── USAGE.md
/.python-version:
--------------------------------------------------------------------------------
1 | 3.12
2 |
--------------------------------------------------------------------------------
/coderetrx/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bench/recall_rate_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XuanwuAI/CodeRetrX/HEAD/bench/recall_rate_comparison.png
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/c.scm:
--------------------------------------------------------------------------------
1 | ;; All includes
2 | (preproc_include path: (_) @module)
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/README.md:
--------------------------------------------------------------------------------
1 | Import module extraction queries.
2 | source: Atum
3 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/cpp.scm:
--------------------------------------------------------------------------------
1 | ;; All includes
2 | (preproc_include path: (_) @module)
--------------------------------------------------------------------------------
/justfile:
--------------------------------------------------------------------------------
1 | local_qdrant:
2 | docker run -d -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/java.scm:
--------------------------------------------------------------------------------
1 | ;; All import types
2 | (import_declaration (identifier) @module)
--------------------------------------------------------------------------------
/coderetrx/static/codebase/parsers/treesitter/__init__.py:
--------------------------------------------------------------------------------
1 | from .parser import TreeSitterParser
2 |
3 | __all__ = ["TreeSitterParser"]
4 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/csharp.scm:
--------------------------------------------------------------------------------
1 | ;; All using types
2 | (using_directive
3 | (identifier) @dependency)
--------------------------------------------------------------------------------
/bench/effectiveness_efficiency_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XuanwuAI/CodeRetrX/HEAD/bench/effectiveness_efficiency_comparison.png
--------------------------------------------------------------------------------
/coderetrx/__init__.py:
--------------------------------------------------------------------------------
1 | from . import retrieval, static
2 | from .static import Codebase
3 |
4 | __all__ = ["Codebase", "retrieval", "static"]
5 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/go.scm:
--------------------------------------------------------------------------------
1 | ;; All import specs
2 | (import_spec path: (interpreted_string_literal) @module)
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/codeql/python/qlpack.yml:
--------------------------------------------------------------------------------
1 | name: python-queries
2 | version: 1.0.0
3 | dependencies:
4 | codeql/python-all: "*"
5 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/python.scm:
--------------------------------------------------------------------------------
1 | ;; Standard imports
2 | (import_statement (dotted_name) @module)
3 | (import_from_statement module_name: (_) @module)
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/rust.scm:
--------------------------------------------------------------------------------
1 | ;; All use patterns
2 | (
3 | (use_declaration
4 | argument: (_
5 | path: (identifier) @module
6 | )
7 | )
8 | )
9 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/parsers/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import CodebaseParser, ExtractionType
2 | from .factory import ParserFactory
3 | from .treesitter import TreeSitterParser
4 |
5 | __all__ = ["CodebaseParser", "ExtractionType", "ParserFactory", "TreeSitterParser"]
6 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | LLM_TOPIC_EXTRACTION_MODEL_ID=openai/gpt-4.1-mini
2 | MAX_CHUNKS_ONE_FILE=500
3 | EMBEDDING_MODEL_ID=text-embedding-3-large
4 | EMBEDDING_BASE_URL=
5 | EMBEDDING_API_KEY=
6 | OPENAI_API_KEY=
7 | OPENAI_BASE_URL=
8 | GITHUB_TOKEN=
9 | KEYWORD_SENTENCE_EXTRACTION=false
10 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/parsers/codeql/__init__.py:
--------------------------------------------------------------------------------
1 | from ....codeql.codeql import CodeQLWrapper, CodeQLDatabase
2 | from .parser import CodeQLParser
3 | from .queries import CodeQLQueryTemplates
4 |
5 | __all__ = ["CodeQLWrapper", "CodeQLDatabase", "CodeQLParser", "CodeQLQueryTemplates"]
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python-generated files
2 | __pycache__/
3 | *.py[oc]
4 | build/
5 | dist/
6 | wheels/
7 | *.egg-info
8 | .vscode
9 | .idea
10 | .pytest_cache
11 |
12 | # Virtual environments
13 | .venv
14 | .env
15 |
16 | # Ripgrep binary
17 | rg
18 | rg.exe
19 | .tmp
20 | .cache
21 | .data
22 |
23 | .DS_Store
24 | code_reports/
25 | logs/
26 | qdrant_storage/
--------------------------------------------------------------------------------
/coderetrx/static/ripgrep/__init__.py:
--------------------------------------------------------------------------------
1 | from .ripgrep import (
2 | ripgrep_glob,
3 | ripgrep_search,
4 | ripgrep_search_symbols,
5 | ripgrep_raw,
6 | GrepMatchResult,
7 | )
8 |
9 | __all__ = [
10 | "ripgrep_glob",
11 | "ripgrep_search",
12 | "ripgrep_search_symbols",
13 | "ripgrep_raw",
14 | "GrepMatchResult",
15 | ]
16 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/javascript.scm:
--------------------------------------------------------------------------------
1 | ;; Static imports
2 | (import_statement source: (string) @module)
3 |
4 | ;; Require
5 | (call_expression
6 | function: (identifier) @func
7 | arguments: (arguments (string) @module)
8 | (#eq? @func "require"))
9 |
10 | ;; Re-exports
11 | (export_statement source: (string) @module)
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/typescript.scm:
--------------------------------------------------------------------------------
1 | ;; Static imports
2 | (import_statement source: (string) @module)
3 |
4 | ;; Require
5 | (call_expression
6 | function: (identifier) @func
7 | arguments: (arguments (string) @module)
8 | (#eq? @func "require"))
9 |
10 | ;; Re-exports
11 | (export_statement source: (string) @module)
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/elisp.scm:
--------------------------------------------------------------------------------
1 | ;; defun/defsubst
2 | (function_definition name: (symbol) @name.definition.function) @definition.function
3 |
4 | ;; Treat macros as function definitions for the sake of TAGS.
5 | (macro_definition name: (symbol) @name.definition.function) @definition.function
6 |
7 | ;; Match function calls
8 | (list (symbol) @name.reference.function) @reference.function
9 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tests/php.scm:
--------------------------------------------------------------------------------
1 | ;; PHPUnit test methods (methods starting with test)
2 | ;; Only match test methods, not test classes, to allow non-test methods in test classes
3 | (
4 | (method_declaration
5 | name: (name) @run @_test_method_name
6 | (#match? @_test_method_name "^test")
7 | ) @_php-test-method
8 | (#set! tag php-test-method)
9 | )
10 |
--------------------------------------------------------------------------------
/coderetrx/utils/_extras.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 |
4 | def require_extra(pkg_name: str, extra: str):
5 | """
6 | Raises an ImportError if pkg_name isn't importable, pointing users
7 | to install the given extra.
8 | """
9 | try:
10 | importlib.import_module(pkg_name)
11 | except ImportError as e:
12 | raise ImportError(
13 | f"This feature requires the '{extra}' extra. "
14 | f"Install with: uv add coderetrx[{extra}]"
15 | ) from e
16 |
--------------------------------------------------------------------------------
/coderetrx/impl/default/factory.py:
--------------------------------------------------------------------------------
1 | # Compatibility import stub for backwards compatibility
2 | # CodebaseFactory has been moved to coderetrx.retrieval.factory
3 |
4 | # Will be removed in future versions
5 | from coderetrx.retrieval.factory import CodebaseFactory
6 | import logging
7 | logger = logging.getLogger(__name__)
8 | logger.warning("The 'coderetrx.impl.default.factory' module is deprecated and will be removed in future versions, use coderetrx.retrieval.factory instead.")
9 |
10 | __all__ = ["CodebaseFactory"]
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/php.scm:
--------------------------------------------------------------------------------
1 | ;; PHP import statements
2 | ;; Basic namespace imports (use statements)
3 | (namespace_use_declaration
4 | (namespace_use_clause
5 | (qualified_name) @module))
6 |
7 | ;; require/include statements with string literals
8 | (require_expression
9 | (string) @module)
10 |
11 | (require_once_expression
12 | (string) @module)
13 |
14 | (include_expression
15 | (string) @module)
16 |
17 | (include_once_expression
18 | (string) @module)
19 |
--------------------------------------------------------------------------------
/coderetrx/impl/default/topic_extractor.py:
--------------------------------------------------------------------------------
1 | # Compatibility import stub for backwards compatibility
2 | # TopicExtractor has been moved to coderetrx.retrieval.topic_extractor
3 | # Will be removed in future versions
4 |
5 | from coderetrx.retrieval.topic_extractor import TopicExtractor
6 | import logging
7 | logger = logging.getLogger(__name__)
8 | logger.warning("The 'coderetrx.impl.default.topic_extractor' module is deprecated and will be removed in future versions, use coderetrx.retrieval.topic_extractor instead.")
9 |
10 |
11 | __all__ = ["TopicExtractor"]
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/codeql/python/imports.ql:
--------------------------------------------------------------------------------
1 | /**
2 | * @name Python Imports
3 | * @description Find all import statements in Python code
4 | * @kind table
5 | * @id codeql/python-imports
6 | */
7 |
8 | import python
9 |
10 | from Import imp
11 | where exists(imp.getLocation()) and exists(imp.getAnImportedModuleName())
12 | select imp.getLocation().getFile().getRelativePath(), imp.getAnImportedModuleName(),
13 | imp.getLocation().getStartLine(), imp.getLocation().getEndLine(),
14 | imp.getLocation().getStartColumn(), imp.getLocation().getEndColumn()
--------------------------------------------------------------------------------
/coderetrx/impl/default/smart_codebase.py:
--------------------------------------------------------------------------------
1 | # Compatibility import stub for backwards compatibility
2 | # SmartCodebase has been moved to coderetrx.retrieval.smart_codebase
3 | # Will be removed in future versions
4 |
5 | from coderetrx.retrieval.smart_codebase import SmartCodebase, SmartCodebaseSettings
6 | import logging
7 | logger = logging.getLogger(__name__)
8 | logger.warning("The 'coderetrx.impl.default.smart_codebase' module is deprecated and will be removed in future versions, use coderetrx.retrieval.smart_codebase instead.")
9 |
10 | __all__ = ["SmartCodebase", "SmartCodebaseSettings"]
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/fine_imports/elixir.scm:
--------------------------------------------------------------------------------
1 | ; tree-sitter query to extract Elixir dependencies
2 |
3 | ; Match alias statements
4 | (call
5 | target: (identifier) @directive (#eq? @directive "alias")
6 | (arguments (alias) @module)
7 | ) @alias_stmt
8 |
9 | ; Match import statements
10 | (call
11 | target: (identifier) @directive (#eq? @directive "import")
12 | (arguments (alias) @module)
13 | ) @import_stmt
14 |
15 | ; Match require statements
16 | (call
17 | target: (identifier) @directive (#eq? @directive "require")
18 | (arguments (alias) @module)
19 | ) @require_stmt
--------------------------------------------------------------------------------
/scripts/try_codelines.py:
--------------------------------------------------------------------------------
1 | from rich import print
2 | from coderetrx.static.codebase import Codebase, File
3 |
4 | code_file = File.jit_for_testing('tst.py', """
5 | import os
6 |
7 | def x():
8 | def y():
9 | return x
10 | print("hello world!")
11 | for i in range(1000):
12 | print(i)
13 |
14 | print("hello world! Again!!")
15 | g = 20
16 | do(sth)
17 | print("What on earth are we doing")
18 | y = 10
19 | return y
20 |
21 | x()
22 | """
23 | )
24 |
25 | code_file.init_all()
26 |
27 | for line in code_file.get_lines(max_chars=100):
28 | print(line)
29 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/codeql/python/classes.ql:
--------------------------------------------------------------------------------
1 | /**
2 | * @name Python Classes with last statement
3 | * @description Find all class definitions and their last statement
4 | * @kind table
5 | * @id codeql/python-classes-laststmt
6 | */
7 |
8 | import python
9 |
10 | from Class cls, Stmt stmt
11 | where
12 | stmt = cls.getBody().getLastItem()
13 | select
14 | cls.getLocation().getFile().getRelativePath(),
15 | cls.getQualifiedName(),
16 | cls.getLocation().getStartLine(),
17 | stmt.getLocation().getEndLine(),
18 | cls.getLocation().getStartColumn(),
19 | cls.getLocation().getEndColumn()
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/codeql/python/functions.ql:
--------------------------------------------------------------------------------
1 | /**
2 | * @name Python Functions with last statement
3 | * @description Find all functions and their last statement
4 | * @kind table
5 | * @id codeql/python-functions-laststmt
6 | */
7 |
8 | import python
9 |
10 | from Function func, Stmt stmt
11 | where
12 | stmt = func.getBody().getLastItem()
13 | select
14 | func.getLocation().getFile().getRelativePath(),
15 | func.getQualifiedName(),
16 | func.getLocation().getStartLine(),
17 | stmt.getLocation().getEndLine(),
18 | func.getLocation().getStartColumn(),
19 | func.getLocation().getEndColumn()
--------------------------------------------------------------------------------
/coderetrx/impl/default/__init__.py:
--------------------------------------------------------------------------------
1 | # Compatibility import stubs for backwards compatibility
2 | # All functionality has been moved to coderetrx.retrieval
3 | # Will be removed in future versions
4 |
5 | from coderetrx.retrieval import (
6 | CodebaseFactory,
7 | SmartCodebase,
8 | TopicExtractor,
9 | )
10 | import logging
11 | logger = logging.getLogger(__name__)
12 | logger.warning("The 'coderetrx.impl.default' module is deprecated and will be removed in future versions. use coderetrx.retrieval instead.")
13 |
14 | # For backwards compatibility, re-export the main classes
15 | __all__ = [
16 | "CodebaseFactory",
17 | "SmartCodebase",
18 | "TopicExtractor",
19 | ]
20 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tests/elixir.scm:
--------------------------------------------------------------------------------
1 | ; Macros `describe`, `test` and `property`.
2 | ; This matches the ExUnit test style.
3 | (
4 | (call
5 | target: (identifier) @run (#any-of? @run "describe" "test" "property")
6 | ) @_elixir-test
7 | (#set! tag elixir-test)
8 | )
9 |
10 | ; Modules containing at least one `describe`, `test` and `property`.
11 | ; This matches the ExUnit test style.
12 | (
13 | (call
14 | target: (identifier) @run (#eq? @run "defmodule")
15 | (do_block
16 | (call target: (identifier) @_keyword (#any-of? @_keyword "describe" "test" "property"))
17 | )
18 | ) @_elixir-module-test
19 | (#set! tag elixir-module-test)
20 | )
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tests/javascript.scm:
--------------------------------------------------------------------------------
1 | ; Add support for (node:test, bun:test and Jest) runnable
2 | ; Function expression that has `it`, `test` or `describe` as the function name
3 | (
4 | (call_expression
5 | function: [
6 | (identifier) @_name
7 | (member_expression
8 | object: [
9 | (identifier) @_name
10 | (member_expression object: (identifier) @_name)
11 | ]
12 | )
13 | ]
14 | (#any-of? @_name "it" "test" "describe" "context" "suite")
15 | arguments: (
16 | arguments . (string (string_fragment) @run)
17 | )
18 | ) @_js-test
19 |
20 | (#set! tag js-test)
21 | )
22 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tests/typescript.scm:
--------------------------------------------------------------------------------
1 | ; Add support for (node:test, bun:test and Jest) runnable
2 | ; Function expression that has `it`, `test` or `describe` as the function name
3 | (
4 | (call_expression
5 | function: [
6 | (identifier) @_name
7 | (member_expression
8 | object: [
9 | (identifier) @_name
10 | (member_expression object: (identifier) @_name)
11 | ]
12 | )
13 | ]
14 | (#any-of? @_name "it" "test" "describe" "context" "suite")
15 | arguments: (
16 | arguments . (string (string_fragment) @run)
17 | )
18 | ) @_js-test
19 |
20 | (#set! tag js-test)
21 | )
22 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/kotlin.scm:
--------------------------------------------------------------------------------
1 | ; Definitions
2 |
3 | (class_declaration
4 | (type_identifier) @name.definition.class) @definition.class
5 |
6 | (function_declaration
7 | (simple_identifier) @name.definition.function) @definition.function
8 |
9 | (object_declaration
10 | (type_identifier) @name.definition.object) @definition.object
11 |
12 | ; References
13 |
14 | (call_expression
15 | [
16 | (simple_identifier) @name.reference.call
17 | (navigation_expression
18 | (navigation_suffix
19 | (simple_identifier) @name.reference.call))
20 | ]) @reference.call
21 |
22 | (delegation_specifier
23 | [
24 | (user_type) @name.reference.type
25 | (constructor_invocation
26 | (user_type) @name.reference.type)
27 | ]) @reference.type
28 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/ql.scm:
--------------------------------------------------------------------------------
1 | (classlessPredicate
2 | name: (predicateName) @name.definition.function) @definition.function
3 |
4 | (memberPredicate
5 | name: (predicateName) @name.definition.method) @definition.method
6 |
7 | (aritylessPredicateExpr
8 | name: (literalId) @name.reference.call) @reference.call
9 |
10 | (module
11 | name: (moduleName) @name.definition.module) @definition.module
12 |
13 | (dataclass
14 | name: (className) @name.definition.class) @definition.class
15 |
16 | (datatype
17 | name: (className) @name.definition.class) @definition.class
18 |
19 | (datatypeBranch
20 | name: (className) @name.definition.class) @definition.class
21 |
22 | (qualifiedRhs
23 | name: (predicateName) @name.reference.call) @reference.call
24 |
25 | (typeExpr
26 | name: (className) @name.reference.type) @reference.type
27 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/python.scm:
--------------------------------------------------------------------------------
1 | (class_definition
2 | name: (identifier) @name.definition.class) @definition.class
3 |
4 | (function_definition
5 | name: (identifier) @name.definition.function) @definition.function
6 |
7 | (call
8 | function: [
9 | (identifier) @name.reference.call
10 | (attribute
11 | attribute: (identifier) @name.reference.call)
12 | ]) @reference.call
13 |
14 | ; Handle imports
15 | (import_statement) @import
16 | (import_from_statement) @import
17 |
18 | ; Variable definitions - all assignment statements
19 | (assignment
20 | left: (identifier) @name.definition.variable) @definition.variable
21 |
22 | (assignment
23 | left: (pattern_list
24 | (identifier) @name.definition.variable)) @definition.variable
25 |
26 | (assignment
27 | left: (attribute) @name.definition.variable) @definition.variable
28 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/filter_symbol_name_by_llm.py:
--------------------------------------------------------------------------------
1 | """
2 | Strategy for filtering symbols using LLM.
3 | """
4 |
5 | from typing import List, override
6 | from .base import FilterByLLMStrategy
7 | from ..smart_codebase import SmartCodebase as Codebase, LLMMapFilterTargetType
8 | from coderetrx.static import Symbol
9 |
10 |
11 | class FilterSymbolNameByLLMStrategy(FilterByLLMStrategy[Symbol]):
12 | """Strategy to filter symbols using LLM."""
13 |
14 | name: str = "FILTER_SYMBOL_NAME_BY_LLM"
15 |
16 | @override
17 | def get_strategy_name(self) -> str:
18 | return self.name
19 |
20 | @override
21 | def get_target_type(self) -> LLMMapFilterTargetType:
22 | return "symbol_name"
23 |
24 | @override
25 | def extract_file_paths(
26 | self, elements: List[Symbol], codebase: Codebase
27 | ) -> List[str]:
28 | return [str(symbol.file.path) for symbol in elements]
29 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/__init__.py:
--------------------------------------------------------------------------------
1 | from .code_recall import (
2 | coderetrx_filter,
3 | coderetrx_mapping,
4 | llm_traversal_filter,
5 | llm_traversal_mapping,
6 | )
7 | from .strategy import (
8 | RecallStrategy,
9 | )
10 | from .smart_codebase import (
11 | SmartCodebase,
12 | SmartCodebaseSettings,
13 | LLMMapFilterTargetType,
14 | SimilaritySearchTargetType,
15 | LLMCallMode,
16 | CodeMapFilterResult,
17 | )
18 | from .topic_extractor import TopicExtractor
19 | from .factory import CodebaseFactory
20 |
21 | __all__ = [
22 | "coderetrx_filter",
23 | "coderetrx_mapping",
24 | "llm_traversal_filter",
25 | "llm_traversal_mapping",
26 | "RecallStrategy",
27 | "SmartCodebase",
28 | "SmartCodebaseSettings",
29 | "LLMMapFilterTargetType",
30 | "SimilaritySearchTargetType",
31 | "LLMCallMode",
32 | "CodeMapFilterResult",
33 | "TopicExtractor",
34 | "CodebaseFactory",
35 | ]
36 |
--------------------------------------------------------------------------------
/coderetrx/impl/default/prompt.py:
--------------------------------------------------------------------------------
1 | # Compatibility import stub for backwards compatibility
2 | # Prompt utilities have been moved to coderetrx.retrieval.prompt
3 | # Will be removed in future versions
4 |
5 | from coderetrx.retrieval.prompt import *
6 | import logging
7 | logger = logging.getLogger(__name__)
8 | logger.warning("The 'coderetrx.impl.default.prompt' module is deprecated and will be removed in future versions, use coderetrx.retrieval.prompt instead.")
9 |
10 | __all__ = [
11 | "llm_filter_prompt_template",
12 | "llm_mapping_prompt_template",
13 | "topic_extraction_prompt_template",
14 | "KeywordExtractorResult",
15 | "llm_filter_function_call_system_prompt",
16 | "llm_mapping_function_call_system_prompt",
17 | "topic_extraction_function_call_system_prompt",
18 | "filter_and_mapping_function_call_user_prompt_template",
19 | "topic_extraction_function_call_user_prompt_template",
20 | "get_filter_function_definition",
21 | "get_mapping_function_definition",
22 | "get_topic_extraction_function_definition",
23 | ]
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/elm.scm:
--------------------------------------------------------------------------------
1 | (value_declaration (function_declaration_left (lower_case_identifier) @name.definition.function)) @definition.function
2 |
3 | (function_call_expr (value_expr (value_qid) @name.reference.function)) @reference.function
4 | (exposed_value (lower_case_identifier) @name.reference.function) @reference.function
5 | (type_annotation ((lower_case_identifier) @name.reference.function) (colon)) @reference.function
6 |
7 | (type_declaration ((upper_case_identifier) @name.definition.type)) @definition.type
8 |
9 | (type_ref (upper_case_qid (upper_case_identifier) @name.reference.type)) @reference.type
10 | (exposed_type (upper_case_identifier) @name.reference.type) @reference.type
11 |
12 | (type_declaration (union_variant (upper_case_identifier) @name.definition.union)) @definition.union
13 |
14 | (value_expr (upper_case_qid (upper_case_identifier) @name.reference.union)) @reference.union
15 |
16 |
17 | (module_declaration
18 | (upper_case_qid (upper_case_identifier)) @name.definition.module
19 | ) @definition.module
20 |
--------------------------------------------------------------------------------
/coderetrx/static/__init__.py:
--------------------------------------------------------------------------------
1 | from .codebase import (
2 | Codebase,
3 | File,
4 | FileType,
5 | Symbol,
6 | Keyword,
7 | Dependency,
8 | CallGraphEdge,
9 | CodebaseModel,
10 | FileModel,
11 | SymbolModel,
12 | KeywordModel,
13 | DependencyModel,
14 | CallGraphEdgeModel,
15 | CodeElement,
16 | CodeElementTypeVar,
17 | )
18 |
19 | from .ripgrep import (
20 | ripgrep_glob,
21 | ripgrep_search,
22 | ripgrep_search_symbols,
23 | ripgrep_raw,
24 | GrepMatchResult,
25 | )
26 |
27 | __all__ = [
28 | # Codebase exports
29 | "Codebase",
30 | "File",
31 | "FileType",
32 | "Symbol",
33 | "Keyword",
34 | "Dependency",
35 | "CallGraphEdge",
36 | "CodebaseModel",
37 | "FileModel",
38 | "SymbolModel",
39 | "KeywordModel",
40 | "DependencyModel",
41 | "CallGraphEdgeModel",
42 | "CodeElementTypeVar",
43 | # Ripgrep exports
44 | "ripgrep_glob",
45 | "ripgrep_search",
46 | "ripgrep_search_symbols",
47 | "ripgrep_raw",
48 | "GrepMatchResult",
49 | "CodeElement"
50 | ]
51 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/__init__.py:
--------------------------------------------------------------------------------
1 | from .codebase import (
2 | Codebase,
3 | File,
4 | FileType,
5 | Symbol,
6 | Keyword,
7 | Dependency,
8 | CallGraphEdge,
9 | CodeElementTypeVar,
10 | CodeElement,
11 | CodeChunk,
12 | ChunkType,
13 | CodeHunk,
14 | CodeLine,
15 | )
16 | from .models import (
17 | CodebaseModel,
18 | FileModel,
19 | SymbolModel,
20 | KeywordModel,
21 | DependencyModel,
22 | CallGraphEdgeModel,
23 | )
24 | from .languages import IDXSupportedLanguage, IDXSupportedTag, BUILTIN_CRYPTO_LIBS
25 |
26 | __all__ = [
27 | "Codebase",
28 | "File",
29 | "FileType",
30 | "Symbol",
31 | "Keyword",
32 | "Dependency",
33 | "CallGraphEdge",
34 | "CodebaseModel",
35 | "FileModel",
36 | "SymbolModel",
37 | "KeywordModel",
38 | "DependencyModel",
39 | "CallGraphEdgeModel",
40 | "CodeElementTypeVar",
41 | "CodeElement",
42 | "CodeChunk",
43 | "ChunkType",
44 | "CodeHunk",
45 | "CodeLine",
46 | "IDXSupportedLanguage",
47 | "IDXSupportedTag",
48 | "BUILTIN_CRYPTO_LIBS",
49 | ]
50 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/java.scm:
--------------------------------------------------------------------------------
1 | (class_declaration
2 | name: (identifier) @name.definition.class) @definition.class
3 |
4 | (method_declaration
5 | name: (identifier) @name.definition.method) @definition.method
6 |
7 | (method_invocation
8 | name: (identifier) @name.reference.call
9 | arguments: (argument_list) @reference.call)
10 |
11 | (interface_declaration
12 | name: (identifier) @name.definition.interface) @definition.interface
13 |
14 | (type_list
15 | (type_identifier) @name.reference.implementation) @reference.implementation
16 |
17 | (object_creation_expression
18 | type: (type_identifier) @name.reference.class) @reference.class
19 |
20 | (superclass (type_identifier) @name.reference.class) @reference.class
21 |
22 | (import_declaration) @import
23 |
24 | ; Variable definitions - field declarations
25 | (field_declaration
26 | declarator: (variable_declarator
27 | name: (identifier) @name.definition.variable)) @definition.variable
28 |
29 | ; Local variable declarations
30 | (local_variable_declaration
31 | declarator: (variable_declarator
32 | name: (identifier) @name.definition.variable)) @definition.variable
33 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/c.scm:
--------------------------------------------------------------------------------
1 | (struct_specifier name: (type_identifier) @name.definition.class body:(_)) @definition.class
2 |
3 | (declaration type: (union_specifier name: (type_identifier) @name.definition.class)) @definition.class
4 |
5 | (function_definition
6 | declarator: (function_declarator
7 | declarator: (identifier) @name.definition.function)) @definition.function
8 |
9 | (function_definition
10 | declarator: (pointer_declarator
11 | declarator: (function_declarator
12 | declarator: (identifier) @name.definition.function))) @definition.function
13 |
14 | (type_definition declarator: (type_identifier) @name.definition.type) @definition.type
15 |
16 | (enum_specifier name: (type_identifier) @name.definition.type) @definition.type
17 |
18 | (preproc_include) @import
19 |
20 | ; Variable definitions - all declarations
21 | (declaration
22 | declarator: (identifier) @name.definition.variable) @definition.variable
23 |
24 | (declaration
25 | declarator: (init_declarator
26 | declarator: (identifier) @name.definition.variable)) @definition.variable
27 |
28 | (declaration
29 | declarator: (pointer_declarator
30 | declarator: (identifier) @name.definition.variable)) @definition.variable
31 |
--------------------------------------------------------------------------------
/bench/EXPERIMENTS.md:
--------------------------------------------------------------------------------
1 | # Experiments
2 |
3 | We conducted comprehensive experiments on large-scale benchmarks across multiple programming languages and repository sizes to validate the effectiveness of our code retrieval strategies. The analysis demonstrates how **`coderetrx_filter`** performs across various bug types and complexity levels.
4 |
5 | ## Experiment Setup
6 |
7 | All methods share the same indexing and parsing pipeline (repository snapshot, language parser, and symbol table extraction). The Symbol Vector Embedding baseline encodes identifier-level semantics only, while Line-per-Symbol indexes each line within a function / class's body, enabling precise structure-aware retrieval. We used **gpt-oss-120b** for dataset construction and evaluation.
8 |
9 | ## Performance Benchmarks
10 |
11 | The following results demonstrate the performance across different strategies. Figure 1 shows the Recall Rate Comparisons across languages and repository sizes and Table 1 compares the Effectiveness and Efficiency between CodeRetrX and baselines.
12 |
13 | 
14 | 
15 |
16 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/c_sharp.scm:
--------------------------------------------------------------------------------
1 | (class_declaration
2 | name: (identifier) @name.definition.class
3 | ) @definition.class
4 |
5 | (class_declaration
6 | bases: (base_list (_) @name.reference.class)
7 | ) @reference.class
8 |
9 | (interface_declaration
10 | name: (identifier) @name.definition.interface
11 | ) @definition.interface
12 |
13 | (interface_declaration
14 | bases: (base_list (_) @name.reference.interface)
15 | ) @reference.interface
16 |
17 | (method_declaration
18 | name: (identifier) @name.definition.method
19 | ) @definition.method
20 |
21 | (object_creation_expression
22 | type: (identifier) @name.reference.class
23 | ) @reference.class
24 |
25 | (type_parameter_constraints_clause
26 | target: (identifier) @name.reference.class
27 | ) @reference.class
28 |
29 | (type_constraint
30 | type: (identifier) @name.reference.class
31 | ) @reference.class
32 |
33 | (variable_declaration
34 | type: (identifier) @name.reference.class
35 | ) @reference.class
36 |
37 | (invocation_expression
38 | function:
39 | (member_access_expression
40 | name: (identifier) @name.reference.send
41 | )
42 | ) @reference.send
43 |
44 | (namespace_declaration
45 | name: (identifier) @name.definition.module
46 | ) @definition.module
47 |
--------------------------------------------------------------------------------
/bench/repos.txt:
--------------------------------------------------------------------------------
1 | // Boringtun, rust, Small, 7937
2 | https://github.com/cloudflare/boringtun
3 |
4 | // Rosenpass, rust, Mid, 30063
5 | https://github.com/rosenpass/rosenpass
6 |
7 | // Neon, rust, Large, 375060
8 | https://github.com/neondatabase/neon
9 |
10 | // Python p2p, python, Small, 1189
11 | https://github.com/GianisTsol/python-p2p
12 |
13 | // Magic Wormhole, python, Mid, 27538
14 | https://github.com/magic-wormhole/magic-wormhole
15 |
16 | // Zulip, python, Large, 499528
17 | https://github.com/zulip/zulip
18 |
19 | // Anubis, golang, Small, 5163
20 | https://github.com/TecharoHQ/anubis
21 |
22 | // FSCrypt, golang, Mid, 18815
23 | https://github.com/google/fscrypt
24 |
25 | // Ethereum, golang, Large, 272576
26 | https://github.com/ethereum/go-ethereum
27 |
28 | // Mastercard Client Encryption, java, Small, 8654
29 | https://github.com/Mastercard/client-encryption-java
30 |
31 | // Keycloak, java, Large, 955351
32 | https://github.com/keycloak/keycloak
33 |
34 | // Cryptomator, java, Mid, 38883
35 | https://github.com/cryptomator/cryptomator
36 |
37 | // Padloc, js/ts, Large, 136233
38 | https://github.com/padloc/padloc
39 |
40 | // Vaultwarden, js/ts, Mid, 62609
41 | https://github.com/dani-garcia/vaultwarden
42 |
43 | // Swifty, js/ts, Small, 8295
44 | https://github.com/swiftyapp/swifty
45 |
--------------------------------------------------------------------------------
/coderetrx/tools/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains tools exposed by mcp server.
3 | """
4 |
5 | from .base import BaseTool
6 | from .find_file_by_name import FindFileByNameTool
7 | from .get_references import GetReferenceTool
8 | from .keyword_search import KeywordSearchTool
9 | from .list_dir import ListDirTool
10 | from .view_file import ViewFileTool
11 | from typing import Type
12 | import sys
13 |
14 | # Export the tools as the default
15 | __all__ = [
16 | "FindFileByNameTool",
17 | "GetReferenceTool",
18 | "KeywordSearchTool",
19 | "ListDirTool",
20 | "ViewFileTool",
21 | ]
22 | tool_classes = [
23 | FindFileByNameTool,
24 | GetReferenceTool,
25 | KeywordSearchTool,
26 | ListDirTool,
27 | ViewFileTool,
28 | ]
29 | tool_map: dict[str, dict[str, BaseTool]] = {}
30 |
31 |
32 | def get_tool_class(name: str) -> Type[BaseTool]:
33 | for cls in tool_classes:
34 | if getattr(cls, "name", None) == name:
35 | return cls
36 | raise ValueError(f"{name} is not a valid tool")
37 |
38 |
39 | def get_tool(repo_url: str, name: str) -> BaseTool:
40 | if tool_map.get(repo_url) is None:
41 | tool_map[repo_url] = {}
42 | if tool_map[repo_url].get(name) is None:
43 | tool_map[repo_url][name] = get_tool_class(name)(repo_url)
44 | return tool_map[repo_url][name]
45 |
46 |
47 | def list_tool_class() -> list[BaseTool]:
48 | return tool_classes
49 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/filter_symbol_content_by_vector.py:
--------------------------------------------------------------------------------
1 | """
2 | Strategy for filtering symbols using vector similarity search.
3 | """
4 |
5 | from typing import List, override
6 | from .base import FilterByVectorStrategy
7 | from ..smart_codebase import SmartCodebase as Codebase, SimilaritySearchTargetType
8 | from coderetrx.static import Symbol
9 |
10 |
11 | class FilterSymbolContentByVectorStrategy(FilterByVectorStrategy[Symbol]):
12 | """Strategy to filter symbols using vector similarity search."""
13 |
14 | name: str = "FILTER_SYMBOL_CONTENT_BY_VECTOR"
15 |
16 | @override
17 | def get_strategy_name(self) -> str:
18 | return self.name
19 |
20 | @override
21 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]:
22 | return ["symbol_content"]
23 |
24 | @override
25 | def get_collection_size(self, codebase: Codebase) -> int:
26 | return len(codebase.symbols)
27 |
28 | @override
29 | def extract_file_paths(
30 | self, elements: List[Symbol], codebase: Codebase, subdirs_or_files: List[str]
31 | ) -> List[str]:
32 | file_paths = []
33 | for symbol in elements:
34 | if isinstance(symbol, Symbol):
35 | file_path = str(symbol.file.path)
36 | if file_path.startswith(tuple(subdirs_or_files)):
37 | file_paths.append(file_path)
38 | return list(set(file_paths))
39 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/go.scm:
--------------------------------------------------------------------------------
1 | (
2 | (comment)* @doc
3 | .
4 | (function_declaration
5 | name: (identifier) @name.definition.function) @definition.function
6 | (#strip! @doc "^//\\s*")
7 | (#set-adjacent! @doc @definition.function)
8 | )
9 |
10 | (
11 | (comment)* @doc
12 | .
13 | (method_declaration
14 | name: (field_identifier) @name.definition.method) @definition.method
15 | (#strip! @doc "^//\\s*")
16 | (#set-adjacent! @doc @definition.method)
17 | )
18 |
19 | (call_expression
20 | function: [
21 | (identifier) @name.reference.call
22 | (parenthesized_expression (identifier) @name.reference.call)
23 | (selector_expression field: (field_identifier) @name.reference.call)
24 | (parenthesized_expression (selector_expression field: (field_identifier) @name.reference.call))
25 | ]) @reference.call
26 |
27 | (type_spec
28 | name: (type_identifier) @name.definition.type) @definition.type
29 |
30 | (type_identifier) @name.reference.type @reference.type
31 |
32 | (import_spec) @import
33 |
34 | ; Variable definitions
35 | (var_declaration
36 | (var_spec
37 | name: (identifier) @name.definition.variable)) @definition.variable
38 |
39 | (const_declaration
40 | (const_spec
41 | name: (identifier) @name.definition.variable)) @definition.variable
42 |
43 | (short_var_declaration
44 | left: (expression_list
45 | (identifier) @name.definition.variable)) @definition.variable
46 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/filter_keyword_by_vector.py:
--------------------------------------------------------------------------------
1 | """
2 | Strategy for filtering keywords using vector similarity search.
3 | """
4 |
5 | from typing import List, override, Any
6 | from .base import FilterByVectorStrategy
7 | from ..smart_codebase import SmartCodebase as Codebase, SimilaritySearchTargetType
8 | from coderetrx.static import Keyword
9 |
10 |
11 | class FilterKeywordByVectorStrategy(FilterByVectorStrategy[Keyword]):
12 | """Strategy to filter keywords using vector similarity search."""
13 |
14 | name: str = "FILTER_KEYWORD_BY_VECTOR"
15 |
16 | @override
17 | def get_strategy_name(self) -> str:
18 | return self.name
19 |
20 | @override
21 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]:
22 | return ["keyword"]
23 |
24 | @override
25 | def get_collection_size(self, codebase: Codebase) -> int:
26 | return len(codebase.keywords)
27 |
28 | @override
29 | def extract_file_paths(
30 | self, elements: List[Keyword], codebase: Codebase, subdirs_or_files: List[str]
31 | ) -> List[str]:
32 | referenced_paths = set()
33 | for item in elements:
34 | if isinstance(item, Keyword) and item.referenced_by:
35 | for ref_file in item.referenced_by:
36 | if str(ref_file.path).startswith(tuple(subdirs_or_files)):
37 | referenced_paths.add(str(ref_file.path))
38 | return list(referenced_paths)
39 |
--------------------------------------------------------------------------------
/coderetrx/utils/git.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | from pathlib import Path
5 | from typing import Tuple
6 | from git import Repo, GitCommandError
7 |
8 | import logging
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def clone_repo_if_not_exists(repo_url: str, target_path: str) -> str:
14 | logger.info(f"clone repo: {repo_url}")
15 | repo_path = Path(target_path)
16 | if not repo_path.exists():
17 | try:
18 | logger.info(f"Cloning repository {repo_url} into {repo_path}")
19 | # todo: remove ssh clone
20 | Repo.clone_from(
21 | repo_url,
22 | repo_path,
23 | env=(
24 | {"GIT_SSH_COMMAND": f'ssh -i {Path(__file__).parent/"sg_rsa"}'}
25 | if repo_url.endswith(".git")
26 | else None
27 | ),
28 | depth=1,
29 | )
30 | except GitCommandError as e:
31 | repo_path.unlink(missing_ok=True)
32 | raise Exception(f"Clone failed: {e}")
33 | return repo_path.as_posix()
34 |
35 |
36 | def get_repo_id(repo_url: str) -> str:
37 | if repo_url.startswith("http"):
38 | repo_id = "_".join(repo_url.split("/")[-2:])
39 | else:
40 | # todo: only for backward compatibility, we need to remove this in the future
41 | repo_id = repo_url.split("/")[-1]
42 | repo_id = repo_id.replace(".git", "")
43 | return repo_id
44 |
45 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/ruby.scm:
--------------------------------------------------------------------------------
1 | ; Method definitions
2 |
3 | (
4 | (comment)* @doc
5 | .
6 | [
7 | (method
8 | name: (_) @name.definition.method) @definition.method
9 | (singleton_method
10 | name: (_) @name.definition.method) @definition.method
11 | ]
12 | (#strip! @doc "^#\\s*")
13 | (#select-adjacent! @doc @definition.method)
14 | )
15 |
16 | (alias
17 | name: (_) @name.definition.method) @definition.method
18 |
19 | (setter
20 | (identifier) @ignore)
21 |
22 | ; Class definitions
23 |
24 | (
25 | (comment)* @doc
26 | .
27 | [
28 | (class
29 | name: [
30 | (constant) @name.definition.class
31 | (scope_resolution
32 | name: (_) @name.definition.class)
33 | ]) @definition.class
34 | (singleton_class
35 | value: [
36 | (constant) @name.definition.class
37 | (scope_resolution
38 | name: (_) @name.definition.class)
39 | ]) @definition.class
40 | ]
41 | (#strip! @doc "^#\\s*")
42 | (#select-adjacent! @doc @definition.class)
43 | )
44 |
45 | ; Module definitions
46 |
47 | (
48 | (module
49 | name: [
50 | (constant) @name.definition.module
51 | (scope_resolution
52 | name: (_) @name.definition.module)
53 | ]) @definition.module
54 | )
55 |
56 | ; Calls
57 |
58 | (call method: (identifier) @name.reference.call) @reference.call
59 |
60 | (
61 | [(identifier) (constant)] @name.reference.call @reference.call
62 | (#is-not? local)
63 | (#not-match? @name.reference.call "^(lambda|load|require|require_relative|__FILE__|__LINE__)$")
64 | )
65 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/php.scm:
--------------------------------------------------------------------------------
1 | ;; Class definitions
2 | (class_declaration
3 | name: (name) @name.definition.class) @definition.class
4 |
5 | ;; Interface definitions
6 | (interface_declaration
7 | name: (name) @name.definition.interface) @definition.interface
8 |
9 | ;; Trait definitions
10 | (trait_declaration
11 | name: (name) @name.definition.class) @definition.class
12 |
13 | ;; Function definitions
14 | (function_definition
15 | name: (name) @name.definition.function) @definition.function
16 |
17 | ;; Method definitions
18 | (method_declaration
19 | name: (name) @name.definition.method) @definition.method
20 |
21 | ;; Namespace imports
22 | (namespace_use_declaration
23 | (namespace_use_clause
24 | (qualified_name) @name.import)) @import
25 |
26 | ;; Object creation (class references)
27 | (object_creation_expression
28 | [
29 | (qualified_name (name) @name.reference.class)
30 | (variable_name (name) @name.reference.class)
31 | ]) @reference.class
32 |
33 | ;; Function calls
34 | (function_call_expression
35 | function: [
36 | (qualified_name (name) @name.reference.call)
37 | (variable_name (name)) @name.reference.call
38 | ]) @reference.call
39 |
40 | ;; Static method calls
41 | (scoped_call_expression
42 | name: (name) @name.reference.call) @reference.call
43 |
44 | ;; Member method calls
45 | (member_call_expression
46 | name: (name) @name.reference.call) @reference.call
47 |
48 | ;; Property declarations (class variables)
49 | (property_declaration
50 | (property_element
51 | (variable_name) @name.definition.variable)) @definition.variable
52 |
--------------------------------------------------------------------------------
/scripts/view_chunks.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import sys
3 |
4 | from coderetrx import Codebase
5 | from coderetrx.static.codebase import CodeChunk
6 | from coderetrx.retrieval.factory import CodebaseFactory
7 | from coderetrx.retrieval.smart_codebase import SmartCodebaseSettings
8 |
9 |
10 | def load_coderetrx_codebase(codebase_path: str) -> Codebase:
11 | abs_repo_path = Path(codebase_path).resolve()
12 |
13 | smart_settings = SmartCodebaseSettings()
14 | smart_settings.symbol_codeline_embedding = False
15 | smart_settings.keyword_embedding = False
16 | smart_settings.symbol_content_embedding = False
17 | smart_settings.symbol_name_embedding = False
18 |
19 | codebase = CodebaseFactory.new(
20 | str(abs_repo_path),
21 | abs_repo_path,
22 | smart_settings,
23 | )
24 | return codebase
25 |
26 |
27 | def show_code_chunks(codebase: Codebase, file_name: str):
28 | for chunk in codebase.get_splited_distinct_chunks(100):
29 | # for chunk in codebase.all_chunks:
30 | if str(chunk.src.path) != file_name:
31 | continue
32 | print("Type:", chunk.type)
33 | print()
34 | print(
35 | chunk.ast_codeblock(
36 | show_line_numbers=True,
37 | zero_based_line_numbers=False,
38 | show_imports=True,
39 | )
40 | )
41 | print("=" * 40)
42 |
43 |
44 | def main():
45 | codebase_path = sys.argv[1]
46 | codebase = load_coderetrx_codebase(codebase_path)
47 |
48 | file_name = sys.argv[2]
49 | show_code_chunks(codebase, file_name)
50 |
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tests/go.scm:
--------------------------------------------------------------------------------
1 | ; Functions names start with `Test`
2 | (
3 | (
4 | (function_declaration name: (_) @run
5 | (#match? @run "^Test.*"))
6 | ) @_
7 | (#set! tag go-test)
8 | )
9 |
10 | ; `go:generate` comments
11 | (
12 | ((comment) @_comment @run
13 | (#match? @_comment "^//go:generate"))
14 | (#set! tag go-generate)
15 | )
16 |
17 | ; `t.Run`
18 | (
19 | (
20 | (call_expression
21 | function: (
22 | selector_expression
23 | field: _ @run @_name
24 | (#eq? @_name "Run")
25 | )
26 | arguments: (
27 | argument_list
28 | .
29 | (interpreted_string_literal) @_subtest_name
30 | .
31 | (func_literal
32 | parameters: (
33 | parameter_list
34 | (parameter_declaration
35 | name: (identifier) @_param_name
36 | type: (pointer_type
37 | (qualified_type
38 | package: (package_identifier) @_pkg
39 | name: (type_identifier) @_type
40 | (#eq? @_pkg "testing")
41 | (#eq? @_type "T")
42 | )
43 | )
44 | )
45 | )
46 | ) @_second_argument
47 | )
48 | )
49 | ) @_
50 | (#set! tag go-subtest)
51 | )
52 |
53 | ; Functions names start with `Benchmark`
54 | (
55 | (
56 | (function_declaration name: (_) @run @_name
57 | (#match? @_name "^Benchmark.+"))
58 | ) @_
59 | (#set! tag go-benchmark)
60 | )
61 |
62 | ; Functions names start with `Fuzz`
63 | (
64 | (
65 | (function_declaration name: (_) @run @_name
66 | (#match? @_name "^Fuzz"))
67 | ) @_
68 | (#set! tag go-fuzz)
69 | )
70 |
71 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "coderetrx"
3 | version = "0.2.0beta"
4 | description = "Library for Code Analysis via Static Analysis and LLMs"
5 | readme = "README.md"
6 | license = { file = "LICENSE" }
7 | requires-python = ">=3.12"
8 | dependencies = [
9 | "attrs>=25.3.0",
10 | "git-python>=1.0.3",
11 | "httpx>=0.27.2",
12 | "ignore-python>=0.2.0",
13 | "json-repair>=0.46.0",
14 | "nanoid>=2.0.0",
15 | "pydantic>=2.11.4",
16 | "pydantic-settings>=2.9.1",
17 | "python-dotenv>=1.1.0",
18 | "python-git>=2018.2.1",
19 | "tomli-w>=1.2.0",
20 | "tree-sitter>=0.25.2",
21 | "tree-sitter-language-pack>=0.7.3",
22 | "openai>=1.0.0",
23 | "qdrant-client>=1.12.0",
24 | "tenacity>=9.1.2",
25 | "aiofiles>=24.1.0",
26 | "tiktoken>=0.9.0",
27 | "filetype>=1.2.0",
28 | "mcp>=1.13.1",
29 | "starlette>=0.47.3",
30 | "anyio>=4.9.0",
31 | "pytest-asyncio>=1.2.0",
32 | ]
33 |
34 | [tool.setuptools.packages.find]
35 | where = ["."]
36 | include = ["coderetrx*"]
37 | exclude = []
38 | namespaces = false
39 |
40 | [dependency-groups]
41 | dev = [
42 | "httpx[socks]>=0.28.1",
43 | "ipython>=9.2.0",
44 | "pytest>=8.3.5",
45 | "rich>=14.0.0",
46 | "black",
47 | ]
48 |
49 | [project.optional-dependencies]
50 | stats = ["tiktoken>=0.9.0"]
51 | cli = ["typer>=0.16.0"]
52 | redis = ["redis>=5.0.0"]
53 | chromadb = ["chromadb>=0.4.0"]
54 |
55 |
56 | [tool.pytest.ini_options]
57 | asyncio_default_fixture_loop_scope = "function"
58 |
59 | [build-system]
60 | requires = ["setuptools>=61.0"]
61 | build-backend = "setuptools.build_meta"
62 |
63 | [tool.setuptools]
64 | include-package-data = true
65 |
66 | [tool.setuptools.package-data]
67 | "coderetrx" = ["**/*.scm", "**/*.py"]
68 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/filter_filename_by_llm.py:
--------------------------------------------------------------------------------
1 | """
2 | Strategy for filtering filenames using LLM.
3 | """
4 |
5 | from typing import List, override
6 | from .base import FilterByLLMStrategy, StrategyExecuteResult
7 | from ..smart_codebase import SmartCodebase as Codebase, LLMMapFilterTargetType
8 | from coderetrx.static import File
9 |
10 |
11 | class FilterFilenameByLLMStrategy(FilterByLLMStrategy[File]):
12 | """Strategy to filter filenames using LLM."""
13 |
14 | @override
15 | def get_strategy_name(self) -> str:
16 | return "FILTER_FILENAME_BY_LLM"
17 |
18 | @override
19 | def get_target_type(self) -> LLMMapFilterTargetType:
20 | return "file_name"
21 |
22 | @override
23 | def extract_file_paths(self, elements: List[File], codebase: Codebase) -> List[str]:
24 | return [str(file.path) for file in elements]
25 |
26 | @override
27 | async def execute(
28 | self,
29 | codebase: Codebase,
30 | prompt: str,
31 | subdirs_or_files: List[str],
32 | target_type: str = "symbol_content",
33 | ) -> StrategyExecuteResult:
34 | prompt = f"""
35 | A file with this path is highly likely to contain content that matches the following criteria:
36 |
37 | {prompt}
38 |
39 |
40 | The objective of this requirement is to preliminarily identify files based on their paths that are likely to meet specific content criteria.
41 | Files with matching paths will proceed to a deeper analysis in the content filter (content_criterias) at a later stage (not in this run).
42 |
43 | """
44 | return await super().execute(codebase, prompt, subdirs_or_files, target_type)
45 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/csharp.scm:
--------------------------------------------------------------------------------
1 | ; Based on https://github.com/tree-sitter/tree-sitter-c-sharp/blob/master/queries/tags.scm
2 | ; MIT License.
3 |
4 | (class_declaration name: (identifier) @name.definition.class) @definition.class
5 |
6 | (class_declaration (base_list (_) @name.reference.class)) @reference.class
7 |
8 | (interface_declaration name: (identifier) @name.definition.interface) @definition.interface
9 |
10 | (interface_declaration (base_list (_) @name.reference.interface)) @reference.interface
11 |
12 | (method_declaration name: (identifier) @name.definition.method) @definition.method
13 |
14 | (object_creation_expression type: (identifier) @name.reference.class) @reference.class
15 |
16 | (type_parameter_constraints_clause (identifier) @name.reference.class) @reference.class
17 |
18 | (type_parameter_constraint (type type: (identifier) @name.reference.class)) @reference.class
19 |
20 | (variable_declaration type: (identifier) @name.reference.class) @reference.class
21 |
22 | (invocation_expression function: (member_access_expression name: (identifier) @name.reference.send)) @reference.send
23 |
24 | (namespace_declaration name: (identifier) @name.definition.module) @definition.module
25 |
26 | (namespace_declaration name: (identifier) @name.definition.module) @module
27 |
28 | (using_directive) @import
29 |
30 | ; Variable definitions - field and variable declarations
31 | (field_declaration
32 | (variable_declaration
33 | (variable_declarator
34 | (identifier) @name.definition.variable))) @definition.variable
35 |
36 | (local_declaration_statement
37 | (variable_declaration
38 | (variable_declarator
39 | (identifier) @name.definition.variable))) @definition.variable
40 |
41 | (property_declaration
42 | name: (identifier) @name.definition.variable) @definition.variable
43 |
--------------------------------------------------------------------------------
/coderetrx/utils/path.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from coderetrx.utils.git import get_repo_id
3 | import os
4 |
5 |
6 | def get_data_dir():
7 | """Returns the path to the data directory."""
8 | data_dir = os.getenv("DATA_DIR")
9 | if data_dir:
10 | data_dir = Path(data_dir)
11 | else:
12 | data_dir = Path(__file__).parent.parent.parent / ".data"
13 | data_dir.mkdir(parents=True, exist_ok=True)
14 | return data_dir
15 |
16 | def get_repo_path(repo_url):
17 | repo_path = get_data_dir() / "repos" / get_repo_id(repo_url)
18 | return repo_path
19 |
20 | def get_cache_dir():
21 | cache_dir = os.getenv("CACHE_DIR")
22 | if cache_dir:
23 | cache_dir = Path(cache_dir)
24 | else:
25 | cache_dir = Path(__file__).parent.parent.parent / ".cache"
26 | (cache_dir / "llm").mkdir(parents=True, exist_ok=True)
27 | (cache_dir / "embedding").mkdir(parents=True, exist_ok=True)
28 | cache_dir.mkdir(parents=True, exist_ok=True)
29 | return cache_dir
30 |
31 | def safe_join(path1: str | Path, path2: str | Path) -> Path:
32 | if isinstance(path1, str):
33 | path1 = Path(path1)
34 | if isinstance(path2, str):
35 | path2 = Path(path2)
36 | result = path1 / path2
37 | # Basic check: disallow absolute path2 that would ignore path1
38 | if not result.is_relative_to(path1):
39 | raise ValueError(f"Path {path2} is not relative to the base directory {path1}")
40 | # Symlink-aware check: resolve both sides (non-strict to allow non-existing leaf)
41 | resolved_base = path1.resolve(strict=False)
42 | resolved_result = (path1 / path2).resolve(strict=False)
43 | if not resolved_result.is_relative_to(resolved_base):
44 | raise ValueError(
45 | f"Path {path2} escapes base directory {path1} via symlink resolution"
46 | )
47 | return result
48 |
--------------------------------------------------------------------------------
/test/impl/default/test_determine_strategy.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import asyncio
3 | from coderetrx.retrieval.code_recall import _determine_strategy_by_llm
4 | import json
5 | import logging
6 |
7 | logger = logging.getLogger(__name__)
8 | logging.basicConfig(level=logging.INFO)
9 |
10 | def generate_prompts(limit=10):
11 | all_prompts = []
12 | bug_name = []
13 | feature_file = "features/test.json"
14 | try:
15 | with open(feature_file, 'r', encoding='utf-8') as f:
16 | features = json.load(f)
17 | for feature in features:
18 | if len(all_prompts) >= limit:
19 | break
20 | # Extract filter prompts from resources
21 | for resource in feature.get('resources', []):
22 | if resource.get('type') == 'ToolCallingResource':
23 | filter_prompt = resource.get('tool_input_kwargs', {}).get('filter_prompt')
24 | if filter_prompt:
25 | all_prompts.append(filter_prompt)
26 | bug_name.append(feature.get('name'))
27 | if len(all_prompts) >= limit:
28 | break
29 | except Exception as e:
30 | logger.warning(f"Failed to load feature file {feature_file}: {e}")
31 | return all_prompts[:limit], bug_name[:limit]
32 |
33 | class TestDetermineStrategy(unittest.TestCase):
34 | def test_determine_strategy(self):
35 | async def run_test():
36 | res = generate_prompts(25)
37 | for idx in range(len(res[0])):
38 | bug_name = res[1][idx]
39 | prompt = res[0][idx]
40 | print(f"Bug Name: {bug_name}")
41 | strategy = await _determine_strategy_by_llm(prompt)
42 | print(f"Strategy: {strategy}")
43 | asyncio.run(run_test())
44 |
45 | if __name__ == "__main__":
46 | unittest.main()
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/typescript.scm:
--------------------------------------------------------------------------------
1 | (function_signature
2 | name: (identifier) @name.definition.function) @definition.function
3 |
4 | (method_signature
5 | name: (property_identifier) @name.definition.method) @definition.method
6 |
7 | (abstract_method_signature
8 | name: (property_identifier) @name.definition.method) @definition.method
9 |
10 | (abstract_class_declaration
11 | name: (type_identifier) @name.definition.class) @definition.class
12 |
13 | (module
14 | name: (identifier) @name.definition.module) @definition.module
15 |
16 | (interface_declaration
17 | name: (type_identifier) @name.definition.interface) @definition.interface
18 |
19 | (type_annotation
20 | (type_identifier) @name.reference.type) @reference.type
21 |
22 | (new_expression
23 | constructor: (identifier) @name.reference.class) @reference.class
24 |
25 | (function_declaration
26 | name: (identifier) @name.definition.function) @definition.function
27 |
28 | (arrow_function) @definition.function
29 |
30 | (method_definition
31 | name: (property_identifier) @name.definition.method) @definition.method
32 |
33 | (class_declaration
34 | name: (type_identifier) @name.definition.class) @definition.class
35 |
36 | (interface_declaration
37 | name: (type_identifier) @name.definition.class) @definition.class
38 |
39 | (type_alias_declaration
40 | name: (type_identifier) @name.definition.type) @definition.type
41 |
42 | (enum_declaration
43 | name: (identifier) @name.definition.enum) @definition.enum
44 |
45 | (import_statement) @import
46 |
47 | ; Variable definitions - const/let/var declarations
48 | ; Note: We capture all variable declarations, function assignments are already captured above
49 | (lexical_declaration
50 | (variable_declarator
51 | name: (identifier) @name.definition.variable)) @definition.variable
52 |
53 | (variable_declaration
54 | (variable_declarator
55 | name: (identifier) @name.definition.variable)) @definition.variable
56 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tests/rust.scm:
--------------------------------------------------------------------------------
1 | ; Rust mod test
2 | (
3 | (mod_item
4 | name: (_) @run
5 | (#eq? @run "tests")
6 | ) @test
7 | (#set! tag rust-mod-test)
8 | )
9 |
10 | ; Rust test
11 | (
12 | (
13 | (attribute_item (attribute
14 | [((identifier) @_attribute)
15 | (scoped_identifier (identifier) @_attribute)
16 | ])
17 | (#match? @_attribute "test")
18 | ) @_start
19 | .
20 | (attribute_item) *
21 | .
22 | [(line_comment) (block_comment)] *
23 | .
24 | (function_item
25 | name: (_) @run @_test_name
26 | body: _
27 | ) @_end
28 | ) @test
29 | (#set! tag rust-test)
30 | )
31 |
32 | ; Rust doc test
33 | ; (
34 | ; (
35 | ; (line_comment) *
36 | ; (line_comment
37 | ; doc: (_) @_comment_content
38 | ; ) @_start @run
39 | ; (#match? @_comment_content "```")
40 | ; .
41 | ; (line_comment) *
42 | ; .
43 | ; (line_comment
44 | ; doc: (_) @_end_comment_content
45 | ; ) @_end_code_block
46 | ; (#match? @_end_comment_content "```")
47 | ; .
48 | ; (line_comment) *
49 | ; (attribute_item) *
50 | ; .
51 | ; [(function_item
52 | ; name: (_) @_doc_test_name
53 | ; body: _
54 | ; ) (function_signature_item
55 | ; name: (_) @_doc_test_name
56 | ; ) (struct_item
57 | ; name: (_) @_doc_test_name
58 | ; ) (enum_item
59 | ; name: (_) @_doc_test_name
60 | ; body: _
61 | ; ) (
62 | ; (attribute_item) ?
63 | ; (macro_definition
64 | ; name: (_) @_doc_test_name)
65 | ; ) (mod_item
66 | ; name: (_) @_doc_test_name
67 | ; )] @_end
68 | ; ) @test
69 | ; (#set! tag rust-doc-test)
70 | ; )
71 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/cpp.scm:
--------------------------------------------------------------------------------
1 | (struct_specifier name: (type_identifier) @name.definition.class body:(_)) @definition.class
2 |
3 | (declaration type: (union_specifier name: (type_identifier) @name.definition.class)) @definition.class
4 |
5 | (function_definition
6 | declarator: (function_declarator
7 | declarator: (identifier) @name.definition.function)) @definition.function
8 |
9 | (function_definition
10 | declarator: (pointer_declarator
11 | declarator: (function_declarator
12 | declarator: (identifier) @name.definition.function))) @definition.function
13 |
14 | (function_definition
15 | declarator: (function_declarator
16 | declarator: (field_identifier) @name.definition.function)) @definition.function
17 |
18 | (function_definition
19 | declarator: (function_declarator
20 | declarator: (qualified_identifier scope: (namespace_identifier) @scope name: (identifier) @name.definition.method))) @definition.method
21 |
22 | (function_definition
23 | declarator: (reference_declarator
24 | (function_declarator
25 | declarator: (qualified_identifier scope: (namespace_identifier) @scope name: (identifier) @name.definition.method)))) @definition.method
26 |
27 | (type_definition declarator: (type_identifier) @name.definition.type) @definition.type
28 |
29 | (enum_specifier name: (type_identifier) @name.definition.type) @definition.type
30 |
31 | (class_specifier name: (type_identifier) @name.definition.class) @definition.class
32 |
33 | (preproc_include) @import
34 |
35 | ; Variable definitions - all declarations
36 | (declaration
37 | declarator: (identifier) @name.definition.variable) @definition.variable
38 |
39 | (declaration
40 | declarator: (init_declarator
41 | declarator: (identifier) @name.definition.variable)) @definition.variable
42 |
43 | (declaration
44 | declarator: (pointer_declarator
45 | declarator: (identifier) @name.definition.variable)) @definition.variable
46 |
47 | (declaration
48 | declarator: (reference_declarator
49 | (identifier) @name.definition.variable)) @definition.variable
50 |
--------------------------------------------------------------------------------
/scripts/benchmark.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import re
3 | import subprocess
4 | import time
5 | from pathlib import Path
6 |
7 | async def run_command(cmd):
8 | process = await asyncio.create_subprocess_exec(*cmd, cwd=Path(__file__).parent.parent)
9 | return await process.wait()
10 |
11 | def parse_repos():
12 | repos = []
13 | with open("bench/repos.txt", 'r') as f:
14 | content = f.read()
15 |
16 | blocks = [block.strip() for block in content.split('\n\n') if block.strip()]
17 | for block in blocks:
18 | lines = block.strip().split('\n')
19 | if len(lines) >= 2 and lines[1].startswith('https://'):
20 | repos.append(lines[1].strip())
21 |
22 | return repos
23 |
24 | async def main():
25 | repos = parse_repos()
26 | modes = ["file_name", "symbol_name", "line_per_symbol", "precise", "auto"]
27 |
28 | print(f"Running benchmark on {len(repos)} repositories with {len(modes)} modes")
29 |
30 | repo_url = "https://github.com/ollama/ollama"
31 | for mode in modes:
32 | print(f" Running {mode} mode...")
33 | cmd = ["uv", "run", "scripts/code_retriever.py", "-l", "9", "-f", "--mode", mode, "--repo", repo_url]
34 | exit_code = await run_command(cmd)
35 | if exit_code != 0:
36 | print(f" Failed with exit code {exit_code}")
37 |
38 | # for i, repo_url in enumerate(repos, 1):
39 | # print(f"\n[{i}/{len(repos)}] Processing {repo_url}")
40 | #
41 | # for mode in modes:
42 | # print(f" Running {mode} mode...")
43 | # cmd = ["uv", "run", "scripts/code_retriever.py", "-l", "9", "-f", "--mode", mode, "--repo", repo_url, "-t"]
44 | # exit_code = await run_command(cmd)
45 | # if exit_code != 0:
46 | # print(f" Failed with exit code {exit_code}")
47 | #
48 | # break
49 |
50 | print("\nRunning analyze_code_reports...")
51 | await run_command(["uv", "run", "scripts/analyze_code_reports.py"])
52 |
53 | print("Benchmark completed")
54 |
55 | if __name__ == "__main__":
56 | asyncio.run(main())
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/rust.scm:
--------------------------------------------------------------------------------
1 | ; ADT definitions
2 |
3 | (struct_item
4 | name: (type_identifier) @name.definition.class) @definition.class
5 |
6 | (enum_item
7 | name: (type_identifier) @name.definition.class) @definition.class
8 |
9 | (union_item
10 | name: (type_identifier) @name.definition.class) @definition.class
11 |
12 | ; type aliases
13 |
14 | (type_item
15 | name: (type_identifier) @name.definition.class) @definition.class
16 |
17 | ; method definitions
18 |
19 | ; (declaration_list
20 | ; (function_item
21 | ; name: (identifier) @name.definition.method) @definition.method)
22 |
23 | ; function definitions
24 |
25 | (function_item
26 | name: (identifier) @name.definition.function) @definition.function
27 |
28 | ; trait definitions
29 | (trait_item
30 | name: (type_identifier) @name.definition.interface) @definition.interface
31 |
32 | ; module definitions
33 | (mod_item
34 | name: (identifier) @name.definition.module) @definition.module
35 |
36 | ; macro definitions
37 |
38 | (macro_definition
39 | name: (identifier) @name.definition.macro) @definition.macro
40 |
41 | ; references
42 |
43 | (call_expression
44 | function: (identifier) @name.reference.call) @reference.call
45 |
46 | (call_expression
47 | function: (field_expression
48 | field: (field_identifier) @name.reference.call)) @reference.call
49 |
50 | (macro_invocation
51 | macro: (identifier) @name.reference.call) @reference.call
52 |
53 | ; implementations
54 |
55 | (impl_item
56 | trait: (type_identifier) @name.reference.implementation) @reference.implementation
57 |
58 | (impl_item
59 | type: (type_identifier) @name.reference.implementation
60 | !trait) @reference.implementation
61 |
62 | ; handle imports
63 |
64 | ( use_declaration ) @import
65 |
66 | ; Variable definitions
67 | (let_declaration
68 | pattern: (identifier) @name.definition.variable) @definition.variable
69 |
70 | (const_item
71 | name: (identifier) @name.definition.variable) @definition.variable
72 |
73 | (static_item
74 | name: (identifier) @name.definition.variable) @definition.variable
75 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/elixir.scm:
--------------------------------------------------------------------------------
1 | ; Definitions
2 |
3 | ; * modules and protocols
4 | (call
5 | target: (identifier) @ignore
6 | (arguments (alias) @name.definition.module)
7 | (#match? @ignore "^(defmodule|defprotocol)$")) @definition.module
8 |
9 | ; * functions/macros
10 | (call
11 | target: (identifier) @ignore
12 | (arguments
13 | [
14 | ; zero-arity functions with no parentheses
15 | (identifier) @name.definition.function
16 | ; regular function clause
17 | (call target: (identifier) @name.definition.function)
18 | ; function clause with a guard clause
19 | (binary_operator
20 | left: (call target: (identifier) @name.definition.function)
21 | operator: "when")
22 | ])
23 | (#match? @ignore "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @definition.function
24 |
25 | ; References
26 |
27 | ; ignore calls to kernel/special-forms keywords
28 | (call
29 | target: (identifier) @ignore
30 | (#match? @ignore "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp|defmodule|defprotocol|defimpl|defstruct|defexception|defoverridable|alias|case|cond|else|for|if|import|quote|raise|receive|require|reraise|super|throw|try|unless|unquote|unquote_splicing|use|with)$"))
31 |
32 | ; ignore module attributes
33 | (unary_operator
34 | operator: "@"
35 | operand: (call
36 | target: (identifier) @ignore))
37 |
38 | ; * function call
39 | (call
40 | target: [
41 | ; local
42 | (identifier) @name.reference.call
43 | ; remote
44 | (dot
45 | right: (identifier) @name.reference.call)
46 | ]) @reference.call
47 |
48 | ; * pipe into function call
49 | (binary_operator
50 | operator: "|>"
51 | right: (identifier) @name.reference.call) @reference.call
52 |
53 | ; * modules
54 | (alias) @name.reference.module @reference.module
55 |
56 | (call
57 | target: (identifier)@_cap
58 | (#any-of? @_cap "import" "use" "require"))@import
59 |
60 | ; Variable definitions - module attributes
61 | (unary_operator
62 | operator: "@"
63 | operand: (call
64 | target: (identifier) @name.definition.variable)) @definition.variable
65 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/filter_dependency_by_llm.py:
--------------------------------------------------------------------------------
1 | """
2 | Strategy for filtering dependencies using LLM.
3 | """
4 |
5 | from typing import List, override
6 | from .base import FilterByLLMStrategy, StrategyExecuteResult
7 | from ..smart_codebase import SmartCodebase as Codebase, LLMMapFilterTargetType
8 | from coderetrx.static import Dependency, Symbol
9 |
10 |
11 | class FilterDependencyByLLMStrategy(FilterByLLMStrategy[Dependency]):
12 | """Strategy to filter dependencies by LLM and retrieve code chunks that use these dependencies."""
13 |
14 | @override
15 | def get_strategy_name(self) -> str:
16 | return "FILTER_DEPENDENCY_BY_LLM"
17 |
18 | @override
19 | def get_target_type(self) -> LLMMapFilterTargetType:
20 | return "dependency_name"
21 |
22 | @override
23 | def extract_file_paths(
24 | self, elements: List[Dependency], codebase: Codebase
25 | ) -> List[str]:
26 | """Extract file paths from dependencies by getting files that import these dependencies."""
27 | file_paths = []
28 | for dependency in elements:
29 | if isinstance(dependency, Dependency):
30 | file_paths.extend([str(f.path) for f in dependency.imported_by])
31 | elif isinstance(dependency, Symbol) and dependency.type == "dependency":
32 | file_paths.append(str(dependency.file.path))
33 | return list(set(file_paths))
34 |
35 | @override
36 | async def execute(
37 | self,
38 | codebase: Codebase,
39 | prompt: str,
40 | subdirs_or_files: List[str],
41 | target_type: str = "symbol_content",
42 | ) -> StrategyExecuteResult:
43 | enhanced_prompt = f"""
44 | A dependency that matches the following criteria is highly likely to be relevant:
45 |
46 | {prompt}
47 |
48 |
49 | The objective is to identify dependencies based on their names that match the specified criteria.
50 | Files that import these matching dependencies will be retrieved for further analysis.
51 | Focus on dependency names, package names, module names, and library names that are relevant to the criteria.
52 |
53 | """
54 | return await super().execute(
55 | codebase, enhanced_prompt, subdirs_or_files, target_type
56 | )
57 |
--------------------------------------------------------------------------------
/STRATEGIES.md:
--------------------------------------------------------------------------------
1 | # Coarse Recall Strategies
2 |
3 | In the coarse recall stage, multiple strategies are used to efficiently retrieve potentially relevant code snippets at low cost. The retrieved results will then be further processed in the refined filter stage to improve precision and accuracy.
4 |
5 | ## Available Strategies
6 |
7 | - **`file_name`**: The fastest filtering strategy, ideal for coarse recall by directly determining relevance based on filenames. This approach is highly effective for cases where the query can be matched through filenames alone, such as "retrieve all configuration files." Its simplicity makes it extremely efficient for structural queries and file discovery.
8 |
9 | - **`symbol_name`**: A filtering approach that focuses on symbol names (e.g., function or class names) during coarse recall. This method excels in cases like "retrieve functions implementing cryptographic algorithms," where relevance can be inferred directly from the symbol name. It balances recall and precision for targeted queries at the symbol level.
10 |
11 | - **`line_per_symbol`**: A high-accuracy filtering strategy that leverages line-level vector search combined with LLM analysis. It identifies relevance by focusing on the most relevant code lines (`top-k`) within a function body. This method is particularly effective for complex cases where understanding the function's implementation is necessary, such as "retrieve code implementing authentication logic." While powerful, it's also versatile enough to handle most queries effectively. Our benchmarking shows that the algorithm achieves over 90% recall with approximately 25% of the computational cost, making it both precise and efficient.
12 |
13 | - **`auto`**: Automatically selects the optimal filtering strategy based on query complexity, routing requests to the most appropriate method. For the majority of cases, `line_per_symbol` is a well-balanced and reliable choice, but `file_name` or `symbol_name` can also be explicitly used for scenarios where they excel.
14 |
15 | ## Performance Characteristics
16 |
17 | - **Speed**: `file_name` > `symbol_name` > `auto` > `line_per_symbol` > `precise`
18 | - **Accuracy**: `precise` > `line_per_symbol` > `auto` > `symbol_name` > `file_name`
19 | - **Cost**: `file_name` < `line_per_symbol` < `auto` < `symbol_name` < `precise`
20 |
21 | Use `file_name` for structural queries, `symbol_name` for API search, `line_per_symbol` for specific code related analysis, `auto` for general purpose, and `precise` for ground truth.
22 |
23 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/hcl.scm:
--------------------------------------------------------------------------------
1 | ;; Based on https://github.com/tree-sitter-grammars/tree-sitter-hcl/blob/main/make_grammar.js
2 | ;; Which has Apache 2.0 License
3 | ;; tags.scm for Terraform (tree-sitter-hcl)
4 |
5 | ; === Definitions: Terraform Blocks ===
6 | (block
7 | (identifier) @block_type
8 | (string_lit (template_literal) @resource_type)
9 | (string_lit (template_literal) @name.definition.resource)
10 | (body) @definition.resource
11 | ) (#eq? @block_type "resource")
12 |
13 | (block
14 | (identifier) @block_type
15 | (string_lit (template_literal) @name.definition.module)
16 | (body) @definition.module
17 | ) (#eq? @block_type "module")
18 |
19 | (block
20 | (identifier) @block_type
21 | (string_lit (template_literal) @name.definition.variable)
22 | (body) @definition.variable
23 | ) (#eq? @block_type "variable")
24 |
25 | (block
26 | (identifier) @block_type
27 | (string_lit (template_literal) @name.definition.output)
28 | (body) @definition.output
29 | ) (#eq? @block_type "output")
30 |
31 | (block
32 | (identifier) @block_type
33 | (string_lit (template_literal) @name.definition.provider)
34 | (body) @definition.provider
35 | ) (#eq? @block_type "provider")
36 |
37 | (block
38 | (identifier) @block_type
39 | (body
40 | (attribute
41 | (identifier) @name.definition.local
42 | (expression) @definition.local
43 | )+
44 | )
45 | ) (#eq? @block_type "locals")
46 |
47 | ; === References: Variables, Locals, Modules, Data, Resources ===
48 | ((variable_expr) @ref_type
49 | (get_attr (identifier) @name.reference.variable)
50 | ) @reference.variable
51 | (#eq? @ref_type "var")
52 |
53 | ((variable_expr) @ref_type
54 | (get_attr (identifier) @name.reference.local)
55 | ) @reference.local
56 | (#eq? @ref_type "local")
57 |
58 | ((variable_expr) @ref_type
59 | (get_attr (identifier) @name.reference.module)
60 | ) @reference.module
61 | (#eq? @ref_type "module")
62 |
63 | ((variable_expr) @ref_type
64 | (get_attr (identifier) @data_source_type)
65 | (get_attr (identifier) @name.reference.data)
66 | ) @reference.data
67 | (#eq? @ref_type "data")
68 |
69 | ((variable_expr) @resource_type
70 | (get_attr (identifier) @name.reference.resource)
71 | ) @reference.resource
72 | (#not-eq? @resource_type "var")
73 | (#not-eq? @resource_type "local")
74 | (#not-eq? @resource_type "module")
75 | (#not-eq? @resource_type "data")
76 | (#not-eq? @resource_type "provider")
77 | (#not-eq? @resource_type "output")
78 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Strategy package for code retrieval strategies.
3 |
4 | This package contains individual strategy implementations split from the original strategies.py file.
5 | Each strategy is now in its own module for better organization and maintainability.
6 | """
7 |
8 | from .base import (
9 | RecallStrategy,
10 | StrategyExecuteResult,
11 | RecallStrategyExecutor,
12 | FilterByLLMStrategy,
13 | FilterByVectorStrategy,
14 | FilterByVectorAndLLMStrategy,
15 | AdaptiveFilterByVectorAndLLMStrategy,
16 | deduplicate_elements,
17 | rel_path,
18 | )
19 | from .factory import StrategyFactory
20 |
21 | from .filter_filename_by_llm import FilterFilenameByLLMStrategy
22 | from .filter_symbol_name_by_llm import FilterSymbolNameByLLMStrategy
23 | from .filter_dependency_by_llm import FilterDependencyByLLMStrategy
24 | from .filter_keyword_by_vector import FilterKeywordByVectorStrategy
25 | from .filter_symbol_content_by_vector import FilterSymbolContentByVectorStrategy
26 | from .filter_keyword_by_vector_and_llm import FilterKeywordByVectorAndLLMStrategy
27 | from .filter_symbol_content_by_vector_and_llm import (
28 | FilterSymbolContentByVectorAndLLMStrategy,
29 | )
30 | from .adaptive_filter_keyword_by_vector_and_llm import (
31 | AdaptiveFilterKeywordByVectorAndLLMStrategy,
32 | )
33 | from .adaptive_filter_symbol_content_by_vector_and_llm import (
34 | AdaptiveFilterSymbolContentByVectorAndLLMStrategy,
35 | )
36 | from .filter_line_per_symbol_by_vector_and_llm import (
37 | FilterLinePerSymbolByVectorAndLLMStrategy,
38 | )
39 |
40 | __all__ = [
41 | # Enums and Models
42 | "RecallStrategy",
43 | "StrategyExecuteResult",
44 | # Base Classes
45 | "RecallStrategyExecutor",
46 | "FilterByLLMStrategy",
47 | "FilterByVectorStrategy",
48 | "FilterByVectorAndLLMStrategy",
49 | "AdaptiveFilterByVectorAndLLMStrategy",
50 | # Concrete Strategy Implementations
51 | "FilterFilenameByLLMStrategy",
52 | "FilterSymbolNameByLLMStrategy",
53 | "FilterDependencyByLLMStrategy",
54 | "FilterKeywordByVectorStrategy",
55 | "FilterSymbolContentByVectorStrategy",
56 | "FilterKeywordByVectorAndLLMStrategy",
57 | "FilterSymbolContentByVectorAndLLMStrategy",
58 | "AdaptiveFilterKeywordByVectorAndLLMStrategy",
59 | "AdaptiveFilterSymbolContentByVectorAndLLMStrategy",
60 | "FilterLinePerSymbolByVectorAndLLMStrategy",
61 | # Factory and Utilities
62 | "StrategyFactory",
63 | "deduplicate_elements",
64 | "rel_path",
65 | ]
66 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/dart.scm:
--------------------------------------------------------------------------------
1 | (class_definition
2 | name: (identifier) @name.definition.class) @definition.class
3 |
4 | (method_signature
5 | (function_signature)) @definition.method
6 |
7 | (type_alias
8 | (type_identifier) @name.definition.type) @definition.type
9 |
10 | (method_signature
11 | (getter_signature
12 | name: (identifier) @name.definition.method)) @definition.method
13 |
14 | (method_signature
15 | (setter_signature
16 | name: (identifier) @name.definition.method)) @definition.method
17 |
18 | (method_signature
19 | (function_signature
20 | name: (identifier) @name.definition.method)) @definition.method
21 |
22 | (method_signature
23 | (factory_constructor_signature
24 | (identifier) @name.definition.method)) @definition.method
25 |
26 | (method_signature
27 | (constructor_signature
28 | name: (identifier) @name.definition.method)) @definition.method
29 |
30 | (method_signature
31 | (operator_signature)) @definition.method
32 |
33 | (method_signature) @definition.method
34 |
35 | (mixin_declaration
36 | (mixin)
37 | (identifier) @name.definition.mixin) @definition.mixin
38 |
39 | (extension_declaration
40 | name: (identifier) @name.definition.extension) @definition.extension
41 |
42 | (enum_declaration
43 | name: (identifier) @name.definition.enum) @definition.enum
44 |
45 | (function_signature
46 | name: (identifier) @name.definition.function) @definition.function
47 |
48 | (new_expression
49 | (type_identifier) @name.reference.class) @reference.class
50 |
51 | (initialized_variable_definition
52 | name: (identifier)
53 | value: (identifier) @name.reference.class
54 | value: (selector
55 | "!"?
56 | (argument_part
57 | (arguments
58 | (argument)*))?)?) @reference.class
59 |
60 | (assignment_expression
61 | left: (assignable_expression
62 | (identifier)
63 | (unconditional_assignable_selector
64 | "."
65 | (identifier) @name.reference.call))) @reference.call
66 |
67 | (assignment_expression
68 | left: (assignable_expression
69 | (identifier)
70 | (conditional_assignable_selector
71 | "?."
72 | (identifier) @name.reference.call))) @reference.call
73 |
74 | ((identifier) @name
75 | (selector
76 | "!"?
77 | (conditional_assignable_selector
78 | "?." (identifier) @name.reference.call)?
79 | (unconditional_assignable_selector
80 | "."? (identifier) @name.reference.call)?
81 | (argument_part
82 | (arguments
83 | (argument)*))?)*
84 | (cascade_section
85 | (cascade_selector
86 | (identifier)) @name.reference.call
87 | (argument_part
88 | (arguments
89 | (argument)*))?)?) @reference.call
90 |
91 |
92 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tests/python.scm:
--------------------------------------------------------------------------------
1 | ; subclasses of unittest.TestCase or TestCase
2 | (
3 | (class_definition
4 | name: (identifier) @run @_unittest_class_name
5 | superclasses: (argument_list
6 | [(identifier) @_superclass
7 | (attribute (identifier) @_superclass)]
8 | )
9 | (#eq? @_superclass "TestCase")
10 | ) @_python-unittest-class
11 | (#set! tag python-unittest-class)
12 | )
13 |
14 | ; test methods whose names start with `test` in a TestCase
15 | (
16 | (class_definition
17 | name: (identifier) @_unittest_class_name
18 | superclasses: (argument_list
19 | [(identifier) @_superclass
20 | (attribute (identifier) @_superclass)]
21 | )
22 | (#eq? @_superclass "TestCase")
23 | body: (block
24 | (function_definition
25 | name: (identifier) @run @_unittest_method_name
26 | (#match? @_unittest_method_name "^test.*")
27 | ) @_python-unittest-method
28 | (#set! tag python-unittest-method)
29 | )
30 | )
31 | )
32 |
33 | ; pytest functions
34 | (
35 | (module
36 | (function_definition
37 | name: (identifier) @run @_pytest_method_name
38 | (#match? @_pytest_method_name "^test_")
39 | ) @_python-pytest-method
40 | )
41 | (#set! tag python-pytest-method)
42 | )
43 |
44 | ; decorated pytest functions
45 | (
46 | (module
47 | (decorated_definition
48 | (decorator)+ @_decorator
49 | definition: (function_definition
50 | name: (identifier) @run @_pytest_method_name
51 | (#match? @_pytest_method_name "^test_")
52 | )
53 | ) @_python-pytest-method
54 | )
55 | (#set! tag python-pytest-method)
56 | )
57 |
58 | ; pytest classes
59 | (
60 | (module
61 | (class_definition
62 | name: (identifier) @run @_pytest_class_name
63 | (#match? @_pytest_class_name "^Test")
64 | )
65 | (#set! tag python-pytest-class)
66 | )
67 | )
68 |
69 | ; pytest class methods
70 | (
71 | (module
72 | (class_definition
73 | name: (identifier) @_pytest_class_name
74 | (#match? @_pytest_class_name "^Test")
75 | body: (block
76 | (function_definition
77 | name: (identifier) @run @_pytest_method_name
78 | (#match? @_pytest_method_name "^test")
79 | ) @_python-pytest-method
80 | (#set! tag python-pytest-method)
81 | )
82 | )
83 | )
84 | )
85 |
--------------------------------------------------------------------------------
/coderetrx/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Pydantic based structured logging into a JSONL file.
2 | # The "logging" here should specifically be for dataset creation.
3 | from datetime import datetime
4 | from os import PathLike
5 | from pathlib import Path
6 | from typing import Literal, Optional, List, TYPE_CHECKING
7 |
8 | from pydantic import BaseModel, RootModel
9 |
10 | from coderetrx.static.codebase.models import IsolatedCodeChunkModel
11 | from coderetrx.static.codebase.languages import IDXSupportedTag
12 |
13 |
14 | class FilteringLog(BaseModel):
15 | type: Literal["filtering"] = "filtering"
16 | query: str
17 | total_chunks: int
18 | strategy: Literal["static", "dynamic"]
19 | limit: int
20 | filter_tags: Optional[List["IDXSupportedTag"]] = None
21 |
22 |
23 | class CodeChunkClassificationLog(BaseModel):
24 | type: Literal["code_chunk_classification"] = "code_chunk_classification"
25 | code_chunk: IsolatedCodeChunkModel
26 | classification: str
27 | rationale: str
28 |
29 |
30 | class VecSearchLog(BaseModel):
31 | type: Literal["vec_search"] = "vec_search"
32 | query: str
33 | total_retrieved: int
34 | matched_count: int
35 | strategy: Literal["static", "dynamic"]
36 | initial_limit: int
37 | final_limit: int
38 | filter_tags: Optional[List["IDXSupportedTag"]] = None
39 | success_ratio: float
40 | llm_model: str
41 |
42 |
43 | class LLMCallLog(BaseModel):
44 | type: Literal["llm_call"] = "llm_call"
45 | completion_id: str
46 | model: str
47 | completion_tokens: int
48 | prompt_tokens: int
49 | total_tokens: int
50 | call_url: str
51 | cached: bool = False
52 |
53 |
54 | class ErrLog(BaseModel):
55 | type: Literal["err"] = "err"
56 | error_type: str
57 | error: str
58 |
59 |
60 | type LogData = (
61 | FilteringLog | CodeChunkClassificationLog | VecSearchLog | LLMCallLog | ErrLog
62 | )
63 |
64 |
65 | class LogEntry(BaseModel):
66 | timestamp: str
67 | data: LogData
68 |
69 |
70 | def write_log(entry: LogEntry, file: PathLike | str):
71 | with open(file, "a") as f:
72 | f.write(entry.model_dump_json() + "\n")
73 |
74 |
75 | def read_logs(file: PathLike | str):
76 | with open(file, "r") as f:
77 | for line in f:
78 | yield LogEntry.model_validate_json(line)
79 |
80 | class JsonLogger:
81 | def __init__(self, file: PathLike | str):
82 | self.file = file
83 | Path(file).touch(exist_ok=True)
84 |
85 | def log(self, data: LogData):
86 | entry = LogEntry(timestamp=datetime.now().isoformat(), data=data)
87 | write_log(entry, self.file)
88 |
89 | if __name__ == "__main__":
90 | from rich import print
91 |
92 | print(LogEntry.model_json_schema())
93 |
--------------------------------------------------------------------------------
/coderetrx/tools/find_file_by_name.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from typing import ClassVar, List, Type
4 | from pathlib import Path
5 | from pydantic import BaseModel, Field
6 | from coderetrx.static.ripgrep import ripgrep_glob # type: ignore
7 | from coderetrx.utils.path import safe_join
8 | from coderetrx.tools.base import BaseTool
9 | import filetype
10 |
11 |
12 | class FindFileByNameArgs(BaseModel):
13 | dir_path: str = Field(description="The directory to search within")
14 | pattern: str = Field(
15 | description="pattern to search for. Based on keyword matching, wildcards are not supported and are not required."
16 | )
17 |
18 |
19 | class FindFileByNameResult(BaseModel):
20 | path: str = Field(description="The path of matched file name")
21 | type: str = Field(description="The type of the file")
22 |
23 | @classmethod
24 | def repr(cls, entries: list["FindFileByNameResult"]):
25 | """
26 | convert a list of FindFileByNameResult to a readable string
27 | """
28 | if not entries:
29 | return "No result Found."
30 | if not entries[0].path:
31 | return "The SearchDirectory does not exist. Please make sure that the SearchDirectory you are investigating exists."
32 | tool_result = ""
33 | for i, file in enumerate(entries, 1):
34 | tool_result += f"# **{i}. {file.path}**\n"
35 |
36 | return tool_result
37 |
38 |
39 | class FindFileByNameTool(BaseTool):
40 | name = "find_file_by_name"
41 | description = (
42 | "This tool searches for files and directories within a specified directory, similar to the Linux `find` command. "
43 | "The returned result paths are relative to the root path."
44 | )
45 | args_schema: ClassVar[Type[FindFileByNameArgs]] = FindFileByNameArgs
46 |
47 | def _get_file_type(self, path: str) -> str:
48 | type = filetype.guess(path)
49 | if type:
50 | return type.mime
51 | else:
52 | return "Unknown"
53 |
54 | def forward(self, dir_path: str, pattern: str) -> str:
55 | """Synchronous wrapper for async _run method."""
56 | return self.run_sync(dir_path=dir_path, pattern=pattern)
57 |
58 | async def _run(self, dir_path: str, pattern: str) -> list[FindFileByNameResult]:
59 | """Search for files matching the pattern in the specified directory"""
60 | full_dir_path = safe_join(self.repo_path, dir_path.lstrip("/"))
61 |
62 | if not os.path.exists(full_dir_path):
63 | return [FindFileByNameResult(path="", type="Directory Not Exists")]
64 |
65 | matched_files = await ripgrep_glob(
66 | full_dir_path, pattern, extra_argv=["-g", "!.git"]
67 | )
68 |
69 | results = []
70 | for file_path in matched_files:
71 | full_file_path = full_dir_path / file_path
72 | file_type = self._get_file_type(str(full_file_path))
73 | results.append(FindFileByNameResult(path=file_path, type=file_type))
74 |
75 | return results
76 |
--------------------------------------------------------------------------------
/coderetrx/tools/base.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | from pathlib import Path
4 | from typing import ClassVar, Optional
5 |
6 | from coderetrx.utils.path import get_data_dir, get_repo_path
7 | from coderetrx.utils.git import clone_repo_if_not_exists, get_repo_id
8 | from typing import Any, Type
9 | from pydantic import BaseModel
10 | from abc import abstractmethod
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class BaseTool:
16 | """Base class for tools that work with repositories."""
17 |
18 | name: str
19 | description: str
20 | args_schema: ClassVar[Type[BaseModel]]
21 |
22 | def __init__(self, repo_url: str, uuid: Optional[str] = None):
23 | super().__init__()
24 | logger.info(f"Init base repo tool {self.name} with uuid: {uuid} ...")
25 |
26 | self.repo_url = repo_url
27 | self.repo_id = get_repo_id(repo_url)
28 | self.uuid = uuid
29 | self.repo_path = get_repo_path(repo_url)
30 |
31 | clone_repo_if_not_exists(repo_url, str(self.repo_path))
32 |
33 | def run_sync(self, *args, **kwargs):
34 | """Synchronous wrapper for async _run method."""
35 | try:
36 | loop = asyncio.get_running_loop()
37 | # If we're already in a running loop, we need to use a different approach
38 | import concurrent.futures
39 | import threading
40 |
41 | def run_in_thread():
42 | new_loop = asyncio.new_event_loop()
43 | asyncio.set_event_loop(new_loop)
44 | try:
45 | return new_loop.run_until_complete(self.run(*args, **kwargs))
46 | finally:
47 | new_loop.close()
48 |
49 | with concurrent.futures.ThreadPoolExecutor() as executor:
50 | future = executor.submit(run_in_thread)
51 | return future.result()
52 |
53 | except RuntimeError:
54 | # No event loop is running, we can create one
55 | loop = asyncio.new_event_loop()
56 | asyncio.set_event_loop(loop)
57 | try:
58 | return loop.run_until_complete(self.run(*args, **kwargs))
59 | finally:
60 | loop.close()
61 |
62 | async def _run_repr(self, *args, **kwargs: dict[str, str]):
63 | result = await self._run(*args, **kwargs)
64 | if not result:
65 | return "No result Found."
66 | if type(result) == str:
67 | return result
68 | return type(result[0]).repr(result)
69 |
70 | @abstractmethod
71 | async def _run(
72 | self,
73 | *args: Any,
74 | **kwargs: Any,
75 | ) -> Any:
76 | """Async implementation to be overridden by subclasses."""
77 | raise NotImplementedError("Subclasses must implement _run method")
78 |
79 | async def run(
80 | self,
81 | *args: Any,
82 | **kwargs: Any,
83 | ) -> Any:
84 | repr_output = kwargs.pop("repr_output", True)
85 | if repr_output:
86 | return await self._run_repr(*args, **kwargs)
87 | else:
88 | return await self._run(*args, **kwargs)
89 |
--------------------------------------------------------------------------------
/scripts/example.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from coderetrx.retrieval import coderetrx_filter, CodebaseFactory, TopicExtractor, SmartCodebaseSettings
3 | from coderetrx.retrieval.code_recall import CodeRecallSettings
4 | from coderetrx.utils.git import clone_repo_if_not_exists, get_repo_id
5 | from coderetrx.utils.path import get_data_dir
6 |
7 | async def main():
8 | # Set up the repository URL and path
9 | repo_url = "https://github.com/ollama/ollama.git"
10 | repo_path = get_data_dir() / "repos" / get_repo_id(repo_url)
11 |
12 | # Clone the repository if it does not exist
13 | clone_repo_if_not_exists(repo_url, str(repo_path))
14 |
15 | # Create codebase settings with symbol_codeline_embedding enabled
16 | codebase_settings = SmartCodebaseSettings()
17 | codebase_settings.symbol_codeline_embedding = True
18 |
19 | # Create a codebase instance
20 | codebase = CodebaseFactory.new(get_repo_id(repo_url), repo_path, settings=codebase_settings)
21 |
22 | # Create a topic extractor instance
23 | topic_extractor = TopicExtractor()
24 |
25 | # Initialize code recall settings
26 | settings = CodeRecallSettings()
27 |
28 | # Set the target_type and coarse recall strategy
29 | result, llm_output = await coderetrx_filter(
30 | codebase=codebase,
31 | subdirs_or_files=["/"],
32 | prompt="The code snippet contains a function call that dynamically executes code or system commands. Examples include Python's `eval()`, `exec()`, or functions like `os.system()`, `subprocess.run()` (especially with `shell=True`), `subprocess.call()` (with `shell=True`), or `popen()`. The critical feature is that the string representing the code or command to be executed is not a hardcoded literal; instead, it's derived from a variable, function argument, string concatenation/formatting, or an external source such as user input, network request, or LLM output.",
33 | target_type="symbol_content",
34 | coarse_recall_strategy="line_per_symbol",
35 | topic_extractor=topic_extractor,
36 | settings=settings
37 | )
38 |
39 | '''
40 | result, llm_output = await coderetrx_filter(
41 | codebase=codebase,
42 | subdirs_or_files=["/"],
43 | prompt="The code snippet contains a function call that dynamically executes code or system commands. Examples include Python's `eval()`, `exec()`, or functions like `os.system()`, `subprocess.run()` (especially with `shell=True`), `subprocess.call()` (with `shell=True`), or `popen()`. The critical feature is that the string representing the code or command to be executed is not a hardcoded literal; instead, it's derived from a variable, function argument, string concatenation/formatting, or an external source such as user input, network request, or LLM output.",
44 | target_type="symbol_content",
45 | coarse_recall_strategy="line_per_symbol",
46 | topic_extractor=topic_extractor,
47 | settings=settings
48 | )
49 | '''
50 |
51 | print(f"Find {len(result)} results, first {min(5, len(result))} are:")
52 | for i, location in enumerate(result[:5], 1):
53 | print(f" {i}. {location}")
54 |
55 | if __name__ == "__main__":
56 | asyncio.run(main())
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/javascript.scm:
--------------------------------------------------------------------------------
1 | (
2 | (comment)* @doc
3 | .
4 | (method_definition
5 | name: (property_identifier) @name.definition.method) @definition.method
6 | (#not-eq? @name.definition.method "constructor")
7 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$")
8 | (#select-adjacent! @doc @definition.method)
9 | )
10 |
11 | (
12 | (comment)* @doc
13 | .
14 | [
15 | (class
16 | name: (_) @name.definition.class)
17 | (class_declaration
18 | name: (_) @name.definition.class)
19 | ] @definition.class
20 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$")
21 | (#select-adjacent! @doc @definition.class)
22 | )
23 |
24 | (
25 | (comment)* @doc
26 | .
27 | [
28 | (function_expression
29 | name: (identifier) @name.definition.function)
30 | (function_declaration
31 | name: (identifier) @name.definition.function)
32 | (generator_function
33 | name: (identifier) @name.definition.function)
34 | (generator_function_declaration
35 | name: (identifier) @name.definition.function)
36 | ] @definition.function
37 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$")
38 | (#select-adjacent! @doc @definition.function)
39 | )
40 |
41 | (
42 | (comment)* @doc
43 | .
44 | (lexical_declaration
45 | (variable_declarator
46 | name: (identifier) @name.definition.function
47 | value: [(arrow_function) (function_expression)]) @definition.function)
48 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$")
49 | (#select-adjacent! @doc @definition.function)
50 | )
51 |
52 | (
53 | (comment)* @doc
54 | .
55 | (variable_declaration
56 | (variable_declarator
57 | name: (identifier) @name.definition.function
58 | value: [(arrow_function) (function_expression)]) @definition.function)
59 | (#strip! @doc "^[\\s\\*/]+|^[\\s\\*/]$")
60 | (#select-adjacent! @doc @definition.function)
61 | )
62 |
63 | (assignment_expression
64 | left: [
65 | (identifier) @name.definition.function
66 | (member_expression
67 | property: (property_identifier) @name.definition.function)
68 | ]
69 | right: [(arrow_function) (function_expression)]
70 | ) @definition.function
71 |
72 | (pair
73 | key: (property_identifier) @name.definition.function
74 | value: [(arrow_function) (function_expression)]) @definition.function
75 |
76 | (
77 | (call_expression
78 | function: (identifier) @name.reference.call) @reference.call
79 | (#not-match? @name.reference.call "^(require)$")
80 | )
81 |
82 | (call_expression
83 | function: (member_expression
84 | property: (property_identifier) @name.reference.call)
85 | arguments: (_) @reference.call)
86 |
87 | (new_expression
88 | constructor: (_) @name.reference.class) @reference.class
89 |
90 | (import_statement) @import
91 |
92 | ; Variable definitions - const/let/var declarations (non-function values)
93 | ; Note: We capture all variable declarations, function assignments are already captured above
94 | (lexical_declaration
95 | (variable_declarator
96 | name: (identifier) @name.definition.variable)) @definition.variable
97 |
98 | (variable_declaration
99 | (variable_declarator
100 | name: (identifier) @name.definition.variable)) @definition.variable
101 |
--------------------------------------------------------------------------------
/test/tools/test_tools.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import pytest
3 | from coderetrx.tools.get_references import GetReferenceTool
4 | from coderetrx.tools.view_file import ViewFileTool
5 | from coderetrx.tools.find_file_by_name import FindFileByNameTool
6 | from coderetrx.tools.keyword_search import KeywordSearchTool
7 | from coderetrx.tools.list_dir import ListDirTool
8 | import logging
9 | logger =logging.getLogger(__name__)
10 | logging.basicConfig(level=logging.INFO)
11 | TEST_REPO = "https://github.com/apache/flink.git"
12 |
13 | class TestGetReferenceTool:
14 | def test(self):
15 | """Test finding references to a symbol"""
16 | logger.info("Testing GetReferenceTool...")
17 | tool = GetReferenceTool(TEST_REPO)
18 | result = asyncio.run(tool._run(symbol_name="upload"))
19 | logger.info(f"GetReferenceTool result: {result}")
20 | assert isinstance(result, (list, dict)), "Result should be a list or dict"
21 |
22 |
23 | class TestViewFileTool:
24 | def test(self):
25 | """Test viewing file content"""
26 | logger.info("Testing ViewFileTool...")
27 | tool = ViewFileTool(TEST_REPO)
28 | result = asyncio.run(tool._run(file_path="./README.md", start_line=0, end_line=10))
29 | logger.info(f"ViewFileTool result: {result}")
30 | assert isinstance(result, str), "Result should be a string (file content)"
31 | assert len(result) > 0, "File content should not be empty"
32 |
33 |
34 | class TestFindFileByNameTool:
35 | def test(self):
36 | """Test finding files by name pattern"""
37 | logger.info("Testing FindFileByNameTool...")
38 | tool = FindFileByNameTool(TEST_REPO)
39 | results = asyncio.run(tool._run(dir_path="/", pattern="*.md"))
40 | logger.info(f"FindFileByNameTool result: {results}")
41 | assert isinstance(results, list), "Result should be a list of file paths"
42 |
43 |
44 | class TestKeywordSearchTool:
45 | def test(self):
46 | """Test keyword search functionality"""
47 | logger.info("Testing KeywordSearchTool...")
48 | tool = KeywordSearchTool(TEST_REPO)
49 | result = asyncio.run(
50 | tool._run(
51 | query="README",
52 | dir_path="/",
53 | case_insensitive=False,
54 | include_content=False,
55 | )
56 | )
57 | logger.info(f"KeywordSearchTool result: {result}")
58 | assert isinstance(result, list), "Result should be a list of matches"
59 |
60 |
61 | class TestListDirTool:
62 | def test(self):
63 | """Test listing directory contents"""
64 | logger.info("Testing ListDirTool...")
65 | tool = ListDirTool(TEST_REPO)
66 | result = asyncio.run(tool._run(directory_path="."))
67 | logger.info(f"ListDirTool result: {result}")
68 | assert isinstance(result, list), "Result should be a list of directory entries"
69 |
70 | def test_all_tools():
71 | """Test all tools"""
72 | tool_testers = [
73 | TestGetReferenceTool(),
74 | TestViewFileTool(),
75 | TestFindFileByNameTool(),
76 | TestKeywordSearchTool(),
77 | TestListDirTool(),
78 | ]
79 | for tester in tool_testers:
80 | tester.test()
81 |
82 | if __name__ == "__main__":
83 | test_all_tools()
--------------------------------------------------------------------------------
/coderetrx/static/codebase/queries/treesitter/tags/ocaml.scm:
--------------------------------------------------------------------------------
1 | ; Modules
2 | ;--------
3 |
4 | (
5 | (comment)? @doc .
6 | (module_definition (module_binding (module_name) @name.definition.module) @definition.module)
7 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$")
8 | )
9 |
10 | (module_path (module_name) @name.reference.module) @reference.module
11 |
12 | ; Module types
13 | ;--------------
14 |
15 | (
16 | (comment)? @doc .
17 | (module_type_definition (module_type_name) @name.definition.interface) @definition.interface
18 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$")
19 | )
20 |
21 | (module_type_path (module_type_name) @name.reference.implementation) @reference.implementation
22 |
23 | ; Functions
24 | ;----------
25 |
26 | (
27 | (comment)? @doc .
28 | (value_definition
29 | [
30 | (let_binding
31 | pattern: (value_name) @name.definition.function
32 | (parameter))
33 | (let_binding
34 | pattern: (value_name) @name.definition.function
35 | body: [(fun_expression) (function_expression)])
36 | ] @definition.function
37 | )
38 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$")
39 | )
40 |
41 | (
42 | (comment)? @doc .
43 | (external (value_name) @name.definition.function) @definition.function
44 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$")
45 | )
46 |
47 | (application_expression
48 | function: (value_path (value_name) @name.reference.call)) @reference.call
49 |
50 | (infix_expression
51 | left: (value_path (value_name) @name.reference.call)
52 | operator: (concat_operator) @reference.call
53 | (#eq? @reference.call "@@"))
54 |
55 | (infix_expression
56 | operator: (rel_operator) @reference.call
57 | right: (value_path (value_name) @name.reference.call)
58 | (#eq? @reference.call "|>"))
59 |
60 | ; Operator
61 | ;---------
62 |
63 | (
64 | (comment)? @doc .
65 | (value_definition
66 | (let_binding
67 | pattern: (parenthesized_operator (_) @name.definition.function)) @definition.function)
68 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$")
69 | )
70 |
71 | [
72 | (prefix_operator)
73 | (sign_operator)
74 | (pow_operator)
75 | (mult_operator)
76 | (add_operator)
77 | (concat_operator)
78 | (rel_operator)
79 | (and_operator)
80 | (or_operator)
81 | (assign_operator)
82 | (hash_operator)
83 | (indexing_operator)
84 | (let_operator)
85 | (let_and_operator)
86 | (match_operator)
87 | ] @name.reference.call @reference.call
88 |
89 | ; Classes
90 | ;--------
91 |
92 | (
93 | (comment)? @doc .
94 | [
95 | (class_definition (class_binding (class_name) @name.definition.class) @definition.class)
96 | (class_type_definition (class_type_binding (class_type_name) @name.definition.class) @definition.class)
97 | ]
98 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$")
99 | )
100 |
101 | [
102 | (class_path (class_name) @name.reference.class)
103 | (class_type_path (class_type_name) @name.reference.class)
104 | ] @reference.class
105 |
106 | ; Methods
107 | ;--------
108 |
109 | (
110 | (comment)? @doc .
111 | (method_definition (method_name) @name.definition.method) @definition.method
112 | (#strip! @doc "^\\(\\*\\*?\\s*|\\s\\*\\)$")
113 | )
114 |
115 | (method_invocation (method_name) @name.reference.call) @reference.call
116 |
--------------------------------------------------------------------------------
/coderetrx/utils/concurrency.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Callable, Any, TypeVar, Coroutine, Optional
3 | from concurrent.futures import ThreadPoolExecutor
4 | import asyncio
5 | import threading
6 |
7 | T = TypeVar("T")
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | async def abatch_func_call(
13 | max_concurrency: int, func: Callable[..., Any], kwargs_list: list[dict]
14 | ) -> list[Any]:
15 | """Execute a batch of function calls with controlled concurrency.
16 |
17 | Handles both async and sync functions automatically:
18 | - Async functions run directly in the event loop
19 | - Sync functions run in thread pool executor to avoid blocking
20 |
21 | Features:
22 | - Semaphore-based concurrency control
23 | - Automatic error logging with function context
24 | - Preserves original exception stack traces
25 |
26 | Args:
27 | max_concurrency: Maximum parallel executions allowed
28 | func: Callable to execute (async or sync)
29 | kwargs_list: List of keyword arguments dictionaries for each call
30 |
31 | Returns:
32 | list: Results in the same order as kwargs_list, exceptions will propagate
33 |
34 | Raises:
35 | Exception: Re-raises the first encountered exception from any task
36 | """
37 | semaphore = asyncio.Semaphore(max_concurrency)
38 |
39 | async def a_func_call(func: Callable, kwargs: dict) -> Any:
40 | """Execute a single function call with concurrency control.
41 |
42 | Args:
43 | func: Target function to execute
44 | kwargs: Keyword arguments for this call
45 |
46 | Returns:
47 | Any: Result of the function call
48 |
49 | Raises:
50 | Exception: Propagates any exceptions from the function call
51 | """
52 | async with semaphore: # Acquire semaphore slot
53 | try:
54 | if asyncio.iscoroutinefunction(func):
55 | # Directly await async functions
56 | result = await func(**kwargs)
57 | else:
58 | # Run sync functions in thread pool to prevent blocking
59 | result = await asyncio.to_thread(func, **kwargs)
60 | return result
61 | except Exception as e:
62 | logger.error(f"Error in {func.__name__} with args {kwargs}: {e}")
63 | raise # Re-raise to maintain stack trace
64 |
65 | # Create and schedule all tasks
66 | tasks = [a_func_call(func, kwargs) for kwargs in kwargs_list]
67 |
68 | # Execute all tasks concurrently with controlled concurrency
69 | results = await asyncio.gather(*tasks, return_exceptions=False)
70 |
71 | return results
72 |
73 |
74 | def run_coroutine_sync(
75 | coroutine: Coroutine[Any, Any, T], timeout: Optional[float] = None
76 | ) -> T:
77 | def run_in_new_loop():
78 | new_loop = asyncio.new_event_loop()
79 | asyncio.set_event_loop(new_loop)
80 | try:
81 | return new_loop.run_until_complete(coroutine)
82 | finally:
83 | new_loop.close()
84 |
85 | try:
86 | loop = asyncio.get_running_loop()
87 | except RuntimeError:
88 | return asyncio.run(coroutine)
89 |
90 | if threading.current_thread() is threading.main_thread():
91 | if not loop.is_running():
92 | return loop.run_until_complete(coroutine)
93 | else:
94 | with ThreadPoolExecutor() as pool:
95 | future = pool.submit(run_in_new_loop)
96 | return future.result(timeout=timeout)
97 | else:
98 | return asyncio.run_coroutine_threadsafe(coroutine, loop).result()
99 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/factory.py:
--------------------------------------------------------------------------------
1 | from coderetrx.retrieval.strategy.base import RecallStrategy
2 | from typing import Optional, Union, List
3 |
4 | from coderetrx.retrieval.smart_codebase import LLMCallMode
5 |
6 | from coderetrx.retrieval.topic_extractor import TopicExtractor
7 | from coderetrx.retrieval.strategy.base import RecallStrategyExecutor
8 | from coderetrx.retrieval.strategy.filter_filename_by_llm import (
9 | FilterFilenameByLLMStrategy,
10 | )
11 | from coderetrx.retrieval.strategy.filter_keyword_by_vector import (
12 | FilterKeywordByVectorStrategy,
13 | )
14 | from coderetrx.retrieval.strategy.filter_symbol_content_by_vector import (
15 | FilterSymbolContentByVectorStrategy,
16 | )
17 | from coderetrx.retrieval.strategy.filter_symbol_name_by_llm import (
18 | FilterSymbolNameByLLMStrategy,
19 | )
20 | from coderetrx.retrieval.strategy.filter_dependency_by_llm import (
21 | FilterDependencyByLLMStrategy,
22 | )
23 | from coderetrx.retrieval.strategy.filter_keyword_by_vector_and_llm import (
24 | FilterKeywordByVectorAndLLMStrategy,
25 | )
26 | from coderetrx.retrieval.strategy.filter_symbol_content_by_vector_and_llm import (
27 | FilterSymbolContentByVectorAndLLMStrategy,
28 | )
29 | from coderetrx.retrieval.strategy.adaptive_filter_symbol_content_by_vector_and_llm import (
30 | AdaptiveFilterSymbolContentByVectorAndLLMStrategy,
31 | )
32 | from coderetrx.retrieval.strategy.adaptive_filter_keyword_by_vector_and_llm import (
33 | AdaptiveFilterKeywordByVectorAndLLMStrategy,
34 | )
35 | from coderetrx.retrieval.strategy.filter_line_per_symbol_by_vector_and_llm import (
36 | FilterLinePerSymbolByVectorAndLLMStrategy,
37 | )
38 | from coderetrx.retrieval.strategy.filter_line_per_file_by_vector_and_llm import (
39 | FilterLinePerFileByVectorAndLLMStrategy,
40 | )
41 |
42 |
43 | class StrategyFactory:
44 | """Factory for creating strategy executors."""
45 |
46 | def __init__(
47 | self,
48 | topic_extractor: Optional[TopicExtractor] = None,
49 | llm_call_mode: LLMCallMode = "traditional",
50 | ):
51 | self.topic_extractor = topic_extractor
52 | self.llm_call_mode = llm_call_mode
53 |
54 | def create_strategy(self, strategy: RecallStrategy) -> RecallStrategyExecutor:
55 | """Create a strategy executor based on the strategy enum."""
56 | strategy_map = {
57 | RecallStrategy.FILTER_FILENAME_BY_LLM: FilterFilenameByLLMStrategy,
58 | RecallStrategy.FILTER_KEYWORD_BY_VECTOR: FilterKeywordByVectorStrategy,
59 | RecallStrategy.FILTER_SYMBOL_CONTENT_BY_VECTOR: FilterSymbolContentByVectorStrategy,
60 | RecallStrategy.FILTER_SYMBOL_NAME_BY_LLM: FilterSymbolNameByLLMStrategy,
61 | RecallStrategy.FILTER_DEPENDENCY_BY_LLM: FilterDependencyByLLMStrategy,
62 | RecallStrategy.FILTER_KEYWORD_BY_VECTOR_AND_LLM: FilterKeywordByVectorAndLLMStrategy,
63 | RecallStrategy.FILTER_SYMBOL_CONTENT_BY_VECTOR_AND_LLM: FilterSymbolContentByVectorAndLLMStrategy,
64 | RecallStrategy.ADAPTIVE_FILTER_KEYWORD_BY_VECTOR_AND_LLM: AdaptiveFilterKeywordByVectorAndLLMStrategy,
65 | RecallStrategy.ADAPTIVE_FILTER_SYMBOL_CONTENT_BY_VECTOR_AND_LLM: AdaptiveFilterSymbolContentByVectorAndLLMStrategy,
66 | RecallStrategy.FILTER_LINE_PER_SYMBOL_BY_VECTOR_AND_LLM: FilterLinePerSymbolByVectorAndLLMStrategy,
67 | RecallStrategy.FILTER_LINE_PER_FILE_BY_VECTOR_AND_LLM: FilterLinePerFileByVectorAndLLMStrategy,
68 | }
69 |
70 | if strategy not in strategy_map:
71 | raise ValueError(f"Unknown strategy: {strategy}")
72 |
73 | return strategy_map[strategy](
74 | topic_extractor=self.topic_extractor, llm_call_mode=self.llm_call_mode
75 | )
76 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/adaptive_filter_keyword_by_vector_and_llm.py:
--------------------------------------------------------------------------------
1 | """
2 | Strategy for adaptive filtering of keywords using vector similarity search followed by LLM refinement.
3 | """
4 |
5 | from typing import List, Union, Optional, override, Any, Literal
6 | from .base import AdaptiveFilterByVectorAndLLMStrategy, StrategyExecuteResult
7 | from ..smart_codebase import (
8 | SmartCodebase as Codebase,
9 | LLMMapFilterTargetType,
10 | SimilaritySearchTargetType,
11 | )
12 | from coderetrx.static import Keyword, Symbol, File
13 |
14 |
15 | class AdaptiveFilterKeywordByVectorAndLLMStrategy(AdaptiveFilterByVectorAndLLMStrategy):
16 | """Strategy to filter keywords using adaptive vector similarity search followed by LLM refinement."""
17 |
18 | name: str = "ADAPTIVE_FILTER_KEYWORD_BY_VECTOR_AND_LLM"
19 |
20 | @override
21 | def get_strategy_name(self) -> str:
22 | return self.name
23 |
24 | @override
25 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]:
26 | return ["keyword"]
27 |
28 | @override
29 | def get_target_type_for_llm(self) -> LLMMapFilterTargetType:
30 | return "keyword"
31 |
32 | @override
33 | def get_collection_size(self, codebase: Codebase) -> int:
34 | return len(codebase.keywords)
35 |
36 | @override
37 | def filter_elements(
38 | self,
39 | codebase: Codebase,
40 | elements: List[Any],
41 | target_type: LLMMapFilterTargetType = "symbol_content",
42 | subdirs_or_files: List[str] = [],
43 | ) -> List[Union[Keyword, Symbol, File]]:
44 | keyword_elements: List[Union[Keyword, Symbol, File]] = []
45 | for element in elements:
46 | if isinstance(element, Keyword):
47 | if subdirs_or_files and codebase:
48 | for ref_file in element.referenced_by:
49 | if any(
50 | str(ref_file.path).startswith(subdir)
51 | for subdir in subdirs_or_files
52 | ):
53 | keyword_elements.append(element)
54 | break
55 | else:
56 | keyword_elements.append(element)
57 | return keyword_elements
58 |
59 | @override
60 | def collect_file_paths(
61 | self,
62 | filtered_elements: List[Any],
63 | codebase: Codebase,
64 | subdirs_or_files: List[str],
65 | ) -> List[str]:
66 | referenced_paths = set()
67 | for keyword in filtered_elements:
68 | if isinstance(keyword, Keyword) and keyword.referenced_by:
69 | for ref_file in keyword.referenced_by:
70 | if str(ref_file.path).startswith(tuple(subdirs_or_files)):
71 | referenced_paths.add(str(ref_file.path))
72 | return list(referenced_paths)
73 |
74 | @override
75 | async def execute(
76 | self,
77 | codebase: Codebase,
78 | prompt: str,
79 | subdirs_or_files: List[str],
80 | target_type: LLMMapFilterTargetType = "symbol_content",
81 | ) -> StrategyExecuteResult:
82 | prompt = f"""
83 | A code chunk containing the specified keywords is highly likely to meet the following criteria:
84 |
85 | {prompt}
86 |
87 |
88 | The objective of this requirement is to preliminarily filter files that are likely to meet specific content criteria based on the keywords they contain.
89 | Files with matching keywords will proceed to deeper analysis in the content filter (content_criteria) at a later stage (not in this run).
90 |
91 | """
92 | return await super().execute(codebase, prompt, subdirs_or_files, target_type)
93 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/filter_keyword_by_vector_and_llm.py:
--------------------------------------------------------------------------------
1 | """
2 | Strategy for filtering keywords using vector similarity search followed by LLM refinement.
3 | """
4 |
5 | from typing import Any, List, Union, Optional, override
6 |
7 | from coderetrx.retrieval.strategy.base import StrategyExecuteResult
8 | from coderetrx.retrieval.strategy.base import FilterByVectorAndLLMStrategy
9 | from coderetrx.retrieval.smart_codebase import (
10 | SmartCodebase as Codebase,
11 | LLMMapFilterTargetType,
12 | SimilaritySearchTargetType,
13 | )
14 | from coderetrx.static import Keyword, Symbol, File
15 |
16 |
17 | class FilterKeywordByVectorAndLLMStrategy(FilterByVectorAndLLMStrategy):
18 | """Strategy to filter keywords using vector similarity search followed by LLM refinement."""
19 |
20 | name: str = "FILTER_KEYWORD_BY_VECTOR_AND_LLM"
21 |
22 | @override
23 | def get_strategy_name(self) -> str:
24 | return self.name
25 |
26 | @override
27 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]:
28 | return ["keyword"]
29 |
30 | @override
31 | def get_target_type_for_llm(self) -> LLMMapFilterTargetType:
32 | return "keyword"
33 |
34 | @override
35 | def get_collection_size(self, codebase: Codebase) -> int:
36 | return len(codebase.keywords)
37 |
38 | @override
39 | def filter_elements(
40 | self,
41 | codebase: Codebase,
42 | elements: List[Any],
43 | target_type: LLMMapFilterTargetType = "symbol_content",
44 | subdirs_or_files: List[str] = [],
45 | ) -> List[Union[Keyword, Symbol, File]]:
46 | keyword_elements: List[Union[Keyword, Symbol, File]] = []
47 | for element in elements:
48 | if isinstance(element, Keyword):
49 | # If subdirs_or_files is provided and codebase is available, filter by subdirs
50 | if subdirs_or_files and codebase:
51 | for ref_file in element.referenced_by:
52 | if any(
53 | str(ref_file.path).startswith(subdir)
54 | for subdir in subdirs_or_files
55 | ):
56 | keyword_elements.append(element)
57 | break
58 | else:
59 | keyword_elements.append(element)
60 | return keyword_elements
61 |
62 | @override
63 | def collect_file_paths(
64 | self,
65 | filtered_elements: List[Any],
66 | codebase: Codebase,
67 | subdirs_or_files: List[str],
68 | ) -> List[str]:
69 | referenced_paths = set()
70 | for keyword in filtered_elements:
71 | if isinstance(keyword, Keyword) and keyword.referenced_by:
72 | for ref_file in keyword.referenced_by:
73 | if str(ref_file.path).startswith(tuple(subdirs_or_files)):
74 | referenced_paths.add(str(ref_file.path))
75 | return list(referenced_paths)
76 |
77 | @override
78 | async def execute(
79 | self,
80 | codebase: Codebase,
81 | prompt: str,
82 | subdirs_or_files: List[str],
83 | target_type: LLMMapFilterTargetType = "symbol_content",
84 | ) -> StrategyExecuteResult:
85 | prompt = f"""
86 | A code chunk containing the specified keywords is highly likely to meet the following criteria:
87 |
88 | {prompt}
89 |
90 |
91 | The objective of this requirement is to preliminarily filter files that are likely to meet specific content criteria based on the keywords they contain.
92 | Files with matching keywords will proceed to deeper analysis in the content filter (content_criteria) at a later stage (not in this run).
93 |
94 | """
95 | return await super().execute(codebase, prompt, subdirs_or_files, target_type)
96 |
--------------------------------------------------------------------------------
/coderetrx/tools/view_file.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from typing import Optional, Type, ClassVar
4 | from pathlib import Path
5 | from pydantic import BaseModel, Field
6 | from coderetrx.tools.base import BaseTool
7 | import aiofiles
8 | from coderetrx.utils.path import safe_join
9 |
10 |
11 | class ViewFileArgs(BaseModel):
12 | """Parameters for file viewing operation"""
13 |
14 | file_path: str = Field(
15 | ...,
16 | description="Absolute path of file to view",
17 | )
18 | start_line: Optional[int] = Field(
19 | None, description="Starting line number. Optional, default to be 0.", ge=0
20 | )
21 | end_line: Optional[int] = Field(
22 | None,
23 | description="Ending line number. Optional, default to be the last line.",
24 | ge=0,
25 | )
26 |
27 |
28 | class ViewFileTool(BaseTool):
29 | name = "view_file"
30 | description = (
31 | "View the contents of a file. The lines of the file are 0-indexed, and the output of this tool call will be the file contents from StartLine to EndLine. The line range should be less than or equal 1000.\n\n"
32 | "When using this tool to gather information, it's your responsibility to ensure you have the COMPLETE context. Specifically, each time you call this command you should:\n"
33 | "1) Assess if the file contents you viewed are sufficient to proceed with your task.\n"
34 | "2) Take note of where there are lines not shown. These are represented by <... XX more lines from [code item] not shown ...> in the tool response.\n"
35 | "3) If the file contents you have viewed are insufficient, and you suspect they may be in lines not shown, proactively call the tool again to view those lines.\n"
36 | "4) When in doubt, call this tool again to gather more information. Remember that partial file views may miss critical dependencies, imports, or functionality."
37 | )
38 | args_schema: ClassVar[Type[ViewFileArgs]] = ViewFileArgs
39 |
40 | def forward(self, file_path: str, start_line: int, end_line: int) -> str:
41 | """Synchronous wrapper for async _run method."""
42 | return self.run_sync(
43 | file_path=file_path, start_line=start_line, end_line=end_line
44 | )
45 |
46 | async def _run(
47 | self,
48 | file_path: str,
49 | start_line: Optional[int] = None,
50 | end_line: Optional[int] = None,
51 | ) -> str:
52 | """View file content with optional line range."""
53 | full_path = safe_join(self.repo_path, file_path.lstrip("/"))
54 |
55 | if not full_path.exists():
56 | return "File Not Exists.\n"
57 |
58 | if full_path.is_dir():
59 | return "Path is a directory, not a file.\n"
60 |
61 | try:
62 | async with aiofiles.open(full_path, "r", encoding="utf-8") as f:
63 | content = await f.read()
64 | except UnicodeDecodeError:
65 | return "File is not a text file or uses unsupported encoding.\n"
66 |
67 | lines = content.split("\n")
68 | total_lines = len(lines)
69 |
70 | # Handle line range
71 | if start_line is not None or end_line is not None:
72 | start = start_line if start_line is not None else 0
73 | end = end_line if end_line is not None else total_lines
74 |
75 | # Validate line numbers
76 | if start < 0 or end > total_lines or start > end:
77 | return f"Invalid line range (0-{total_lines}).\n"
78 |
79 | threshold_line = 1000
80 | if end - start >= threshold_line:
81 | return f"File is too large ({total_lines} lines), please specify a line range less than or equal {threshold_line}, or search keyword in the file.\n"
82 |
83 | selected_lines = lines[start:end]
84 | result = "\n".join(selected_lines)
85 | else:
86 | result = content
87 |
88 | # Add line count info unless we're showing the whole file
89 | if start_line is not None or end_line is not None:
90 | result += f"\n\n(This file has total {total_lines} lines.)"
91 |
92 | return result
93 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/languages.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 | from os import PathLike
3 | from typing import List, Literal, Optional
4 | from pathlib import Path
5 | import fnmatch
6 |
7 |
8 | IDXSupportedLanguage = Literal[
9 | "javascript",
10 | "typescript",
11 | "python",
12 | "rust",
13 | "c",
14 | "cpp",
15 | "csharp",
16 | "go",
17 | "elixir",
18 | "java",
19 | "php",
20 | ]
21 |
22 | IDXSupportedTag = Literal[
23 | "definition.function",
24 | "definition.type",
25 | "definition.method",
26 | "definition.class",
27 | "definition.interface",
28 | "definition.module",
29 | "definition.reexport",
30 | "definition.variable",
31 | "reference.implementation",
32 | "reference.call",
33 | "reference.class",
34 | "import",
35 | ]
36 |
37 | EXTENSION_MAP: dict[str, IDXSupportedLanguage] = {
38 | "js": "javascript",
39 | "ts": "typescript",
40 | "py": "python",
41 | "rs": "rust",
42 | "c": "c",
43 | "cpp": "cpp",
44 | "h": "c",
45 | "hpp": "cpp",
46 | "cs": "csharp",
47 | "go": "go",
48 | "ex": "elixir",
49 | "exs": "elixir",
50 | "java": "java",
51 | "php": "php",
52 | }
53 |
54 | BLOCKED_PATTERNS = ["*.min.js", "*_test.go"]
55 |
56 | DEP_FILES: List[str] = [
57 | # JS / TS
58 | "package.json",
59 | # Python
60 | "requirements.txt",
61 | "setup.py",
62 | "Pipfile",
63 | "pyproject.toml",
64 | # Rust
65 | "Cargo.toml",
66 | # C/CPP
67 | "Makefile",
68 | # Golang
69 | "go.mod",
70 | # Elixir
71 | "mix.exs",
72 | "pom.xml",
73 | # Java
74 | "build.gradle",
75 | "build.gradle.kts",
76 | "build.sbt",
77 | "build.gradle",
78 | "build.gradle.kts",
79 | "build.sbt",
80 | ]
81 |
82 | BUILTIN_CRYPTO_LIBS: dict[IDXSupportedLanguage, List[str]] = {
83 | "javascript": ["crypto", "node:crypto", "webcrypto"],
84 | "typescript": ["crypto", "node:crypto", "webcrypto"],
85 | "python": ["hashlib", "hmac", "secrets", "ssl"],
86 | "rust": [],
87 | "c": [],
88 | "cpp": [],
89 | "csharp": ["System.Security.Cryptography"],
90 | "go": ["crypto"],
91 | "elixir": [":crypto"],
92 | "java": ["java.security", "javax.crypto"],
93 | "php": ["openssl", "hash", "sodium"],
94 | }
95 |
96 | FUNCLIKE_TAGS: List[IDXSupportedTag] = [
97 | "definition.function",
98 | "definition.method",
99 | ]
100 |
101 | OBJLIKE_TAGS: List[IDXSupportedTag] = [
102 | "definition.class",
103 | "definition.interface",
104 | "definition.module",
105 | "definition.reexport",
106 | "reference.implementation",
107 | ]
108 |
109 | PRIMARY_TAGS: List[IDXSupportedTag] = [
110 | "definition.class",
111 | "definition.type",
112 | "definition.function",
113 | "definition.interface",
114 | "definition.method",
115 | "definition.module",
116 | # Special case: imports something and introduces a new symbol to the global pool
117 | "definition.reexport",
118 | "reference.implementation",
119 | ]
120 |
121 | REFERENCE_TAGS: List[IDXSupportedTag] = [
122 | "reference.call",
123 | "reference.class",
124 | ]
125 |
126 | IMPORT_TAGS: List[IDXSupportedTag] = [
127 | "import",
128 | ]
129 |
130 | VARIABLE_TAGS: List[IDXSupportedTag] = [
131 | "definition.variable",
132 | ]
133 |
134 |
135 | def get_extension(filepath: PathLike | str) -> str:
136 | return str(filepath).split(".")[-1].lower()
137 |
138 |
139 | def is_blocked_file(filepath: PathLike | str) -> bool:
140 | return any(fnmatch.fnmatch(str(filepath), pattern) for pattern in BLOCKED_PATTERNS)
141 |
142 |
143 | def is_sourcecode(filepath: PathLike | str) -> bool:
144 | if is_blocked_file(filepath):
145 | return False
146 | extension = get_extension(filepath)
147 | return extension in EXTENSION_MAP
148 |
149 |
150 | def is_dependency(filepath: PathLike | str) -> bool:
151 | return str(filepath).split("/")[-1] in DEP_FILES
152 |
153 |
154 | def get_language(filepath: PathLike | str) -> Optional[IDXSupportedLanguage]:
155 | extension = get_extension(filepath)
156 | return EXTENSION_MAP.get(extension)
157 |
--------------------------------------------------------------------------------
/coderetrx/utils/cost_tracking.py:
--------------------------------------------------------------------------------
1 | # Cost tracking for LLM calls; currently only for OpenRouter.
2 |
3 | import httpx
4 | from .logger import JsonLogger, LLMCallLog, ErrLog
5 | from pydantic import BaseModel, Field
6 | from typing import Dict, Optional
7 | from attrs import define
8 | from httpx import AsyncClient
9 | from .logger import read_logs
10 | from os import PathLike
11 |
12 | a_client = AsyncClient()
13 |
14 |
15 | @define
16 | class ModelCost:
17 | prompt: float
18 | completion: float
19 |
20 |
21 | type ModelCosts = Dict[str, ModelCost]
22 |
23 | def get_cost_hook(json_logger: JsonLogger, base_url: str = "https://openrouter.ai/api/v1"):
24 | async def on_response(response: httpx.Response):
25 | if not str(response.request.url).startswith(base_url):
26 | return
27 |
28 | try:
29 | await response.aread()
30 | response_json = response.json()
31 | model = response_json["model"]
32 | usage_data = response_json["usage"]
33 | json_logger.log(
34 | LLMCallLog(
35 | model=model,
36 | completion_id=response_json["id"],
37 | completion_tokens=usage_data["completion_tokens"],
38 | prompt_tokens=usage_data["prompt_tokens"],
39 | total_tokens=usage_data["total_tokens"],
40 | call_url=str(response.request.url),
41 | )
42 | )
43 | except Exception as e:
44 | json_logger.log(
45 | ErrLog(
46 | error_type="LLM_CALL_ERROR",
47 | error=str(e),
48 | )
49 | )
50 | return on_response
51 |
52 | async def load_model_costs() -> ModelCosts:
53 | all_model_costs_rsp = await a_client.get("https://openrouter.ai/api/v1/models")
54 | all_model_costs = all_model_costs_rsp.json()
55 | model_costs_parsed = {}
56 | for model in all_model_costs["data"]:
57 | try:
58 | model_id = model["id"]
59 | model_slug = model["canonical_slug"]
60 | model_pricing = model["pricing"]
61 | model_cost = ModelCost(
62 | prompt=float(model_pricing["prompt"]),
63 | completion=float(model_pricing["completion"]),
64 | )
65 | model_costs_parsed[model_id] = model_cost
66 | model_costs_parsed[model_slug] = model_cost
67 | except Exception as e:
68 | print(f"Error parsing model {model_slug}: {e}")
69 | print(f"Loaded {len(model_costs_parsed)} models")
70 | return model_costs_parsed
71 |
72 | async def calc_llm_costs(log_path: PathLike | str, model_costs: Optional[ModelCosts] = None):
73 | from pathlib import Path
74 |
75 | # Check if log file exists
76 | if not Path(log_path).exists():
77 | return 0.0
78 |
79 | model_costs_parsed = model_costs or await load_model_costs()
80 | total_cost = 0
81 | for log_item in read_logs(log_path):
82 | if log_item.data.type == "llm_call":
83 | log = log_item.data
84 | if log.model not in model_costs_parsed:
85 | print(f"Model {log.model} not found in model costs")
86 | continue
87 | cost_info = model_costs_parsed[log.model]
88 | total_cost += (
89 | log.prompt_tokens * cost_info.prompt
90 | + log.completion_tokens * cost_info.completion
91 | )
92 | return total_cost
93 |
94 | def calc_input_tokens(log_path: PathLike | str):
95 | from pathlib import Path
96 |
97 | # Check if log file exists
98 | if not Path(log_path).exists():
99 | return 0
100 |
101 | total_input_tokens = 0
102 | for log_item in read_logs(log_path):
103 | if log_item.data.type == "llm_call":
104 | log = log_item.data
105 | total_input_tokens += log.prompt_tokens
106 | return total_input_tokens
107 |
108 | def calc_output_tokens(log_path: PathLike | str):
109 | from pathlib import Path
110 |
111 | # Check if log file exists
112 | if not Path(log_path).exists():
113 | return 0
114 |
115 | total_output_tokens = 0
116 | for log_item in read_logs(log_path):
117 | if log_item.data.type == "llm_call":
118 | log = log_item.data
119 | total_output_tokens += log.completion_tokens
120 | return total_output_tokens
121 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/filter_symbol_content_by_vector_and_llm.py:
--------------------------------------------------------------------------------
1 | """
2 | Strategy for filtering symbols using vector similarity search followed by LLM refinement.
3 | """
4 |
5 | from collections import defaultdict
6 | from typing import List, Union, Optional, override
7 | from .base import FilterByVectorAndLLMStrategy
8 | from ..smart_codebase import (
9 | SmartCodebase as Codebase,
10 | LLMMapFilterTargetType,
11 | SimilaritySearchTargetType,
12 | )
13 | from coderetrx.static import Keyword, Symbol, File
14 |
15 |
16 | class FilterSymbolContentByVectorAndLLMStrategy(FilterByVectorAndLLMStrategy):
17 | """Strategy to filter symbols using vector similarity search followed by LLM refinement."""
18 |
19 | @override
20 | def get_strategy_name(self) -> str:
21 | return "FILTER_SYMBOL_CONTENT_BY_VECTOR_AND_LLM"
22 |
23 | @override
24 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]:
25 | return ["symbol_content"]
26 |
27 | @override
28 | def get_target_type_for_llm(self) -> LLMMapFilterTargetType:
29 | return "symbol_content"
30 |
31 | @override
32 | def get_collection_size(self, codebase: Codebase) -> int:
33 | return len(codebase.symbols)
34 |
35 | @override
36 | def filter_elements(
37 | self,
38 | codebase: Codebase,
39 | elements: List[Symbol],
40 | target_type: LLMMapFilterTargetType = "symbol_content",
41 | subdirs_or_files: List[str] = [],
42 | ) -> List[Union[Keyword, Symbol, File]]:
43 | filtered_symbols: List[Symbol] = []
44 | for element in elements:
45 | if not isinstance(element, Symbol):
46 | continue
47 | # If subdirs_or_files is provided and codebase is available, filter by subdirs
48 | if subdirs_or_files and codebase:
49 | # Get the relative path from the codebase directory
50 | rpath = str(element.file.path)
51 | if any(rpath.startswith(subdir) for subdir in subdirs_or_files):
52 | filtered_symbols.append(element)
53 | else:
54 | filtered_symbols.append(element)
55 | if target_type == "class_content":
56 | # If the target type is class_content, filter symbols that are classes
57 | filtered_symbols = [
58 | elem for elem in filtered_symbols if elem.type == "class"
59 | ]
60 | elif target_type == "function_content":
61 | # If the target type is function_content, filter symbols that are functions
62 | filtered_symbols = [
63 | elem for elem in filtered_symbols if elem.type == "function"
64 | ]
65 | elif target_type == "leaf_symbol_content":
66 | # If the target type is leaf_symbol_content, filter symbols that are leaves
67 |
68 | parent_of_symbol = {
69 | symbol.id: symbol.chunk.parent.id
70 | for symbol in codebase.symbols
71 | if symbol.chunk.parent
72 | }
73 | childs_of_symbol = defaultdict(list)
74 | for child, parent in parent_of_symbol.items():
75 | childs_of_symbol[parent].append(child)
76 | filtered_symbols = [
77 | elem for elem in filtered_symbols if not childs_of_symbol[elem.id]
78 | ]
79 | elif target_type == "root_symbol_content":
80 | parent_of_symbol = {
81 | symbol.id: symbol.chunk.parent.id
82 | for symbol in codebase.symbols
83 | if symbol.chunk.parent
84 | }
85 | filtered_symbols = [
86 | elem for elem in filtered_symbols if not parent_of_symbol[elem.id]
87 | ]
88 | return filtered_symbols
89 |
90 | @override
91 | def collect_file_paths(
92 | self,
93 | filtered_elements: List[Symbol],
94 | codebase: Codebase,
95 | subdirs_or_files: List[str],
96 | ) -> List[str]:
97 | """Collect file paths from the filtered symbols."""
98 | file_paths = set()
99 | for symbol in filtered_elements:
100 | if isinstance(symbol, Symbol):
101 | file_path = str(symbol.file.path)
102 | if not subdirs_or_files or any(
103 | file_path.startswith(subdir) for subdir in subdirs_or_files
104 | ):
105 | file_paths.add(file_path)
106 | return list(file_paths)
107 |
--------------------------------------------------------------------------------
/bench/repos.lock:
--------------------------------------------------------------------------------
1 | [repositories.cloudflare_boringtun]
2 | url = "https://github.com/cloudflare/boringtun"
3 | commit = "2f3c85f5c4a601018c10b464b1ca890d9504bf6e"
4 |
5 | [repositories.cloudflare_boringtun.stats]
6 | num_files = 30
7 | num_lines = 7937
8 | num_tokens = 73507
9 | primary_language = "rust"
10 |
11 | [repositories.rosenpass_rosenpass]
12 | url = "https://github.com/rosenpass/rosenpass"
13 | commit = "d98815fa7f8dbe6fd2ea2e024e17447bf587d134"
14 |
15 | [repositories.rosenpass_rosenpass.stats]
16 | num_files = 147
17 | num_lines = 30063
18 | num_tokens = 237990
19 | primary_language = "rust"
20 |
21 | [repositories.neondatabase_neon]
22 | url = "https://github.com/neondatabase/neon"
23 | commit = "1b935b1958a7f508807f1bd241e715f33cdc386e"
24 |
25 | [repositories.neondatabase_neon.stats]
26 | num_files = 934
27 | num_lines = 375060
28 | num_tokens = 3021451
29 | primary_language = "rust"
30 |
31 | [repositories.GianisTsol_python-p2p]
32 | url = "https://github.com/GianisTsol/python-p2p"
33 | commit = "69056e1634ec108c7863d36b837f93043ee89ac6"
34 |
35 | [repositories.GianisTsol_python-p2p.stats]
36 | num_files = 8
37 | num_lines = 1189
38 | num_tokens = 11875
39 | primary_language = "python"
40 |
41 | [repositories.magic-wormhole_magic-wormhole]
42 | url = "https://github.com/magic-wormhole/magic-wormhole"
43 | commit = "e5f2ba2c77fe041ffd83db2045591a4d5a65a8de"
44 |
45 | [repositories.magic-wormhole_magic-wormhole.stats]
46 | num_files = 99
47 | num_lines = 27538
48 | num_tokens = 175095
49 | primary_language = "python"
50 |
51 | [repositories.zulip_zulip]
52 | url = "https://github.com/zulip/zulip"
53 | commit = "e243fc67fac9f228dffbbbc5a0ba30abb99e298e"
54 |
55 | [repositories.zulip_zulip.stats]
56 | num_files = 2300
57 | num_lines = 499528
58 | num_tokens = 5711007
59 | primary_language = "python"
60 |
61 | [repositories.TecharoHQ_anubis]
62 | url = "https://github.com/TecharoHQ/anubis"
63 | commit = "6e2eeb9e6562e9330024aa3bcb393d223713e40e"
64 |
65 | [repositories.TecharoHQ_anubis.stats]
66 | num_files = 48
67 | num_lines = 5163
68 | num_tokens = 34501
69 | primary_language = "go"
70 |
71 | [repositories.google_fscrypt]
72 | url = "https://github.com/google/fscrypt"
73 | commit = "827c13689b39814552a3a18449f922b123725b49"
74 |
75 | [repositories.google_fscrypt.stats]
76 | num_files = 42
77 | num_lines = 11698
78 | num_tokens = 18502
79 | primary_language = "go"
80 |
81 | [repositories.ethereum_go-ethereum]
82 | url = "https://github.com/ethereum/go-ethereum"
83 | commit = "0983cd789ee1905aedaed96f72793e5af8466f34"
84 |
85 | [repositories.ethereum_go-ethereum.stats]
86 | num_files = 920
87 | num_lines = 272576
88 | num_tokens = 1104733
89 | primary_language = "go"
90 |
91 | [repositories.Mastercard_client-encryption-java]
92 | url = "https://github.com/Mastercard/client-encryption-java"
93 | commit = "20a5b880b39a590d10785e4b6dbd957112548e31"
94 |
95 | [repositories.Mastercard_client-encryption-java.stats]
96 | num_files = 86
97 | num_lines = 8654
98 | num_tokens = 188192
99 | primary_language = "java"
100 |
101 | [repositories.keycloak_keycloak]
102 | url = "https://github.com/keycloak/keycloak"
103 | commit = "c37e597117a0999960106e6a9a420f514249ccd4"
104 |
105 | [repositories.keycloak_keycloak.stats]
106 | num_files = 7031
107 | num_lines = 955351
108 | num_tokens = 12068541
109 | primary_language = "java"
110 |
111 | [repositories.cryptomator_cryptomator]
112 | url = "https://github.com/cryptomator/cryptomator"
113 | commit = "0178ca4080b9adb774006af53e3517a104476077"
114 |
115 | [repositories.cryptomator_cryptomator.stats]
116 | num_files = 387
117 | num_lines = 29637
118 | num_tokens = 363437
119 | primary_language = "java"
120 |
121 | [repositories.padloc_padloc]
122 | url = "https://github.com/padloc/padloc"
123 | commit = "1d2e9129d65afad72cf7fefc80fd5eb20b5671d9"
124 |
125 | [repositories.padloc_padloc.stats]
126 | num_files = 304
127 | num_lines = 136352
128 | num_tokens = 1363775
129 | primary_language = "javascript"
130 |
131 | [repositories.dani-garcia_vaultwarden]
132 | url = "https://github.com/dani-garcia/vaultwarden"
133 | commit = "0d3f283c3720b82e97c1aa383a66ebc8b4951dfb"
134 |
135 | [repositories.dani-garcia_vaultwarden.stats]
136 | num_files = 67
137 | num_lines = 62609
138 | num_tokens = 401294
139 | primary_language = "javascript"
140 |
141 | [repositories.swiftyapp_swifty]
142 | url = "https://github.com/swiftyapp/swifty"
143 | commit = "858209dc2b67183065de40e298f3fa1e60aad2c3"
144 |
145 | [repositories.swiftyapp_swifty.stats]
146 | num_files = 190
147 | num_lines = 8295
148 | num_tokens = 43671
149 | primary_language = "javascript"
150 |
--------------------------------------------------------------------------------
/test/impl/default/test_smart_codebase.py:
--------------------------------------------------------------------------------
1 | from dotenv import load_dotenv
2 | load_dotenv()
3 | import json
4 | from pathlib import Path
5 | from coderetrx.impl.default import CodebaseFactory
6 | from coderetrx.impl.default import SmartCodebase
7 | import os
8 | import asyncio
9 | import unittest
10 | from coderetrx.utils.embedding import create_documents_embedding
11 | from coderetrx.utils.git import clone_repo_if_not_exists, get_repo_id
12 | from coderetrx.utils.path import get_data_dir
13 | import logging
14 |
15 | logger = logging.getLogger(__name__)
16 | logging.basicConfig(level=logging.INFO)
17 |
18 | TEST_REPOS = ["https://github.com/apache/dubbo-admin.git"]
19 |
20 |
21 | def prepare_codebase(repo_url: str, repo_path: Path):
22 | """Helper function to prepare codebase for testing"""
23 | database_path = get_data_dir() / "databases" / f"{get_repo_id(repo_url)}.json"
24 | # Create a test codebase
25 | clone_repo_if_not_exists(repo_url, str(repo_path))
26 |
27 | if database_path.exists():
28 | codebase = CodebaseFactory.from_json(
29 | json.load(open(database_path, "r", encoding="utf-8"))
30 | )
31 | else:
32 | codebase = CodebaseFactory.new(get_repo_id(repo_url), repo_path)
33 | with open(f"{repo_path}.json", "w") as f:
34 | json.dump(codebase.to_json(), f, indent=4)
35 | return codebase
36 |
37 |
38 | class TestSmartCodebase(unittest.TestCase):
39 | """Test SmartCodebase functionality"""
40 |
41 | def setUp(self):
42 | """Set up test environment"""
43 | self.repo_url = TEST_REPOS[0]
44 | self.repo_path = get_data_dir() / "repos" / get_repo_id(self.repo_url)
45 |
46 | def test_codebase_initialization(self):
47 | """Test codebase initialization"""
48 | codebase = prepare_codebase(self.repo_url, self.repo_path)
49 | self.assertIsInstance(codebase, SmartCodebase)
50 | self.assertEqual(codebase.id, get_repo_id(self.repo_url))
51 |
52 | def test_keyword_extraction(self):
53 | """Test keyword extraction functionality"""
54 | os.environ["KEYWORD_EMBEDDING"] = "True"
55 | codebase = prepare_codebase(self.repo_url, self.repo_path)
56 |
57 | # Verify keywords were extracted
58 | self.assertGreater(len(codebase.keywords), 0)
59 | logger.info(f"Extracted {len(codebase.keywords)} keywords")
60 | logger.info(f"Sample keywords: {[k.content for k in codebase.keywords[:10]]}")
61 |
62 | def test_keyword_search(self):
63 | """Test keyword search functionality"""
64 | os.environ["KEYWORD_EMBEDDING"] = "True"
65 | codebase = prepare_codebase(self.repo_url, self.repo_path)
66 |
67 | try:
68 | # Generate embeddings
69 | test_query = "Is the code snippet used for user authentication?"
70 | logger.info(f"\nTesting keyword search with query: '{test_query}'")
71 |
72 | # Create a searcher
73 | results = asyncio.run(codebase.similarity_search(
74 | target_types=["keyword"], query=test_query
75 | ))
76 |
77 | logger.info("Search results:")
78 | logger.info(results)
79 |
80 | # Verify results
81 | self.assertIsNotNone(results)
82 |
83 | except Exception as e:
84 | self.fail(f"Error during keyword search test: {e}")
85 |
86 | def test_symbol_extraction(self):
87 | """Test symbol extraction functionality"""
88 | os.environ["SYMBOL_NAME_EMBEDDING"] = "True"
89 | codebase = prepare_codebase(self.repo_url, self.repo_path)
90 |
91 | # Verify symbols were extracted
92 | self.assertGreater(len(codebase.symbols), 0)
93 | logger.info(f"Sample symbols: {[k.name for k in codebase.symbols[:10]]}")
94 |
95 | def test_symbol_search(self):
96 | """Test symbol search functionality"""
97 | os.environ["SYMBOL_NAME_EMBEDDING"] = "True"
98 | codebase = prepare_codebase(self.repo_url, self.repo_path)
99 |
100 | try:
101 | # Test search with a sample query
102 | test_query = "Is the code snippet used for user authentication?"
103 | logger.info(f"\nTesting symbol search with query: '{test_query}'")
104 |
105 | # Create a searcher
106 | results = asyncio.run(codebase.similarity_search(
107 | target_types=["symbol_name"], query=test_query
108 | ))
109 |
110 | logger.info("Search results:")
111 | logger.info(results)
112 |
113 | # Verify results
114 | self.assertIsNotNone(results)
115 |
116 | except Exception as e:
117 | self.fail(f"Error during symbol search test: {e}")
118 |
119 |
120 | # Run tests if specified
121 | if __name__ == "__main__":
122 | # Run unittest tests
123 | unittest.main(argv=['first-arg-is-ignored'], exit=False)
--------------------------------------------------------------------------------
/coderetrx/tools/get_references.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from typing import List, ClassVar, Type
4 | from pathlib import Path
5 | from pydantic import BaseModel, Field
6 | from coderetrx.static.ripgrep import ripgrep_search # type: ignore
7 | from coderetrx.tools.base import BaseTool
8 | from coderetrx.utils.llm import count_tokens_openai
9 | from coderetrx.tools.keyword_search import KeywordSearchResult
10 |
11 |
12 | class GetReferenceArgs(BaseModel):
13 | symbol_name: str = Field(
14 | description="The symbolic name whose reference is to be retrieved."
15 | )
16 |
17 |
18 | class GetReferenceResult(KeywordSearchResult):
19 | symbol_idx: int = Field(
20 | default=-1,
21 | description="The index of the symbol in the list of symbols. if -1, it means not applicable.",
22 | )
23 | symbol_name: str = Field(
24 | default="", description="The symbolic name whose reference is to be retrieved."
25 | )
26 | symbol_location: str = Field(default="", description="The location of the symbol.")
27 |
28 | @classmethod
29 | def repr(cls, entries: list["GetReferenceResult"]):
30 | """
31 | convert a list of GetReferenceResult to a readable string
32 | """
33 | if not entries:
34 | return "No entries found."
35 |
36 | result = ""
37 | current_symbol_idx = -1
38 | entries_count = {}
39 | entries = sorted(
40 | entries, key=lambda x: (x.symbol_idx, x.start_line, x.end_line)
41 | )
42 | # symbol_idx is -1 means the multi-symbol cross-reference is not applicable
43 | if entries[0].symbol_idx == -1:
44 | for idx, ref in enumerate(entries):
45 | result += f"## **{idx}. {ref.path}:L{ref.start_line}:{ref.end_line}**\n"
46 | result += f"{ref.content}\n"
47 | return result
48 |
49 | threshold_tokens = 5000 # reference: 6262 tokens = 830 lines python
50 | cur_tokens = 0
51 | for idx, ref in enumerate(entries):
52 | each_result = ""
53 | # When processing a new symbol, add a symbol title.
54 | if ref.symbol_idx != current_symbol_idx:
55 | current_symbol_idx = ref.symbol_idx
56 | entries_count[current_symbol_idx] = 0
57 | each_result += f"# **{current_symbol_idx}. {ref.symbol_name}** location:{ref.symbol_location}\n"
58 |
59 | if not ref.path:
60 | each_result += "No entries Found.\n"
61 | continue
62 |
63 | entries_count[current_symbol_idx] += 1
64 |
65 | if entries_count[current_symbol_idx] == 1:
66 | # calulate the total reference count of the current symbol
67 | total_refs = sum(
68 | 1 for r in entries if r.symbol_idx == current_symbol_idx and r.path
69 | )
70 | each_result += f"{total_refs} entries found:\n"
71 |
72 | each_result += f"## **{entries_count[current_symbol_idx]}. {ref.path}:L{ref.start_line}:{ref.end_line}**\n"
73 | each_result += f"{ref.content}\n"
74 |
75 | cur_tokens += count_tokens_openai(each_result)
76 | if cur_tokens > threshold_tokens:
77 | result += f"Omitted...\nReturn contents are too long. Please refine your query.\n"
78 | break
79 | result += each_result
80 |
81 | return result
82 |
83 |
84 | class GetReferenceTool(BaseTool):
85 | name = "get_reference"
86 | description = (
87 | "Used to find symbol direct references in the codebase, should be used when tracking code usage. "
88 | "Finding multiple levels of references requires multiple calls."
89 | )
90 | args_schema: ClassVar[Type[GetReferenceArgs]] = GetReferenceArgs
91 |
92 | def forward(self, symbol_name: str) -> str:
93 | """Synchronous wrapper for async _run method."""
94 | return self.run_sync(symbol_name=symbol_name)
95 |
96 | async def _run(self, symbol_name: str) -> list[GetReferenceResult]:
97 | """Find references to a symbol in the codebase"""
98 | # Convert the query to a list of regexes
99 | regexes = [f"\\b{symbol_name}\\b"]
100 |
101 | # Call ripgrep_search with the appropriate parameters
102 | rg_results = await ripgrep_search(
103 | search_dir=Path(self.repo_path),
104 | regexes=regexes,
105 | case_sensitive=True,
106 | exclude_file_pattern=".git",
107 | )
108 |
109 | # Convert GrepMatchResult to KeywordSearchResult
110 | results = []
111 | for result in rg_results:
112 | search_result = GetReferenceResult(
113 | path=str(result.file_path),
114 | start_line=result.line_number,
115 | end_line=result.line_number,
116 | content=result.line_text,
117 | )
118 | results.append(search_result)
119 |
120 | return results
121 |
--------------------------------------------------------------------------------
/coderetrx/utils/stats.py:
--------------------------------------------------------------------------------
1 | from ._extras import require_extra
2 |
3 | require_extra("tiktoken", "stats")
4 |
5 | from pydantic import BaseModel
6 | from collections import Counter
7 | from typing import Dict, List, ClassVar
8 | from pathlib import Path
9 | import tiktoken
10 | from concurrent.futures import ThreadPoolExecutor, as_completed
11 |
12 | from coderetrx.static.codebase import Codebase, File, CodeChunk, ChunkType
13 | from coderetrx.static.codebase.languages import get_language
14 |
15 |
16 | class ChunkStats(BaseModel):
17 | chunk_id: str
18 | file_path: str
19 | num_lines: int
20 | num_tokens: int
21 |
22 | # Class variable for the tokenizer
23 | _tokenizer: ClassVar = tiktoken.encoding_for_model("gpt-4o")
24 |
25 | @classmethod
26 | def from_chunk(cls, chunk: CodeChunk, strict: bool = False) -> "ChunkStats":
27 | lines = chunk.lines()
28 | text = "\n".join(lines)
29 |
30 | SAMPLE_SIZE = 100_000
31 | if not strict and len(text) > SAMPLE_SIZE:
32 | tokens = len(cls._tokenizer.encode(text[:SAMPLE_SIZE], disallowed_special=())) * (
33 | len(text) / SAMPLE_SIZE
34 | )
35 | else:
36 | tokens = len(cls._tokenizer.encode(text, disallowed_special=()))
37 |
38 | return cls(
39 | chunk_id=chunk.id,
40 | file_path=str(chunk.src.path),
41 | num_lines=len(lines),
42 | num_tokens=int(tokens),
43 | )
44 |
45 |
46 | class FileStats(BaseModel):
47 | file_path: str
48 | num_chunks: int
49 | num_lines: int
50 | num_tokens: int
51 |
52 | @classmethod
53 | def from_file(cls, file: File) -> "FileStats":
54 | with ThreadPoolExecutor() as executor:
55 | # Process chunks in parallel
56 | future_to_chunk = {
57 | executor.submit(ChunkStats.from_chunk, chunk): chunk
58 | for chunk in file.chunks
59 | if chunk.type in [ChunkType.PRIMARY, ChunkType.IMPORT]
60 | }
61 | chunk_stats = [future.result() for future in as_completed(future_to_chunk)]
62 |
63 | return cls(
64 | file_path=str(file.path),
65 | num_chunks=len(file.chunks),
66 | num_lines=len(file.lines),
67 | num_tokens=sum(stat.num_tokens for stat in chunk_stats),
68 | )
69 |
70 |
71 | class CodebaseStats(BaseModel):
72 | codebase_id: str
73 | num_chunks: int
74 | num_files: int
75 | num_lines: int
76 | num_tokens: int
77 | language_distribution: Dict[str, int] # Maps language name to line count
78 | primary_language: str # The language with the most lines of code
79 | file_stats: List[FileStats]
80 | chunk_stats: List[ChunkStats]
81 |
82 | @classmethod
83 | def from_codebase(cls, codebase: Codebase) -> "CodebaseStats":
84 | with ThreadPoolExecutor() as executor:
85 | # Process files in parallel
86 | future_to_file = {
87 | executor.submit(FileStats.from_file, file): file_path
88 | for file_path, file in codebase.source_files.items()
89 | }
90 | file_stats = [future.result() for future in as_completed(future_to_file)]
91 |
92 | # Calculate language distribution by line count
93 | language_counts = Counter()
94 | for file_path, file in codebase.source_files.items():
95 | lang = get_language(Path(file_path))
96 | if lang:
97 | language_counts[lang] += len(file.lines)
98 |
99 | # Determine primary language (language with most lines)
100 | primary_language = max(language_counts, key=language_counts.get) if language_counts else "Unknown"
101 |
102 | # Collect all chunk stats
103 | chunk_stats = []
104 | for file in codebase.source_files.values():
105 | for chunk in file.chunks:
106 | if chunk.type in [ChunkType.PRIMARY, ChunkType.IMPORT]:
107 | chunk_stats.append(ChunkStats.from_chunk(chunk))
108 |
109 | return cls(
110 | codebase_id=codebase.id,
111 | num_chunks=len(codebase.all_chunks),
112 | num_files=len(codebase.source_files),
113 | num_lines=sum(stat.num_lines for stat in file_stats),
114 | num_tokens=sum(stat.num_tokens for stat in file_stats),
115 | language_distribution=dict(language_counts),
116 | primary_language=primary_language,
117 | file_stats=file_stats,
118 | chunk_stats=chunk_stats,
119 | )
120 |
121 |
122 | if __name__ == "__main__":
123 | import argparse
124 | from textwrap import dedent
125 |
126 | parser = argparse.ArgumentParser()
127 | parser.add_argument("target", type=str, help="Path to codebase")
128 | args = parser.parse_args()
129 |
130 | codebase = Codebase.new("tmp", Path(args.target))
131 | codebase.init_chunks()
132 | stats = CodebaseStats.from_codebase(codebase)
133 | print(
134 | dedent(f"""
135 | Codebase Stats for {args.target}:
136 | - Number of chunks: {stats.num_chunks}
137 | - Number of files: {stats.num_files}
138 | - Number of lines: {stats.num_lines}
139 | - Number of tokens: {stats.num_tokens}
140 | - Primary language: {stats.primary_language}
141 | - Language distribution: {stats.language_distribution}
142 | """)
143 | )
144 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/strategy/adaptive_filter_symbol_content_by_vector_and_llm.py:
--------------------------------------------------------------------------------
1 | """
2 | Strategy for adaptive filtering of symbols using vector similarity search followed by LLM refinement.
3 | """
4 |
5 | from collections import defaultdict
6 | from typing import List, Union, Optional, override, Literal, Any
7 | from .base import (
8 | AdaptiveFilterByVectorAndLLMStrategy,
9 | FilterByVectorAndLLMStrategy,
10 | StrategyExecuteResult,
11 | )
12 | from ..smart_codebase import (
13 | SmartCodebase as Codebase,
14 | LLMMapFilterTargetType,
15 | SimilaritySearchTargetType,
16 | )
17 | from coderetrx.static import Keyword, Symbol, File
18 |
19 |
20 | class AdaptiveFilterSymbolContentByVectorAndLLMStrategy(
21 | AdaptiveFilterByVectorAndLLMStrategy
22 | ):
23 | """Strategy to filter symbols using adaptive vector similarity search followed by LLM refinement."""
24 |
25 | name: str = "ADAPTIVE_FILTER_SYMBOL_CONTENT_BY_VECTOR_AND_LLM"
26 |
27 | @override
28 | def get_strategy_name(self) -> str:
29 | return self.name
30 |
31 | @override
32 | def get_target_types_for_vector(self) -> List[SimilaritySearchTargetType]:
33 | return ["symbol_content"]
34 |
35 | @override
36 | def get_target_type_for_llm(self) -> LLMMapFilterTargetType:
37 | return "symbol_content"
38 |
39 | @override
40 | def get_collection_size(self, codebase: Codebase) -> int:
41 | return len(codebase.symbols)
42 |
43 | @override
44 | def filter_elements(
45 | self,
46 | codebase: Codebase,
47 | elements: List[Any],
48 | target_type: LLMMapFilterTargetType = "symbol_content",
49 | subdirs_or_files: List[str] = [],
50 | ) -> List[Union[Keyword, Symbol, File]]:
51 | filtered_symbols: List[Symbol] = []
52 | for element in elements:
53 | if not isinstance(element, Symbol):
54 | continue
55 | # If subdirs_or_files is provided and codebase is available, filter by subdirs
56 | if subdirs_or_files and codebase:
57 | # Get the relative path from the codebase directory
58 | rpath = str(element.file.path)
59 | if any(rpath.startswith(subdir) for subdir in subdirs_or_files):
60 | filtered_symbols.append(element)
61 | else:
62 | filtered_symbols.append(element)
63 | if target_type == "class_content":
64 | # If the target type is class_content, filter symbols that are classes
65 | filtered_symbols = [
66 | elem for elem in filtered_symbols if elem.type == "class"
67 | ]
68 | elif target_type == "function_content":
69 | # If the target type is function_content, filter symbols that are functions
70 | filtered_symbols = [
71 | elem for elem in filtered_symbols if elem.type == "function"
72 | ]
73 | elif target_type == "leaf_symbol_content":
74 | # If the target type is leaf_symbol_content, filter symbols that are leaves
75 |
76 | parent_of_symbol = {
77 | symbol.id: symbol.chunk.parent.id
78 | for symbol in codebase.symbols
79 | if symbol.chunk.parent
80 | }
81 | childs_of_symbol = defaultdict(list)
82 | for child, parent in parent_of_symbol.items():
83 | childs_of_symbol[parent].append(child)
84 | filtered_symbols = [
85 | elem for elem in filtered_symbols if not childs_of_symbol[elem.id]
86 | ]
87 | elif target_type == "root_symbol_content":
88 | parent_of_symbol = {
89 | symbol.id: symbol.chunk.parent.id
90 | for symbol in codebase.symbols
91 | if symbol.chunk.parent
92 | }
93 | filtered_symbols = [
94 | elem for elem in filtered_symbols if not parent_of_symbol[elem.id]
95 | ]
96 | return filtered_symbols # type: ignore
97 |
98 | @override
99 | def collect_file_paths(
100 | self,
101 | filtered_elements: List[Any],
102 | codebase: Codebase,
103 | subdirs_or_files: List[str],
104 | ) -> List[str]:
105 | file_paths = []
106 | for symbol in filtered_elements:
107 | if isinstance(symbol, Symbol):
108 | file_path = str(symbol.file.path)
109 | if file_path.startswith(tuple(subdirs_or_files)):
110 | file_paths.append(file_path)
111 | return file_paths
112 |
113 | @override
114 | async def execute(
115 | self,
116 | codebase: Codebase,
117 | prompt: str,
118 | subdirs_or_files: List[str],
119 | target_type: LLMMapFilterTargetType = "symbol_content",
120 | ) -> StrategyExecuteResult:
121 | prompt = f"""
122 | requirement: A code chunk with this name is highly likely to meet the following criteria:
123 |
124 | {prompt}
125 |
126 |
127 | The objective of this requirement is to preliminarily identify code chunks that are likely to meet specific content criteria based on their names.
128 | Code chunks with matching names will proceed to deeper analysis in the content filter (content_criteria) at a later stage (not in this run).
129 |
130 | """
131 | return await super().execute(codebase, prompt, subdirs_or_files, target_type)
132 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/parsers/codeql/queries.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 | from pathlib import Path
3 | import logging
4 | from ...languages import IDXSupportedLanguage
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class CodeQLQueryTemplates:
10 | """
11 | Provides CodeQL query templates for different languages and symbol types.
12 |
13 | Each template is loaded from .ql files in the queries/codeql directory,
14 | ensuring consistent results between parsers and maintainable query files.
15 | """
16 |
17 | # Supported query types for each language
18 | LANGUAGE_QUERY_TYPES = {
19 | "python": ["functions", "classes", "imports"],
20 | "javascript": ["functions", "classes", "imports"],
21 | "typescript": ["functions", "classes", "imports"],
22 | "java": ["functions", "classes", "imports"],
23 | "cpp": ["functions", "classes", "includes"],
24 | "c": ["functions", "classes", "includes"],
25 | "go": ["functions", "types", "imports"],
26 | "rust": ["functions", "structs"],
27 | "csharp": ["functions", "classes", "imports"],
28 | }
29 |
30 | @classmethod
31 | def get_supported_languages(cls) -> List[IDXSupportedLanguage]:
32 | """
33 | Get list of languages that have CodeQL query templates.
34 |
35 | Returns:
36 | List of supported languages
37 | """
38 | from typing import cast
39 |
40 | supported = []
41 | for language_str in cls.LANGUAGE_QUERY_TYPES.keys():
42 | language = cast(IDXSupportedLanguage, language_str)
43 | # Check if at least one query file exists for this language
44 | query_types = cls.LANGUAGE_QUERY_TYPES[language_str]
45 | for query_type in query_types:
46 | if cls._get_query_file_path(language, query_type).exists():
47 | supported.append(language)
48 | break
49 |
50 | return supported
51 |
52 | @classmethod
53 | def get_query(cls, language: IDXSupportedLanguage, query_type: str) -> Path:
54 | """
55 | Get a specific query for a language.
56 |
57 | Args:
58 | language: The language
59 | query_type: The type of query (e.g., 'functions', 'classes')
60 |
61 | Returns:
62 | Query text
63 |
64 | Raises:
65 | KeyError: If language or query type is not supported
66 | FileNotFoundError: If query file is not found
67 | """
68 | if language not in cls.LANGUAGE_QUERY_TYPES:
69 | raise KeyError(f"Language not supported: {language}")
70 |
71 | if query_type not in cls.LANGUAGE_QUERY_TYPES[language]:
72 | raise KeyError(
73 | f"Query type '{query_type}' not available for language: {language}"
74 | )
75 |
76 | query_file = cls._get_query_file_path(language, query_type)
77 | if not query_file.exists():
78 | raise FileNotFoundError(f"CodeQL query file not found: {query_file}")
79 |
80 | return query_file
81 |
82 | @classmethod
83 | def has_query(cls, language: IDXSupportedLanguage, query_type: str) -> bool:
84 | """
85 | Check if a specific query is available for a language.
86 |
87 | Args:
88 | language: The language
89 | query_type: The type of query
90 |
91 | Returns:
92 | True if query is available, False otherwise
93 | """
94 | if language not in cls.LANGUAGE_QUERY_TYPES:
95 | return False
96 |
97 | if query_type not in cls.LANGUAGE_QUERY_TYPES[language]:
98 | return False
99 |
100 | query_file = cls._get_query_file_path(language, query_type)
101 | return query_file.exists()
102 |
103 | @classmethod
104 | def get_available_queries(cls) -> Dict[str, List[str]]:
105 | """
106 | Get all available queries organized by language.
107 |
108 | Returns:
109 | Dictionary mapping language -> list of available query types
110 | """
111 | from typing import cast
112 |
113 | available = {}
114 | for language_str in cls.LANGUAGE_QUERY_TYPES:
115 | language = cast(IDXSupportedLanguage, language_str)
116 | query_types = []
117 | for query_type in cls.LANGUAGE_QUERY_TYPES[language_str]:
118 | if cls.has_query(language, query_type):
119 | query_types.append(query_type)
120 | if query_types:
121 | available[language_str] = query_types
122 |
123 | return available
124 |
125 | @classmethod
126 | def _get_query_file_path(
127 | cls, language: IDXSupportedLanguage, query_type: str
128 | ) -> Path:
129 | """
130 | Get the file path for a specific query.
131 |
132 | Args:
133 | language: The language
134 | query_type: The type of query (e.g., 'functions', 'classes')
135 |
136 | Returns:
137 | Path to the query file
138 | """
139 | # Get the directory where this file is located
140 | current_file = Path(__file__)
141 | # Navigate to the queries directory: ../../../queries/codeql/
142 | queries_dir = current_file.parent.parent.parent / "queries" / "codeql"
143 |
144 | # Construct the path: queries/codeql/{language}/{query_type}.ql
145 | query_file_path = queries_dir / language / f"{query_type}.ql"
146 |
147 | return query_file_path
148 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/parsers/treesitter/queries.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 | from pathlib import Path
3 | import logging
4 | from ...languages import IDXSupportedLanguage
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class TreeSitterQueryTemplates:
10 | """
11 | Provides Tree-sitter query templates for different languages and symbol types.
12 |
13 | Each template is loaded from .scm files in the queries/treesitter directory,
14 | ensuring consistent results between parsers and maintainable query files.
15 | """
16 |
17 | # Supported query types for each language
18 | LANGUAGE_QUERY_TYPES = {
19 | "javascript": ["tags", "tests", "fine_imports"],
20 | "typescript": ["tags", "tests", "fine_imports"],
21 | "python": ["tags", "tests", "fine_imports"],
22 | "rust": ["tags", "tests", "fine_imports"],
23 | "c": ["tags", "fine_imports"],
24 | "cpp": ["tags", "fine_imports"],
25 | "csharp": ["tags", "fine_imports"],
26 | "go": ["tags", "tests", "fine_imports"],
27 | "elixir": ["tags", "tests", "fine_imports"],
28 | "java": ["tags", "fine_imports"],
29 | "php": ["tags", "tests", "fine_imports"],
30 | }
31 |
32 | @classmethod
33 | def get_supported_languages(cls) -> List[IDXSupportedLanguage]:
34 | """
35 | Get list of languages that have Tree-sitter query templates.
36 |
37 | Returns:
38 | List of supported languages
39 | """
40 | from typing import cast
41 |
42 | supported = []
43 | for language_str in cls.LANGUAGE_QUERY_TYPES.keys():
44 | language = cast(IDXSupportedLanguage, language_str)
45 | # Check if at least one query file exists for this language
46 | query_types = cls.LANGUAGE_QUERY_TYPES[language_str]
47 | for query_type in query_types:
48 | if cls._get_query_file_path(language, query_type).exists():
49 | supported.append(language)
50 | break
51 |
52 | return supported
53 |
54 | @classmethod
55 | def get_query(cls, language: IDXSupportedLanguage, query_type: str = "tags") -> str:
56 | """
57 | Get a specific query for a language.
58 |
59 | Args:
60 | language: The language
61 | query_type: The type of query (e.g., 'tags', 'tests', 'fine_imports')
62 |
63 | Returns:
64 | Query text
65 |
66 | Raises:
67 | KeyError: If language or query type is not supported
68 | FileNotFoundError: If query file is not found
69 | """
70 | if language not in cls.LANGUAGE_QUERY_TYPES:
71 | raise KeyError(f"Language not supported: {language}")
72 |
73 | if query_type not in cls.LANGUAGE_QUERY_TYPES[language]:
74 | raise KeyError(
75 | f"Query type '{query_type}' not available for language: {language}"
76 | )
77 |
78 | query_file = cls._get_query_file_path(language, query_type)
79 | if not query_file.exists():
80 | raise FileNotFoundError(f"Tree-sitter query file not found: {query_file}")
81 |
82 | with open(query_file, "r") as f:
83 | return f.read()
84 |
85 | @classmethod
86 | def has_query(cls, language: IDXSupportedLanguage, query_type: str) -> bool:
87 | """
88 | Check if a specific query is available for a language.
89 |
90 | Args:
91 | language: The language
92 | query_type: The type of query
93 |
94 | Returns:
95 | True if query is available, False otherwise
96 | """
97 | if language not in cls.LANGUAGE_QUERY_TYPES:
98 | return False
99 |
100 | if query_type not in cls.LANGUAGE_QUERY_TYPES[language]:
101 | return False
102 |
103 | query_file = cls._get_query_file_path(language, query_type)
104 | return query_file.exists()
105 |
106 | @classmethod
107 | def get_available_queries(cls) -> Dict[str, List[str]]:
108 | """
109 | Get all available queries organized by language.
110 |
111 | Returns:
112 | Dictionary mapping language -> list of available query types
113 | """
114 | from typing import cast
115 |
116 | available = {}
117 | for language_str in cls.LANGUAGE_QUERY_TYPES:
118 | language = cast(IDXSupportedLanguage, language_str)
119 | query_types = []
120 | for query_type in cls.LANGUAGE_QUERY_TYPES[language_str]:
121 | if cls.has_query(language, query_type):
122 | query_types.append(query_type)
123 | if query_types:
124 | available[language_str] = query_types
125 |
126 | return available
127 |
128 | @classmethod
129 | def _get_query_file_path(
130 | cls, language: IDXSupportedLanguage, query_type: str
131 | ) -> Path:
132 | """
133 | Get the file path for a specific query.
134 |
135 | Args:
136 | language: The language
137 | query_type: The type of query (e.g., 'tags', 'tests', 'fine_imports')
138 |
139 | Returns:
140 | Path to the query file
141 | """
142 | # Get the directory where this file is located
143 | current_file = Path(__file__)
144 | # Navigate to the queries directory: ../../../queries/treesitter/
145 | queries_dir = current_file.parent.parent.parent / "queries" / "treesitter"
146 |
147 | # Construct the path: queries/treesitter/{query_type}/{language}.scm
148 | query_file_path = queries_dir / query_type / f"{language}.scm"
149 |
150 | return query_file_path
151 |
--------------------------------------------------------------------------------
/USAGE.md:
--------------------------------------------------------------------------------
1 | # Usage Guide
2 |
3 | ## Programmatic API
4 |
5 | CodeRetrX provides a powerful programmatic interface through the `coderetrx_filter` and `llm_traversal_filter` API, which enables flexible code analysis and retrieval across different search strategies and filtering modes.
6 |
7 | The coderetrx_filter implements stragties designed by us. It offers a cost-effective semantic recall approach for large-scale repositories, achieving approximately 90% recall with only about 25% of the resource consumption in practical tests(line_per_symbol strategy) — the larger the repository, the greater the savings.
8 |
9 | The llm_traversal_filter provides the most comprehensive and accurate analysis, ideal for establishing ground truth. For small-scale repositories, this strategy can also be a good choise.
10 |
11 | ### Using coderetrx_filter with symbol_name strategy
12 |
13 | ```python
14 | from pathlib import Path
15 | from coderetrx.retrieval import coderetrx_filter, CodebaseFactory
16 |
17 | # Initialize codebase
18 | codebase = CodebaseFactory.new("repo_name", Path("/path/to/your/repo"))
19 |
20 | # Basic symbol search
21 | elements, llm_results = await coderetrx_filter(
22 | codebase=codebase,
23 | prompt="your_filter_prompt",
24 | subdirs_or_files=["src/"],
25 | target_type="symbol_content",
26 | coarse_recall_strategy="symbol_name"
27 | )
28 |
29 | # Process results
30 | for element in elements:
31 | print(f"Found: {element.name} in {element.file.path}")
32 | ```
33 |
34 | ### Using coderetrx_filter with line_per_symbol strategy
35 |
36 | ```python
37 | from coderetrx.retrieval import coderetrx_filter
38 | from coderetrx.retrieval.code_recall import CodeRecallSettings
39 |
40 | # Configure advanced settings
41 | settings = CodeRecallSettings(
42 | llm_primary_recall_model_id="google/gemini-2.5-flash-lite-preview-06-17",
43 | llm_secondary_recall_model_id="openai/gpt-4o-mini"
44 | )
45 |
46 | # Cost-efficient and complete search with line_per_symbol mode and secondary recall
47 | elements, llm_results = await coderetrx_filter(
48 | codebase=codebase,
49 | prompt="your_filter_prompt",
50 | subdirs_or_files=["src/", "lib/"],
51 | target_type="symbol_content",
52 | coarse_recall_strategy="line_per_symbol",
53 | settings=settings,
54 | enable_secondary_recall=True
55 | )
56 | ```
57 |
58 | ### Using llm_traversal_filter (Ground Truth & Maximum Accuracy)
59 |
60 | ```python
61 | from coderetrx.retrieval import llm_traversal_filter
62 |
63 | # Ground truth search - most comprehensive and accurate
64 | elements, llm_results = await llm_traversal_filter(
65 | codebase=codebase,
66 | prompt="your_filter_prompt",
67 | subdirs_or_files=["src/", "lib/"],
68 | target_type="symbol_content",
69 | settings=settings
70 | )
71 | ```
72 |
73 | ## Coarse Recall Strategy
74 |
75 | **coderetrx_filter** supports `file_name`, `symbol_name`, `line_per_symbol`, and `auto` strategies for different speed/accuracy tradeoffs. **llm_traversal_filter** uses full LLM processing for maximum accuracy.
76 |
77 | [See detailed strategy comparison in README.md](#-coarse-recall-strategies)
78 |
79 | ## Target type Options
80 |
81 | Target type defines the retrieval target, determining the type of code object to be recalled and returned. For example, if the target_type is set to class_content, the result will include the relevant classes whose content match the query. Below are the available target_type options:
82 |
83 | - **`symbol_name`**:Matches symbols (e.g., functions, classes) whose **name** satisfies the query.
84 | - **`symbol_content`**: Matches symbols whose **entire code content** satisfies the query.
85 | - **`leaf_symbol_name`**: Matches **leaf symbols** (symbols without child elements, such as methods) whose **name** satisfy the query.
86 | - **`leaf_symbol_content`**: Matches **leaf symbols** whose **code content** satisfies the query.
87 | - **`root_symbol_name`**: Matches **root symbols** (symbols without parent elements, such as top-level classes, functions) whose **name** satisfy the query.
88 | - **`root_symbol_content`**: Matches **root symbols** whose **entire code content** satisfies the query.
89 | - **`class_name`**: Matches **classes** whose **name** that satisfy the query.
90 | - **`class_content`**: Matches **classes** whose **entire code content** satisfies the query.
91 | - **`function_name`**: Matches **functions** whose **name** satisfy the query.
92 | - **`function_content`**: Matches **functions** whose **entire code content** satisfies the query.
93 | - **`dependency_name`**: Matches **dependency names** (e.g., imported libraries or modules) that satisfy the query.
94 |
95 | Note: The coderetrx_filter only supports the xxx_content series of target_type, while the llm_traversal_filter supports all target_type options.
96 |
97 | ## Settings Configuration
98 |
99 | The `CodeRecallSettings` class allows fine-tuning of the search behavior:
100 |
101 | ```python
102 | settings = CodeRecallSettings(
103 | llm_primary_recall_model_id="...", # Model used for coarse recall and the primary recall in the refined stage.
104 | llm_secondary_recall_model_id="...", # Model used for secondary recall in the refined stage. If set (not None), secondary recall will be enabled.
105 | llm_selector_strategy_model_id="...", # Model used for strategy selection in "auto" mode during the coarse recall stage.
106 | llm_call_mode="function_call" # LLM call mode. If set to "function_call", the LLM will return results as a function call (recommended for models supporting this feature).
107 | # If set to "traditional", the LLM will return results in plain text format.
108 | )
109 | ```
110 |
111 | ## Working with Results
112 |
113 | ```python
114 | # Process different types of results
115 | for element in elements:
116 | if hasattr(element, 'name'): # Symbol
117 | print(f"Symbol: {element.name} in {element.file.path}")
118 | print(f"Content: {element.chunk.code()}")
119 | elif hasattr(element, 'path'): # File
120 | print(f"File: {element.path}")
121 | elif hasattr(element, 'content'): # Keyword
122 | print(f"Keyword: {element.content}")
123 |
124 | # Access LLM analysis results
125 | for result in llm_results:
126 | print(f"Index: {result.index}")
127 | print(f"Reason: {result.reason}")
128 | print(f"Result: {result.result}")
129 | ```
130 |
131 |
--------------------------------------------------------------------------------
/bench/queries.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "id": "3b866a1d-8de9-4999-bbf3-651c1eb6ec5e",
4 | "name": "Unsafe Evaluation Functions",
5 | "filter_prompt": "The code snippet contains a function call that dynamically executes code or system commands. Examples include Python's `eval()`, `exec()`, or functions like `os.system()`, `subprocess.run()` (especially with `shell=True`), `subprocess.call()` (with `shell=True`), or `popen()`. The critical feature is that the string representing the code or command to be executed is not a hardcoded literal; instead, it's derived from a variable, function argument, string concatenation/formatting, or an external source such as user input, network request, or LLM output.",
6 | "subdirs_or_files": [
7 | "/"
8 | ]
9 | },
10 | {
11 | "id": "b2565e7b-dacc-44e5-9817-a31898326bc1",
12 | "name": "Pickle/Cloudpickle Deserialization",
13 | "filter_prompt": "This code snippet deserializes data using functions from Python's `pickle` or `cloudpickle` libraries, such as `load()` or `loads()`. The input data for the deserialization operation is not a hardcoded literal.",
14 | "subdirs_or_files": [
15 | "/"
16 | ]
17 | },
18 | {
19 | "id": "a8a337fc-4a32-4bf9-b98c-4b30a666352a",
20 | "name": "Magic Bytes File Type Manipulation",
21 | "filter_prompt": "This code snippet implements logic to determine or validate a file's type by reading and analyzing its initial bytes (e.g., magic bytes, file signature, or header). This is often part of a file upload handling mechanism or file processing pipeline where verifying the actual content type based on its leading bytes is critical.",
22 | "subdirs_or_files": [
23 | "/"
24 | ]
25 | },
26 | {
27 | "id": "9011537d-3f3a-4587-9500-068c453194f9",
28 | "name": "Shell Command Execution",
29 | "filter_prompt": "This code snippet executes a shell command, system command, an external program, or evaluates a string as code. This is often done using functions like `os.system`, `subprocess.call`, `subprocess.run` (especially with `shell=True`), `subprocess.Popen` (especially with `shell=True`), `commands.getoutput`, `Runtime.getRuntime().exec`, `ProcessBuilder`, `php.system`, `php.exec`, `php.shell_exec`, `php.passthru`, `php.popen`, PHP backticks (` `), `Node.child_process.exec`, `Node.child_process.execSync`, `eval`, `exec`, `ScriptEngine.eval()`, `execCommand`, `Perl.system`, `Ruby.system`, `Ruby.exec`, Ruby backticks (``), `Go.os/exec.Command`, etc. The command string, arguments to the command, or the string being evaluated as code, are derived from variables, function parameters, or other dynamic sources, rather than being solely hardcoded string literals.",
30 | "subdirs_or_files": [
31 | "/"
32 | ]
33 | },
34 | {
35 | "id": "ab624bd9-92e7-4f34-a2b4-02a914a78040",
36 | "name": "Command Injection in CLI Applications",
37 | "filter_prompt": "This code snippet executes operating system commands using functions like `os.system`, `subprocess.run`, `subprocess.Popen`, `subprocess.call`, `subprocess.check_output`, `commands.getoutput`, or `pty.spawn`. The command being executed is dynamically constructed using string operations (e.g., concatenation, f-strings, `.format()`) with variables that could hold data from external sources like command-line arguments or file content. Prioritize instances where `subprocess` functions are used with `shell=True` or where command components are assembled from non-literal string variables.",
38 | "subdirs_or_files": [
39 | "/"
40 | ]
41 | },
42 | {
43 | "id": "1f7cc885-ad42-42a7-a051-0bb5cd97374f",
44 | "name": "Other Deserialization Mechanisms",
45 | "filter_prompt": "This code snippet performs deserialization of data using PyTorch's `torch.load()` (or similar model loading functions in AI/ML frameworks), Python's `shelve` module (e.g., `shelve.open()`, `shelf[key]`), or JDBC connection mechanisms (e.g., constructing connection URLs or using drivers). The deserialization is flagged if the input data (such as a model file path or content, data from a shelve file, or components of a JDBC URL) is not a hardcoded literal and could originate from an untrusted external source.",
46 | "subdirs_or_files": [
47 | "/"
48 | ]
49 | },
50 | {
51 | "id": "dbfd2d19-9e1a-4c1a-b1f1-79318deebb06",
52 | "name": "Path Traversal and File Operations",
53 | "filter_prompt": "Locate code snippets that perform file system operations (such as reading, writing, deleting, moving files or directories, extracting archives, or including files) or use file paths or names within system commands. Focus on cases where these file paths or names are derived from, or can be influenced by, external sources (e.g., user input, network data, API parameters, environment variables, or function arguments traceable to such sources) and where there is a potential lack of, or insufficient, sanitization or validation against path traversal techniques (e.g., sequences like '..', absolute paths, symbolic links, null bytes, or encoding tricks).",
54 | "subdirs_or_files": [
55 | "/"
56 | ]
57 | },
58 | {
59 | "id": "bd1ec518-2837-40a2-ab2d-fc38099b3686",
60 | "name": "Arbitrary File Write",
61 | "filter_prompt": "This code snippet involves a file system write operation (such as creating, writing to, or moving a file). The destination path, filename, or the content of the file appears to be constructed or influenced by data originating from an external source (e.g., user input, API request parameters, network data, configuration files, environment variables) and there is no clear evidence of robust sanitization, validation, or restriction of the path to a predefined safe directory.",
62 | "subdirs_or_files": [
63 | "/"
64 | ]
65 | },
66 | {
67 | "id": "50794dd5-45fd-43e4-afd3-37359ef5b076",
68 | "name": "Bypass of File Extension Restrictions",
69 | "filter_prompt": "This code snippet is involved in processing files uploaded by users. This includes operations such as retrieving the original filename or file extension, determining the destination path or filename for storage, moving or saving the uploaded file to the server's filesystem, and/or implementing validation rules to restrict allowed file types. These validation rules might be based on the file's extension (e.g., checking against a list of permitted or forbidden extensions like '.php', '.jsp', '.asp', '.exe', '.gif', '.jpg'), its MIME type, or its initial bytes (magic bytes/file signatures).",
70 | "subdirs_or_files": [
71 | "/"
72 | ]
73 | }
74 | ]
--------------------------------------------------------------------------------
/coderetrx/static/ripgrep/installer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import sys
4 | import asyncio
5 | import shutil
6 | import tarfile
7 | from tempfile import TemporaryDirectory
8 | import zipfile
9 | from pathlib import Path
10 | from typing import Optional, Tuple, Dict
11 |
12 | import httpx
13 |
14 | # Ripgrep release information
15 | RG_VERSION = "14.1.1"
16 | RG_REPO = "BurntSushi/ripgrep"
17 | RG_BASE_URL = f"https://github.com/{RG_REPO}/releases/download/{RG_VERSION}"
18 |
19 |
20 | def get_platform_info() -> Optional[str]:
21 | """Get platform information and return the corresponding download filename
22 |
23 | Returns:
24 | Optional[str]: Matching platform filename or None if unsupported
25 | """
26 | system = platform.system().lower()
27 | machine = platform.machine().lower()
28 |
29 | # Map platform to ripgrep release filename
30 | platform_map: Dict[Tuple[str, str], str] = {
31 | ("darwin", "x86_64"): f"ripgrep-{RG_VERSION}-x86_64-apple-darwin.tar.gz",
32 | ("darwin", "arm64"): f"ripgrep-{RG_VERSION}-aarch64-apple-darwin.tar.gz",
33 | ("linux", "x86_64"): f"ripgrep-{RG_VERSION}-x86_64-unknown-linux-musl.tar.gz",
34 | ("linux", "i686"): f"ripgrep-{RG_VERSION}-i686-unknown-linux-gnu.tar.gz",
35 | ("linux", "aarch64"): f"ripgrep-{RG_VERSION}-aarch64-unknown-linux-gnu.tar.gz",
36 | (
37 | "linux",
38 | "armv7l",
39 | ): f"ripgrep-{RG_VERSION}-armv7-unknown-linux-gnueabihf.tar.gz",
40 | ("windows", "amd64"): f"ripgrep-{RG_VERSION}-x86_64-pc-windows-msvc.zip",
41 | ("windows", "x86"): f"ripgrep-{RG_VERSION}-i686-pc-windows-msvc.zip",
42 | }
43 |
44 | return platform_map.get((system, machine))
45 |
46 |
47 | async def download_file(url: str, dest_path: Path) -> None:
48 | """Download a file from a URL to a destination path."""
49 | async with httpx.AsyncClient(follow_redirects=True) as client:
50 | async with client.stream("GET", url) as response:
51 | response.raise_for_status()
52 | with open(dest_path, "wb") as f:
53 | async for chunk in response.aiter_bytes():
54 | f.write(chunk)
55 |
56 |
57 | def extract_single_file(archive_path: Path, filename: str, dest: Path) -> bool:
58 | """Extract a single file from archive without extracting everything
59 |
60 | Args:
61 | archive_path: Path to the archive file
62 | filename: Name of the file to extract
63 | dest: Destination path for the extracted file
64 |
65 | Returns:
66 | bool: True if extraction succeeded, False otherwise
67 | """
68 | try:
69 | if (
70 | archive_path.suffixes[-2:] == [".tar", ".gz"]
71 | or archive_path.suffix[-1] == ".tgz"
72 | ):
73 | with tarfile.open(archive_path, "r:gz") as tar:
74 | member = next(
75 | (m for m in tar.getmembers() if m.name.endswith(f"/{filename}")),
76 | None,
77 | )
78 | if member:
79 | extracted_file = tar.extractfile(member)
80 | if extracted_file is None:
81 | return False
82 | with extracted_file as src, open(dest, "wb") as dst:
83 | shutil.copyfileobj(src, dst)
84 | return True
85 | elif archive_path.suffix[-1] == ".zip":
86 | with zipfile.ZipFile(archive_path, "r") as zip_ref:
87 | member = next(
88 | (m for m in zip_ref.infolist() if m.filename.endswith(filename)),
89 | None,
90 | )
91 | if member:
92 | with zip_ref.open(member) as src, open(dest, "wb") as dst:
93 | shutil.copyfileobj(src, dst)
94 | return True
95 | return False
96 | except Exception as e:
97 | print(f"Error extracting {filename}: {e}")
98 | return False
99 |
100 |
101 | async def install_rg(install_path: Path) -> Optional[Path]:
102 | """Main installation function
103 |
104 | Returns:
105 | Optional[Path]: Path to the installed rg binary or None if failed
106 | """
107 | # Get appropriate ripgrep filename for current platform
108 | rg_file = get_platform_info()
109 | if not rg_file:
110 | print("Error: Unsupported platform")
111 | print(f"System: {platform.system()}, Machine: {platform.machine()}")
112 | return None
113 |
114 | install_path.mkdir(parents=True, exist_ok=True)
115 | with TemporaryDirectory() as temp_dir: # ty: ignore[no-matching-overload]
116 | try:
117 | # Download the file
118 | download_url = f"{RG_BASE_URL}/{rg_file}"
119 | archive_path = Path(temp_dir) / rg_file
120 | await download_file(download_url, archive_path)
121 |
122 | # Determine binary filename based on platform
123 | binary_name = "rg.exe" if os.name == "nt" else "rg"
124 | temp_binary = Path(temp_dir) / binary_name
125 |
126 | # Extract just the binary file
127 | if not extract_single_file(archive_path, binary_name, temp_binary):
128 | print(f"Error: Could not find {binary_name} in the downloaded archive")
129 | return None
130 |
131 | # Set executable permissions (Unix-like systems)
132 | if os.name != "nt":
133 | os.chmod(temp_binary, 0o755)
134 |
135 | # Install to destination
136 | dest_path = install_path / binary_name
137 |
138 | # Move the binary to final location
139 | shutil.move(temp_binary, dest_path)
140 |
141 | print("Installation complete!")
142 | print(f"ripgrep has been installed to: {dest_path}")
143 | print(f"{install_path}")
144 |
145 | return dest_path
146 |
147 | except httpx.HTTPError as e:
148 | print(f"Download failed: {e}")
149 | except Exception as e:
150 | print(f"An unexpected error occurred: {e}")
151 |
152 | return None
153 |
154 |
155 | if __name__ == "__main__":
156 | from pathlib import Path
157 |
158 | if len(sys.argv) < 2:
159 | path = Path(".")
160 | else:
161 | path = Path(sys.argv[1])
162 | installed_path = asyncio.run(install_rg(path))
163 | if installed_path:
164 | print(f"Successfully installed rg at: {installed_path}")
165 | else:
166 | print("Failed to install ripgrep")
167 | sys.exit(1)
168 |
--------------------------------------------------------------------------------
/coderetrx/retrieval/topic_extractor.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any, List, Optional, Literal
2 | import logging
3 | import json
4 | from pydantic import BaseModel
5 |
6 | from .prompt import (
7 | KeywordExtractorResult,
8 | topic_extraction_prompt_template,
9 | topic_extraction_function_call_system_prompt,
10 | get_topic_extraction_function_definition, topic_extraction_function_call_user_prompt_template,
11 | )
12 | from .smart_codebase import SmartCodebaseSettings
13 | from coderetrx.utils.llm import call_llm_with_fallback, call_llm_with_function_call
14 | from .smart_codebase import (
15 | SmartCodebase,
16 | SimilaritySearchTargetType,
17 | LLMCallMode,
18 | )
19 | import os
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class TopicExtractor:
25 | """
26 | A class to extract topics from input text using LLM before performing vector similarity searching.
27 | """
28 | async def extract_topic(self, input_text: str, llm_call_mode: Literal["traditional", "function_call"] = "traditional") -> Optional[str]:
29 | """
30 | Extract the core topic from the input text using LLM.
31 |
32 | Args:
33 | input_text: The input text to extract topic from
34 | llm_call_mode: Whether to use traditional prompt-based extraction or function call mode.
35 |
36 | Returns:
37 | The extracted topic as a string, or None if extraction fails
38 | """
39 | try:
40 | if llm_call_mode == "function_call":
41 | return await self._extract_topic_with_function_call(input_text)
42 | else:
43 | return await self._extract_topic_traditional(input_text)
44 | except Exception as e:
45 | logger.error(f"Error extracting topic: {str(e)}", exc_info=True)
46 | return None
47 |
48 | async def _extract_topic_traditional(self, input_text: str) -> Optional[str]:
49 | """Extract topic using traditional prompt-based approach."""
50 | try:
51 | # Prepare input data for the prompt template
52 | input_data = {"input": input_text}
53 |
54 | # Call LLM with the topic extraction prompt template
55 | # Use topic extraction model_id from settings
56 | settings = SmartCodebaseSettings()
57 | topic_model_id = settings.llm_topic_extraction_model_id or settings.default_model_id
58 | model_ids = [topic_model_id]
59 | result = await call_llm_with_fallback(
60 | response_model=KeywordExtractorResult,
61 | input_data=input_data,
62 | prompt_template=topic_extraction_prompt_template,
63 | model_ids=model_ids,
64 | )
65 |
66 | assert isinstance(result, KeywordExtractorResult)
67 | # Extract the topic from the result
68 | extracted_topic = result.result
69 | logger.info(f"Successfully extracted topic: '{extracted_topic}' from input using traditional mode")
70 | return extracted_topic
71 |
72 | except Exception as e:
73 | logger.error(f"Error in traditional topic extraction: {str(e)}", exc_info=True)
74 | return None
75 |
76 | async def _extract_topic_with_function_call(self, input_text: str) -> Optional[str]:
77 | """Extract topic using function call approach."""
78 | try:
79 | # Prepare prompts for function call
80 | system_prompt = topic_extraction_function_call_system_prompt
81 | user_prompt = topic_extraction_function_call_user_prompt_template.format(
82 | input=input_text,
83 | )
84 | function_definition = get_topic_extraction_function_definition()
85 |
86 | # Call LLM with function call
87 | # Use topic extraction model_id from settings
88 | settings = SmartCodebaseSettings()
89 | topic_model_id = settings.llm_topic_extraction_model_id or settings.default_model_id
90 | model_ids = [topic_model_id]
91 | function_args = await call_llm_with_function_call(
92 | system_prompt=system_prompt,
93 | user_prompt=user_prompt,
94 | function_definition=function_definition,
95 | model_ids=model_ids,
96 | )
97 |
98 | # Extract topic from function call result
99 | extracted_topic = function_args.get("result")
100 | reason = function_args.get("reason", "")
101 |
102 | if extracted_topic:
103 | logger.info(f"Successfully extracted topic: '{extracted_topic}' from input using function call mode. Reason: {reason}")
104 | return extracted_topic
105 | else:
106 | logger.warning("Function call returned empty result for topic extraction")
107 | return None
108 |
109 | except Exception as e:
110 | logger.error(f"Error in function call topic extraction: {str(e)}", exc_info=True)
111 | return None
112 |
113 | async def extract_and_search(
114 | self,
115 | codebase: SmartCodebase,
116 | input_text: str,
117 | target_types: List[SimilaritySearchTargetType],
118 | threshold: float = 0.1,
119 | top_k: int = 100,
120 | llm_call_mode: Literal["traditional", "function_call"] = "traditional",
121 | ) -> List[Any]:
122 | """
123 | Extract topic from input text and use it for vector similarity search.
124 |
125 | Args:
126 | codebase: The codebase to search in
127 | input_text: The input text to extract topic from
128 | target_types: List of target types for similarity search
129 | threshold: Similarity threshold
130 | top_k: Number of top results to return
131 | llm_call_mode: Whether to use traditional prompt-based extraction or function call mode.
132 |
133 | Returns:
134 | List of search results
135 | """
136 | # Extract topic from input text
137 | topic = await self.extract_topic(input_text, llm_call_mode)
138 |
139 | if not topic:
140 | logger.warning(
141 | "Using original input text for search as topic extraction failed"
142 | )
143 | topic = input_text
144 |
145 | logger.info(f"Performing similarity search with topic: '{topic}'")
146 |
147 | # Perform similarity search using the extracted topic
148 | results = await codebase.similarity_search(
149 | target_types=target_types, query=topic, threshold=threshold, top_k=top_k
150 | )
151 |
152 | logger.info(f"Found {len(results)} results for topic '{topic}'")
153 | return results
154 |
--------------------------------------------------------------------------------
/test/impl/default/test_code_recall.py:
--------------------------------------------------------------------------------
1 | from dotenv import load_dotenv
2 | load_dotenv()
3 | import json
4 | from pathlib import Path
5 | from coderetrx.impl.default import CodebaseFactory
6 | from coderetrx.impl.default import TopicExtractor
7 | from coderetrx.retrieval.code_recall import multi_strategy_code_filter, multi_strategy_code_mapping
8 | import os
9 | import asyncio
10 | import unittest
11 | from typing import Literal
12 | from unittest.mock import patch, MagicMock
13 | from coderetrx.utils.embedding import create_documents_embedding
14 | from coderetrx.utils.git import clone_repo_if_not_exists, get_repo_id
15 | from coderetrx.utils.path import get_data_dir
16 | import logging
17 |
18 | logger = logging.getLogger(__name__)
19 | logging.basicConfig(level=logging.INFO)
20 |
21 | TEST_REPOS = ["https://github.com/apache/dubbo-admin.git"]
22 |
23 |
24 | def prepare_codebase(repo_url: str, repo_path: Path):
25 | """Helper function to prepare codebase for testing"""
26 | database_path = get_data_dir() / "databases" / f"{get_repo_id(repo_url)}.json"
27 | # Create a test codebase
28 | clone_repo_if_not_exists(repo_url, str(repo_path))
29 |
30 | if database_path.exists():
31 | codebase = CodebaseFactory.from_json(
32 | json.load(open(database_path, "r", encoding="utf-8"))
33 | )
34 | else:
35 | codebase = CodebaseFactory.new(get_repo_id(repo_url), repo_path)
36 | with open(f"{repo_path}.json", "w") as f:
37 | json.dump(codebase.to_json(), f, indent=4)
38 | return codebase
39 |
40 |
41 | class TestLLMCodeFilterTool(unittest.TestCase):
42 | """Test LLMCodeFilterTool functionality"""
43 |
44 | def setUp(self):
45 | """Set up test environment"""
46 | self.repo_url = TEST_REPOS[0]
47 | self.repo_path = get_data_dir() / "repos" / get_repo_id(self.repo_url)
48 | self.codebase = prepare_codebase(self.repo_url, self.repo_path)
49 | self.test_dir = "/"
50 | self.test_prompt = "Is the code snippet used for user authentication?"
51 | self.topic_extractor = TopicExtractor()
52 |
53 | async def recall_with_mode(self, mode: Literal["fast", "balance", "precise", "custom"]):
54 | """Helper method to run the tool with a specific mode"""
55 | result, llm_output = await multi_strategy_code_filter(
56 | codebase=self.codebase,
57 | subdirs_or_files=[self.test_dir],
58 | prompt=self.test_prompt,
59 | target_type="symbol_content",
60 | mode=mode,
61 | topic_extractor=self.topic_extractor,
62 | )
63 | return result
64 |
65 | def test_initialization(self):
66 | """Test initialization of test environment"""
67 | self.assertIsNotNone(self.codebase)
68 | self.assertIsNotNone(self.test_dir)
69 |
70 | def test_fast_mode(self):
71 | """Test LLMCodeFilterTool in fast mode"""
72 | result = asyncio.run(self.recall_with_mode("fast"))
73 |
74 | # Verify results
75 | self.assertIsNotNone(result)
76 | self.assertIsInstance(result, list)
77 | logger.info(f"Fast mode results count: {len(result)}")
78 | if result:
79 | logger.info(f"Sample result: {result[0]}")
80 |
81 | def test_balance_mode(self):
82 | """Test LLMCodeFilterTool in balance mode"""
83 | result = asyncio.run(self.recall_with_mode("balance"))
84 |
85 | # Verify results
86 | self.assertIsNotNone(result)
87 | self.assertIsInstance(result, list)
88 | logger.info(f"Balance mode results count: {len(result)}")
89 | if result:
90 | logger.info(f"Sample result: {result[0]}")
91 |
92 | def test_precise_mode(self):
93 | """Test LLMCodeFilterTool in precise mode"""
94 | result = asyncio.run(self.recall_with_mode("precise"))
95 |
96 | # Verify results
97 | self.assertIsNotNone(result)
98 | self.assertIsInstance(result, list)
99 | logger.info(f"Precise mode results count: {len(result)}")
100 | if result:
101 | logger.info(f"Sample result: {result[0]}")
102 |
103 |
104 | class TestLLMCodeMappingTool(unittest.TestCase):
105 | """Test multi_strategy_code_mapping functionality"""
106 |
107 | def setUp(self):
108 | """Set up test environment"""
109 | self.repo_url = TEST_REPOS[0]
110 | self.repo_path = get_data_dir() / "repos" / get_repo_id(self.repo_url)
111 | self.codebase = prepare_codebase(self.repo_url, self.repo_path)
112 | self.test_dir = "/"
113 | self.test_prompt = "Extract the function call sink that may result in arbitrary code execution"
114 |
115 | async def recall_with_mode(self, mode: Literal["fast", "balance", "precise", "custom"]):
116 | """Helper method to run the tool with a specific mode"""
117 | result, llm_output = await multi_strategy_code_mapping(
118 | codebase=self.codebase,
119 | subdirs_or_files=[self.test_dir],
120 | prompt=self.test_prompt,
121 | target_type="symbol_content",
122 | mode=mode,
123 | )
124 | return result
125 |
126 | def test_initialization(self):
127 | """Test initialization of test environment"""
128 | self.assertIsNotNone(self.codebase)
129 | self.assertIsNotNone(self.test_dir)
130 |
131 | def test_fast_mode(self):
132 | """Test LLMCodeMappingTool in fast mode"""
133 | result = asyncio.run(self.recall_with_mode("fast"))
134 |
135 | # Verify results
136 | self.assertIsNotNone(result)
137 | self.assertIsInstance(result, list)
138 | logger.info(f"Fast mode results count: {len(result)}")
139 | if result:
140 | logger.info(f"Sample result: {result[0]}")
141 |
142 | def test_balance_mode(self):
143 | """Test LLMCodeMappingTool in balance mode"""
144 | result = asyncio.run(self.recall_with_mode("balance"))
145 |
146 | # Verify results
147 | self.assertIsNotNone(result)
148 | self.assertIsInstance(result, list)
149 | logger.info(f"Balance mode results count: {len(result)}")
150 | if result:
151 | logger.info(f"Sample result: {result[0]}")
152 |
153 | def test_precise_mode(self):
154 | """Test LLMCodeMappingTool in precise mode"""
155 | result = asyncio.run(self.recall_with_mode("precise"))
156 |
157 | # Verify results
158 | self.assertIsNotNone(result)
159 | self.assertIsInstance(result, list)
160 | logger.info(f"Precise mode results count: {len(result)}")
161 | if result:
162 | logger.info(f"Sample result: {result[0]}")
163 |
164 |
165 | # Run tests if specified
166 | if __name__ == "__main__":
167 | # Run unittest tests
168 | unittest.main(argv=['first-arg-is-ignored'], exit=False)
169 |
--------------------------------------------------------------------------------
/coderetrx/static/codeql/installer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import sys
4 | import asyncio
5 | import shutil
6 | import tarfile
7 | from tempfile import TemporaryDirectory
8 | from pathlib import Path
9 | from typing import Optional, Tuple, Dict
10 |
11 | import httpx
12 |
13 | # CodeQL release information
14 | CODEQL_VERSION = "v2.23.1"
15 | CODEQL_REPO = "github/codeql-action"
16 | CODEQL_BASE_URL = (
17 | f"https://github.com/{CODEQL_REPO}/releases/download/codeql-bundle-{CODEQL_VERSION}"
18 | )
19 |
20 |
21 | def get_platform_info() -> Optional[str]:
22 | """Get platform information and return the corresponding download filename
23 |
24 | Returns:
25 | Optional[str]: Matching platform filename or None if unsupported
26 | """
27 | system = platform.system().lower()
28 |
29 | # Map platform to CodeQL release filename
30 | platform_map: Dict[str, str] = {
31 | "darwin": f"codeql-bundle-osx64.tar.gz",
32 | "linux": f"codeql-bundle-linux64.tar.gz",
33 | }
34 |
35 | return platform_map.get(system)
36 |
37 |
38 | async def download_file(url: str, dest_path: Path) -> None:
39 | """Download a file from a URL to a destination path."""
40 | async with httpx.AsyncClient(follow_redirects=True) as client:
41 | async with client.stream("GET", url) as response:
42 | response.raise_for_status()
43 | with open(dest_path, "wb") as f:
44 | async for chunk in response.aiter_bytes():
45 | f.write(chunk)
46 |
47 |
48 | def extract_codeql_bundle(archive_path: Path, dest_dir: Path) -> bool:
49 | """Extract the entire CodeQL bundle to destination directory
50 |
51 | Args:
52 | archive_path: Path to the archive file
53 | dest_dir: Destination directory for extraction
54 |
55 | Returns:
56 | bool: True if extraction succeeded, False otherwise
57 | """
58 | try:
59 | with tarfile.open(archive_path, "r:gz") as tar:
60 | # Extract all contents
61 | tar.extractall(path=dest_dir)
62 | return True
63 | except Exception as e:
64 | print(f"Error extracting CodeQL bundle: {e}")
65 | return False
66 |
67 |
68 | async def install_codeql(install_path: Path) -> Optional[Path]:
69 | """Main installation function
70 |
71 | Args:
72 | install_path: Path where CodeQL should be installed (e.g., /opt/codeql)
73 |
74 | Returns:
75 | Optional[Path]: Path to the installed CodeQL CLI binary or None if failed
76 | """
77 | # Get appropriate CodeQL filename for current platform
78 | codeql_file = get_platform_info()
79 | if not codeql_file:
80 | print("Error: Unsupported platform")
81 | print(f"System: {platform.system()}")
82 | print("CodeQL installer only supports Linux and macOS")
83 | return None
84 |
85 | # Create install directory if it doesn't exist
86 | install_path.mkdir(parents=True, exist_ok=True)
87 |
88 | with TemporaryDirectory() as temp_dir:
89 | try:
90 | # Download the file
91 | download_url = f"{CODEQL_BASE_URL}/{codeql_file}"
92 | archive_path = Path(temp_dir) / codeql_file
93 |
94 | print(f"Downloading CodeQL from: {download_url}")
95 | await download_file(download_url, archive_path)
96 | print("Download completed")
97 |
98 | # Extract the bundle to temp directory
99 | temp_extract_dir = Path(temp_dir) / "extracted"
100 | temp_extract_dir.mkdir()
101 |
102 | print("Extracting CodeQL bundle...")
103 | if not extract_codeql_bundle(archive_path, temp_extract_dir):
104 | print("Error: Could not extract CodeQL bundle")
105 | return None
106 |
107 | # Find the codeql directory in extracted contents
108 | # The bundle typically extracts to a 'codeql' directory
109 | codeql_dir = temp_extract_dir / "codeql"
110 | if not codeql_dir.exists():
111 | # Look for any directory that might contain CodeQL
112 | extracted_dirs = [d for d in temp_extract_dir.iterdir() if d.is_dir()]
113 | if extracted_dirs:
114 | codeql_dir = extracted_dirs[0]
115 | else:
116 | print("Error: Could not find CodeQL directory in extracted bundle")
117 | return None
118 |
119 | # Verify that the codeql binary exists
120 | binary_name = "codeql.exe" if os.name == "nt" else "codeql"
121 | codeql_binary = codeql_dir / binary_name
122 | if not codeql_binary.exists():
123 | print(f"Error: Could not find {binary_name} in the extracted bundle")
124 | return None
125 |
126 | # Remove existing installation if it exists
127 | if install_path.exists():
128 | print(f"Removing existing CodeQL installation at {install_path}")
129 | shutil.rmtree(install_path)
130 |
131 | # Move the entire codeql directory to the install location
132 | shutil.move(str(codeql_dir), str(install_path))
133 |
134 | # Set executable permissions for the binary (Unix-like systems)
135 | final_binary = install_path / binary_name
136 | if os.name != "nt":
137 | os.chmod(final_binary, 0o755)
138 |
139 | print("Installation complete!")
140 | print(f"CodeQL has been installed to: {install_path}")
141 | print(f"CodeQL CLI binary: {final_binary}")
142 | print(
143 | f"To use CodeQL, add {install_path} to your PATH or use the full path: {final_binary}"
144 | )
145 |
146 | return final_binary
147 |
148 | except httpx.HTTPError as e:
149 | print(f"Download failed: {e}")
150 | except Exception as e:
151 | print(f"An unexpected error occurred: {e}")
152 |
153 | return None
154 |
155 |
156 | if __name__ == "__main__":
157 | import argparse
158 |
159 | parser = argparse.ArgumentParser(description="Install CodeQL CLI")
160 | parser.add_argument(
161 | "--install-path",
162 | type=Path,
163 | default=Path("/opt/codeql"),
164 | help="Installation path for CodeQL (default: /opt/codeql)",
165 | )
166 |
167 | args = parser.parse_args()
168 |
169 | # Check if we have write permissions to the install path
170 | try:
171 | args.install_path.parent.mkdir(parents=True, exist_ok=True)
172 | if not os.access(args.install_path.parent, os.W_OK):
173 | print(f"Error: No write permission to {args.install_path.parent}")
174 | print(
175 | "You may need to run this script with sudo or choose a different install path"
176 | )
177 | sys.exit(1)
178 | except Exception as e:
179 | print(f"Error checking install path: {e}")
180 | sys.exit(1)
181 |
182 | installed_path = asyncio.run(install_codeql(args.install_path))
183 | if installed_path:
184 | print(f"Successfully installed CodeQL CLI at: {installed_path}")
185 | else:
186 | print("Failed to install CodeQL")
187 | sys.exit(1)
188 |
--------------------------------------------------------------------------------
/coderetrx/static/codebase/parsers/factory.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Dict, Any, List
2 | import logging
3 | from pathlib import Path
4 |
5 | from .base import CodebaseParser, UnsupportedLanguageError
6 | from .treesitter import TreeSitterParser
7 | from .codeql import CodeQLParser
8 | from ..languages import IDXSupportedLanguage
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class ParserFactory:
14 | """
15 | Factory for creating and managing codebase parsers.
16 |
17 | Supports auto-selection of parsers based on availability and preferences,
18 | with fallback mechanisms for robustness.
19 | """
20 |
21 | # Parser priority order for auto-selection
22 | PARSER_PRIORITY = ["treesitter", "codeql"]
23 |
24 | @classmethod
25 | def get_parser(cls, parser_type: str = "auto", **kwargs) -> CodebaseParser:
26 | """
27 | Get a parser instance.
28 |
29 | Args:
30 | parser_type: Type of parser to create:
31 | - "auto": Auto-select best available parser
32 | - "codeql": CodeQL parser
33 | - "treesitter": Tree-sitter parser
34 | - "hybrid": Use CodeQL where available, fallback to tree-sitter
35 | **kwargs: Parser-specific configuration options
36 |
37 | Returns:
38 | CodebaseParser instance
39 |
40 | Raises:
41 | ValueError: If parser_type is invalid
42 | RuntimeError: If no suitable parser is available
43 | """
44 | if parser_type == "auto":
45 | return cls._auto_select_parser(**kwargs)
46 | elif parser_type == "codeql":
47 | return cls._create_codeql_parser(**kwargs)
48 | elif parser_type == "treesitter":
49 | return cls._create_treesitter_parser(**kwargs)
50 | else:
51 | raise ValueError(f"Unknown parser type: {parser_type}")
52 |
53 | @classmethod
54 | def _auto_select_parser(cls, **kwargs) -> CodebaseParser:
55 | """
56 | Auto-select the best available parser.
57 |
58 | Tries parsers in priority order and returns the first working one.
59 | """
60 | for parser_name in cls.PARSER_PRIORITY:
61 | try:
62 | if parser_name == "codeql":
63 | parser = cls._create_codeql_parser(**kwargs)
64 | # Test if CodeQL is actually available
65 | if cls._test_codeql_availability(parser):
66 | logger.info("Auto-selected CodeQL parser")
67 | return parser
68 | else:
69 | logger.info("CodeQL CLI not available, trying next parser")
70 | continue
71 | elif parser_name == "treesitter":
72 | parser = cls._create_treesitter_parser(**kwargs)
73 | logger.info("Auto-selected Tree-sitter parser")
74 | return parser
75 | except Exception as e:
76 | logger.debug(f"Failed to create {parser_name} parser: {e}")
77 | continue
78 |
79 | raise RuntimeError("No suitable parser available")
80 |
81 | @classmethod
82 | def _create_codeql_parser(cls, **kwargs) -> CodeQLParser:
83 | """Create a CodeQL parser instance."""
84 | return CodeQLParser(**kwargs)
85 |
86 | @classmethod
87 | def _create_treesitter_parser(cls, **kwargs) -> TreeSitterParser:
88 | """Create a Tree-sitter parser instance."""
89 | return TreeSitterParser(**kwargs)
90 |
91 | @classmethod
92 | def _test_codeql_availability(cls, parser: CodeQLParser) -> bool:
93 | """
94 | Test if CodeQL is actually available and working.
95 |
96 | Args:
97 | parser: CodeQL parser instance to test
98 |
99 | Returns:
100 | True if CodeQL is available, False otherwise
101 | """
102 | try:
103 | # Test basic CodeQL functionality
104 | supported_langs = parser.get_supported_languages()
105 | return len(supported_langs) > 0
106 | except Exception as e:
107 | logger.debug(f"CodeQL availability test failed: {e}")
108 | return False
109 |
110 | @classmethod
111 | def get_available_parsers(cls) -> Dict[str, bool]:
112 | """
113 | Get status of all available parsers.
114 |
115 | Returns:
116 | Dictionary mapping parser names to availability status
117 | """
118 | status = {}
119 |
120 | # Test Tree-sitter
121 | try:
122 | parser = cls._create_treesitter_parser()
123 | status["treesitter"] = len(parser.get_supported_languages()) > 0
124 | except Exception:
125 | status["treesitter"] = False
126 |
127 | # Test CodeQL
128 | try:
129 | parser = cls._create_codeql_parser()
130 | status["codeql"] = cls._test_codeql_availability(parser)
131 | except Exception:
132 | status["codeql"] = False
133 |
134 | return status
135 |
136 | @classmethod
137 | def recommend_parser(
138 | cls, languages: Optional[List[IDXSupportedLanguage]] = None
139 | ) -> str:
140 | """
141 | Recommend the best parser for given languages.
142 |
143 | Args:
144 | languages: List of languages to support (None for all)
145 |
146 | Returns:
147 | Recommended parser name
148 | """
149 | available = cls.get_available_parsers()
150 |
151 | if not languages:
152 | # If no specific languages, prefer CodeQL if available
153 | if available.get("codeql", False):
154 | return "codeql"
155 | elif available.get("treesitter", False):
156 | return "treesitter"
157 | else:
158 | return "auto" # Let auto-selection handle the error
159 |
160 | # Check language support for each parser
161 | codeql_coverage = 0
162 | treesitter_coverage = 0
163 |
164 | if available.get("codeql", False):
165 | try:
166 | codeql_parser = cls._create_codeql_parser()
167 | codeql_supported = set(codeql_parser.get_supported_languages())
168 | codeql_coverage = len(set(languages).intersection(codeql_supported))
169 | except Exception:
170 | pass
171 |
172 | if available.get("treesitter", False):
173 | try:
174 | treesitter_parser = cls._create_treesitter_parser()
175 | treesitter_supported = set(treesitter_parser.get_supported_languages())
176 | treesitter_coverage = len(
177 | set(languages).intersection(treesitter_supported)
178 | )
179 | except Exception:
180 | pass
181 |
182 | # Prefer CodeQL if it covers more languages, otherwise tree-sitter
183 | if codeql_coverage > treesitter_coverage:
184 | return "codeql"
185 | elif treesitter_coverage > 0:
186 | return "treesitter"
187 | elif codeql_coverage > 0:
188 | return "codeql"
189 | else:
190 | return "hybrid" # Use hybrid for maximum coverage
191 |
--------------------------------------------------------------------------------