├── .gitignore ├── Dockerfile ├── README.md ├── localembedding.py ├── requirements.txt └── script └── build.sh /.gitignore: -------------------------------------------------------------------------------- 1 | HELP.md 2 | target/ 3 | !.mvn/wrapper/maven-wrapper.jar 4 | !**/src/main/** 5 | !**/src/test/** 6 | /logs* 7 | 8 | ### STS ### 9 | .apt_generated 10 | .classpath 11 | .factorypath 12 | .project 13 | .settings 14 | .springBeans 15 | .sts4-cache 16 | 17 | ### IntelliJ IDEA ### 18 | .idea 19 | *.iws 20 | *.iml 21 | *.ipr 22 | 23 | ### NetBeans ### 24 | /nbproject/private/ 25 | /nbbuild/ 26 | /dist/ 27 | /nbdist/ 28 | /.nb-gradle/ 29 | build/ 30 | 31 | ### VS Code ### 32 | .vscode/ 33 | /**/application-hccake.yml 34 | /**/application-preview.yml 35 | /aty/ 36 | /ballcat-job/logs/ballcat-job/ 37 | .flattened-pom.xml 38 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 构建参数 2 | ARG ARCH 3 | 4 | # 使用官方Python运行时作为父镜像 5 | FROM ${ARCH}/python:3.10-bullseye 6 | 7 | # 设置工作目录 8 | WORKDIR /app 9 | 10 | # 将当前目录内容复制到容器的/app中 11 | ADD . /app 12 | 13 | RUN pip install --upgrade pip 14 | # 安装程序需要的包 15 | RUN pip install --no-cache-dir -r requirements.txt 16 | 17 | # 运行时监听的端口 18 | EXPOSE 6008 19 | 20 | # 运行app.py时的命令及其参数 21 | CMD ["uvicorn", "localembedding:app", "--host", "0.0.0.0", "--port", "6008"] 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | m3e-large-api 2 | -------------------------------------------------------------------------------- /localembedding.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Depends, HTTPException, status,Request 2 | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials 3 | from sentence_transformers import SentenceTransformer 4 | from pydantic import BaseModel, Field 5 | from fastapi.middleware.cors import CORSMiddleware 6 | import uvicorn 7 | import tiktoken 8 | import numpy as np 9 | from scipy.interpolate import interp1d 10 | from typing import List, Literal, Optional, Union,Dict 11 | from sklearn.preprocessing import PolynomialFeatures 12 | import torch 13 | import os 14 | import time 15 | 16 | 17 | #环境变量传入 18 | sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk') 19 | 20 | # 创建一个FastAPI实例 21 | app = FastAPI() 22 | 23 | app.add_middleware( 24 | CORSMiddleware, 25 | allow_origins=["*"], 26 | allow_credentials=True, 27 | allow_methods=["*"], 28 | allow_headers=["*"], 29 | ) 30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 检测是否有GPU可用,如果有则使用cuda设备,否则使用cpu设备 31 | if torch.cuda.is_available(): 32 | print('本次加载模型的设备为GPU: ', torch.cuda.get_device_name(0)) 33 | else: 34 | print('本次加载模型的设备为CPU.') 35 | model = SentenceTransformer('./moka-ai_m3e-large',device=device) 36 | 37 | # 创建一个HTTPBearer实例 38 | security = HTTPBearer() 39 | 40 | 41 | 42 | class ChatMessage(BaseModel): 43 | role: Literal["user", "assistant", "system"] 44 | content: str 45 | 46 | 47 | class DeltaMessage(BaseModel): 48 | role: Optional[Literal["user", "assistant", "system"]] = None 49 | content: Optional[str] = None 50 | 51 | class ChatCompletionRequest(BaseModel): 52 | model: str 53 | messages: List[ChatMessage] 54 | temperature: Optional[float] = None 55 | top_p: Optional[float] = None 56 | max_length: Optional[int] = None 57 | stream: Optional[bool] = False 58 | 59 | 60 | class ChatCompletionResponseChoice(BaseModel): 61 | index: int 62 | message: ChatMessage 63 | finish_reason: Literal["stop", "length"] 64 | 65 | 66 | class ChatCompletionResponseStreamChoice(BaseModel): 67 | index: int 68 | delta: DeltaMessage 69 | finish_reason: Optional[Literal["stop", "length"]] 70 | 71 | class ChatCompletionResponse(BaseModel): 72 | model: str 73 | object: Literal["chat.completion", "chat.completion.chunk"] 74 | choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] 75 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 76 | 77 | class EmbeddingRequest(BaseModel): 78 | input: List[str] 79 | model: str 80 | 81 | class EmbeddingResponse(BaseModel): 82 | data: list 83 | model: str 84 | object: str 85 | usage: dict 86 | 87 | def num_tokens_from_string(string: str) -> int: 88 | """Returns the number of tokens in a text string.""" 89 | encoding = tiktoken.get_encoding('cl100k_base') 90 | num_tokens = len(encoding.encode(string)) 91 | return num_tokens 92 | 93 | # 插值法 94 | def interpolate_vector(vector, target_length): 95 | original_indices = np.arange(len(vector)) 96 | target_indices = np.linspace(0, len(vector)-1, target_length) 97 | f = interp1d(original_indices, vector, kind='linear') 98 | return f(target_indices) 99 | 100 | def expand_features(embedding, target_length): 101 | poly = PolynomialFeatures(degree=2) 102 | expanded_embedding = poly.fit_transform(embedding.reshape(1, -1)) 103 | expanded_embedding = expanded_embedding.flatten() 104 | if len(expanded_embedding) > target_length: 105 | # 如果扩展后的特征超过目标长度,可以通过截断或其他方法来减少维度 106 | expanded_embedding = expanded_embedding[:target_length] 107 | elif len(expanded_embedding) < target_length: 108 | # 如果扩展后的特征少于目标长度,可以通过填充或其他方法来增加维度 109 | expanded_embedding = np.pad(expanded_embedding, (0, target_length - len(expanded_embedding))) 110 | return expanded_embedding 111 | 112 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 113 | async def create_chat_completion(request: ChatCompletionRequest, credentials: HTTPAuthorizationCredentials = Depends(security)): 114 | if credentials.credentials != sk_key: 115 | raise HTTPException( 116 | status_code=status.HTTP_401_UNAUTHORIZED, 117 | detail="Invalid authorization code", 118 | ) 119 | choice_data = ChatCompletionResponseChoice( 120 | index=0, 121 | message=ChatMessage(role="assistant", content='你说得对,但这个是向量模型不能对话'), 122 | finish_reason="stop" 123 | ) 124 | 125 | return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") 126 | 127 | @app.post("/v1/embeddings", response_model=EmbeddingResponse) 128 | async def get_embeddings(http_request: Request, request: EmbeddingRequest, credentials: HTTPAuthorizationCredentials = Depends(security)): 129 | client_host = http_request.client.host 130 | headers = http_request.headers 131 | print(f"Client IP: {client_host}") 132 | print(f"Request headers: {headers}") 133 | 134 | if credentials.credentials != sk_key: 135 | raise HTTPException( 136 | status_code=status.HTTP_401_UNAUTHORIZED, 137 | detail="Invalid authorization code", 138 | ) 139 | 140 | # 计算嵌入向量和tokens数量 141 | embeddings = [model.encode(text) for text in request.input] 142 | 143 | 144 | # 如果嵌入向量的维度不为1536,则使用插值法扩展至1536维度 145 | # embeddings = [interpolate_vector(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings] 146 | # 如果嵌入向量的维度不为1536,则使用特征扩展法扩展至1536维度 147 | embeddings = [expand_features(embedding, 1536) if len(embedding) < 1536 else embedding for embedding in embeddings] 148 | 149 | # Min-Max normalization 150 | # embeddings = [(embedding - np.min(embedding)) / (np.max(embedding) - np.min(embedding)) if np.max(embedding) != np.min(embedding) else embedding for embedding in embeddings] 151 | embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings] 152 | # 将numpy数组转换为列表 153 | embeddings = [embedding.tolist() for embedding in embeddings] 154 | prompt_tokens = sum(len(text.split()) for text in request.input) 155 | total_tokens = sum(num_tokens_from_string(text) for text in request.input) 156 | 157 | 158 | response = { 159 | "data": [ 160 | { 161 | "embedding": embedding, 162 | "index": index, 163 | "object": "embedding" 164 | } for index, embedding in enumerate(embeddings) 165 | ], 166 | "model": request.model, 167 | "object": "list", 168 | "usage": { 169 | "prompt_tokens": prompt_tokens, 170 | "total_tokens": total_tokens, 171 | } 172 | } 173 | 174 | 175 | return response 176 | 177 | if __name__ == "__main__": 178 | # 预加载模型 179 | 180 | uvicorn.run("localembedding:app", host='0.0.0.0', port=6008, workers=1) 181 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.99.1 2 | pydantic==1.10.7 3 | sentence-transformers==2.2.2 4 | uvicorn==0.23.1 5 | tiktoken==0.4.0 6 | numpy==1.24.4 7 | scipy==1.10.1 8 | scikit-learn==1.3.0 9 | torchvision 10 | torchaudio 11 | torch -------------------------------------------------------------------------------- /script/build.sh: -------------------------------------------------------------------------------- 1 | docker build --build-arg ARCH=arm64v8 -t wavecode:m3e-large-api . --------------------------------------------------------------------------------