├── .env.example ├── clients ├── __init__.py └── jina │ ├── __init__.py │ ├── client.py │ └── embeddings.py ├── config ├── __init__.py └── loguru.py ├── constants ├── __init__.py └── constants.py ├── domain ├── __init__.py ├── entity │ ├── __init__.py │ ├── code_segment.py │ ├── task_context.py │ ├── document.py │ └── blob.py ├── enums │ ├── __init__.py │ ├── model_provider.py │ ├── chroma_mode.py │ ├── qdrant_mode.py │ └── vector_type.py ├── model │ ├── __init__.py │ ├── base.py │ └── repository.py ├── request │ ├── __init__.py │ ├── all_index_request.py │ ├── drop_index_request.py │ ├── git_index_request.py │ ├── retrieval_request.py │ ├── delete_index_request.py │ └── add_index_request.py ├── result │ ├── __init__.py │ └── result.py └── response │ ├── __init__.py │ ├── add_index_by_file_response.py │ ├── add_index_response.py │ └── retrieval_segment_response.py ├── exception ├── __init__.py └── exception.py ├── launch ├── __init__.py ├── we0_index_mcp.py └── launch.py ├── loader ├── __init__.py ├── segmenter │ ├── language │ │ ├── __init__.py │ │ ├── python.py │ │ ├── css.py │ │ ├── go.py │ │ ├── java.py │ │ ├── typescriptxml.py │ │ ├── typescript.py │ │ └── javascript.py │ ├── __init__.py │ ├── tree_sitter_factory.py │ ├── base_segmenter.py │ ├── base_line_segmenter.py │ └── tree_sitter_segmenter.py └── repo_loader.py ├── models ├── __init__.py └── model_factory.py ├── prompt ├── __init__.py └── prompt.py ├── router ├── __init__.py ├── git_router.py └── vector_router.py ├── setting ├── __init__.py └── setting.py ├── utils ├── __init__.py ├── path_util.py ├── mimetype_util.py ├── git_parse.py ├── helper.py └── vector_helper.py ├── extensions ├── __init__.py ├── vector │ ├── __init__.py │ ├── base_vector.py │ ├── ext_vector.py │ ├── chroma.py │ ├── qdrant.py │ └── pgvector.py └── ext_manager.py ├── resource └── dev.yaml ├── Dockerfile ├── pyproject.toml ├── main.py ├── README-zh.md ├── .gitignore └── README.md /.env.example: -------------------------------------------------------------------------------- 1 | WE0_INDEX_ENV=dev 2 | TZ=Asia/Shanghai 3 | OPENAI_BASE_URL=https://openai.com/v1 4 | OPENAI_API_KEY= 5 | JINA_BASE_URL=https://api.jina.ai/v1 6 | JINA_API_KEY= -------------------------------------------------------------------------------- /clients/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/19 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/7/17 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /constants/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/7/17 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /domain/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/10 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /exception/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/8/21 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /launch/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/8/21 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /loader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/10 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/15 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /prompt/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/20 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /router/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/11 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /setting/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/06/19 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/11 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /domain/entity/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/11 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /domain/enums/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/11 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /domain/model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/20 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /domain/request/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/22 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /domain/result/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/7/17 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /extensions/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/10 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /domain/response/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/22 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /extensions/vector/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/14 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /loader/segmenter/language/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/7 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | -------------------------------------------------------------------------------- /domain/model/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/20 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : base 7 | # @Software: PyCharm 8 | from sqlalchemy.ext.declarative import declarative_base 9 | 10 | Base = declarative_base() 11 | -------------------------------------------------------------------------------- /domain/enums/model_provider.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/15 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : model_provider 7 | # @Software: PyCharm 8 | from enum import StrEnum 9 | 10 | 11 | class ModelType(StrEnum): 12 | OPENAI = "openai" 13 | JINA = "jina" 14 | -------------------------------------------------------------------------------- /domain/enums/chroma_mode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/22 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : qdrant_mode 7 | # @Software: PyCharm 8 | from enum import StrEnum 9 | 10 | 11 | class ChromaMode(StrEnum): 12 | DISK = "disk" 13 | REMOTE = "remote" 14 | MEMORY = "memory" 15 | -------------------------------------------------------------------------------- /domain/enums/qdrant_mode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/22 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : qdrant_mode 7 | # @Software: PyCharm 8 | from enum import StrEnum 9 | 10 | 11 | class QdrantMode(StrEnum): 12 | DISK = "disk" 13 | REMOTE = "remote" 14 | MEMORY = "memory" 15 | -------------------------------------------------------------------------------- /domain/request/all_index_request.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/23 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : all_index_request 7 | # @Software: PyCharm 8 | from pydantic import BaseModel, Field 9 | 10 | 11 | class AllIndexRequest(BaseModel): 12 | repo_id: str = Field(description='仓库 ID') 13 | -------------------------------------------------------------------------------- /domain/request/drop_index_request.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/23 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : drop_index_request 7 | # @Software: PyCharm 8 | from pydantic import BaseModel, Field 9 | 10 | 11 | class DropIndexRequest(BaseModel): 12 | repo_id: str = Field(description='仓库 ID') 13 | -------------------------------------------------------------------------------- /domain/enums/vector_type.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/11 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : vector_type 7 | # @Software: PyCharm 8 | from enum import StrEnum 9 | 10 | 11 | class VectorType(StrEnum): 12 | PGVECTOR: str = "pgvector" 13 | QDRANT: str = "qdrant" 14 | CHROMA: str = "chroma" 15 | -------------------------------------------------------------------------------- /domain/request/git_index_request.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import Optional 3 | 4 | class GitIndexRequest(BaseModel): 5 | uid: str | None = None 6 | repo_url: str 7 | # 私有仓库认证字段 8 | username: Optional[str] = None # 用户名 9 | password: Optional[str] = None # 密码或个人访问令牌 10 | access_token: Optional[str] = None # 访问令牌(GitHub/GitLab Personal Access Token) -------------------------------------------------------------------------------- /domain/response/add_index_by_file_response.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/23 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : add_index_by_file_response 7 | # @Software: PyCharm 8 | 9 | from pydantic import BaseModel 10 | 11 | 12 | class AddIndexByFileResponse(BaseModel): 13 | repo_id: str 14 | file_id: str 15 | -------------------------------------------------------------------------------- /clients/jina/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/19 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | from clients.jina.client import AsyncClient 9 | from clients.jina.embeddings import AsyncEmbeddings 10 | 11 | __all__ = [ 12 | 'AsyncClient', 13 | 'AsyncEmbeddings', 14 | ] 15 | -------------------------------------------------------------------------------- /domain/request/retrieval_request.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/23 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : retrieval_request 7 | # @Software: PyCharm 8 | from typing import List, Optional 9 | 10 | from pydantic import BaseModel 11 | 12 | 13 | class RetrievalRequest(BaseModel): 14 | repo_id: str 15 | file_ids: Optional[List[str]] = None 16 | query: str 17 | -------------------------------------------------------------------------------- /domain/request/delete_index_request.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/23 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : delete_index_request 7 | # @Software: PyCharm 8 | from typing import List 9 | 10 | from pydantic import BaseModel, Field 11 | 12 | 13 | class DeleteIndexRequest(BaseModel): 14 | repo_id: str = Field(description='仓库 ID') 15 | file_id: List[str] = Field(description='仓库 ID') 16 | -------------------------------------------------------------------------------- /domain/entity/code_segment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/17 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : code_segment 7 | # @Software: PyCharm 8 | from pydantic import BaseModel, ConfigDict, Field 9 | 10 | 11 | class CodeSegment(BaseModel): 12 | start: int 13 | end: int 14 | code: str 15 | block: int = Field(default=1) 16 | model_config = ConfigDict( 17 | extra='ignore' 18 | ) 19 | -------------------------------------------------------------------------------- /domain/entity/task_context.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/12 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : task_context 7 | # @Software: PyCharm 8 | 9 | from pydantic import BaseModel, ConfigDict 10 | 11 | from domain.entity.blob import Blob 12 | 13 | 14 | class TaskContext(BaseModel): 15 | repo_id: str 16 | file_id: str 17 | relative_path: str 18 | blob: Blob 19 | model_config = ConfigDict(extra='ignore') 20 | -------------------------------------------------------------------------------- /domain/response/add_index_response.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/23 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : add_index_response 7 | # @Software: PyCharm 8 | from typing import List 9 | 10 | from pydantic import BaseModel 11 | 12 | 13 | class FileInfoResponse(BaseModel): 14 | file_id: str 15 | relative_path: str 16 | 17 | 18 | class AddIndexResponse(BaseModel): 19 | repo_id: str 20 | file_infos: List[FileInfoResponse] 21 | -------------------------------------------------------------------------------- /extensions/ext_manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/23 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : ext_manager 7 | # @Software: PyCharm 8 | from loguru import logger 9 | 10 | from extensions.vector.ext_vector import Vector 11 | 12 | 13 | class ExtManager: 14 | vector = Vector() 15 | 16 | 17 | async def init_vector(): 18 | logger.info("Initializing vector") 19 | await ExtManager.vector.init_app() 20 | logger.info("Initialized vector") 21 | -------------------------------------------------------------------------------- /utils/path_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/7/18 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : path_util 7 | # @Software: PyCharm 8 | import os 9 | 10 | 11 | class PathUtil: 12 | @staticmethod 13 | def check_or_make_dir(path: str): 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | 17 | @staticmethod 18 | def check_or_make_dirs(*paths): 19 | for path in paths: 20 | PathUtil.check_or_make_dir(path) 21 | -------------------------------------------------------------------------------- /exception/exception.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/8/21 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : exception 7 | # @Software: PyCharm 8 | 9 | 10 | class CommonException(Exception): 11 | def __init__(self, message=''): 12 | self.message = message 13 | super().__init__(self.message) 14 | 15 | def __str__(self): 16 | return self.message 17 | 18 | def __repr__(self): 19 | return self.message 20 | 21 | 22 | class StorageUploadFileException(CommonException): 23 | ... 24 | -------------------------------------------------------------------------------- /domain/response/retrieval_segment_response.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/22 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : retrieval_segment_response 7 | # @Software: PyCharm 8 | from pydantic import BaseModel, Field, ConfigDict 9 | 10 | 11 | class RetrievalSegmentResponse(BaseModel): 12 | relative_path: str = Field(description='代码段所属文件相对路径') 13 | start_line: int = Field(description='代码段开始行') 14 | end_line: int = Field(description='代码段结束行') 15 | score: float = Field(description='相似度评分') 16 | model_config = ConfigDict( 17 | extra='ignore' 18 | ) 19 | -------------------------------------------------------------------------------- /domain/request/add_index_request.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/23 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : add_index_request 7 | # @Software: PyCharm 8 | from typing import List 9 | 10 | from pydantic import BaseModel, Field 11 | 12 | 13 | class AddFileInfo(BaseModel): 14 | relative_path: str = Field(description='File Relative Path') 15 | content: str = Field(description='File Content') 16 | 17 | 18 | class AddIndexRequest(BaseModel): 19 | uid: str = Field(description='Unique ID') 20 | repo_abs_path: str = Field(description='Repository Absolute Path') 21 | file_infos: List[AddFileInfo] 22 | -------------------------------------------------------------------------------- /domain/result/result.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/07/17 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : result 7 | # @Software: PyCharm 8 | 9 | from pydantic import BaseModel 10 | 11 | 12 | class Result[T](BaseModel): 13 | code: int 14 | message: str 15 | data: T | None = None 16 | success: bool 17 | 18 | @classmethod 19 | def ok(cls, data: T | None = None, code: int = 200, message: str = 'Success'): 20 | return cls(code=code, message=message, data=data, success=True) 21 | 22 | @classmethod 23 | def failed(cls, code: int = 500, message: str = 'Internal Server Error'): 24 | return cls(code=code, message=message, success=False) 25 | -------------------------------------------------------------------------------- /loader/segmenter/language/python.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from tree_sitter import Language 4 | 5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory 6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter 7 | 8 | 9 | @TreeSitterFactory.register(ext_set={'.py'}) 10 | class PythonSegmenter(TreeSitterSegmenter): 11 | """Code segmenter for Python.""" 12 | 13 | def get_language(self) -> Language: 14 | import tree_sitter_python 15 | return Language(tree_sitter_python.language()) 16 | 17 | def get_node_types(self) -> List[str]: 18 | return ['function_definition', 'decorated_definition'] 19 | 20 | def get_recursion_node_types(self) -> List[str]: 21 | return ['class_definition', 'block'] 22 | -------------------------------------------------------------------------------- /loader/segmenter/language/css.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from tree_sitter import Language 4 | 5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory 6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter 7 | 8 | 9 | @TreeSitterFactory.register(ext_set={'.css'}) 10 | class CssSegmenter(TreeSitterSegmenter): 11 | """Code segmenter for Css.""" 12 | 13 | def get_language(self) -> Language: 14 | import tree_sitter_css 15 | return Language(tree_sitter_css.language()) 16 | 17 | def get_node_types(self) -> List[str]: 18 | return [ 19 | 'rule_set', 20 | 'keyframes_statement', 21 | 'media_statement' 22 | ] 23 | 24 | def get_recursion_node_types(self) -> List[str]: 25 | return [] 26 | -------------------------------------------------------------------------------- /loader/segmenter/language/go.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from tree_sitter import Language 4 | 5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory 6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter 7 | 8 | 9 | @TreeSitterFactory.register(ext_set={'.go'}) 10 | class GoSegmenter(TreeSitterSegmenter): 11 | """Code segmenter for Go.""" 12 | 13 | def get_language(self) -> Language: 14 | import tree_sitter_go 15 | return Language(tree_sitter_go.language()) 16 | 17 | def get_node_types(self) -> List[str]: 18 | return [ 19 | 'method_declaration', 20 | 'function_declaration', 21 | 'type_declaration' 22 | ] 23 | 24 | def get_recursion_node_types(self) -> List[str]: 25 | return [] 26 | -------------------------------------------------------------------------------- /resource/dev.yaml: -------------------------------------------------------------------------------- 1 | we0-index: 2 | application: we0-index 3 | server: 4 | host: 0.0.0.0 5 | port: 8080 6 | reload: True 7 | log: 8 | level: INFO 9 | file: false 10 | debug: false 11 | vector: 12 | platform: pgvector 13 | code2desc: false 14 | chat-provider: openai 15 | chat-model: gpt-4o-mini 16 | embedding-provider: jina 17 | embedding-model: jina-embeddings-v2-base-code 18 | pgvector: 19 | db: we0_index 20 | host: localhost 21 | port: 5432 22 | user: root 23 | password: password 24 | qdrant: 25 | mode: disk 26 | disk: 27 | path: vector/qdrant 28 | remote: 29 | host: localhost 30 | port: 6333 31 | chroma: 32 | mode: disk 33 | disk: 34 | path: vector/chroma 35 | remote: 36 | host: localhost 37 | port: 8000 38 | ssl: false -------------------------------------------------------------------------------- /loader/segmenter/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/12 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : __init__.py 7 | # @Software: PyCharm 8 | from loader.segmenter.language.css import CssSegmenter 9 | from loader.segmenter.language.go import GoSegmenter 10 | from loader.segmenter.language.java import JavaSegmenter 11 | from loader.segmenter.language.javascript import JavaScriptSegmenter 12 | from loader.segmenter.language.python import PythonSegmenter 13 | from loader.segmenter.language.typescript import TypeScriptSegmenter 14 | from loader.segmenter.language.typescriptxml import TypeScriptXmlSegmenter 15 | 16 | __all__ = [ 17 | 'CssSegmenter', 18 | 'GoSegmenter', 19 | 'JavaSegmenter', 20 | 'JavaScriptSegmenter', 21 | 'PythonSegmenter', 22 | 'TypeScriptSegmenter', 23 | 'TypeScriptXmlSegmenter' 24 | ] 25 | -------------------------------------------------------------------------------- /utils/mimetype_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/12 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : mimetype_util 7 | # @Software: PyCharm 8 | import mimetypes 9 | import os 10 | 11 | 12 | def guess_mimetype_and_extension(file_path: str) -> tuple[str | None, str | None]: 13 | """ 14 | 根据文件路径(通常是文件名或 URL 中的路径片段)来猜测 MIME 类型和扩展名。 15 | 优先使用文件后缀名进行判断,如果没有后缀名则使用默认的 application/octet-stream。 16 | """ 17 | _, extension = os.path.splitext(file_path) 18 | if extension: 19 | # 通过后缀名在 mimetypes 中查找 MIME 类型 20 | mimetype = mimetypes.types_map.get(extension.lower()) 21 | else: 22 | # 如果没有扩展名,使用默认的二进制流类型 23 | mimetype = 'application/octet-stream' 24 | # 根据 MIME 类型再猜一次扩展名(一般会给出 .bin) 25 | extension = mimetypes.guess_extension(mimetype) 26 | 27 | return mimetype, extension 28 | -------------------------------------------------------------------------------- /loader/segmenter/language/java.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from tree_sitter import Language 4 | 5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory 6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter 7 | 8 | 9 | @TreeSitterFactory.register(ext_set={'.java'}) 10 | class JavaSegmenter(TreeSitterSegmenter): 11 | """Code segmenter for Java.""" 12 | 13 | def get_language(self) -> Language: 14 | import tree_sitter_java 15 | return Language(tree_sitter_java.language()) 16 | 17 | def get_node_types(self) -> List[str]: 18 | return [ 19 | 'method_declaration', 20 | 'enum_declaration' 21 | ] 22 | 23 | def get_recursion_node_types(self) -> List[str]: 24 | return [ 25 | 'class_declaration', 26 | 'class_body', 27 | 'interface_declaration', 28 | 'interface_body' 29 | ] 30 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # ---------------------------- 2 | # builder 3 | # ---------------------------- 4 | FROM python:3.12.8-slim AS builder 5 | 6 | ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_PYTHON_DOWNLOADS=0 7 | 8 | WORKDIR /app 9 | 10 | RUN pip install --no-cache-dir uv 11 | 12 | COPY pyproject.toml /app 13 | 14 | RUN uv lock && uv sync --frozen --no-dev 15 | # ---------------------------- 16 | # runtime 17 | # ---------------------------- 18 | FROM python:3.12.8-slim 19 | 20 | RUN apt-get update && \ 21 | apt-get install --no-install-recommends -y git ca-certificates && \ 22 | rm -rf /var/lib/apt/lists/* 23 | ENV PORT 8080 24 | ENV PATH="/app/.venv/bin:$PATH" 25 | 26 | EXPOSE $PORT 27 | VOLUME /app/resource 28 | VOLUME /app/log 29 | VOLUME /app/vector 30 | VOLUME /app/storage 31 | 32 | WORKDIR /app 33 | 34 | COPY . /app 35 | 36 | COPY --from=builder /usr/local/bin /usr/local/bin 37 | COPY --from=builder --chown=app:app /app/.venv /app/.venv 38 | 39 | CMD uv run python /app/main.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "we0-index" 3 | version = "0.1.0" 4 | description = "We0 Index" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "aiofiles>=24.1.0", 9 | "chromadb>=0.6.3", 10 | "fastapi>=0.115.7", 11 | "gitpython>=3.1.42", 12 | "greenlet>=3.2.2", 13 | "loguru>=0.7.3", 14 | "mcp[cli]>=1.9.2", 15 | "openai", 16 | "psycopg[binary,pool]>=3.2.4", 17 | "pydantic-settings>=2.7.1", 18 | "python-multipart>=0.0.20", 19 | "pyyaml>=6.0.2", 20 | "qdrant-client>=1.13.2", 21 | "sqlalchemy>=2.0.41", 22 | "tiktoken>=0.8.0", 23 | "tree-sitter>=0.24.0", 24 | "tree-sitter-css>=0.23.2", 25 | "tree-sitter-go>=0.23.4", 26 | "tree-sitter-java>=0.23.5", 27 | "tree-sitter-javascript>=0.23.1", 28 | "tree-sitter-python>=0.23.6", 29 | "tree-sitter-typescript>=0.23.2", 30 | "uvicorn>=0.34.0", 31 | ] 32 | 33 | [[tool.uv.index]] 34 | url = "https://mirrors.aliyun.com/pypi/simple" 35 | default = true 36 | -------------------------------------------------------------------------------- /loader/segmenter/language/typescriptxml.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from tree_sitter import Language 4 | 5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory 6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter 7 | 8 | 9 | @TreeSitterFactory.register(ext_set={'.tsx'}) 10 | class TypeScriptXmlSegmenter(TreeSitterSegmenter): 11 | """Code segmenter for TypeScriptXml.""" 12 | 13 | def get_language(self) -> Language: 14 | import tree_sitter_typescript 15 | return Language(tree_sitter_typescript.language_tsx()) 16 | 17 | def get_node_types(self) -> List[str]: 18 | return [ 19 | 'lexical_declaration', 20 | 'interface_declaration', 21 | 'method_definition', 22 | 'function_declaration', 23 | 'export_statement' 24 | ] 25 | 26 | def get_recursion_node_types(self) -> List[str]: 27 | return [ 28 | 'class_declaration', 29 | 'class_body' 30 | ] 31 | -------------------------------------------------------------------------------- /loader/segmenter/language/typescript.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from tree_sitter import Language 4 | 5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory 6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter 7 | 8 | 9 | @TreeSitterFactory.register(ext_set={'.ts'}) 10 | class TypeScriptSegmenter(TreeSitterSegmenter): 11 | """Code segmenter for TypeScript.""" 12 | 13 | def get_language(self) -> Language: 14 | import tree_sitter_typescript 15 | return Language(tree_sitter_typescript.language_typescript()) 16 | 17 | def get_node_types(self) -> List[str]: 18 | return [ 19 | 'lexical_declaration', 20 | 'interface_declaration', 21 | 'method_definition', 22 | 'function_declaration', 23 | 'export_statement' 24 | ] 25 | 26 | def get_recursion_node_types(self) -> List[str]: 27 | return [ 28 | 'class_declaration', 29 | 'class_body', 30 | 'abstract_class_declaration' 31 | ] 32 | -------------------------------------------------------------------------------- /loader/segmenter/language/javascript.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from tree_sitter import Language 4 | 5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory 6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter 7 | 8 | 9 | @TreeSitterFactory.register(ext_set={'.js', '.mjs'}) 10 | class JavaScriptSegmenter(TreeSitterSegmenter): 11 | """Code segmenter for JavaScript.""" 12 | 13 | def get_language(self) -> Language: 14 | import tree_sitter_javascript 15 | return Language(tree_sitter_javascript.language()) 16 | 17 | def get_node_types(self) -> List[str]: 18 | return [ 19 | 'lexical_declaration', 20 | 'interface_declaration', 21 | 'export_statement', 22 | 'method_definition', 23 | 'function_declaration', 24 | 'function_expression', 25 | 'generator_function', 26 | 'generator_function_declaration' 27 | ] 28 | 29 | def get_recursion_node_types(self) -> List[str]: 30 | return [ 31 | 'class_declaration', 32 | 'class_body' 33 | ] 34 | -------------------------------------------------------------------------------- /domain/model/repository.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/20 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : repository 7 | # @Software: PyCharm 8 | from datetime import datetime 9 | 10 | from sqlalchemy import UUID, String, DateTime, func 11 | from sqlalchemy.dialects import postgresql 12 | from sqlalchemy.orm import Mapped, mapped_column 13 | 14 | from domain.model.base import Base 15 | 16 | 17 | class Repository(Base): 18 | __tablename__ = 'repository_info' 19 | id: Mapped[UUID] = mapped_column( 20 | postgresql.UUID(as_uuid=True), 21 | primary_key=True, 22 | ) 23 | embedding_model: Mapped[str] = mapped_column(String, nullable=True) 24 | embedding_model_provider: Mapped[str] = mapped_column(String, nullable=True) 25 | created_by: Mapped[str] = mapped_column(String, nullable=False) 26 | created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) 27 | updated_by: Mapped[str] = mapped_column(String, nullable=False) 28 | updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) 29 | -------------------------------------------------------------------------------- /constants/constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/7/17 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : constants 7 | # @Software: PyCharm 8 | import os 9 | 10 | from dotenv import find_dotenv, load_dotenv 11 | 12 | 13 | class Constants: 14 | class Common: 15 | PROJECT_NAME: str = 'we0-index' 16 | 17 | class Path: 18 | # SYSTEM PATH 19 | ROOT_PATH: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 20 | LOG_PATH: str = os.path.join(ROOT_PATH, 'log') 21 | RESOURCE_PATH: str = os.path.join(ROOT_PATH, 'resource') 22 | ENV_FILE_PATH: str = os.path.join(ROOT_PATH, '.env') 23 | TEMP_PATH: str = '/tmp' 24 | QDRANT_DEFAULT_DISK_PATH: str = os.path.join(ROOT_PATH, 'vector', 'qdrant') 25 | CHROMA_DEFAULT_DISK_PATH: str = os.path.join(ROOT_PATH, 'vector', 'chroma') 26 | 27 | # We0 CONFIG 28 | load_dotenv(ENV_FILE_PATH) 29 | YAML_FILE_PATH: str = find_dotenv( 30 | filename=os.path.join( 31 | RESOURCE_PATH, 32 | f"{os.environ.get('WE0_INDEX_ENV', 'dev')}.yaml" 33 | ) 34 | ) 35 | -------------------------------------------------------------------------------- /clients/jina/client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/19 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : jina 7 | # @Software: PyCharm 8 | import os 9 | from typing import Optional 10 | 11 | import httpx 12 | 13 | 14 | class AsyncClient(httpx.AsyncClient): 15 | 16 | def __init__( 17 | self, 18 | base_url: Optional[str] = None, 19 | api_key: Optional[str] = None, 20 | timeout: Optional[int] = 300, 21 | *args, **kwargs 22 | ): 23 | if api_key is None: 24 | api_key = os.environ.get("JINA_API_KEY") 25 | if api_key is None: 26 | raise ValueError( 27 | "The api_key clients option must be set either by passing api_key to the clients or by setting the JINA_API_KEY environment variable" 28 | ) 29 | self.api_key = api_key 30 | 31 | if base_url is None: 32 | base_url = os.environ.get("JINA_BASE_URL") 33 | if base_url is None: 34 | base_url = f"https://api.jina.ai/v1" 35 | 36 | super().__init__(base_url=base_url, timeout=timeout, *args, **kwargs) 37 | 38 | from .embeddings import AsyncEmbeddings 39 | self.embeddings = AsyncEmbeddings(self) 40 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/10 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : main 7 | # @Software: PyCharm 8 | 9 | import click 10 | import uvicorn 11 | 12 | from constants.constants import Constants 13 | from launch.we0_index_mcp import we0_index_mcp 14 | from setting.setting import get_we0_index_settings 15 | 16 | @click.command() 17 | @click.option('--mode', default='mcp', show_default=True, type=click.Choice(['mcp', 'fastapi']), required=True, help='Choose run mode: "mcp" or "fastapi".') 18 | @click.option('--transport', default='streamable-http', show_default=True, type=click.Choice(['streamable-http', 'stdio', 'sse']), help='Transport protocol for MCP mode') 19 | def main(mode, transport): 20 | sider_settings = get_we0_index_settings() 21 | 22 | if mode == 'mcp': 23 | we0_index_mcp.run(transport) 24 | elif mode == 'fastapi': 25 | uvicorn.run( 26 | 'launch.launch:app', 27 | host=sider_settings.server.host, 28 | port=sider_settings.server.port, 29 | reload=sider_settings.server.reload, 30 | env_file=Constants.Path.ENV_FILE_PATH 31 | ) 32 | else: 33 | raise ValueError(f"Unknown mode: {mode}") 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /loader/repo_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/11 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : repo_loader 7 | # @Software: PyCharm 8 | """Abstract interface for document loader implementations.""" 9 | 10 | from typing import AsyncIterator, Type 11 | 12 | from domain.entity.blob import Blob 13 | from domain.entity.code_segment import CodeSegment 14 | from loader.segmenter.base_line_segmenter import LineBasedSegmenter 15 | from loader.segmenter.base_segmenter import BaseSegmenter 16 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory 17 | 18 | 19 | class RepoLoader: 20 | 21 | @classmethod 22 | def get_segmenter_constructor(cls, extension: str | None = None) -> Type[BaseSegmenter]: 23 | if extension in TreeSitterFactory.get_ext_set(): 24 | return TreeSitterFactory.get_segmenter(extension) 25 | else: 26 | return LineBasedSegmenter 27 | 28 | @classmethod 29 | async def load_blob(cls, blob: Blob) -> AsyncIterator[CodeSegment]: 30 | try: 31 | text = await blob.as_string() 32 | except Exception as e: 33 | raise e 34 | 35 | segmenter = cls.get_segmenter_constructor(extension=blob.extension).from_tiktoken_encoder(text=text, merge_small_chunks=True) 36 | if not segmenter.is_valid(): 37 | segmenter = cls.get_segmenter_constructor().from_tiktoken_encoder(text=text) 38 | 39 | for code in segmenter.segment(): 40 | yield code 41 | -------------------------------------------------------------------------------- /launch/we0_index_mcp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/6/2 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : we0_index_mcp 7 | # @Software: PyCharm 8 | from contextlib import asynccontextmanager 9 | from typing import AsyncIterator 10 | 11 | from mcp.server.fastmcp import FastMCP 12 | from mcp.server.fastmcp.tools import Tool 13 | from mcp.server.lowlevel.server import LifespanResultT, Server 14 | from mcp.shared.context import RequestT 15 | 16 | from extensions import ext_manager 17 | from router.git_router import clone_and_index 18 | from router.vector_router import retrieval 19 | from setting.setting import get_we0_index_settings 20 | 21 | sider_settings = get_we0_index_settings() 22 | 23 | 24 | @asynccontextmanager 25 | async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]: 26 | await ext_manager.init_vector() 27 | yield {} 28 | await ext_manager.init_vector() 29 | 30 | 31 | def create_fast_mcp() -> FastMCP: 32 | app = FastMCP( 33 | name="We0 Index", 34 | description="CodeIndex, embedding, retrieval, Tool parameters must be in standard JSON format", 35 | tools=[ 36 | Tool.from_function(clone_and_index), 37 | Tool.from_function(retrieval), 38 | ], 39 | lifespan=lifespan, 40 | host=sider_settings.server.host, 41 | port=sider_settings.server.port, 42 | log_level=sider_settings.log.level 43 | ) 44 | return app 45 | 46 | 47 | we0_index_mcp = create_fast_mcp() -------------------------------------------------------------------------------- /loader/segmenter/tree_sitter_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/12 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : tree_sitter_factory 7 | # @Software: PyCharm 8 | from functools import lru_cache 9 | from typing import Dict, Type, Set 10 | 11 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter 12 | 13 | 14 | class TreeSitterFactory: 15 | __segmenter: Dict[str, Type[TreeSitterSegmenter]] = {} 16 | 17 | @classmethod 18 | def get_segmenter(cls, ext: str) -> Type[TreeSitterSegmenter]: 19 | if ext not in cls.__segmenter: 20 | raise ValueError(f'ext type {ext} is not supported') 21 | return cls.__segmenter[ext] 22 | 23 | @classmethod 24 | def __register(cls, ext: str, _cls: Type[TreeSitterSegmenter]): 25 | cls.__segmenter[ext] = _cls 26 | 27 | @classmethod 28 | def __has_cls(cls, ext: str) -> bool: 29 | return ext in cls.__segmenter 30 | 31 | @classmethod 32 | def register( 33 | cls, 34 | ext_set: Set[str] 35 | ): 36 | if not ext_set: 37 | raise ValueError('Must provide support extension set') 38 | 39 | def decorator(origin_cls): 40 | for ext in ext_set: 41 | if not cls.__has_cls(ext): 42 | TreeSitterFactory.__register(ext, origin_cls) 43 | return origin_cls 44 | 45 | return decorator 46 | 47 | @classmethod 48 | @lru_cache 49 | def get_ext_set(cls) -> Set[str]: 50 | return set(cls.__segmenter.keys()) 51 | -------------------------------------------------------------------------------- /prompt/prompt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/20 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : prompt 7 | # @Software: PyCharm 8 | class SystemPrompt: 9 | ANALYZE_CODE_PROMPT = """ 10 | # Task Instructions 11 | 1. I will provide a code block wrapped in ``` 12 | 2. Analyze the code with these steps: 13 | - Identify natural segments separated by empty lines, comment blocks, or logical sections 14 | - Generate technical descriptions for each segment 15 | 3. Output requirements: 16 | - Use numbered Markdown lists (1. 2. 3.) 17 | - Maximum 2 lines per item 18 | - Prioritize functional explanations, then implementation details 19 | - Preserve key technical terms/algorithms 20 | - Keep identical terminology with source code 21 | 22 | # Output Example 23 | 1. Initializes Spring Boot application: Uses @SpringBootApplication to configure bootstrap class, sets base package for component scanning 24 | 2. Implements RESTful endpoint: Creates /user API through @RestController, defines base path with @RequestMapping 25 | 3. Handles file uploads: Leverages S3 SDK to transfer local file_infos to cloud storage 26 | 27 | # Now analyze this code:* 28 | """ 29 | 30 | 31 | class SystemMessageTemplate: 32 | ANALYZE_CODE_MESSAGE_TEMPLATE = lambda code_text: [ 33 | { 34 | "role": "system", 35 | "content": SystemPrompt.ANALYZE_CODE_PROMPT 36 | }, 37 | { 38 | "role": "user", 39 | "content": code_text 40 | } 41 | ] 42 | -------------------------------------------------------------------------------- /domain/entity/document.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/14 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : document 7 | # @Software: PyCharm 8 | from typing import List, Optional 9 | 10 | from pydantic import BaseModel, ConfigDict, Field 11 | 12 | 13 | class DocumentMeta(BaseModel): 14 | repo_id: Optional[str] = Field(description='仓库ID') 15 | file_id: Optional[str] = Field(description='文件ID') 16 | segment_id: str = Field(description='代码段ID uuid4') 17 | relative_path: str = Field(description='代码段所属文件相对路径') 18 | start_line: int = Field(description='代码块启始行') 19 | end_line: int = Field(description='代码块结束行') 20 | segment_block: int = Field(description='代码块序号') 21 | segment_hash: str = Field(description='代码段哈希') 22 | segment_cl100k_base_token: Optional[int] = Field(default=None, description='代码段 cl100k_base token') 23 | segment_o200k_base_token: Optional[int] = Field(default=None, description='代码段 o200k_base token') 24 | description: Optional[str] = Field(default=None, description='代码描述 可选 用于描述嵌入') 25 | 26 | score: Optional[float] = Field(default=None, description='相似度评分 仅在相似度匹配时使用') 27 | content: Optional[str] = Field(default=None, description='代码块纯文本 兼容qdrant,qdrant其他字段只能存储在payload') 28 | 29 | model_config = ConfigDict( 30 | extra='ignore' 31 | ) 32 | 33 | 34 | class Document(BaseModel): 35 | vector: List[float] = Field(description='向量Embedding', default_factory=list) 36 | content: Optional[str] = Field(default=None, description='纯文本代码') 37 | meta: Optional[DocumentMeta] = Field(default=None, description='代码元数据') 38 | 39 | model_config = ConfigDict( 40 | extra='ignore' 41 | ) 42 | -------------------------------------------------------------------------------- /clients/jina/embeddings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/19 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : embeddings 7 | # @Software: PyCharm 8 | from typing import Union, List, Iterable, Optional 9 | 10 | from openai.types import CreateEmbeddingResponse 11 | 12 | from .client import AsyncClient 13 | 14 | 15 | class AsyncEmbeddings: 16 | def __init__(self, client: AsyncClient): 17 | self.client = client 18 | 19 | async def create( 20 | self, 21 | input: Union[str, List[str], Iterable[int], Iterable[Iterable[int]]], 22 | model: str = 'jina-embeddings-v2-base-code', 23 | normalized: Optional[bool] = None, 24 | embedding_type: Optional[str] = None, 25 | task: Optional[str] = None, 26 | late_chunking: Optional[bool] = None, 27 | dimensions: Optional[int] = None, 28 | ) -> CreateEmbeddingResponse: 29 | request_json = { 30 | 'input': input, 31 | 'model': model 32 | } 33 | if normalized: 34 | request_json['normalized'] = normalized 35 | if embedding_type: 36 | request_json['embedding_type'] = embedding_type 37 | if task: 38 | request_json['task'] = task 39 | if late_chunking: 40 | request_json['late_chunking'] = late_chunking 41 | if dimensions: 42 | request_json['dimensions'] = dimensions 43 | response = await self.client.post( 44 | '/embeddings', json=request_json, 45 | headers={'Authorization': f'Bearer {self.client.api_key}'} 46 | ) 47 | return CreateEmbeddingResponse.model_validate(response.json()) 48 | -------------------------------------------------------------------------------- /utils/git_parse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/4/27 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : git_parse 7 | # @Software: PyCharm 8 | 9 | 10 | import re 11 | 12 | 13 | def parse_git_url(git_url): 14 | # 支持的平台域名 15 | platforms = [ 16 | 'github.com', 17 | 'gitlab.com', 18 | 'gitee.com', 19 | 'bitbucket.org', 20 | 'codeberg.org' 21 | ] 22 | 23 | # 支持的URL模式 24 | patterns = [ 25 | # SSH格式: git@github.com:owner/repo 26 | r'^git@([^:]+):([^/]+)/([^/.]+)(?:\.git)?$', 27 | 28 | # HTTP/HTTPS格式: http(s)://github.com/owner/repo 29 | r'^https?://([^/]+)/([^/]+)/([^/.]+)(?:\.git)?$' 30 | ] 31 | 32 | # 去除可能的空白 33 | url = git_url.strip() if git_url else '' 34 | 35 | # 遍历匹配模式 36 | for pattern in patterns: 37 | match = re.match(pattern, url) 38 | if match: 39 | domain, owner, repo = match.groups() 40 | if domain.lower() in platforms: 41 | return domain, owner, repo 42 | 43 | return None, None, None 44 | 45 | 46 | if __name__ == '__main__': 47 | 48 | # 测试用例 49 | test_urls = [ 50 | # SSH协议 51 | 'git@github.com:we0-dev/we0', 52 | 'git@github.com:we0-dev/we0.git', 53 | 54 | # HTTP协议 55 | 'http://github.com/we0-dev/we0', 56 | 'http://github.com/we0-dev/we0.git', 57 | 58 | # HTTPS协议 59 | 'https://github.com/we0-dev/we0', 60 | 'https://github.com/we0-dev/we0.git', 61 | 62 | # 其他平台 63 | 'git@gitlab.com:group/project', 64 | 'http://gitlab.com/group/project', 65 | 'https://gitee.com/username/repo', 66 | 67 | # 无效的URL 68 | 'we0-dev/we0', # 不完整 69 | 'github.com/we0-dev/we0', # 缺少协议 70 | 'https://example.com/we0-dev/we0' # 非支持的平台 71 | ] 72 | 73 | for test_url in test_urls: 74 | result = parse_git_url(test_url) 75 | print(f"URL: {test_url}") 76 | print(f"Parsed: {result}\n") 77 | -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/18 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : helper 7 | # @Software: PyCharm 8 | 9 | import uuid 10 | from hashlib import sha256 11 | from typing import Optional, Literal, Union, Collection 12 | 13 | import tiktoken 14 | 15 | 16 | class Helper: 17 | # 缓存字典,用于存储生成的编码器实例 18 | _encoders_cache = {} 19 | 20 | @staticmethod 21 | def generate_text_hash(text: str) -> str: 22 | """生成文本的SHA-256哈希值""" 23 | return sha256(text.encode()).hexdigest() 24 | 25 | @staticmethod 26 | def calculate_tokens( 27 | text: str, 28 | encoding_name: str = "cl100k_base", 29 | model_name: Optional[str] = None, 30 | allowed_special=None, 31 | disallowed_special: Union[Literal["all"], Collection[str]] = "all", 32 | ) -> int: 33 | """使用tiktoken编码器计算文本的token数量""" 34 | if allowed_special is None: 35 | allowed_special = set() 36 | 37 | # 使用model_name/encoding_name和特殊字符集合作为缓存键 38 | cache_key = (model_name, encoding_name, frozenset(allowed_special), disallowed_special) 39 | 40 | # 如果缓存中已有相应的编码器实例,则直接返回 41 | if cache_key in Helper._encoders_cache: 42 | return Helper._encoders_cache[cache_key](text) 43 | 44 | # 根据模型名或编码名称来选择合适的编码器 45 | if model_name: 46 | enc = tiktoken.encoding_for_model(model_name) 47 | else: 48 | enc = tiktoken.get_encoding(encoding_name) 49 | 50 | # 创建新的编码器 51 | def _tiktoken_encoder(_text: str) -> int: 52 | """计算给定文本的token数量""" 53 | return len( 54 | enc.encode( 55 | _text, 56 | allowed_special=allowed_special, 57 | disallowed_special=disallowed_special, 58 | ) 59 | ) 60 | 61 | # 缓存编码器并返回计算结果 62 | Helper._encoders_cache[cache_key] = _tiktoken_encoder 63 | return _tiktoken_encoder(text) 64 | 65 | @staticmethod 66 | def generate_fixed_uuid(unique_str: str) -> str: 67 | namespace = uuid.NAMESPACE_URL 68 | return str(uuid.uuid5(namespace, unique_str)) 69 | -------------------------------------------------------------------------------- /launch/launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/7/29 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : launch 7 | # @Software: PyCharm 8 | import asyncio 9 | from contextlib import asynccontextmanager 10 | 11 | from fastapi import FastAPI 12 | from fastapi.encoders import jsonable_encoder 13 | from loguru import logger 14 | from starlette.middleware.cors import CORSMiddleware 15 | from starlette.requests import Request 16 | from starlette.responses import JSONResponse 17 | 18 | from config.loguru import Log 19 | from domain.result.result import Result 20 | from exception.exception import CommonException 21 | from extensions import ext_manager 22 | from router.git_router import git_router 23 | from router.vector_router import vector_router 24 | from setting.setting import get_we0_index_settings 25 | 26 | settings = get_we0_index_settings() 27 | 28 | 29 | async def initialize_extensions(): 30 | await asyncio.gather( 31 | ext_manager.init_vector(), 32 | ) 33 | 34 | 35 | async def close_extensions(): 36 | ... 37 | 38 | 39 | @asynccontextmanager 40 | async def lifespan(fastapi: FastAPI): 41 | try: 42 | Log.start() 43 | await initialize_extensions() 44 | yield 45 | finally: 46 | await close_extensions() 47 | Log.close() 48 | 49 | 50 | def create_app() -> FastAPI: 51 | app = FastAPI( 52 | title="We0 Index", 53 | description="We0 Index API", 54 | version="0.1.0", 55 | lifespan=lifespan 56 | ) 57 | 58 | # 添加 CORS 中间件 59 | app.add_middleware( 60 | CORSMiddleware, 61 | allow_origins=["*"], 62 | allow_credentials=True, 63 | allow_methods=["*"], 64 | allow_headers=["*"], 65 | ) 66 | 67 | return app 68 | 69 | 70 | app = create_app() 71 | # 注册路由 72 | app.include_router(vector_router, prefix="/vector", tags=["vector"]) 73 | app.include_router(git_router, prefix="/git", tags=["git"]) 74 | 75 | 76 | @app.exception_handler(CommonException) 77 | async def common_exception_handler(request: Request, exc: CommonException): 78 | error = Result.failed(code=-1, message=exc.message) 79 | logger.exception(f"Url: {request.url}, Exception: CommonException, Error: {error}") 80 | return JSONResponse(content=jsonable_encoder(error)) 81 | 82 | 83 | @app.exception_handler(Exception) 84 | async def exception_handler(request: Request, exc: Exception): 85 | error = Result.failed(code=-1, message=str(exc.args)) 86 | logger.exception(f"Url: {request.url}, {type(exc).__name__}: {exc} , Error: {error}") 87 | return JSONResponse(content=jsonable_encoder(error)) 88 | 89 | 90 | app.add_middleware( 91 | CORSMiddleware, 92 | allow_origins=['*'], 93 | allow_credentials=True, 94 | allow_methods=['*'], 95 | allow_headers=['*'], 96 | ) 97 | -------------------------------------------------------------------------------- /extensions/vector/base_vector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/14 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : base_vector 7 | # @Software: PyCharm 8 | from __future__ import annotations 9 | 10 | from abc import ABC, abstractmethod 11 | from typing import List, Optional 12 | 13 | from domain.entity.document import Document, DocumentMeta 14 | from models.model_factory import ModelFactory 15 | from setting.setting import get_we0_index_settings 16 | 17 | settings = get_we0_index_settings() 18 | 19 | 20 | class BaseVector(ABC): 21 | 22 | @staticmethod 23 | @abstractmethod 24 | def get_client(): 25 | raise NotImplementedError 26 | 27 | @abstractmethod 28 | async def init(self): 29 | raise NotImplementedError 30 | 31 | @abstractmethod 32 | async def create(self, documents: List[Document]): 33 | raise NotImplementedError 34 | 35 | @abstractmethod 36 | async def upsert(self, documents: List[Document]): 37 | raise NotImplementedError 38 | 39 | @abstractmethod 40 | async def all_meta(self, repo_id: str) -> List[DocumentMeta]: 41 | raise NotImplementedError 42 | 43 | @abstractmethod 44 | async def drop(self, repo_id: str): 45 | raise NotImplementedError 46 | 47 | @abstractmethod 48 | async def delete(self, repo_id: str, file_ids: List[str]): 49 | raise NotImplementedError 50 | 51 | @abstractmethod 52 | async def search_by_vector( 53 | self, 54 | repo_id: str, 55 | file_ids: Optional[List[str]], 56 | query_vector: List[float], 57 | top_k: int = 5, 58 | score_threshold: float = 0.0 59 | ) -> List[Document]: 60 | raise NotImplementedError 61 | 62 | @staticmethod 63 | def dynamic_collection_name(dimension: int) -> str: 64 | return f'we0_index_{settings.vector.embedding_model}_{dimension}'.replace('-', '_') 65 | 66 | # TODO 以后这边应该从仓库数据表中读取用户的`model_provider`和`model_name` 67 | # 前期先全部使用`openai`的`text-embedding-3-small` 68 | @classmethod 69 | async def get_embedding_model(cls): 70 | return await ModelFactory.get_model( 71 | model_provider=settings.vector.embedding_provider, 72 | model_name=settings.vector.embedding_model 73 | ) 74 | 75 | @classmethod 76 | async def get_completions_model(cls): 77 | return await ModelFactory.get_model( 78 | model_provider=settings.vector.chat_provider, 79 | model_name=settings.vector.chat_model 80 | ) 81 | 82 | @classmethod 83 | async def get_dimension(cls) -> int: 84 | embedding_model = await cls.get_embedding_model() 85 | vector_data_list = await embedding_model.create_embedding(['get_embedding_dimension']) 86 | dimension = len(vector_data_list[0]) 87 | return dimension 88 | -------------------------------------------------------------------------------- /loader/segmenter/base_segmenter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/12 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : base_segmenter 7 | # @Software: PyCharm 8 | from abc import ABC, abstractmethod 9 | from typing import Any, Callable, Collection, Union, Literal, Optional, Iterator 10 | 11 | from domain.entity.code_segment import CodeSegment 12 | 13 | 14 | class BaseSegmenter(ABC): 15 | def __init__( 16 | self, 17 | max_tokens: int = 512, 18 | length_function: Callable[[str], int] = len, 19 | merge_small_chunks: bool = False, 20 | ): 21 | self.max_tokens = max_tokens 22 | self.length_function = length_function 23 | self.merge_small_chunks = merge_small_chunks 24 | 25 | def is_valid(self) -> bool: 26 | return True 27 | 28 | @abstractmethod 29 | def segment(self) -> Iterator[CodeSegment]: 30 | raise NotImplementedError 31 | 32 | @classmethod 33 | def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any): 34 | """Text splitter that uses HuggingFace tokenizer to count length.""" 35 | try: 36 | from transformers import PreTrainedTokenizerBase 37 | 38 | if not isinstance(tokenizer, PreTrainedTokenizerBase): 39 | raise ValueError( 40 | "Tokenizer received was not an instance of PreTrainedTokenizerBase" 41 | ) 42 | 43 | def _huggingface_tokenizer_length(text: str) -> int: 44 | return len(tokenizer.encode(text)) 45 | 46 | except ImportError: 47 | raise ValueError( 48 | "Could not import transformers python package. " 49 | "Please install it with `pip install transformers`." 50 | ) 51 | return cls(length_function=_huggingface_tokenizer_length, **kwargs) 52 | 53 | @classmethod 54 | def from_tiktoken_encoder( 55 | cls, 56 | encoding_name: str = "cl100k_base", 57 | model_name: Optional[str] = None, 58 | allowed_special=None, 59 | disallowed_special: Union[Literal["all"], Collection[str]] = "all", 60 | **kwargs: Any, 61 | ): 62 | """Text splitter that uses tiktoken encoder to count length.""" 63 | if allowed_special is None: 64 | allowed_special = set() 65 | import tiktoken 66 | if model_name: 67 | enc = tiktoken.encoding_for_model(model_name) 68 | else: 69 | enc = tiktoken.get_encoding(encoding_name) 70 | 71 | def _tiktoken_encoder(text: str) -> int: 72 | return len( 73 | enc.encode( 74 | text, 75 | allowed_special=allowed_special, 76 | disallowed_special=disallowed_special, 77 | ) 78 | ) 79 | 80 | return cls(length_function=_tiktoken_encoder, **kwargs) 81 | -------------------------------------------------------------------------------- /config/loguru.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/07/17 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : loguru 7 | # @Software: PyCharm 8 | import inspect 9 | import logging 10 | import os 11 | import sys 12 | from typing import List, Dict 13 | 14 | from loguru import logger 15 | 16 | from constants.constants import Constants 17 | from setting.setting import get_we0_index_settings 18 | from utils.path_util import PathUtil 19 | 20 | sider_settings = get_we0_index_settings() 21 | 22 | 23 | class InterceptHandler(logging.Handler): 24 | def emit(self, record: logging.LogRecord) -> None: 25 | level: str | int 26 | try: 27 | level = logger.level(record.levelname).name 28 | except ValueError: 29 | level = record.levelno 30 | frame, depth = inspect.currentframe(), 0 31 | while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__): 32 | frame = frame.f_back 33 | depth += 1 34 | 35 | logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) 36 | 37 | 38 | class Log: 39 | file_name = Constants.Common.PROJECT_NAME + '_{time}' + '.log' 40 | log_file_path = os.path.join(Constants.Path.LOG_PATH, file_name) 41 | logger_level = sider_settings.log.level 42 | logger_file = sider_settings.log.file 43 | DEFAULT_CONFIG = [ 44 | { 45 | 'sink': sys.stdout, 46 | 'level': logger_level, 47 | 'format': '[{time:YYYY-MM-DD HH:mm:ss.SSS}][{level}]' 48 | '[{file}:{line}]: {message}', 49 | 'colorize': True, # 自定义配色 50 | 'serialize': False, # 序列化数据打印 51 | 'backtrace': True, # 是否显示完整的异常堆栈跟踪 52 | 'diagnose': True, # 异常跟踪是否显示触发异常的方法或语句所使用的变量,生产环境应设为 False 53 | 'enqueue': False, # 默认线程安全。若想实现协程安全 或 进程安全,该参数设为 True 54 | 'catch': True, # 捕获异常 55 | 56 | } 57 | ] 58 | if logger_file: 59 | DEFAULT_CONFIG.append({ 60 | 'sink': log_file_path, 61 | 'level': logger_level, 62 | 'format': '[{time:YYYY-MM-DD HH:mm:ss.SSS}][{level}][{file}:{line}]: {message}', 63 | 'retention': '7 days', # 日志保留时间 64 | 'serialize': False, # 序列化数据打印 65 | 'backtrace': True, # 是否显示完整的异常堆栈跟踪 66 | 'diagnose': True, # 异常跟踪是否显示触发异常的方法或语句所使用的变量,生产环境应设为 False 67 | 'enqueue': False, # 默认线程安全。若想实现协程安全 或 进程安全,该参数设为 True 68 | 'catch': True, # 捕获异常 69 | }) 70 | 71 | @staticmethod 72 | def start(config: List[Dict] | None = None) -> None: 73 | PathUtil.check_or_make_dirs( 74 | Constants.Path.LOG_PATH 75 | ) 76 | if config: 77 | logger.configure(handlers=config) 78 | else: 79 | logger.configure(handlers=Log.DEFAULT_CONFIG) 80 | if sider_settings.log.debug: 81 | logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True) 82 | logger.enable('__main__') 83 | 84 | @staticmethod 85 | def close() -> None: 86 | logger.disable('__main__') 87 | -------------------------------------------------------------------------------- /extensions/vector/ext_vector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/14 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : ext_vector 7 | # @Software: PyCharm 8 | from typing import List, Optional 9 | 10 | from loguru import logger 11 | 12 | from domain.enums.vector_type import VectorType 13 | from domain.entity.document import Document, DocumentMeta 14 | from extensions.vector.base_vector import BaseVector 15 | from setting.setting import We0IndexSettings, get_we0_index_settings 16 | 17 | settings: We0IndexSettings = get_we0_index_settings() 18 | 19 | 20 | class Vector: 21 | 22 | def __init__(self): 23 | self.vector_runner: BaseVector | None = None 24 | 25 | async def create(self, documents: List[Document]): 26 | try: 27 | await self.vector_runner.create(documents) 28 | except Exception as e: 29 | logger.error(e) 30 | raise e 31 | 32 | async def upsert(self, documents: List[Document]): 33 | try: 34 | await self.vector_runner.upsert(documents) 35 | except Exception as e: 36 | logger.error(e) 37 | raise e 38 | 39 | async def all_meta(self, repo_id: str) -> List[DocumentMeta]: 40 | try: 41 | return await self.vector_runner.all_meta(repo_id) 42 | except Exception as e: 43 | logger.error(e) 44 | raise e 45 | 46 | async def drop(self, repo_id: str): 47 | try: 48 | await self.vector_runner.drop(repo_id) 49 | except Exception as e: 50 | logger.error(e) 51 | raise e 52 | 53 | async def delete(self, repo_id: str, file_ids: List[str]): 54 | try: 55 | await self.vector_runner.delete(repo_id, file_ids) 56 | except Exception as e: 57 | logger.error(e) 58 | raise e 59 | 60 | async def search_by_vector( 61 | self, repo_id: str, file_ids: Optional[List[str]], query_vector: List[float], top_k: int = 5 62 | ) -> List[Document]: 63 | try: 64 | return await self.vector_runner.search_by_vector(repo_id, file_ids, query_vector, top_k) 65 | except Exception as e: 66 | logger.error(e) 67 | raise e 68 | 69 | async def init_app(self) -> None: 70 | vector_constructor = self.get_vector_factory(settings.vector.platform) 71 | self.vector_runner = vector_constructor() 72 | await self.vector_runner.init() 73 | 74 | @staticmethod 75 | def get_vector_factory(vector_type: VectorType) -> type[BaseVector]: 76 | match vector_type: 77 | case VectorType.PGVECTOR: 78 | from extensions.vector.pgvector import PgVector 79 | 80 | return PgVector 81 | case VectorType.QDRANT: 82 | from extensions.vector.qdrant import Qdrant 83 | 84 | return Qdrant 85 | case VectorType.CHROMA: 86 | from extensions.vector.chroma import Chroma 87 | 88 | return Chroma 89 | case _: 90 | raise ValueError(f"Unknown storage type: {vector_type}") 91 | 92 | def __getattr__(self, item): 93 | if self.vector_runner is None: 94 | raise RuntimeError("Vector clients is not initialized. Call init_app first.") 95 | return getattr(self.vector_runner, item) 96 | -------------------------------------------------------------------------------- /README-zh.md: -------------------------------------------------------------------------------- 1 | # We0-index 2 | 3 | 一个基于Python的类似Cursor的代码索引引擎,可将Git仓库代码转化为代码片段并生成语义embedding,根据用户query实现智能代码搜索和检索。 4 | 5 | ## 🚀 功能特性 6 | 7 | - **代码片段生成**:自动处理Git仓库,将代码转换为可搜索的代码片段 8 | - **语义搜索**:生成语义embeddings,基于用户查询实现智能代码检索 9 | - **多语言支持**:针对Python、Java、Go、JavaScript和TypeScript进行优化 10 | - **灵活后端**:支持多种向量数据库后端和embedding服务提供商 11 | - **MCP集成**:内置MCP(模型上下文协议)服务调用支持 12 | - **部署就绪**:为不同环境提供灵活的部署选项 13 | 14 | ## 📋 环境要求 15 | 16 | - Python 3.12+ 17 | - Git 18 | 19 | ## 🛠️ 安装 20 | 21 | ### 快速开始 22 | 23 | ```bash 24 | # 克隆仓库 25 | git clone https://github.com/we0-dev/we0-index 26 | cd we0-index 27 | 28 | # 设置环境配置 29 | cp .env.example .env 30 | vim .env 31 | 32 | # 配置应用设置 33 | vim resource/dev.yaml 34 | 35 | # 创建虚拟环境并安装依赖 36 | uv venv 37 | # Linux/macOS 38 | source .venv/bin/activate 39 | # Windows 40 | .venv\Scripts\activate 41 | uv sync 42 | ``` 43 | 44 | ### 开发环境设置 45 | 46 | ```bash 47 | # 安装开发依赖 48 | uv sync --frozen 49 | ``` 50 | 51 | ## ⚙️ 配置 52 | 53 | 1. **环境变量**:将`.env.example`复制为`.env`并配置您的设置 54 | 2. **应用配置**:编辑`resource/dev.yaml`以自定义您的部署 55 | 3. **向量数据库**:配置您首选的向量数据库后端 56 | 4. **Embedding服务**:设置您的embedding服务提供商 57 | 58 | ## 🚀 启动服务 59 | 60 | We0-index支持两种运行模式:Web API服务和MCP协议服务。 61 | 62 | ### Web API 模式 63 | 64 | 启动FastAPI Web服务器,提供RESTful API接口: 65 | 66 | ```bash 67 | # 激活虚拟环境 68 | # Linux/macOS 69 | source .venv/bin/activate 70 | # Windows 71 | .venv\Scripts\activate 72 | 73 | # 启动Web服务 74 | python main.py --mode fastapi 75 | ``` 76 | 77 | Web服务将在配置的主机和端口上启动(默认配置请查看`resource/dev.yaml`)。 78 | 79 | ### MCP 协议模式 80 | 81 | 启动MCP(模型上下文协议)服务,用于AI集成: 82 | 83 | ```bash 84 | # 激活虚拟环境 85 | # Linux/macOS 86 | source .venv/bin/activate 87 | # Windows 88 | .venv\Scripts\activate 89 | 90 | # 启动MCP服务(默认使用streamable-http传输协议) 91 | python main.py --mode mcp 92 | 93 | # 指定其他传输协议 94 | python main.py --mode mcp --transport stdio 95 | python main.py --mode mcp --transport websocket 96 | ``` 97 | 98 | MCP服务默认使用streamable-http传输协议运行,可与支持MCP的AI客户端集成。 99 | 100 | ### 运行参数 101 | 102 | **模式参数**: 103 | - `--mode fastapi`:启动Web API服务 104 | - `--mode mcp`:启动MCP协议服务 105 | 106 | **传输协议参数**(仅适用于MCP模式): 107 | - `--transport streamable-http`:使用HTTP流传输(默认) 108 | - `--transport stdio`:使用标准输入输出传输 109 | - `--transport websocket`:使用WebSocket传输 110 | 111 | 112 | ## 🏗️ 架构 113 | 114 | We0-index采用模块化架构,支持: 115 | 116 | - **代码解析器**:特定语言的代码解析和片段提取 117 | - **Embedding引擎**:多种embedding服务集成 118 | - **向量存储**:可插拔的向量数据库后端 119 | - **搜索接口**:用于代码搜索的RESTful API和CLI 120 | - **MCP协议**:用于AI集成的模型上下文协议 121 | 122 | ## 🤝 贡献 123 | 124 | 我们欢迎贡献!请查看我们的[贡献指南](CONTRIBUTING.md)了解详情。 125 | 126 | 1. Fork 仓库 127 | 2. 创建您的功能分支 (`git checkout -b feature/amazing-feature`) 128 | 3. 提交您的更改 (`git commit -m 'Add some amazing feature'`) 129 | 4. 推送到分支 (`git push origin feature/amazing-feature`) 130 | 5. 打开一个Pull Request 131 | 132 | ## 📝 许可证 133 | 134 | 本项目基于MIT许可证 - 查看[LICENSE](LICENSE)文件了解详情。 135 | 136 | 137 | ## 📚 文档 138 | 139 | 详细文档请访问我们的[文档站点](https://docs.we0-dev.com)或查看`docs/`目录。 140 | 141 | ## 🐛 问题反馈 142 | 143 | 如果您遇到任何问题,请在GitHub上[创建issue](https://github.com/we0-dev/we0-index/issues)。 144 | 145 | ## 📞 支持 146 | 147 | - 📧 邮箱:we0@wegc.cn 148 | - 💬 讨论:[GitHub Discussions](https://github.com/we0-dev/we0-index/discussions) 149 | - 📖 Wiki:[项目Wiki](https://deepwiki.com/we0-dev/we0-index) 150 | 151 | ## 🌟 致谢 152 | 153 | - 感谢所有帮助改进这个项目的贡献者 154 | - 灵感来源于Cursor的代码智能方法 155 | - 使用现代Python工具和最佳实践构建 156 | 157 | --- 158 | 159 | **由We0-dev团队用❤️制作** 160 | -------------------------------------------------------------------------------- /utils/vector_helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/23 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : vector_helper 7 | # @Software: PyCharm 8 | import asyncio 9 | import uuid 10 | from typing import List 11 | 12 | from fastapi import APIRouter 13 | from loguru import logger 14 | 15 | from domain.entity.document import Document, DocumentMeta 16 | from domain.entity.task_context import TaskContext 17 | from extensions.ext_manager import ExtManager 18 | from loader.repo_loader import RepoLoader 19 | from models.model_factory import ModelInstance 20 | from prompt.prompt import SystemMessageTemplate 21 | from setting.setting import get_we0_index_settings 22 | from utils.helper import Helper 23 | 24 | vector_router = APIRouter() 25 | 26 | settings = get_we0_index_settings() 27 | 28 | 29 | class VectorHelper: 30 | @staticmethod 31 | async def code2description(document: Document, chat_model: ModelInstance): 32 | document.meta.description = await chat_model.create_completions( 33 | messages=SystemMessageTemplate.ANALYZE_CODE_MESSAGE_TEMPLATE(document.content) 34 | ) 35 | 36 | @staticmethod 37 | async def build_and_embedding_segment(task_context: TaskContext) -> List[Document]: 38 | try: 39 | documents: List[Document] = [ 40 | Document( 41 | content=segment.code, 42 | meta=DocumentMeta( 43 | repo_id=task_context.repo_id, 44 | file_id=task_context.file_id, 45 | segment_id=f"{uuid.uuid4()}", 46 | relative_path=task_context.relative_path, 47 | start_line=segment.start, 48 | end_line=segment.end, 49 | segment_block=segment.block, 50 | segment_hash=Helper.generate_text_hash(segment.code), 51 | segment_cl100k_base_token=Helper.calculate_tokens(segment.code), 52 | segment_o200k_base_token=Helper.calculate_tokens(segment.code, 'o200k_base') 53 | ) 54 | ) async for segment in RepoLoader.load_blob(task_context.blob) 55 | ] 56 | except UnicodeDecodeError as e: 57 | logger.error(e) 58 | documents = [] 59 | if documents: 60 | embedding_model: ModelInstance = await ExtManager.vector.get_embedding_model() 61 | if settings.vector.code2desc: 62 | chat_model: ModelInstance = await ExtManager.vector.get_completions_model() 63 | await asyncio.wait([ 64 | asyncio.create_task(VectorHelper.code2description(document=document, chat_model=chat_model)) 65 | for document in documents 66 | ]) 67 | vector_data: List[List[float]] = await embedding_model.create_embedding( 68 | [ 69 | f"'{document.meta.relative_path}'\n'{document.meta.description}'\n{document.content}" 70 | for document in documents 71 | ] 72 | ) 73 | else: 74 | vector_data: List[List[float]] = await embedding_model.create_embedding( 75 | [ 76 | f"'{document.meta.relative_path}'\n{document.content}" for document in documents 77 | ] 78 | ) 79 | for index, document in enumerate(documents): 80 | document.vector = vector_data[index] 81 | return documents 82 | -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/15 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : model_factory 7 | # @Software: PyCharm 8 | import asyncio 9 | from typing import List, Tuple, Dict, Iterable 10 | 11 | from openai.types import CreateEmbeddingResponse 12 | from openai.types.chat import ChatCompletionMessageParam, ChatCompletion 13 | 14 | from clients import jina 15 | from domain.enums.model_provider import ModelType 16 | 17 | 18 | class ModelInstance: 19 | 20 | def __init__(self, model_type: ModelType, model_name: str): 21 | self.model_type = model_type 22 | self.model_name = model_name 23 | 24 | def get_completions_client(self): 25 | match self.model_type: 26 | case ModelType.OPENAI: 27 | import openai 28 | return openai.AsyncClient().chat.completions 29 | case _: 30 | raise Exception(f"Unknown model type: {self.model_type}") 31 | 32 | def get_embedding_client(self): 33 | match self.model_type: 34 | case ModelType.OPENAI: 35 | import openai 36 | return openai.AsyncClient().embeddings 37 | case ModelType.JINA: 38 | return jina.AsyncClient().embeddings 39 | case _: 40 | raise Exception(f"Unknown model type: {self.model_type}") 41 | 42 | async def create_embedding(self, documents: List[str]) -> List[List[float]]: 43 | match self.model_type: 44 | case ModelType.OPENAI | ModelType.JINA: 45 | docs_seq = list(documents) 46 | if len(docs_seq) > 2048: 47 | all_embeddings: List[List[float]] = [] 48 | for start in range(0, len(docs_seq), 2048): 49 | batch = docs_seq[start: start + 2048] 50 | resp = await self.get_embedding_client().create( 51 | input=batch, model=self.model_name 52 | ) 53 | all_embeddings.extend(d.embedding for d in resp.data) 54 | return all_embeddings 55 | else: 56 | resp = await self.get_embedding_client().create( 57 | input=docs_seq, model=self.model_name 58 | ) 59 | return [d.embedding for d in resp.data] 60 | case _: 61 | raise Exception(f"Unknown model type: {self.model_type}") 62 | 63 | async def create_completions(self, messages: Iterable[ChatCompletionMessageParam]) -> str: 64 | match self.model_type: 65 | case ModelType.OPENAI: 66 | completions_response: ChatCompletion = await self.get_completions_client().create( 67 | model=self.model_name, messages=messages 68 | ) 69 | return completions_response.choices[0].message.content 70 | case _: 71 | raise Exception(f"Unknown model type: {self.model_type}") 72 | 73 | 74 | class ModelFactory: 75 | _lock = asyncio.Lock() 76 | _instances: Dict[Tuple[ModelType, str], ModelInstance] = {} 77 | 78 | @classmethod 79 | async def get_model(cls, model_provider: ModelType, model_name: str) -> ModelInstance: 80 | key = (model_provider, model_name) 81 | if key not in ModelFactory._instances: 82 | async with cls._lock: 83 | if key not in ModelFactory._instances: 84 | instance = ModelInstance(model_provider, model_name) 85 | cls._instances[key] = instance 86 | return cls._instances[key] 87 | -------------------------------------------------------------------------------- /domain/entity/blob.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/14 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : blob 7 | # @Software: PyCharm 8 | import asyncio 9 | from contextlib import asynccontextmanager, contextmanager 10 | from io import BytesIO 11 | from pathlib import PurePath 12 | from typing import Dict, Any, Generator, BinaryIO 13 | 14 | import aiofiles 15 | from pydantic import Field, model_validator, ConfigDict, BaseModel 16 | 17 | 18 | class Blob(BaseModel): 19 | id: str | None = None 20 | filename: str | None = None 21 | 22 | meta: Dict[str, Any] = Field(default_factory=lambda: list) 23 | 24 | data: bytes | str | None = None 25 | mimetype: str | None = None 26 | extension: str | None = None 27 | encoding: str = "utf-8" 28 | path: str | PurePath | None = None 29 | 30 | model_config = ConfigDict( 31 | arbitrary_types_allowed=True, 32 | frozen=True, 33 | ) 34 | 35 | @classmethod 36 | @model_validator(mode="before") 37 | def check_blob_is_valid(cls, values: dict[str, Any]) -> Any: 38 | if "data" not in values and "path" not in values: 39 | msg = "Either data or path must be provided" 40 | raise ValueError(msg) 41 | return values 42 | 43 | async def as_string(self) -> str: 44 | """Read data as a string.""" 45 | if self.data is None and self.path: 46 | async with aiofiles.open(str(self.path), mode='r', encoding=self.encoding) as f: 47 | return await f.read() 48 | elif isinstance(self.data, bytes): 49 | return self.data.decode(self.encoding) 50 | elif isinstance(self.data, str): 51 | return self.data 52 | else: 53 | msg = f"Unable to get string for blob {self}" 54 | raise ValueError(msg) 55 | 56 | def as_bytes(self) -> bytes: 57 | """Read data as bytes.""" 58 | if isinstance(self.data, bytes): 59 | return self.data 60 | elif isinstance(self.data, str): 61 | return self.data.encode(self.encoding) 62 | elif self.data is None and self.path: 63 | with open(str(self.path), "rb") as f: 64 | return f.read() 65 | else: 66 | msg = f"Unable to get bytes for blob {self}" 67 | raise ValueError(msg) 68 | 69 | @contextmanager 70 | def as_bytes_io(self) -> Generator[BytesIO | BinaryIO, None, None]: 71 | """Read data as a byte stream.""" 72 | if isinstance(self.data, bytes): 73 | yield BytesIO(self.data) 74 | elif self.data is None and self.path: 75 | with open(str(self.path), "rb") as f: 76 | yield f 77 | else: 78 | msg = f"Unable to convert blob {self}" 79 | raise NotImplementedError(msg) 80 | 81 | @asynccontextmanager 82 | async def as_async_bytes_io(self): 83 | if isinstance(self.data, bytes): 84 | reader = asyncio.StreamReader() 85 | reader.feed_data(self.data) 86 | reader.feed_eof() 87 | yield reader 88 | elif self.data is None and self.path: 89 | async with aiofiles.open(str(self.path), 'rb') as f: 90 | yield f 91 | else: 92 | msg = f"Unable to convert blob {self}" 93 | raise NotImplementedError(msg) 94 | 95 | async def write_to_file(self, file, chunks: int = 8192): 96 | async with self.as_async_bytes_io() as reader: 97 | if isinstance(reader, asyncio.StreamReader): 98 | while chunk := await reader.read(chunks): 99 | await file.write(chunk) 100 | else: 101 | await file.write(await reader.read()) 102 | 103 | @classmethod 104 | def from_path( 105 | cls, 106 | path: str | PurePath, 107 | *, 108 | encoding: str = "utf-8", 109 | mimetype: str | None = None, 110 | extension: str | None = None, 111 | meta: Dict[str, Any] | None = None, 112 | ): 113 | return cls( 114 | data=None, 115 | mimetype=mimetype, 116 | extension=extension, 117 | encoding=encoding, 118 | path=path, 119 | meta=meta if meta is not None else {} 120 | ) 121 | 122 | @classmethod 123 | def from_data( 124 | cls, 125 | data: str | bytes, 126 | *, 127 | encoding: str = "utf-8", 128 | mimetype: str | None = None, 129 | extension: str | None = None, 130 | path: str | None = None, 131 | meta: dict | None = None, 132 | ): 133 | return cls( 134 | data=data, 135 | mimetype=mimetype, 136 | extension=extension, 137 | encoding=encoding, 138 | path=path, 139 | meta=meta if meta is not None else {} 140 | ) 141 | 142 | def __repr__(self) -> str: 143 | str_repr = f"Blob {id(self)}" 144 | return str_repr 145 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | #poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 117 | .pdm.toml 118 | .pdm-python 119 | .pdm-build/ 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | 171 | # Abstra 172 | # Abstra is an AI-powered process automation framework. 173 | # Ignore directories containing user credentials, local state, and settings. 174 | # Learn more at https://abstra.io/docs 175 | .abstra/ 176 | 177 | # Visual Studio Code 178 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 179 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 180 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 181 | # you could uncomment the following to ignore the entire vscode folder 182 | # .vscode/ 183 | 184 | # Ruff stuff: 185 | .ruff_cache/ 186 | 187 | # PyPI configuration file 188 | .pypirc 189 | 190 | # Cursor 191 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 192 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 193 | # refer to https://docs.cursor.com/context/ignore-files 194 | .cursorignore 195 | .cursorindexingignore 196 | -------------------------------------------------------------------------------- /setting/setting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/alias python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/07/17 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : setting 7 | # @Software: PyCharm 8 | import os.path 9 | from functools import lru_cache 10 | from typing import Type 11 | 12 | from pydantic import BaseModel, Field, model_validator 13 | from pydantic_settings import BaseSettings, SettingsConfigDict, PydanticBaseSettingsSource, YamlConfigSettingsSource 14 | 15 | from constants.constants import Constants 16 | from domain.enums.chroma_mode import ChromaMode 17 | from domain.enums.model_provider import ModelType 18 | from domain.enums.qdrant_mode import QdrantMode 19 | from domain.enums.vector_type import VectorType 20 | 21 | 22 | class ServerSettings(BaseModel): 23 | host: str = Field('0.0.0.0') 24 | port: int = Field(8080) 25 | reload: bool = Field(True) 26 | 27 | 28 | class LogSettings(BaseModel): 29 | level: str = Field(default="INFO") 30 | file: bool = Field(default=False) 31 | debug: bool = Field(default=False) 32 | 33 | 34 | class PGVectorSettings(BaseSettings): 35 | db: str 36 | host: str 37 | port: int 38 | user: str 39 | password: str 40 | 41 | 42 | class QdrantDiskSettings(BaseSettings): 43 | path: str = Field(default=Constants.Path.QDRANT_DEFAULT_DISK_PATH) 44 | 45 | @model_validator(mode='before') 46 | def handle_path(self): 47 | if not os.path.isabs(self['path']): 48 | self['path'] = os.path.join(Constants.Path.ROOT_PATH, self['path']) 49 | return self 50 | 51 | 52 | class QdrantRemoteSettings(BaseSettings): 53 | host: str 54 | port: int = Field(default=6333) 55 | 56 | 57 | class QdrantSettings(BaseSettings): 58 | mode: QdrantMode 59 | disk: QdrantDiskSettings | None 60 | remote: QdrantRemoteSettings | None 61 | memory: None 62 | 63 | @model_validator(mode='before') 64 | def clear_conflicting_settings(self): 65 | for key in [member for member in QdrantMode if member != self['mode']]: 66 | self[key] = None 67 | return self 68 | 69 | 70 | class ChromaDiskSettings(BaseSettings): 71 | path: str = Field(default=Constants.Path.QDRANT_DEFAULT_DISK_PATH) 72 | 73 | @model_validator(mode='before') 74 | def handle_path(self): 75 | if not os.path.isabs(self['path']): 76 | self['path'] = os.path.join(Constants.Path.ROOT_PATH, self['path']) 77 | return self 78 | 79 | 80 | class ChromaRemoteSettings(BaseSettings): 81 | host: str 82 | port: int = Field(default=6333) 83 | ssl: bool = Field(default=False) 84 | 85 | 86 | class ChromaSettings(BaseSettings): 87 | mode: ChromaMode 88 | disk: ChromaDiskSettings | None 89 | remote: ChromaRemoteSettings | None 90 | memory: None 91 | 92 | @model_validator(mode='before') 93 | def clear_conflicting_settings(self): 94 | for key in [member for member in ChromaMode if member != self['mode']]: 95 | self[key] = None 96 | return self 97 | 98 | 99 | class VectorSettings(BaseSettings): 100 | platform: VectorType 101 | code2desc: bool = Field(default=False) 102 | chat_provider: ModelType = Field(default='openai', alias='chat-provider') 103 | chat_model: str = Field(default='gpt-4o-mini', alias='chat-model') 104 | embedding_provider: ModelType = Field(default='openai', alias='embedding-provider') 105 | embedding_model: str = Field(default='text-embedding-3-small', alias='embedding-model') 106 | pgvector: PGVectorSettings | None 107 | qdrant: QdrantSettings | None 108 | chroma: ChromaSettings | None 109 | 110 | @model_validator(mode='before') 111 | def clear_conflicting_settings(self): 112 | for key in [member for member in VectorType if member != self['platform']]: 113 | self[key] = None 114 | return self 115 | 116 | 117 | class We0IndexSettings(BaseModel): 118 | application: str 119 | log: LogSettings 120 | server: ServerSettings 121 | vector: VectorSettings 122 | 123 | 124 | class AppSettings(BaseSettings): 125 | we0_index: We0IndexSettings | None = Field(default=None, alias='we0-index') 126 | 127 | model_config = SettingsConfigDict( 128 | yaml_file=Constants.Path.YAML_FILE_PATH, 129 | yaml_file_encoding='utf-8', 130 | extra='ignore' 131 | ) 132 | 133 | @classmethod 134 | def settings_customise_sources( 135 | cls, 136 | settings_cls: Type[BaseSettings], 137 | init_settings: PydanticBaseSettingsSource, 138 | env_settings: PydanticBaseSettingsSource, 139 | dotenv_settings: PydanticBaseSettingsSource, 140 | file_secret_settings: PydanticBaseSettingsSource, 141 | ) -> tuple[PydanticBaseSettingsSource, ...]: 142 | return ( 143 | init_settings, 144 | YamlConfigSettingsSource(settings_cls), 145 | env_settings, 146 | dotenv_settings, 147 | file_secret_settings, 148 | ) 149 | 150 | 151 | @lru_cache 152 | def get_we0_index_settings() -> We0IndexSettings: 153 | app_settings = AppSettings() 154 | return app_settings.we0_index 155 | 156 | 157 | if __name__ == '__main__': 158 | settings = get_we0_index_settings() 159 | print(settings.model_dump_json(indent=4)) 160 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # We0-index 2 | 3 | [English](README.md) | [中文](README-zh.md) 4 | 5 | A Python-based code indexing engine similar to Cursor, designed to transform Git repository code into code snippets with semantic embeddings for intelligent code search and retrieval. 6 | 7 | ## 🚀 Features 8 | 9 | - **Code Snippet Generation**: Automatically processes Git repositories and converts code into searchable snippets 10 | - **Semantic Search**: Generates semantic embeddings for intelligent code retrieval based on user queries 11 | - **Multi-Language Support**: Optimized for Python, Java, Go, JavaScript, and TypeScript 12 | - **Flexible Backends**: Supports multiple vector database backends and embedding service providers 13 | - **MCP Integration**: Built-in support for MCP (Model Context Protocol) service calls 14 | - **Deployment Ready**: Flexible deployment options for different environments 15 | 16 | ## 📋 Requirements 17 | 18 | - Python 3.12+ 19 | - Git 20 | 21 | ## 🛠️ Installation 22 | 23 | ### Quick Start 24 | 25 | ```bash 26 | # Clone the repository 27 | git clone https://github.com/we0-dev/we0-index 28 | cd we0-index 29 | 30 | # Set up environment configuration 31 | cp .env.example .env 32 | vim .env 33 | 34 | # Configure application settings 35 | vim resource/dev.yaml 36 | 37 | # Create virtual environment and install dependencies 38 | uv venv 39 | # Linux/macOS 40 | source .venv/bin/activate 41 | # Windows 42 | .venv\Scripts\activate 43 | uv sync 44 | ``` 45 | 46 | ### Development Setup 47 | 48 | ```bash 49 | # Install development dependencies 50 | uv sync --frozen 51 | ``` 52 | 53 | ## ⚙️ Configuration 54 | 55 | 1. **Environment Variables**: Copy `.env.example` to `.env` and configure your settings 56 | 2. **Application Config**: Edit `resource/dev.yaml` to customize your deployment 57 | 3. **Vector Database**: Configure your preferred vector database backend 58 | 4. **Embedding Service**: Set up your embedding service provider 59 | 60 | ## 🚀 Running the Service 61 | 62 | We0-index supports two running modes: Web API service and MCP protocol service. 63 | 64 | ### Web API Mode 65 | 66 | Start the FastAPI web server to provide RESTful API endpoints: 67 | 68 | ```bash 69 | # Activate virtual environment 70 | # Linux/macOS 71 | source .venv/bin/activate 72 | # Windows 73 | .venv\Scripts\activate 74 | 75 | # Start web service 76 | python main.py --mode fastapi 77 | ``` 78 | 79 | The web service will start on the configured host and port (check `resource/dev.yaml` for default configuration). 80 | 81 | ### MCP Protocol Mode 82 | 83 | Start the MCP (Model Context Protocol) service for AI integration: 84 | 85 | ```bash 86 | # Activate virtual environment 87 | # Linux/macOS 88 | source .venv/bin/activate 89 | # Windows 90 | .venv\Scripts\activate 91 | 92 | # Start MCP service (default with streamable-http transport) 93 | python main.py --mode mcp 94 | 95 | # Specify other transport protocols 96 | python main.py --mode mcp --transport stdio 97 | python main.py --mode mcp --transport sse 98 | ``` 99 | 100 | The MCP service runs with streamable-http transport by default and can be integrated with MCP-compatible AI clients. 101 | 102 | ### Runtime Parameters 103 | 104 | **Mode Parameters**: 105 | - `--mode fastapi`: Start Web API service 106 | - `--mode mcp`: Start MCP protocol service 107 | 108 | **Transport Parameters** (only applicable for MCP mode): 109 | - `--transport streamable-http`: Use HTTP streaming transport (default) 110 | - `--transport stdio`: Use standard input/output transport 111 | - `--transport sse`: Use sse transport 112 | 113 | 114 | 115 | ## 🏗️ Architecture 116 | 117 | We0-index is built with a modular architecture supporting: 118 | 119 | - **Code Parsers**: Language-specific code parsing and snippet extraction 120 | - **Embedding Engines**: Multiple embedding service integrations 121 | - **Vector Stores**: Pluggable vector database backends 122 | - **Search Interface**: RESTful API and CLI for code search 123 | - **MCP Protocol**: Model Context Protocol for AI integration 124 | 125 | ## 🤝 Contributing 126 | 127 | We welcome contributions! Please see our [Contributing Guidelines](CONTRIBUTING.md) for details. 128 | 129 | 1. Fork the repository 130 | 2. Create your feature branch (`git checkout -b feature/amazing-feature`) 131 | 3. Commit your changes (`git commit -m 'Add some amazing feature'`) 132 | 4. Push to the branch (`git push origin feature/amazing-feature`) 133 | 5. Open a Pull Request 134 | 135 | ## 📝 License 136 | 137 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 138 | 139 | 140 | 141 | ## 📚 Documentation 142 | 143 | For detailed documentation, please visit our [Documentation Site](https://docs.we0-dev.com) or check the `docs/` directory. 144 | 145 | ## 🐛 Issues 146 | 147 | If you encounter any issues, please [create an issue](https://github.com/we0-dev/we0-index/issues) on GitHub. 148 | 149 | ## 📞 Support 150 | 151 | - 📧 Email: we0@wegc.cn 152 | - 💬 Discussions: [GitHub Discussions](https://github.com/we0-dev/we0-index/discussions) 153 | - 📖 Wiki: [Project Wiki](https://deepwiki.com/we0-dev/we0-index) 154 | 155 | ## 🌟 Acknowledgments 156 | 157 | - Thanks to all contributors who have helped make this project better 158 | - Inspired by Cursor's approach to code intelligence 159 | - Built with modern Python tooling and best practices 160 | 161 | --- 162 | 163 | **Made with ❤️ by the We0-dev Team** -------------------------------------------------------------------------------- /router/git_router.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import asyncio 4 | import os 5 | from typing import List, Optional 6 | from urllib.parse import quote 7 | 8 | import aiofiles 9 | from fastapi import APIRouter 10 | from git import Repo 11 | from loguru import logger 12 | 13 | from domain.entity.blob import Blob 14 | from domain.entity.document import Document 15 | from domain.entity.task_context import TaskContext 16 | from domain.request.git_index_request import GitIndexRequest 17 | from domain.response.add_index_response import AddIndexResponse, FileInfoResponse 18 | from domain.result.result import Result 19 | from extensions.ext_manager import ExtManager 20 | from setting.setting import get_we0_index_settings 21 | from utils.git_parse import parse_git_url 22 | from utils.helper import Helper 23 | from utils.mimetype_util import guess_mimetype_and_extension 24 | from utils.vector_helper import VectorHelper 25 | from asyncio import Semaphore 26 | git_router = APIRouter() 27 | settings = get_we0_index_settings() 28 | async_semaphore = Semaphore(100) 29 | 30 | async def _process_file(uid: str, repo_id: str, repo_path: str, base_dir, relative_path: str) -> FileInfoResponse: 31 | logger.info(f'Processing file {relative_path}') 32 | mimetype, extension = guess_mimetype_and_extension(relative_path) 33 | file_id = Helper.generate_fixed_uuid(f"{uid}:{repo_path}:{relative_path}") 34 | file_path = os.path.join(base_dir, relative_path) 35 | async with aiofiles.open(file_path, 'rb') as f: 36 | content = await f.read() 37 | 38 | task_context = TaskContext( 39 | repo_id=repo_id, 40 | file_id=file_id, 41 | relative_path=relative_path, 42 | blob=Blob.from_data( 43 | data=content, 44 | mimetype=mimetype, 45 | extension=extension 46 | ) 47 | ) 48 | 49 | try: 50 | async with async_semaphore: 51 | documents: List[Document] = await VectorHelper.build_and_embedding_segment(task_context) 52 | if documents: 53 | await ExtManager.vector.upsert(documents) 54 | except Exception as e: 55 | raise e 56 | 57 | return FileInfoResponse( 58 | file_id=file_id, 59 | relative_path=relative_path 60 | ) 61 | 62 | 63 | def _prepare_repo_url_with_auth(repo_url: str, username: Optional[str] = None, password: Optional[str] = None, 64 | access_token: Optional[str] = None) -> str: 65 | """ 66 | 为私有仓库准备带有认证信息的 URL 67 | 对用户名和密码进行 URL 编码,以处理包含特殊字符(如 @ 符号)的情况 68 | 支持三种认证方式: 69 | 1. access_token: 使用个人访问令牌(推荐) 70 | 2. username + password: 使用用户名和密码 71 | 3. 无认证: 公开仓库 72 | """ 73 | if repo_url.startswith('https://'): 74 | if access_token: 75 | # 使用 access_token 进行认证 76 | # 对于大多数 Git 服务提供商(GitHub, GitLab 等),使用 token 作为用户名,密码可以为空或使用 token 77 | encoded_token = quote(access_token, safe='') 78 | url_without_protocol = repo_url[8:] 79 | # 使用 token 作为用户名,密码为 'x-oauth-basic'(GitHub)或空字符串 80 | return f'https://{encoded_token}:x-oauth-basic@{url_without_protocol}' 81 | elif username and password: 82 | # 对用户名和密码进行 URL 编码以处理特殊字符 83 | encoded_username = quote(username, safe='') 84 | encoded_password = quote(password, safe='') 85 | # 使用编码后的用户名和密码 86 | url_without_protocol = repo_url[8:] 87 | return f'https://{encoded_username}:{encoded_password}@{url_without_protocol}' 88 | return repo_url 89 | 90 | 91 | async def clone_and_index(git_index_request: GitIndexRequest) -> Result[AddIndexResponse]: 92 | """ 93 | Tool parameters must be in standard JSON format! 94 | "git_index_request": { 95 | "xxx": "xxx" 96 | } 97 | 克隆 Git 仓库并对所有文件进行索引 98 | 支持私有仓库访问: 99 | - Access Token 认证 (推荐): 提供 access_token 参数 100 | - SSH 密钥认证: 使用 git@github.com:user/repo.git 格式的 URL,可选择提供 ssh_key_path 101 | - HTTPS + 用户名密码: 提供 username 和 password 参数 102 | """ 103 | try: 104 | if not git_index_request.uid: 105 | git_index_request.uid = 'default_uid' 106 | 107 | domain, owner, repo = parse_git_url(git_index_request.repo_url) 108 | repo_abs_path = f'{domain}/{owner}/{repo}' 109 | repo_id = Helper.generate_fixed_uuid(f"{git_index_request.uid}{repo_abs_path}:") 110 | 111 | # 准备认证后的仓库 URL 112 | auth_repo_url = _prepare_repo_url_with_auth( 113 | git_index_request.repo_url, 114 | git_index_request.username, 115 | git_index_request.password, 116 | git_index_request.access_token 117 | ) 118 | 119 | file_count = 0 # 初始化 file_count 变量 120 | 121 | async with aiofiles.tempfile.TemporaryDirectory() as tmp_dir: 122 | try: 123 | await asyncio.to_thread( 124 | Repo.clone_from, 125 | auth_repo_url, 126 | tmp_dir 127 | ) 128 | except Exception as e: 129 | logger.error(f'{type(e).__name__}: {e}') 130 | raise e 131 | # 遍历并处理文件 132 | tasks = [] 133 | for root, dirs, files in os.walk(tmp_dir): 134 | # 忽略 .git 等隐藏目录 135 | dirs[:] = [d for d in dirs if not d.startswith('.')] 136 | for file in files: 137 | if not file.startswith('.'): 138 | relative_path = os.path.relpath(os.path.join(root, file), start=tmp_dir) 139 | tasks.append(asyncio.create_task(_process_file( 140 | uid=git_index_request.uid, 141 | repo_id=repo_id, 142 | repo_path=repo_abs_path, 143 | base_dir=tmp_dir, 144 | relative_path=relative_path 145 | ))) 146 | for task in asyncio.as_completed(tasks): 147 | await task 148 | file_count += 1 149 | 150 | logger.info(f"Successfully processed {file_count} files from repository {repo_abs_path}") 151 | 152 | return Result.ok(data=AddIndexResponse( 153 | repo_id=repo_id, 154 | file_infos=[] 155 | )) 156 | except Exception as e: 157 | logger.exception(e) 158 | return Result.failed(message=f"{type(e).__name__}: {e}") 159 | 160 | 161 | git_router.add_api_route('/clone_and_index', clone_and_index, methods=['POST'], response_model=Result[AddIndexResponse]) 162 | -------------------------------------------------------------------------------- /extensions/vector/chroma.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/23 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : chroma 7 | # @Software: PyCharm 8 | import asyncio 9 | import inspect 10 | from typing import List, Optional 11 | 12 | import chromadb 13 | from chromadb import QueryResult 14 | 15 | from domain.enums.chroma_mode import ChromaMode 16 | from domain.entity.document import Document, DocumentMeta 17 | from extensions.vector.base_vector import BaseVector 18 | from setting.setting import get_we0_index_settings 19 | 20 | settings = get_we0_index_settings() 21 | 22 | 23 | class Chroma(BaseVector): 24 | 25 | def __init__(self): 26 | self.client: None = None 27 | self.collection_name: str | None = None 28 | 29 | @staticmethod 30 | def get_client(): 31 | return None 32 | 33 | async def init(self): 34 | if self.client is None: 35 | chroma = settings.vector.chroma 36 | match chroma.mode: 37 | case ChromaMode.MEMORY: 38 | self.client = chromadb.Client() 39 | case ChromaMode.DISK: 40 | self.client = chromadb.PersistentClient(path=chroma.disk.path) 41 | case ChromaMode.REMOTE: 42 | self.client = await chromadb.AsyncHttpClient( 43 | host=chroma.remote.host, port=chroma.remote.port, ssl=chroma.remote.ssl 44 | ) 45 | case _: 46 | raise ValueError(f'Unknown chroma mode: {chroma.mode}') 47 | dimension = await self.get_dimension() 48 | self.collection_name = self.dynamic_collection_name(dimension) 49 | await self._execute_async_or_thread(func=self.client.get_or_create_collection, name=self.collection_name) 50 | 51 | async def create(self, documents: List[Document]): 52 | collection = await self._execute_async_or_thread( 53 | func=self.client.get_or_create_collection, 54 | name=self.collection_name 55 | ) 56 | ids = [document.meta.segment_id for document in documents] 57 | vectors = [document.vector for document in documents] 58 | metas = [document.meta.model_dump(exclude_none=True) for document in documents] 59 | contents = [document.content for document in documents] 60 | await self._execute_async_or_thread( 61 | func=collection.upsert, ids=ids, embeddings=vectors, metadatas=metas, documents=contents 62 | ) 63 | 64 | async def upsert(self, documents: List[Document]): 65 | repo_id = documents[0].meta.repo_id 66 | file_ids = list(set(document.meta.file_id for document in documents)) 67 | await self.delete(repo_id, file_ids) 68 | await self.create(documents) 69 | 70 | async def all_meta(self, repo_id: str) -> List[DocumentMeta]: 71 | collection = await self._execute_async_or_thread( 72 | func=self.client.get_or_create_collection, 73 | name=self.collection_name 74 | ) 75 | results: QueryResult = await self._execute_async_or_thread( 76 | func=collection.get, where={ 77 | 'repo_id': { 78 | '$eq': repo_id 79 | } 80 | } 81 | ) 82 | metadatas = results.get('metadatas', []) 83 | if len(metadatas) == 0: 84 | return [] 85 | metas = metadatas[0] 86 | return [DocumentMeta.model_validate(meta) for meta in metas] 87 | 88 | async def drop(self, repo_id: str): 89 | collection = await self._execute_async_or_thread( 90 | func=self.client.get_or_create_collection, 91 | name=self.collection_name 92 | ) 93 | await self._execute_async_or_thread( 94 | func=collection.delete, where={ 95 | 'repo_id': { 96 | "$eq": repo_id 97 | } 98 | } 99 | ) 100 | 101 | async def delete(self, repo_id: str, file_ids: List[str]): 102 | collection = await self._execute_async_or_thread( 103 | func=self.client.get_or_create_collection, 104 | name=self.collection_name 105 | ) 106 | await self._execute_async_or_thread(collection.delete, where={ 107 | "$and": [ 108 | {"repo_id": {"$eq": repo_id}}, 109 | {"file_id": {"$in": file_ids}} 110 | ] 111 | }) 112 | 113 | async def search_by_vector( 114 | self, 115 | repo_id: str, 116 | file_ids: Optional[List[str]], 117 | query_vector: List[float], 118 | top_k: int = 5, 119 | score_threshold: float = 0.0 120 | ) -> List[Document]: 121 | collection = await self._execute_async_or_thread( 122 | func=self.client.get_or_create_collection, 123 | name=self.collection_name 124 | ) 125 | 126 | if not file_ids: 127 | where = { 128 | 'repo_id': { 129 | '$eq': repo_id 130 | } 131 | } 132 | else: 133 | where = { 134 | "$and": [ 135 | {"repo_id": {"$eq": repo_id}}, 136 | {"file_id": {"$in": file_ids}} 137 | ] 138 | } 139 | results: QueryResult = await self._execute_async_or_thread( 140 | func=collection.query, query_embeddings=query_vector, n_results=top_k, where=where 141 | ) 142 | ids = results.get('ids', []) 143 | if len(ids) == 0: 144 | return [] 145 | idx = ids[0] 146 | metas = results.get('metadatas', [])[0] 147 | distances = results.get('distances', [])[0] 148 | contents = results.get('documents', [])[0] 149 | 150 | documents: List[Document] = [] 151 | for index in range(len(idx)): 152 | distance = distances[index] 153 | metadata = dict(metas[index]) 154 | if distance >= score_threshold: 155 | metadata["score"] = distance 156 | metadata["content"] = contents[index] 157 | document = Document( 158 | meta=DocumentMeta.model_validate(metadata), 159 | ) 160 | documents.append(document) 161 | 162 | return sorted(documents, key=lambda x: x.meta.score) 163 | 164 | @staticmethod 165 | async def _execute_async_or_thread(func, *args, **kwargs): 166 | if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): 167 | return await func(*args, **kwargs) 168 | else: 169 | return await asyncio.to_thread(func, *args, **kwargs) 170 | -------------------------------------------------------------------------------- /router/vector_router.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2024/10/11 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : vector_router 7 | # @Software: PyCharm 8 | import asyncio 9 | from typing import Annotated 10 | from typing import List 11 | 12 | from fastapi import APIRouter, UploadFile 13 | from fastapi.params import File, Form 14 | 15 | from domain.entity.blob import Blob 16 | from domain.entity.document import DocumentMeta, Document 17 | from domain.entity.task_context import TaskContext 18 | from domain.request.add_index_request import AddIndexRequest, AddFileInfo 19 | from domain.request.all_index_request import AllIndexRequest 20 | from domain.request.delete_index_request import DeleteIndexRequest 21 | from domain.request.drop_index_request import DropIndexRequest 22 | from domain.request.retrieval_request import RetrievalRequest 23 | from domain.response.add_index_by_file_response import AddIndexByFileResponse 24 | from domain.response.add_index_response import AddIndexResponse, FileInfoResponse 25 | from domain.result.result import Result 26 | from extensions.ext_manager import ExtManager 27 | from models.model_factory import ModelInstance 28 | from setting.setting import get_we0_index_settings 29 | from utils.helper import Helper 30 | from utils.mimetype_util import guess_mimetype_and_extension 31 | from utils.vector_helper import VectorHelper 32 | 33 | vector_router = APIRouter() 34 | 35 | settings = get_we0_index_settings() 36 | 37 | 38 | async def _upsert_index(uid: str, repo_abs_path: str, repo_id: str, file_info: AddFileInfo) -> FileInfoResponse: 39 | mimetype, extension = guess_mimetype_and_extension(file_info.relative_path) 40 | file_id = Helper.generate_fixed_uuid( 41 | f"{uid}:{repo_abs_path}:{file_info.relative_path}" 42 | ) 43 | 44 | task_context = TaskContext( 45 | repo_id=repo_id, 46 | file_id=file_id, 47 | relative_path=file_info.relative_path, 48 | blob=Blob.from_data( 49 | data=file_info.content, 50 | mimetype=mimetype, 51 | extension=extension 52 | ) 53 | ) 54 | try: 55 | documents: List[Document] = await VectorHelper.build_and_embedding_segment(task_context) 56 | if documents: 57 | await ExtManager.vector.upsert(documents) 58 | except Exception as e: 59 | raise e 60 | return FileInfoResponse( 61 | file_id=file_id, 62 | relative_path=file_info.relative_path 63 | ) 64 | 65 | 66 | @vector_router.post("/upsert_index", response_model=Result[AddIndexResponse]) 67 | async def upsert_index(add_index_request: AddIndexRequest): 68 | """ 69 | 新增或追加索引 70 | 请务必使用JSON.stringify对文本进行转义,确保格式正确,否则AST结构将被破坏\n 71 | `fs.readFile(filePath, 'utf8', (err, fileContent) => {const jsonStr = JSON.stringify(data);}`\n 72 | """ 73 | repo_id = Helper.generate_fixed_uuid(f"{add_index_request.uid}:{add_index_request.repo_abs_path}") 74 | 75 | tasks = [ 76 | asyncio.create_task(_upsert_index( 77 | uid=add_index_request.uid, 78 | repo_abs_path=add_index_request.repo_abs_path, 79 | repo_id=repo_id, 80 | file_info=file_info 81 | )) for file_info in add_index_request.file_infos 82 | ] 83 | file_infos = [await task for task in asyncio.as_completed(tasks)] 84 | return Result.ok(data=AddIndexResponse(repo_id=repo_id, file_infos=file_infos)) 85 | 86 | 87 | @vector_router.post('/upsert_index_by_file', response_model=Result[AddIndexByFileResponse]) 88 | async def upsert_index_by_file( 89 | uid: Annotated[str, Form(description='Unique ID')], 90 | repo_abs_path: Annotated[str, Form(description='Repository Absolute Path')], 91 | relative_path: Annotated[str, Form(description='File Relative Path')], 92 | file: UploadFile = File(...), 93 | ): 94 | """ 95 | 新增或追加索引(通过文件) 96 | """ 97 | repo_id = Helper.generate_fixed_uuid(f"{uid}:{repo_abs_path}") 98 | mimetype, extension = guess_mimetype_and_extension(relative_path) 99 | 100 | file_id = Helper.generate_fixed_uuid(f"{uid}:{repo_abs_path}:{relative_path}") 101 | 102 | task_context = TaskContext( 103 | repo_id=repo_id, 104 | file_id=file_id, 105 | relative_path=relative_path, 106 | blob=Blob.from_data( 107 | data=await file.read(), 108 | mimetype=mimetype, 109 | extension=extension 110 | ) 111 | ) 112 | try: 113 | documents: List[Document] = await VectorHelper.build_and_embedding_segment(task_context) 114 | if documents: 115 | await ExtManager.vector.upsert(documents) 116 | return Result.ok(data=AddIndexByFileResponse(repo_id=repo_id, file_id=file_id)) 117 | else: 118 | return Result.failed(message=f"Not Content") 119 | except Exception as e: 120 | return Result.failed(message=f"{type(e).__name__}: {e}") 121 | 122 | 123 | @vector_router.post('/drop_index', response_model=Result) 124 | async def drop_index(drop_index_request: DropIndexRequest): 125 | """ 126 | 删除索引的全部向量 127 | """ 128 | try: 129 | await ExtManager.vector.drop(repo_id=drop_index_request.repo_id) 130 | return Result.ok() 131 | except Exception as e: 132 | return Result.failed(message=f"{type(e).__name__}: {e}") 133 | 134 | 135 | @vector_router.post('/delete_index', response_model=Result) 136 | async def delete_index(delete_index_request: DeleteIndexRequest): 137 | """ 138 | 删除索引的指定向量 139 | """ 140 | try: 141 | await ExtManager.vector.delete(repo_id=delete_index_request.repo_id, file_ids=delete_index_request.file_ids) 142 | return Result.ok() 143 | except Exception as e: 144 | return Result.failed(message=f"{type(e).__name__}: {e}") 145 | 146 | 147 | @vector_router.post('/all_index', response_model=Result) 148 | async def all_index(all_index_request: AllIndexRequest): 149 | try: 150 | all_meta = await ExtManager.vector.all_meta(repo_id=all_index_request.repo_id) 151 | return Result.ok(data=all_meta) 152 | except Exception as e: 153 | return Result.failed(message=f"{type(e).__name__}: {e}") 154 | 155 | 156 | async def retrieval( 157 | retrieval_request: RetrievalRequest 158 | ) -> Result[List[DocumentMeta]]: 159 | """ 160 | Tool parameters must be in standard JSON format! 161 | "retrieval_request": { 162 | "xxx": "xxx" 163 | } 164 | 相似度匹配,从整个仓库或指定仓库的部分文件 165 | """ 166 | try: 167 | embedding_model: ModelInstance = await ExtManager.vector.get_embedding_model() 168 | vector_data = await embedding_model.create_embedding([retrieval_request.query]) 169 | documents = await ExtManager.vector.search_by_vector( 170 | repo_id=retrieval_request.repo_id, file_ids=retrieval_request.file_ids, query_vector=vector_data[0] 171 | ) 172 | retrieval_segment_list = [document.meta for document in documents] 173 | return Result.ok(data=retrieval_segment_list) 174 | except Exception as e: 175 | return Result.failed(message=f"{type(e).__name__}: {e}") 176 | 177 | 178 | vector_router.add_api_route('/retrieval', retrieval, methods=['POST'], response_model=Result[List[DocumentMeta]]) 179 | -------------------------------------------------------------------------------- /extensions/vector/qdrant.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/22 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : qdrant 7 | # @Software: PyCharm 8 | 9 | from typing import List, Optional 10 | 11 | from qdrant_client.async_qdrant_client import AsyncQdrantClient 12 | from qdrant_client.http import models as rest 13 | from qdrant_client.http.exceptions import UnexpectedResponse 14 | 15 | from domain.enums.qdrant_mode import QdrantMode 16 | from domain.entity.document import Document, DocumentMeta 17 | from extensions.vector.base_vector import BaseVector 18 | from setting.setting import get_we0_index_settings 19 | 20 | settings = get_we0_index_settings() 21 | 22 | 23 | class Qdrant(BaseVector): 24 | 25 | def __init__(self): 26 | self.client = self.get_client() 27 | self.collection_name: str | None = None 28 | 29 | @staticmethod 30 | def get_client(): 31 | qdrant = settings.vector.qdrant 32 | match qdrant.mode: 33 | case QdrantMode.MEMORY: 34 | return AsyncQdrantClient(location=':memory:') 35 | case QdrantMode.DISK: 36 | return AsyncQdrantClient(path=qdrant.disk.path) 37 | case QdrantMode.REMOTE: 38 | return AsyncQdrantClient(host=qdrant.remote.host, port=qdrant.remote.port) 39 | case _: 40 | raise ValueError(f'Unknown qdrant mode: {qdrant.mode}') 41 | 42 | async def init(self): 43 | collection_names = [] 44 | dimension = await self.get_dimension() 45 | self.collection_name = self.dynamic_collection_name(dimension) 46 | collections: rest.CollectionsResponse = await self.client.get_collections() 47 | for collection in collections.collections: 48 | collection_names.append(collection.name) 49 | if self.collection_name not in collection_names: 50 | vectors_config = rest.VectorParams( 51 | size=dimension, 52 | distance=rest.Distance.COSINE 53 | ) 54 | hnsw_config = rest.HnswConfigDiff( 55 | m=0, 56 | payload_m=16, 57 | ef_construct=100, 58 | full_scan_threshold=10000, 59 | max_indexing_threads=0, 60 | on_disk=False, 61 | ) 62 | await self.client.create_collection( 63 | collection_name=self.collection_name, 64 | vectors_config=vectors_config, 65 | hnsw_config=hnsw_config, 66 | timeout=30 67 | ) 68 | if settings.vector.qdrant.mode != QdrantMode.DISK: 69 | await self.client.create_payload_index( 70 | self.collection_name, 'repo_id', field_schema=rest.PayloadSchemaType.KEYWORD 71 | ) 72 | await self.client.create_payload_index( 73 | self.collection_name, 'file_id', field_schema=rest.PayloadSchemaType.KEYWORD 74 | ) 75 | 76 | async def create(self, documents: List[Document]): 77 | repo_id = documents[0].meta.repo_id 78 | print_structs = [] 79 | for document in documents: 80 | document.meta.content = document.content # qdrant,就只能存三个值id vector payload,所以只能把content转到meta 81 | for document in documents: 82 | document.meta.repo_id = repo_id 83 | print_structs.append(rest.PointStruct( 84 | id=document.meta.segment_id, 85 | vector=document.vector, 86 | payload=document.meta.model_dump(exclude_none=True), 87 | )) 88 | 89 | await self.client.upsert(collection_name=self.collection_name, points=print_structs) 90 | 91 | async def upsert(self, documents: List[Document]): 92 | repo_id = documents[0].meta.repo_id 93 | file_ids = list(set(document.meta.file_id for document in documents)) 94 | await self.delete(repo_id, file_ids) 95 | await self.create(documents) 96 | 97 | async def all_meta(self, repo_id: str) -> List[DocumentMeta]: 98 | scroll_filter = rest.Filter( 99 | must=[ 100 | rest.FieldCondition( 101 | key="repo_id", 102 | match=rest.MatchValue(value=repo_id), 103 | ), 104 | ], 105 | ) 106 | 107 | records, next_offset = await self.client.scroll( 108 | collection_name=self.collection_name, 109 | scroll_filter=scroll_filter, 110 | limit=100, 111 | with_payload=True 112 | ) 113 | 114 | while next_offset: 115 | scroll_records, next_offset = await self.client.scroll( 116 | collection_name=self.collection_name, 117 | scroll_filter=scroll_filter, 118 | limit=100, 119 | with_payload=True, 120 | offset=next_offset 121 | ) 122 | records.extend(scroll_records) 123 | return [DocumentMeta.model_validate(record.payload) for record in records] 124 | 125 | async def drop(self, repo_id: str): 126 | filter_selector = rest.Filter( 127 | must=[ 128 | rest.FieldCondition( 129 | key="repo_id", 130 | match=rest.MatchValue(value=repo_id), 131 | ), 132 | ], 133 | ) 134 | try: 135 | await self.client.delete( 136 | collection_name=self.collection_name, 137 | points_selector=filter_selector 138 | ) 139 | except UnexpectedResponse as e: 140 | if e.status_code == 404: 141 | return 142 | raise e 143 | 144 | async def delete(self, repo_id: str, file_ids: List[str]): 145 | filter_selector = rest.Filter( 146 | must=[ 147 | rest.FieldCondition( 148 | key="repo_id", 149 | match=rest.MatchValue(value=repo_id), 150 | ), 151 | rest.FieldCondition( 152 | key="file_id", 153 | match=rest.MatchAny(any=file_ids), 154 | ) 155 | ], 156 | ) 157 | try: 158 | await self.client.delete( 159 | collection_name=self.collection_name, 160 | points_selector=filter_selector 161 | ) 162 | except UnexpectedResponse as e: 163 | if e.status_code == 404: 164 | return 165 | raise e 166 | except Exception as e: 167 | raise e 168 | 169 | async def search_by_vector( 170 | self, 171 | repo_id: str, 172 | file_ids: Optional[List[str]], 173 | query_vector: List[float], 174 | top_k: int = 5, 175 | score_threshold: float = 0.0 176 | ) -> List[Document]: 177 | 178 | query_filter = rest.Filter( 179 | must=[ 180 | rest.FieldCondition( 181 | key="repo_id", 182 | match=rest.MatchValue(value=repo_id), 183 | ), 184 | ], 185 | ) 186 | # 如果 file_ids 不为空,添加 file_id 的过滤条件 187 | if file_ids: 188 | query_filter.must.append( 189 | rest.FieldCondition( 190 | key="file_id", 191 | match=rest.MatchAny(any=file_ids), 192 | ) 193 | ) 194 | response: rest.QueryResponse = await self.client.query_points( 195 | collection_name=self.collection_name, 196 | query=query_vector, 197 | query_filter=query_filter, 198 | limit=top_k, 199 | with_payload=True, 200 | with_vectors=True, 201 | score_threshold=score_threshold, 202 | ) 203 | documents: List[Document] = [] 204 | for point in response.points: 205 | meta = DocumentMeta.model_validate(point.payload) 206 | meta.score = point.score 207 | documents.append(Document(meta=meta)) 208 | return documents 209 | -------------------------------------------------------------------------------- /extensions/vector/pgvector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/14 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : pgvector 7 | # @Software: PyCharm 8 | import json 9 | from typing import List, Optional 10 | 11 | import numpy as np 12 | from sqlalchemy import text, bindparam 13 | from sqlalchemy.ext.asyncio import create_async_engine 14 | 15 | from domain.entity.document import Document, DocumentMeta 16 | from extensions.vector.base_vector import BaseVector 17 | from setting.setting import get_we0_index_settings 18 | 19 | settings = get_we0_index_settings() 20 | 21 | SQL_CREATE_FILE_INDEX = lambda table_name: f""" 22 | CREATE INDEX IF NOT EXISTS file_idx ON {table_name} (file_id); 23 | """ 24 | SQL_CREATE_REPO_FILE_INDEX = lambda table_name: f""" 25 | CREATE INDEX IF NOT EXISTS repo_file_idx ON {table_name} (repo_id, file_id); 26 | """ 27 | SQL_CREATE_EMBEDDING_INDEX = lambda table_name: f""" 28 | CREATE INDEX IF NOT EXISTS embedding_cosine_idx ON {table_name} 29 | USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64); 30 | """ 31 | 32 | SQL_CREATE_TABLE = lambda table_name, dimension: f""" 33 | CREATE TABLE IF NOT EXISTS {table_name} ( 34 | id UUID PRIMARY KEY, 35 | repo_id UUID NOT NULL, 36 | file_id UUID NOT NULL, 37 | content TEXT NOT NULL, 38 | meta JSONB NOT NULL, 39 | embedding vector({dimension}) NOT NULL 40 | ) USING heap; 41 | """ 42 | 43 | 44 | class PgVector(BaseVector): 45 | 46 | def __init__(self): 47 | self.client = self.get_client() 48 | self.table_name: str | None = None 49 | self.normalized: bool = False 50 | 51 | @staticmethod 52 | def get_client(): 53 | pgvector = settings.vector.pgvector 54 | return create_async_engine( 55 | url=f"postgresql+psycopg://{pgvector.user}:{pgvector.password}@{pgvector.host}:{pgvector.port}/{pgvector.db}", 56 | echo=False, 57 | ) 58 | 59 | async def init(self): 60 | async with self.client.begin() as conn: 61 | dimension = await self.get_dimension() 62 | if dimension > 2000: 63 | dimension = 2000 64 | self.normalized = True 65 | self.table_name = self.dynamic_collection_name(dimension) 66 | await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) 67 | await conn.execute(text(SQL_CREATE_TABLE(self.table_name, dimension))) 68 | await conn.execute(text(SQL_CREATE_REPO_FILE_INDEX(self.table_name))) 69 | await conn.execute(text(SQL_CREATE_FILE_INDEX(self.table_name))) 70 | await conn.execute(text(SQL_CREATE_EMBEDDING_INDEX(self.table_name))) 71 | 72 | async def _create(self, repo_id: str, documents: List[Document]): 73 | stmt = text( 74 | f""" 75 | INSERT INTO {self.table_name} (id, repo_id, file_id, content, meta, embedding) 76 | VALUES (:id, :repo_id, :file_id, :content, :meta, :embedding) 77 | ON CONFLICT (id) DO UPDATE SET 78 | repo_id = EXCLUDED.repo_id, 79 | file_id = EXCLUDED.file_id, 80 | content = EXCLUDED.content, 81 | meta = EXCLUDED.meta, 82 | embedding = EXCLUDED.embedding 83 | """ 84 | ) 85 | parameters = [] 86 | for doc in documents: 87 | if doc.meta is not None: 88 | if self.normalized: 89 | vector = self.normalize_l2(doc.vector[:2000]) 90 | else: 91 | vector = doc.vector 92 | parameters.append({ 93 | 'id': doc.meta.segment_id, 94 | 'repo_id': repo_id, 95 | 'file_id': doc.meta.file_id, 96 | 'content': doc.content, 97 | 'meta': doc.meta.model_dump_json(exclude={'score', 'content'}), 98 | 'embedding': vector, 99 | }) 100 | 101 | return stmt, parameters 102 | 103 | async def create(self, documents: List[Document]): 104 | async with self.client.begin() as conn: 105 | repo_id = documents[0].meta.repo_id 106 | insert_stmt, insert_parameters = await self._create(repo_id, documents) 107 | await conn.execute( 108 | insert_stmt, 109 | insert_parameters 110 | ) 111 | 112 | async def upsert(self, documents: List[Document]): 113 | repo_id = documents[0].meta.repo_id 114 | file_ids = list(set(document.meta.file_id for document in documents)) 115 | async with self.client.begin() as conn: 116 | delete_stmt, delete_parameters = await self._delete(repo_id, file_ids) 117 | await conn.execute( 118 | delete_stmt, 119 | delete_parameters 120 | ) 121 | insert_stmt, insert_parameters = await self._create(repo_id, documents) 122 | await conn.execute( 123 | insert_stmt, 124 | insert_parameters 125 | ) 126 | 127 | async def all_meta(self, repo_id: str) -> List[DocumentMeta]: 128 | async with self.client.begin() as conn: 129 | sql_query = ( 130 | f"SELECT meta " 131 | f"FROM {self.table_name} " 132 | f"WHERE repo_id = :repo_id" 133 | ) 134 | result = await conn.execute( 135 | text(sql_query), 136 | { 137 | "repo_id": repo_id, 138 | } 139 | ) 140 | records = result.all() 141 | return [DocumentMeta.model_validate(meta[0]) for meta in records] 142 | 143 | async def drop(self, repo_id: str): 144 | async with self.client.begin() as conn: 145 | await conn.execute( 146 | text( 147 | f"DELETE FROM {self.table_name} " 148 | "WHERE repo_id = :repo_id" 149 | ), 150 | {'repo_id': repo_id} 151 | ) 152 | 153 | async def _delete(self, repo_id: str, file_ids: List[str]): 154 | sql_query = ( 155 | f"DELETE FROM {self.table_name} " 156 | "WHERE repo_id = :repo_id AND file_id = ANY(:file_ids) " 157 | ) 158 | stmt = text(sql_query).bindparams( 159 | bindparam("file_ids", expanding=False) # 关键点:禁止自动展开列表为多个参数 160 | ) 161 | parameters = {'repo_id': repo_id, 'file_ids': file_ids} 162 | return stmt, parameters 163 | 164 | async def delete(self, repo_id: str, file_ids: List[str]): 165 | async with self.client.begin() as conn: 166 | delete_stmt, delete_parameters = await self._delete(repo_id, file_ids) 167 | await conn.execute( 168 | delete_stmt, 169 | delete_parameters 170 | ) 171 | 172 | async def search_by_vector( 173 | self, 174 | repo_id: str, 175 | file_ids: Optional[List[str]], 176 | query_vector: List[float], 177 | top_k: int = 5, 178 | score_threshold: float = 0.0 179 | ) -> List[Document]: 180 | documents = [] 181 | async with self.client.begin() as conn: 182 | # 基础 SQL 查询 183 | sql_query = ( 184 | f"SELECT content, meta, embedding <=> :query_vector AS distance " 185 | f"FROM {self.table_name} " 186 | f"WHERE repo_id = :repo_id " 187 | ) 188 | if self.normalized: 189 | query_vector = self.normalize_l2(query_vector[:2000]) 190 | # 基础参数 191 | parameters = { 192 | "query_vector": json.dumps(query_vector), 193 | "repo_id": repo_id, 194 | "top_k": top_k 195 | } 196 | 197 | if file_ids: 198 | sql_query += "AND file_id = ANY(:file_ids) " 199 | parameters["file_ids"] = file_ids 200 | 201 | sql_query += "ORDER BY distance LIMIT :top_k" 202 | 203 | if file_ids: 204 | stmt = text(sql_query).bindparams( 205 | bindparam("file_ids", expanding=False) 206 | ) 207 | else: 208 | stmt = text(sql_query) 209 | 210 | result = await conn.execute( 211 | stmt, 212 | parameters 213 | ) 214 | records = result.all() 215 | for record in records: 216 | content, meta, distance = record 217 | score = 1 - distance 218 | if score > score_threshold: 219 | meta["score"] = score 220 | meta['content'] = content 221 | documents.append(Document(content=content, meta=DocumentMeta.model_validate(meta))) 222 | return documents 223 | 224 | @staticmethod 225 | def normalize_l2(x: List[float]) -> List[float]: 226 | x = np.array(x) 227 | if x.ndim == 1: 228 | norm = np.linalg.norm(x) 229 | if norm == 0: 230 | return x.tolist() 231 | return (x / norm).tolist() 232 | else: 233 | norm = np.linalg.norm(x, 2, axis=1, keepdims=True) 234 | return np.where(norm == 0, x, x / norm).tolist() 235 | -------------------------------------------------------------------------------- /loader/segmenter/base_line_segmenter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2025/2/11 4 | # @Author : .*? 5 | # @Email : amashiro2233@gmail.com 6 | # @File : line_based_segmenter 7 | # @Software: PyCharm 8 | from bisect import bisect_right, bisect_left 9 | from typing import List, Dict, Any, Tuple, Iterator 10 | 11 | from domain.entity.code_segment import CodeSegment 12 | from loader.segmenter.base_segmenter import BaseSegmenter 13 | 14 | 15 | class LineBasedSegmenter(BaseSegmenter): 16 | 17 | def __init__( 18 | self, 19 | text: str, 20 | max_chunk_size: int = 50, 21 | min_chunk_size: int = 10, 22 | delimiters=None, 23 | **kwargs 24 | ): 25 | super().__init__(**kwargs) 26 | self.text = text 27 | self.max_chunk_size = max_chunk_size 28 | self.min_chunk_size = min_chunk_size 29 | self.delimiters = delimiters or ['\n\n', '\n'] 30 | 31 | # 构建行号位置映射 32 | self.line_positions = self._build_line_positions(text) 33 | 34 | def segment(self) -> Iterator[CodeSegment]: 35 | """执行分块操作""" 36 | initial_range = (0, len(self.text)) 37 | segments = self._recursive_split(initial_range, self.delimiters) 38 | if self.merge_small_chunks: 39 | segments = self._merge_small_segments(segments) 40 | for segment in segments: 41 | code_segment = CodeSegment.model_validate(segment) 42 | if not code_segment.code.isspace(): 43 | yield code_segment 44 | 45 | @staticmethod 46 | def _build_line_positions(text: str) -> List[Tuple[int, int]]: 47 | """生成每行的字符位置范围列表 (start, end)""" 48 | lines = [] 49 | start = 0 50 | for line in text.split('\n'): 51 | end = start + len(line) 52 | lines.append((start, end)) 53 | start = end + 1 # 跳过换行符 54 | return lines 55 | 56 | def _get_original_lines(self, start_pos: int, end_pos: int) -> Tuple[int, int]: 57 | """根据字符位置获取原始行号 (1-based)""" 58 | starts = [line[0] for line in self.line_positions] 59 | ends = [line[1] for line in self.line_positions] 60 | 61 | start_line = bisect_right(starts, start_pos) 62 | end_line = bisect_left(ends, end_pos) 63 | 64 | # 处理边界情况 65 | start_line = max(0, start_line - 1) if start_line > 0 else 0 66 | end_line = min(end_line, len(ends) - 1) 67 | 68 | return start_line + 1, end_line + 1 # 转换为1-based 69 | 70 | def _recursive_split( 71 | self, 72 | char_range: Tuple[int, int], 73 | delimiters: List[str] 74 | ) -> List[Dict[str, Any]]: 75 | start_pos, end_pos = char_range 76 | text = self.text[start_pos:end_pos] 77 | 78 | if self.max_tokens is not None: 79 | raw_lines = text.splitlines(keepends=True) 80 | segments = [] 81 | current_chunk_start = start_pos 82 | pos = start_pos 83 | any_long_line = False 84 | for line in raw_lines: 85 | # 对单行去掉首尾空白后计算 token 数 86 | if self.length_function(line.strip()) > self.max_tokens: 87 | any_long_line = True 88 | # 先处理该行之前的部分(如果存在) 89 | if pos > current_chunk_start: 90 | segments.extend( 91 | self._recursive_split((current_chunk_start, pos), delimiters) 92 | ) 93 | # 对这一行进行强制拆分(标记 forced 为 True) 94 | segments.extend( 95 | self._forced_split((pos, pos + len(line)), block=1) 96 | ) 97 | current_chunk_start = pos + len(line) 98 | pos += len(line) 99 | if any_long_line: 100 | if current_chunk_start < end_pos: 101 | segments.extend( 102 | self._recursive_split((current_chunk_start, end_pos), delimiters) 103 | ) 104 | return segments 105 | 106 | current_lines = len(text.splitlines()) 107 | 108 | # 判断是否需要分割 109 | need_split = False 110 | if current_lines > self.max_chunk_size: 111 | need_split = True 112 | if self.max_tokens is not None: 113 | current_tokens = self.length_function(text) 114 | if current_tokens > self.max_tokens: 115 | need_split = True 116 | 117 | if not need_split: 118 | code = text.strip() 119 | if not code: 120 | return [] 121 | start_line, end_line = self._get_original_lines(start_pos, end_pos) 122 | # 正常返回的分块,标记 forced 为 False 123 | return [{ 124 | "start": start_line, 125 | "end": end_line, 126 | "code": code, 127 | "forced": False 128 | }] 129 | 130 | # 尝试使用当前分隔符分割 131 | if delimiters: 132 | current_delim = delimiters[0] 133 | parts = text.split(current_delim) 134 | if len(parts) > 1: 135 | return self._split_by_delimiter(start_pos, end_pos, current_delim, delimiters[1:]) 136 | 137 | # 无有效分隔符时强制分割 138 | return self._forced_split(char_range) 139 | 140 | def _split_by_delimiter( 141 | self, 142 | start_pos: int, 143 | end_pos: int, 144 | delimiter: str, 145 | next_delimiters: List[str] 146 | ) -> List[Dict[str, Any]]: 147 | """使用指定分隔符进行分割""" 148 | text = self.text[start_pos:end_pos] 149 | delim_len = len(delimiter) 150 | segments = [] 151 | current_start = start_pos 152 | 153 | for part in text.split(delimiter): 154 | if not part.strip(): 155 | current_start += len(part) + delim_len 156 | continue 157 | 158 | part_end = current_start + len(part) 159 | sub_segments = self._recursive_split( 160 | (current_start, part_end), 161 | next_delimiters 162 | ) 163 | segments.extend(sub_segments) 164 | current_start = part_end + delim_len 165 | 166 | return segments 167 | 168 | def _compute_optimal_chunk_length(self, char_range: Tuple[int, int]) -> int: 169 | """ 170 | 在指定的字符区间内(二分查找),计算一个合适的拆分长度,使得: 171 | self.length_function(文本[start_pos:start_pos+chunk_len]) <= self.max_tokens 172 | """ 173 | start_pos, end_pos = char_range 174 | total_chars = end_pos - start_pos 175 | 176 | low, high = 1, total_chars 177 | increment_value = self.max_tokens // 10 178 | optimal = low # 至少保证 1 个字符 179 | while low <= high: 180 | mid = (low + high) // 2 181 | chunk_text = self.text[start_pos:start_pos + mid] 182 | token_count = self.length_function(chunk_text) 183 | if token_count <= self.max_tokens: 184 | # 当前长度符合要求,记录下来并尝试更大的长度 185 | optimal = mid 186 | low = mid + increment_value 187 | else: 188 | # 当前长度超出限制,尝试减小长度 189 | high = mid - increment_value 190 | 191 | return optimal 192 | 193 | def _forced_split(self, char_range: Tuple[int, int], block: int = 1) -> List[Dict[str, Any]]: 194 | start_pos, end_pos = char_range 195 | segments = [] 196 | pos = start_pos 197 | current_block = block 198 | 199 | while pos < end_pos: 200 | optimal_chunk_size = self._compute_optimal_chunk_length((pos, end_pos)) 201 | next_pos = min(pos + optimal_chunk_size, end_pos) 202 | chunk_text = self.text[pos:next_pos].strip() 203 | if chunk_text: 204 | start_line, end_line = self._get_original_lines(pos, next_pos) 205 | segments.append({ 206 | "start": start_line, 207 | "end": end_line, 208 | "code": chunk_text, 209 | "forced": True, # 标记为强制拆分生成的片段 210 | "block": current_block 211 | }) 212 | current_block += 1 213 | pos = next_pos 214 | return segments 215 | 216 | def _merge_small_segments(self, segments: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 217 | merged = [] 218 | if not segments: 219 | return merged 220 | buffer = segments[0] 221 | for seg in segments[1:]: 222 | # 如果一个分块是强制拆分的,而另一个不是,则不进行合并,直接分开处理 223 | if buffer.get("forced", False) != seg.get("forced", False): 224 | merged.append(buffer) 225 | buffer = seg 226 | continue 227 | 228 | merged_code = f"{buffer['code']}\n{seg['code']}" 229 | merged_lines = len(merged_code.splitlines()) 230 | current_lines = buffer["end"] - buffer["start"] + 1 231 | next_lines = seg["end"] - seg["start"] + 1 232 | 233 | if self.max_tokens is not None: 234 | # 有 token 限制时,必须确保合并后 token 数不超过限制 235 | merged_tokens = self.length_function(merged_code) 236 | if merged_tokens <= self.max_tokens: 237 | buffer = { 238 | "start": buffer["start"], 239 | "end": seg["end"], 240 | "code": merged_code.strip(), 241 | "forced": buffer.get("forced", False) 242 | } 243 | else: 244 | merged.append(buffer) 245 | buffer = seg 246 | else: 247 | # 无 token 限制时,基于行数进行合并: 248 | # 如果合并后行数不超过 max_chunk_size 或者任一块过小,都可以合并 249 | if (merged_lines <= self.max_chunk_size) or ( 250 | current_lines < self.min_chunk_size or next_lines < self.min_chunk_size): 251 | buffer = { 252 | "start": buffer["start"], 253 | "end": seg["end"], 254 | "code": merged_code.strip(), 255 | "forced": buffer.get("forced", False) 256 | } 257 | else: 258 | merged.append(buffer) 259 | buffer = seg 260 | merged.append(buffer) 261 | return merged 262 | -------------------------------------------------------------------------------- /loader/segmenter/tree_sitter_segmenter.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from collections import deque 3 | from typing import List, Dict, Any, Iterator 4 | 5 | from tree_sitter import Language, Parser 6 | 7 | from domain.entity.code_segment import CodeSegment 8 | from loader.segmenter.base_segmenter import BaseSegmenter 9 | 10 | 11 | class TreeSitterSegmenter(BaseSegmenter): 12 | def __init__( 13 | self, 14 | text: str, 15 | chunk_size: int = 30, 16 | min_chunk_size: int = 10, 17 | max_chunk_size: int = 50, 18 | max_depth: int = 5, 19 | split_large_chunks=True, 20 | **kwargs 21 | ): 22 | super().__init__(**kwargs) 23 | self.text = text 24 | self.chunk_size = chunk_size 25 | self.min_chunk_size = min_chunk_size 26 | self.max_chunk_size = max_chunk_size 27 | self.max_depth = max_depth 28 | self.split_large_chunks = split_large_chunks 29 | self.parser = self.get_parser() 30 | 31 | @abstractmethod 32 | def get_language(self) -> Language: 33 | """返回 Tree-sitter 对应语言对象。示例: 34 | return Language('build/my-languages.so', 'python') 35 | """ 36 | raise NotImplementedError 37 | 38 | @abstractmethod 39 | def get_node_types(self) -> List[str]: 40 | """ 41 | 返回要提取的“目标节点类型”, 42 | 如 ["function_definition", "class_definition"] 等 43 | """ 44 | raise NotImplementedError 45 | 46 | @abstractmethod 47 | def get_recursion_node_types(self) -> List[str]: 48 | """ 49 | 返回需要继续递归遍历其子节点的类型, 50 | 如 ["module", "block", "class_definition", "function_definition"] 等 51 | """ 52 | raise NotImplementedError 53 | 54 | def is_valid(self) -> bool: 55 | """简单检查语法是否存在 ERROR 节点""" 56 | error_query = self.parser.language.query("(ERROR) @error") 57 | tree = self.parser.parse(bytes(self.text, encoding="UTF-8")) 58 | return len(error_query.captures(tree.root_node)) == 0 59 | 60 | def segment(self) -> Iterator[CodeSegment]: 61 | tree = self.parser.parse(bytes(self.text, "utf-8")) 62 | unfiltered_nodes = deque() 63 | for child in tree.root_node.children: 64 | unfiltered_nodes.append((child, 1)) 65 | all_nodes = [] 66 | while unfiltered_nodes: 67 | current_node, current_depth = unfiltered_nodes.popleft() 68 | if current_node.type in self.get_node_types(): 69 | all_nodes.append(current_node) 70 | if current_node.type in self.get_recursion_node_types(): 71 | if current_depth < self.max_depth: 72 | for child in current_node.children: 73 | unfiltered_nodes.append((child, current_depth + 1)) 74 | all_nodes.sort(key=lambda n: (n.start_point[0], -n.end_point[0])) 75 | 76 | processed_ranges = [] 77 | processed_chunks_info = [] 78 | for node in all_nodes: 79 | start_0 = node.start_point[0] 80 | end_0 = node.end_point[0] 81 | node_text = node.text.decode() 82 | if self._is_range_covered(start_0, end_0, processed_ranges): 83 | continue 84 | processed_ranges.append((start_0, end_0)) 85 | processed_chunks_info.append({ 86 | "start": start_0 + 1, 87 | "end": end_0 + 1, 88 | "code": node_text 89 | }) 90 | 91 | processed_chunks_info.sort(key=lambda x: x["start"]) 92 | # 用 split("\n") 来和 Tree-sitter 的行数计算方式保持一致 93 | code_lines = self.text.split("\n") 94 | # 明确计算总行数:tree-sitter 的行数 = 文件中 "\n" 数量 + 1 95 | total_lines = self.text.count("\n") + 1 96 | 97 | combined_chunks = [] 98 | current_pos = 0 # 0-based 行号 99 | for chunk in processed_chunks_info: 100 | chunk_start_0 = chunk["start"] - 1 101 | chunk_end_0 = chunk["end"] - 1 102 | if current_pos < chunk_start_0: 103 | unprocessed = self._handle_unprocessed( 104 | code_lines[current_pos:chunk_start_0], 105 | current_pos 106 | ) 107 | combined_chunks.extend(unprocessed) 108 | combined_chunks.append({ 109 | "start": chunk["start"], 110 | "end": chunk["end"], 111 | "code": chunk["code"] 112 | }) 113 | current_pos = chunk_end_0 + 1 114 | 115 | if current_pos < total_lines: 116 | unprocessed = self._handle_unprocessed( 117 | code_lines[current_pos:total_lines], 118 | current_pos 119 | ) 120 | combined_chunks.extend(unprocessed) 121 | 122 | final_chunks = self._post_process_chunks(combined_chunks) 123 | if self.max_tokens: 124 | final_chunks = self._split_by_tokens(final_chunks, self.max_tokens) 125 | for segment in final_chunks: 126 | code_segment = CodeSegment.model_validate(segment) 127 | if not code_segment.code.isspace(): 128 | yield code_segment 129 | 130 | def get_parser(self) -> Parser: 131 | """初始化并返回 Parser 对象""" 132 | parser = Parser() 133 | parser.language = self.get_language() 134 | return parser 135 | 136 | @staticmethod 137 | def _is_range_covered(start: int, end: int, existing_ranges: List[tuple]) -> bool: 138 | """ 139 | 优化后的区间覆盖检查(时间复杂度 O(n))。 140 | 如果已有区间 [existing_start, existing_end] 能完全覆盖 [start, end],则返回 True 141 | """ 142 | for (existing_start, existing_end) in existing_ranges: 143 | if existing_start <= start and end <= existing_end: 144 | return True 145 | return False 146 | 147 | def _handle_unprocessed(self, lines: List[str], start_0: int) -> List[Dict[str, Any]]: 148 | """处理未识别的代码区域。""" 149 | if not lines: 150 | return [] 151 | return self._split_into_chunks_without_empty_lines(lines, start_0) 152 | 153 | @staticmethod 154 | def _split_into_chunks_without_empty_lines( 155 | lines: List[str], 156 | start_0_based: int 157 | ) -> List[Dict[str, Any]]: 158 | """将未识别区域当作一个连续块输出,保留所有行""" 159 | return [{ 160 | "start": start_0_based + 1, 161 | "end": start_0_based + len(lines), 162 | "code": "\n".join(lines) 163 | }] 164 | 165 | def _post_process_chunks( 166 | self, 167 | chunks: List[Dict[str, Any]], 168 | ) -> List[Dict[str, Any]]: 169 | """ 170 | 对 chunks 做二次处理: 171 | 1) 如果 chunk 行数 > max_chunk_size,按 chunk_size 均分拆分(行数维度)。 172 | 2) 如果 chunk 行数 < min_chunk_size,尝试与前或后合并(只要合并后不超过 max_chunk_size)。 173 | """ 174 | # 第一步:先将所有大于 max_chunk_size 的 chunk 拆分 175 | split_chunks = [] 176 | if self.split_large_chunks: 177 | for c in chunks: 178 | if len(c['code']): 179 | split_chunks.extend( 180 | self._split_large_chunk(c, self.chunk_size, self.max_chunk_size) 181 | ) 182 | else: 183 | split_chunks = chunks 184 | 185 | # 第二步:对拆分后的结果,进行“过短合并” 186 | if self.merge_small_chunks: 187 | return self._merge_small_chunks( 188 | split_chunks, 189 | self.min_chunk_size, 190 | self.max_chunk_size, 191 | self.chunk_size 192 | ) 193 | else: 194 | return split_chunks 195 | 196 | @staticmethod 197 | def _split_large_chunk( 198 | chunk: Dict[str, Any], 199 | chunk_size: int, 200 | max_chunk_size: int 201 | ) -> List[Dict[str, Any]]: 202 | """ 203 | 如果某个 chunk 行数大于 max_chunk_size, 204 | 则根据 chunk_size 进行“均分拆分”,保证子块不会过大或过小。 205 | """ 206 | lines = chunk["code"].split("\n") 207 | total_lines = len(lines) 208 | if total_lines <= max_chunk_size: 209 | return [chunk] 210 | 211 | # 初步估计拆分为 n 块 212 | n = max(1, round(total_lines / chunk_size)) 213 | # 如果拆分后每块依旧大于 max_chunk_size,则增加 n 214 | while (total_lines / n) > max_chunk_size: 215 | n += 1 216 | # 确保 n 不超过 total_lines 217 | n = min(n, total_lines) 218 | 219 | base_size = total_lines // n 220 | remainder = total_lines % n # 前 remainder 块每块多 1 行 221 | results = [] 222 | current_index = 0 223 | original_start = chunk["start"] 224 | for i in range(n): 225 | current_chunk_size = base_size + (1 if i < remainder else 0) 226 | sub_lines = lines[current_index: current_index + current_chunk_size] 227 | results.append({ 228 | "start": original_start + current_index, 229 | "end": original_start + current_index + current_chunk_size - 1, 230 | "code": "\n".join(sub_lines) 231 | }) 232 | current_index += current_chunk_size 233 | return results 234 | 235 | @staticmethod 236 | def _merge_small_chunks( 237 | chunks: List[Dict[str, Any]], 238 | min_chunk_size: int, 239 | max_chunk_size: int, 240 | chunk_size: int 241 | ) -> List[Dict[str, Any]]: 242 | """ 243 | 对较小 chunk (< min_chunk_size) 做前后合并(只要合并后不超过 max_chunk_size)。 244 | 同时设置了一个“理想范围”容差(tolerance_low ~ tolerance_high), 245 | 如果相邻块本身已经在理想范围内,就不再合并以免破坏合理块。 246 | """ 247 | merged = chunks.copy() 248 | changed = True 249 | tolerance_low = 0.8 * chunk_size 250 | tolerance_high = 1.2 * chunk_size 251 | 252 | while changed: 253 | changed = False 254 | new_merged = [] 255 | i = 0 256 | while i < len(merged): 257 | current = merged[i] 258 | current_lines = len(current["code"].split("\n")) 259 | 260 | # 如果当前 chunk 已经足够大,则直接放入 new_merged 261 | if current_lines >= min_chunk_size: 262 | new_merged.append(current) 263 | i += 1 264 | continue 265 | 266 | best_merge = None 267 | best_score = float('inf') 268 | 269 | # 尝试与前一个合并 270 | if new_merged: 271 | prev = new_merged[-1] 272 | prev_lines = len(prev["code"].split("\n")) 273 | # 如果前一个 chunk 不在理想范围内,且合并后行数不超过 max_chunk_size 274 | if not (tolerance_low <= prev_lines <= tolerance_high): 275 | if prev_lines + current_lines <= max_chunk_size: 276 | score = abs((prev_lines + current_lines) - chunk_size) 277 | if score < best_score: 278 | best_merge = 'prev' 279 | best_score = score 280 | # 尝试与后一个合并 281 | if i + 1 < len(merged): 282 | nxt = merged[i + 1] 283 | nxt_lines = len(nxt["code"].split("\n")) 284 | if not (tolerance_low <= nxt_lines <= tolerance_high): 285 | if current_lines + nxt_lines <= max_chunk_size: 286 | score = abs((current_lines + nxt_lines) - chunk_size) 287 | if score < best_score: 288 | best_merge = 'next' 289 | if best_merge == 'prev': 290 | prev = new_merged.pop() 291 | new_chunk = { 292 | "start": prev["start"], 293 | "end": current["end"], 294 | "code": prev["code"] + "\n" + current["code"] 295 | } 296 | new_merged.append(new_chunk) 297 | i += 1 298 | changed = True 299 | elif best_merge == 'next': 300 | nxt = merged[i + 1] 301 | new_chunk = { 302 | "start": current["start"], 303 | "end": nxt["end"], 304 | "code": current["code"] + "\n" + nxt["code"] 305 | } 306 | new_merged.append(new_chunk) 307 | i += 2 308 | changed = True 309 | else: 310 | # 如果当前 chunk 特别小 (< min_chunk_size/2),则尝试强制与前一个合并 311 | if current_lines < (min_chunk_size / 2) and new_merged: 312 | prev = new_merged.pop() 313 | new_chunk = { 314 | "start": prev["start"], 315 | "end": current["end"], 316 | "code": prev["code"] + "\n" + current["code"] 317 | } 318 | new_merged.append(new_chunk) 319 | changed = True 320 | else: 321 | new_merged.append(current) 322 | i += 1 323 | merged = new_merged 324 | return merged 325 | 326 | def _split_by_tokens( 327 | self, 328 | chunks: List[Dict[str, Any]], 329 | max_tokens: int, 330 | block: int = 1 331 | ) -> List[Dict[str, Any]]: 332 | processed = [] 333 | for chunk in chunks: 334 | current_code = chunk["code"] 335 | current_tokens = self.length_function(current_code) 336 | if current_tokens <= max_tokens: 337 | if block == 1: 338 | processed.append(chunk) 339 | else: 340 | # 为分割后的块添加序号标记 341 | chunk_with_part = chunk.copy() 342 | chunk_with_part["block"] = block 343 | processed.append(chunk_with_part) 344 | continue 345 | 346 | lines = current_code.split('\n') 347 | if len(lines) > 1: 348 | split_index = len(lines) // 2 349 | first_code = '\n'.join(lines[:split_index]) 350 | second_code = '\n'.join(lines[split_index:]) 351 | first_start = chunk["start"] 352 | first_end = first_start + split_index - 1 353 | second_start = first_end + 1 354 | second_end = chunk["end"] 355 | first_sub = {"start": first_start, "end": first_end, "code": first_code} 356 | second_sub = {"start": second_start, "end": second_end, "code": second_code} 357 | processed.extend(self._split_by_tokens([first_sub, second_sub], max_tokens, block)) 358 | else: 359 | code = current_code 360 | low, high = 0, len(code) 361 | best_mid = 0 362 | while low <= high: 363 | mid = (low + high) // 2 364 | part = code[:mid] 365 | if self.length_function(part) <= max_tokens: 366 | best_mid = mid 367 | low = mid + 1 368 | else: 369 | high = mid - 1 370 | best_mid = best_mid if best_mid != 0 else len(code) // 2 371 | first_part = code[:best_mid] 372 | second_part = code[best_mid:] 373 | first_sub = {"start": chunk["start"], "end": chunk["end"], "code": first_part} 374 | second_sub = {"start": chunk["start"], "end": chunk["end"], "code": second_part} 375 | # 递归调用时增加序号 376 | processed.extend(self._split_by_tokens([first_sub], max_tokens, block)) 377 | processed.extend(self._split_by_tokens([second_sub], max_tokens, block + 1)) 378 | return processed 379 | --------------------------------------------------------------------------------