├── .env.example ├── docs └── images │ ├── logo.png │ └── api-doc.png ├── models ├── __init__.py └── rerank_result.py ├── main.py ├── services ├── __init__.py ├── rerank_service.py └── cross_encoder │ └── cross_encoder_rerank_service.py ├── http_requests ├── __init__.py └── rerank │ └── rerank_request.py ├── http_responses ├── __init__.py └── rerank │ └── rerank_response.py ├── docker-compose.yaml ├── pyproject.toml ├── Dockerfile ├── endpoints └── __init__.py ├── .github └── workflows │ └── docker-image.yaml ├── LICENSE ├── README.md └── .gitignore /.env.example: -------------------------------------------------------------------------------- 1 | HF_ENDPOINT=https://huggingface.co -------------------------------------------------------------------------------- /docs/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugarforever/peanut-shell/main/docs/images/logo.png -------------------------------------------------------------------------------- /docs/images/api-doc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugarforever/peanut-shell/main/docs/images/api-doc.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.rerank_result import RerankResult 2 | 3 | __all__ = [ 4 | "RerankResult" 5 | ] 6 | -------------------------------------------------------------------------------- /models/rerank_result.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class RerankResult(BaseModel): 5 | index: int 6 | relevance_score: float 7 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from endpoints import router 3 | 4 | app = FastAPI() 5 | 6 | app.include_router(router=router, prefix="/v1") 7 | -------------------------------------------------------------------------------- /services/__init__.py: -------------------------------------------------------------------------------- 1 | from services.cross_encoder.cross_encoder_rerank_service import CrossEncoderRerankService 2 | 3 | __all__ = [ 4 | "CrossEncoderRerankService" 5 | ] 6 | -------------------------------------------------------------------------------- /http_requests/__init__.py: -------------------------------------------------------------------------------- 1 | from http_requests.rerank.rerank_request import RerankRequest, CROSS_ENCODER_MODELS 2 | 3 | __all__ = [ 4 | "CROSS_ENCODER_MODELS", 5 | "RerankRequest" 6 | ] 7 | -------------------------------------------------------------------------------- /http_responses/__init__.py: -------------------------------------------------------------------------------- 1 | from http_responses.rerank.rerank_response import RerankModelsResponse, RerankResponse 2 | 3 | __all__ = [ 4 | "RerankModelsResponse", 5 | "RerankResponse" 6 | ] 7 | -------------------------------------------------------------------------------- /http_responses/rerank/rerank_response.py: -------------------------------------------------------------------------------- 1 | 2 | from pydantic import BaseModel 3 | from models import RerankResult 4 | 5 | 6 | class RerankModelsResponse(BaseModel): 7 | models: list[str] 8 | 9 | 10 | class RerankResponse(BaseModel): 11 | results: list[RerankResult] 12 | -------------------------------------------------------------------------------- /services/rerank_service.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from http_requests import RerankRequest 3 | from http_responses import RerankResponse 4 | 5 | 6 | class RerankService(ABC): 7 | @abstractmethod 8 | def rerank(self, request: RerankRequest) -> RerankResponse: 9 | pass 10 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '3.1' 2 | services: 3 | peanutshell: 4 | image: ghcr.io/sugarforever/peanut-shell:latest 5 | pull_policy: always 6 | environment: 7 | - HF_ENDPOINT=${HF_ENDPOINT:-https://huggingface.co} 8 | ports: 9 | - "8000:8000" 10 | volumes: 11 | - hf_data:/root/.cache 12 | 13 | volumes: 14 | hf_data: 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "peanut-shell" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["sugarforever "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.11" 10 | fastapi = "^0.111.0" 11 | uvicorn = "^0.29.0" 12 | pydantic = "^2.7.1" 13 | sentence-transformers = "3.0.0" 14 | requests = "^2.31.0" 15 | transformers = "4.41.2" 16 | 17 | 18 | [build-system] 19 | requires = ["poetry-core"] 20 | build-backend = "poetry.core.masonry.api" 21 | -------------------------------------------------------------------------------- /http_requests/rerank/rerank_request.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | CROSS_ENCODER_MODELS = [ 4 | "ms-marco-TinyBERT-L-2-v2", 5 | "ms-marco-MiniLM-L-2-v2", 6 | "ms-marco-MiniLM-L-4-v2", 7 | "ms-marco-MiniLM-L-6-v2", 8 | "ms-marco-MiniLM-L-12-v2", 9 | "ms-marco-TinyBERT-L-2", 10 | "ms-marco-TinyBERT-L-4", 11 | "ms-marco-TinyBERT-L-6", 12 | "ms-marco-electra-base" 13 | ] 14 | 15 | class RerankRequest(BaseModel): 16 | model: str 17 | query: str 18 | top_n: int 19 | documents: list[str] 20 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim as builder 2 | 3 | WORKDIR /app 4 | COPY pyproject.toml poetry.lock README.md ./ 5 | RUN pip3 install --no-cache-dir poetry && \ 6 | poetry config virtualenvs.create false && \ 7 | poetry install --no-root --no-dev && \ 8 | pip uninstall -y poetry 9 | 10 | FROM python:3.11-slim 11 | 12 | WORKDIR /app 13 | 14 | COPY --from=builder /usr/local/lib/python3.11/site-packages/ /usr/local/lib/python3.11/site-packages/ 15 | COPY --from=builder /usr/local/bin/ /usr/local/bin/ 16 | COPY . . 17 | 18 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] -------------------------------------------------------------------------------- /endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException 2 | from http_requests import RerankRequest, CROSS_ENCODER_MODELS 3 | from http_responses import RerankModelsResponse, RerankResponse 4 | from services import CrossEncoderRerankService 5 | 6 | router = APIRouter() 7 | 8 | 9 | @router.get("/models/", response_model=RerankModelsResponse) 10 | def models(): 11 | return RerankModelsResponse(models=CROSS_ENCODER_MODELS) 12 | 13 | 14 | @router.post("/rerank/", response_model=RerankResponse) 15 | def rerank(request: RerankRequest): 16 | if request.model not in CROSS_ENCODER_MODELS: 17 | detail = f"Invalid model name: {request.model}. Available models: {','.join(CROSS_ENCODER_MODELS)}" 18 | raise HTTPException(status_code=400, detail=detail) 19 | 20 | if len(request.documents) == 0: 21 | return RerankResponse(results=[]) 22 | 23 | model_name = f"cross-encoder/{request.model}" 24 | service = CrossEncoderRerankService(modelName=model_name) 25 | return service.rerank(request) 26 | -------------------------------------------------------------------------------- /.github/workflows/docker-image.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Publish Docker Image 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build-and-publish: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Delete cached images 14 | run: | 15 | sudo docker system prune --all --force 16 | - name: Checkout code 17 | uses: actions/checkout@v2 18 | 19 | - name: Set up Docker Buildx 20 | uses: docker/setup-buildx-action@v1 21 | 22 | - name: Log in to GitHub Container Registry 23 | run: echo "${{ secrets.SECRET_GITHUB_TOKEN }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin 24 | 25 | - name: Build and push Docker image 26 | uses: docker/build-push-action@v2 27 | with: 28 | context: . 29 | push: true 30 | tags: | 31 | ghcr.io/${{ github.repository_owner }}/peanut-shell:${{ github.sha }} 32 | ghcr.io/${{ github.repository_owner }}/peanut-shell:latest 33 | platforms: linux/amd64,linux/arm64,linux/arm64/v8 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 sugarforever 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /services/cross_encoder/cross_encoder_rerank_service.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import CrossEncoder 2 | from models import RerankResult 3 | from http_requests import RerankRequest 4 | from http_responses import RerankResponse 5 | from services.rerank_service import RerankService 6 | 7 | class CrossEncoderRerankService(RerankService): 8 | 9 | def __init__(self, modelName: str) -> None: 10 | super().__init__() 11 | 12 | try: 13 | self.cross_encoder = CrossEncoder(model_name=modelName, local_files_only=True) 14 | except Exception as e: 15 | print(f"Cached model files may not exist. Failed to load model {modelName} with local_files_only=True. Error: {e}") 16 | print(f"Attempting to load model {modelName} from remote repository") 17 | self.cross_encoder = CrossEncoder(model_name=modelName, local_files_only=False) 18 | 19 | def rerank(self, request: RerankRequest) -> RerankResponse: 20 | pairs = [] 21 | for document in request.documents: 22 | pairs.append([request.query, document]) 23 | 24 | scores = self.cross_encoder.predict(pairs) 25 | 26 | scored_pairs = list(zip(scores, range(len(request.documents)))) 27 | scored_pairs.sort(key=lambda x: x[0], reverse=True) 28 | 29 | # Generate a list of indices in descending order of their corresponding scores 30 | results = [RerankResult(index=index, relevance_score=score) for score, index in scored_pairs] 31 | 32 | return RerankResponse(results=results[:request.top_n]) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Peanut Shell 2 | 3 | 4 | 5 | `Peanut Shell` is a reranking service for better RAG performance. The models supported are from [CrossEncoder](https://huggingface.co/cross-encoder). 6 | 7 | Check out the models supported so far: 8 | 9 | - ms-marco-TinyBERT-L-2-v2 10 | - ms-marco-MiniLM-L-2-v2 11 | - ms-marco-MiniLM-L-4-v2 12 | - ms-marco-MiniLM-L-6-v2 13 | - ms-marco-MiniLM-L-12-v2 14 | - ms-marco-TinyBERT-L-2 15 | - ms-marco-TinyBERT-L-4 16 | - ms-marco-TinyBERT-L-6 17 | - ms-marco-electra-base 18 | 19 | For more details, please refer to the models table on [cross-encoder/ms-marco-MiniLM-L-6-v2](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-6-v2). 20 | 21 | ## User Guide 22 | 23 | It's recommended to run `Peanut Shell` as a docker container. 24 | 25 | ### How to install 26 | 27 | Use the commands below to clone this repo and run the docker container. 28 | 29 | ```shell 30 | $ git clone git@github.com:sugarforever/peanut-shell.git 31 | $ cd peanut-shell 32 | $ docker compose up -d 33 | ``` 34 | 35 | ### How to use 36 | 37 | There are 2 APIs available. Refer to the example below on how to make API call. 38 | 39 | #### List supported models 40 | 41 | ```shell 42 | curl --location 'http://localhost:8000/v1/models' 43 | ``` 44 | 45 | Expected response 46 | 47 | ```json 48 | { 49 | "models": [ 50 | "ms-marco-TinyBERT-L-2-v2", 51 | "ms-marco-MiniLM-L-2-v2", 52 | "ms-marco-MiniLM-L-4-v2", 53 | "ms-marco-MiniLM-L-6-v2", 54 | "ms-marco-MiniLM-L-12-v2", 55 | "ms-marco-TinyBERT-L-2", 56 | "ms-marco-TinyBERT-L-4", 57 | "ms-marco-TinyBERT-L-6", 58 | "ms-marco-electra-base" 59 | ] 60 | } 61 | ``` 62 | 63 | #### Rerank documents 64 | 65 | ```shell 66 | curl --location 'http://localhost:8000/v1/rerank/' \ 67 | --header 'Content-Type: application/json' \ 68 | --data '{ 69 | "model": "ms-marco-MiniLM-L-6-v2", 70 | "query": "What is the capital of the United States?", 71 | "top_n": 3, 72 | "documents": [ 73 | "Carson City is the capital city of the American state of Nevada.", 74 | "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", 75 | "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", 76 | "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states." 77 | ] 78 | }' 79 | ``` 80 | 81 | You should expect similar response as below: 82 | 83 | ```json 84 | { 85 | "results": [ 86 | { 87 | "index": 2, 88 | "relevance_score": 8.186723709106445 89 | }, 90 | { 91 | "index": 3, 92 | "relevance_score": 0.2271757274866104 93 | }, 94 | { 95 | "index": 0, 96 | "relevance_score": -2.0371546745300293 97 | } 98 | ] 99 | } 100 | ``` 101 | 102 | ## API Documentation 103 | 104 | When you launch the API on localhost, you should be able to view the API documentation on [http://localhost:8000/docs](http://localhost:8000/docs). 105 | 106 | ![API Doc](./docs/images/api-doc.png) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .DS_Store --------------------------------------------------------------------------------