├── .env.example
├── clients
├── __init__.py
└── jina
│ ├── __init__.py
│ ├── client.py
│ └── embeddings.py
├── config
├── __init__.py
└── loguru.py
├── constants
├── __init__.py
└── constants.py
├── domain
├── __init__.py
├── entity
│ ├── __init__.py
│ ├── code_segment.py
│ ├── task_context.py
│ ├── document.py
│ └── blob.py
├── enums
│ ├── __init__.py
│ ├── model_provider.py
│ ├── chroma_mode.py
│ ├── qdrant_mode.py
│ └── vector_type.py
├── model
│ ├── __init__.py
│ ├── base.py
│ └── repository.py
├── request
│ ├── __init__.py
│ ├── all_index_request.py
│ ├── drop_index_request.py
│ ├── git_index_request.py
│ ├── retrieval_request.py
│ ├── delete_index_request.py
│ └── add_index_request.py
├── result
│ ├── __init__.py
│ └── result.py
└── response
│ ├── __init__.py
│ ├── add_index_by_file_response.py
│ ├── add_index_response.py
│ └── retrieval_segment_response.py
├── exception
├── __init__.py
└── exception.py
├── launch
├── __init__.py
├── we0_index_mcp.py
└── launch.py
├── loader
├── __init__.py
├── segmenter
│ ├── language
│ │ ├── __init__.py
│ │ ├── python.py
│ │ ├── css.py
│ │ ├── go.py
│ │ ├── java.py
│ │ ├── typescriptxml.py
│ │ ├── typescript.py
│ │ └── javascript.py
│ ├── __init__.py
│ ├── tree_sitter_factory.py
│ ├── base_segmenter.py
│ ├── base_line_segmenter.py
│ └── tree_sitter_segmenter.py
└── repo_loader.py
├── models
├── __init__.py
└── model_factory.py
├── prompt
├── __init__.py
└── prompt.py
├── router
├── __init__.py
├── git_router.py
└── vector_router.py
├── setting
├── __init__.py
└── setting.py
├── utils
├── __init__.py
├── path_util.py
├── mimetype_util.py
├── git_parse.py
├── helper.py
└── vector_helper.py
├── extensions
├── __init__.py
├── vector
│ ├── __init__.py
│ ├── base_vector.py
│ ├── ext_vector.py
│ ├── chroma.py
│ ├── qdrant.py
│ └── pgvector.py
└── ext_manager.py
├── resource
└── dev.yaml
├── Dockerfile
├── pyproject.toml
├── main.py
├── README-zh.md
├── .gitignore
└── README.md
/.env.example:
--------------------------------------------------------------------------------
1 | WE0_INDEX_ENV=dev
2 | TZ=Asia/Shanghai
3 | OPENAI_BASE_URL=https://openai.com/v1
4 | OPENAI_API_KEY=
5 | JINA_BASE_URL=https://api.jina.ai/v1
6 | JINA_API_KEY=
--------------------------------------------------------------------------------
/clients/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/19
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/7/17
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/constants/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/7/17
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/domain/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/10
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/exception/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/8/21
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/launch/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/8/21
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/loader/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/10
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/15
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/prompt/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/20
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/router/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/11
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/setting/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/06/19
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/11
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/domain/entity/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/11
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/domain/enums/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/11
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/domain/model/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/20
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/domain/request/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/22
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/domain/result/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/7/17
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/extensions/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/10
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/domain/response/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/22
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/extensions/vector/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/14
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/loader/segmenter/language/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/7
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 |
--------------------------------------------------------------------------------
/domain/model/base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/20
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : base
7 | # @Software: PyCharm
8 | from sqlalchemy.ext.declarative import declarative_base
9 |
10 | Base = declarative_base()
11 |
--------------------------------------------------------------------------------
/domain/enums/model_provider.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/15
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : model_provider
7 | # @Software: PyCharm
8 | from enum import StrEnum
9 |
10 |
11 | class ModelType(StrEnum):
12 | OPENAI = "openai"
13 | JINA = "jina"
14 |
--------------------------------------------------------------------------------
/domain/enums/chroma_mode.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/22
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : qdrant_mode
7 | # @Software: PyCharm
8 | from enum import StrEnum
9 |
10 |
11 | class ChromaMode(StrEnum):
12 | DISK = "disk"
13 | REMOTE = "remote"
14 | MEMORY = "memory"
15 |
--------------------------------------------------------------------------------
/domain/enums/qdrant_mode.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/22
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : qdrant_mode
7 | # @Software: PyCharm
8 | from enum import StrEnum
9 |
10 |
11 | class QdrantMode(StrEnum):
12 | DISK = "disk"
13 | REMOTE = "remote"
14 | MEMORY = "memory"
15 |
--------------------------------------------------------------------------------
/domain/request/all_index_request.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/23
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : all_index_request
7 | # @Software: PyCharm
8 | from pydantic import BaseModel, Field
9 |
10 |
11 | class AllIndexRequest(BaseModel):
12 | repo_id: str = Field(description='仓库 ID')
13 |
--------------------------------------------------------------------------------
/domain/request/drop_index_request.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/23
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : drop_index_request
7 | # @Software: PyCharm
8 | from pydantic import BaseModel, Field
9 |
10 |
11 | class DropIndexRequest(BaseModel):
12 | repo_id: str = Field(description='仓库 ID')
13 |
--------------------------------------------------------------------------------
/domain/enums/vector_type.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/11
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : vector_type
7 | # @Software: PyCharm
8 | from enum import StrEnum
9 |
10 |
11 | class VectorType(StrEnum):
12 | PGVECTOR: str = "pgvector"
13 | QDRANT: str = "qdrant"
14 | CHROMA: str = "chroma"
15 |
--------------------------------------------------------------------------------
/domain/request/git_index_request.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import Optional
3 |
4 | class GitIndexRequest(BaseModel):
5 | uid: str | None = None
6 | repo_url: str
7 | # 私有仓库认证字段
8 | username: Optional[str] = None # 用户名
9 | password: Optional[str] = None # 密码或个人访问令牌
10 | access_token: Optional[str] = None # 访问令牌(GitHub/GitLab Personal Access Token)
--------------------------------------------------------------------------------
/domain/response/add_index_by_file_response.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/23
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : add_index_by_file_response
7 | # @Software: PyCharm
8 |
9 | from pydantic import BaseModel
10 |
11 |
12 | class AddIndexByFileResponse(BaseModel):
13 | repo_id: str
14 | file_id: str
15 |
--------------------------------------------------------------------------------
/clients/jina/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/19
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 | from clients.jina.client import AsyncClient
9 | from clients.jina.embeddings import AsyncEmbeddings
10 |
11 | __all__ = [
12 | 'AsyncClient',
13 | 'AsyncEmbeddings',
14 | ]
15 |
--------------------------------------------------------------------------------
/domain/request/retrieval_request.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/23
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : retrieval_request
7 | # @Software: PyCharm
8 | from typing import List, Optional
9 |
10 | from pydantic import BaseModel
11 |
12 |
13 | class RetrievalRequest(BaseModel):
14 | repo_id: str
15 | file_ids: Optional[List[str]] = None
16 | query: str
17 |
--------------------------------------------------------------------------------
/domain/request/delete_index_request.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/23
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : delete_index_request
7 | # @Software: PyCharm
8 | from typing import List
9 |
10 | from pydantic import BaseModel, Field
11 |
12 |
13 | class DeleteIndexRequest(BaseModel):
14 | repo_id: str = Field(description='仓库 ID')
15 | file_id: List[str] = Field(description='仓库 ID')
16 |
--------------------------------------------------------------------------------
/domain/entity/code_segment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/17
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : code_segment
7 | # @Software: PyCharm
8 | from pydantic import BaseModel, ConfigDict, Field
9 |
10 |
11 | class CodeSegment(BaseModel):
12 | start: int
13 | end: int
14 | code: str
15 | block: int = Field(default=1)
16 | model_config = ConfigDict(
17 | extra='ignore'
18 | )
19 |
--------------------------------------------------------------------------------
/domain/entity/task_context.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/12
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : task_context
7 | # @Software: PyCharm
8 |
9 | from pydantic import BaseModel, ConfigDict
10 |
11 | from domain.entity.blob import Blob
12 |
13 |
14 | class TaskContext(BaseModel):
15 | repo_id: str
16 | file_id: str
17 | relative_path: str
18 | blob: Blob
19 | model_config = ConfigDict(extra='ignore')
20 |
--------------------------------------------------------------------------------
/domain/response/add_index_response.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/23
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : add_index_response
7 | # @Software: PyCharm
8 | from typing import List
9 |
10 | from pydantic import BaseModel
11 |
12 |
13 | class FileInfoResponse(BaseModel):
14 | file_id: str
15 | relative_path: str
16 |
17 |
18 | class AddIndexResponse(BaseModel):
19 | repo_id: str
20 | file_infos: List[FileInfoResponse]
21 |
--------------------------------------------------------------------------------
/extensions/ext_manager.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/23
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : ext_manager
7 | # @Software: PyCharm
8 | from loguru import logger
9 |
10 | from extensions.vector.ext_vector import Vector
11 |
12 |
13 | class ExtManager:
14 | vector = Vector()
15 |
16 |
17 | async def init_vector():
18 | logger.info("Initializing vector")
19 | await ExtManager.vector.init_app()
20 | logger.info("Initialized vector")
21 |
--------------------------------------------------------------------------------
/utils/path_util.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/7/18
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : path_util
7 | # @Software: PyCharm
8 | import os
9 |
10 |
11 | class PathUtil:
12 | @staticmethod
13 | def check_or_make_dir(path: str):
14 | if not os.path.exists(path):
15 | os.makedirs(path)
16 |
17 | @staticmethod
18 | def check_or_make_dirs(*paths):
19 | for path in paths:
20 | PathUtil.check_or_make_dir(path)
21 |
--------------------------------------------------------------------------------
/exception/exception.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/8/21
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : exception
7 | # @Software: PyCharm
8 |
9 |
10 | class CommonException(Exception):
11 | def __init__(self, message=''):
12 | self.message = message
13 | super().__init__(self.message)
14 |
15 | def __str__(self):
16 | return self.message
17 |
18 | def __repr__(self):
19 | return self.message
20 |
21 |
22 | class StorageUploadFileException(CommonException):
23 | ...
24 |
--------------------------------------------------------------------------------
/domain/response/retrieval_segment_response.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/22
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : retrieval_segment_response
7 | # @Software: PyCharm
8 | from pydantic import BaseModel, Field, ConfigDict
9 |
10 |
11 | class RetrievalSegmentResponse(BaseModel):
12 | relative_path: str = Field(description='代码段所属文件相对路径')
13 | start_line: int = Field(description='代码段开始行')
14 | end_line: int = Field(description='代码段结束行')
15 | score: float = Field(description='相似度评分')
16 | model_config = ConfigDict(
17 | extra='ignore'
18 | )
19 |
--------------------------------------------------------------------------------
/domain/request/add_index_request.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/23
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : add_index_request
7 | # @Software: PyCharm
8 | from typing import List
9 |
10 | from pydantic import BaseModel, Field
11 |
12 |
13 | class AddFileInfo(BaseModel):
14 | relative_path: str = Field(description='File Relative Path')
15 | content: str = Field(description='File Content')
16 |
17 |
18 | class AddIndexRequest(BaseModel):
19 | uid: str = Field(description='Unique ID')
20 | repo_abs_path: str = Field(description='Repository Absolute Path')
21 | file_infos: List[AddFileInfo]
22 |
--------------------------------------------------------------------------------
/domain/result/result.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/07/17
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : result
7 | # @Software: PyCharm
8 |
9 | from pydantic import BaseModel
10 |
11 |
12 | class Result[T](BaseModel):
13 | code: int
14 | message: str
15 | data: T | None = None
16 | success: bool
17 |
18 | @classmethod
19 | def ok(cls, data: T | None = None, code: int = 200, message: str = 'Success'):
20 | return cls(code=code, message=message, data=data, success=True)
21 |
22 | @classmethod
23 | def failed(cls, code: int = 500, message: str = 'Internal Server Error'):
24 | return cls(code=code, message=message, success=False)
25 |
--------------------------------------------------------------------------------
/loader/segmenter/language/python.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from tree_sitter import Language
4 |
5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory
6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter
7 |
8 |
9 | @TreeSitterFactory.register(ext_set={'.py'})
10 | class PythonSegmenter(TreeSitterSegmenter):
11 | """Code segmenter for Python."""
12 |
13 | def get_language(self) -> Language:
14 | import tree_sitter_python
15 | return Language(tree_sitter_python.language())
16 |
17 | def get_node_types(self) -> List[str]:
18 | return ['function_definition', 'decorated_definition']
19 |
20 | def get_recursion_node_types(self) -> List[str]:
21 | return ['class_definition', 'block']
22 |
--------------------------------------------------------------------------------
/loader/segmenter/language/css.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from tree_sitter import Language
4 |
5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory
6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter
7 |
8 |
9 | @TreeSitterFactory.register(ext_set={'.css'})
10 | class CssSegmenter(TreeSitterSegmenter):
11 | """Code segmenter for Css."""
12 |
13 | def get_language(self) -> Language:
14 | import tree_sitter_css
15 | return Language(tree_sitter_css.language())
16 |
17 | def get_node_types(self) -> List[str]:
18 | return [
19 | 'rule_set',
20 | 'keyframes_statement',
21 | 'media_statement'
22 | ]
23 |
24 | def get_recursion_node_types(self) -> List[str]:
25 | return []
26 |
--------------------------------------------------------------------------------
/loader/segmenter/language/go.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from tree_sitter import Language
4 |
5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory
6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter
7 |
8 |
9 | @TreeSitterFactory.register(ext_set={'.go'})
10 | class GoSegmenter(TreeSitterSegmenter):
11 | """Code segmenter for Go."""
12 |
13 | def get_language(self) -> Language:
14 | import tree_sitter_go
15 | return Language(tree_sitter_go.language())
16 |
17 | def get_node_types(self) -> List[str]:
18 | return [
19 | 'method_declaration',
20 | 'function_declaration',
21 | 'type_declaration'
22 | ]
23 |
24 | def get_recursion_node_types(self) -> List[str]:
25 | return []
26 |
--------------------------------------------------------------------------------
/resource/dev.yaml:
--------------------------------------------------------------------------------
1 | we0-index:
2 | application: we0-index
3 | server:
4 | host: 0.0.0.0
5 | port: 8080
6 | reload: True
7 | log:
8 | level: INFO
9 | file: false
10 | debug: false
11 | vector:
12 | platform: pgvector
13 | code2desc: false
14 | chat-provider: openai
15 | chat-model: gpt-4o-mini
16 | embedding-provider: jina
17 | embedding-model: jina-embeddings-v2-base-code
18 | pgvector:
19 | db: we0_index
20 | host: localhost
21 | port: 5432
22 | user: root
23 | password: password
24 | qdrant:
25 | mode: disk
26 | disk:
27 | path: vector/qdrant
28 | remote:
29 | host: localhost
30 | port: 6333
31 | chroma:
32 | mode: disk
33 | disk:
34 | path: vector/chroma
35 | remote:
36 | host: localhost
37 | port: 8000
38 | ssl: false
--------------------------------------------------------------------------------
/loader/segmenter/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/12
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : __init__.py
7 | # @Software: PyCharm
8 | from loader.segmenter.language.css import CssSegmenter
9 | from loader.segmenter.language.go import GoSegmenter
10 | from loader.segmenter.language.java import JavaSegmenter
11 | from loader.segmenter.language.javascript import JavaScriptSegmenter
12 | from loader.segmenter.language.python import PythonSegmenter
13 | from loader.segmenter.language.typescript import TypeScriptSegmenter
14 | from loader.segmenter.language.typescriptxml import TypeScriptXmlSegmenter
15 |
16 | __all__ = [
17 | 'CssSegmenter',
18 | 'GoSegmenter',
19 | 'JavaSegmenter',
20 | 'JavaScriptSegmenter',
21 | 'PythonSegmenter',
22 | 'TypeScriptSegmenter',
23 | 'TypeScriptXmlSegmenter'
24 | ]
25 |
--------------------------------------------------------------------------------
/utils/mimetype_util.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/12
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : mimetype_util
7 | # @Software: PyCharm
8 | import mimetypes
9 | import os
10 |
11 |
12 | def guess_mimetype_and_extension(file_path: str) -> tuple[str | None, str | None]:
13 | """
14 | 根据文件路径(通常是文件名或 URL 中的路径片段)来猜测 MIME 类型和扩展名。
15 | 优先使用文件后缀名进行判断,如果没有后缀名则使用默认的 application/octet-stream。
16 | """
17 | _, extension = os.path.splitext(file_path)
18 | if extension:
19 | # 通过后缀名在 mimetypes 中查找 MIME 类型
20 | mimetype = mimetypes.types_map.get(extension.lower())
21 | else:
22 | # 如果没有扩展名,使用默认的二进制流类型
23 | mimetype = 'application/octet-stream'
24 | # 根据 MIME 类型再猜一次扩展名(一般会给出 .bin)
25 | extension = mimetypes.guess_extension(mimetype)
26 |
27 | return mimetype, extension
28 |
--------------------------------------------------------------------------------
/loader/segmenter/language/java.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from tree_sitter import Language
4 |
5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory
6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter
7 |
8 |
9 | @TreeSitterFactory.register(ext_set={'.java'})
10 | class JavaSegmenter(TreeSitterSegmenter):
11 | """Code segmenter for Java."""
12 |
13 | def get_language(self) -> Language:
14 | import tree_sitter_java
15 | return Language(tree_sitter_java.language())
16 |
17 | def get_node_types(self) -> List[str]:
18 | return [
19 | 'method_declaration',
20 | 'enum_declaration'
21 | ]
22 |
23 | def get_recursion_node_types(self) -> List[str]:
24 | return [
25 | 'class_declaration',
26 | 'class_body',
27 | 'interface_declaration',
28 | 'interface_body'
29 | ]
30 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # ----------------------------
2 | # builder
3 | # ----------------------------
4 | FROM python:3.12.8-slim AS builder
5 |
6 | ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy UV_PYTHON_DOWNLOADS=0
7 |
8 | WORKDIR /app
9 |
10 | RUN pip install --no-cache-dir uv
11 |
12 | COPY pyproject.toml /app
13 |
14 | RUN uv lock && uv sync --frozen --no-dev
15 | # ----------------------------
16 | # runtime
17 | # ----------------------------
18 | FROM python:3.12.8-slim
19 |
20 | RUN apt-get update && \
21 | apt-get install --no-install-recommends -y git ca-certificates && \
22 | rm -rf /var/lib/apt/lists/*
23 | ENV PORT 8080
24 | ENV PATH="/app/.venv/bin:$PATH"
25 |
26 | EXPOSE $PORT
27 | VOLUME /app/resource
28 | VOLUME /app/log
29 | VOLUME /app/vector
30 | VOLUME /app/storage
31 |
32 | WORKDIR /app
33 |
34 | COPY . /app
35 |
36 | COPY --from=builder /usr/local/bin /usr/local/bin
37 | COPY --from=builder --chown=app:app /app/.venv /app/.venv
38 |
39 | CMD uv run python /app/main.py
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "we0-index"
3 | version = "0.1.0"
4 | description = "We0 Index"
5 | readme = "README.md"
6 | requires-python = ">=3.12"
7 | dependencies = [
8 | "aiofiles>=24.1.0",
9 | "chromadb>=0.6.3",
10 | "fastapi>=0.115.7",
11 | "gitpython>=3.1.42",
12 | "greenlet>=3.2.2",
13 | "loguru>=0.7.3",
14 | "mcp[cli]>=1.9.2",
15 | "openai",
16 | "psycopg[binary,pool]>=3.2.4",
17 | "pydantic-settings>=2.7.1",
18 | "python-multipart>=0.0.20",
19 | "pyyaml>=6.0.2",
20 | "qdrant-client>=1.13.2",
21 | "sqlalchemy>=2.0.41",
22 | "tiktoken>=0.8.0",
23 | "tree-sitter>=0.24.0",
24 | "tree-sitter-css>=0.23.2",
25 | "tree-sitter-go>=0.23.4",
26 | "tree-sitter-java>=0.23.5",
27 | "tree-sitter-javascript>=0.23.1",
28 | "tree-sitter-python>=0.23.6",
29 | "tree-sitter-typescript>=0.23.2",
30 | "uvicorn>=0.34.0",
31 | ]
32 |
33 | [[tool.uv.index]]
34 | url = "https://mirrors.aliyun.com/pypi/simple"
35 | default = true
36 |
--------------------------------------------------------------------------------
/loader/segmenter/language/typescriptxml.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from tree_sitter import Language
4 |
5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory
6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter
7 |
8 |
9 | @TreeSitterFactory.register(ext_set={'.tsx'})
10 | class TypeScriptXmlSegmenter(TreeSitterSegmenter):
11 | """Code segmenter for TypeScriptXml."""
12 |
13 | def get_language(self) -> Language:
14 | import tree_sitter_typescript
15 | return Language(tree_sitter_typescript.language_tsx())
16 |
17 | def get_node_types(self) -> List[str]:
18 | return [
19 | 'lexical_declaration',
20 | 'interface_declaration',
21 | 'method_definition',
22 | 'function_declaration',
23 | 'export_statement'
24 | ]
25 |
26 | def get_recursion_node_types(self) -> List[str]:
27 | return [
28 | 'class_declaration',
29 | 'class_body'
30 | ]
31 |
--------------------------------------------------------------------------------
/loader/segmenter/language/typescript.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from tree_sitter import Language
4 |
5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory
6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter
7 |
8 |
9 | @TreeSitterFactory.register(ext_set={'.ts'})
10 | class TypeScriptSegmenter(TreeSitterSegmenter):
11 | """Code segmenter for TypeScript."""
12 |
13 | def get_language(self) -> Language:
14 | import tree_sitter_typescript
15 | return Language(tree_sitter_typescript.language_typescript())
16 |
17 | def get_node_types(self) -> List[str]:
18 | return [
19 | 'lexical_declaration',
20 | 'interface_declaration',
21 | 'method_definition',
22 | 'function_declaration',
23 | 'export_statement'
24 | ]
25 |
26 | def get_recursion_node_types(self) -> List[str]:
27 | return [
28 | 'class_declaration',
29 | 'class_body',
30 | 'abstract_class_declaration'
31 | ]
32 |
--------------------------------------------------------------------------------
/loader/segmenter/language/javascript.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from tree_sitter import Language
4 |
5 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory
6 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter
7 |
8 |
9 | @TreeSitterFactory.register(ext_set={'.js', '.mjs'})
10 | class JavaScriptSegmenter(TreeSitterSegmenter):
11 | """Code segmenter for JavaScript."""
12 |
13 | def get_language(self) -> Language:
14 | import tree_sitter_javascript
15 | return Language(tree_sitter_javascript.language())
16 |
17 | def get_node_types(self) -> List[str]:
18 | return [
19 | 'lexical_declaration',
20 | 'interface_declaration',
21 | 'export_statement',
22 | 'method_definition',
23 | 'function_declaration',
24 | 'function_expression',
25 | 'generator_function',
26 | 'generator_function_declaration'
27 | ]
28 |
29 | def get_recursion_node_types(self) -> List[str]:
30 | return [
31 | 'class_declaration',
32 | 'class_body'
33 | ]
34 |
--------------------------------------------------------------------------------
/domain/model/repository.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/20
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : repository
7 | # @Software: PyCharm
8 | from datetime import datetime
9 |
10 | from sqlalchemy import UUID, String, DateTime, func
11 | from sqlalchemy.dialects import postgresql
12 | from sqlalchemy.orm import Mapped, mapped_column
13 |
14 | from domain.model.base import Base
15 |
16 |
17 | class Repository(Base):
18 | __tablename__ = 'repository_info'
19 | id: Mapped[UUID] = mapped_column(
20 | postgresql.UUID(as_uuid=True),
21 | primary_key=True,
22 | )
23 | embedding_model: Mapped[str] = mapped_column(String, nullable=True)
24 | embedding_model_provider: Mapped[str] = mapped_column(String, nullable=True)
25 | created_by: Mapped[str] = mapped_column(String, nullable=False)
26 | created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
27 | updated_by: Mapped[str] = mapped_column(String, nullable=False)
28 | updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
29 |
--------------------------------------------------------------------------------
/constants/constants.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/7/17
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : constants
7 | # @Software: PyCharm
8 | import os
9 |
10 | from dotenv import find_dotenv, load_dotenv
11 |
12 |
13 | class Constants:
14 | class Common:
15 | PROJECT_NAME: str = 'we0-index'
16 |
17 | class Path:
18 | # SYSTEM PATH
19 | ROOT_PATH: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
20 | LOG_PATH: str = os.path.join(ROOT_PATH, 'log')
21 | RESOURCE_PATH: str = os.path.join(ROOT_PATH, 'resource')
22 | ENV_FILE_PATH: str = os.path.join(ROOT_PATH, '.env')
23 | TEMP_PATH: str = '/tmp'
24 | QDRANT_DEFAULT_DISK_PATH: str = os.path.join(ROOT_PATH, 'vector', 'qdrant')
25 | CHROMA_DEFAULT_DISK_PATH: str = os.path.join(ROOT_PATH, 'vector', 'chroma')
26 |
27 | # We0 CONFIG
28 | load_dotenv(ENV_FILE_PATH)
29 | YAML_FILE_PATH: str = find_dotenv(
30 | filename=os.path.join(
31 | RESOURCE_PATH,
32 | f"{os.environ.get('WE0_INDEX_ENV', 'dev')}.yaml"
33 | )
34 | )
35 |
--------------------------------------------------------------------------------
/clients/jina/client.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/19
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : jina
7 | # @Software: PyCharm
8 | import os
9 | from typing import Optional
10 |
11 | import httpx
12 |
13 |
14 | class AsyncClient(httpx.AsyncClient):
15 |
16 | def __init__(
17 | self,
18 | base_url: Optional[str] = None,
19 | api_key: Optional[str] = None,
20 | timeout: Optional[int] = 300,
21 | *args, **kwargs
22 | ):
23 | if api_key is None:
24 | api_key = os.environ.get("JINA_API_KEY")
25 | if api_key is None:
26 | raise ValueError(
27 | "The api_key clients option must be set either by passing api_key to the clients or by setting the JINA_API_KEY environment variable"
28 | )
29 | self.api_key = api_key
30 |
31 | if base_url is None:
32 | base_url = os.environ.get("JINA_BASE_URL")
33 | if base_url is None:
34 | base_url = f"https://api.jina.ai/v1"
35 |
36 | super().__init__(base_url=base_url, timeout=timeout, *args, **kwargs)
37 |
38 | from .embeddings import AsyncEmbeddings
39 | self.embeddings = AsyncEmbeddings(self)
40 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/10
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : main
7 | # @Software: PyCharm
8 |
9 | import click
10 | import uvicorn
11 |
12 | from constants.constants import Constants
13 | from launch.we0_index_mcp import we0_index_mcp
14 | from setting.setting import get_we0_index_settings
15 |
16 | @click.command()
17 | @click.option('--mode', default='mcp', show_default=True, type=click.Choice(['mcp', 'fastapi']), required=True, help='Choose run mode: "mcp" or "fastapi".')
18 | @click.option('--transport', default='streamable-http', show_default=True, type=click.Choice(['streamable-http', 'stdio', 'sse']), help='Transport protocol for MCP mode')
19 | def main(mode, transport):
20 | sider_settings = get_we0_index_settings()
21 |
22 | if mode == 'mcp':
23 | we0_index_mcp.run(transport)
24 | elif mode == 'fastapi':
25 | uvicorn.run(
26 | 'launch.launch:app',
27 | host=sider_settings.server.host,
28 | port=sider_settings.server.port,
29 | reload=sider_settings.server.reload,
30 | env_file=Constants.Path.ENV_FILE_PATH
31 | )
32 | else:
33 | raise ValueError(f"Unknown mode: {mode}")
34 |
35 | if __name__ == '__main__':
36 | main()
37 |
--------------------------------------------------------------------------------
/loader/repo_loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/11
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : repo_loader
7 | # @Software: PyCharm
8 | """Abstract interface for document loader implementations."""
9 |
10 | from typing import AsyncIterator, Type
11 |
12 | from domain.entity.blob import Blob
13 | from domain.entity.code_segment import CodeSegment
14 | from loader.segmenter.base_line_segmenter import LineBasedSegmenter
15 | from loader.segmenter.base_segmenter import BaseSegmenter
16 | from loader.segmenter.tree_sitter_factory import TreeSitterFactory
17 |
18 |
19 | class RepoLoader:
20 |
21 | @classmethod
22 | def get_segmenter_constructor(cls, extension: str | None = None) -> Type[BaseSegmenter]:
23 | if extension in TreeSitterFactory.get_ext_set():
24 | return TreeSitterFactory.get_segmenter(extension)
25 | else:
26 | return LineBasedSegmenter
27 |
28 | @classmethod
29 | async def load_blob(cls, blob: Blob) -> AsyncIterator[CodeSegment]:
30 | try:
31 | text = await blob.as_string()
32 | except Exception as e:
33 | raise e
34 |
35 | segmenter = cls.get_segmenter_constructor(extension=blob.extension).from_tiktoken_encoder(text=text, merge_small_chunks=True)
36 | if not segmenter.is_valid():
37 | segmenter = cls.get_segmenter_constructor().from_tiktoken_encoder(text=text)
38 |
39 | for code in segmenter.segment():
40 | yield code
41 |
--------------------------------------------------------------------------------
/launch/we0_index_mcp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/6/2
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : we0_index_mcp
7 | # @Software: PyCharm
8 | from contextlib import asynccontextmanager
9 | from typing import AsyncIterator
10 |
11 | from mcp.server.fastmcp import FastMCP
12 | from mcp.server.fastmcp.tools import Tool
13 | from mcp.server.lowlevel.server import LifespanResultT, Server
14 | from mcp.shared.context import RequestT
15 |
16 | from extensions import ext_manager
17 | from router.git_router import clone_and_index
18 | from router.vector_router import retrieval
19 | from setting.setting import get_we0_index_settings
20 |
21 | sider_settings = get_we0_index_settings()
22 |
23 |
24 | @asynccontextmanager
25 | async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
26 | await ext_manager.init_vector()
27 | yield {}
28 | await ext_manager.init_vector()
29 |
30 |
31 | def create_fast_mcp() -> FastMCP:
32 | app = FastMCP(
33 | name="We0 Index",
34 | description="CodeIndex, embedding, retrieval, Tool parameters must be in standard JSON format",
35 | tools=[
36 | Tool.from_function(clone_and_index),
37 | Tool.from_function(retrieval),
38 | ],
39 | lifespan=lifespan,
40 | host=sider_settings.server.host,
41 | port=sider_settings.server.port,
42 | log_level=sider_settings.log.level
43 | )
44 | return app
45 |
46 |
47 | we0_index_mcp = create_fast_mcp()
--------------------------------------------------------------------------------
/loader/segmenter/tree_sitter_factory.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/12
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : tree_sitter_factory
7 | # @Software: PyCharm
8 | from functools import lru_cache
9 | from typing import Dict, Type, Set
10 |
11 | from loader.segmenter.tree_sitter_segmenter import TreeSitterSegmenter
12 |
13 |
14 | class TreeSitterFactory:
15 | __segmenter: Dict[str, Type[TreeSitterSegmenter]] = {}
16 |
17 | @classmethod
18 | def get_segmenter(cls, ext: str) -> Type[TreeSitterSegmenter]:
19 | if ext not in cls.__segmenter:
20 | raise ValueError(f'ext type {ext} is not supported')
21 | return cls.__segmenter[ext]
22 |
23 | @classmethod
24 | def __register(cls, ext: str, _cls: Type[TreeSitterSegmenter]):
25 | cls.__segmenter[ext] = _cls
26 |
27 | @classmethod
28 | def __has_cls(cls, ext: str) -> bool:
29 | return ext in cls.__segmenter
30 |
31 | @classmethod
32 | def register(
33 | cls,
34 | ext_set: Set[str]
35 | ):
36 | if not ext_set:
37 | raise ValueError('Must provide support extension set')
38 |
39 | def decorator(origin_cls):
40 | for ext in ext_set:
41 | if not cls.__has_cls(ext):
42 | TreeSitterFactory.__register(ext, origin_cls)
43 | return origin_cls
44 |
45 | return decorator
46 |
47 | @classmethod
48 | @lru_cache
49 | def get_ext_set(cls) -> Set[str]:
50 | return set(cls.__segmenter.keys())
51 |
--------------------------------------------------------------------------------
/prompt/prompt.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/20
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : prompt
7 | # @Software: PyCharm
8 | class SystemPrompt:
9 | ANALYZE_CODE_PROMPT = """
10 | # Task Instructions
11 | 1. I will provide a code block wrapped in ```
12 | 2. Analyze the code with these steps:
13 | - Identify natural segments separated by empty lines, comment blocks, or logical sections
14 | - Generate technical descriptions for each segment
15 | 3. Output requirements:
16 | - Use numbered Markdown lists (1. 2. 3.)
17 | - Maximum 2 lines per item
18 | - Prioritize functional explanations, then implementation details
19 | - Preserve key technical terms/algorithms
20 | - Keep identical terminology with source code
21 |
22 | # Output Example
23 | 1. Initializes Spring Boot application: Uses @SpringBootApplication to configure bootstrap class, sets base package for component scanning
24 | 2. Implements RESTful endpoint: Creates /user API through @RestController, defines base path with @RequestMapping
25 | 3. Handles file uploads: Leverages S3 SDK to transfer local file_infos to cloud storage
26 |
27 | # Now analyze this code:*
28 | """
29 |
30 |
31 | class SystemMessageTemplate:
32 | ANALYZE_CODE_MESSAGE_TEMPLATE = lambda code_text: [
33 | {
34 | "role": "system",
35 | "content": SystemPrompt.ANALYZE_CODE_PROMPT
36 | },
37 | {
38 | "role": "user",
39 | "content": code_text
40 | }
41 | ]
42 |
--------------------------------------------------------------------------------
/domain/entity/document.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/14
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : document
7 | # @Software: PyCharm
8 | from typing import List, Optional
9 |
10 | from pydantic import BaseModel, ConfigDict, Field
11 |
12 |
13 | class DocumentMeta(BaseModel):
14 | repo_id: Optional[str] = Field(description='仓库ID')
15 | file_id: Optional[str] = Field(description='文件ID')
16 | segment_id: str = Field(description='代码段ID uuid4')
17 | relative_path: str = Field(description='代码段所属文件相对路径')
18 | start_line: int = Field(description='代码块启始行')
19 | end_line: int = Field(description='代码块结束行')
20 | segment_block: int = Field(description='代码块序号')
21 | segment_hash: str = Field(description='代码段哈希')
22 | segment_cl100k_base_token: Optional[int] = Field(default=None, description='代码段 cl100k_base token')
23 | segment_o200k_base_token: Optional[int] = Field(default=None, description='代码段 o200k_base token')
24 | description: Optional[str] = Field(default=None, description='代码描述 可选 用于描述嵌入')
25 |
26 | score: Optional[float] = Field(default=None, description='相似度评分 仅在相似度匹配时使用')
27 | content: Optional[str] = Field(default=None, description='代码块纯文本 兼容qdrant,qdrant其他字段只能存储在payload')
28 |
29 | model_config = ConfigDict(
30 | extra='ignore'
31 | )
32 |
33 |
34 | class Document(BaseModel):
35 | vector: List[float] = Field(description='向量Embedding', default_factory=list)
36 | content: Optional[str] = Field(default=None, description='纯文本代码')
37 | meta: Optional[DocumentMeta] = Field(default=None, description='代码元数据')
38 |
39 | model_config = ConfigDict(
40 | extra='ignore'
41 | )
42 |
--------------------------------------------------------------------------------
/clients/jina/embeddings.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/19
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : embeddings
7 | # @Software: PyCharm
8 | from typing import Union, List, Iterable, Optional
9 |
10 | from openai.types import CreateEmbeddingResponse
11 |
12 | from .client import AsyncClient
13 |
14 |
15 | class AsyncEmbeddings:
16 | def __init__(self, client: AsyncClient):
17 | self.client = client
18 |
19 | async def create(
20 | self,
21 | input: Union[str, List[str], Iterable[int], Iterable[Iterable[int]]],
22 | model: str = 'jina-embeddings-v2-base-code',
23 | normalized: Optional[bool] = None,
24 | embedding_type: Optional[str] = None,
25 | task: Optional[str] = None,
26 | late_chunking: Optional[bool] = None,
27 | dimensions: Optional[int] = None,
28 | ) -> CreateEmbeddingResponse:
29 | request_json = {
30 | 'input': input,
31 | 'model': model
32 | }
33 | if normalized:
34 | request_json['normalized'] = normalized
35 | if embedding_type:
36 | request_json['embedding_type'] = embedding_type
37 | if task:
38 | request_json['task'] = task
39 | if late_chunking:
40 | request_json['late_chunking'] = late_chunking
41 | if dimensions:
42 | request_json['dimensions'] = dimensions
43 | response = await self.client.post(
44 | '/embeddings', json=request_json,
45 | headers={'Authorization': f'Bearer {self.client.api_key}'}
46 | )
47 | return CreateEmbeddingResponse.model_validate(response.json())
48 |
--------------------------------------------------------------------------------
/utils/git_parse.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/4/27
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : git_parse
7 | # @Software: PyCharm
8 |
9 |
10 | import re
11 |
12 |
13 | def parse_git_url(git_url):
14 | # 支持的平台域名
15 | platforms = [
16 | 'github.com',
17 | 'gitlab.com',
18 | 'gitee.com',
19 | 'bitbucket.org',
20 | 'codeberg.org'
21 | ]
22 |
23 | # 支持的URL模式
24 | patterns = [
25 | # SSH格式: git@github.com:owner/repo
26 | r'^git@([^:]+):([^/]+)/([^/.]+)(?:\.git)?$',
27 |
28 | # HTTP/HTTPS格式: http(s)://github.com/owner/repo
29 | r'^https?://([^/]+)/([^/]+)/([^/.]+)(?:\.git)?$'
30 | ]
31 |
32 | # 去除可能的空白
33 | url = git_url.strip() if git_url else ''
34 |
35 | # 遍历匹配模式
36 | for pattern in patterns:
37 | match = re.match(pattern, url)
38 | if match:
39 | domain, owner, repo = match.groups()
40 | if domain.lower() in platforms:
41 | return domain, owner, repo
42 |
43 | return None, None, None
44 |
45 |
46 | if __name__ == '__main__':
47 |
48 | # 测试用例
49 | test_urls = [
50 | # SSH协议
51 | 'git@github.com:we0-dev/we0',
52 | 'git@github.com:we0-dev/we0.git',
53 |
54 | # HTTP协议
55 | 'http://github.com/we0-dev/we0',
56 | 'http://github.com/we0-dev/we0.git',
57 |
58 | # HTTPS协议
59 | 'https://github.com/we0-dev/we0',
60 | 'https://github.com/we0-dev/we0.git',
61 |
62 | # 其他平台
63 | 'git@gitlab.com:group/project',
64 | 'http://gitlab.com/group/project',
65 | 'https://gitee.com/username/repo',
66 |
67 | # 无效的URL
68 | 'we0-dev/we0', # 不完整
69 | 'github.com/we0-dev/we0', # 缺少协议
70 | 'https://example.com/we0-dev/we0' # 非支持的平台
71 | ]
72 |
73 | for test_url in test_urls:
74 | result = parse_git_url(test_url)
75 | print(f"URL: {test_url}")
76 | print(f"Parsed: {result}\n")
77 |
--------------------------------------------------------------------------------
/utils/helper.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/18
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : helper
7 | # @Software: PyCharm
8 |
9 | import uuid
10 | from hashlib import sha256
11 | from typing import Optional, Literal, Union, Collection
12 |
13 | import tiktoken
14 |
15 |
16 | class Helper:
17 | # 缓存字典,用于存储生成的编码器实例
18 | _encoders_cache = {}
19 |
20 | @staticmethod
21 | def generate_text_hash(text: str) -> str:
22 | """生成文本的SHA-256哈希值"""
23 | return sha256(text.encode()).hexdigest()
24 |
25 | @staticmethod
26 | def calculate_tokens(
27 | text: str,
28 | encoding_name: str = "cl100k_base",
29 | model_name: Optional[str] = None,
30 | allowed_special=None,
31 | disallowed_special: Union[Literal["all"], Collection[str]] = "all",
32 | ) -> int:
33 | """使用tiktoken编码器计算文本的token数量"""
34 | if allowed_special is None:
35 | allowed_special = set()
36 |
37 | # 使用model_name/encoding_name和特殊字符集合作为缓存键
38 | cache_key = (model_name, encoding_name, frozenset(allowed_special), disallowed_special)
39 |
40 | # 如果缓存中已有相应的编码器实例,则直接返回
41 | if cache_key in Helper._encoders_cache:
42 | return Helper._encoders_cache[cache_key](text)
43 |
44 | # 根据模型名或编码名称来选择合适的编码器
45 | if model_name:
46 | enc = tiktoken.encoding_for_model(model_name)
47 | else:
48 | enc = tiktoken.get_encoding(encoding_name)
49 |
50 | # 创建新的编码器
51 | def _tiktoken_encoder(_text: str) -> int:
52 | """计算给定文本的token数量"""
53 | return len(
54 | enc.encode(
55 | _text,
56 | allowed_special=allowed_special,
57 | disallowed_special=disallowed_special,
58 | )
59 | )
60 |
61 | # 缓存编码器并返回计算结果
62 | Helper._encoders_cache[cache_key] = _tiktoken_encoder
63 | return _tiktoken_encoder(text)
64 |
65 | @staticmethod
66 | def generate_fixed_uuid(unique_str: str) -> str:
67 | namespace = uuid.NAMESPACE_URL
68 | return str(uuid.uuid5(namespace, unique_str))
69 |
--------------------------------------------------------------------------------
/launch/launch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/7/29
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : launch
7 | # @Software: PyCharm
8 | import asyncio
9 | from contextlib import asynccontextmanager
10 |
11 | from fastapi import FastAPI
12 | from fastapi.encoders import jsonable_encoder
13 | from loguru import logger
14 | from starlette.middleware.cors import CORSMiddleware
15 | from starlette.requests import Request
16 | from starlette.responses import JSONResponse
17 |
18 | from config.loguru import Log
19 | from domain.result.result import Result
20 | from exception.exception import CommonException
21 | from extensions import ext_manager
22 | from router.git_router import git_router
23 | from router.vector_router import vector_router
24 | from setting.setting import get_we0_index_settings
25 |
26 | settings = get_we0_index_settings()
27 |
28 |
29 | async def initialize_extensions():
30 | await asyncio.gather(
31 | ext_manager.init_vector(),
32 | )
33 |
34 |
35 | async def close_extensions():
36 | ...
37 |
38 |
39 | @asynccontextmanager
40 | async def lifespan(fastapi: FastAPI):
41 | try:
42 | Log.start()
43 | await initialize_extensions()
44 | yield
45 | finally:
46 | await close_extensions()
47 | Log.close()
48 |
49 |
50 | def create_app() -> FastAPI:
51 | app = FastAPI(
52 | title="We0 Index",
53 | description="We0 Index API",
54 | version="0.1.0",
55 | lifespan=lifespan
56 | )
57 |
58 | # 添加 CORS 中间件
59 | app.add_middleware(
60 | CORSMiddleware,
61 | allow_origins=["*"],
62 | allow_credentials=True,
63 | allow_methods=["*"],
64 | allow_headers=["*"],
65 | )
66 |
67 | return app
68 |
69 |
70 | app = create_app()
71 | # 注册路由
72 | app.include_router(vector_router, prefix="/vector", tags=["vector"])
73 | app.include_router(git_router, prefix="/git", tags=["git"])
74 |
75 |
76 | @app.exception_handler(CommonException)
77 | async def common_exception_handler(request: Request, exc: CommonException):
78 | error = Result.failed(code=-1, message=exc.message)
79 | logger.exception(f"Url: {request.url}, Exception: CommonException, Error: {error}")
80 | return JSONResponse(content=jsonable_encoder(error))
81 |
82 |
83 | @app.exception_handler(Exception)
84 | async def exception_handler(request: Request, exc: Exception):
85 | error = Result.failed(code=-1, message=str(exc.args))
86 | logger.exception(f"Url: {request.url}, {type(exc).__name__}: {exc} , Error: {error}")
87 | return JSONResponse(content=jsonable_encoder(error))
88 |
89 |
90 | app.add_middleware(
91 | CORSMiddleware,
92 | allow_origins=['*'],
93 | allow_credentials=True,
94 | allow_methods=['*'],
95 | allow_headers=['*'],
96 | )
97 |
--------------------------------------------------------------------------------
/extensions/vector/base_vector.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/14
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : base_vector
7 | # @Software: PyCharm
8 | from __future__ import annotations
9 |
10 | from abc import ABC, abstractmethod
11 | from typing import List, Optional
12 |
13 | from domain.entity.document import Document, DocumentMeta
14 | from models.model_factory import ModelFactory
15 | from setting.setting import get_we0_index_settings
16 |
17 | settings = get_we0_index_settings()
18 |
19 |
20 | class BaseVector(ABC):
21 |
22 | @staticmethod
23 | @abstractmethod
24 | def get_client():
25 | raise NotImplementedError
26 |
27 | @abstractmethod
28 | async def init(self):
29 | raise NotImplementedError
30 |
31 | @abstractmethod
32 | async def create(self, documents: List[Document]):
33 | raise NotImplementedError
34 |
35 | @abstractmethod
36 | async def upsert(self, documents: List[Document]):
37 | raise NotImplementedError
38 |
39 | @abstractmethod
40 | async def all_meta(self, repo_id: str) -> List[DocumentMeta]:
41 | raise NotImplementedError
42 |
43 | @abstractmethod
44 | async def drop(self, repo_id: str):
45 | raise NotImplementedError
46 |
47 | @abstractmethod
48 | async def delete(self, repo_id: str, file_ids: List[str]):
49 | raise NotImplementedError
50 |
51 | @abstractmethod
52 | async def search_by_vector(
53 | self,
54 | repo_id: str,
55 | file_ids: Optional[List[str]],
56 | query_vector: List[float],
57 | top_k: int = 5,
58 | score_threshold: float = 0.0
59 | ) -> List[Document]:
60 | raise NotImplementedError
61 |
62 | @staticmethod
63 | def dynamic_collection_name(dimension: int) -> str:
64 | return f'we0_index_{settings.vector.embedding_model}_{dimension}'.replace('-', '_')
65 |
66 | # TODO 以后这边应该从仓库数据表中读取用户的`model_provider`和`model_name`
67 | # 前期先全部使用`openai`的`text-embedding-3-small`
68 | @classmethod
69 | async def get_embedding_model(cls):
70 | return await ModelFactory.get_model(
71 | model_provider=settings.vector.embedding_provider,
72 | model_name=settings.vector.embedding_model
73 | )
74 |
75 | @classmethod
76 | async def get_completions_model(cls):
77 | return await ModelFactory.get_model(
78 | model_provider=settings.vector.chat_provider,
79 | model_name=settings.vector.chat_model
80 | )
81 |
82 | @classmethod
83 | async def get_dimension(cls) -> int:
84 | embedding_model = await cls.get_embedding_model()
85 | vector_data_list = await embedding_model.create_embedding(['get_embedding_dimension'])
86 | dimension = len(vector_data_list[0])
87 | return dimension
88 |
--------------------------------------------------------------------------------
/loader/segmenter/base_segmenter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/12
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : base_segmenter
7 | # @Software: PyCharm
8 | from abc import ABC, abstractmethod
9 | from typing import Any, Callable, Collection, Union, Literal, Optional, Iterator
10 |
11 | from domain.entity.code_segment import CodeSegment
12 |
13 |
14 | class BaseSegmenter(ABC):
15 | def __init__(
16 | self,
17 | max_tokens: int = 512,
18 | length_function: Callable[[str], int] = len,
19 | merge_small_chunks: bool = False,
20 | ):
21 | self.max_tokens = max_tokens
22 | self.length_function = length_function
23 | self.merge_small_chunks = merge_small_chunks
24 |
25 | def is_valid(self) -> bool:
26 | return True
27 |
28 | @abstractmethod
29 | def segment(self) -> Iterator[CodeSegment]:
30 | raise NotImplementedError
31 |
32 | @classmethod
33 | def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any):
34 | """Text splitter that uses HuggingFace tokenizer to count length."""
35 | try:
36 | from transformers import PreTrainedTokenizerBase
37 |
38 | if not isinstance(tokenizer, PreTrainedTokenizerBase):
39 | raise ValueError(
40 | "Tokenizer received was not an instance of PreTrainedTokenizerBase"
41 | )
42 |
43 | def _huggingface_tokenizer_length(text: str) -> int:
44 | return len(tokenizer.encode(text))
45 |
46 | except ImportError:
47 | raise ValueError(
48 | "Could not import transformers python package. "
49 | "Please install it with `pip install transformers`."
50 | )
51 | return cls(length_function=_huggingface_tokenizer_length, **kwargs)
52 |
53 | @classmethod
54 | def from_tiktoken_encoder(
55 | cls,
56 | encoding_name: str = "cl100k_base",
57 | model_name: Optional[str] = None,
58 | allowed_special=None,
59 | disallowed_special: Union[Literal["all"], Collection[str]] = "all",
60 | **kwargs: Any,
61 | ):
62 | """Text splitter that uses tiktoken encoder to count length."""
63 | if allowed_special is None:
64 | allowed_special = set()
65 | import tiktoken
66 | if model_name:
67 | enc = tiktoken.encoding_for_model(model_name)
68 | else:
69 | enc = tiktoken.get_encoding(encoding_name)
70 |
71 | def _tiktoken_encoder(text: str) -> int:
72 | return len(
73 | enc.encode(
74 | text,
75 | allowed_special=allowed_special,
76 | disallowed_special=disallowed_special,
77 | )
78 | )
79 |
80 | return cls(length_function=_tiktoken_encoder, **kwargs)
81 |
--------------------------------------------------------------------------------
/config/loguru.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/07/17
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : loguru
7 | # @Software: PyCharm
8 | import inspect
9 | import logging
10 | import os
11 | import sys
12 | from typing import List, Dict
13 |
14 | from loguru import logger
15 |
16 | from constants.constants import Constants
17 | from setting.setting import get_we0_index_settings
18 | from utils.path_util import PathUtil
19 |
20 | sider_settings = get_we0_index_settings()
21 |
22 |
23 | class InterceptHandler(logging.Handler):
24 | def emit(self, record: logging.LogRecord) -> None:
25 | level: str | int
26 | try:
27 | level = logger.level(record.levelname).name
28 | except ValueError:
29 | level = record.levelno
30 | frame, depth = inspect.currentframe(), 0
31 | while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
32 | frame = frame.f_back
33 | depth += 1
34 |
35 | logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
36 |
37 |
38 | class Log:
39 | file_name = Constants.Common.PROJECT_NAME + '_{time}' + '.log'
40 | log_file_path = os.path.join(Constants.Path.LOG_PATH, file_name)
41 | logger_level = sider_settings.log.level
42 | logger_file = sider_settings.log.file
43 | DEFAULT_CONFIG = [
44 | {
45 | 'sink': sys.stdout,
46 | 'level': logger_level,
47 | 'format': '[{time:YYYY-MM-DD HH:mm:ss.SSS}][{level}]'
48 | '[{file}:{line}]: {message}',
49 | 'colorize': True, # 自定义配色
50 | 'serialize': False, # 序列化数据打印
51 | 'backtrace': True, # 是否显示完整的异常堆栈跟踪
52 | 'diagnose': True, # 异常跟踪是否显示触发异常的方法或语句所使用的变量,生产环境应设为 False
53 | 'enqueue': False, # 默认线程安全。若想实现协程安全 或 进程安全,该参数设为 True
54 | 'catch': True, # 捕获异常
55 |
56 | }
57 | ]
58 | if logger_file:
59 | DEFAULT_CONFIG.append({
60 | 'sink': log_file_path,
61 | 'level': logger_level,
62 | 'format': '[{time:YYYY-MM-DD HH:mm:ss.SSS}][{level}][{file}:{line}]: {message}',
63 | 'retention': '7 days', # 日志保留时间
64 | 'serialize': False, # 序列化数据打印
65 | 'backtrace': True, # 是否显示完整的异常堆栈跟踪
66 | 'diagnose': True, # 异常跟踪是否显示触发异常的方法或语句所使用的变量,生产环境应设为 False
67 | 'enqueue': False, # 默认线程安全。若想实现协程安全 或 进程安全,该参数设为 True
68 | 'catch': True, # 捕获异常
69 | })
70 |
71 | @staticmethod
72 | def start(config: List[Dict] | None = None) -> None:
73 | PathUtil.check_or_make_dirs(
74 | Constants.Path.LOG_PATH
75 | )
76 | if config:
77 | logger.configure(handlers=config)
78 | else:
79 | logger.configure(handlers=Log.DEFAULT_CONFIG)
80 | if sider_settings.log.debug:
81 | logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True)
82 | logger.enable('__main__')
83 |
84 | @staticmethod
85 | def close() -> None:
86 | logger.disable('__main__')
87 |
--------------------------------------------------------------------------------
/extensions/vector/ext_vector.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/14
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : ext_vector
7 | # @Software: PyCharm
8 | from typing import List, Optional
9 |
10 | from loguru import logger
11 |
12 | from domain.enums.vector_type import VectorType
13 | from domain.entity.document import Document, DocumentMeta
14 | from extensions.vector.base_vector import BaseVector
15 | from setting.setting import We0IndexSettings, get_we0_index_settings
16 |
17 | settings: We0IndexSettings = get_we0_index_settings()
18 |
19 |
20 | class Vector:
21 |
22 | def __init__(self):
23 | self.vector_runner: BaseVector | None = None
24 |
25 | async def create(self, documents: List[Document]):
26 | try:
27 | await self.vector_runner.create(documents)
28 | except Exception as e:
29 | logger.error(e)
30 | raise e
31 |
32 | async def upsert(self, documents: List[Document]):
33 | try:
34 | await self.vector_runner.upsert(documents)
35 | except Exception as e:
36 | logger.error(e)
37 | raise e
38 |
39 | async def all_meta(self, repo_id: str) -> List[DocumentMeta]:
40 | try:
41 | return await self.vector_runner.all_meta(repo_id)
42 | except Exception as e:
43 | logger.error(e)
44 | raise e
45 |
46 | async def drop(self, repo_id: str):
47 | try:
48 | await self.vector_runner.drop(repo_id)
49 | except Exception as e:
50 | logger.error(e)
51 | raise e
52 |
53 | async def delete(self, repo_id: str, file_ids: List[str]):
54 | try:
55 | await self.vector_runner.delete(repo_id, file_ids)
56 | except Exception as e:
57 | logger.error(e)
58 | raise e
59 |
60 | async def search_by_vector(
61 | self, repo_id: str, file_ids: Optional[List[str]], query_vector: List[float], top_k: int = 5
62 | ) -> List[Document]:
63 | try:
64 | return await self.vector_runner.search_by_vector(repo_id, file_ids, query_vector, top_k)
65 | except Exception as e:
66 | logger.error(e)
67 | raise e
68 |
69 | async def init_app(self) -> None:
70 | vector_constructor = self.get_vector_factory(settings.vector.platform)
71 | self.vector_runner = vector_constructor()
72 | await self.vector_runner.init()
73 |
74 | @staticmethod
75 | def get_vector_factory(vector_type: VectorType) -> type[BaseVector]:
76 | match vector_type:
77 | case VectorType.PGVECTOR:
78 | from extensions.vector.pgvector import PgVector
79 |
80 | return PgVector
81 | case VectorType.QDRANT:
82 | from extensions.vector.qdrant import Qdrant
83 |
84 | return Qdrant
85 | case VectorType.CHROMA:
86 | from extensions.vector.chroma import Chroma
87 |
88 | return Chroma
89 | case _:
90 | raise ValueError(f"Unknown storage type: {vector_type}")
91 |
92 | def __getattr__(self, item):
93 | if self.vector_runner is None:
94 | raise RuntimeError("Vector clients is not initialized. Call init_app first.")
95 | return getattr(self.vector_runner, item)
96 |
--------------------------------------------------------------------------------
/README-zh.md:
--------------------------------------------------------------------------------
1 | # We0-index
2 |
3 | 一个基于Python的类似Cursor的代码索引引擎,可将Git仓库代码转化为代码片段并生成语义embedding,根据用户query实现智能代码搜索和检索。
4 |
5 | ## 🚀 功能特性
6 |
7 | - **代码片段生成**:自动处理Git仓库,将代码转换为可搜索的代码片段
8 | - **语义搜索**:生成语义embeddings,基于用户查询实现智能代码检索
9 | - **多语言支持**:针对Python、Java、Go、JavaScript和TypeScript进行优化
10 | - **灵活后端**:支持多种向量数据库后端和embedding服务提供商
11 | - **MCP集成**:内置MCP(模型上下文协议)服务调用支持
12 | - **部署就绪**:为不同环境提供灵活的部署选项
13 |
14 | ## 📋 环境要求
15 |
16 | - Python 3.12+
17 | - Git
18 |
19 | ## 🛠️ 安装
20 |
21 | ### 快速开始
22 |
23 | ```bash
24 | # 克隆仓库
25 | git clone https://github.com/we0-dev/we0-index
26 | cd we0-index
27 |
28 | # 设置环境配置
29 | cp .env.example .env
30 | vim .env
31 |
32 | # 配置应用设置
33 | vim resource/dev.yaml
34 |
35 | # 创建虚拟环境并安装依赖
36 | uv venv
37 | # Linux/macOS
38 | source .venv/bin/activate
39 | # Windows
40 | .venv\Scripts\activate
41 | uv sync
42 | ```
43 |
44 | ### 开发环境设置
45 |
46 | ```bash
47 | # 安装开发依赖
48 | uv sync --frozen
49 | ```
50 |
51 | ## ⚙️ 配置
52 |
53 | 1. **环境变量**:将`.env.example`复制为`.env`并配置您的设置
54 | 2. **应用配置**:编辑`resource/dev.yaml`以自定义您的部署
55 | 3. **向量数据库**:配置您首选的向量数据库后端
56 | 4. **Embedding服务**:设置您的embedding服务提供商
57 |
58 | ## 🚀 启动服务
59 |
60 | We0-index支持两种运行模式:Web API服务和MCP协议服务。
61 |
62 | ### Web API 模式
63 |
64 | 启动FastAPI Web服务器,提供RESTful API接口:
65 |
66 | ```bash
67 | # 激活虚拟环境
68 | # Linux/macOS
69 | source .venv/bin/activate
70 | # Windows
71 | .venv\Scripts\activate
72 |
73 | # 启动Web服务
74 | python main.py --mode fastapi
75 | ```
76 |
77 | Web服务将在配置的主机和端口上启动(默认配置请查看`resource/dev.yaml`)。
78 |
79 | ### MCP 协议模式
80 |
81 | 启动MCP(模型上下文协议)服务,用于AI集成:
82 |
83 | ```bash
84 | # 激活虚拟环境
85 | # Linux/macOS
86 | source .venv/bin/activate
87 | # Windows
88 | .venv\Scripts\activate
89 |
90 | # 启动MCP服务(默认使用streamable-http传输协议)
91 | python main.py --mode mcp
92 |
93 | # 指定其他传输协议
94 | python main.py --mode mcp --transport stdio
95 | python main.py --mode mcp --transport websocket
96 | ```
97 |
98 | MCP服务默认使用streamable-http传输协议运行,可与支持MCP的AI客户端集成。
99 |
100 | ### 运行参数
101 |
102 | **模式参数**:
103 | - `--mode fastapi`:启动Web API服务
104 | - `--mode mcp`:启动MCP协议服务
105 |
106 | **传输协议参数**(仅适用于MCP模式):
107 | - `--transport streamable-http`:使用HTTP流传输(默认)
108 | - `--transport stdio`:使用标准输入输出传输
109 | - `--transport websocket`:使用WebSocket传输
110 |
111 |
112 | ## 🏗️ 架构
113 |
114 | We0-index采用模块化架构,支持:
115 |
116 | - **代码解析器**:特定语言的代码解析和片段提取
117 | - **Embedding引擎**:多种embedding服务集成
118 | - **向量存储**:可插拔的向量数据库后端
119 | - **搜索接口**:用于代码搜索的RESTful API和CLI
120 | - **MCP协议**:用于AI集成的模型上下文协议
121 |
122 | ## 🤝 贡献
123 |
124 | 我们欢迎贡献!请查看我们的[贡献指南](CONTRIBUTING.md)了解详情。
125 |
126 | 1. Fork 仓库
127 | 2. 创建您的功能分支 (`git checkout -b feature/amazing-feature`)
128 | 3. 提交您的更改 (`git commit -m 'Add some amazing feature'`)
129 | 4. 推送到分支 (`git push origin feature/amazing-feature`)
130 | 5. 打开一个Pull Request
131 |
132 | ## 📝 许可证
133 |
134 | 本项目基于MIT许可证 - 查看[LICENSE](LICENSE)文件了解详情。
135 |
136 |
137 | ## 📚 文档
138 |
139 | 详细文档请访问我们的[文档站点](https://docs.we0-dev.com)或查看`docs/`目录。
140 |
141 | ## 🐛 问题反馈
142 |
143 | 如果您遇到任何问题,请在GitHub上[创建issue](https://github.com/we0-dev/we0-index/issues)。
144 |
145 | ## 📞 支持
146 |
147 | - 📧 邮箱:we0@wegc.cn
148 | - 💬 讨论:[GitHub Discussions](https://github.com/we0-dev/we0-index/discussions)
149 | - 📖 Wiki:[项目Wiki](https://deepwiki.com/we0-dev/we0-index)
150 |
151 | ## 🌟 致谢
152 |
153 | - 感谢所有帮助改进这个项目的贡献者
154 | - 灵感来源于Cursor的代码智能方法
155 | - 使用现代Python工具和最佳实践构建
156 |
157 | ---
158 |
159 | **由We0-dev团队用❤️制作**
160 |
--------------------------------------------------------------------------------
/utils/vector_helper.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/23
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : vector_helper
7 | # @Software: PyCharm
8 | import asyncio
9 | import uuid
10 | from typing import List
11 |
12 | from fastapi import APIRouter
13 | from loguru import logger
14 |
15 | from domain.entity.document import Document, DocumentMeta
16 | from domain.entity.task_context import TaskContext
17 | from extensions.ext_manager import ExtManager
18 | from loader.repo_loader import RepoLoader
19 | from models.model_factory import ModelInstance
20 | from prompt.prompt import SystemMessageTemplate
21 | from setting.setting import get_we0_index_settings
22 | from utils.helper import Helper
23 |
24 | vector_router = APIRouter()
25 |
26 | settings = get_we0_index_settings()
27 |
28 |
29 | class VectorHelper:
30 | @staticmethod
31 | async def code2description(document: Document, chat_model: ModelInstance):
32 | document.meta.description = await chat_model.create_completions(
33 | messages=SystemMessageTemplate.ANALYZE_CODE_MESSAGE_TEMPLATE(document.content)
34 | )
35 |
36 | @staticmethod
37 | async def build_and_embedding_segment(task_context: TaskContext) -> List[Document]:
38 | try:
39 | documents: List[Document] = [
40 | Document(
41 | content=segment.code,
42 | meta=DocumentMeta(
43 | repo_id=task_context.repo_id,
44 | file_id=task_context.file_id,
45 | segment_id=f"{uuid.uuid4()}",
46 | relative_path=task_context.relative_path,
47 | start_line=segment.start,
48 | end_line=segment.end,
49 | segment_block=segment.block,
50 | segment_hash=Helper.generate_text_hash(segment.code),
51 | segment_cl100k_base_token=Helper.calculate_tokens(segment.code),
52 | segment_o200k_base_token=Helper.calculate_tokens(segment.code, 'o200k_base')
53 | )
54 | ) async for segment in RepoLoader.load_blob(task_context.blob)
55 | ]
56 | except UnicodeDecodeError as e:
57 | logger.error(e)
58 | documents = []
59 | if documents:
60 | embedding_model: ModelInstance = await ExtManager.vector.get_embedding_model()
61 | if settings.vector.code2desc:
62 | chat_model: ModelInstance = await ExtManager.vector.get_completions_model()
63 | await asyncio.wait([
64 | asyncio.create_task(VectorHelper.code2description(document=document, chat_model=chat_model))
65 | for document in documents
66 | ])
67 | vector_data: List[List[float]] = await embedding_model.create_embedding(
68 | [
69 | f"'{document.meta.relative_path}'\n'{document.meta.description}'\n{document.content}"
70 | for document in documents
71 | ]
72 | )
73 | else:
74 | vector_data: List[List[float]] = await embedding_model.create_embedding(
75 | [
76 | f"'{document.meta.relative_path}'\n{document.content}" for document in documents
77 | ]
78 | )
79 | for index, document in enumerate(documents):
80 | document.vector = vector_data[index]
81 | return documents
82 |
--------------------------------------------------------------------------------
/models/model_factory.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/15
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : model_factory
7 | # @Software: PyCharm
8 | import asyncio
9 | from typing import List, Tuple, Dict, Iterable
10 |
11 | from openai.types import CreateEmbeddingResponse
12 | from openai.types.chat import ChatCompletionMessageParam, ChatCompletion
13 |
14 | from clients import jina
15 | from domain.enums.model_provider import ModelType
16 |
17 |
18 | class ModelInstance:
19 |
20 | def __init__(self, model_type: ModelType, model_name: str):
21 | self.model_type = model_type
22 | self.model_name = model_name
23 |
24 | def get_completions_client(self):
25 | match self.model_type:
26 | case ModelType.OPENAI:
27 | import openai
28 | return openai.AsyncClient().chat.completions
29 | case _:
30 | raise Exception(f"Unknown model type: {self.model_type}")
31 |
32 | def get_embedding_client(self):
33 | match self.model_type:
34 | case ModelType.OPENAI:
35 | import openai
36 | return openai.AsyncClient().embeddings
37 | case ModelType.JINA:
38 | return jina.AsyncClient().embeddings
39 | case _:
40 | raise Exception(f"Unknown model type: {self.model_type}")
41 |
42 | async def create_embedding(self, documents: List[str]) -> List[List[float]]:
43 | match self.model_type:
44 | case ModelType.OPENAI | ModelType.JINA:
45 | docs_seq = list(documents)
46 | if len(docs_seq) > 2048:
47 | all_embeddings: List[List[float]] = []
48 | for start in range(0, len(docs_seq), 2048):
49 | batch = docs_seq[start: start + 2048]
50 | resp = await self.get_embedding_client().create(
51 | input=batch, model=self.model_name
52 | )
53 | all_embeddings.extend(d.embedding for d in resp.data)
54 | return all_embeddings
55 | else:
56 | resp = await self.get_embedding_client().create(
57 | input=docs_seq, model=self.model_name
58 | )
59 | return [d.embedding for d in resp.data]
60 | case _:
61 | raise Exception(f"Unknown model type: {self.model_type}")
62 |
63 | async def create_completions(self, messages: Iterable[ChatCompletionMessageParam]) -> str:
64 | match self.model_type:
65 | case ModelType.OPENAI:
66 | completions_response: ChatCompletion = await self.get_completions_client().create(
67 | model=self.model_name, messages=messages
68 | )
69 | return completions_response.choices[0].message.content
70 | case _:
71 | raise Exception(f"Unknown model type: {self.model_type}")
72 |
73 |
74 | class ModelFactory:
75 | _lock = asyncio.Lock()
76 | _instances: Dict[Tuple[ModelType, str], ModelInstance] = {}
77 |
78 | @classmethod
79 | async def get_model(cls, model_provider: ModelType, model_name: str) -> ModelInstance:
80 | key = (model_provider, model_name)
81 | if key not in ModelFactory._instances:
82 | async with cls._lock:
83 | if key not in ModelFactory._instances:
84 | instance = ModelInstance(model_provider, model_name)
85 | cls._instances[key] = instance
86 | return cls._instances[key]
87 |
--------------------------------------------------------------------------------
/domain/entity/blob.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/14
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : blob
7 | # @Software: PyCharm
8 | import asyncio
9 | from contextlib import asynccontextmanager, contextmanager
10 | from io import BytesIO
11 | from pathlib import PurePath
12 | from typing import Dict, Any, Generator, BinaryIO
13 |
14 | import aiofiles
15 | from pydantic import Field, model_validator, ConfigDict, BaseModel
16 |
17 |
18 | class Blob(BaseModel):
19 | id: str | None = None
20 | filename: str | None = None
21 |
22 | meta: Dict[str, Any] = Field(default_factory=lambda: list)
23 |
24 | data: bytes | str | None = None
25 | mimetype: str | None = None
26 | extension: str | None = None
27 | encoding: str = "utf-8"
28 | path: str | PurePath | None = None
29 |
30 | model_config = ConfigDict(
31 | arbitrary_types_allowed=True,
32 | frozen=True,
33 | )
34 |
35 | @classmethod
36 | @model_validator(mode="before")
37 | def check_blob_is_valid(cls, values: dict[str, Any]) -> Any:
38 | if "data" not in values and "path" not in values:
39 | msg = "Either data or path must be provided"
40 | raise ValueError(msg)
41 | return values
42 |
43 | async def as_string(self) -> str:
44 | """Read data as a string."""
45 | if self.data is None and self.path:
46 | async with aiofiles.open(str(self.path), mode='r', encoding=self.encoding) as f:
47 | return await f.read()
48 | elif isinstance(self.data, bytes):
49 | return self.data.decode(self.encoding)
50 | elif isinstance(self.data, str):
51 | return self.data
52 | else:
53 | msg = f"Unable to get string for blob {self}"
54 | raise ValueError(msg)
55 |
56 | def as_bytes(self) -> bytes:
57 | """Read data as bytes."""
58 | if isinstance(self.data, bytes):
59 | return self.data
60 | elif isinstance(self.data, str):
61 | return self.data.encode(self.encoding)
62 | elif self.data is None and self.path:
63 | with open(str(self.path), "rb") as f:
64 | return f.read()
65 | else:
66 | msg = f"Unable to get bytes for blob {self}"
67 | raise ValueError(msg)
68 |
69 | @contextmanager
70 | def as_bytes_io(self) -> Generator[BytesIO | BinaryIO, None, None]:
71 | """Read data as a byte stream."""
72 | if isinstance(self.data, bytes):
73 | yield BytesIO(self.data)
74 | elif self.data is None and self.path:
75 | with open(str(self.path), "rb") as f:
76 | yield f
77 | else:
78 | msg = f"Unable to convert blob {self}"
79 | raise NotImplementedError(msg)
80 |
81 | @asynccontextmanager
82 | async def as_async_bytes_io(self):
83 | if isinstance(self.data, bytes):
84 | reader = asyncio.StreamReader()
85 | reader.feed_data(self.data)
86 | reader.feed_eof()
87 | yield reader
88 | elif self.data is None and self.path:
89 | async with aiofiles.open(str(self.path), 'rb') as f:
90 | yield f
91 | else:
92 | msg = f"Unable to convert blob {self}"
93 | raise NotImplementedError(msg)
94 |
95 | async def write_to_file(self, file, chunks: int = 8192):
96 | async with self.as_async_bytes_io() as reader:
97 | if isinstance(reader, asyncio.StreamReader):
98 | while chunk := await reader.read(chunks):
99 | await file.write(chunk)
100 | else:
101 | await file.write(await reader.read())
102 |
103 | @classmethod
104 | def from_path(
105 | cls,
106 | path: str | PurePath,
107 | *,
108 | encoding: str = "utf-8",
109 | mimetype: str | None = None,
110 | extension: str | None = None,
111 | meta: Dict[str, Any] | None = None,
112 | ):
113 | return cls(
114 | data=None,
115 | mimetype=mimetype,
116 | extension=extension,
117 | encoding=encoding,
118 | path=path,
119 | meta=meta if meta is not None else {}
120 | )
121 |
122 | @classmethod
123 | def from_data(
124 | cls,
125 | data: str | bytes,
126 | *,
127 | encoding: str = "utf-8",
128 | mimetype: str | None = None,
129 | extension: str | None = None,
130 | path: str | None = None,
131 | meta: dict | None = None,
132 | ):
133 | return cls(
134 | data=data,
135 | mimetype=mimetype,
136 | extension=extension,
137 | encoding=encoding,
138 | path=path,
139 | meta=meta if meta is not None else {}
140 | )
141 |
142 | def __repr__(self) -> str:
143 | str_repr = f"Blob {id(self)}"
144 | return str_repr
145 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 | #poetry.toml
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
117 | .pdm.toml
118 | .pdm-python
119 | .pdm-build/
120 |
121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122 | __pypackages__/
123 |
124 | # Celery stuff
125 | celerybeat-schedule
126 | celerybeat.pid
127 |
128 | # SageMath parsed files
129 | *.sage.py
130 |
131 | # Environments
132 | .env
133 | .venv
134 | env/
135 | venv/
136 | ENV/
137 | env.bak/
138 | venv.bak/
139 |
140 | # Spyder project settings
141 | .spyderproject
142 | .spyproject
143 |
144 | # Rope project settings
145 | .ropeproject
146 |
147 | # mkdocs documentation
148 | /site
149 |
150 | # mypy
151 | .mypy_cache/
152 | .dmypy.json
153 | dmypy.json
154 |
155 | # Pyre type checker
156 | .pyre/
157 |
158 | # pytype static type analyzer
159 | .pytype/
160 |
161 | # Cython debug symbols
162 | cython_debug/
163 |
164 | # PyCharm
165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167 | # and can be added to the global gitignore or merged into this file. For a more nuclear
168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169 | #.idea/
170 |
171 | # Abstra
172 | # Abstra is an AI-powered process automation framework.
173 | # Ignore directories containing user credentials, local state, and settings.
174 | # Learn more at https://abstra.io/docs
175 | .abstra/
176 |
177 | # Visual Studio Code
178 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
179 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
180 | # and can be added to the global gitignore or merged into this file. However, if you prefer,
181 | # you could uncomment the following to ignore the entire vscode folder
182 | # .vscode/
183 |
184 | # Ruff stuff:
185 | .ruff_cache/
186 |
187 | # PyPI configuration file
188 | .pypirc
189 |
190 | # Cursor
191 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
192 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
193 | # refer to https://docs.cursor.com/context/ignore-files
194 | .cursorignore
195 | .cursorindexingignore
196 |
--------------------------------------------------------------------------------
/setting/setting.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/alias python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/07/17
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : setting
7 | # @Software: PyCharm
8 | import os.path
9 | from functools import lru_cache
10 | from typing import Type
11 |
12 | from pydantic import BaseModel, Field, model_validator
13 | from pydantic_settings import BaseSettings, SettingsConfigDict, PydanticBaseSettingsSource, YamlConfigSettingsSource
14 |
15 | from constants.constants import Constants
16 | from domain.enums.chroma_mode import ChromaMode
17 | from domain.enums.model_provider import ModelType
18 | from domain.enums.qdrant_mode import QdrantMode
19 | from domain.enums.vector_type import VectorType
20 |
21 |
22 | class ServerSettings(BaseModel):
23 | host: str = Field('0.0.0.0')
24 | port: int = Field(8080)
25 | reload: bool = Field(True)
26 |
27 |
28 | class LogSettings(BaseModel):
29 | level: str = Field(default="INFO")
30 | file: bool = Field(default=False)
31 | debug: bool = Field(default=False)
32 |
33 |
34 | class PGVectorSettings(BaseSettings):
35 | db: str
36 | host: str
37 | port: int
38 | user: str
39 | password: str
40 |
41 |
42 | class QdrantDiskSettings(BaseSettings):
43 | path: str = Field(default=Constants.Path.QDRANT_DEFAULT_DISK_PATH)
44 |
45 | @model_validator(mode='before')
46 | def handle_path(self):
47 | if not os.path.isabs(self['path']):
48 | self['path'] = os.path.join(Constants.Path.ROOT_PATH, self['path'])
49 | return self
50 |
51 |
52 | class QdrantRemoteSettings(BaseSettings):
53 | host: str
54 | port: int = Field(default=6333)
55 |
56 |
57 | class QdrantSettings(BaseSettings):
58 | mode: QdrantMode
59 | disk: QdrantDiskSettings | None
60 | remote: QdrantRemoteSettings | None
61 | memory: None
62 |
63 | @model_validator(mode='before')
64 | def clear_conflicting_settings(self):
65 | for key in [member for member in QdrantMode if member != self['mode']]:
66 | self[key] = None
67 | return self
68 |
69 |
70 | class ChromaDiskSettings(BaseSettings):
71 | path: str = Field(default=Constants.Path.QDRANT_DEFAULT_DISK_PATH)
72 |
73 | @model_validator(mode='before')
74 | def handle_path(self):
75 | if not os.path.isabs(self['path']):
76 | self['path'] = os.path.join(Constants.Path.ROOT_PATH, self['path'])
77 | return self
78 |
79 |
80 | class ChromaRemoteSettings(BaseSettings):
81 | host: str
82 | port: int = Field(default=6333)
83 | ssl: bool = Field(default=False)
84 |
85 |
86 | class ChromaSettings(BaseSettings):
87 | mode: ChromaMode
88 | disk: ChromaDiskSettings | None
89 | remote: ChromaRemoteSettings | None
90 | memory: None
91 |
92 | @model_validator(mode='before')
93 | def clear_conflicting_settings(self):
94 | for key in [member for member in ChromaMode if member != self['mode']]:
95 | self[key] = None
96 | return self
97 |
98 |
99 | class VectorSettings(BaseSettings):
100 | platform: VectorType
101 | code2desc: bool = Field(default=False)
102 | chat_provider: ModelType = Field(default='openai', alias='chat-provider')
103 | chat_model: str = Field(default='gpt-4o-mini', alias='chat-model')
104 | embedding_provider: ModelType = Field(default='openai', alias='embedding-provider')
105 | embedding_model: str = Field(default='text-embedding-3-small', alias='embedding-model')
106 | pgvector: PGVectorSettings | None
107 | qdrant: QdrantSettings | None
108 | chroma: ChromaSettings | None
109 |
110 | @model_validator(mode='before')
111 | def clear_conflicting_settings(self):
112 | for key in [member for member in VectorType if member != self['platform']]:
113 | self[key] = None
114 | return self
115 |
116 |
117 | class We0IndexSettings(BaseModel):
118 | application: str
119 | log: LogSettings
120 | server: ServerSettings
121 | vector: VectorSettings
122 |
123 |
124 | class AppSettings(BaseSettings):
125 | we0_index: We0IndexSettings | None = Field(default=None, alias='we0-index')
126 |
127 | model_config = SettingsConfigDict(
128 | yaml_file=Constants.Path.YAML_FILE_PATH,
129 | yaml_file_encoding='utf-8',
130 | extra='ignore'
131 | )
132 |
133 | @classmethod
134 | def settings_customise_sources(
135 | cls,
136 | settings_cls: Type[BaseSettings],
137 | init_settings: PydanticBaseSettingsSource,
138 | env_settings: PydanticBaseSettingsSource,
139 | dotenv_settings: PydanticBaseSettingsSource,
140 | file_secret_settings: PydanticBaseSettingsSource,
141 | ) -> tuple[PydanticBaseSettingsSource, ...]:
142 | return (
143 | init_settings,
144 | YamlConfigSettingsSource(settings_cls),
145 | env_settings,
146 | dotenv_settings,
147 | file_secret_settings,
148 | )
149 |
150 |
151 | @lru_cache
152 | def get_we0_index_settings() -> We0IndexSettings:
153 | app_settings = AppSettings()
154 | return app_settings.we0_index
155 |
156 |
157 | if __name__ == '__main__':
158 | settings = get_we0_index_settings()
159 | print(settings.model_dump_json(indent=4))
160 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # We0-index
2 |
3 | [English](README.md) | [中文](README-zh.md)
4 |
5 | A Python-based code indexing engine similar to Cursor, designed to transform Git repository code into code snippets with semantic embeddings for intelligent code search and retrieval.
6 |
7 | ## 🚀 Features
8 |
9 | - **Code Snippet Generation**: Automatically processes Git repositories and converts code into searchable snippets
10 | - **Semantic Search**: Generates semantic embeddings for intelligent code retrieval based on user queries
11 | - **Multi-Language Support**: Optimized for Python, Java, Go, JavaScript, and TypeScript
12 | - **Flexible Backends**: Supports multiple vector database backends and embedding service providers
13 | - **MCP Integration**: Built-in support for MCP (Model Context Protocol) service calls
14 | - **Deployment Ready**: Flexible deployment options for different environments
15 |
16 | ## 📋 Requirements
17 |
18 | - Python 3.12+
19 | - Git
20 |
21 | ## 🛠️ Installation
22 |
23 | ### Quick Start
24 |
25 | ```bash
26 | # Clone the repository
27 | git clone https://github.com/we0-dev/we0-index
28 | cd we0-index
29 |
30 | # Set up environment configuration
31 | cp .env.example .env
32 | vim .env
33 |
34 | # Configure application settings
35 | vim resource/dev.yaml
36 |
37 | # Create virtual environment and install dependencies
38 | uv venv
39 | # Linux/macOS
40 | source .venv/bin/activate
41 | # Windows
42 | .venv\Scripts\activate
43 | uv sync
44 | ```
45 |
46 | ### Development Setup
47 |
48 | ```bash
49 | # Install development dependencies
50 | uv sync --frozen
51 | ```
52 |
53 | ## ⚙️ Configuration
54 |
55 | 1. **Environment Variables**: Copy `.env.example` to `.env` and configure your settings
56 | 2. **Application Config**: Edit `resource/dev.yaml` to customize your deployment
57 | 3. **Vector Database**: Configure your preferred vector database backend
58 | 4. **Embedding Service**: Set up your embedding service provider
59 |
60 | ## 🚀 Running the Service
61 |
62 | We0-index supports two running modes: Web API service and MCP protocol service.
63 |
64 | ### Web API Mode
65 |
66 | Start the FastAPI web server to provide RESTful API endpoints:
67 |
68 | ```bash
69 | # Activate virtual environment
70 | # Linux/macOS
71 | source .venv/bin/activate
72 | # Windows
73 | .venv\Scripts\activate
74 |
75 | # Start web service
76 | python main.py --mode fastapi
77 | ```
78 |
79 | The web service will start on the configured host and port (check `resource/dev.yaml` for default configuration).
80 |
81 | ### MCP Protocol Mode
82 |
83 | Start the MCP (Model Context Protocol) service for AI integration:
84 |
85 | ```bash
86 | # Activate virtual environment
87 | # Linux/macOS
88 | source .venv/bin/activate
89 | # Windows
90 | .venv\Scripts\activate
91 |
92 | # Start MCP service (default with streamable-http transport)
93 | python main.py --mode mcp
94 |
95 | # Specify other transport protocols
96 | python main.py --mode mcp --transport stdio
97 | python main.py --mode mcp --transport sse
98 | ```
99 |
100 | The MCP service runs with streamable-http transport by default and can be integrated with MCP-compatible AI clients.
101 |
102 | ### Runtime Parameters
103 |
104 | **Mode Parameters**:
105 | - `--mode fastapi`: Start Web API service
106 | - `--mode mcp`: Start MCP protocol service
107 |
108 | **Transport Parameters** (only applicable for MCP mode):
109 | - `--transport streamable-http`: Use HTTP streaming transport (default)
110 | - `--transport stdio`: Use standard input/output transport
111 | - `--transport sse`: Use sse transport
112 |
113 |
114 |
115 | ## 🏗️ Architecture
116 |
117 | We0-index is built with a modular architecture supporting:
118 |
119 | - **Code Parsers**: Language-specific code parsing and snippet extraction
120 | - **Embedding Engines**: Multiple embedding service integrations
121 | - **Vector Stores**: Pluggable vector database backends
122 | - **Search Interface**: RESTful API and CLI for code search
123 | - **MCP Protocol**: Model Context Protocol for AI integration
124 |
125 | ## 🤝 Contributing
126 |
127 | We welcome contributions! Please see our [Contributing Guidelines](CONTRIBUTING.md) for details.
128 |
129 | 1. Fork the repository
130 | 2. Create your feature branch (`git checkout -b feature/amazing-feature`)
131 | 3. Commit your changes (`git commit -m 'Add some amazing feature'`)
132 | 4. Push to the branch (`git push origin feature/amazing-feature`)
133 | 5. Open a Pull Request
134 |
135 | ## 📝 License
136 |
137 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
138 |
139 |
140 |
141 | ## 📚 Documentation
142 |
143 | For detailed documentation, please visit our [Documentation Site](https://docs.we0-dev.com) or check the `docs/` directory.
144 |
145 | ## 🐛 Issues
146 |
147 | If you encounter any issues, please [create an issue](https://github.com/we0-dev/we0-index/issues) on GitHub.
148 |
149 | ## 📞 Support
150 |
151 | - 📧 Email: we0@wegc.cn
152 | - 💬 Discussions: [GitHub Discussions](https://github.com/we0-dev/we0-index/discussions)
153 | - 📖 Wiki: [Project Wiki](https://deepwiki.com/we0-dev/we0-index)
154 |
155 | ## 🌟 Acknowledgments
156 |
157 | - Thanks to all contributors who have helped make this project better
158 | - Inspired by Cursor's approach to code intelligence
159 | - Built with modern Python tooling and best practices
160 |
161 | ---
162 |
163 | **Made with ❤️ by the We0-dev Team**
--------------------------------------------------------------------------------
/router/git_router.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | import asyncio
4 | import os
5 | from typing import List, Optional
6 | from urllib.parse import quote
7 |
8 | import aiofiles
9 | from fastapi import APIRouter
10 | from git import Repo
11 | from loguru import logger
12 |
13 | from domain.entity.blob import Blob
14 | from domain.entity.document import Document
15 | from domain.entity.task_context import TaskContext
16 | from domain.request.git_index_request import GitIndexRequest
17 | from domain.response.add_index_response import AddIndexResponse, FileInfoResponse
18 | from domain.result.result import Result
19 | from extensions.ext_manager import ExtManager
20 | from setting.setting import get_we0_index_settings
21 | from utils.git_parse import parse_git_url
22 | from utils.helper import Helper
23 | from utils.mimetype_util import guess_mimetype_and_extension
24 | from utils.vector_helper import VectorHelper
25 | from asyncio import Semaphore
26 | git_router = APIRouter()
27 | settings = get_we0_index_settings()
28 | async_semaphore = Semaphore(100)
29 |
30 | async def _process_file(uid: str, repo_id: str, repo_path: str, base_dir, relative_path: str) -> FileInfoResponse:
31 | logger.info(f'Processing file {relative_path}')
32 | mimetype, extension = guess_mimetype_and_extension(relative_path)
33 | file_id = Helper.generate_fixed_uuid(f"{uid}:{repo_path}:{relative_path}")
34 | file_path = os.path.join(base_dir, relative_path)
35 | async with aiofiles.open(file_path, 'rb') as f:
36 | content = await f.read()
37 |
38 | task_context = TaskContext(
39 | repo_id=repo_id,
40 | file_id=file_id,
41 | relative_path=relative_path,
42 | blob=Blob.from_data(
43 | data=content,
44 | mimetype=mimetype,
45 | extension=extension
46 | )
47 | )
48 |
49 | try:
50 | async with async_semaphore:
51 | documents: List[Document] = await VectorHelper.build_and_embedding_segment(task_context)
52 | if documents:
53 | await ExtManager.vector.upsert(documents)
54 | except Exception as e:
55 | raise e
56 |
57 | return FileInfoResponse(
58 | file_id=file_id,
59 | relative_path=relative_path
60 | )
61 |
62 |
63 | def _prepare_repo_url_with_auth(repo_url: str, username: Optional[str] = None, password: Optional[str] = None,
64 | access_token: Optional[str] = None) -> str:
65 | """
66 | 为私有仓库准备带有认证信息的 URL
67 | 对用户名和密码进行 URL 编码,以处理包含特殊字符(如 @ 符号)的情况
68 | 支持三种认证方式:
69 | 1. access_token: 使用个人访问令牌(推荐)
70 | 2. username + password: 使用用户名和密码
71 | 3. 无认证: 公开仓库
72 | """
73 | if repo_url.startswith('https://'):
74 | if access_token:
75 | # 使用 access_token 进行认证
76 | # 对于大多数 Git 服务提供商(GitHub, GitLab 等),使用 token 作为用户名,密码可以为空或使用 token
77 | encoded_token = quote(access_token, safe='')
78 | url_without_protocol = repo_url[8:]
79 | # 使用 token 作为用户名,密码为 'x-oauth-basic'(GitHub)或空字符串
80 | return f'https://{encoded_token}:x-oauth-basic@{url_without_protocol}'
81 | elif username and password:
82 | # 对用户名和密码进行 URL 编码以处理特殊字符
83 | encoded_username = quote(username, safe='')
84 | encoded_password = quote(password, safe='')
85 | # 使用编码后的用户名和密码
86 | url_without_protocol = repo_url[8:]
87 | return f'https://{encoded_username}:{encoded_password}@{url_without_protocol}'
88 | return repo_url
89 |
90 |
91 | async def clone_and_index(git_index_request: GitIndexRequest) -> Result[AddIndexResponse]:
92 | """
93 | Tool parameters must be in standard JSON format!
94 | "git_index_request": {
95 | "xxx": "xxx"
96 | }
97 | 克隆 Git 仓库并对所有文件进行索引
98 | 支持私有仓库访问:
99 | - Access Token 认证 (推荐): 提供 access_token 参数
100 | - SSH 密钥认证: 使用 git@github.com:user/repo.git 格式的 URL,可选择提供 ssh_key_path
101 | - HTTPS + 用户名密码: 提供 username 和 password 参数
102 | """
103 | try:
104 | if not git_index_request.uid:
105 | git_index_request.uid = 'default_uid'
106 |
107 | domain, owner, repo = parse_git_url(git_index_request.repo_url)
108 | repo_abs_path = f'{domain}/{owner}/{repo}'
109 | repo_id = Helper.generate_fixed_uuid(f"{git_index_request.uid}{repo_abs_path}:")
110 |
111 | # 准备认证后的仓库 URL
112 | auth_repo_url = _prepare_repo_url_with_auth(
113 | git_index_request.repo_url,
114 | git_index_request.username,
115 | git_index_request.password,
116 | git_index_request.access_token
117 | )
118 |
119 | file_count = 0 # 初始化 file_count 变量
120 |
121 | async with aiofiles.tempfile.TemporaryDirectory() as tmp_dir:
122 | try:
123 | await asyncio.to_thread(
124 | Repo.clone_from,
125 | auth_repo_url,
126 | tmp_dir
127 | )
128 | except Exception as e:
129 | logger.error(f'{type(e).__name__}: {e}')
130 | raise e
131 | # 遍历并处理文件
132 | tasks = []
133 | for root, dirs, files in os.walk(tmp_dir):
134 | # 忽略 .git 等隐藏目录
135 | dirs[:] = [d for d in dirs if not d.startswith('.')]
136 | for file in files:
137 | if not file.startswith('.'):
138 | relative_path = os.path.relpath(os.path.join(root, file), start=tmp_dir)
139 | tasks.append(asyncio.create_task(_process_file(
140 | uid=git_index_request.uid,
141 | repo_id=repo_id,
142 | repo_path=repo_abs_path,
143 | base_dir=tmp_dir,
144 | relative_path=relative_path
145 | )))
146 | for task in asyncio.as_completed(tasks):
147 | await task
148 | file_count += 1
149 |
150 | logger.info(f"Successfully processed {file_count} files from repository {repo_abs_path}")
151 |
152 | return Result.ok(data=AddIndexResponse(
153 | repo_id=repo_id,
154 | file_infos=[]
155 | ))
156 | except Exception as e:
157 | logger.exception(e)
158 | return Result.failed(message=f"{type(e).__name__}: {e}")
159 |
160 |
161 | git_router.add_api_route('/clone_and_index', clone_and_index, methods=['POST'], response_model=Result[AddIndexResponse])
162 |
--------------------------------------------------------------------------------
/extensions/vector/chroma.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/23
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : chroma
7 | # @Software: PyCharm
8 | import asyncio
9 | import inspect
10 | from typing import List, Optional
11 |
12 | import chromadb
13 | from chromadb import QueryResult
14 |
15 | from domain.enums.chroma_mode import ChromaMode
16 | from domain.entity.document import Document, DocumentMeta
17 | from extensions.vector.base_vector import BaseVector
18 | from setting.setting import get_we0_index_settings
19 |
20 | settings = get_we0_index_settings()
21 |
22 |
23 | class Chroma(BaseVector):
24 |
25 | def __init__(self):
26 | self.client: None = None
27 | self.collection_name: str | None = None
28 |
29 | @staticmethod
30 | def get_client():
31 | return None
32 |
33 | async def init(self):
34 | if self.client is None:
35 | chroma = settings.vector.chroma
36 | match chroma.mode:
37 | case ChromaMode.MEMORY:
38 | self.client = chromadb.Client()
39 | case ChromaMode.DISK:
40 | self.client = chromadb.PersistentClient(path=chroma.disk.path)
41 | case ChromaMode.REMOTE:
42 | self.client = await chromadb.AsyncHttpClient(
43 | host=chroma.remote.host, port=chroma.remote.port, ssl=chroma.remote.ssl
44 | )
45 | case _:
46 | raise ValueError(f'Unknown chroma mode: {chroma.mode}')
47 | dimension = await self.get_dimension()
48 | self.collection_name = self.dynamic_collection_name(dimension)
49 | await self._execute_async_or_thread(func=self.client.get_or_create_collection, name=self.collection_name)
50 |
51 | async def create(self, documents: List[Document]):
52 | collection = await self._execute_async_or_thread(
53 | func=self.client.get_or_create_collection,
54 | name=self.collection_name
55 | )
56 | ids = [document.meta.segment_id for document in documents]
57 | vectors = [document.vector for document in documents]
58 | metas = [document.meta.model_dump(exclude_none=True) for document in documents]
59 | contents = [document.content for document in documents]
60 | await self._execute_async_or_thread(
61 | func=collection.upsert, ids=ids, embeddings=vectors, metadatas=metas, documents=contents
62 | )
63 |
64 | async def upsert(self, documents: List[Document]):
65 | repo_id = documents[0].meta.repo_id
66 | file_ids = list(set(document.meta.file_id for document in documents))
67 | await self.delete(repo_id, file_ids)
68 | await self.create(documents)
69 |
70 | async def all_meta(self, repo_id: str) -> List[DocumentMeta]:
71 | collection = await self._execute_async_or_thread(
72 | func=self.client.get_or_create_collection,
73 | name=self.collection_name
74 | )
75 | results: QueryResult = await self._execute_async_or_thread(
76 | func=collection.get, where={
77 | 'repo_id': {
78 | '$eq': repo_id
79 | }
80 | }
81 | )
82 | metadatas = results.get('metadatas', [])
83 | if len(metadatas) == 0:
84 | return []
85 | metas = metadatas[0]
86 | return [DocumentMeta.model_validate(meta) for meta in metas]
87 |
88 | async def drop(self, repo_id: str):
89 | collection = await self._execute_async_or_thread(
90 | func=self.client.get_or_create_collection,
91 | name=self.collection_name
92 | )
93 | await self._execute_async_or_thread(
94 | func=collection.delete, where={
95 | 'repo_id': {
96 | "$eq": repo_id
97 | }
98 | }
99 | )
100 |
101 | async def delete(self, repo_id: str, file_ids: List[str]):
102 | collection = await self._execute_async_or_thread(
103 | func=self.client.get_or_create_collection,
104 | name=self.collection_name
105 | )
106 | await self._execute_async_or_thread(collection.delete, where={
107 | "$and": [
108 | {"repo_id": {"$eq": repo_id}},
109 | {"file_id": {"$in": file_ids}}
110 | ]
111 | })
112 |
113 | async def search_by_vector(
114 | self,
115 | repo_id: str,
116 | file_ids: Optional[List[str]],
117 | query_vector: List[float],
118 | top_k: int = 5,
119 | score_threshold: float = 0.0
120 | ) -> List[Document]:
121 | collection = await self._execute_async_or_thread(
122 | func=self.client.get_or_create_collection,
123 | name=self.collection_name
124 | )
125 |
126 | if not file_ids:
127 | where = {
128 | 'repo_id': {
129 | '$eq': repo_id
130 | }
131 | }
132 | else:
133 | where = {
134 | "$and": [
135 | {"repo_id": {"$eq": repo_id}},
136 | {"file_id": {"$in": file_ids}}
137 | ]
138 | }
139 | results: QueryResult = await self._execute_async_or_thread(
140 | func=collection.query, query_embeddings=query_vector, n_results=top_k, where=where
141 | )
142 | ids = results.get('ids', [])
143 | if len(ids) == 0:
144 | return []
145 | idx = ids[0]
146 | metas = results.get('metadatas', [])[0]
147 | distances = results.get('distances', [])[0]
148 | contents = results.get('documents', [])[0]
149 |
150 | documents: List[Document] = []
151 | for index in range(len(idx)):
152 | distance = distances[index]
153 | metadata = dict(metas[index])
154 | if distance >= score_threshold:
155 | metadata["score"] = distance
156 | metadata["content"] = contents[index]
157 | document = Document(
158 | meta=DocumentMeta.model_validate(metadata),
159 | )
160 | documents.append(document)
161 |
162 | return sorted(documents, key=lambda x: x.meta.score)
163 |
164 | @staticmethod
165 | async def _execute_async_or_thread(func, *args, **kwargs):
166 | if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
167 | return await func(*args, **kwargs)
168 | else:
169 | return await asyncio.to_thread(func, *args, **kwargs)
170 |
--------------------------------------------------------------------------------
/router/vector_router.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2024/10/11
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : vector_router
7 | # @Software: PyCharm
8 | import asyncio
9 | from typing import Annotated
10 | from typing import List
11 |
12 | from fastapi import APIRouter, UploadFile
13 | from fastapi.params import File, Form
14 |
15 | from domain.entity.blob import Blob
16 | from domain.entity.document import DocumentMeta, Document
17 | from domain.entity.task_context import TaskContext
18 | from domain.request.add_index_request import AddIndexRequest, AddFileInfo
19 | from domain.request.all_index_request import AllIndexRequest
20 | from domain.request.delete_index_request import DeleteIndexRequest
21 | from domain.request.drop_index_request import DropIndexRequest
22 | from domain.request.retrieval_request import RetrievalRequest
23 | from domain.response.add_index_by_file_response import AddIndexByFileResponse
24 | from domain.response.add_index_response import AddIndexResponse, FileInfoResponse
25 | from domain.result.result import Result
26 | from extensions.ext_manager import ExtManager
27 | from models.model_factory import ModelInstance
28 | from setting.setting import get_we0_index_settings
29 | from utils.helper import Helper
30 | from utils.mimetype_util import guess_mimetype_and_extension
31 | from utils.vector_helper import VectorHelper
32 |
33 | vector_router = APIRouter()
34 |
35 | settings = get_we0_index_settings()
36 |
37 |
38 | async def _upsert_index(uid: str, repo_abs_path: str, repo_id: str, file_info: AddFileInfo) -> FileInfoResponse:
39 | mimetype, extension = guess_mimetype_and_extension(file_info.relative_path)
40 | file_id = Helper.generate_fixed_uuid(
41 | f"{uid}:{repo_abs_path}:{file_info.relative_path}"
42 | )
43 |
44 | task_context = TaskContext(
45 | repo_id=repo_id,
46 | file_id=file_id,
47 | relative_path=file_info.relative_path,
48 | blob=Blob.from_data(
49 | data=file_info.content,
50 | mimetype=mimetype,
51 | extension=extension
52 | )
53 | )
54 | try:
55 | documents: List[Document] = await VectorHelper.build_and_embedding_segment(task_context)
56 | if documents:
57 | await ExtManager.vector.upsert(documents)
58 | except Exception as e:
59 | raise e
60 | return FileInfoResponse(
61 | file_id=file_id,
62 | relative_path=file_info.relative_path
63 | )
64 |
65 |
66 | @vector_router.post("/upsert_index", response_model=Result[AddIndexResponse])
67 | async def upsert_index(add_index_request: AddIndexRequest):
68 | """
69 | 新增或追加索引
70 | 请务必使用JSON.stringify对文本进行转义,确保格式正确,否则AST结构将被破坏\n
71 | `fs.readFile(filePath, 'utf8', (err, fileContent) => {const jsonStr = JSON.stringify(data);}`\n
72 | """
73 | repo_id = Helper.generate_fixed_uuid(f"{add_index_request.uid}:{add_index_request.repo_abs_path}")
74 |
75 | tasks = [
76 | asyncio.create_task(_upsert_index(
77 | uid=add_index_request.uid,
78 | repo_abs_path=add_index_request.repo_abs_path,
79 | repo_id=repo_id,
80 | file_info=file_info
81 | )) for file_info in add_index_request.file_infos
82 | ]
83 | file_infos = [await task for task in asyncio.as_completed(tasks)]
84 | return Result.ok(data=AddIndexResponse(repo_id=repo_id, file_infos=file_infos))
85 |
86 |
87 | @vector_router.post('/upsert_index_by_file', response_model=Result[AddIndexByFileResponse])
88 | async def upsert_index_by_file(
89 | uid: Annotated[str, Form(description='Unique ID')],
90 | repo_abs_path: Annotated[str, Form(description='Repository Absolute Path')],
91 | relative_path: Annotated[str, Form(description='File Relative Path')],
92 | file: UploadFile = File(...),
93 | ):
94 | """
95 | 新增或追加索引(通过文件)
96 | """
97 | repo_id = Helper.generate_fixed_uuid(f"{uid}:{repo_abs_path}")
98 | mimetype, extension = guess_mimetype_and_extension(relative_path)
99 |
100 | file_id = Helper.generate_fixed_uuid(f"{uid}:{repo_abs_path}:{relative_path}")
101 |
102 | task_context = TaskContext(
103 | repo_id=repo_id,
104 | file_id=file_id,
105 | relative_path=relative_path,
106 | blob=Blob.from_data(
107 | data=await file.read(),
108 | mimetype=mimetype,
109 | extension=extension
110 | )
111 | )
112 | try:
113 | documents: List[Document] = await VectorHelper.build_and_embedding_segment(task_context)
114 | if documents:
115 | await ExtManager.vector.upsert(documents)
116 | return Result.ok(data=AddIndexByFileResponse(repo_id=repo_id, file_id=file_id))
117 | else:
118 | return Result.failed(message=f"Not Content")
119 | except Exception as e:
120 | return Result.failed(message=f"{type(e).__name__}: {e}")
121 |
122 |
123 | @vector_router.post('/drop_index', response_model=Result)
124 | async def drop_index(drop_index_request: DropIndexRequest):
125 | """
126 | 删除索引的全部向量
127 | """
128 | try:
129 | await ExtManager.vector.drop(repo_id=drop_index_request.repo_id)
130 | return Result.ok()
131 | except Exception as e:
132 | return Result.failed(message=f"{type(e).__name__}: {e}")
133 |
134 |
135 | @vector_router.post('/delete_index', response_model=Result)
136 | async def delete_index(delete_index_request: DeleteIndexRequest):
137 | """
138 | 删除索引的指定向量
139 | """
140 | try:
141 | await ExtManager.vector.delete(repo_id=delete_index_request.repo_id, file_ids=delete_index_request.file_ids)
142 | return Result.ok()
143 | except Exception as e:
144 | return Result.failed(message=f"{type(e).__name__}: {e}")
145 |
146 |
147 | @vector_router.post('/all_index', response_model=Result)
148 | async def all_index(all_index_request: AllIndexRequest):
149 | try:
150 | all_meta = await ExtManager.vector.all_meta(repo_id=all_index_request.repo_id)
151 | return Result.ok(data=all_meta)
152 | except Exception as e:
153 | return Result.failed(message=f"{type(e).__name__}: {e}")
154 |
155 |
156 | async def retrieval(
157 | retrieval_request: RetrievalRequest
158 | ) -> Result[List[DocumentMeta]]:
159 | """
160 | Tool parameters must be in standard JSON format!
161 | "retrieval_request": {
162 | "xxx": "xxx"
163 | }
164 | 相似度匹配,从整个仓库或指定仓库的部分文件
165 | """
166 | try:
167 | embedding_model: ModelInstance = await ExtManager.vector.get_embedding_model()
168 | vector_data = await embedding_model.create_embedding([retrieval_request.query])
169 | documents = await ExtManager.vector.search_by_vector(
170 | repo_id=retrieval_request.repo_id, file_ids=retrieval_request.file_ids, query_vector=vector_data[0]
171 | )
172 | retrieval_segment_list = [document.meta for document in documents]
173 | return Result.ok(data=retrieval_segment_list)
174 | except Exception as e:
175 | return Result.failed(message=f"{type(e).__name__}: {e}")
176 |
177 |
178 | vector_router.add_api_route('/retrieval', retrieval, methods=['POST'], response_model=Result[List[DocumentMeta]])
179 |
--------------------------------------------------------------------------------
/extensions/vector/qdrant.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/22
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : qdrant
7 | # @Software: PyCharm
8 |
9 | from typing import List, Optional
10 |
11 | from qdrant_client.async_qdrant_client import AsyncQdrantClient
12 | from qdrant_client.http import models as rest
13 | from qdrant_client.http.exceptions import UnexpectedResponse
14 |
15 | from domain.enums.qdrant_mode import QdrantMode
16 | from domain.entity.document import Document, DocumentMeta
17 | from extensions.vector.base_vector import BaseVector
18 | from setting.setting import get_we0_index_settings
19 |
20 | settings = get_we0_index_settings()
21 |
22 |
23 | class Qdrant(BaseVector):
24 |
25 | def __init__(self):
26 | self.client = self.get_client()
27 | self.collection_name: str | None = None
28 |
29 | @staticmethod
30 | def get_client():
31 | qdrant = settings.vector.qdrant
32 | match qdrant.mode:
33 | case QdrantMode.MEMORY:
34 | return AsyncQdrantClient(location=':memory:')
35 | case QdrantMode.DISK:
36 | return AsyncQdrantClient(path=qdrant.disk.path)
37 | case QdrantMode.REMOTE:
38 | return AsyncQdrantClient(host=qdrant.remote.host, port=qdrant.remote.port)
39 | case _:
40 | raise ValueError(f'Unknown qdrant mode: {qdrant.mode}')
41 |
42 | async def init(self):
43 | collection_names = []
44 | dimension = await self.get_dimension()
45 | self.collection_name = self.dynamic_collection_name(dimension)
46 | collections: rest.CollectionsResponse = await self.client.get_collections()
47 | for collection in collections.collections:
48 | collection_names.append(collection.name)
49 | if self.collection_name not in collection_names:
50 | vectors_config = rest.VectorParams(
51 | size=dimension,
52 | distance=rest.Distance.COSINE
53 | )
54 | hnsw_config = rest.HnswConfigDiff(
55 | m=0,
56 | payload_m=16,
57 | ef_construct=100,
58 | full_scan_threshold=10000,
59 | max_indexing_threads=0,
60 | on_disk=False,
61 | )
62 | await self.client.create_collection(
63 | collection_name=self.collection_name,
64 | vectors_config=vectors_config,
65 | hnsw_config=hnsw_config,
66 | timeout=30
67 | )
68 | if settings.vector.qdrant.mode != QdrantMode.DISK:
69 | await self.client.create_payload_index(
70 | self.collection_name, 'repo_id', field_schema=rest.PayloadSchemaType.KEYWORD
71 | )
72 | await self.client.create_payload_index(
73 | self.collection_name, 'file_id', field_schema=rest.PayloadSchemaType.KEYWORD
74 | )
75 |
76 | async def create(self, documents: List[Document]):
77 | repo_id = documents[0].meta.repo_id
78 | print_structs = []
79 | for document in documents:
80 | document.meta.content = document.content # qdrant,就只能存三个值id vector payload,所以只能把content转到meta
81 | for document in documents:
82 | document.meta.repo_id = repo_id
83 | print_structs.append(rest.PointStruct(
84 | id=document.meta.segment_id,
85 | vector=document.vector,
86 | payload=document.meta.model_dump(exclude_none=True),
87 | ))
88 |
89 | await self.client.upsert(collection_name=self.collection_name, points=print_structs)
90 |
91 | async def upsert(self, documents: List[Document]):
92 | repo_id = documents[0].meta.repo_id
93 | file_ids = list(set(document.meta.file_id for document in documents))
94 | await self.delete(repo_id, file_ids)
95 | await self.create(documents)
96 |
97 | async def all_meta(self, repo_id: str) -> List[DocumentMeta]:
98 | scroll_filter = rest.Filter(
99 | must=[
100 | rest.FieldCondition(
101 | key="repo_id",
102 | match=rest.MatchValue(value=repo_id),
103 | ),
104 | ],
105 | )
106 |
107 | records, next_offset = await self.client.scroll(
108 | collection_name=self.collection_name,
109 | scroll_filter=scroll_filter,
110 | limit=100,
111 | with_payload=True
112 | )
113 |
114 | while next_offset:
115 | scroll_records, next_offset = await self.client.scroll(
116 | collection_name=self.collection_name,
117 | scroll_filter=scroll_filter,
118 | limit=100,
119 | with_payload=True,
120 | offset=next_offset
121 | )
122 | records.extend(scroll_records)
123 | return [DocumentMeta.model_validate(record.payload) for record in records]
124 |
125 | async def drop(self, repo_id: str):
126 | filter_selector = rest.Filter(
127 | must=[
128 | rest.FieldCondition(
129 | key="repo_id",
130 | match=rest.MatchValue(value=repo_id),
131 | ),
132 | ],
133 | )
134 | try:
135 | await self.client.delete(
136 | collection_name=self.collection_name,
137 | points_selector=filter_selector
138 | )
139 | except UnexpectedResponse as e:
140 | if e.status_code == 404:
141 | return
142 | raise e
143 |
144 | async def delete(self, repo_id: str, file_ids: List[str]):
145 | filter_selector = rest.Filter(
146 | must=[
147 | rest.FieldCondition(
148 | key="repo_id",
149 | match=rest.MatchValue(value=repo_id),
150 | ),
151 | rest.FieldCondition(
152 | key="file_id",
153 | match=rest.MatchAny(any=file_ids),
154 | )
155 | ],
156 | )
157 | try:
158 | await self.client.delete(
159 | collection_name=self.collection_name,
160 | points_selector=filter_selector
161 | )
162 | except UnexpectedResponse as e:
163 | if e.status_code == 404:
164 | return
165 | raise e
166 | except Exception as e:
167 | raise e
168 |
169 | async def search_by_vector(
170 | self,
171 | repo_id: str,
172 | file_ids: Optional[List[str]],
173 | query_vector: List[float],
174 | top_k: int = 5,
175 | score_threshold: float = 0.0
176 | ) -> List[Document]:
177 |
178 | query_filter = rest.Filter(
179 | must=[
180 | rest.FieldCondition(
181 | key="repo_id",
182 | match=rest.MatchValue(value=repo_id),
183 | ),
184 | ],
185 | )
186 | # 如果 file_ids 不为空,添加 file_id 的过滤条件
187 | if file_ids:
188 | query_filter.must.append(
189 | rest.FieldCondition(
190 | key="file_id",
191 | match=rest.MatchAny(any=file_ids),
192 | )
193 | )
194 | response: rest.QueryResponse = await self.client.query_points(
195 | collection_name=self.collection_name,
196 | query=query_vector,
197 | query_filter=query_filter,
198 | limit=top_k,
199 | with_payload=True,
200 | with_vectors=True,
201 | score_threshold=score_threshold,
202 | )
203 | documents: List[Document] = []
204 | for point in response.points:
205 | meta = DocumentMeta.model_validate(point.payload)
206 | meta.score = point.score
207 | documents.append(Document(meta=meta))
208 | return documents
209 |
--------------------------------------------------------------------------------
/extensions/vector/pgvector.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/14
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : pgvector
7 | # @Software: PyCharm
8 | import json
9 | from typing import List, Optional
10 |
11 | import numpy as np
12 | from sqlalchemy import text, bindparam
13 | from sqlalchemy.ext.asyncio import create_async_engine
14 |
15 | from domain.entity.document import Document, DocumentMeta
16 | from extensions.vector.base_vector import BaseVector
17 | from setting.setting import get_we0_index_settings
18 |
19 | settings = get_we0_index_settings()
20 |
21 | SQL_CREATE_FILE_INDEX = lambda table_name: f"""
22 | CREATE INDEX IF NOT EXISTS file_idx ON {table_name} (file_id);
23 | """
24 | SQL_CREATE_REPO_FILE_INDEX = lambda table_name: f"""
25 | CREATE INDEX IF NOT EXISTS repo_file_idx ON {table_name} (repo_id, file_id);
26 | """
27 | SQL_CREATE_EMBEDDING_INDEX = lambda table_name: f"""
28 | CREATE INDEX IF NOT EXISTS embedding_cosine_idx ON {table_name}
29 | USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
30 | """
31 |
32 | SQL_CREATE_TABLE = lambda table_name, dimension: f"""
33 | CREATE TABLE IF NOT EXISTS {table_name} (
34 | id UUID PRIMARY KEY,
35 | repo_id UUID NOT NULL,
36 | file_id UUID NOT NULL,
37 | content TEXT NOT NULL,
38 | meta JSONB NOT NULL,
39 | embedding vector({dimension}) NOT NULL
40 | ) USING heap;
41 | """
42 |
43 |
44 | class PgVector(BaseVector):
45 |
46 | def __init__(self):
47 | self.client = self.get_client()
48 | self.table_name: str | None = None
49 | self.normalized: bool = False
50 |
51 | @staticmethod
52 | def get_client():
53 | pgvector = settings.vector.pgvector
54 | return create_async_engine(
55 | url=f"postgresql+psycopg://{pgvector.user}:{pgvector.password}@{pgvector.host}:{pgvector.port}/{pgvector.db}",
56 | echo=False,
57 | )
58 |
59 | async def init(self):
60 | async with self.client.begin() as conn:
61 | dimension = await self.get_dimension()
62 | if dimension > 2000:
63 | dimension = 2000
64 | self.normalized = True
65 | self.table_name = self.dynamic_collection_name(dimension)
66 | await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
67 | await conn.execute(text(SQL_CREATE_TABLE(self.table_name, dimension)))
68 | await conn.execute(text(SQL_CREATE_REPO_FILE_INDEX(self.table_name)))
69 | await conn.execute(text(SQL_CREATE_FILE_INDEX(self.table_name)))
70 | await conn.execute(text(SQL_CREATE_EMBEDDING_INDEX(self.table_name)))
71 |
72 | async def _create(self, repo_id: str, documents: List[Document]):
73 | stmt = text(
74 | f"""
75 | INSERT INTO {self.table_name} (id, repo_id, file_id, content, meta, embedding)
76 | VALUES (:id, :repo_id, :file_id, :content, :meta, :embedding)
77 | ON CONFLICT (id) DO UPDATE SET
78 | repo_id = EXCLUDED.repo_id,
79 | file_id = EXCLUDED.file_id,
80 | content = EXCLUDED.content,
81 | meta = EXCLUDED.meta,
82 | embedding = EXCLUDED.embedding
83 | """
84 | )
85 | parameters = []
86 | for doc in documents:
87 | if doc.meta is not None:
88 | if self.normalized:
89 | vector = self.normalize_l2(doc.vector[:2000])
90 | else:
91 | vector = doc.vector
92 | parameters.append({
93 | 'id': doc.meta.segment_id,
94 | 'repo_id': repo_id,
95 | 'file_id': doc.meta.file_id,
96 | 'content': doc.content,
97 | 'meta': doc.meta.model_dump_json(exclude={'score', 'content'}),
98 | 'embedding': vector,
99 | })
100 |
101 | return stmt, parameters
102 |
103 | async def create(self, documents: List[Document]):
104 | async with self.client.begin() as conn:
105 | repo_id = documents[0].meta.repo_id
106 | insert_stmt, insert_parameters = await self._create(repo_id, documents)
107 | await conn.execute(
108 | insert_stmt,
109 | insert_parameters
110 | )
111 |
112 | async def upsert(self, documents: List[Document]):
113 | repo_id = documents[0].meta.repo_id
114 | file_ids = list(set(document.meta.file_id for document in documents))
115 | async with self.client.begin() as conn:
116 | delete_stmt, delete_parameters = await self._delete(repo_id, file_ids)
117 | await conn.execute(
118 | delete_stmt,
119 | delete_parameters
120 | )
121 | insert_stmt, insert_parameters = await self._create(repo_id, documents)
122 | await conn.execute(
123 | insert_stmt,
124 | insert_parameters
125 | )
126 |
127 | async def all_meta(self, repo_id: str) -> List[DocumentMeta]:
128 | async with self.client.begin() as conn:
129 | sql_query = (
130 | f"SELECT meta "
131 | f"FROM {self.table_name} "
132 | f"WHERE repo_id = :repo_id"
133 | )
134 | result = await conn.execute(
135 | text(sql_query),
136 | {
137 | "repo_id": repo_id,
138 | }
139 | )
140 | records = result.all()
141 | return [DocumentMeta.model_validate(meta[0]) for meta in records]
142 |
143 | async def drop(self, repo_id: str):
144 | async with self.client.begin() as conn:
145 | await conn.execute(
146 | text(
147 | f"DELETE FROM {self.table_name} "
148 | "WHERE repo_id = :repo_id"
149 | ),
150 | {'repo_id': repo_id}
151 | )
152 |
153 | async def _delete(self, repo_id: str, file_ids: List[str]):
154 | sql_query = (
155 | f"DELETE FROM {self.table_name} "
156 | "WHERE repo_id = :repo_id AND file_id = ANY(:file_ids) "
157 | )
158 | stmt = text(sql_query).bindparams(
159 | bindparam("file_ids", expanding=False) # 关键点:禁止自动展开列表为多个参数
160 | )
161 | parameters = {'repo_id': repo_id, 'file_ids': file_ids}
162 | return stmt, parameters
163 |
164 | async def delete(self, repo_id: str, file_ids: List[str]):
165 | async with self.client.begin() as conn:
166 | delete_stmt, delete_parameters = await self._delete(repo_id, file_ids)
167 | await conn.execute(
168 | delete_stmt,
169 | delete_parameters
170 | )
171 |
172 | async def search_by_vector(
173 | self,
174 | repo_id: str,
175 | file_ids: Optional[List[str]],
176 | query_vector: List[float],
177 | top_k: int = 5,
178 | score_threshold: float = 0.0
179 | ) -> List[Document]:
180 | documents = []
181 | async with self.client.begin() as conn:
182 | # 基础 SQL 查询
183 | sql_query = (
184 | f"SELECT content, meta, embedding <=> :query_vector AS distance "
185 | f"FROM {self.table_name} "
186 | f"WHERE repo_id = :repo_id "
187 | )
188 | if self.normalized:
189 | query_vector = self.normalize_l2(query_vector[:2000])
190 | # 基础参数
191 | parameters = {
192 | "query_vector": json.dumps(query_vector),
193 | "repo_id": repo_id,
194 | "top_k": top_k
195 | }
196 |
197 | if file_ids:
198 | sql_query += "AND file_id = ANY(:file_ids) "
199 | parameters["file_ids"] = file_ids
200 |
201 | sql_query += "ORDER BY distance LIMIT :top_k"
202 |
203 | if file_ids:
204 | stmt = text(sql_query).bindparams(
205 | bindparam("file_ids", expanding=False)
206 | )
207 | else:
208 | stmt = text(sql_query)
209 |
210 | result = await conn.execute(
211 | stmt,
212 | parameters
213 | )
214 | records = result.all()
215 | for record in records:
216 | content, meta, distance = record
217 | score = 1 - distance
218 | if score > score_threshold:
219 | meta["score"] = score
220 | meta['content'] = content
221 | documents.append(Document(content=content, meta=DocumentMeta.model_validate(meta)))
222 | return documents
223 |
224 | @staticmethod
225 | def normalize_l2(x: List[float]) -> List[float]:
226 | x = np.array(x)
227 | if x.ndim == 1:
228 | norm = np.linalg.norm(x)
229 | if norm == 0:
230 | return x.tolist()
231 | return (x / norm).tolist()
232 | else:
233 | norm = np.linalg.norm(x, 2, axis=1, keepdims=True)
234 | return np.where(norm == 0, x, x / norm).tolist()
235 |
--------------------------------------------------------------------------------
/loader/segmenter/base_line_segmenter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2025/2/11
4 | # @Author : .*?
5 | # @Email : amashiro2233@gmail.com
6 | # @File : line_based_segmenter
7 | # @Software: PyCharm
8 | from bisect import bisect_right, bisect_left
9 | from typing import List, Dict, Any, Tuple, Iterator
10 |
11 | from domain.entity.code_segment import CodeSegment
12 | from loader.segmenter.base_segmenter import BaseSegmenter
13 |
14 |
15 | class LineBasedSegmenter(BaseSegmenter):
16 |
17 | def __init__(
18 | self,
19 | text: str,
20 | max_chunk_size: int = 50,
21 | min_chunk_size: int = 10,
22 | delimiters=None,
23 | **kwargs
24 | ):
25 | super().__init__(**kwargs)
26 | self.text = text
27 | self.max_chunk_size = max_chunk_size
28 | self.min_chunk_size = min_chunk_size
29 | self.delimiters = delimiters or ['\n\n', '\n']
30 |
31 | # 构建行号位置映射
32 | self.line_positions = self._build_line_positions(text)
33 |
34 | def segment(self) -> Iterator[CodeSegment]:
35 | """执行分块操作"""
36 | initial_range = (0, len(self.text))
37 | segments = self._recursive_split(initial_range, self.delimiters)
38 | if self.merge_small_chunks:
39 | segments = self._merge_small_segments(segments)
40 | for segment in segments:
41 | code_segment = CodeSegment.model_validate(segment)
42 | if not code_segment.code.isspace():
43 | yield code_segment
44 |
45 | @staticmethod
46 | def _build_line_positions(text: str) -> List[Tuple[int, int]]:
47 | """生成每行的字符位置范围列表 (start, end)"""
48 | lines = []
49 | start = 0
50 | for line in text.split('\n'):
51 | end = start + len(line)
52 | lines.append((start, end))
53 | start = end + 1 # 跳过换行符
54 | return lines
55 |
56 | def _get_original_lines(self, start_pos: int, end_pos: int) -> Tuple[int, int]:
57 | """根据字符位置获取原始行号 (1-based)"""
58 | starts = [line[0] for line in self.line_positions]
59 | ends = [line[1] for line in self.line_positions]
60 |
61 | start_line = bisect_right(starts, start_pos)
62 | end_line = bisect_left(ends, end_pos)
63 |
64 | # 处理边界情况
65 | start_line = max(0, start_line - 1) if start_line > 0 else 0
66 | end_line = min(end_line, len(ends) - 1)
67 |
68 | return start_line + 1, end_line + 1 # 转换为1-based
69 |
70 | def _recursive_split(
71 | self,
72 | char_range: Tuple[int, int],
73 | delimiters: List[str]
74 | ) -> List[Dict[str, Any]]:
75 | start_pos, end_pos = char_range
76 | text = self.text[start_pos:end_pos]
77 |
78 | if self.max_tokens is not None:
79 | raw_lines = text.splitlines(keepends=True)
80 | segments = []
81 | current_chunk_start = start_pos
82 | pos = start_pos
83 | any_long_line = False
84 | for line in raw_lines:
85 | # 对单行去掉首尾空白后计算 token 数
86 | if self.length_function(line.strip()) > self.max_tokens:
87 | any_long_line = True
88 | # 先处理该行之前的部分(如果存在)
89 | if pos > current_chunk_start:
90 | segments.extend(
91 | self._recursive_split((current_chunk_start, pos), delimiters)
92 | )
93 | # 对这一行进行强制拆分(标记 forced 为 True)
94 | segments.extend(
95 | self._forced_split((pos, pos + len(line)), block=1)
96 | )
97 | current_chunk_start = pos + len(line)
98 | pos += len(line)
99 | if any_long_line:
100 | if current_chunk_start < end_pos:
101 | segments.extend(
102 | self._recursive_split((current_chunk_start, end_pos), delimiters)
103 | )
104 | return segments
105 |
106 | current_lines = len(text.splitlines())
107 |
108 | # 判断是否需要分割
109 | need_split = False
110 | if current_lines > self.max_chunk_size:
111 | need_split = True
112 | if self.max_tokens is not None:
113 | current_tokens = self.length_function(text)
114 | if current_tokens > self.max_tokens:
115 | need_split = True
116 |
117 | if not need_split:
118 | code = text.strip()
119 | if not code:
120 | return []
121 | start_line, end_line = self._get_original_lines(start_pos, end_pos)
122 | # 正常返回的分块,标记 forced 为 False
123 | return [{
124 | "start": start_line,
125 | "end": end_line,
126 | "code": code,
127 | "forced": False
128 | }]
129 |
130 | # 尝试使用当前分隔符分割
131 | if delimiters:
132 | current_delim = delimiters[0]
133 | parts = text.split(current_delim)
134 | if len(parts) > 1:
135 | return self._split_by_delimiter(start_pos, end_pos, current_delim, delimiters[1:])
136 |
137 | # 无有效分隔符时强制分割
138 | return self._forced_split(char_range)
139 |
140 | def _split_by_delimiter(
141 | self,
142 | start_pos: int,
143 | end_pos: int,
144 | delimiter: str,
145 | next_delimiters: List[str]
146 | ) -> List[Dict[str, Any]]:
147 | """使用指定分隔符进行分割"""
148 | text = self.text[start_pos:end_pos]
149 | delim_len = len(delimiter)
150 | segments = []
151 | current_start = start_pos
152 |
153 | for part in text.split(delimiter):
154 | if not part.strip():
155 | current_start += len(part) + delim_len
156 | continue
157 |
158 | part_end = current_start + len(part)
159 | sub_segments = self._recursive_split(
160 | (current_start, part_end),
161 | next_delimiters
162 | )
163 | segments.extend(sub_segments)
164 | current_start = part_end + delim_len
165 |
166 | return segments
167 |
168 | def _compute_optimal_chunk_length(self, char_range: Tuple[int, int]) -> int:
169 | """
170 | 在指定的字符区间内(二分查找),计算一个合适的拆分长度,使得:
171 | self.length_function(文本[start_pos:start_pos+chunk_len]) <= self.max_tokens
172 | """
173 | start_pos, end_pos = char_range
174 | total_chars = end_pos - start_pos
175 |
176 | low, high = 1, total_chars
177 | increment_value = self.max_tokens // 10
178 | optimal = low # 至少保证 1 个字符
179 | while low <= high:
180 | mid = (low + high) // 2
181 | chunk_text = self.text[start_pos:start_pos + mid]
182 | token_count = self.length_function(chunk_text)
183 | if token_count <= self.max_tokens:
184 | # 当前长度符合要求,记录下来并尝试更大的长度
185 | optimal = mid
186 | low = mid + increment_value
187 | else:
188 | # 当前长度超出限制,尝试减小长度
189 | high = mid - increment_value
190 |
191 | return optimal
192 |
193 | def _forced_split(self, char_range: Tuple[int, int], block: int = 1) -> List[Dict[str, Any]]:
194 | start_pos, end_pos = char_range
195 | segments = []
196 | pos = start_pos
197 | current_block = block
198 |
199 | while pos < end_pos:
200 | optimal_chunk_size = self._compute_optimal_chunk_length((pos, end_pos))
201 | next_pos = min(pos + optimal_chunk_size, end_pos)
202 | chunk_text = self.text[pos:next_pos].strip()
203 | if chunk_text:
204 | start_line, end_line = self._get_original_lines(pos, next_pos)
205 | segments.append({
206 | "start": start_line,
207 | "end": end_line,
208 | "code": chunk_text,
209 | "forced": True, # 标记为强制拆分生成的片段
210 | "block": current_block
211 | })
212 | current_block += 1
213 | pos = next_pos
214 | return segments
215 |
216 | def _merge_small_segments(self, segments: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
217 | merged = []
218 | if not segments:
219 | return merged
220 | buffer = segments[0]
221 | for seg in segments[1:]:
222 | # 如果一个分块是强制拆分的,而另一个不是,则不进行合并,直接分开处理
223 | if buffer.get("forced", False) != seg.get("forced", False):
224 | merged.append(buffer)
225 | buffer = seg
226 | continue
227 |
228 | merged_code = f"{buffer['code']}\n{seg['code']}"
229 | merged_lines = len(merged_code.splitlines())
230 | current_lines = buffer["end"] - buffer["start"] + 1
231 | next_lines = seg["end"] - seg["start"] + 1
232 |
233 | if self.max_tokens is not None:
234 | # 有 token 限制时,必须确保合并后 token 数不超过限制
235 | merged_tokens = self.length_function(merged_code)
236 | if merged_tokens <= self.max_tokens:
237 | buffer = {
238 | "start": buffer["start"],
239 | "end": seg["end"],
240 | "code": merged_code.strip(),
241 | "forced": buffer.get("forced", False)
242 | }
243 | else:
244 | merged.append(buffer)
245 | buffer = seg
246 | else:
247 | # 无 token 限制时,基于行数进行合并:
248 | # 如果合并后行数不超过 max_chunk_size 或者任一块过小,都可以合并
249 | if (merged_lines <= self.max_chunk_size) or (
250 | current_lines < self.min_chunk_size or next_lines < self.min_chunk_size):
251 | buffer = {
252 | "start": buffer["start"],
253 | "end": seg["end"],
254 | "code": merged_code.strip(),
255 | "forced": buffer.get("forced", False)
256 | }
257 | else:
258 | merged.append(buffer)
259 | buffer = seg
260 | merged.append(buffer)
261 | return merged
262 |
--------------------------------------------------------------------------------
/loader/segmenter/tree_sitter_segmenter.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from collections import deque
3 | from typing import List, Dict, Any, Iterator
4 |
5 | from tree_sitter import Language, Parser
6 |
7 | from domain.entity.code_segment import CodeSegment
8 | from loader.segmenter.base_segmenter import BaseSegmenter
9 |
10 |
11 | class TreeSitterSegmenter(BaseSegmenter):
12 | def __init__(
13 | self,
14 | text: str,
15 | chunk_size: int = 30,
16 | min_chunk_size: int = 10,
17 | max_chunk_size: int = 50,
18 | max_depth: int = 5,
19 | split_large_chunks=True,
20 | **kwargs
21 | ):
22 | super().__init__(**kwargs)
23 | self.text = text
24 | self.chunk_size = chunk_size
25 | self.min_chunk_size = min_chunk_size
26 | self.max_chunk_size = max_chunk_size
27 | self.max_depth = max_depth
28 | self.split_large_chunks = split_large_chunks
29 | self.parser = self.get_parser()
30 |
31 | @abstractmethod
32 | def get_language(self) -> Language:
33 | """返回 Tree-sitter 对应语言对象。示例:
34 | return Language('build/my-languages.so', 'python')
35 | """
36 | raise NotImplementedError
37 |
38 | @abstractmethod
39 | def get_node_types(self) -> List[str]:
40 | """
41 | 返回要提取的“目标节点类型”,
42 | 如 ["function_definition", "class_definition"] 等
43 | """
44 | raise NotImplementedError
45 |
46 | @abstractmethod
47 | def get_recursion_node_types(self) -> List[str]:
48 | """
49 | 返回需要继续递归遍历其子节点的类型,
50 | 如 ["module", "block", "class_definition", "function_definition"] 等
51 | """
52 | raise NotImplementedError
53 |
54 | def is_valid(self) -> bool:
55 | """简单检查语法是否存在 ERROR 节点"""
56 | error_query = self.parser.language.query("(ERROR) @error")
57 | tree = self.parser.parse(bytes(self.text, encoding="UTF-8"))
58 | return len(error_query.captures(tree.root_node)) == 0
59 |
60 | def segment(self) -> Iterator[CodeSegment]:
61 | tree = self.parser.parse(bytes(self.text, "utf-8"))
62 | unfiltered_nodes = deque()
63 | for child in tree.root_node.children:
64 | unfiltered_nodes.append((child, 1))
65 | all_nodes = []
66 | while unfiltered_nodes:
67 | current_node, current_depth = unfiltered_nodes.popleft()
68 | if current_node.type in self.get_node_types():
69 | all_nodes.append(current_node)
70 | if current_node.type in self.get_recursion_node_types():
71 | if current_depth < self.max_depth:
72 | for child in current_node.children:
73 | unfiltered_nodes.append((child, current_depth + 1))
74 | all_nodes.sort(key=lambda n: (n.start_point[0], -n.end_point[0]))
75 |
76 | processed_ranges = []
77 | processed_chunks_info = []
78 | for node in all_nodes:
79 | start_0 = node.start_point[0]
80 | end_0 = node.end_point[0]
81 | node_text = node.text.decode()
82 | if self._is_range_covered(start_0, end_0, processed_ranges):
83 | continue
84 | processed_ranges.append((start_0, end_0))
85 | processed_chunks_info.append({
86 | "start": start_0 + 1,
87 | "end": end_0 + 1,
88 | "code": node_text
89 | })
90 |
91 | processed_chunks_info.sort(key=lambda x: x["start"])
92 | # 用 split("\n") 来和 Tree-sitter 的行数计算方式保持一致
93 | code_lines = self.text.split("\n")
94 | # 明确计算总行数:tree-sitter 的行数 = 文件中 "\n" 数量 + 1
95 | total_lines = self.text.count("\n") + 1
96 |
97 | combined_chunks = []
98 | current_pos = 0 # 0-based 行号
99 | for chunk in processed_chunks_info:
100 | chunk_start_0 = chunk["start"] - 1
101 | chunk_end_0 = chunk["end"] - 1
102 | if current_pos < chunk_start_0:
103 | unprocessed = self._handle_unprocessed(
104 | code_lines[current_pos:chunk_start_0],
105 | current_pos
106 | )
107 | combined_chunks.extend(unprocessed)
108 | combined_chunks.append({
109 | "start": chunk["start"],
110 | "end": chunk["end"],
111 | "code": chunk["code"]
112 | })
113 | current_pos = chunk_end_0 + 1
114 |
115 | if current_pos < total_lines:
116 | unprocessed = self._handle_unprocessed(
117 | code_lines[current_pos:total_lines],
118 | current_pos
119 | )
120 | combined_chunks.extend(unprocessed)
121 |
122 | final_chunks = self._post_process_chunks(combined_chunks)
123 | if self.max_tokens:
124 | final_chunks = self._split_by_tokens(final_chunks, self.max_tokens)
125 | for segment in final_chunks:
126 | code_segment = CodeSegment.model_validate(segment)
127 | if not code_segment.code.isspace():
128 | yield code_segment
129 |
130 | def get_parser(self) -> Parser:
131 | """初始化并返回 Parser 对象"""
132 | parser = Parser()
133 | parser.language = self.get_language()
134 | return parser
135 |
136 | @staticmethod
137 | def _is_range_covered(start: int, end: int, existing_ranges: List[tuple]) -> bool:
138 | """
139 | 优化后的区间覆盖检查(时间复杂度 O(n))。
140 | 如果已有区间 [existing_start, existing_end] 能完全覆盖 [start, end],则返回 True
141 | """
142 | for (existing_start, existing_end) in existing_ranges:
143 | if existing_start <= start and end <= existing_end:
144 | return True
145 | return False
146 |
147 | def _handle_unprocessed(self, lines: List[str], start_0: int) -> List[Dict[str, Any]]:
148 | """处理未识别的代码区域。"""
149 | if not lines:
150 | return []
151 | return self._split_into_chunks_without_empty_lines(lines, start_0)
152 |
153 | @staticmethod
154 | def _split_into_chunks_without_empty_lines(
155 | lines: List[str],
156 | start_0_based: int
157 | ) -> List[Dict[str, Any]]:
158 | """将未识别区域当作一个连续块输出,保留所有行"""
159 | return [{
160 | "start": start_0_based + 1,
161 | "end": start_0_based + len(lines),
162 | "code": "\n".join(lines)
163 | }]
164 |
165 | def _post_process_chunks(
166 | self,
167 | chunks: List[Dict[str, Any]],
168 | ) -> List[Dict[str, Any]]:
169 | """
170 | 对 chunks 做二次处理:
171 | 1) 如果 chunk 行数 > max_chunk_size,按 chunk_size 均分拆分(行数维度)。
172 | 2) 如果 chunk 行数 < min_chunk_size,尝试与前或后合并(只要合并后不超过 max_chunk_size)。
173 | """
174 | # 第一步:先将所有大于 max_chunk_size 的 chunk 拆分
175 | split_chunks = []
176 | if self.split_large_chunks:
177 | for c in chunks:
178 | if len(c['code']):
179 | split_chunks.extend(
180 | self._split_large_chunk(c, self.chunk_size, self.max_chunk_size)
181 | )
182 | else:
183 | split_chunks = chunks
184 |
185 | # 第二步:对拆分后的结果,进行“过短合并”
186 | if self.merge_small_chunks:
187 | return self._merge_small_chunks(
188 | split_chunks,
189 | self.min_chunk_size,
190 | self.max_chunk_size,
191 | self.chunk_size
192 | )
193 | else:
194 | return split_chunks
195 |
196 | @staticmethod
197 | def _split_large_chunk(
198 | chunk: Dict[str, Any],
199 | chunk_size: int,
200 | max_chunk_size: int
201 | ) -> List[Dict[str, Any]]:
202 | """
203 | 如果某个 chunk 行数大于 max_chunk_size,
204 | 则根据 chunk_size 进行“均分拆分”,保证子块不会过大或过小。
205 | """
206 | lines = chunk["code"].split("\n")
207 | total_lines = len(lines)
208 | if total_lines <= max_chunk_size:
209 | return [chunk]
210 |
211 | # 初步估计拆分为 n 块
212 | n = max(1, round(total_lines / chunk_size))
213 | # 如果拆分后每块依旧大于 max_chunk_size,则增加 n
214 | while (total_lines / n) > max_chunk_size:
215 | n += 1
216 | # 确保 n 不超过 total_lines
217 | n = min(n, total_lines)
218 |
219 | base_size = total_lines // n
220 | remainder = total_lines % n # 前 remainder 块每块多 1 行
221 | results = []
222 | current_index = 0
223 | original_start = chunk["start"]
224 | for i in range(n):
225 | current_chunk_size = base_size + (1 if i < remainder else 0)
226 | sub_lines = lines[current_index: current_index + current_chunk_size]
227 | results.append({
228 | "start": original_start + current_index,
229 | "end": original_start + current_index + current_chunk_size - 1,
230 | "code": "\n".join(sub_lines)
231 | })
232 | current_index += current_chunk_size
233 | return results
234 |
235 | @staticmethod
236 | def _merge_small_chunks(
237 | chunks: List[Dict[str, Any]],
238 | min_chunk_size: int,
239 | max_chunk_size: int,
240 | chunk_size: int
241 | ) -> List[Dict[str, Any]]:
242 | """
243 | 对较小 chunk (< min_chunk_size) 做前后合并(只要合并后不超过 max_chunk_size)。
244 | 同时设置了一个“理想范围”容差(tolerance_low ~ tolerance_high),
245 | 如果相邻块本身已经在理想范围内,就不再合并以免破坏合理块。
246 | """
247 | merged = chunks.copy()
248 | changed = True
249 | tolerance_low = 0.8 * chunk_size
250 | tolerance_high = 1.2 * chunk_size
251 |
252 | while changed:
253 | changed = False
254 | new_merged = []
255 | i = 0
256 | while i < len(merged):
257 | current = merged[i]
258 | current_lines = len(current["code"].split("\n"))
259 |
260 | # 如果当前 chunk 已经足够大,则直接放入 new_merged
261 | if current_lines >= min_chunk_size:
262 | new_merged.append(current)
263 | i += 1
264 | continue
265 |
266 | best_merge = None
267 | best_score = float('inf')
268 |
269 | # 尝试与前一个合并
270 | if new_merged:
271 | prev = new_merged[-1]
272 | prev_lines = len(prev["code"].split("\n"))
273 | # 如果前一个 chunk 不在理想范围内,且合并后行数不超过 max_chunk_size
274 | if not (tolerance_low <= prev_lines <= tolerance_high):
275 | if prev_lines + current_lines <= max_chunk_size:
276 | score = abs((prev_lines + current_lines) - chunk_size)
277 | if score < best_score:
278 | best_merge = 'prev'
279 | best_score = score
280 | # 尝试与后一个合并
281 | if i + 1 < len(merged):
282 | nxt = merged[i + 1]
283 | nxt_lines = len(nxt["code"].split("\n"))
284 | if not (tolerance_low <= nxt_lines <= tolerance_high):
285 | if current_lines + nxt_lines <= max_chunk_size:
286 | score = abs((current_lines + nxt_lines) - chunk_size)
287 | if score < best_score:
288 | best_merge = 'next'
289 | if best_merge == 'prev':
290 | prev = new_merged.pop()
291 | new_chunk = {
292 | "start": prev["start"],
293 | "end": current["end"],
294 | "code": prev["code"] + "\n" + current["code"]
295 | }
296 | new_merged.append(new_chunk)
297 | i += 1
298 | changed = True
299 | elif best_merge == 'next':
300 | nxt = merged[i + 1]
301 | new_chunk = {
302 | "start": current["start"],
303 | "end": nxt["end"],
304 | "code": current["code"] + "\n" + nxt["code"]
305 | }
306 | new_merged.append(new_chunk)
307 | i += 2
308 | changed = True
309 | else:
310 | # 如果当前 chunk 特别小 (< min_chunk_size/2),则尝试强制与前一个合并
311 | if current_lines < (min_chunk_size / 2) and new_merged:
312 | prev = new_merged.pop()
313 | new_chunk = {
314 | "start": prev["start"],
315 | "end": current["end"],
316 | "code": prev["code"] + "\n" + current["code"]
317 | }
318 | new_merged.append(new_chunk)
319 | changed = True
320 | else:
321 | new_merged.append(current)
322 | i += 1
323 | merged = new_merged
324 | return merged
325 |
326 | def _split_by_tokens(
327 | self,
328 | chunks: List[Dict[str, Any]],
329 | max_tokens: int,
330 | block: int = 1
331 | ) -> List[Dict[str, Any]]:
332 | processed = []
333 | for chunk in chunks:
334 | current_code = chunk["code"]
335 | current_tokens = self.length_function(current_code)
336 | if current_tokens <= max_tokens:
337 | if block == 1:
338 | processed.append(chunk)
339 | else:
340 | # 为分割后的块添加序号标记
341 | chunk_with_part = chunk.copy()
342 | chunk_with_part["block"] = block
343 | processed.append(chunk_with_part)
344 | continue
345 |
346 | lines = current_code.split('\n')
347 | if len(lines) > 1:
348 | split_index = len(lines) // 2
349 | first_code = '\n'.join(lines[:split_index])
350 | second_code = '\n'.join(lines[split_index:])
351 | first_start = chunk["start"]
352 | first_end = first_start + split_index - 1
353 | second_start = first_end + 1
354 | second_end = chunk["end"]
355 | first_sub = {"start": first_start, "end": first_end, "code": first_code}
356 | second_sub = {"start": second_start, "end": second_end, "code": second_code}
357 | processed.extend(self._split_by_tokens([first_sub, second_sub], max_tokens, block))
358 | else:
359 | code = current_code
360 | low, high = 0, len(code)
361 | best_mid = 0
362 | while low <= high:
363 | mid = (low + high) // 2
364 | part = code[:mid]
365 | if self.length_function(part) <= max_tokens:
366 | best_mid = mid
367 | low = mid + 1
368 | else:
369 | high = mid - 1
370 | best_mid = best_mid if best_mid != 0 else len(code) // 2
371 | first_part = code[:best_mid]
372 | second_part = code[best_mid:]
373 | first_sub = {"start": chunk["start"], "end": chunk["end"], "code": first_part}
374 | second_sub = {"start": chunk["start"], "end": chunk["end"], "code": second_part}
375 | # 递归调用时增加序号
376 | processed.extend(self._split_by_tokens([first_sub], max_tokens, block))
377 | processed.extend(self._split_by_tokens([second_sub], max_tokens, block + 1))
378 | return processed
379 |
--------------------------------------------------------------------------------