├── codebased ├── __init__.py ├── constants.py ├── GREETING.txt ├── exceptions.py ├── utils.py ├── editor.py ├── migrations │ └── 000_initial.sql ├── models.py ├── filesystem.py ├── embeddings.py ├── background_worker.py ├── storage.py ├── stats.py ├── settings.py ├── main.py ├── gitignore.py ├── tui.py ├── search.py ├── parser.py └── index.py ├── assets ├── editor.png ├── search.png └── empty_search.png ├── Makefile ├── .github └── workflows │ ├── push.yml │ └── release.yml ├── LICENSE ├── pyproject.toml ├── README.md ├── .gitignore └── tests ├── test_parser.py └── test_main.py /codebased/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/editor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codebased-sh/codebased/HEAD/assets/editor.png -------------------------------------------------------------------------------- /assets/search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codebased-sh/codebased/HEAD/assets/search.png -------------------------------------------------------------------------------- /assets/empty_search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codebased-sh/codebased/HEAD/assets/empty_search.png -------------------------------------------------------------------------------- /codebased/constants.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | DEFAULT_MODEL = 'text-embedding-3-large' 4 | DEFAULT_MODEL_DIMENSIONS = 256 5 | EMBEDDING_MODEL_CONTEXT_LENGTH = 8192 6 | DEFAULT_EDITOR = "vi" 7 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .gitignore: 2 | curl -s https://raw.githubusercontent.com/github/gitignore/main/Python.gitignore > .gitignore 3 | curl -s https://raw.githubusercontent.com/github/gitignore/main/Global/JetBrains.gitignore >> .gitignore 4 | 5 | package-build: 6 | poetry build 7 | 8 | package-publish: package-build 9 | poetry publish 10 | 11 | test: 12 | poetry run pytest -------------------------------------------------------------------------------- /codebased/GREETING.txt: -------------------------------------------------------------------------------- 1 | I'm Max Conradt, founder of Codebased. 2 | If you ever need any help, text me at +1 (913) 808-7343. 3 | You can also check out the Discord: https://discord.gg/cQrQCAKZ. 4 | The Discord server has an OpenAI key for early adopters to use: https://discord.com/channels/1276709641073590364/1276709641744814132/1282513891787800659. 5 | There's a limit of $10 in usage per day, but that should be sufficient for regular use, even on large projects. -------------------------------------------------------------------------------- /codebased/exceptions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | class CodebasedException(Exception): 4 | """ 5 | Differentiate between business logic exceptions and Knightian exceptions. 6 | """ 7 | pass 8 | 9 | 10 | class MissingConfigFileException(CodebasedException): 11 | """ 12 | Raised when the config directory is not found. 13 | """ 14 | pass 15 | 16 | class NotFoundException(CodebasedException, LookupError): 17 | """ 18 | Raised when something is not found. 19 | """ 20 | 21 | def __init__(self, identifier: object): 22 | self.identifier = identifier 23 | super().__init__(identifier) 24 | -------------------------------------------------------------------------------- /codebased/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import chardet 3 | 4 | 5 | def decode_text(file_bytes: bytes) -> str: 6 | # Try UTF-8 first 7 | try: 8 | return file_bytes.decode('utf-8') 9 | except UnicodeDecodeError: 10 | pass 11 | 12 | # Use chardet to detect encoding 13 | detected = chardet.detect(file_bytes) 14 | if detected['encoding']: 15 | try: 16 | return file_bytes.decode(detected['encoding']) 17 | except UnicodeDecodeError: 18 | pass 19 | 20 | # If chardet fails, try some common encodings 21 | for encoding in ['windows-1252', 'iso-8859-1']: 22 | try: 23 | return file_bytes.decode(encoding) 24 | except UnicodeDecodeError: 25 | continue 26 | 27 | # If all else fails, use 'replace' error handling with UTF-8 28 | return file_bytes.decode('utf-8', errors='replace') 29 | -------------------------------------------------------------------------------- /codebased/editor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import subprocess 4 | from pathlib import Path 5 | from typing import Literal 6 | 7 | VSCODE_STYLE_EDITORS = {"code", "cursor"} 8 | VIM_STYLE_EDITORS = {"vi", "nvim", "vim"} 9 | 10 | 11 | def suspends(editor: Literal["vi", "idea", "code"]) -> bool: 12 | return editor in VIM_STYLE_EDITORS 13 | 14 | 15 | def open_editor(editor: Literal["vi", "idea", "code"], *, file: Path, row: int, column: int): 16 | line_number = row + 1 17 | if editor in VIM_STYLE_EDITORS: 18 | subprocess.run([editor, str(file), f"+{line_number}"]) 19 | elif editor == "idea": 20 | subprocess.run(["idea", "--line", str(line_number), str(file)]) 21 | elif editor in VSCODE_STYLE_EDITORS: 22 | subprocess.run([editor, "--goto", f"{file}:{line_number}:{column}"]) 23 | else: 24 | raise NotImplementedError(editor) 25 | 26 | 27 | Editor = Literal["vi", "vim", "nvim", "idea", "code", "cursor"] 28 | 29 | ALLOWED_EDITORS = {"idea", *VSCODE_STYLE_EDITORS, *VIM_STYLE_EDITORS} 30 | -------------------------------------------------------------------------------- /.github/workflows/push.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | python-version: [ "3.9", "3.10", "3.11", "3.12" ] 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v3 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | curl -sSL https://install.python-poetry.org | python - 28 | poetry install --with=dev 29 | - name: Test with pytest 30 | env: 31 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 32 | run: | 33 | poetry run pytest -vvv -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Max Conradt 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /codebased/migrations/000_initial.sql: -------------------------------------------------------------------------------- 1 | create table file 2 | ( 3 | -- Note: this is relative to the repository root. 4 | path text primary key, 5 | size_bytes integer, 6 | -- st_mtime_ns 7 | last_modified_ns integer, 8 | -- sha256 digest bytes 9 | sha256_digest blob 10 | ); 11 | 12 | create table object 13 | ( 14 | id integer primary key, 15 | path text, 16 | name text, 17 | language text, 18 | context_before text, 19 | context_after text, 20 | kind text, 21 | byte_range text, 22 | coordinates text, 23 | foreign key (path) references file (path) 24 | ); 25 | 26 | create index object_path_index on object (path); 27 | 28 | create table embedding 29 | ( 30 | object_id integer primary key, 31 | data blob, 32 | content_sha256 blob, 33 | foreign key (object_id) references object (id) 34 | ); 35 | 36 | create index embedding_content_sha256_index on embedding (content_sha256); 37 | 38 | -- rowid is object id 39 | create virtual table fts using fts5(path, name, content, tokenize="trigram"); -------------------------------------------------------------------------------- /codebased/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import typing as T 5 | from pathlib import Path 6 | 7 | 8 | @dataclasses.dataclass 9 | class Object: 10 | path: Path 11 | name: str 12 | language: str 13 | context_before: list[int] 14 | context_after: list[int] 15 | kind: str 16 | # text: bytes # This field is an excellent candidate for removal / using a memoryview. 17 | byte_range: T.Tuple[int, int] # [start, end) 18 | coordinates: Coordinates 19 | id: T.Union[int, None] = None 20 | 21 | def __len__(self): 22 | start, end = self.byte_range 23 | return end - start 24 | 25 | @property 26 | def line_length(self) -> int: 27 | return self.coordinates[1][0] - self.coordinates[0][0] + 1 28 | 29 | 30 | @dataclasses.dataclass 31 | class EmbeddingRequest: 32 | object_id: int 33 | content: str 34 | content_hash: str 35 | 36 | 37 | @dataclasses.dataclass 38 | class Embedding: 39 | object_id: int 40 | data: list[float] 41 | content_hash: str 42 | 43 | 44 | Coordinates = T.Tuple[T.Tuple[int, int], T.Tuple[int, int]] 45 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "codebased" 3 | version = "0.6.2" 4 | description = "AI-powered code search for the terminal and beyond." 5 | authors = ["Max Conradt "] 6 | readme = "README.md" 7 | repository = "https://github.com/codebased-sh/codebased" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.9" 11 | tree-sitter = "^0.23" 12 | openai = "^1.42.0" 13 | tree-sitter-c = "^0.23" 14 | tree-sitter-c-sharp = "^0.23" 15 | tree-sitter-cpp = "^0.23" 16 | tree-sitter-go = "^0.23" 17 | tree-sitter-java = "^0.23" 18 | tree-sitter-javascript = "^0.23" 19 | tree-sitter-php = "^0.23" 20 | tree-sitter-python = "^0.23" 21 | tree-sitter-ruby = "^0.23" 22 | tree-sitter-rust = "^0.23" 23 | tree-sitter-typescript = "^0.23" 24 | faiss-cpu = "^1.8.0.post1" 25 | toml = "^0.10.2" 26 | tiktoken = "^0.7.0" 27 | colorama = "^0.4.6" 28 | watchdog = "^5.0.0" 29 | chardet = "^5.2.0" 30 | textual = "^0.79.1" 31 | typer = "^0.12.5" 32 | 33 | 34 | [tool.poetry.scripts] 35 | codebased = "codebased.main:cli" 36 | 37 | [tool.poetry.group.dev.dependencies] 38 | ipython = "^7.0" 39 | textual-dev = "^1.6.1" 40 | pytest = "^8.3.3" 41 | 42 | [build-system] 43 | requires = ["poetry-core"] 44 | build-backend = "poetry.core.masonry.api" 45 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [ published ] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | curl -sSL https://install.python-poetry.org | python - 32 | poetry install --with=dev 33 | - name: Build package 34 | run: poetry build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /codebased/filesystem.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import queue 5 | from pathlib import Path 6 | 7 | import watchdog.events 8 | import watchdog.observers 9 | 10 | _OBSERVER = watchdog.observers.Observer() 11 | 12 | 13 | @dataclasses.dataclass 14 | class EventWrapper: 15 | event: watchdog.events.FileSystemEvent 16 | time: float 17 | 18 | 19 | class QueueEventHandler(watchdog.events.FileSystemEventHandler): 20 | def __init__(self, q: queue.Queue[Path]): 21 | self.q = q 22 | 23 | def on_any_event(self, event: watchdog.events.FileSystemEvent): 24 | if event.is_directory: 25 | return 26 | for path in get_paths(event): 27 | self.q.put(path) 28 | 29 | 30 | def get_filesystem_events_queue(root: Path) -> queue.Queue[Path]: 31 | observer = _OBSERVER 32 | q = queue.Queue() 33 | observer.schedule(QueueEventHandler(q), root, recursive=True) 34 | observer.start() 35 | return q 36 | 37 | 38 | def get_paths( 39 | event: watchdog.events.FileSystemEvent 40 | ) -> list[Path]: 41 | abs_paths: list[str] = [] 42 | if event.event_type == 'moved': 43 | abs_paths = [event.src_path, event.dest_path] 44 | elif event.event_type == 'created': 45 | abs_paths = [event.src_path] 46 | elif event.event_type == 'deleted': 47 | abs_paths = [event.src_path] 48 | elif event.event_type == 'modified': 49 | abs_paths = [event.src_path] 50 | return [Path(path) for path in abs_paths] 51 | -------------------------------------------------------------------------------- /codebased/embeddings.py: -------------------------------------------------------------------------------- 1 | import typing as T 2 | 3 | if T.TYPE_CHECKING: 4 | from openai import OpenAI 5 | 6 | from codebased.settings import EmbeddingsConfig 7 | from codebased.models import Embedding, EmbeddingRequest 8 | from codebased.stats import STATS 9 | 10 | 11 | def get_embedding_kwargs(config: EmbeddingsConfig) -> dict: 12 | kwargs = {"model": config.model} 13 | if config.model in {'text-embedding-3-large', 'text-embedding-3-small'}: 14 | kwargs["dimensions"] = config.dimensions 15 | return kwargs 16 | 17 | 18 | def create_openai_embeddings_sync_batched( 19 | client: "OpenAI", 20 | requests: T.List[EmbeddingRequest], 21 | config: EmbeddingsConfig 22 | ) -> T.Iterable[Embedding]: 23 | with STATS.timer("codebased.embeddings.batch.duration"): 24 | text = [o.content for o in requests] 25 | response = client.embeddings.create(input=text, **get_embedding_kwargs(config)) 26 | STATS.increment("codebased.embeddings.usage.total_tokens", response.usage.total_tokens) 27 | return [ 28 | Embedding( 29 | object_id=o.object_id, 30 | data=e.embedding, 31 | content_hash=o.content_hash 32 | ) 33 | for o, e in zip(requests, response.data) 34 | ] 35 | 36 | 37 | def create_ephemeral_embedding(client: "OpenAI", text: str, config: EmbeddingsConfig) -> list[float]: 38 | with STATS.timer("codebased.embeddings.ephemeral.duration"): 39 | response = client.embeddings.create(input=text, **get_embedding_kwargs(config)) 40 | return response.data[0].embedding 41 | -------------------------------------------------------------------------------- /codebased/background_worker.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import threading 3 | import time 4 | from pathlib import Path 5 | 6 | from codebased.index import Dependencies, Config, index_paths 7 | from codebased.stats import STATS 8 | 9 | 10 | def background_worker( 11 | dependencies: Dependencies, 12 | config: Config, 13 | shutdown_event: threading.Event, 14 | event_queue: queue.Queue[Path] 15 | ): 16 | def pre_filter(event: Path) -> bool: 17 | if event.is_relative_to(config.codebased_directory): 18 | return False 19 | if event.is_relative_to(config.git_directory): 20 | return False 21 | if dependencies.ignore_checker(event): 22 | return False 23 | return True 24 | 25 | while not shutdown_event.is_set(): 26 | # Wait indefinitely for an event. 27 | try: 28 | events: list[Path] = [event_queue.get(timeout=1.0)] 29 | except (queue.Empty, TimeoutError): 30 | continue 31 | start = time.monotonic() 32 | loop_timeout = .1 33 | while time.monotonic() - start < loop_timeout: 34 | try: 35 | events.append(event_queue.get(timeout=loop_timeout)) 36 | except queue.Empty: 37 | break 38 | try: 39 | while not event_queue.empty(): 40 | events.append(event_queue.get(block=False)) 41 | except queue.Empty: 42 | pass 43 | # Don't create events when we write to the index, especially from this thread. 44 | events = [event for event in events if pre_filter(event)] 45 | if not events: 46 | continue 47 | if shutdown_event.is_set(): 48 | break 49 | STATS.increment("codebased.background_worker.updates.total") 50 | STATS.increment("codebased.background_worker.updates.events", len(events)) 51 | index_paths(dependencies, config, events, total=False) 52 | STATS.increment("codebased.background_worker.updates.index") 53 | -------------------------------------------------------------------------------- /codebased/storage.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | import re 6 | import sqlite3 7 | import struct 8 | from pathlib import Path 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def serialize_embedding_data(vector: list[float]) -> bytes: 14 | dimension = len(vector) 15 | return struct.pack(f'{dimension}f', *vector) 16 | 17 | 18 | def deserialize_embedding_data(data: bytes) -> list[float]: 19 | dimension = len(data) // struct.calcsize('f') 20 | return list(struct.unpack(f'{dimension}f', data)) 21 | 22 | 23 | class DatabaseMigrations: 24 | def __init__(self, db: sqlite3.Connection, migrations_directory: Path): 25 | self.db = db 26 | self.migrations_directory = migrations_directory 27 | 28 | def initialize(self): 29 | self.db.execute("CREATE TABLE IF NOT EXISTS migrations (version INTEGER PRIMARY KEY)") 30 | 31 | def get_current_version(self) -> int | None: 32 | cursor = self.db.execute("SELECT version FROM migrations ORDER BY version DESC LIMIT 1") 33 | version_row = cursor.fetchone() 34 | return version_row[0] if version_row else None 35 | 36 | def add_version(self, version: int): 37 | self.db.execute("INSERT INTO migrations (version) VALUES (?)", (version,)) 38 | 39 | def migrate(self): 40 | migration_file_names = [f for f in os.listdir(self.migrations_directory) if f.endswith(".sql")] 41 | migration_paths = [self.migrations_directory / file for file in migration_file_names] 42 | 43 | def get_migration_version(path: Path) -> int: 44 | return int(re.match(r'(\d+)', path.name).group(1)) 45 | 46 | migration_paths.sort(key=get_migration_version) 47 | current_version = self.get_current_version() 48 | for migration_path in migration_paths: 49 | version = get_migration_version(migration_path) 50 | if current_version is not None and current_version >= version: 51 | logger.debug(f"Skipping migration {migration_path}") 52 | continue 53 | logger.debug(f"Running migration {migration_path}") 54 | with open(migration_path) as f: 55 | migration_text = f.read() 56 | self.db.executescript(migration_text) 57 | self.add_version(version) 58 | self.db.commit() 59 | -------------------------------------------------------------------------------- /codebased/stats.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import dataclasses 4 | import threading 5 | import time 6 | import typing as T 7 | from collections import defaultdict 8 | from contextlib import contextmanager 9 | 10 | 11 | @dataclasses.dataclass 12 | class Stats: 13 | class Key(str, Enum): 14 | index_creation_embedding_tokens_consumed = "index.creation.embedding.tokens_consumed" 15 | 16 | _lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) 17 | counters: dict[str, float] = dataclasses.field(default_factory=lambda: defaultdict(float)) 18 | ratios: dict[str, tuple[int, int]] = dataclasses.field(default_factory=lambda: defaultdict(lambda: (0, 0))) 19 | 20 | def import_cache_info(self, key: str, cache_info): 21 | self.import_ratio( 22 | key, 23 | cache_info.hits, 24 | cache_info.hits + cache_info.misses 25 | ) 26 | 27 | def import_ratio(self, key: str, num: int, denom: int): 28 | with self._lock: 29 | self.ratios[key] = (num, denom) 30 | 31 | def increment(self, key: str, by: T.Union[int, float] = 1): 32 | with self._lock: 33 | self.counters[key] += by 34 | 35 | @contextmanager 36 | def timer(self, key: str): 37 | start = time.perf_counter() 38 | try: 39 | yield 40 | finally: 41 | self.increment(key, time.perf_counter() - start) 42 | 43 | def hit(self, key: str, yes: bool = True): 44 | with self._lock: 45 | if yes: 46 | self.ratios[key] = (self.ratios[key][0] + 1, self.ratios[key][1] + 1) 47 | else: 48 | self.ratios[key] = (self.ratios[key][0], self.ratios[key][1] + 1) 49 | 50 | @contextmanager 51 | def except_rate(self, key: str): 52 | try: 53 | yield 54 | except (Exception,): 55 | self.hit(key, yes=True) 56 | else: 57 | self.hit(key, yes=False) 58 | 59 | def dumps(self) -> str: 60 | lines = [f"Counters:"] 61 | for key, value in self.counters.items(): 62 | lines.append(f" {key}: {value}") 63 | ratio_lines = [f"Ratios:"] 64 | for key, (num, denom) in self.ratios.items(): 65 | if denom > 0: 66 | ratio_lines.append(f" {key}: {num / denom:.3f}") 67 | if len(ratio_lines) > 1: 68 | lines.extend(ratio_lines) 69 | return '\n'.join(lines) 70 | 71 | 72 | STATS = Stats() 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Codebased 2 | 3 | Codebased is the most powerful code search tool that runs on your computer. 4 | 5 | Here's why it's great: 6 | 7 | - Search, simplified: Combines semantic and full-text search to find what you're looking for, not just what you typed. 8 | - Search code, not just text: Searches for complete code structures, not just lines of text, in 11 languages. 9 | - Ignores uninteresting files: Respects your `.gitignore` file(s) and ignores hidden directories. 10 | - Instant editing: Selecting a search result opens your favorite editor at the correct file and line number. 11 | - Fast: Indexes your code in seconds, searches in milliseconds, and updates in real-time as you edit your code. 12 | - Open-source, runs locally: No code is stored on remote servers. 13 | 14 | ## Getting Started 15 | 16 | The fastest way to install Codebased is with [pipx](https://github.com/pypa/pipx?tab=readme-ov-file#install-pipx): 17 | 18 | ```shell 19 | pipx install codebased 20 | ``` 21 | 22 | Verify the installation by running: 23 | 24 | ```shell 25 | codebased --version 26 | ``` 27 | 28 | If this fails, please double-check your pipx installation. 29 | 30 | Next, run the following command in a Git repository to start searching: 31 | 32 | ```shell 33 | codebased search 34 | ``` 35 | 36 | The first time you run Codebased, it will create a configuration file at `~/.codebased/config.toml`. 37 | It will prompt you for an OpenAI key, you can access a testing key on the [Discord](https://discord.gg/cQrQCAKZ). 38 | 39 | Once this is finished, `codebased` will create an index of your codebase, stored in `.codebased` at the root of your 40 | repository. 41 | This takes seconds for most projects, but can take a few minutes for large projects. 42 | Most of the time is spent creating embeddings using the OpenAI API. 43 | 44 | Once the index is created, a terminal UI will open with a search bar. 45 | At this point, you can start typing a search query and results will appear as you type. 46 | 47 | - You can use the arrow keys and the mouse to navigate the results. 48 | - A preview of the selected result is displayed. 49 | - Pressing enter on the highlighted result opens the file in your editor at the correct line number. 50 | - Pressing escape returns to the search bar. 51 | - As you edit your code, the index will be updated in real-time, and future searches will reflect your changes. 52 | 53 | Codebased will run `stat` on all non-ignored files in your repository, which can take a few seconds, but after that 54 | will listen for filesystem events, so it's recommended to use the TUI. 55 | 56 | # Development 57 | 58 | If you'd like to contribute, bug fixes are welcome, as well as anything in the list 59 | of [issues](https://github.com/codebased-sh/codebased/issues). 60 | 61 | Especially welcome is support for your favorite language, as long as: 62 | 63 | 1. There's a tree-sitter grammar for it. 64 | 2. There are Python bindings for it maintained by the excellent [amaanq](https://pypi.org/user/amaanq/). 65 | 66 | Also, if there's anything ripgrep does that Codebased doesn't, feel free to file an issue / PR. 67 | 68 | Clone the repository: 69 | 70 | ```shell 71 | git clone https://github.com/codebased-sh/codebased.git 72 | ``` 73 | 74 | Install the project's dependencies (requires [poetry](https://python-poetry.org), using a virtual environment is 75 | recommended): 76 | 77 | ```shell 78 | poetry install 79 | ``` 80 | 81 | Run the tests (some tests require an `OPENAI_API_KEY` environment variable, usage is de minimis): 82 | 83 | ```shell 84 | poetry run pytest 85 | ``` 86 | 87 | # Appendix 88 | 89 | ## Languages 90 | 91 | - [X] C 92 | - [X] C# 93 | - [X] C++ 94 | - [X] Go 95 | - [X] Java 96 | - [X] JavaScript 97 | - [X] PHP 98 | - [X] Python 99 | - [X] Ruby 100 | - [X] Rust 101 | - [X] TypeScript 102 | - [ ] HTML / CSS 103 | - [ ] SQL 104 | - [ ] Shell 105 | - [ ] Swift 106 | - [ ] Lua 107 | - [ ] Kotlin 108 | - [ ] Dart 109 | - [ ] R 110 | - [ ] Assembly language 111 | - [ ] OCaml 112 | - [ ] Zig 113 | - [ ] Haskell 114 | - [ ] Elixir 115 | - [ ] Erlang 116 | - [ ] TOML? 117 | - [ ] YAML? -------------------------------------------------------------------------------- /codebased/settings.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import getpass 5 | import logging 6 | import os 7 | import sqlite3 8 | import sys 9 | import typing as T 10 | from pathlib import Path 11 | 12 | import toml 13 | 14 | from codebased.constants import DEFAULT_MODEL, DEFAULT_MODEL_DIMENSIONS, DEFAULT_EDITOR 15 | from codebased.exceptions import MissingConfigFileException 16 | from codebased.editor import Editor, ALLOWED_EDITORS 17 | 18 | logger = logging.getLogger(__name__) 19 | PACKAGE_DIR: Path = Path(__file__).parent 20 | CONFIG_DIRECTORY = Path.home() / ".codebased" 21 | CONFIG_FILE = CONFIG_DIRECTORY / "config.toml" 22 | 23 | 24 | @dataclasses.dataclass 25 | class EmbeddingsConfig: 26 | model: str = DEFAULT_MODEL 27 | dimensions: int = DEFAULT_MODEL_DIMENSIONS 28 | 29 | 30 | @dataclasses.dataclass 31 | class Settings: 32 | """ 33 | Combined class for Settings, Config, and Secrets 34 | """ 35 | embeddings: EmbeddingsConfig = dataclasses.field(default_factory=EmbeddingsConfig) 36 | editor: Editor = DEFAULT_EDITOR 37 | OPENAI_API_KEY: str = dataclasses.field(default_factory=lambda: os.environ.get("OPENAI_API_KEY")) 38 | 39 | @classmethod 40 | def always(cls) -> "Settings": 41 | try: 42 | cls.verify() 43 | except MissingConfigFileException: 44 | cls.create() 45 | return cls.load_file(CONFIG_FILE) 46 | 47 | def __post_init__(self): 48 | if not self.OPENAI_API_KEY: 49 | raise ValueError( 50 | "Codebased requires an OpenAI API key for now." 51 | "Join the Discord to access to a key for testing." 52 | ) 53 | 54 | @staticmethod 55 | def verify(): 56 | if not CONFIG_FILE.exists(): 57 | raise MissingConfigFileException() 58 | 59 | @classmethod 60 | def create(cls): 61 | if sys.stdin.isatty(): 62 | greet() 63 | CONFIG_DIRECTORY.mkdir(parents=True, exist_ok=True) 64 | CONFIG_FILE.touch() 65 | # TODO: Windows? 66 | if sys.stdin.isatty(): 67 | effective_defaults = cls.from_prompt() 68 | else: 69 | effective_defaults = Settings() 70 | effective_defaults.save(CONFIG_FILE) 71 | 72 | def ensure_ok(self): 73 | try: 74 | self.verify() 75 | except MissingConfigFileException: 76 | if sys.stdin.isatty(): 77 | print(f"Looks like you're new here, setting up {str(CONFIG_FILE)}.") 78 | self.create() 79 | 80 | @classmethod 81 | def load_file(cls, path: Path): 82 | with open(path) as f: 83 | data = toml.load(f) 84 | try: 85 | embeddings_config = EmbeddingsConfig(**data.pop('embeddings')) 86 | except KeyError: 87 | embeddings_config = EmbeddingsConfig() 88 | return cls(**data, embeddings=embeddings_config) 89 | 90 | @classmethod 91 | def from_prompt(cls): 92 | embedding_model = cls.prompt_default_model() 93 | dimensions = cls.prompt_default_dimensions() 94 | editor = cls.prompt_default_editor() 95 | env = os.getenv("OPENAI_API_KEY") 96 | if env: 97 | openai_api_key = getpass.getpass(f"What is your OpenAI API key? [OPENAI_API_KEY={env[:7]}...]: ") 98 | if not openai_api_key: 99 | openai_api_key = env 100 | else: 101 | openai_api_key = getpass.getpass("What is your OpenAI API key? ") 102 | return cls( 103 | embeddings=EmbeddingsConfig( 104 | model=embedding_model, 105 | dimensions=dimensions 106 | ), 107 | editor=editor, 108 | OPENAI_API_KEY=openai_api_key 109 | ) 110 | 111 | @classmethod 112 | def prompt_default_model(cls) -> str: 113 | embedding_model = input(f"What model do you want to use for embeddings? [{DEFAULT_MODEL}]: ") 114 | return embedding_model if embedding_model else DEFAULT_MODEL 115 | 116 | @classmethod 117 | def prompt_default_dimensions(cls) -> int: 118 | text = input(f"What dimensions do you want to use for embeddings? [{DEFAULT_MODEL_DIMENSIONS}]: ") 119 | dimensions = int(text) if text else DEFAULT_MODEL_DIMENSIONS 120 | return dimensions 121 | 122 | @classmethod 123 | def prompt_default_editor(cls) -> Editor: 124 | prompt = f"What editor do you want to use? ({'|'.join(sorted(ALLOWED_EDITORS))}) [{DEFAULT_EDITOR}]: " 125 | while True: 126 | editor = input(prompt) 127 | if not editor: 128 | return DEFAULT_EDITOR 129 | if editor in ALLOWED_EDITORS: 130 | return T.cast(Editor, editor) 131 | print("Invalid editor. Try again.") 132 | 133 | def save(self, path: Path): 134 | with open(path, 'w') as f: 135 | toml.dump(dataclasses.asdict(self), f) 136 | 137 | 138 | def get_db(database_file: Path) -> sqlite3.Connection: 139 | db = sqlite3.connect(database_file, check_same_thread=False) 140 | db.row_factory = sqlite3.Row 141 | return db 142 | 143 | 144 | def greet(): 145 | with open(PACKAGE_DIR / "GREETING.txt") as f: 146 | print(f.read()) 147 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 164 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 165 | 166 | # User-specific stuff 167 | .idea/**/workspace.xml 168 | .idea/**/tasks.xml 169 | .idea/**/usage.statistics.xml 170 | .idea/**/dictionaries 171 | .idea/**/shelf 172 | 173 | # AWS User-specific 174 | .idea/**/aws.xml 175 | 176 | # Generated files 177 | .idea/**/contentModel.xml 178 | 179 | # Sensitive or high-churn files 180 | .idea/**/dataSources/ 181 | .idea/**/dataSources.ids 182 | .idea/**/dataSources.local.xml 183 | .idea/**/sqlDataSources.xml 184 | .idea/**/dynamic.xml 185 | .idea/**/uiDesigner.xml 186 | .idea/**/dbnavigator.xml 187 | 188 | # Gradle 189 | .idea/**/gradle.xml 190 | .idea/**/libraries 191 | 192 | # Gradle and Maven with auto-import 193 | # When using Gradle or Maven with auto-import, you should exclude module files, 194 | # since they will be recreated, and may cause churn. Uncomment if using 195 | # auto-import. 196 | # .idea/artifacts 197 | # .idea/compiler.xml 198 | # .idea/jarRepositories.xml 199 | # .idea/modules.xml 200 | # .idea/*.iml 201 | # .idea/modules 202 | # *.iml 203 | # *.ipr 204 | 205 | # CMake 206 | cmake-build-*/ 207 | 208 | # Mongo Explorer plugin 209 | .idea/**/mongoSettings.xml 210 | 211 | # File-based project format 212 | *.iws 213 | 214 | # IntelliJ 215 | out/ 216 | 217 | # mpeltonen/sbt-idea plugin 218 | .idea_modules/ 219 | 220 | # JIRA plugin 221 | atlassian-ide-plugin.xml 222 | 223 | # Cursive Clojure plugin 224 | .idea/replstate.xml 225 | 226 | # SonarLint plugin 227 | .idea/sonarlint/ 228 | 229 | # Crashlytics plugin (for Android Studio and IntelliJ) 230 | com_crashlytics_export_strings.xml 231 | crashlytics.properties 232 | crashlytics-build.properties 233 | fabric.properties 234 | 235 | # Editor-based Rest Client 236 | .idea/httpRequests 237 | 238 | # Android studio 3.1+ serialized cache file 239 | .idea/caches/build_file_checksums.ser 240 | GBPs.txt 241 | .idea/ 242 | codebased.iml 243 | .codebased/ -------------------------------------------------------------------------------- /codebased/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | import sqlite3 5 | import sys 6 | import threading 7 | from pathlib import Path 8 | 9 | import typer 10 | 11 | from codebased.background_worker import background_worker 12 | from codebased.filesystem import get_filesystem_events_queue 13 | from codebased.index import Config, Dependencies, index_paths, Flags 14 | from codebased.search import search_once, print_results 15 | from codebased.settings import Settings 16 | from codebased.stats import STATS 17 | from codebased.tui import Codebased 18 | 19 | VERSION = "0.6.2" 20 | 21 | cli = typer.Typer( 22 | name="Codebased CLI", 23 | ) 24 | 25 | 26 | def version_callback(value: bool): 27 | if value: 28 | print(f"Codebased {VERSION}") # Replace with actual version 29 | raise typer.Exit() 30 | 31 | 32 | @cli.callback() 33 | def main( 34 | version: bool = typer.Option( 35 | None, 36 | "-v", "-V", "--version", 37 | callback=version_callback, 38 | is_eager=True, 39 | help="Show the version and exit.", 40 | ), 41 | ): 42 | pass 43 | 44 | 45 | @cli.command("debug") 46 | def debug(): 47 | import faiss 48 | import openai 49 | import sqlite3 50 | components: dict[str, str] = { 51 | "Codebased": VERSION, 52 | "Python": sys.version, 53 | "SQLite": sqlite3.sqlite_version, 54 | "FAISS": faiss.__version__, 55 | "OpenAI": openai.__version__, 56 | } 57 | lines = [f"{key}: {value}" for key, value in components.items()] 58 | typer.echo("\n".join(lines)) 59 | 60 | 61 | @cli.command("search") 62 | def search( 63 | query: str = typer.Argument("", help="The search query"), 64 | directory: Path = typer.Option( 65 | Path.cwd(), 66 | "-d", 67 | "--directory", 68 | help="Directory to search in.", 69 | exists=True, 70 | file_okay=False, 71 | dir_okay=True, 72 | readable=True, 73 | resolve_path=True, 74 | allow_dash=True, 75 | ), 76 | rebuild_faiss_index: bool = typer.Option( 77 | False, 78 | help="Rebuild the FAISS index.", 79 | is_flag=True, 80 | ), 81 | cached_only: bool = typer.Option( 82 | False, 83 | help="Only read from cache. Avoids running stat on every file / reading files.", 84 | is_flag=True, 85 | ), 86 | stats: bool = typer.Option( 87 | False, 88 | help="Print stats.", 89 | is_flag=True, 90 | ), 91 | semantic: bool = typer.Option( 92 | True, 93 | "--semantic-search/--no-semantic-search", 94 | help="Use semantic search.", 95 | ), 96 | full_text: bool = typer.Option( 97 | True, 98 | "--full-text-search/--no-full-text-search", 99 | help="Use full-text search.", 100 | ), 101 | top_k: int = typer.Option( 102 | 32, 103 | "-k", 104 | "--top-k", 105 | help="Number of results to return.", 106 | min=1, 107 | ), 108 | background: bool = typer.Option( 109 | True, 110 | "--background/--no-background", 111 | help="Run in the background.", 112 | ), 113 | rerank: bool = typer.Option( 114 | True, 115 | "--rerank/--no-rerank", 116 | help="Rerank results.", 117 | ), 118 | radius: float = typer.Option( 119 | # There's a systematic difference between the "search query" vector space and the "code" vector space. 120 | # The pairwise distance between two code embedding vectors from the same codebase often clusters around 1.2. 121 | # This distance is also Chi-distributed, which means we can use the CDF to determine a distance cutoff. 122 | # This can be higher or lower depending on whether the codebase is written in multiple languages. 123 | # However, there's a systematic difference between "query" vector space and "code" vector space. 124 | # This means we need to increase the radius to account for this discrepancy. 125 | # In the future, we could (deep) learn a projection from "query" vector space to "code" vector space. 126 | # or even project each space to a unified space. 127 | math.sqrt(2), 128 | "--radius", 129 | help="Maximum L2 distance for semantic search. The higher this is, the more results there are.", 130 | min=0.0, 131 | max=2.0, 132 | ), 133 | ): 134 | sqlite_version = tuple(map(int, sqlite3.sqlite_version.split('.'))) 135 | if sqlite_version < (3, 34, 0): 136 | typer.echo(f"Codebased requires SQLite 3.9.0 or higher, found {sqlite3.sqlite_version}.", err=True) 137 | raise typer.Exit(1) 138 | flags = Flags( 139 | directory=directory, 140 | rebuild_faiss_index=rebuild_faiss_index, 141 | cached_only=cached_only, 142 | query=query, 143 | # We gather the set of gitignore files during startup and don't specially cache these. 144 | # So if we ran the background worker without gathering the .gitignore files, we would not properly ignore 145 | # changed files. 146 | background=background and not cached_only, 147 | stats=stats, 148 | semantic=semantic, 149 | full_text_search=full_text, 150 | top_k=top_k, 151 | rerank=rerank, 152 | radius=radius 153 | ) 154 | config = Config(flags=flags) 155 | settings = Settings.always() 156 | dependencies = Dependencies(config=config, settings=settings) 157 | __ = config.root, dependencies.db, dependencies.index 158 | fs_events = get_filesystem_events_queue(config.root) 159 | shutdown_event = threading.Event() 160 | if flags.background: 161 | thread = threading.Thread( 162 | target=background_worker, 163 | args=(dependencies, config, shutdown_event, fs_events), 164 | daemon=True 165 | ) 166 | else: 167 | thread = None 168 | 169 | try: 170 | if not flags.cached_only: 171 | index_paths(dependencies, config, [config.root], total=True) 172 | if thread is not None: 173 | thread.start() 174 | if flags.query: 175 | results, times = search_once(dependencies, flags) 176 | print_results(config, flags, results) 177 | else: 178 | Codebased(flags=flags, config=config, dependencies=dependencies).run() 179 | finally: 180 | dependencies.db.close() 181 | shutdown_event.set() 182 | if thread is not None and thread.is_alive(): 183 | thread.join() 184 | if flags.stats: 185 | print(STATS.dumps()) 186 | 187 | 188 | if __name__ == '__main__': 189 | cli() 190 | -------------------------------------------------------------------------------- /codebased/gitignore.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/mherrmann/gitignore_parser 2 | 3 | import collections 4 | import os 5 | import re 6 | 7 | from os.path import abspath, dirname 8 | from pathlib import Path 9 | from typing import Reversible, Union 10 | 11 | 12 | def handle_negation(file_path, rules: Reversible["IgnoreRule"]): 13 | for rule in reversed(rules): 14 | if rule.match(file_path): 15 | return not rule.negation 16 | return None 17 | 18 | 19 | def parse_gitignore(full_path, base_dir=None): 20 | if base_dir is None: 21 | base_dir = dirname(full_path) 22 | rules = [] 23 | with open(full_path) as ignore_file: 24 | counter = 0 25 | for line in ignore_file: 26 | counter += 1 27 | line = line.rstrip('\n') 28 | rule = rule_from_pattern( 29 | line, base_path=_normalize_path(base_dir), 30 | source=(full_path, counter) 31 | ) 32 | if rule: 33 | rules.append(rule) 34 | if not any(r.negation for r in rules): 35 | return lambda file_path: any(r.match(file_path) for r in rules) or None 36 | else: 37 | # We have negation rules. We can't use a simple "any" to evaluate them. 38 | # Later rules override earlier rules. 39 | return lambda file_path: handle_negation(file_path, rules) 40 | 41 | 42 | def rule_from_pattern(pattern, base_path=None, source=None): 43 | """ 44 | Take a .gitignore match pattern, such as "*.py[cod]" or "**/*.bak", 45 | and return an IgnoreRule suitable for matching against files and 46 | directories. Patterns which do not match files, such as comments 47 | and blank lines, will return None. 48 | Because git allows for nested .gitignore files, a base_path value 49 | is required for correct behavior. The base path should be absolute. 50 | """ 51 | # Store the exact pattern for our repr and string functions 52 | orig_pattern = pattern 53 | # Early returns follow 54 | # Discard comments and separators 55 | if pattern.strip() == '' or pattern[0] == '#': 56 | return 57 | # Strip leading bang before examining double asterisks 58 | if pattern[0] == '!': 59 | negation = True 60 | pattern = pattern[1:] 61 | else: 62 | negation = False 63 | # Multi-asterisks not surrounded by slashes (or at the start/end) should 64 | # be treated like single-asterisks. 65 | pattern = re.sub(r'([^/])\*{2,}', r'\1*', pattern) 66 | pattern = re.sub(r'\*{2,}([^/])', r'*\1', pattern) 67 | 68 | # Special-casing '/', which doesn't match any files or directories 69 | if pattern.rstrip() == '/': 70 | return 71 | 72 | directory_only = pattern[-1] == '/' 73 | # A slash is a sign that we're tied to the base_path of our rule 74 | # set. 75 | anchored = '/' in pattern[:-1] 76 | if pattern[0] == '/': 77 | pattern = pattern[1:] 78 | if pattern[0] == '*' and len(pattern) >= 2 and pattern[1] == '*': 79 | pattern = pattern[2:] 80 | anchored = False 81 | if pattern[0] == '/': 82 | pattern = pattern[1:] 83 | if pattern[-1] == '/': 84 | pattern = pattern[:-1] 85 | # patterns with leading hashes or exclamation marks are escaped with a 86 | # backslash in front, unescape it 87 | if pattern[0] == '\\' and pattern[1] in ('#', '!'): 88 | pattern = pattern[1:] 89 | # trailing spaces are ignored unless they are escaped with a backslash 90 | i = len(pattern) - 1 91 | striptrailingspaces = True 92 | while i > 1 and pattern[i] == ' ': 93 | if pattern[i - 1] == '\\': 94 | pattern = pattern[:i - 1] + pattern[i:] 95 | i = i - 1 96 | striptrailingspaces = False 97 | else: 98 | if striptrailingspaces: 99 | pattern = pattern[:i] 100 | i = i - 1 101 | regex = fnmatch_pathname_to_regex( 102 | pattern, directory_only, negation, anchored=bool(anchored) 103 | ) 104 | return IgnoreRule( 105 | pattern=orig_pattern, 106 | regex=regex, 107 | negation=negation, 108 | directory_only=directory_only, 109 | anchored=anchored, 110 | base_path=_normalize_path(base_path) if base_path else None, 111 | source=source 112 | ) 113 | 114 | 115 | IGNORE_RULE_FIELDS = [ 116 | 'pattern', 'regex', # Basic values 117 | 'negation', 'directory_only', 'anchored', # Behavior flags 118 | 'base_path', # Meaningful for gitignore-style behavior 119 | 'source' # (file, line) tuple for reporting 120 | ] 121 | 122 | 123 | class IgnoreRule(collections.namedtuple('IgnoreRule_', IGNORE_RULE_FIELDS)): 124 | def __str__(self): 125 | return self.pattern 126 | 127 | def __repr__(self): 128 | return ''.join(['IgnoreRule(\'', self.pattern, '\')']) 129 | 130 | def match(self, abs_path: Union[str, Path]): 131 | matched = False 132 | if self.base_path: 133 | rel_path = str(_normalize_path(abs_path).relative_to(self.base_path)) 134 | else: 135 | rel_path = str(_normalize_path(abs_path)) 136 | # Path() strips the trailing slash, so we need to preserve it 137 | # in case of directory-only negation 138 | if self.negation and type(abs_path) == str and abs_path[-1] == '/': 139 | rel_path += '/' 140 | if rel_path.startswith('./'): 141 | rel_path = rel_path[2:] 142 | if re.search(self.regex, rel_path): 143 | matched = True 144 | return matched 145 | 146 | 147 | # Frustratingly, python's fnmatch doesn't provide the FNM_PATHNAME 148 | # option that .gitignore's behavior depends on. 149 | def fnmatch_pathname_to_regex( 150 | pattern, directory_only: bool, negation: bool, anchored: bool = False 151 | ): 152 | """ 153 | Implements fnmatch style-behavior, as though with FNM_PATHNAME flagged; 154 | the path separator will not match shell-style '*' and '.' wildcards. 155 | """ 156 | i, n = 0, len(pattern) 157 | 158 | seps = [re.escape(os.sep)] 159 | if os.altsep is not None: 160 | seps.append(re.escape(os.altsep)) 161 | seps_group = '[' + '|'.join(seps) + ']' 162 | nonsep = r'[^{}]'.format('|'.join(seps)) 163 | 164 | res = [] 165 | while i < n: 166 | c = pattern[i] 167 | i += 1 168 | if c == '*': 169 | try: 170 | if pattern[i] == '*': 171 | i += 1 172 | if i < n and pattern[i] == '/': 173 | i += 1 174 | res.append(''.join(['(.*', seps_group, ')?'])) 175 | else: 176 | res.append('.*') 177 | else: 178 | res.append(''.join([nonsep, '*'])) 179 | except IndexError: 180 | res.append(''.join([nonsep, '*'])) 181 | elif c == '?': 182 | res.append(nonsep) 183 | elif c == '/': 184 | res.append(seps_group) 185 | elif c == '[': 186 | j = i 187 | if j < n and pattern[j] == '!': 188 | j += 1 189 | if j < n and pattern[j] == ']': 190 | j += 1 191 | while j < n and pattern[j] != ']': 192 | j += 1 193 | if j >= n: 194 | res.append('\\[') 195 | else: 196 | stuff = pattern[i:j].replace('\\', '\\\\').replace('/', '') 197 | i = j + 1 198 | if stuff[0] == '!': 199 | stuff = ''.join(['^', stuff[1:]]) 200 | elif stuff[0] == '^': 201 | stuff = ''.join('\\' + stuff) 202 | res.append('[{}]'.format(stuff)) 203 | else: 204 | res.append(re.escape(c)) 205 | if anchored: 206 | res.insert(0, '^') 207 | else: 208 | res.insert(0, f"(^|{seps_group})") 209 | if not directory_only: 210 | res.append('$') 211 | elif directory_only and negation: 212 | res.append('/$') 213 | else: 214 | res.append('($|\\/)') 215 | return ''.join(res) 216 | 217 | 218 | def _normalize_path(path: Union[str, Path]) -> Path: 219 | """Normalize a path without resolving symlinks. 220 | 221 | This is equivalent to `Path.resolve()` except that it does not resolve symlinks. 222 | Note that this simplifies paths by removing double slashes, `..`, `.` etc. like 223 | `Path.resolve()` does. 224 | """ 225 | return Path(abspath(path)) 226 | -------------------------------------------------------------------------------- /codebased/tui.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import contextlib 4 | import dataclasses 5 | import threading 6 | import time 7 | from enum import Enum 8 | from rich.syntax import Syntax 9 | from textual import work, events 10 | from textual.app import App, ComposeResult 11 | from textual.containers import Horizontal, VerticalScroll 12 | from textual.message import Message 13 | from textual.reactive import var 14 | from textual.widgets import Input, Footer, Header, Static, ListView, ListItem 15 | from typing import TypeVar, Generic 16 | 17 | from codebased.editor import open_editor, suspends 18 | from codebased.index import Flags, Config, Dependencies 19 | from codebased.search import search_once, RenderedResult, CombinedSearchResult, render_result, find_highlights, Query 20 | 21 | 22 | class Id(str, Enum): 23 | LATENCY = "latency" 24 | PREVIEW_CONTAINER = "preview-container" 25 | SEARCH_INPUT = "search-input" 26 | RESULTS_LIST = "results-list" 27 | RESULTS_CONTAINER = "results-container" 28 | PREVIEW = "preview" 29 | 30 | @property 31 | def selector(self) -> str: 32 | return "#" + self.value 33 | 34 | 35 | V = TypeVar('V') 36 | 37 | 38 | @dataclasses.dataclass 39 | class HWM(Generic[V]): 40 | key: float = float('-inf') 41 | _value: V | None = None 42 | _lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) 43 | 44 | @property 45 | def value(self) -> V: 46 | with self._lock: 47 | return self._value 48 | 49 | def set(self, key: float, value: V) -> bool: 50 | with self._lock: 51 | if key > self.key: 52 | self.key = key 53 | self._value = value 54 | return True 55 | return False 56 | 57 | 58 | class Codebased(App): 59 | BINDINGS = [ 60 | ("q", "quit", "Quit"), 61 | ("escape", "focus_search", "Focus search"), 62 | ("tab", "focus_preview", "Focus preview"), 63 | ("down", "focus_results", "Focus results"), 64 | ("d", "debug_mode", "Toggle debug mode"), 65 | ("f", "full_text_search", "Toggle full text search"), 66 | ("s", "semantic_search", "Toggle semantic search"), 67 | ("r", "rerank", "Toggle reranking"), 68 | ] 69 | 70 | show_results = var(False) 71 | 72 | CSS = """ 73 | #results-container { 74 | width: 100%; 75 | height: 100%; 76 | } 77 | 78 | #results-list { 79 | width: 30%; 80 | border-right: solid green; 81 | } 82 | 83 | #preview-container { 84 | width: 70%; 85 | } 86 | """ 87 | 88 | def __init__( 89 | self, 90 | flags: Flags, 91 | config: Config, 92 | dependencies: Dependencies, 93 | ): 94 | super().__init__() 95 | self.debug_mode = False 96 | self.flags = flags 97 | self.config = config 98 | self.dependencies = dependencies 99 | self.results: HWM[list[CombinedSearchResult]] = HWM() 100 | self.results.set(0, []) 101 | 102 | def compose(self) -> ComposeResult: 103 | yield Header() 104 | yield Input(placeholder="Enter your search query", id=Id.SEARCH_INPUT.value) 105 | yield Static(id=Id.LATENCY.value, shrink=True) 106 | with Horizontal(id=Id.RESULTS_CONTAINER.value): 107 | yield ListView(id=Id.RESULTS_LIST.value, initial_index=0) 108 | with VerticalScroll(id=Id.PREVIEW_CONTAINER.value): 109 | yield Static(id=Id.PREVIEW.value, expand=True) 110 | yield Footer() 111 | 112 | def on_mount(self): 113 | self.query_one(Id.SEARCH_INPUT.selector).focus() 114 | 115 | async def on_input_changed(self, event: Input.Changed): 116 | query = event.value 117 | self.flags = dataclasses.replace(self.flags, query=query) 118 | if len(query) >= 3: 119 | self.search_background(self.flags, time.monotonic()) 120 | else: 121 | await self.clear_results() 122 | 123 | def action_rerank(self): 124 | self.flags = dataclasses.replace(self.flags, rerank=not self.flags.rerank) 125 | self.search_background(self.flags, time.monotonic()) 126 | 127 | def action_full_text_search(self): 128 | self.flags = dataclasses.replace(self.flags, full_text_search=not self.flags.full_text_search) 129 | self.search_background(self.flags, time.monotonic()) 130 | 131 | def action_semantic_search(self): 132 | self.flags = dataclasses.replace(self.flags, semantic=not self.flags.semantic) 133 | self.search_background(self.flags, time.monotonic()) 134 | 135 | @work(thread=True) 136 | def search_background(self, flags: Flags, start_time: float): 137 | results, times = search_once(self.dependencies, flags) 138 | self.post_message(self.SearchCompleted(results, start_time, time.monotonic(), times)) 139 | 140 | class SearchCompleted(Message): 141 | def __init__( 142 | self, 143 | results: list[CombinedSearchResult], 144 | start: float, 145 | finish: float, 146 | times: dict[str, float] 147 | ): 148 | self.results = results 149 | self.start = start 150 | self.finish = finish 151 | self.times = times 152 | super().__init__() 153 | 154 | @property 155 | def latency(self) -> float: 156 | return self.finish - self.start 157 | 158 | class RenderResults(Message): 159 | pass 160 | 161 | def on_key(self, event: events.Key): 162 | if event.key == "enter": 163 | self.select_result() 164 | elif event.key == "up": 165 | focused = self.focused 166 | if isinstance(focused, ListView) and focused.id == Id.RESULTS_LIST.value: 167 | if focused.index == 0: 168 | self.action_focus_search() 169 | 170 | def select_result(self): 171 | focused = self.focused 172 | if isinstance(focused, ListView) and focused.id == Id.RESULTS_LIST.value: 173 | try: 174 | result = self.results.value[focused.index] 175 | self.open_result_in_editor(result) 176 | except IndexError: 177 | return 178 | elif focused and focused.id == Id.SEARCH_INPUT.value: 179 | self.query_one(Id.RESULTS_LIST.selector, ListView).focus() 180 | 181 | def open_result_in_editor(self, result: RenderedResult): 182 | file_path = self.config.root / result.obj.path 183 | row, col = result.obj.coordinates[0] 184 | contextlib.nullcontext() 185 | editor = self.dependencies.settings.editor 186 | with self.suspend() if suspends(editor) else contextlib.nullcontext(): 187 | open_editor(editor, file=file_path, row=row, column=col) 188 | 189 | def action_focus_search(self): 190 | self.query_one(Id.SEARCH_INPUT.selector, Input).focus() 191 | 192 | def action_focus_preview(self): 193 | self.query_one(Id.PREVIEW.selector, Static).focus() 194 | 195 | def action_focus_results(self): 196 | if self.focused and self.focused.id == Id.SEARCH_INPUT.value: 197 | self.query_one(Id.RESULTS_LIST.selector, ListView).focus() 198 | 199 | def action_debug_mode(self): 200 | self.debug_mode = not self.debug_mode 201 | self.post_message(self.RenderResults()) 202 | 203 | async def on_codebased_search_completed(self, message: SearchCompleted): 204 | def print_latency(total: float, times: dict[str, float]) -> str: 205 | filtered = {k: v for k, v in times.items() if v >= 0.001} 206 | breakdown = " + ".join(f"{k}: {v:.3f}s" for k, v in filtered.items()) 207 | return f"Completed in {total:.3f}s" + (f" ({breakdown})" if breakdown else "") 208 | 209 | if not self.results.set(message.start, message.results): 210 | return 211 | self.query_one(Id.LATENCY.selector, Static).update(print_latency(message.latency, message.times)) 212 | self.post_message(self.RenderResults()) 213 | 214 | async def on_codebased_render_results(self, event: RenderResults): 215 | results_list = await self.clear_results() 216 | for result in self.results.value: 217 | obj = result.obj 218 | lines = [str(obj.path)] 219 | if obj.kind != 'file': 220 | lines.append(f"[{obj.kind}] {obj.name}") 221 | if self.debug_mode: 222 | hit_categories = [] 223 | if result.l2 is not None: 224 | hit_categories.append(f"Semantic ({result.l2:.2f})") 225 | if result.bm25 is not None: 226 | hit_categories.append(f"Full text ({result.bm25:.1f})") 227 | reasoning = ' + '.join(hit_categories) 228 | lines.append(reasoning) 229 | item_text = '\n'.join(lines) 230 | await results_list.append(ListItem(Static(item_text), id=f"result-{obj.id}")) 231 | 232 | self.show_results = True 233 | if self.results.value: 234 | try: 235 | self.update_preview(self.results.value[0]) 236 | return 237 | except IndexError: 238 | pass 239 | preview = self.query_one(Id.PREVIEW.selector, Static) 240 | preview.update("") 241 | 242 | async def clear_results(self): 243 | results_list = self.query_one(Id.RESULTS_LIST.selector, ListView) 244 | await results_list.clear() 245 | return results_list 246 | 247 | def on_list_view_highlighted(self, event: ListView.Highlighted): 248 | item = event.item 249 | self.update_item_preview(item) 250 | 251 | def update_item_preview(self, item: ListItem | None): 252 | if item is not None: 253 | result_id = int(item.id.split("-")[1]) 254 | try: 255 | result = next(r for r in self.results.value if r.obj.id == result_id) 256 | self.update_preview(result) 257 | except StopIteration: 258 | pass 259 | 260 | def on_list_view_selected(self, event: ListView.Selected): 261 | item = event.item 262 | self.update_item_preview(item) 263 | 264 | def update_preview(self, result: CombinedSearchResult): 265 | preview = self.query_one(Id.PREVIEW.selector, Static) 266 | start_line, end_line = result.obj.coordinates[0][0], result.obj.coordinates[1][0] 267 | rendered_result, _ = render_result(self.config, self.flags, result, file=False, context=False) 268 | if rendered_result is None: 269 | return 270 | file_bytes = rendered_result.file_bytes 271 | try: 272 | code = file_bytes.decode('utf-8') 273 | except UnicodeDecodeError: 274 | code = file_bytes.decode('utf-16') 275 | lexer = Syntax.guess_lexer(str(result.obj.path), code) 276 | highlight_lines = rendered_result.highlighted_lines 277 | syntax = Syntax( 278 | code, 279 | lexer, 280 | theme="dracula", 281 | line_numbers=True, 282 | line_range=(start_line + 1, end_line + 1), 283 | highlight_lines={start_line + x + 1 for x in highlight_lines}, 284 | word_wrap=True 285 | ) 286 | preview.update(syntax) 287 | 288 | def watch_show_results(self, show_results: bool): 289 | self.query_one(Id.RESULTS_CONTAINER.selector).display = show_results 290 | -------------------------------------------------------------------------------- /codebased/search.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import textwrap 4 | 5 | import dataclasses 6 | import hashlib 7 | import json 8 | import math 9 | import numpy as np 10 | import re 11 | import sqlite3 12 | import time 13 | import typing as T 14 | from pathlib import Path 15 | 16 | from codebased.utils import decode_text 17 | 18 | if T.TYPE_CHECKING: 19 | from openai import OpenAI 20 | 21 | from codebased.embeddings import create_ephemeral_embedding 22 | from codebased.index import Dependencies, Flags, Config 23 | 24 | from codebased.models import Object 25 | from codebased.parser import render_object 26 | 27 | import bisect 28 | 29 | 30 | @dataclasses.dataclass(frozen=True) 31 | class Query: 32 | phrases: list[str] 33 | keywords: list[str] 34 | original: str 35 | 36 | @classmethod 37 | def parse(cls, query: str) -> Query: 38 | original = query 39 | phrases = [] 40 | keywords = [] 41 | 42 | pattern = r'(?:"((?:[^"\\]|\\.)*)"|\S+)' 43 | matches = re.finditer(pattern, query) 44 | 45 | for match in matches: 46 | if match.group(1) is not None: 47 | phrase = match.group(1).replace('\\"', '"') 48 | if phrase: 49 | phrases.append(phrase) 50 | else: 51 | keywords.append(match.group()) 52 | 53 | return cls(phrases=phrases, keywords=keywords, original=original) 54 | 55 | 56 | def get_offsets(text: str, byte_start: int) -> tuple[int, int]: 57 | return byte_start, ... 58 | 59 | 60 | def find_highlights(query: Query, text: str) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]: 61 | highlights = [] 62 | 63 | # Create a list of newline positions 64 | newline_positions = [m.start() for m in re.finditer('\n', text)] 65 | 66 | def get_line_number(char_index): 67 | return bisect.bisect(newline_positions, char_index) 68 | 69 | # Highlight keywords 70 | for keyword in query.keywords: 71 | for match in re.finditer(re.escape(keyword), text, re.IGNORECASE): 72 | highlights.append(match.span()) 73 | 74 | # Highlight phrases 75 | for phrase in query.phrases: 76 | for match in re.finditer(re.escape(phrase), text, re.IGNORECASE): 77 | highlights.append(match.span()) 78 | 79 | # Sort and merge overlapping highlights 80 | highlights.sort(key=lambda x: x[0]) 81 | merged = [] 82 | for start, end in highlights: 83 | if merged and merged[-1][1] >= start: 84 | merged[-1] = (merged[-1][0], max(merged[-1][1], end)) 85 | else: 86 | merged.append((start, end)) 87 | 88 | # Create parallel list of line numbers 89 | line_numbers = [(get_line_number(start), get_line_number(end - 1)) for start, end in merged] 90 | 91 | return merged, line_numbers 92 | 93 | 94 | @dataclasses.dataclass 95 | class SemanticSearchResult: 96 | obj: Object 97 | distance: float 98 | content_sha256: bytes 99 | 100 | 101 | @dataclasses.dataclass 102 | class FullTextSearchResult: 103 | obj: Object 104 | name_match: bool 105 | bm25: float 106 | content_sha256: bytes 107 | 108 | 109 | def l2_is_close(l2: float) -> bool: 110 | return l2 < math.sqrt(2) * .9 111 | 112 | 113 | @dataclasses.dataclass 114 | class CombinedSearchResult: 115 | obj: Object 116 | l2: T.Union[float, None] 117 | bm25: T.Union[float, None] 118 | content_sha256: bytes 119 | 120 | 121 | def semantic_search(dependencies: Dependencies, flags: Flags) -> tuple[list[SemanticSearchResult], dict[str, float]]: 122 | times = {} 123 | semantic_results: list[SemanticSearchResult] = [] 124 | start = time.perf_counter() 125 | emb = create_ephemeral_embedding( 126 | dependencies.openai_client, 127 | flags.query, 128 | dependencies.settings.embeddings 129 | ) 130 | end = time.perf_counter() 131 | times['embedding'] = end - start 132 | 133 | start = time.perf_counter() 134 | _, distances, object_ids = dependencies.index.range_search( 135 | np.array([emb]), 136 | flags.radius, 137 | ) 138 | distances, object_ids = distances[:flags.top_k], object_ids[:flags.top_k] 139 | end = time.perf_counter() 140 | times['vss'] = end - start 141 | 142 | start = time.perf_counter() 143 | placeholders = ','.join(['?'] * len(object_ids)) 144 | query = f""" 145 | SELECT o.*, 146 | (SELECT sha256_digest FROM file WHERE path = o.path) AS file_sha256_digest 147 | FROM object o 148 | WHERE o.id IN ({placeholders}) 149 | """ 150 | object_rows = dependencies.db.execute(query, [int(object_id) for object_id in object_ids]).fetchall() 151 | end = time.perf_counter() 152 | times['sqlite'] = end - start 153 | 154 | # Create a mapping of object_id to its index in the search results 155 | id_to_index = {int(id): index for index, id in enumerate(object_ids)} 156 | 157 | # Sort the object_rows based on their order in the search results 158 | sorted_object_rows = sorted(object_rows, key=lambda row: id_to_index[row['id']]) 159 | 160 | for object_row, distance in zip(sorted_object_rows, distances): 161 | obj = deserialize_object_row(object_row) 162 | result = SemanticSearchResult(obj, float(distance), object_row['file_sha256_digest']) 163 | semantic_results.append(result) 164 | 165 | return semantic_results, times 166 | 167 | 168 | _quote_fts_re = re.compile(r'\s+|(".*?")') 169 | 170 | 171 | def quote_fts_query(query: str) -> str: 172 | if query.count('"') % 2: 173 | query += '"' 174 | bits = _quote_fts_re.split(query) 175 | bits = [b for b in bits if b and b != '""'] 176 | query = " ".join( 177 | '"{}"'.format(bit) if not bit.startswith('"') else bit for bit in bits 178 | ) 179 | return query 180 | 181 | 182 | def rerank_results( 183 | query: str, 184 | results: list[CombinedSearchResult], 185 | oai_client: "OpenAI" 186 | ) -> list[CombinedSearchResult]: 187 | if not results: 188 | return results 189 | json_results = [ 190 | { 191 | "id": r.obj.id, 192 | "path": str(r.obj.path), 193 | "name": r.obj.name, 194 | "kind": r.obj.kind, 195 | "line_length": r.obj.line_length, 196 | "byte_length": r.obj.byte_range[1] - r.obj.byte_range[0], 197 | } 198 | for r in results 199 | ] 200 | json_results = json.dumps(json_results) 201 | system_prompt = textwrap.dedent( 202 | """ 203 | You're acting as the reranking component in a search engine for code. 204 | The following is a list of results from a search query. 205 | Please respond with a JSON list of result IDs, in order of relevance, excluding irrelevant or low quality results. 206 | Implementations are typically better than tests/mocks/documentation, unless the query 207 | asked for these specifically. 208 | Prefer code elements like structs, classes, functions, etc. to entire files. 209 | Include ALL relevant results, there is no limit on the number of results you should return. 210 | Including any non-JSON content will cause the application to crash and / or increase latency, which is bad. 211 | """ 212 | ) 213 | messages = [ 214 | {"role": "system", "content": system_prompt}, 215 | {"role": "system", "content": f"Query: {query}\nResults: {json_results}"} 216 | ] 217 | response = oai_client.chat.completions.create( 218 | model="gpt-4o-mini", 219 | messages=messages, 220 | # temperature=0.0, 221 | ) 222 | content = response.choices[0].message.content 223 | cleaned_content = content[content.find('['):content.rfind(']') + 1] 224 | parsed_reranking_results = json.loads(cleaned_content) 225 | results_by_id = {r.obj.id: r for r in results} 226 | out = [] 227 | for result_id in parsed_reranking_results: 228 | try: 229 | out.append(results_by_id.pop(result_id)) 230 | except KeyError: 231 | continue 232 | return out 233 | 234 | 235 | def full_text_search(dependencies: Dependencies, flags: Flags) -> tuple[list[FullTextSearchResult], dict[str, float]]: 236 | fts_results = [] 237 | query = quote_fts_query(flags.query) 238 | times = {} 239 | start = time.perf_counter() 240 | object_rows = dependencies.db.execute( 241 | """ 242 | with name_matches as ( 243 | select rowid, true as name_match, rank 244 | from fts 245 | where name match :query 246 | order by rank 247 | limit :top_k 248 | ), 249 | content_matches as ( 250 | select rowid, false as name_match, rank 251 | from fts(:query) 252 | order by rank 253 | limit :top_k 254 | ), 255 | all_matches as ( 256 | select * from name_matches 257 | union all 258 | select * from content_matches 259 | ), 260 | min_rank_by_rowid as ( 261 | select 262 | rowid, 263 | max(name_match) as name_match, 264 | min(rank) as rank 265 | from all_matches 266 | group by rowid 267 | order by name_match desc, rank 268 | ), 269 | sorted_limited_results as ( 270 | select 271 | rowid, 272 | name_match, 273 | rank 274 | from min_rank_by_rowid 275 | order by name_match desc, rank 276 | limit :top_k 277 | ), 278 | ranked_objects as ( 279 | select o.id, 280 | o.path, 281 | o.name, 282 | o.language, 283 | o.context_before, 284 | o.context_after, 285 | o.kind, 286 | o.byte_range, 287 | o.coordinates, 288 | s.name_match, 289 | s.rank 290 | from object o 291 | inner join sorted_limited_results s on o.id = s.rowid 292 | ) 293 | select *, 294 | (select sha256_digest from file where path = o.path) as file_sha256_digest 295 | from ranked_objects o 296 | order by o.name_match desc, o.rank; 297 | """, 298 | { 299 | 'query': query, 300 | 'top_k': flags.top_k 301 | } 302 | ).fetchall() 303 | times['fts'] = time.perf_counter() - start 304 | for object_row in object_rows: 305 | obj = deserialize_object_row(object_row) 306 | fts_results.append( 307 | FullTextSearchResult( 308 | obj, 309 | object_row['name_match'], 310 | object_row['rank'], 311 | object_row['file_sha256_digest'] 312 | ) 313 | ) 314 | return fts_results, times 315 | 316 | 317 | def merge_results( 318 | semantic_results: list[SemanticSearchResult], 319 | full_text_results: list[FullTextSearchResult] 320 | ) -> list[CombinedSearchResult]: 321 | results: list[CombinedSearchResult] = [] 322 | semantic_ids = {result.obj.id: i for i, result in enumerate(semantic_results)} 323 | full_text_ids = {result.obj.id: i for i, result in enumerate(full_text_results)} 324 | both = set(semantic_ids) & set(full_text_ids) 325 | name_matches = {x.obj.id for x in full_text_results if x.name_match} 326 | sort_key = {} 327 | for obj_id in both: 328 | semantic_index = semantic_ids.pop(obj_id) 329 | full_text_index = full_text_ids.pop(obj_id) 330 | semantic_result = semantic_results[semantic_index] 331 | full_text_result = full_text_results[full_text_index] 332 | assert semantic_result.content_sha256 == full_text_result.content_sha256 333 | result = CombinedSearchResult( 334 | semantic_result.obj, 335 | semantic_result.distance, 336 | full_text_result.bm25, 337 | semantic_result.content_sha256 338 | ) 339 | sort_key[obj_id] = ( 340 | 0, 341 | min( 342 | semantic_index, 343 | full_text_index 344 | ) 345 | ) 346 | results.append(result) 347 | for obj_id, full_text_index in full_text_ids.items(): 348 | full_text_result = full_text_results[full_text_index] 349 | results.append( 350 | CombinedSearchResult( 351 | full_text_result.obj, 352 | None, 353 | full_text_result.bm25, 354 | full_text_result.content_sha256 355 | ) 356 | ) 357 | sort_key[obj_id] = (1, full_text_index) 358 | for obj_id, semantic_index in semantic_ids.items(): 359 | semantic_result = semantic_results[semantic_index] 360 | results.append( 361 | CombinedSearchResult( 362 | semantic_result.obj, 363 | semantic_result.distance, 364 | None, 365 | semantic_result.content_sha256 366 | ) 367 | ) 368 | sort_key[obj_id] = (1, semantic_index) 369 | for i, result in enumerate(full_text_results): 370 | obj_id = result.obj.id 371 | if obj_id in name_matches: 372 | sort_key[obj_id] = (-1, i) 373 | else: 374 | break 375 | return sorted(results, key=lambda r: sort_key[r.obj.id]) 376 | 377 | 378 | @dataclasses.dataclass 379 | class SearchResults: 380 | results: list[CombinedSearchResult] 381 | times: dict[str, float] 382 | 383 | 384 | def search_once(dependencies: Dependencies, flags: Flags) -> tuple[list[CombinedSearchResult], dict[str, float]]: 385 | try: 386 | return dependencies.search_cache[flags], {} 387 | except KeyError: 388 | pass 389 | semantic_results, semantic_times = semantic_search(dependencies, flags) if flags.semantic else ([], {}) 390 | full_text_results, full_text_times = full_text_search(dependencies, flags) if flags.full_text_search else ([], {}) 391 | results = merge_results(semantic_results, full_text_results) 392 | total_times = semantic_times 393 | if flags.rerank: 394 | rerank_start = time.perf_counter() 395 | results = rerank_results(flags.query, results, dependencies.openai_client) 396 | total_times['reranking'] = time.perf_counter() - rerank_start 397 | # results = results[:flags.top_k] 398 | dependencies.search_cache[flags] = results 399 | for key, value in full_text_times.items(): 400 | total_times[key] = total_times.get(key, 0) + value 401 | return results, total_times 402 | 403 | 404 | def deserialize_object_row(object_row: sqlite3.Row) -> Object: 405 | return Object( 406 | id=object_row['id'], 407 | path=Path(object_row['path']), 408 | name=object_row['name'], 409 | language=object_row['language'], 410 | context_before=json.loads(object_row['context_before']), 411 | context_after=json.loads(object_row['context_after']), 412 | kind=object_row['kind'], 413 | byte_range=json.loads(object_row['byte_range']), 414 | coordinates=json.loads(object_row['coordinates']) 415 | ) 416 | 417 | 418 | @dataclasses.dataclass 419 | class RenderedResult(CombinedSearchResult): 420 | content: str 421 | file_bytes: bytes 422 | highlights: list[tuple[int, int]] 423 | highlighted_lines: list[int] 424 | 425 | 426 | def render_result( 427 | config: Config, 428 | flags: Flags, 429 | result: CombinedSearchResult, 430 | **kwargs 431 | ) -> tuple[RenderedResult | None, dict[str, float]]: 432 | abs_path = config.root / result.obj.path 433 | times = {'disk': 0, 'render': 0} 434 | parsed = Query.parse(flags.query) 435 | try: 436 | # TODO: Memoize, at least within a search result set. 437 | start = time.perf_counter() 438 | underlying_file_bytes = abs_path.read_bytes() 439 | times['disk'] += time.perf_counter() - start 440 | actual_sha256 = hashlib.sha256(underlying_file_bytes).digest() 441 | if result.content_sha256 != actual_sha256: 442 | return None, times 443 | start = time.perf_counter() 444 | decoded_text = decode_text(underlying_file_bytes) 445 | if decoded_text is None: 446 | return None, times 447 | lines = decoded_text.splitlines() 448 | rendered = render_object(result.obj, lines, **kwargs) 449 | times['render'] += time.perf_counter() - start 450 | highlights, highlighted_lines = find_highlights(parsed, rendered) 451 | rendered_result = RenderedResult( 452 | obj=result.obj, 453 | l2=result.l2, 454 | bm25=result.bm25, 455 | content_sha256=result.content_sha256, 456 | content=rendered, 457 | file_bytes=underlying_file_bytes, 458 | highlights=highlights, 459 | highlighted_lines=sorted({y for x in highlighted_lines for y in x}) 460 | ) 461 | return rendered_result, times 462 | except FileNotFoundError: 463 | return None, times 464 | 465 | 466 | def render_results( 467 | config: Config, 468 | flags: Flags, 469 | results: list[CombinedSearchResult], 470 | **kwargs 471 | ) -> tuple[list[RenderedResult], dict[str, float]]: 472 | rendered_results, times = [], {} 473 | for result in results: 474 | rendered_result, result_times = render_result(config, flags, result, **kwargs) 475 | rendered_results.append(rendered_result) 476 | for key, value in result_times.items(): 477 | times[key] = times.get(key, 0) + value 478 | return rendered_results, times 479 | 480 | 481 | def print_results( 482 | config: Config, 483 | flags: Flags, 484 | results: list[CombinedSearchResult] 485 | ): 486 | rendered_results, times = render_results(config, flags, results) 487 | for result in rendered_results: 488 | print(result.content) 489 | print() 490 | -------------------------------------------------------------------------------- /codebased/parser.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import lru_cache 4 | from pathlib import Path 5 | 6 | import tree_sitter 7 | import tree_sitter_c 8 | import tree_sitter_c_sharp 9 | import tree_sitter_cpp 10 | import tree_sitter_go 11 | import tree_sitter_java 12 | import tree_sitter_javascript 13 | import tree_sitter_php 14 | import tree_sitter_python 15 | import tree_sitter_ruby 16 | import tree_sitter_rust 17 | import tree_sitter_typescript 18 | 19 | from codebased.models import Object, Coordinates 20 | from codebased.utils import decode_text 21 | 22 | CPP_TAG_QUERY = """ 23 | (field_declaration (function_declarator declarator: (field_identifier) @name)) @definition.method 24 | ; removed the local scope from the following line after namespace_identifier 25 | (function_definition (function_declarator declarator: (qualified_identifier scope: (namespace_identifier) name: (identifier) @name))) @definition.method 26 | (class_specifier . name: (type_identifier) @name) @definition.class 27 | """ 28 | 29 | C_TAG_QUERY = """ 30 | (struct_specifier name: (type_identifier) @name body:(_)) @definition.struct 31 | (declaration type: (union_specifier name: (type_identifier) @name)) @definition.class 32 | (function_definition declarator: (function_declarator declarator: (identifier) @name)) @definition.function 33 | (type_definition declarator: (type_identifier) @name) @definition.type 34 | (enum_specifier name: (type_identifier) @name) @definition.type 35 | """ 36 | 37 | 38 | class LanguageImpl: 39 | def __init__( 40 | self, 41 | name: str, 42 | parser: tree_sitter.Parser, 43 | language: tree_sitter.Language, 44 | file_types: list[str], 45 | tags: tree_sitter.Query, 46 | ): 47 | self.name = name 48 | self.parser = parser 49 | self.language = language 50 | self.file_types = file_types 51 | self.tags = tags 52 | 53 | @classmethod 54 | def from_language(cls, language: tree_sitter.Language, *, tags: str, file_types: list[str], name: str): 55 | parser = tree_sitter.Parser(language) 56 | return cls( 57 | name=name, 58 | parser=parser, 59 | language=language, 60 | file_types=file_types, 61 | tags=language.query(tags) 62 | ) 63 | 64 | 65 | def get_node_coordinates(node: tree_sitter.Node) -> Coordinates: 66 | return node.start_point, node.end_point 67 | 68 | 69 | def get_text_coordinates(text: bytes) -> Coordinates: 70 | line_count = text.count(b'\n') + 1 71 | last_newline_pos = text.rfind(b'\n') 72 | total_length = len(text) 73 | last_line_length = total_length - last_newline_pos - 1 74 | if last_newline_pos == -1: 75 | last_line_length = total_length 76 | return (0, 0), (line_count - 1, last_line_length) 77 | 78 | 79 | def get_all_parents(node: tree_sitter.Node) -> list[tree_sitter.Node]: 80 | parents = [] 81 | parent = node.parent 82 | while parent: 83 | parents.append(parent) 84 | parent = parent.parent 85 | return parents 86 | 87 | 88 | def get_context(node: tree_sitter.Node) -> tuple[list[int], list[int]]: 89 | parents = get_all_parents(node) 90 | before, after = [], [] 91 | start_line, end_line = float('-inf'), float('inf') 92 | try: 93 | # The root node is typically like a file or something. 94 | parents.pop() 95 | while parents: 96 | parent = parents.pop() 97 | if parent.children_by_field_name('name') or parent.type == 'impl_item' and parent.children_by_field_name( 98 | 'type' 99 | ): 100 | pass 101 | else: 102 | continue 103 | parent_start_line = parent.start_point.row 104 | assert parent_start_line >= start_line 105 | if start_line < parent_start_line < node.start_point.row: 106 | # first_line_text = parent.text[:parent.text.find(b'\n')] 107 | before.append(parent_start_line) 108 | parent_end_line = parent.end_point.row 109 | assert parent_end_line <= end_line 110 | if node.end_point.row < parent_end_line < end_line: 111 | # last_line_text = parent.text[parent.text.rfind(b'\n') + 1:] 112 | after.append(parent_end_line) 113 | start_line = parent_start_line 114 | end_line = parent_end_line 115 | except IndexError: 116 | pass 117 | return before, after 118 | 119 | 120 | def parse_objects(path: Path, text: bytes) -> list[Object]: 121 | file_type = path.suffix[1:] 122 | impl = get_language_for_file_type(file_type) 123 | language_name = impl.name if impl else 'text' 124 | objects = [ 125 | Object( 126 | path=path, 127 | name=str(path), 128 | language=language_name, 129 | kind='file', 130 | byte_range=(0, len(text)), 131 | coordinates=get_text_coordinates(text), 132 | context_before=[], 133 | context_after=[] 134 | ) 135 | ] 136 | if impl is None: 137 | return objects 138 | tree = impl.parser.parse(text) 139 | root_node = tree.root_node 140 | chunks = objects 141 | matches = impl.tags.matches(root_node) 142 | for _, captures in matches: 143 | name_node = captures.pop('name')[0] 144 | for definition_kind, definition_nodes in captures.items(): 145 | for definition_node in definition_nodes: 146 | before, after = get_context(definition_node) 147 | chunks.append( 148 | Object( 149 | path=path, 150 | name=decode_text(name_node.text), 151 | kind=definition_kind, 152 | language=impl.name, 153 | context_before=before, 154 | context_after=after, 155 | byte_range=definition_node.byte_range, 156 | coordinates=get_node_coordinates(definition_node) 157 | ) 158 | ) 159 | return chunks 160 | 161 | 162 | def get_language_for_file_type(file_type: str) -> LanguageImpl | None: 163 | if file_type == 'py': 164 | return get_python_impl() 165 | elif file_type == 'rs': 166 | return get_rust_impl() 167 | elif file_type in {'cc', 'cpp', 'cxx', 'hpp', 'hxx', 'h'}: 168 | return get_cpp_impl() 169 | elif file_type == 'c': 170 | return get_c_impl() 171 | elif file_type == 'cs': 172 | return get_c_sharp_impl() 173 | elif file_type == 'go': 174 | return get_go_impl() 175 | elif file_type == 'java': 176 | return get_java_impl() 177 | elif file_type == 'js' or file_type == 'mjs' or file_type == 'cjs' or file_type == 'jsx': 178 | return get_javascript_impl() 179 | elif file_type == 'php': 180 | return get_php_impl() 181 | elif file_type == 'rb': 182 | return get_ruby_impl() 183 | elif file_type == 'ts': 184 | return get_typescript_impl() 185 | elif file_type == 'tsx': 186 | return get_tsx_impl() 187 | else: 188 | return None 189 | 190 | 191 | @lru_cache(1) 192 | def get_php_impl() -> LanguageImpl: 193 | PHP_IMPL = LanguageImpl.from_language( 194 | tree_sitter.Language(tree_sitter_php.language_php()), 195 | tags=""" 196 | (namespace_definition 197 | name: (namespace_name) @name) @definition.module 198 | 199 | (interface_declaration 200 | name: (name) @name) @definition.interface 201 | 202 | (trait_declaration 203 | name: (name) @name) @definition.interface 204 | 205 | (class_declaration 206 | name: (name) @name) @definition.class 207 | 208 | (class_interface_clause [(name) (qualified_name)] @name) @definition.class_interface_clause 209 | 210 | (property_declaration 211 | (property_element (variable_name (name) @name))) @definition.field 212 | 213 | (function_definition 214 | name: (name) @name) @definition.function 215 | 216 | (method_declaration 217 | name: (name) @name) @definition.method 218 | """, 219 | file_types=['php'], 220 | name='php' 221 | ) 222 | return PHP_IMPL 223 | 224 | 225 | @lru_cache(1) 226 | def get_ruby_impl() -> LanguageImpl: 227 | RUBY_IMPL = LanguageImpl.from_language( 228 | tree_sitter.Language(tree_sitter_ruby.language()), 229 | tags=""" 230 | ; Method definitions 231 | (method 232 | name: (_) @name) @definition.method 233 | (singleton_method 234 | name: (_) @name) @definition.method 235 | 236 | (alias 237 | name: (_) @name) @definition.method 238 | 239 | (class 240 | name: [ 241 | (constant) @name 242 | (scope_resolution 243 | name: (_) @name) 244 | ]) @definition.class 245 | (singleton_class 246 | value: [ 247 | (constant) @name 248 | (scope_resolution 249 | name: (_) @name) 250 | ]) @definition.class 251 | 252 | ; Module definitions 253 | 254 | (module 255 | name: [ 256 | (constant) @name 257 | (scope_resolution 258 | name: (_) @name) 259 | ]) @definition.module 260 | """, 261 | file_types=['rb'], 262 | name='ruby' 263 | ) 264 | return RUBY_IMPL 265 | 266 | 267 | _TYPESCRIPT_ONLY_TAG_QUERY = """ 268 | (function_signature 269 | name: (identifier) @name) @definition.function 270 | 271 | (method_signature 272 | name: (property_identifier) @name) @definition.method 273 | 274 | (abstract_method_signature 275 | name: (property_identifier) @name) @definition.method 276 | 277 | (abstract_class_declaration 278 | name: (type_identifier) @name) @definition.class 279 | 280 | (module 281 | name: (identifier) @name) @definition.module 282 | 283 | (interface_declaration 284 | name: (type_identifier) @name) @definition.interface 285 | """ 286 | _JAVASCRIPT_TAG_QUERY = """ 287 | (program 288 | (lexical_declaration 289 | (variable_declarator 290 | name: (identifier) @name 291 | value: (_ !parameters) 292 | ) 293 | ) @definition.constant) 294 | 295 | (program 296 | (export_statement 297 | (lexical_declaration 298 | (variable_declarator 299 | name: (identifier) @name 300 | value: (_ !parameters) 301 | ) 302 | ) 303 | ) @definition.constant) 304 | 305 | (program 306 | (variable_declaration 307 | (variable_declarator 308 | name: (identifier) @name 309 | value: (_ !parameters) 310 | ) 311 | ) @definition.constant) 312 | 313 | (program 314 | (export_statement 315 | (variable_declaration 316 | (variable_declarator 317 | name: (identifier) @name 318 | value: (_ !parameters) 319 | ) 320 | ) 321 | ) @definition.constant) 322 | 323 | (method_definition 324 | name: (property_identifier) @name) @definition.method 325 | 326 | (class 327 | name: (_) @name) @definition.class 328 | 329 | (class_declaration 330 | name: (_) @name) @definition.class 331 | 332 | (function_expression 333 | name: (identifier) @name) @definition.function 334 | 335 | (function_declaration 336 | name: (identifier) @name) @definition.function 337 | 338 | (generator_function 339 | name: (identifier) @name) @definition.function 340 | 341 | (generator_function_declaration 342 | name: (identifier) @name) @definition.function 343 | 344 | (variable_declarator 345 | name: (identifier) @name 346 | value: [(arrow_function) (function_expression)]) @definition.function 347 | 348 | (assignment_expression 349 | left: [ 350 | (identifier) @name 351 | (member_expression 352 | property: (property_identifier) @name) 353 | ] 354 | right: [(arrow_function) (function_expression)]) @definition.function 355 | 356 | (pair 357 | key: (property_identifier) @name 358 | value: [(arrow_function) (function_expression)]) @definition.function 359 | """ 360 | _TYPESCRIPT_TAG_QUERY = '\n'.join([_TYPESCRIPT_ONLY_TAG_QUERY, _JAVASCRIPT_TAG_QUERY]) 361 | 362 | 363 | @lru_cache(1) 364 | def get_typescript_impl() -> LanguageImpl: 365 | TYPESCRIPT_IMPL = LanguageImpl.from_language( 366 | tree_sitter.Language(tree_sitter_typescript.language_typescript()), 367 | tags=_TYPESCRIPT_TAG_QUERY, 368 | file_types=[ 369 | 'ts', 370 | ], 371 | name='typescript' 372 | ) 373 | return TYPESCRIPT_IMPL 374 | 375 | 376 | @lru_cache(1) 377 | def get_tsx_impl() -> LanguageImpl: 378 | TSX_IMPL = LanguageImpl.from_language( 379 | tree_sitter.Language(tree_sitter_typescript.language_tsx()), 380 | tags=_TYPESCRIPT_TAG_QUERY, 381 | file_types=[ 382 | 'tsx', 383 | ], 384 | name='tsx' 385 | ) 386 | return TSX_IMPL 387 | 388 | 389 | @lru_cache(1) 390 | def get_python_impl() -> LanguageImpl: 391 | PYTHON_IMPL = LanguageImpl.from_language( 392 | # Don't make breaking changes on me dawg. 393 | tree_sitter.Language(tree_sitter_python.language()), 394 | tags=""" 395 | (module (expression_statement (assignment left: (identifier) @name) @definition.constant)) 396 | 397 | (class_definition 398 | name: (identifier) @name) @definition.class 399 | 400 | (function_definition 401 | name: (identifier) @name) @definition.function 402 | """, 403 | file_types=['py'], 404 | name='python' 405 | ) 406 | return PYTHON_IMPL 407 | 408 | 409 | @lru_cache(1) 410 | def get_rust_impl() -> LanguageImpl: 411 | RUST_IMPL = LanguageImpl.from_language( 412 | tree_sitter.Language(tree_sitter_rust.language()), 413 | tags=""" 414 | ; ADT definitions 415 | 416 | (struct_item 417 | name: (type_identifier) @name) @definition.struct 418 | 419 | (enum_item 420 | name: (type_identifier) @name) @definition.class 421 | 422 | (union_item 423 | name: (type_identifier) @name) @definition.class 424 | 425 | ; type aliases 426 | 427 | (type_item 428 | name: (type_identifier) @name) @definition.class 429 | 430 | ; method definitions 431 | 432 | (function_item 433 | name: (identifier) @name) @definition.function 434 | 435 | ; trait definitions 436 | (trait_item 437 | name: (type_identifier) @name) @definition.interface 438 | 439 | ; module definitions 440 | (mod_item 441 | name: (identifier) @name) @definition.module 442 | 443 | ; macro definitions 444 | 445 | (macro_definition 446 | name: (identifier) @name) @definition.macro 447 | 448 | ; implementations 449 | 450 | (impl_item 451 | trait: (type_identifier) @name) @definition.trait.impl 452 | 453 | (impl_item 454 | type: (type_identifier) @name 455 | !trait) @definition.struct.impl 456 | 457 | """, 458 | file_types=['rs'], 459 | name='rust' 460 | ) 461 | return RUST_IMPL 462 | 463 | 464 | @lru_cache(1) 465 | def get_c_impl() -> LanguageImpl: 466 | C_IMPL = LanguageImpl.from_language( 467 | tree_sitter.Language(tree_sitter_c.language()), 468 | tags=C_TAG_QUERY, 469 | file_types=['c'], 470 | name='c' 471 | ) 472 | return C_IMPL 473 | 474 | 475 | @lru_cache(1) 476 | def get_cpp_impl() -> LanguageImpl: 477 | CPP_IMPL = LanguageImpl.from_language( 478 | tree_sitter.Language(tree_sitter_cpp.language()), 479 | tags='\n'.join([C_TAG_QUERY, CPP_TAG_QUERY]), 480 | file_types=[ 481 | "cc", 482 | "cpp", 483 | "cxx", 484 | "hpp", 485 | "hxx", 486 | "h" 487 | ], 488 | name='cpp' 489 | ) 490 | return CPP_IMPL 491 | 492 | 493 | @lru_cache(1) 494 | def get_c_sharp_impl() -> LanguageImpl: 495 | C_SHARP_IMPL = LanguageImpl.from_language( 496 | tree_sitter.Language(tree_sitter_c_sharp.language()), 497 | tags=""" 498 | (class_declaration name: (identifier) @name) @definition.class 499 | (interface_declaration name: (identifier) @name) @definition.interface 500 | (method_declaration name: (identifier) @name) @definition.method 501 | (namespace_declaration name: (identifier) @name) @definition.module 502 | """, 503 | file_types=['cs'], 504 | name='csharp' 505 | ) 506 | return C_SHARP_IMPL 507 | 508 | 509 | @lru_cache(1) 510 | def get_go_impl() -> LanguageImpl: 511 | GO_IMPL = LanguageImpl.from_language( 512 | tree_sitter.Language(tree_sitter_go.language()), 513 | # TODO: Need to add constants to this. 514 | tags=""" 515 | (function_declaration 516 | name: (identifier) @name) @definition.function 517 | (method_declaration 518 | name: (field_identifier) @name) @definition.method 519 | (type_declaration (type_spec 520 | name: (type_identifier) @name)) @definition.type 521 | """, 522 | file_types=['go'], 523 | name='go' 524 | ) 525 | return GO_IMPL 526 | 527 | 528 | @lru_cache(1) 529 | def get_java_impl() -> LanguageImpl: 530 | JAVA_IMPL = LanguageImpl.from_language( 531 | tree_sitter.Language(tree_sitter_java.language()), 532 | tags=""" 533 | (class_declaration 534 | name: (identifier) @name) @definition.class 535 | 536 | (method_declaration 537 | name: (identifier) @name) @definition.method 538 | 539 | (interface_declaration 540 | name: (identifier) @name) @definition.interface 541 | """, 542 | file_types=['java'], 543 | name='java' 544 | ) 545 | return JAVA_IMPL 546 | 547 | 548 | @lru_cache(1) 549 | def get_javascript_impl() -> LanguageImpl: 550 | JAVASCRIPT_IMPL = LanguageImpl.from_language( 551 | tree_sitter.Language(tree_sitter_javascript.language()), 552 | tags=_JAVASCRIPT_TAG_QUERY, 553 | file_types=[ 554 | "js", 555 | "mjs", 556 | "cjs", 557 | "jsx" 558 | ], 559 | name='javascript' 560 | ) 561 | return JAVASCRIPT_IMPL 562 | 563 | 564 | def render_object( 565 | obj: Object, 566 | in_lines: list[str], 567 | *, 568 | context: bool = True, 569 | file: bool = True, 570 | line_numbers: bool = False, 571 | ) -> str: 572 | out_lines = [] 573 | if file: 574 | out_lines.append(str(obj.path)) 575 | out_lines.append('') 576 | max_line_no = max( 577 | obj.coordinates[0][0], 578 | obj.coordinates[1][0], 579 | *obj.context_before, 580 | # *obj.context_after 581 | ) + 1 582 | line_width = len(str(max_line_no)) 583 | 584 | def line_formatter(line_index: int, line_content: str) -> str: 585 | if line_numbers: 586 | line_number = line_index + 1 587 | return str(line_number).rjust(line_width) + " " + line_content 588 | return line_content 589 | 590 | if context: 591 | for line in obj.context_before: 592 | out_lines.append(line_formatter(line, in_lines[line])) 593 | start_line, end_line = obj.coordinates[0][0], obj.coordinates[1][0] 594 | for i in range(start_line, end_line + 1): 595 | try: 596 | out_lines.append(line_formatter(i, in_lines[i])) 597 | except IndexError: 598 | # If there's a newline at the end of the file. 599 | if i == end_line: 600 | break 601 | raise 602 | # if context: 603 | # for line in obj.context_after[::-1]: 604 | # out_lines.append(line_formatter(line, in_lines[line])) 605 | return '\n'.join(out_lines) 606 | -------------------------------------------------------------------------------- /tests/test_parser.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import textwrap 6 | 7 | import pytest 8 | 9 | from codebased.parser import parse_objects 10 | 11 | 12 | @pytest.mark.parametrize("file_type", ["ts", "js", "jsx", "tsx"]) 13 | def test_javascript_top_level_variable_declarations(file_type): 14 | source = textwrap.dedent( 15 | """ 16 | let stringData = "Hello, world!"; 17 | export const numberData = 123; 18 | const booleanData = true; 19 | export const nullData = null; 20 | export let undefinedData = undefined; 21 | export var objectData = { id: 1, name: 'John', age: 30 }; 22 | var arrayData = [ 23 | { id: 1, name: 'John', age: 30 }, 24 | { id: 2, name: 'Jane', age: 25 }, 25 | { id: 3, name: 'Bob', age: 35 }, 26 | ]; 27 | 28 | export const hidePII = (datum) => { 29 | return {id: datum.id}; 30 | }; 31 | function maskPII(datum) { 32 | return { 33 | id: datum.id, 34 | name: datum.name.replace(/./g, '*'), 35 | age: string(datum.age).replace(/./g, '*'), 36 | }; 37 | } 38 | 39 | export const sanitizedData = hidePII(objectData); 40 | """ 41 | ).encode() 42 | file_name = f'src/constants.{file_type}' 43 | objects = parse_objects( 44 | Path(file_name), 45 | source 46 | ) 47 | assert len(objects) == 11 48 | file_o, string_o, number_o, boolean_o, null_o, undefined_o = objects[:6] 49 | object_o, array_o, hide_pii_o, mask_pii_o, sanitized_o = objects[6:] 50 | assert file_o.name == file_name 51 | assert file_o.kind == 'file' 52 | assert string_o.name == 'stringData' 53 | assert string_o.kind == 'definition.constant' 54 | assert number_o.name == 'numberData' 55 | assert number_o.kind == 'definition.constant' 56 | assert boolean_o.name == 'booleanData' 57 | assert boolean_o.kind == 'definition.constant' 58 | assert null_o.name == 'nullData' 59 | assert null_o.kind == 'definition.constant' 60 | assert undefined_o.name == 'undefinedData' 61 | assert undefined_o.kind == 'definition.constant' 62 | assert object_o.name == 'objectData' 63 | assert object_o.kind == 'definition.constant' 64 | assert array_o.name == 'arrayData' 65 | assert array_o.kind == 'definition.constant' 66 | assert hide_pii_o.name == 'hidePII' 67 | assert hide_pii_o.kind == 'definition.function' 68 | assert mask_pii_o.name == 'maskPII' 69 | assert mask_pii_o.kind == 'definition.function' 70 | assert sanitized_o.name == 'sanitizedData' 71 | assert sanitized_o.kind == 'definition.constant' 72 | 73 | 74 | def test_parse_cxx_header_file(): 75 | file_name = 'src/shapes.h' 76 | source = textwrap.dedent( 77 | """ 78 | #ifndef SHAPES_H 79 | #define SHAPES_H 80 | 81 | #include 82 | 83 | struct Point { 84 | double x; 85 | double y; 86 | }; 87 | 88 | class Shape { 89 | public: 90 | Shape(); 91 | virtual ~Shape(); 92 | virtual double area() = 0; 93 | }; 94 | 95 | class Circle : public Shape { 96 | public: 97 | Circle(double radius); 98 | double area() override; 99 | private: 100 | double radius_; 101 | }; 102 | 103 | class Rectangle : public Shape { 104 | public: 105 | Rectangle(double width, double height); 106 | double area() override; 107 | private: 108 | double width_; 109 | double height_; 110 | }; 111 | 112 | #endif 113 | """ 114 | ).encode() 115 | source_lines = source.splitlines() 116 | objects = parse_objects( 117 | Path(file_name), 118 | source 119 | ) 120 | assert len(objects) == 8 121 | 122 | file, point, shape, shape_area, circle, circle_area, rectangle, rectangle_area = objects 123 | 124 | assert file.name == file_name 125 | assert file.kind == 'file' 126 | assert file.language == 'cpp' 127 | assert file.context_before == [] 128 | assert file.context_after == [] 129 | 130 | ifndef_start, ifndef_end = source_lines.index(b'#ifndef SHAPES_H'), source_lines.index(b'#endif') 131 | 132 | assert point.name == 'Point' 133 | assert point.kind == 'definition.struct' 134 | assert point.context_before == [ifndef_start] 135 | assert point.context_after == [ifndef_end] 136 | 137 | assert shape.name == 'Shape' 138 | assert shape.kind == 'definition.class' 139 | assert shape.context_before == [ifndef_start] 140 | assert shape.context_after == [ifndef_end] 141 | 142 | shape_start = shape.coordinates[0][0] 143 | shape_end = shape.coordinates[1][0] 144 | 145 | assert shape_area.name == 'area' 146 | assert shape_area.kind == 'definition.method' 147 | assert shape_area.context_before == [ifndef_start, shape_start] 148 | assert shape_area.context_after == [ifndef_end, shape_end] 149 | 150 | assert circle.name == 'Circle' 151 | assert circle.kind == 'definition.class' 152 | assert circle.context_before == [ifndef_start] 153 | assert circle.context_after == [ifndef_end] 154 | 155 | circle_start = circle.coordinates[0][0] 156 | circle_end = circle.coordinates[1][0] 157 | 158 | assert circle_area.name == 'area' 159 | assert circle_area.kind == 'definition.method' 160 | assert circle_area.context_before == [ifndef_start, circle_start] 161 | assert circle_area.context_after == [ifndef_end, circle_end] 162 | 163 | assert rectangle.name == 'Rectangle' 164 | assert rectangle.kind == 'definition.class' 165 | assert rectangle.context_before == [ifndef_start] 166 | assert rectangle.context_after == [ifndef_end] 167 | 168 | rectangle_start = rectangle.coordinates[0][0] 169 | rectangle_end = rectangle.coordinates[1][0] 170 | 171 | assert rectangle_area.name == 'area' 172 | assert rectangle_area.kind == 'definition.method' 173 | assert rectangle_area.context_before == [ifndef_start, rectangle_start] 174 | assert rectangle_area.context_after == [ifndef_end, rectangle_end] 175 | 176 | 177 | def test_parse_c_header_file(): 178 | # TODO: Properly parse C function declarations. 179 | # Note: Definitions are correctly parsed. 180 | file_name = 'src/shapes.h' 181 | source = textwrap.dedent( 182 | """ 183 | #ifndef SHAPES_H 184 | #define SHAPES_H 185 | 186 | #include 187 | 188 | typedef struct { 189 | double x; 190 | double y; 191 | } Point; 192 | 193 | typedef struct Shape Shape; 194 | 195 | typedef double (*AreaFunc)(const Shape*); 196 | 197 | struct Shape { 198 | AreaFunc area; 199 | }; 200 | 201 | typedef struct { 202 | Shape base; 203 | double radius; 204 | } Circle; 205 | 206 | typedef struct { 207 | Shape base; 208 | double width; 209 | double height; 210 | } Rectangle; 211 | 212 | double circle_area(const Shape* shape); 213 | double rectangle_area(const Shape* shape); 214 | 215 | Circle* create_circle(double radius); 216 | Rectangle* create_rectangle(double width, double height); 217 | 218 | void destroy_shape(Shape* shape); 219 | 220 | #endif 221 | """ 222 | ).encode() 223 | source_lines = source.splitlines() 224 | objects = parse_objects( 225 | Path(file_name), 226 | source 227 | ) 228 | assert len(objects) == 6 229 | 230 | file, point, shape_fwd, shape, circle, rectangle = objects 231 | 232 | assert file.name == file_name 233 | assert file.kind == 'file' 234 | # This is not ideal, but it's fine. 235 | assert file.language == 'cpp' 236 | assert file.context_before == [] 237 | assert file.context_after == [] 238 | 239 | ifndef_start, ifndef_end = source_lines.index(b'#ifndef SHAPES_H'), source_lines.index(b'#endif') 240 | 241 | assert point.name == 'Point' 242 | assert point.kind == 'definition.type' 243 | assert point.context_before == [ifndef_start] 244 | assert point.context_after == [ifndef_end] 245 | 246 | assert shape_fwd.name == 'Shape' 247 | assert shape_fwd.kind == 'definition.type' 248 | assert shape_fwd.context_before == [ifndef_start] 249 | assert shape_fwd.context_after == [ifndef_end] 250 | 251 | assert shape.name == 'Shape' 252 | assert shape.kind == 'definition.struct' 253 | assert shape.context_before == [ifndef_start] 254 | assert shape.context_after == [ifndef_end] 255 | 256 | assert circle.name == 'Circle' 257 | assert circle.kind == 'definition.type' 258 | assert circle.context_before == [ifndef_start] 259 | assert circle.context_after == [ifndef_end] 260 | 261 | assert rectangle.name == 'Rectangle' 262 | assert rectangle.kind == 'definition.type' 263 | assert rectangle.context_before == [ifndef_start] 264 | assert rectangle.context_after == [ifndef_end] 265 | 266 | 267 | def test_parse_rust(): 268 | file_name = 'src/main.rs' 269 | source = textwrap.dedent( 270 | """ 271 | #[derive(Debug)] 272 | pub struct Point { 273 | x: f64, 274 | y: f64, 275 | } 276 | 277 | impl Point { 278 | pub fn new(x: f64, y: f64) -> Self { 279 | Self { x, y } 280 | } 281 | } 282 | 283 | fn main() { 284 | let p = Point::new(1.0, 2.0); 285 | println!("Hello, world!"); 286 | } 287 | """ 288 | ).encode() 289 | source_lines = source.splitlines() 290 | objects = parse_objects( 291 | Path(file_name), 292 | source 293 | ) 294 | assert len(objects) == 5 295 | 296 | file, point, impl, function, main = objects 297 | 298 | assert file.name == file_name 299 | assert file.kind == 'file' 300 | assert file.language == 'rust' 301 | assert file.context_before == [] 302 | assert file.context_after == [] 303 | 304 | assert point.name == 'Point' 305 | assert point.kind == 'definition.struct' 306 | 307 | assert impl.name == 'Point' 308 | assert impl.kind == 'definition.struct.impl' 309 | 310 | assert function.name == 'new' 311 | assert function.kind == 'definition.function' 312 | assert function.context_before == [impl.coordinates[0][0]] 313 | assert function.context_after == [impl.coordinates[1][0]] 314 | 315 | assert main.name == 'main' 316 | assert main.kind == 'definition.function' 317 | 318 | 319 | def test_parse_python(): 320 | file_name = 'src/main.py' 321 | source = textwrap.dedent( 322 | """ 323 | class Point: 324 | def __init__(self, x, y): 325 | self.x = x 326 | self.y = y 327 | 328 | ORIGIN = Point(0, 0) 329 | 330 | def main(): 331 | p = Point(1, 2) 332 | print("Hello, world!") 333 | """ 334 | ).encode() 335 | objects = parse_objects( 336 | Path(file_name), 337 | source 338 | ) 339 | assert len(objects) == 5 340 | 341 | file, class_, __init__, origin, main = objects 342 | 343 | assert file.name == file_name 344 | assert file.kind == 'file' 345 | assert file.language == 'python' 346 | assert file.context_before == [] 347 | assert file.context_after == [] 348 | 349 | assert class_.name == 'Point' 350 | assert class_.kind == 'definition.class' 351 | assert class_.context_before == [] 352 | assert class_.context_after == [] 353 | 354 | assert __init__.name == '__init__' 355 | assert __init__.kind == 'definition.function' 356 | assert __init__.context_before == [class_.coordinates[0][0]] 357 | assert __init__.context_after == [] 358 | 359 | assert origin.name == 'ORIGIN' 360 | assert origin.kind == 'definition.constant' 361 | assert origin.context_before == [] 362 | assert origin.context_after == [] 363 | 364 | assert main.name == 'main' 365 | assert main.kind == 'definition.function' 366 | assert main.context_before == [] 367 | assert main.context_after == [] 368 | 369 | 370 | def test_parse_c_sharp(): 371 | # I know literally nothing about this language. 372 | # If you're reading this because I made a mistake, please let me know. 373 | file_name = 'src/Main.cs' 374 | source = textwrap.dedent( 375 | """ 376 | public class Point { 377 | public double X { get; set; } 378 | public double Y { get; set; } 379 | } 380 | 381 | public static void Main() { 382 | var p = new Point { X = 1, Y = 2 }; 383 | Console.WriteLine("Hello, world!"); 384 | } 385 | """ 386 | ).encode() 387 | source_lines = source.splitlines() 388 | objects = parse_objects( 389 | Path(file_name), 390 | source 391 | ) 392 | assert len(objects) == 2 393 | 394 | file, point = objects 395 | 396 | assert file.name == file_name 397 | assert file.kind == 'file' 398 | assert file.language == 'csharp' 399 | assert file.context_before == [] 400 | assert file.context_after == [] 401 | 402 | assert point.name == 'Point' 403 | assert point.kind == 'definition.class' 404 | assert point.context_before == [] 405 | 406 | 407 | def test_go(): 408 | file_name = 'src/main.go' 409 | source = textwrap.dedent( 410 | """ 411 | package main 412 | 413 | import "fmt" 414 | 415 | type Point struct { 416 | X float64 417 | Y float64 418 | } 419 | 420 | func (*Point) Area() float64 { 421 | return 0 422 | } 423 | 424 | func main() { 425 | p := Point{X: 1, Y: 2} 426 | fmt.Println("Hello, world!") 427 | } 428 | """ 429 | ).encode() 430 | objects = parse_objects( 431 | Path(file_name), 432 | source 433 | ) 434 | file, point, area, main = objects 435 | 436 | assert file.name == file_name 437 | assert file.kind == 'file' 438 | assert file.language == 'go' 439 | assert file.context_before == [] 440 | assert file.context_after == [] 441 | 442 | assert point.name == 'Point' 443 | assert point.kind == 'definition.type' 444 | assert point.context_before == [] 445 | assert point.context_after == [] 446 | 447 | assert area.name == 'Area' 448 | assert area.kind == 'definition.method' 449 | 450 | assert main.name == 'main' 451 | assert main.kind == 'definition.function' 452 | assert main.context_before == [] 453 | assert main.context_after == [] 454 | 455 | 456 | def test_java(): 457 | file_name = 'src/Main.java' 458 | source = textwrap.dedent( 459 | """ 460 | public class Point { 461 | public double x; 462 | public double y; 463 | 464 | public double area() { 465 | return 0; 466 | } 467 | } 468 | 469 | public class Main { 470 | public static void main(String[] args) { 471 | Point p = new Point(); 472 | System.out.println("Hello, world!"); 473 | } 474 | } 475 | """ 476 | ).encode() 477 | objects = parse_objects( 478 | Path(file_name), 479 | source 480 | ) 481 | assert len(objects) == 5 482 | file, point, area, main_class, main = objects 483 | 484 | assert file.name == file_name 485 | assert file.kind == 'file' 486 | assert file.language == 'java' 487 | assert file.context_before == [] 488 | assert file.context_after == [] 489 | 490 | assert point.name == 'Point' 491 | assert point.kind == 'definition.class' 492 | assert point.context_before == [] 493 | assert point.context_after == [] 494 | 495 | assert area.name == 'area' 496 | assert area.kind == 'definition.method' 497 | assert area.context_before == [point.coordinates[0][0]] 498 | assert area.context_after == [point.coordinates[1][0]] 499 | 500 | assert main_class.name == 'Main' 501 | assert main_class.kind == 'definition.class' 502 | assert main_class.context_before == [] 503 | assert main_class.context_after == [] 504 | 505 | assert main.name == 'main' 506 | assert main.kind == 'definition.method' 507 | assert main.context_before == [main_class.coordinates[0][0]] 508 | assert main.context_after == [main_class.coordinates[1][0]] 509 | 510 | 511 | def test_ruby(): 512 | file_name = 'src/main.rb' 513 | source = textwrap.dedent( 514 | """ 515 | class Point 516 | attr_accessor :x, :y 517 | 518 | def area 519 | 0 520 | end 521 | end 522 | 523 | def main 524 | p = Point.new 525 | puts "Hello, world!" 526 | end 527 | """ 528 | ).encode() 529 | objects = parse_objects( 530 | Path(file_name), 531 | source 532 | ) 533 | assert len(objects) == 4 534 | file, point, area, main = objects 535 | 536 | assert file.name == file_name 537 | assert file.kind == 'file' 538 | assert file.language == 'ruby' 539 | assert file.context_before == [] 540 | assert file.context_after == [] 541 | 542 | assert point.name == 'Point' 543 | assert point.kind == 'definition.class' 544 | assert point.context_before == [] 545 | assert point.context_after == [] 546 | 547 | assert area.name == 'area' 548 | assert area.kind == 'definition.method' 549 | assert area.context_before == [point.coordinates[0][0]] 550 | assert area.context_after == [point.coordinates[1][0]] 551 | 552 | assert main.name == 'main' 553 | # In Ruby, all "functions" are methods. 554 | assert main.kind == 'definition.method' 555 | assert main.context_before == [] 556 | assert main.context_after == [] 557 | 558 | 559 | def test_php(): 560 | file_name = 'src/main.php' 561 | source = textwrap.dedent( 562 | """ 563 | T.NoReturn: 53 | print(message, file=sys.stderr) 54 | sys.exit(exit_code) 55 | 56 | 57 | def get_db(database_file: Path) -> sqlite3.Connection: 58 | db = sqlite3.connect(database_file) 59 | db.row_factory = sqlite3.Row 60 | return db 61 | 62 | 63 | class Events: 64 | FlushEmbeddings = namedtuple('FlushEmbeddings', []) 65 | StoreEmbeddings = namedtuple('StoreEmbeddings', ['embeddings']) 66 | Commit = namedtuple('Commit', []) 67 | FaissInserts = namedtuple('FaissInserts', ['embeddings']) 68 | IndexObjects = namedtuple('IndexObjects', ['file_bytes', 'objects', 'lines']) 69 | ScheduleEmbeddingRequests = namedtuple('EmbeddingRequests', ['requests']) 70 | FaissDeletes = namedtuple('FaissDeletes', ['ids']) 71 | IndexFile = namedtuple('IndexFile', ['path', 'content', 'lines']) 72 | DeleteFile = namedtuple('DeleteFile', ['path']) 73 | DeleteFileObjects = namedtuple('DeleteFile', ['path']) 74 | Directory = namedtuple('Directory', ['path']) 75 | File = namedtuple('File', ['path']) 76 | DeleteNotVisited = namedtuple('DeleteNotVisited', ['paths']) 77 | ReloadFileEmbeddings = namedtuple('ReloadFileEmbeddings', ['path']) 78 | 79 | 80 | def is_binary(file_bytes: bytes) -> bool: 81 | return b'\x00' in file_bytes 82 | 83 | 84 | # Put this on Dependencies object. 85 | class OpenAIRequestScheduler: 86 | def __init__(self, oai_client: "OpenAI", embedding_config: EmbeddingsConfig): 87 | self.oai_client = oai_client 88 | self.embedding_config = embedding_config 89 | self.batch = [] 90 | self.batch_tokens = 0 91 | self.batch_size_limit = 2048 92 | self.batch_token_limit = 400_000 93 | # TODO: Rate limiting. 94 | # This essentially runs a single background thread that processes requests. 95 | # This will give a modest performance boost by allowing the main thread to do other tasks. 96 | # It would be nice to run requests in parallel for short bursts to make medium-sized project indexing faster. 97 | self.max_concurrent_requests = 1 98 | self.executor = concurrent.futures.ThreadPoolExecutor( 99 | max_workers=self.max_concurrent_requests, 100 | thread_name_prefix="OpenAIRequestScheduler" 101 | ) 102 | self.futures = [] 103 | 104 | @cached_property 105 | def encoding(self) -> tiktoken.Encoding: 106 | return tiktoken.encoding_for_model(self.embedding_config.model) 107 | 108 | def schedule(self, req: EmbeddingRequest) -> T.Iterable[Embedding]: 109 | request_tokens = len(self.encoding.encode(req.content, disallowed_special=())) 110 | results = [] 111 | if request_tokens > 8192: 112 | STATS.increment("codebased.embeddings.skipped.too_long") 113 | return results 114 | if len(self.batch) >= self.batch_size_limit or self.batch_tokens + request_tokens > self.batch_token_limit: 115 | STATS.increment(STATS.Key.index_creation_embedding_tokens_consumed, self.batch_tokens) 116 | self.futures.append(self.executor.submit(self._process_batch, self.batch)) 117 | self.batch = [] 118 | self.batch_tokens = 0 119 | self.batch.append(req) 120 | self.batch_tokens += request_tokens 121 | 122 | futures = [] 123 | for future in self.futures: 124 | if future.done(): 125 | # This may raise an exception that should propagate. 126 | results.extend(future.result()) 127 | else: 128 | futures.append(future) 129 | self.futures = futures 130 | 131 | return results 132 | 133 | def flush(self) -> T.Iterable[Embedding]: 134 | if self.batch: 135 | self.futures.append(self.executor.submit(self._process_batch, self.batch)) 136 | self.batch = [] 137 | self.batch_tokens = 0 138 | 139 | results = [] 140 | for future in concurrent.futures.as_completed(self.futures): 141 | results.extend(future.result()) 142 | self.futures.clear() 143 | return results 144 | 145 | def _process_batch(self, batch: T.List[EmbeddingRequest]) -> T.Iterable[Embedding]: 146 | results = create_openai_embeddings_sync_batched(self.oai_client, batch, self.embedding_config) 147 | return results 148 | 149 | 150 | @dataclasses.dataclass 151 | class Config: 152 | flags: Flags 153 | 154 | @cached_property 155 | def root(self) -> Path: 156 | git_repository_dir = find_root_git_repository(self.flags.directory) 157 | if git_repository_dir is None: 158 | exit_with_error('Codebased must be run within a Git repository.') 159 | print(f'Found Git repository {git_repository_dir}') 160 | git_repository_dir: Path = git_repository_dir 161 | return git_repository_dir 162 | 163 | @property 164 | def git_directory(self) -> Path: 165 | return self.root / '.git' 166 | 167 | @property 168 | def codebased_directory(self) -> Path: 169 | directory = self.root / '.codebased' 170 | directory.mkdir(exist_ok=True) 171 | return directory 172 | 173 | @property 174 | def index_path(self) -> Path: 175 | return self.codebased_directory / 'index.faiss' 176 | 177 | @cached_property 178 | def rebuild_faiss_index(self) -> bool: 179 | return self.flags.rebuild_faiss_index or not self.index_path.exists() 180 | 181 | 182 | K = TypeVar('K') 183 | V = TypeVar('V') 184 | 185 | 186 | @dataclasses.dataclass 187 | class ThreadSafeCache(Generic[K, V]): 188 | _lock: threading.Lock = dataclasses.field(default_factory=threading.Lock) 189 | _cache: Dict[K, V] = dataclasses.field(default_factory=dict) 190 | 191 | def __getitem__(self, key: K) -> V: 192 | with self._lock: 193 | return self._cache[key] 194 | 195 | def __setitem__(self, key: K, value: V) -> None: 196 | with self._lock: 197 | self._cache[key] = value 198 | 199 | def __delitem__(self, key: K) -> None: 200 | with self._lock: 201 | del self._cache[key] 202 | 203 | def clear(self) -> None: 204 | with self._lock: 205 | self._cache.clear() 206 | 207 | def get(self, key: K, default: V = None) -> V: 208 | with self._lock: 209 | return self._cache.get(key, default) 210 | 211 | 212 | class thread_local_cached_property: 213 | def __init__(self, func): 214 | self.func = func 215 | self.name = None 216 | 217 | def __set_name__(self, owner, name): 218 | self.name = name 219 | 220 | def __get__(self, obj, cls=None): 221 | if obj is None: 222 | return self 223 | 224 | thread_id = threading.get_ident() 225 | attr_name = f'_thread_local_cache_{self.name}_{thread_id}' 226 | 227 | if not hasattr(obj, attr_name): 228 | setattr(obj, attr_name, self.func(obj)) 229 | return getattr(obj, attr_name) 230 | 231 | def clear_cache(self, obj): 232 | thread_id = threading.get_ident() 233 | attr_name = f'_thread_local_cache_{self.name}_{thread_id}' 234 | if hasattr(obj, attr_name): 235 | delattr(obj, attr_name) 236 | 237 | 238 | def clear_thread_local_cache(func): 239 | @functools.wraps(func) 240 | def wrapper(self, *args, **kwargs): 241 | result = func(self, *args, **kwargs) 242 | for name, attr in type(self).__dict__.items(): 243 | if isinstance(attr, thread_local_cached_property): 244 | attr.clear_cache(self) 245 | return result 246 | 247 | return wrapper 248 | 249 | 250 | @dataclasses.dataclass 251 | class Dependencies: 252 | # config must be passed in explicitly. 253 | config: Config 254 | settings: Settings 255 | search_cache: ThreadSafeCache[Flags, list] = dataclasses.field(default_factory=ThreadSafeCache) 256 | ignores: dict[Path, T.Callable[[Path], bool | None]] = dataclasses.field(default_factory=dict) 257 | 258 | @cached_property 259 | def openai_client(self) -> "OpenAI": 260 | from openai import OpenAI 261 | 262 | return OpenAI(api_key=self.settings.OPENAI_API_KEY) 263 | 264 | @cached_property 265 | def index(self) -> faiss.Index: 266 | if self.config.rebuild_faiss_index: 267 | index = faiss.IndexIDMap2(faiss.IndexFlatL2(self.settings.embeddings.dimensions)) 268 | if not self.config.index_path.exists(): 269 | faiss.write_index(index, str(self.config.index_path)) 270 | else: 271 | index = faiss.read_index(str(self.config.index_path)) 272 | return index 273 | 274 | @thread_local_cached_property 275 | def db(self) -> sqlite3.Connection: 276 | db = get_db(self.config.codebased_directory / 'codebased.db') 277 | migrations = DatabaseMigrations(db, Path(__file__).parent / 'migrations') 278 | migrations.initialize() 279 | migrations.migrate() 280 | return db 281 | 282 | def ignore_checker(self, path: Path) -> bool: 283 | original_path = path 284 | while path.is_relative_to(self.config.root): 285 | try: 286 | ignore_rule_set = self.ignores[path / '.gitignore'] 287 | ignore_result = ignore_rule_set(original_path) 288 | if ignore_result is not None: 289 | return ignore_result 290 | except KeyError: 291 | continue 292 | finally: 293 | path = path.parent 294 | return False 295 | 296 | @cached_property 297 | def request_scheduler(self) -> OpenAIRequestScheduler: 298 | return OpenAIRequestScheduler(self.openai_client, self.settings.embeddings) 299 | 300 | 301 | class FileExceptions: 302 | class AlreadyIndexed(Exception): 303 | """ 304 | File has already been indexed. 305 | """ 306 | pass 307 | 308 | class Ignore(Exception): 309 | """ 310 | File cannot be indexed because it's binary or not UTF-8 / UTF-16. 311 | """ 312 | pass 313 | 314 | class Delete(Exception): 315 | """ 316 | File should be deleted. 317 | """ 318 | pass 319 | 320 | 321 | def index_paths( 322 | dependencies: Dependencies, 323 | config: Config, 324 | paths_to_index: list[Path], 325 | *, 326 | total: bool = True 327 | ): 328 | ignore = dependencies.ignore_checker 329 | db = dependencies.db 330 | index = dependencies.index 331 | 332 | rebuilding_faiss_index = config.rebuild_faiss_index 333 | if not total: 334 | rebuilding_faiss_index = False 335 | 336 | dependencies.db.execute("begin;") 337 | # We can actually be sure we visit each file at most once. 338 | # Also, we don't need O(1) contains checks. 339 | # So use a list instead of a set. 340 | # May be useful to see what the traversal order was too. 341 | embeddings_to_index: list[Embedding] = [] 342 | deletion_markers = [] 343 | paths_visited = [] 344 | events = [ 345 | Events.Commit(), 346 | # Add to FAISS after deletes, because SQLite can reuse row ids. 347 | Events.FaissInserts(embeddings_to_index), 348 | Events.FaissDeletes(deletion_markers), 349 | Events.FlushEmbeddings(), 350 | *[Events.Directory(x) if x.is_dir() else Events.File(x) for x in paths_to_index] 351 | ] 352 | if total: 353 | events.insert(3, Events.DeleteNotVisited(paths_visited)) 354 | 355 | # Why do we need to put space before file? 356 | pbar = tqdm(total=None, desc=f"Indexing {config.root.name}", unit=" file") 357 | 358 | try: 359 | while events: 360 | event = events.pop() 361 | STATS.increment(f"codebased.index.events.{type(event).__name__}.total") 362 | if isinstance(event, Events.Directory): 363 | path = event.path 364 | if path == config.root / '.git' or path == config.root / '.codebased': 365 | continue 366 | dir_entries = list(os.scandir(path)) 367 | # TODO: We don't handle changes to .gitignore files in the background worker 368 | # because it could require re-scanning the entire index. 369 | # i.e. if you remove an ignore rule, we might need to add previously ignored files to the index. 370 | # or if you add a new ignore rule, we might need to remove existing files from the index. 371 | try: 372 | gitignore_file = next(e for e in dir_entries if e.name == '.gitignore' and e.is_file()) 373 | ignore_path = Path(gitignore_file.path) 374 | gitignore_parsed = parse_gitignore(ignore_path, base_dir=path) 375 | dependencies.ignores[ignore_path] = gitignore_parsed 376 | except StopIteration: 377 | pass 378 | for entry in dir_entries: 379 | entry_path = Path(entry.path) 380 | if ignore(entry_path): # noqa 381 | continue 382 | try: 383 | if entry.is_symlink(): 384 | continue 385 | if entry.is_dir() and not entry.name.startswith('.'): 386 | events.append(Events.Directory(entry_path)) 387 | elif entry.is_file(): 388 | events.append(Events.File(entry_path)) 389 | except PermissionError: 390 | continue 391 | elif isinstance(event, Events.File): 392 | path = event.path 393 | assert isinstance(path, Path) 394 | relative_path = path.relative_to(config.root) 395 | pbar.update(1) 396 | try: 397 | if not (path.exists() and path.is_file()): 398 | raise FileExceptions.Delete() 399 | # TODO: This is hilariously slow. 400 | paths_visited.append(relative_path) 401 | 402 | result = db.execute( 403 | """ 404 | select 405 | size_bytes, 406 | last_modified_ns, 407 | sha256_digest 408 | from file 409 | where path = :path; 410 | """, 411 | {'path': str(relative_path)} 412 | ).fetchone() 413 | 414 | stat = path.stat() 415 | if result is not None: 416 | size, last_modified, previous_sha256_digest = result 417 | if stat.st_size == size and stat.st_mtime == last_modified: 418 | raise FileExceptions.AlreadyIndexed() 419 | else: 420 | previous_sha256_digest = None 421 | 422 | try: 423 | file_bytes = path.read_bytes() 424 | except FileNotFoundError: 425 | raise FileExceptions.Delete() 426 | # Ignore binary files. 427 | if is_binary(file_bytes): 428 | raise FileExceptions.Ignore() 429 | # TODO: See how long this takes on large repos. 430 | # TODO: We might want to memoize the "skip" results if this is an issue. 431 | decoded_text = decode_text(file_bytes) 432 | if decoded_text is None: 433 | raise FileExceptions.Ignore() 434 | real_sha256_digest = hashlib.sha256(file_bytes).digest() 435 | # TODO: To support incremental indexing, i.e. allowing this loop to make progress if interrupted 436 | # we would need to wait until the objects, embeddings, FTS index, etc. are computed to insert. 437 | db.execute( 438 | """ 439 | insert into file 440 | (path, size_bytes, last_modified_ns, sha256_digest) 441 | values 442 | (:path, :size_bytes, :last_modified_ns, :sha256_digest) 443 | on conflict (path) do update 444 | set size_bytes = :size_bytes, 445 | last_modified_ns = :last_modified_ns, 446 | sha256_digest = :sha256_digest; 447 | """, 448 | { 449 | 'path': str(relative_path), 450 | 'size_bytes': stat.st_size, 451 | 'last_modified_ns': stat.st_mtime_ns, 452 | 'sha256_digest': real_sha256_digest 453 | } 454 | ) 455 | # Do this after updating the DB, because a write to SQLite is cheaper than reading a file. 456 | # https://www.sqlite.org/fasterthanfs.html 457 | if previous_sha256_digest == real_sha256_digest: 458 | raise FileExceptions.AlreadyIndexed() 459 | # Actually schedule the file for indexing. 460 | events.append(Events.IndexFile(relative_path, file_bytes, decoded_text.splitlines())) 461 | # Delete old objects before adding new ones. 462 | events.append(Events.DeleteFileObjects(relative_path)) 463 | continue 464 | except FileExceptions.Delete: 465 | events.append(Events.DeleteFile(path)) 466 | # Need to run this first due to foreign key constraints. 467 | events.append(Events.DeleteFileObjects(path)) 468 | continue 469 | except FileExceptions.AlreadyIndexed: 470 | if rebuilding_faiss_index: 471 | events.append(Events.ReloadFileEmbeddings(relative_path)) 472 | continue 473 | except FileExceptions.Ignore: 474 | continue 475 | elif isinstance(event, Events.ReloadFileEmbeddings): 476 | # Could do this in a single query at the end. 477 | path = event.path 478 | assert isinstance(path, Path) 479 | embedding_rows = db.execute( 480 | """ 481 | select 482 | object_id, 483 | content_sha256, 484 | data 485 | from embedding 486 | where object_id in ( 487 | select id from object 488 | where path = :path 489 | ) 490 | """, 491 | {'path': str(path)} 492 | ).fetchall() 493 | embeddings = [ 494 | Embedding( 495 | object_id=x['object_id'], 496 | data=deserialize_embedding_data(x['data']), 497 | content_hash=x['content_sha256'] 498 | ) 499 | for x in embedding_rows 500 | ] 501 | embeddings_to_index.extend(embeddings) 502 | elif isinstance(event, Events.DeleteFile): 503 | relative_path = event.path 504 | assert isinstance(relative_path, Path) 505 | db.execute( 506 | """ 507 | delete from file 508 | where path = :path 509 | """, 510 | {'path': str(relative_path)} 511 | ) 512 | elif isinstance(event, Events.DeleteFileObjects): 513 | relative_path = event.path 514 | id_tuples = db.execute( 515 | """ 516 | delete from object 517 | where path = :path 518 | returning id; 519 | """, 520 | {'path': str(relative_path)} 521 | ).fetchall() 522 | deleted_ids = [x[0] for x in id_tuples] 523 | if deleted_ids: 524 | in_clause = ', '.join(['?'] * len(deleted_ids)) 525 | db.execute( 526 | f""" 527 | delete from fts where rowid in ( {in_clause} ); 528 | """, 529 | deleted_ids 530 | ) 531 | # These are relatively expensive to compute, and accessible by their hash, so keep them around. 532 | # db.execute( 533 | # f""" 534 | # delete from embedding 535 | # where object_id = ( {in_clause} ); 536 | # """, 537 | # deleted_ids 538 | # ) 539 | deletion_markers.extend(deleted_ids) 540 | elif isinstance(event, Events.IndexFile): 541 | relative_path, file_bytes, lines = event.path, event.content, event.lines 542 | assert isinstance(relative_path, Path) 543 | assert isinstance(file_bytes, bytes) 544 | 545 | objects = parse_objects(relative_path, file_bytes) 546 | objects_by_id: dict[int, Object] = {} 547 | for obj in objects: 548 | object_id, = db.execute( 549 | """ 550 | insert into object 551 | (path, name, language, context_before, context_after, kind, byte_range, coordinates) 552 | values 553 | (:path, :name, :language, :context_before, :context_after, :kind, :byte_range, :coordinates) 554 | returning id; 555 | """, 556 | { 557 | 'path': str(obj.path), 558 | 'name': obj.name, 559 | 'language': obj.language, 560 | 'context_before': json.dumps(obj.context_before), 561 | 'context_after': json.dumps(obj.context_after), 562 | 'kind': obj.kind, 563 | 'byte_range': json.dumps(obj.byte_range), 564 | 'coordinates': json.dumps(obj.coordinates) 565 | } 566 | 567 | ).fetchone() 568 | objects_by_id[object_id] = obj 569 | events.append(Events.IndexObjects(file_bytes, objects_by_id, lines)) 570 | elif isinstance(event, Events.IndexObjects): 571 | file_bytes = event.file_bytes 572 | in_lines = event.lines 573 | objects_by_id = event.objects 574 | # dict[int, Object] 575 | assert isinstance(objects_by_id, dict) 576 | requests_to_schedule = [] 577 | for obj_id, obj in objects_by_id.items(): 578 | rendered = render_object(obj, in_lines=in_lines, file=False) 579 | if not rendered: 580 | STATS.increment("codebased.embeddings.skipped.empty") 581 | continue 582 | request = EmbeddingRequest( 583 | object_id=obj_id, 584 | content=rendered, 585 | content_hash=hashlib.sha256(rendered.encode('utf-8')).hexdigest(), 586 | ) 587 | requests_to_schedule.append(request) 588 | events.append(Events.ScheduleEmbeddingRequests(requests=requests_to_schedule)) 589 | db.executemany( 590 | """ 591 | insert into fts 592 | (rowid, path, name, content) 593 | values 594 | (:object_id, :path, :name, :content); 595 | """, 596 | [ 597 | { 598 | 'object_id': obj_id, 599 | 'path': str(obj.path), 600 | 'name': obj.name, 601 | 'content': file_bytes[obj.byte_range[0]:obj.byte_range[1]] 602 | } 603 | for obj_id, obj in objects_by_id.items() 604 | ] 605 | ) 606 | elif isinstance(event, Events.ScheduleEmbeddingRequests): 607 | requests_to_schedule = event.requests 608 | embeddings_batch = [] 609 | for request in requests_to_schedule: 610 | existing_embedding = db.execute( 611 | """ 612 | select data from embedding 613 | where content_sha256 = :content_sha256; 614 | """, 615 | {'content_sha256': request.content_hash} 616 | ).fetchone() 617 | if existing_embedding is not None: 618 | embedding = Embedding( 619 | object_id=request.object_id, 620 | data=deserialize_embedding_data(existing_embedding['data']), 621 | content_hash=request.content_hash 622 | ) 623 | embeddings_batch.append(embedding) 624 | else: 625 | embeddings = dependencies.request_scheduler.schedule(request) 626 | embeddings_batch.extend(embeddings) 627 | events.append(Events.StoreEmbeddings(embeddings=embeddings_batch)) 628 | elif isinstance(event, Events.FlushEmbeddings): 629 | if 'request_scheduler' in dependencies.__dict__: 630 | results = dependencies.request_scheduler.flush() 631 | events.append(Events.StoreEmbeddings(embeddings=results)) 632 | elif isinstance(event, Events.StoreEmbeddings): 633 | embeddings_batch = event.embeddings 634 | if not embeddings_batch: 635 | continue 636 | db.executemany( 637 | """ 638 | insert into embedding 639 | (object_id, data, content_sha256) 640 | values 641 | (:object_id, :data, :content_sha256) 642 | on conflict (object_id) do update 643 | set data = :data, 644 | content_sha256 = :content_sha256; 645 | """, 646 | [ 647 | { 648 | 'object_id': e1.object_id, 649 | 'data': serialize_embedding_data(e1.data), 650 | 'content_sha256': e1.content_hash 651 | } 652 | for e1 in embeddings_batch 653 | ] 654 | ) 655 | embeddings_to_index.extend(embeddings_batch) 656 | elif isinstance(event, Events.FaissInserts): 657 | if embeddings_to_index: 658 | index.add_with_ids( 659 | np.array([e.data for e in event.embeddings]), 660 | [e.object_id for e in event.embeddings] 661 | ) 662 | event.embeddings.clear() 663 | elif isinstance(event, Events.FaissDeletes): 664 | delete_ids = event.ids 665 | if delete_ids: 666 | index.remove_ids(np.array(delete_ids)) 667 | delete_ids.clear() 668 | elif isinstance(event, Events.Commit): 669 | dependencies.search_cache.clear() 670 | db.execute("insert into fts(fts) values ('optimize');") 671 | db.commit() 672 | faiss.write_index(index, str(config.index_path)) 673 | elif isinstance(event, Events.DeleteNotVisited): 674 | inverse_paths = [str(path) for path in event.paths] 675 | in_clause = ', '.join(['?'] * len(inverse_paths)) 676 | id_tuples = dependencies.db.execute( 677 | f""" 678 | delete from object 679 | where path not in ({in_clause}) 680 | returning id; 681 | """, 682 | inverse_paths 683 | ).fetchall() 684 | dependencies.db.execute( 685 | f""" 686 | delete from file 687 | where path not in ( {in_clause} ); 688 | """, 689 | inverse_paths 690 | ) 691 | deleted_ids = [x[0] for x in id_tuples] 692 | in_clause = ', '.join(['?'] * len(deleted_ids)) 693 | dependencies.db.execute( 694 | f""" 695 | delete from fts where rowid in ( {in_clause} ); 696 | """, 697 | deleted_ids 698 | ) 699 | deletion_markers.extend(deleted_ids) 700 | else: 701 | raise NotImplementedError(event) 702 | else: 703 | pass 704 | except: 705 | db.rollback() 706 | raise 707 | 708 | 709 | @dataclasses.dataclass(frozen=True) 710 | class Flags: 711 | directory: Path 712 | background: bool 713 | # TODO: These conflict and suck. 714 | rebuild_faiss_index: bool 715 | cached_only: bool 716 | stats: bool 717 | semantic: bool 718 | full_text_search: bool 719 | top_k: int 720 | query: str 721 | rerank: bool 722 | radius: float 723 | -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | import string 5 | 6 | import dataclasses 7 | import os 8 | import re 9 | import sqlite3 10 | import tempfile 11 | import typing as T 12 | import unittest 13 | from contextlib import contextmanager 14 | from pathlib import Path 15 | from typing import Union 16 | 17 | import faiss 18 | from rich.syntax import Syntax 19 | from textual.widgets import Input, ListView, Static 20 | 21 | from codebased.index import find_root_git_repository, Flags, Config, Dependencies, index_paths 22 | from codebased.main import VERSION 23 | from codebased.search import Query, find_highlights 24 | from codebased.settings import Settings 25 | from codebased.tui import Codebased, Id 26 | 27 | SUBMODULE_REPO_TREE = (Path('.'), ( 28 | (Path('README.md'), b'Hello, world!'), 29 | (Path('.git'), ()), 30 | (Path('submodule'), ( 31 | (Path('README.md'), b'Hello, world!'), 32 | ( 33 | Path('below-submodule'), 34 | ( 35 | ( 36 | ( 37 | Path('code.py'), 38 | b'print("Hello, world!")'), 39 | ) 40 | ) 41 | ), 42 | # Git submodules contain a .git **FILE**. 43 | # This points to a subdirectory of the parent .git/modules directory. 44 | (Path( 45 | '.git' 46 | ), b'') 47 | )) 48 | ) 49 | ) 50 | 51 | SIMPLE_NOT_REPO_TREE = ( 52 | Path('.'), 53 | ( 54 | (Path('README.md'), b'Hello, world!'), 55 | (Path('a-directory'), ( 56 | ((Path('code.py'), b'print("Hello, world!")'),) 57 | )), 58 | ) 59 | ) 60 | 61 | SIMPLE_REPO_TREE = ( 62 | Path( 63 | '.' 64 | ), 65 | ( 66 | (Path( 67 | 'README.md' 68 | ), b'Hello, world!'), 69 | (Path( 70 | 'a-directory' 71 | ), ( 72 | ((Path( 73 | 'code.py' 74 | ), b'print("Hello, world!")'),) 75 | )), 76 | (Path( 77 | '.git' 78 | ), ()), 79 | ) 80 | ) 81 | 82 | # node_modules in .gitignore 83 | GITIGNORE_FOLDER_TREE = ( 84 | Path('.'), 85 | ( 86 | (Path('README.md'), b'Hello, world!'), 87 | (Path('.git'), ()), 88 | (Path('.gitignore'), b'node_modules/'), 89 | ( 90 | Path( 91 | 'node_modules' 92 | ), ( 93 | ( 94 | Path('slop'), ( 95 | (Path('slop.js'), b'console.log("Hello, world!");'), 96 | (Path('slop.d.ts'), b'declare function slop(): void;'), 97 | ) 98 | ), 99 | ) 100 | ), 101 | (Path('src'), ( 102 | (Path('index.js'), 103 | b'const express = require("express");\nconst app = express();\napp.get("/", (req, res) => {\n res.send("Hello, world!");\n});\n\napp.listen(3000, () => {\n console.log("Server started on port 3000");\n});\n'), 104 | )), 105 | ( 106 | Path('package.json'), 107 | b'{\n "name": "test",\n "version": "1.0.0",\n "description": "",\n "main": "index.js",\n "scripts": {\n "test": "echo "Error: no test specified" && exit 1"\n },\n "author": "",\n "license": "ISC",\n "dependencies": {\n "slop": "^1.0.0"\n }\n}\n' 108 | ) 109 | ) 110 | ) 111 | 112 | HIDDEN_FOLDER_TREE = ( 113 | ( 114 | Path('.'), 115 | ( 116 | (Path('.git'), ()), 117 | (Path('README.md'), b'Hello, world!'), 118 | ( 119 | Path('a-directory'), ( 120 | (Path('code.py'), b'print("Hello, world!")'), 121 | ) 122 | ), 123 | ( 124 | Path('.venv'), 125 | ( 126 | ( 127 | Path('bin'), 128 | ( 129 | (Path('activate'), b'this is a script of some sort'), 130 | ) 131 | ), 132 | ( 133 | Path('lib'), 134 | ( 135 | ( 136 | Path('python3.10'), 137 | ( 138 | ( 139 | Path('site-packages'), 140 | ( 141 | (Path('slop'), 142 | (( 143 | Path('site.py'), 144 | b'print("Hello, world!")' 145 | ),), 146 | 147 | ), 148 | 149 | ) 150 | ), 151 | ) 152 | ), 153 | ) 154 | ), 155 | ) 156 | ) 157 | ) 158 | ) 159 | ) 160 | 161 | NESTED_GITIGNORE_TREE = ( 162 | Path('.'), 163 | ( 164 | (Path('.gitignore'), b'*.txt'), 165 | # Should be ignored. 166 | (Path('trash.txt'), b'Hello, world!'), 167 | (Path('README.md'), b'Hello, world!'), 168 | (Path('.git'), ()), 169 | ( 170 | Path('app'), ( 171 | (Path('.gitignore'), b'node_modules/'), 172 | # Should be ignored. 173 | (Path('trash.txt'), b'Hello, world!'), 174 | (Path('src'), ( 175 | (Path('index.d.ts'), b'console.log("Hello, world!")'), 176 | (Path('index.js'), b'console.log("Hello, world!");'), 177 | )), 178 | (Path('package.json'), b'{"name": "slop"}'), 179 | # Should be ignored. 180 | (Path('node_modules'), ( 181 | (Path('slop'), ( 182 | (Path('slop.js'), b'console.log("Hello, world!");'), 183 | (Path('slop.d.ts'), b'declare function slop(): void;'), 184 | )), 185 | )), 186 | ) 187 | ), 188 | ( 189 | Path('server'), ( 190 | (Path('.gitignore'), b'venv/\n__pycache__/'), 191 | # Should be ignored. 192 | (Path('trash.txt'), b'Hello, world!'), 193 | (Path('src'), ( 194 | (Path('__pycache__'), ( 195 | (Path('main.cpython-311.pyc'), b''), 196 | (Path('__init__.cpython-311.pyc'), b''), 197 | )), 198 | (Path('main.py'), b'print("Hello, world!")'), 199 | (Path('__init__.py'), b'from .main import *'), 200 | )), 201 | (Path('setup.py'), b'{"name": "slop"}'), 202 | # Should be ignored. 203 | (Path('venv'), ( 204 | (Path('slop'), ( 205 | (Path('slop.py'), b'slop = 1'), 206 | (Path('__init__.py'), b''), 207 | )), 208 | )), 209 | ) 210 | ), 211 | ) 212 | ) 213 | 214 | 215 | @dataclasses.dataclass 216 | class CliTestCase: 217 | tree: T.Any 218 | objects: int 219 | files: int 220 | 221 | def create(self, path: Path): 222 | create_tree(self.tree, path) 223 | 224 | 225 | IGNORE_FOLDER_TEST_CASE = CliTestCase(tree=GITIGNORE_FOLDER_TREE, objects=6, files=4) 226 | 227 | HIDDEN_FOLDER_TEST_CASE = CliTestCase(tree=HIDDEN_FOLDER_TREE, objects=2, files=2) 228 | 229 | SIMPLE_REPO_TEST_CASE = CliTestCase(tree=SIMPLE_REPO_TREE, objects=2, files=2) 230 | NESTED_GITIGNORE_TEST_CASE = CliTestCase(tree=NESTED_GITIGNORE_TREE, objects=10, files=10) 231 | 232 | 233 | # Algebraically: 234 | # File = tuple[Path, bytes] 235 | # Directory = tuple[Path, tuple[DirEntry]] 236 | # DirEntry = Union[File | Directory] 237 | 238 | def create_tree(dir_entry, relative_to: Path): 239 | path, contents = dir_entry 240 | absolute_path = relative_to / path 241 | if isinstance(contents, bytes): 242 | absolute_path.write_bytes(contents) 243 | else: 244 | absolute_path.mkdir(exist_ok=True) 245 | for entry in contents: 246 | create_tree(entry, absolute_path) 247 | 248 | 249 | class TestGitDetection(unittest.TestCase): 250 | def test_in_a_regular_git_repository(self): 251 | with tempfile.TemporaryDirectory() as tempdir: 252 | path = Path(tempdir).resolve() 253 | create_tree(SIMPLE_REPO_TREE, path) 254 | test_paths = [ 255 | path, 256 | path / 'a-directory', 257 | path / 'a-directory' / 'code.py', 258 | path / '.git', 259 | path / 'README.md' 260 | ] 261 | for test_path in test_paths: 262 | with self.subTest(test_path=test_path): 263 | result = find_root_git_repository(test_path) 264 | self.assertEqual(result, path) 265 | 266 | def test_in_a_git_repository_with_submodules(self): 267 | with tempfile.TemporaryDirectory() as tempdir: 268 | path = Path(tempdir).resolve() 269 | create_tree(SUBMODULE_REPO_TREE, path) 270 | test_paths = [ 271 | path, 272 | path / 'submodule', 273 | path / 'submodule' / 'below-submodule', 274 | path / 'submodule' / 'below-submodule' / 'code.py', 275 | path / 'README.md' 276 | ] 277 | for test_path in test_paths: 278 | with self.subTest(test_path=test_path): 279 | result = find_root_git_repository(test_path) 280 | self.assertEqual(result, path) 281 | 282 | def test_run_outside_a_git_repository(self): 283 | with tempfile.TemporaryDirectory() as tempdir: 284 | path = Path(tempdir).resolve() 285 | create_tree(SIMPLE_NOT_REPO_TREE, path) 286 | test_paths = [ 287 | path / 'a-directory', 288 | path / 'a-directory' / 'code.py', 289 | path / 'README.md', 290 | path 291 | ] 292 | for test_path in test_paths: 293 | with self.subTest(test_path=test_path): 294 | result = find_root_git_repository(test_path) 295 | self.assertIs(result, None) 296 | 297 | def test_run_at_root(self): 298 | # You never know what's going to happen on someone's laptop. 299 | # Someone's probably managing their entire filesystem using Git. 300 | # That's how you *win* this test, instead of merely passing it. 301 | root = Path('/') 302 | if (root / '.git').is_dir(): 303 | assert find_root_git_repository(root) == root 304 | else: 305 | assert find_root_git_repository(root) is None 306 | 307 | 308 | StreamAssertion = Union[bytes, re.Pattern, None] 309 | 310 | import re 311 | import subprocess 312 | from pathlib import Path 313 | from typing import Union 314 | 315 | StreamAssertion = Union[bytes, re.Pattern] 316 | 317 | 318 | def check_codebased_cli( 319 | *, 320 | cwd: Path, 321 | exit_code: int, 322 | stderr: StreamAssertion, 323 | stdout: StreamAssertion, 324 | args: list[str], 325 | ascii_only: bool = False 326 | ): 327 | proc = subprocess.run( 328 | ['python', '-m', 'codebased.main', *args], 329 | cwd=cwd.resolve(), 330 | stdout=subprocess.PIPE, 331 | stderr=subprocess.PIPE, 332 | stdin=subprocess.PIPE, 333 | env=os.environ, 334 | ) 335 | actual_stdout, actual_stderr = proc.stdout, proc.stderr 336 | 337 | if ascii_only: 338 | # Keep only ASCII characters (0x00-0x7F) 339 | ascii_pattern = re.compile(b"[^" + re.escape(string.printable.encode()) + b"]") 340 | actual_stdout = ascii_pattern.sub(b'', actual_stdout) 341 | actual_stderr = ascii_pattern.sub(b'', actual_stderr) 342 | 343 | if proc.returncode != 0 and proc.returncode != exit_code: 344 | print(f'stdout: {actual_stdout.decode("utf-8")}') 345 | print(f'stderr: {actual_stderr.decode("utf-8")}') 346 | 347 | assert proc.returncode == exit_code, f'{proc.returncode} != {exit_code}, stdout: {actual_stdout}, stderr: {actual_stderr}' 348 | 349 | if isinstance(stdout, bytes): 350 | assert actual_stdout == stdout, f'{actual_stdout} != {stdout}' 351 | elif isinstance(stdout, re.Pattern): 352 | assert stdout.search(actual_stdout), f'Pattern not found in stdout: {actual_stdout}' 353 | 354 | if isinstance(stderr, bytes): 355 | assert actual_stderr == stderr, f'{actual_stderr} != {stderr}' 356 | elif isinstance(stderr, re.Pattern): 357 | assert stderr.search(actual_stderr), f'Pattern not found in stderr: {actual_stderr}' 358 | 359 | return proc 360 | 361 | 362 | def check_search_command( 363 | *, 364 | args: list[str], 365 | root: Path, 366 | cwd: Path, 367 | exit_code: int, 368 | stderr: StreamAssertion, 369 | stdout: StreamAssertion, 370 | expected_object_count: int | None = None, 371 | expected_file_count: int | None = None 372 | ): 373 | # working directory 374 | # root directory 375 | check_codebased_cli( 376 | cwd=cwd, 377 | exit_code=exit_code, 378 | stderr=stderr, 379 | stdout=stdout, 380 | args=args 381 | ) 382 | codebased_dir = root / '.codebased' 383 | assert codebased_dir.exists() 384 | check_db( 385 | codebased_dir / 'codebased.db', 386 | expected_object_count=expected_object_count, 387 | expected_file_count=expected_file_count 388 | ) 389 | check_faiss_index( 390 | codebased_dir / 'index.faiss', 391 | expected_object_count=expected_object_count 392 | ) 393 | 394 | 395 | def check_faiss_index(path: Path, *, expected_object_count: int | None): 396 | assert path.exists() 397 | if expected_object_count is not None: 398 | faiss_index = faiss.read_index(str(path)) 399 | assert faiss_index.id_map.size() == expected_object_count 400 | 401 | 402 | def check_db( 403 | db_path: Path, 404 | *, 405 | expected_object_count: int | None, 406 | expected_file_count: int | None 407 | ): 408 | assert db_path.exists() 409 | with sqlite3.connect(db_path) as db: 410 | if expected_object_count is not None: 411 | cursor = db.execute('select count(*) from object') 412 | actual_object_count = cursor.fetchone()[0] 413 | assert actual_object_count == expected_object_count 414 | cursor = db.execute('select count(*) from fts') 415 | actual_fts_object_count = cursor.fetchone()[0] 416 | assert actual_fts_object_count == expected_object_count 417 | cursor = db.execute('select count(*) from embedding where object_id in (select id from object)') 418 | actual_embedding_count = cursor.fetchone()[0] 419 | assert actual_embedding_count == expected_object_count 420 | if expected_file_count is not None: 421 | cursor = db.execute('select count(*) from file') 422 | actual_file_count = cursor.fetchone()[0] 423 | assert actual_file_count == expected_file_count 424 | 425 | 426 | @contextmanager 427 | def check_file_did_not_change(path: Path): 428 | stat = path.stat() 429 | try: 430 | yield 431 | finally: 432 | # May be overly strict. 433 | assert path.stat() == stat 434 | 435 | 436 | class TestCli(unittest.TestCase): 437 | def test_debug(self): 438 | with tempfile.TemporaryDirectory() as tempdir: 439 | stdout = re.compile( 440 | b"Codebased: \d+\.\d+\.\d+.*Python: \d+\.\d+\.\d+.*SQLite: \d+\.\d+\.\d+.*FAISS: \d+\.\d+\.\d+.*OpenAI: \d+\.\d+\.\d+.*", 441 | re.ASCII | re.DOTALL 442 | ) 443 | stderr = b"" 444 | check_codebased_cli( 445 | args=["debug"], 446 | cwd=Path(tempdir).resolve(), 447 | exit_code=0, 448 | stderr=stderr, 449 | stdout=stdout, 450 | ) 451 | 452 | def test_run_outside_a_git_repository(self): 453 | with tempfile.TemporaryDirectory() as tempdir: 454 | path = Path(tempdir).resolve() 455 | create_tree(SIMPLE_NOT_REPO_TREE, path) 456 | exit_code = 1 457 | stdout = b"" 458 | stderr = b"Codebased must be run within a Git repository.\n" 459 | check_codebased_cli( 460 | cwd=path, 461 | exit_code=exit_code, 462 | stderr=stderr, 463 | stdout=stdout, 464 | args=["search", "Hello world"] 465 | ) 466 | 467 | def test_run_inside_a_git_repository(self): 468 | with tempfile.TemporaryDirectory() as tempdir: 469 | path = Path(tempdir).resolve() 470 | # Simple repo has two files. 471 | SIMPLE_REPO_TEST_CASE.create(path) 472 | exit_code = 0 473 | stdout = re.compile(b"Found Git repository " + str(path).encode("utf-8") + b"\n", re.ASCII) 474 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 475 | check_search_command( 476 | args=["search", "Hello world"], 477 | root=path, 478 | cwd=path, 479 | exit_code=exit_code, 480 | stderr=stderr, 481 | stdout=stdout, 482 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 483 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 484 | ) 485 | check_search_command( 486 | args=["search", "Hello world"], 487 | root=path, 488 | cwd=path / "a-directory", 489 | exit_code=exit_code, 490 | stderr=stderr, 491 | stdout=stdout, 492 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 493 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 494 | ) 495 | 496 | def test_delete_files_between_runs(self): 497 | with tempfile.TemporaryDirectory() as tempdir: 498 | path = Path(tempdir).resolve() 499 | create_tree(SIMPLE_REPO_TREE, path) 500 | exit_code = 0 501 | stdout = re.compile(b"Found Git repository " + str(path).encode("utf-8") + b"\n", re.ASCII) 502 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 503 | search_args = ["search", "Hello world"] 504 | check_search_command( 505 | args=search_args, 506 | root=path, 507 | cwd=path, 508 | exit_code=exit_code, 509 | stdout=stdout, 510 | stderr=stderr, 511 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 512 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 513 | ) 514 | code_dot_py_path = path / "a-directory" / "code.py" 515 | os.remove(code_dot_py_path) 516 | check_search_command( 517 | args=search_args, 518 | root=path, 519 | cwd=path, 520 | exit_code=exit_code, 521 | stdout=stdout, 522 | stderr=stderr, 523 | expected_file_count=SIMPLE_REPO_TEST_CASE.files - 1, 524 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects - 1 525 | ) 526 | 527 | def test_version(self): 528 | with tempfile.TemporaryDirectory() as tempdir: 529 | path = Path(tempdir) 530 | exit_code = 0 531 | stdout = f"Codebased {VERSION}\n".encode("utf-8") 532 | stderr = b"" 533 | check_codebased_cli( 534 | cwd=path, 535 | exit_code=exit_code, 536 | stderr=stderr, 537 | stdout=stdout, 538 | args=["--version"] 539 | ) 540 | 541 | @pytest.mark.xfail 542 | def test_help(self): 543 | with tempfile.TemporaryDirectory() as tempdir: 544 | path = Path(tempdir) 545 | exit_code = 0 546 | stderr = b"" 547 | check_codebased_cli( 548 | cwd=path, 549 | exit_code=exit_code, 550 | stderr=stderr, 551 | stdout=re.compile( 552 | rb".*COMMAND.*--version.*-V.*--help.*Commands.*search.*", 553 | re.DOTALL | re.ASCII 554 | ), 555 | args=["--help"], 556 | ascii_only=True 557 | ) 558 | check_codebased_cli( 559 | cwd=path, 560 | exit_code=exit_code, 561 | stderr=stderr, 562 | stdout=re.compile(rb".*search.*QUERY.*", re.DOTALL | re.ASCII), 563 | args=["search", "--help"], 564 | ascii_only=True 565 | ) 566 | # Note: We"re not checking the exact help output as it might change and be system-dependent 567 | 568 | def test_directory_argument(self): 569 | with tempfile.TemporaryDirectory() as tempdir: 570 | path = Path(tempdir).resolve() 571 | create_tree(SIMPLE_REPO_TREE, path) 572 | exit_code = 0 573 | stdout = re.compile( 574 | b"Found Git repository " + str(path).encode("utf-8") + b".*", 575 | re.ASCII | re.DOTALL 576 | ) 577 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 578 | 579 | # Test with -d argument 580 | workdir = Path.cwd() 581 | assert workdir != path 582 | check_search_command( 583 | cwd=workdir, 584 | root=path, 585 | exit_code=exit_code, 586 | stderr=stderr, 587 | stdout=stdout, 588 | args=["search", "Hello world", "-d", str(path)], 589 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 590 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 591 | ) 592 | 593 | # Test with --directory argument 594 | check_search_command( 595 | root=path, 596 | cwd=workdir, 597 | exit_code=exit_code, 598 | stderr=stderr, 599 | stdout=stdout, 600 | args=["search", "Hello world", "--directory", str(path)], 601 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 602 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 603 | ) 604 | 605 | def test_with_gitignore(self): 606 | with tempfile.TemporaryDirectory() as tempdir: 607 | path = Path(tempdir).resolve() 608 | create_tree(SIMPLE_REPO_TREE, path) 609 | gitignore_path = path / ".gitignore" 610 | gitignore_path.write_text("*.py\n") 611 | exit_code = 0 612 | stdout = re.compile(b"Found Git repository " + str(path).encode("utf-8") + b".*") 613 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 614 | check_search_command( 615 | args=["search", "Hello world"], 616 | root=path, 617 | cwd=path, 618 | exit_code=exit_code, 619 | stderr=stderr, 620 | stdout=stdout, 621 | # +1 for the gitignore file 622 | # -1 for the .py file because it"s ignored 623 | expected_file_count=SIMPLE_REPO_TEST_CASE.files - 1 + 1, 624 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects - 1 + 1 625 | ) 626 | 627 | def test_with_nested_gitignore(self): 628 | with tempfile.TemporaryDirectory() as tempdir: 629 | path = Path(tempdir).resolve() 630 | create_tree(NESTED_GITIGNORE_TREE, path) 631 | exit_code = 0 632 | stdout = re.compile(b"Found Git repository " + str(path).encode("utf-8") + b".*") 633 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 634 | check_search_command( 635 | args=["search", "Hello world"], 636 | root=path, 637 | cwd=path, 638 | exit_code=exit_code, 639 | stderr=stderr, 640 | stdout=stdout, 641 | # +1 for the gitignore file 642 | # -1 for the .py file because it"s ignored 643 | expected_file_count=NESTED_GITIGNORE_TEST_CASE.files, 644 | expected_object_count=NESTED_GITIGNORE_TEST_CASE.objects 645 | ) 646 | 647 | def test_rebuild_faiss_index(self): 648 | with tempfile.TemporaryDirectory() as tempdir: 649 | path = Path(tempdir).resolve() 650 | create_tree(SIMPLE_REPO_TREE, path) 651 | exit_code = 0 652 | stdout = re.compile(b"Found Git repository " + str(path).encode("utf-8") + b".*") 653 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 654 | search_args = ["search", "Hello world"] 655 | check_search_command( 656 | args=search_args, 657 | root=path, 658 | cwd=path, 659 | exit_code=exit_code, 660 | stderr=stderr, 661 | stdout=stdout, 662 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 663 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 664 | ) 665 | check_search_command( 666 | args=search_args + ["--rebuild-faiss-index"], 667 | root=path, 668 | cwd=path, 669 | exit_code=exit_code, 670 | stderr=stderr, 671 | stdout=stdout, 672 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 673 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 674 | ) 675 | 676 | def test_cached_only(self): 677 | with tempfile.TemporaryDirectory() as tempdir: 678 | path = Path(tempdir).resolve() 679 | create_tree(SIMPLE_REPO_TREE, path) 680 | exit_code = 0 681 | stdout = re.compile(b"Found Git repository " + str(path).encode("utf-8") + b".*") 682 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 683 | search_args = ["search", "Hello world"] 684 | check_search_command( 685 | args=search_args, 686 | root=path, 687 | cwd=path, 688 | exit_code=exit_code, 689 | stderr=stderr, 690 | stdout=stdout, 691 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 692 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 693 | ) 694 | with check_file_did_not_change(path / ".codebased" / "codebased.db"), \ 695 | check_file_did_not_change(path / ".codebased" / "index.faiss"): 696 | check_search_command( 697 | args=search_args + ["--cached-only"], 698 | root=path, 699 | cwd=path, 700 | exit_code=exit_code, 701 | stderr=b"", 702 | stdout=stdout, 703 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 704 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 705 | ) 706 | 707 | def test_cache_only_without_warm_cache(self): 708 | with tempfile.TemporaryDirectory() as tempdir: 709 | path = Path(tempdir).resolve() 710 | create_tree(SIMPLE_REPO_TREE, path) 711 | exit_code = 0 712 | stdout = re.compile(b"Found Git repository " + str(path).encode("utf-8") + b".*") 713 | stderr = b"" 714 | search_args = ["search", "--cached-only", "Hello world"] 715 | check_search_command( 716 | args=search_args, 717 | root=path, 718 | cwd=path, 719 | exit_code=exit_code, 720 | stderr=stderr, 721 | stdout=stdout, 722 | expected_file_count=0, 723 | expected_object_count=0 724 | ) 725 | # Check that we only touched the index the first time because it didn"t exist. 726 | with check_file_did_not_change(path / ".codebased" / "codebased.db"), \ 727 | check_file_did_not_change(path / ".codebased" / "index.faiss"): 728 | check_search_command( 729 | args=search_args, 730 | root=path, 731 | cwd=path, 732 | exit_code=exit_code, 733 | stderr=stderr, 734 | stdout=stdout, 735 | expected_file_count=0, 736 | expected_object_count=0 737 | ) 738 | 739 | def test_semantic_search(self): 740 | with tempfile.TemporaryDirectory() as tempdir: 741 | path = Path(tempdir).resolve() 742 | create_tree(SIMPLE_REPO_TREE, path) 743 | exit_code = 0 744 | stdout = None 745 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 746 | search_args = ["search", "--semantic-search", "Hello world"] 747 | check_search_command( 748 | args=search_args, 749 | root=path, 750 | cwd=path, 751 | exit_code=exit_code, 752 | stderr=stderr, 753 | stdout=stdout, 754 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 755 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 756 | ) 757 | 758 | def test_full_text_search(self): 759 | with tempfile.TemporaryDirectory() as tempdir: 760 | path = Path(tempdir).resolve() 761 | create_tree(SIMPLE_REPO_TREE, path) 762 | exit_code = 0 763 | stdout = None 764 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 765 | search_args = ["search", "Hello world", "--full-text-search"] 766 | check_search_command( 767 | args=search_args, 768 | root=path, 769 | cwd=path, 770 | exit_code=exit_code, 771 | stderr=stderr, 772 | stdout=stdout, 773 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 774 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 775 | ) 776 | 777 | def test_full_text_search_bad_characters(self): 778 | with tempfile.TemporaryDirectory() as tempdir: 779 | path = Path(tempdir).resolve() 780 | create_tree(SIMPLE_REPO_TREE, path) 781 | exit_code = 0 782 | stdout = None 783 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 784 | search_args = ["search", """print('print("Hello world");');""", "--full-text-search"] 785 | check_search_command( 786 | args=search_args, 787 | root=path, 788 | cwd=path, 789 | exit_code=exit_code, 790 | stderr=stderr, 791 | stdout=stdout, 792 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 793 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 794 | ) 795 | 796 | def test_hybrid_search(self): 797 | with tempfile.TemporaryDirectory() as tempdir: 798 | path = Path(tempdir).resolve() 799 | create_tree(SIMPLE_REPO_TREE, path) 800 | exit_code = 0 801 | stdout = None 802 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 803 | search_args = ["search", "Hello world"] 804 | check_search_command( 805 | args=search_args, 806 | root=path, 807 | cwd=path, 808 | exit_code=exit_code, 809 | stderr=stderr, 810 | stdout=stdout, 811 | expected_file_count=SIMPLE_REPO_TEST_CASE.files, 812 | expected_object_count=SIMPLE_REPO_TEST_CASE.objects 813 | ) 814 | 815 | def test_ignore_folder(self): 816 | with tempfile.TemporaryDirectory() as tempdir: 817 | path = Path(tempdir).resolve() 818 | create_tree(GITIGNORE_FOLDER_TREE, path) 819 | exit_code = 0 820 | stdout = re.compile(b"Found Git repository " + str(path).encode("utf-8") + b"\n", re.ASCII) 821 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 822 | search_args = ["search", "Server started"] 823 | check_search_command( 824 | args=search_args, 825 | root=path, 826 | cwd=path, 827 | exit_code=exit_code, 828 | stderr=stderr, 829 | stdout=stdout, 830 | expected_file_count=IGNORE_FOLDER_TEST_CASE.files, 831 | expected_object_count=IGNORE_FOLDER_TEST_CASE.objects 832 | ) 833 | 834 | def test_ignore_hidden_folder(self): 835 | with tempfile.TemporaryDirectory() as tempdir: 836 | path = Path(tempdir).resolve() 837 | create_tree(HIDDEN_FOLDER_TREE, path) 838 | exit_code = 0 839 | stdout = re.compile(b"Found Git repository " + str(path).encode("utf-8") + b"\n", re.ASCII) 840 | stderr = re.compile(b".*Indexing " + path.name.encode("utf-8") + b".*", re.ASCII | re.DOTALL) 841 | search_args = ["search", "Hello world"] 842 | check_search_command( 843 | args=search_args, 844 | root=path, 845 | cwd=path, 846 | exit_code=exit_code, 847 | stderr=stderr, 848 | stdout=stdout, 849 | expected_file_count=HIDDEN_FOLDER_TEST_CASE.files, 850 | expected_object_count=HIDDEN_FOLDER_TEST_CASE.objects 851 | ) 852 | 853 | 854 | class TestQueryParsing(unittest.TestCase): 855 | def test_empty_query(self): 856 | query = Query.parse('') 857 | self.assertEqual(query.phrases, []) 858 | self.assertEqual(query.keywords, []) 859 | self.assertEqual(query.original, '') 860 | query = Query.parse('""') 861 | self.assertEqual(query.phrases, []) 862 | self.assertEqual(query.keywords, []) 863 | self.assertEqual(query.original, '""') 864 | 865 | def test_escape_double_quotes(self): 866 | query = Query.parse('"print(\\\"hello world\\\")"') 867 | self.assertEqual(query.phrases, ['print("hello world")']) 868 | 869 | def test_parse_basic(self): 870 | query = Query.parse('hello "world" how are you') 871 | self.assertEqual(query.phrases, ['world']) 872 | self.assertEqual(query.keywords, ['hello', 'how', 'are', 'you']) 873 | self.assertEqual(query.original, 'hello "world" how are you') 874 | 875 | def test_parse_multiple_exact_phrases(self): 876 | query = Query.parse('"hello world" test "foo bar" baz') 877 | self.assertEqual(query.phrases, ['hello world', 'foo bar']) 878 | self.assertEqual(query.keywords, ['test', 'baz']) 879 | self.assertEqual(query.original, '"hello world" test "foo bar" baz') 880 | 881 | def test_parse_empty_query(self): 882 | query = Query.parse('') 883 | self.assertEqual(query.phrases, []) 884 | self.assertEqual(query.keywords, []) 885 | self.assertEqual(query.original, '') 886 | 887 | def test_parse_only_exact_phrase(self): 888 | query = Query.parse('"this is a test"') 889 | self.assertEqual(query.phrases, ['this is a test']) 890 | self.assertEqual(query.keywords, []) 891 | self.assertEqual(query.original, '"this is a test"') 892 | 893 | def test_parse_with_special_characters(self): 894 | query = Query.parse('hello! "world?" how_are_you') 895 | self.assertEqual(query.phrases, ['world?']) 896 | self.assertEqual(query.keywords, ['hello!', 'how_are_you']) 897 | self.assertEqual(query.original, 'hello! "world?" how_are_you') 898 | 899 | def test_parse_pathological_input(self): 900 | # This test case creates a pathological input that could cause exponential backtracking 901 | pathological_input = '"' + 'a' * 100 + '" ' + 'b' * 100 902 | import time 903 | start_time = time.time() 904 | query = Query.parse(pathological_input) 905 | end_time = time.time() 906 | parsing_time = end_time - start_time 907 | 908 | self.assertEqual(query.phrases, ['a' * 100]) 909 | self.assertEqual(query.keywords, ['b' * 100]) 910 | self.assertEqual(query.original, pathological_input) 911 | 912 | # Assert that parsing time is reasonable (e.g., less than 1 second) 913 | self.assertLess(parsing_time, 1.0, "Parsing took too long, possible exponential backtracking") 914 | 915 | 916 | class TestHighlighting(unittest.TestCase): 917 | def test_empty_query(self): 918 | query = Query.parse('') 919 | self.assertEqual(find_highlights(query, ''), ([], [])) 920 | self.assertEqual(find_highlights(query, '""'), ([], [])) 921 | query = Query.parse('""') 922 | self.assertEqual(find_highlights(query, ''), ([], [])) 923 | self.assertEqual(find_highlights(query, '""'), ([], [])) 924 | 925 | def test_highlights(self): 926 | query = Query.parse('hello "world" how are you') 927 | highlights, lines = find_highlights(query, 'hello "world" how are you') 928 | self.assertEqual( 929 | highlights, 930 | [(0, 5), (7, 12), (14, 17), (18, 21), (22, 25)] 931 | ) 932 | self.assertEqual(lines, [(0, 0)] * len(highlights)) 933 | highlights, lines = find_highlights(query, "hello world how are you") 934 | self.assertEqual( 935 | highlights, 936 | [(0, 5), (6, 11), (12, 15), (16, 19), (20, 23)] 937 | ) 938 | self.assertEqual(lines, [(0, 0)] * len(highlights)) 939 | 940 | def test_out_of_order_highlights(self): 941 | query = Query.parse('hello "world" how are you') 942 | text = 'you are how hello world' 943 | highlights, lines = find_highlights(query, text) 944 | self.assertEqual( 945 | highlights, 946 | [(0, 3), (4, 7), (8, 11), (12, 17), (18, 23)] 947 | ) 948 | self.assertEqual(lines, [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]) 949 | query = Query.parse('"sea world"') 950 | text = "have you been to sea world?" 951 | highlights, lines = find_highlights(query, text) 952 | self.assertEqual( 953 | highlights, 954 | [(17, 26)] 955 | ) 956 | self.assertEqual(lines, [(0, 0)]) 957 | text = "world seap" 958 | highlights, lines = find_highlights(query, text) 959 | self.assertEqual( 960 | highlights, 961 | [] 962 | ) 963 | self.assertEqual(lines, []) 964 | 965 | def test_multiline_highlights(self): 966 | query = Query.parse('hello "world" how are you') 967 | text = 'hello\nworld\nhow\nare\nyou' 968 | highlights, lines = find_highlights(query, text) 969 | self.assertEqual( 970 | highlights, 971 | [(0, 5), (6, 11), (12, 15), (16, 19), (20, 23)] 972 | ) 973 | self.assertEqual(lines, [(i, i) for i in range(5)]) 974 | text = '\nhello\nworld\n' 975 | highlights, lines = find_highlights(query, text) 976 | self.assertEqual( 977 | highlights, 978 | [(1, 6), (7, 12)] 979 | ) 980 | self.assertEqual( 981 | lines, 982 | [(1, 1), (2, 2)] 983 | ) 984 | query = Query.parse('"hello world"') 985 | highlights, lines = find_highlights(query, text) 986 | self.assertEqual( 987 | highlights, 988 | [] 989 | ) 990 | 991 | def test_case_insensitive_highlights(self): 992 | query = Query.parse('HELLO "WoRlD" how ARE you') 993 | 994 | text = 'hello world HOW are YOU' 995 | actual_highlights, lines = find_highlights(query, text) 996 | self.assertEqual( 997 | actual_highlights, 998 | [(0, 5), (6, 11), (12, 15), (16, 19), (20, 23)] 999 | ) 1000 | self.assertEqual(lines, [(0, 0)] * 5) 1001 | 1002 | def test_partial_phrase_match(self): 1003 | query = Query.parse('"hello world" python') 1004 | text = 'hello worlds of python' 1005 | actual_highlights, lines = find_highlights(query, text) 1006 | self.assertEqual( 1007 | actual_highlights, 1008 | [(0, 11), (16, 22)] 1009 | ) 1010 | 1011 | def test_overlapping_highlights(self): 1012 | query = Query.parse('overlapping overlap lap') 1013 | text = 'this is an overlapping text' 1014 | actual_highlights, lines = find_highlights(query, text) 1015 | left = text.index('overlapping') 1016 | self.assertEqual( 1017 | actual_highlights, 1018 | [(left, left + len('overlapping'))] 1019 | ) 1020 | query = Query.parse('overlapping overlap lap over ping') 1021 | text = 'this is an overlapping text' 1022 | actual_highlights, lines = find_highlights(query, text) 1023 | left = text.index('overlapping') 1024 | self.assertEqual( 1025 | actual_highlights, 1026 | [(left, left + len('overlapping'))] 1027 | ) 1028 | query = Query.parse('overlapping "an over"') 1029 | text = 'this is an overlapping text' 1030 | actual_highlights, lines = find_highlights(query, text) 1031 | left = text.index('an') 1032 | self.assertEqual( 1033 | actual_highlights, 1034 | [(left, left + len('an overlapping'))] 1035 | ) 1036 | 1037 | 1038 | class AppTestBase(unittest.IsolatedAsyncioTestCase): 1039 | def setUp(self): 1040 | super().setUp() 1041 | self.tempdir = tempfile.TemporaryDirectory() 1042 | path = Path(self.tempdir.name).resolve() 1043 | SIMPLE_REPO_TEST_CASE.create(path) 1044 | self.flags = Flags( 1045 | rerank=False, 1046 | directory=path, 1047 | rebuild_faiss_index=False, 1048 | cached_only=False, 1049 | query="Hello world", 1050 | background=False, 1051 | stats=False, 1052 | semantic=True, 1053 | full_text_search=True, 1054 | top_k=10, 1055 | radius=1.0 1056 | ) 1057 | self.settings = Settings() 1058 | self.config = Config(flags=self.flags) 1059 | self.dependencies = Dependencies( 1060 | config=self.config, 1061 | settings=self.settings 1062 | ) 1063 | self.setUpIndex() 1064 | self.app = Codebased(flags=self.flags, config=self.config, dependencies=self.dependencies) 1065 | 1066 | def setUpIndex(self): 1067 | index_paths( 1068 | self.dependencies, 1069 | self.config, 1070 | [self.config.root], 1071 | total=True 1072 | ) 1073 | 1074 | @pytest.mark.xfail 1075 | async def test_search(self): 1076 | async with self.app.run_test() as pilot: 1077 | query = "Hello world" 1078 | for i in range(11): 1079 | this = query[i] 1080 | await pilot.press(this) 1081 | so_far = query[:i + 1] 1082 | search_bar = self.app.query_one(f"#{Id.SEARCH_INPUT}", Input) 1083 | self.assertEqual(search_bar.value, so_far) 1084 | result_list = self.app.query_one(f"#{Id.RESULTS_LIST}", ListView) 1085 | # There should be 2 items. 1086 | self.assertEqual(len(result_list.children), 2) 1087 | preview = self.app.query_one(f"#{Id.PREVIEW}", Static) 1088 | preview_text = preview.renderable 1089 | self.assertIsInstance(preview_text, Syntax) 1090 | code = preview_text.code 1091 | self.assertEqual(code, "Hello, world!") 1092 | self.assertEqual(preview_text.lexer.name, "Markdown") 1093 | focused = self.app.focused 1094 | self.assertEqual(focused.id, Id.SEARCH_INPUT.value) 1095 | await pilot.press("enter") 1096 | focused = self.app.focused 1097 | self.assertEqual(focused.id, Id.RESULTS_LIST.value) 1098 | await pilot.press("tab") 1099 | focused = self.app.focused 1100 | self.assertEqual(focused.id, Id.PREVIEW_CONTAINER.value) 1101 | await pilot.press("escape") 1102 | focused = self.app.focused 1103 | self.assertEqual(focused.id, Id.SEARCH_INPUT.value) 1104 | await pilot.press("r") 1105 | await pilot.press("f") 1106 | # await pilot.press("f") 1107 | # await pilot.press("d") 1108 | 1109 | def tearDown(self): 1110 | super().tearDown() 1111 | self.tempdir.cleanup() 1112 | self.dependencies.db.close() 1113 | --------------------------------------------------------------------------------