├── src └── replit │ ├── ai │ ├── __init__.py │ └── modelfarm │ │ ├── structs │ │ ├── __init__.py │ │ ├── shared.py │ │ ├── google.py │ │ ├── embeddings.py │ │ ├── completions.py │ │ └── chat.py │ │ ├── identity │ │ ├── __init__.py │ │ ├── goval │ │ │ ├── api │ │ │ │ ├── __init__.py │ │ │ │ ├── features │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── features_pb2.py │ │ │ │ ├── repl │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── repl_pb2.py │ │ │ │ ├── signing_pb2.py │ │ │ │ └── client_pb2.py │ │ │ └── __init__.py │ │ ├── exceptions.py │ │ ├── sign.py │ │ └── verify.py │ │ ├── exceptions.py │ │ ├── google │ │ ├── language_models │ │ │ ├── __init__.py │ │ │ ├── text_embedding_model.py │ │ │ └── text_generation_model.py │ │ ├── preview │ │ │ └── language_models │ │ │ │ ├── __init__.py │ │ │ │ └── chat_model.py │ │ ├── utils.py │ │ └── structs.py │ │ ├── __init__.py │ │ ├── config.py │ │ ├── replit_identity_token_manager.py │ │ ├── embeddings.py │ │ ├── client.py │ │ ├── completions.py │ │ └── chat_completions.py │ └── tests │ └── ai │ └── modelfarm │ ├── __init__.py │ ├── conftest.py │ ├── test_loadmodel.py │ ├── test_config.py │ ├── google │ ├── language_models │ │ ├── test_text_embedding_model.py │ │ └── test_text_generation_model.py │ └── preview │ │ └── language_model │ │ └── test_chat_model.py │ ├── test_embeddings.py │ ├── test_identity.py │ ├── test_completions.py │ └── test_chat_completions.py ├── RELEASING ├── .gitignore ├── README.md ├── LICENSE └── pyproject.toml /src/replit/ai/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/structs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/goval/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/goval/api/features/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/goval/api/repl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/goval/__init__.py: -------------------------------------------------------------------------------- 1 | """The replit Python package.""" 2 | -------------------------------------------------------------------------------- /RELEASING: -------------------------------------------------------------------------------- 1 | Update version marker in: 2 | 3 | - ./pyproject.toml 4 | 5 | $ poetry publish --build -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | /dist/ 3 | __pycache__/ 4 | 5 | # Replit 6 | .breakpoints 7 | .replit 8 | replit.nix 9 | .pythonlibs/ 10 | 11 | # Environments 12 | .env 13 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/structs/shared.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Usage(BaseModel): 5 | completion_tokens: int 6 | prompt_tokens: int 7 | total_tokens: int 8 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/exceptions.py: -------------------------------------------------------------------------------- 1 | """Exceptions that can be thrown by identity.""" 2 | 3 | 4 | class VerifyError(Exception): 5 | """An Exception occurred when a verification process failed.""" 6 | 7 | pass 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Replit AI 2 | 3 | ## Modelfarm 4 | 5 | A library for building AI applications in Python. 6 | 7 | [Model farm general documentation](https://docs.replit.com/model-farm/) 8 | 9 | [Python library API reference](https://docs.replit.com/model-farm/python/) 10 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/exceptions.py: -------------------------------------------------------------------------------- 1 | class BadRequestException(ValueError): 2 | """Exception raised for a bad request.""" 3 | 4 | pass 5 | 6 | 7 | class InvalidResponseException(ValueError): 8 | """Exception raised for an invalid response.""" 9 | 10 | pass 11 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from replit.ai.modelfarm import AsyncModelfarm, Modelfarm 3 | 4 | 5 | @pytest.fixture 6 | def client() -> Modelfarm: 7 | return Modelfarm() 8 | 9 | 10 | @pytest.fixture 11 | def async_client() -> AsyncModelfarm: 12 | return AsyncModelfarm() 13 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/google/language_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_embedding_model import TextEmbedding as TextEmbedding 2 | from .text_embedding_model import TextEmbeddingModel as TextEmbeddingModel 3 | from .text_generation_model import TextGenerationModel as TextGenerationModel 4 | from .text_generation_model import TextGenerationResponse as TextGenerationResponse 5 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/google/preview/language_models/__init__.py: -------------------------------------------------------------------------------- 1 | from ...structs import ( 2 | TextGenerationResponse as TextGenerationResponse, 3 | ) 4 | from .chat_model import ( 5 | ChatMessage as ChatMessage, 6 | ) 7 | from .chat_model import ( 8 | ChatModel as ChatModel, 9 | ) 10 | from .chat_model import ( 11 | InputOutputTextPair as InputOutputTextPair, 12 | ) 13 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/test_loadmodel.py: -------------------------------------------------------------------------------- 1 | from replit.ai.modelfarm import Modelfarm 2 | 3 | 4 | def test_loadmodel_complete_endpoint(): 5 | client = Modelfarm() 6 | response = client.completions.create( 7 | model="loadtesting", 8 | prompt=["1 + 1 = "], 9 | ) 10 | 11 | assert len(response.choices) == 1 12 | 13 | choice = response.choices[0] 14 | 15 | assert "Content!" in choice.text 16 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/structs/google.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class TokenCountMetadata(BaseModel): 7 | billableTokens: int = 0 8 | unbilledTokens: int = 0 9 | billableCharacters: int = 0 10 | unbilledCharacters: int = 0 11 | 12 | 13 | class GoogleMetadata(BaseModel): 14 | inputTokenCount: Optional[TokenCountMetadata] = None 15 | outputTokenCount: Optional[TokenCountMetadata] = None 16 | 17 | 18 | class GoogleEmbeddingMetadata(BaseModel): 19 | tokenCountMetadata: Optional[TokenCountMetadata] = None 20 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import AsyncModelfarm, Modelfarm 2 | from .structs.chat import ( 3 | ChatCompletionMessageRequestParam, 4 | ChatCompletionResponse, 5 | ChatCompletionStreamChunkResponse, 6 | ) 7 | from .structs.completions import CompletionModelResponse, PromptParameter 8 | from .structs.embeddings import EmbeddingModelResponse 9 | 10 | __version__ = "1.0.0" 11 | 12 | __all__ = [ 13 | "AsyncModelfarm", 14 | "Modelfarm", 15 | "ChatCompletionMessageRequestParam", 16 | "ChatCompletionResponse", 17 | "ChatCompletionStreamChunkResponse", 18 | "CompletionModelResponse", 19 | "PromptParameter", 20 | "EmbeddingModelResponse", 21 | ] 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Replit 2 | 3 | Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. 4 | 5 | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 6 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/structs/embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, TypeAlias, Union 2 | 3 | from pydantic import BaseModel 4 | 5 | from .google import GoogleEmbeddingMetadata 6 | from .shared import Usage 7 | 8 | ## 9 | # Request params 10 | ## 11 | 12 | InputParameter: TypeAlias = Union[str, List[str], List[int], List[List[int]]] 13 | 14 | ## 15 | # Response models 16 | ## 17 | 18 | 19 | class Embedding(BaseModel): 20 | object: str 21 | embedding: List[float] 22 | index: int 23 | metadata: Optional[Dict[str, Any]] 24 | 25 | 26 | class EmbeddingModelResponse(BaseModel): 27 | object: str 28 | data: List[Embedding] 29 | model: str 30 | usage: Optional[Usage] 31 | metadata: Optional[GoogleEmbeddingMetadata] 32 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/test_config.py: -------------------------------------------------------------------------------- 1 | from replit.ai.modelfarm.config import get_config, initialize 2 | 3 | 4 | def test_config_initialization(): 5 | 6 | old_config = get_config() 7 | old_rootUrl = old_config.rootUrl 8 | old_audience = old_config.audience 9 | assert old_rootUrl is not None 10 | assert old_audience is not None 11 | 12 | initialize("https://new-url.com", "new_audience") 13 | 14 | new_config = get_config() 15 | 16 | assert new_config.rootUrl == "https://new-url.com" 17 | assert new_config.audience == "new_audience" 18 | 19 | # Reset config back to original 20 | initialize(old_rootUrl, old_audience) 21 | new_config2 = get_config() 22 | assert new_config2.rootUrl == old_rootUrl 23 | assert new_config2.audience == old_audience 24 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class Config: 6 | """Config for the Model Farm API client.""" 7 | 8 | rootUrl: str = "https://production-modelfarm.replit.com" 9 | audience: str = "modelfarm@replit.com" 10 | 11 | 12 | _config = Config() 13 | 14 | 15 | def initialize(rootUrl=None, serverAudience=None): 16 | """Initializes the global config for the Model Farm API client.""" 17 | if rootUrl: 18 | _config.rootUrl = rootUrl 19 | if serverAudience: 20 | _config.audience = serverAudience 21 | 22 | 23 | def get_config() -> Config: 24 | """Returns the global config for the Model Farm API client. 25 | 26 | Returns: 27 | Config: the global config for the Model Farm API client. 28 | """ 29 | return _config 30 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/structs/completions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, TypeAlias, Union 2 | 3 | from pydantic import BaseModel 4 | 5 | from .google import GoogleMetadata 6 | from .shared import Usage 7 | 8 | ## 9 | # Request params 10 | ## 11 | 12 | PromptParameter: TypeAlias = Optional[Union[str, List[str], List[int], 13 | List[List[int]]]] 14 | 15 | ## 16 | # Response models 17 | ## 18 | 19 | 20 | class Choice(BaseModel): 21 | index: int 22 | text: str 23 | finish_reason: str 24 | logprobs: Optional[Dict[str, Any]] = None 25 | metadata: Optional[Dict[str, Any]] = None 26 | 27 | 28 | class CompletionModelResponse(BaseModel): 29 | id: str 30 | choices: List[Choice] 31 | model: str 32 | created: Optional[int] = None 33 | object: Optional[str] = None 34 | usage: Optional[Usage] = None 35 | metadata: Optional[GoogleMetadata] = None 36 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/google/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | _PROVIDER_EXTRA_PARAMS = {"context", "examples", "top_k"} 4 | 5 | 6 | def ready_parameters(parameters: Dict[str, Any]) -> Dict[str, Any]: 7 | """ 8 | Private method to prep parameter dict keys to send to API. 9 | 10 | Args: 11 | parameters (Dict[str, Any]): Dictionary of parameters. 12 | 13 | Returns: 14 | Dict[str, any]: New dictionary with keys in correct format for API. 15 | """ 16 | remap = { 17 | "max_output_tokens": "max_tokens", 18 | "candidate_count": "n", 19 | "stop_sequences": "stop", 20 | } 21 | params = {remap.get(k, k): v for k, v in parameters.items()} 22 | 23 | provider_extra_parameters = { 24 | k: params.pop(k) 25 | for k in _PROVIDER_EXTRA_PARAMS if k in params 26 | } 27 | params["provider_extra_parameters"] = provider_extra_parameters 28 | 29 | return params 30 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/google/language_models/test_text_embedding_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from replit.ai.modelfarm.google.language_models import TextEmbeddingModel 3 | 4 | 5 | def test_text_embedding_model_get_embeddings(): 6 | model = TextEmbeddingModel.from_pretrained("textembedding-gecko@001") 7 | embeddings = model.get_embeddings(["What is life?"]) 8 | for embedding in embeddings: 9 | assert len(embedding.values) == 768 10 | assert embedding.statistics.truncated is False 11 | assert embedding.statistics.token_count == 4 12 | 13 | 14 | @pytest.mark.asyncio 15 | async def test_text_embedding_model_async_get_embeddings(): 16 | model = TextEmbeddingModel.from_pretrained("textembedding-gecko@001") 17 | embeddings = await model.async_get_embeddings(["What is life?"]) 18 | for embedding in embeddings: 19 | assert len(embedding.values) == 768 20 | assert embedding.statistics.truncated is False 21 | assert embedding.statistics.token_count == 4 22 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/test_embeddings.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from replit.ai.modelfarm import AsyncModelfarm, Modelfarm 3 | 4 | # module level constants 5 | CONTENT = ["1 + 1 = "] 6 | 7 | MODEL = "textembedding-gecko" 8 | 9 | 10 | def test_embed_model_embed(client: Modelfarm) -> None: 11 | response = client.embeddings.create(input=CONTENT, model=MODEL) 12 | assert len(response.data) == 1 13 | 14 | embedding = response.data[0] 15 | 16 | assert len(embedding.embedding) == 768 17 | 18 | assert embedding.metadata is not None 19 | assert embedding.metadata["truncated"] is False 20 | assert embedding.metadata["tokenCountMetadata"]["unbilledTokens"] == 4 21 | 22 | 23 | @pytest.mark.asyncio 24 | async def test_embed_model_async_embed(async_client: AsyncModelfarm) -> None: 25 | response = await async_client.embeddings.create(input=CONTENT, model=MODEL) 26 | assert len(response.data) == 1 27 | 28 | embedding = response.data[0] 29 | 30 | assert len(embedding.embedding) == 768 31 | 32 | assert embedding.metadata is not None 33 | assert embedding.metadata["truncated"] is False 34 | assert embedding.metadata["tokenCountMetadata"]["unbilledTokens"] == 4 35 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/google/structs.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class GoogleCitation(BaseModel): 7 | startIndex: int 8 | endIndex: int 9 | url: int 10 | title: str 11 | license: str 12 | # Documented as "Possible formats are YYYY, YYYY-MM, YYYY-MM-DD." 13 | publicationDate: str 14 | 15 | 16 | class GoogleCitationMetadata(BaseModel): 17 | citations: List[GoogleCitation] = [] 18 | 19 | 20 | class GoogleSafetyAttributes(BaseModel): 21 | blocked: bool = False 22 | categories: List[str] = [] 23 | scores: List[float] = [] 24 | 25 | 26 | class GooglePredictionMetadata(BaseModel): 27 | safetyAttributes: Optional[GoogleSafetyAttributes] = None 28 | citationMetadata: Optional[GoogleCitationMetadata] = None 29 | 30 | 31 | class TextGenerationResponse(BaseModel): 32 | """ 33 | Class representing the response from text generation model. 34 | 35 | Attributes: 36 | is_blocked (bool): Flag indicating whether the output was blocked due to 37 | content safety filters. 38 | raw_prediction_response (Dict[str, Any]): Raw response from the AI model. 39 | safety_attributes (Dict[str, float]): Dictionary with safety attributes of 40 | the generated text. 41 | text (str): Generated text. 42 | """ 43 | 44 | is_blocked: bool 45 | raw_prediction_response: Dict[str, Any] 46 | safety_attributes: Optional[Dict[str, float]] = None 47 | text: str 48 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "replit.ai" 3 | version = "1.0.0" 4 | description = "A library for interacting with AI features of replit" 5 | authors = ["Repl.it "] 6 | license = "ISC" 7 | readme = "README.md" 8 | repository = "https://github.com/replit/replit-ai-python" 9 | homepage = "https://github.com/replit/replit-ai-python" 10 | documentation = "https://replit-ai-python.readthedocs.org/" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: ISC License (ISCL)", 14 | "Operating System :: OS Independent", 15 | ] 16 | packages = [ 17 | { include = "replit", from = "src" } 18 | ] 19 | 20 | [tool.poetry.dependencies] 21 | python = ">=3.8, <4.0" 22 | requests = "^2.31.0" 23 | pydantic = "^2.3.0" 24 | aiohttp = "^3.8.5" 25 | pytest-asyncio = "^0.21.1" 26 | pyseto = "^1.7.3" 27 | google-api-python-client = "^2.98.0" 28 | 29 | 30 | [tool.pyright] 31 | # https://github.com/microsoft/pyright/blob/main/docs/configuration.md 32 | useLibraryCodeForTypes = true 33 | 34 | 35 | [tool.poetry.dev-dependencies] 36 | Flask = "^2.2.0" 37 | pytest = "^7.4.1" 38 | 39 | 40 | [build-system] 41 | requires = ["poetry>=0.12"] 42 | build-backend = "poetry.masonry.api" 43 | 44 | [tool.poetry.scripts] 45 | replit = "replit.__main__:cli" 46 | 47 | [tool.mypy] 48 | exclude = [ 49 | "_pb2.py$", # Generated code 50 | ] 51 | 52 | [tool.ruff] 53 | # https://beta.ruff.rs/docs/configuration/ 54 | select = ['E', 'W', 'F', 'I', 'B', 'C4', 'ARG', 'SIM'] 55 | ignore = ['W291', 'W292', 'W293'] 56 | extend-exclude = ["*_pb2.py"] 57 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/goval/api/features/features_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: replit/goval/api/features/features.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import symbol_database as _symbol_database 8 | from google.protobuf.internal import builder as _builder 9 | 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 16 | b'\n(replit/goval/api/features/features.proto\x12\x19replit.goval.api.features"\x05\n\x03Gpu"\t\n\x07\x42oosted"\x8c\x01\n\x07\x46\x65\x61ture\x12-\n\x03gpu\x18\x01 \x01(\x0b\x32\x1e.replit.goval.api.features.GpuH\x00\x12\x35\n\x07\x62oosted\x18\x02 \x01(\x0b\x32".replit.goval.api.features.BoostedH\x00\x12\x10\n\x08required\x18\x03 \x01(\x08\x42\t\n\x07\x66\x65\x61tureB&Z$github.com/replit/goval/api/featuresb\x06proto3' 17 | ) 18 | 19 | _globals = globals() 20 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 21 | _builder.BuildTopDescriptorsAndMessages( 22 | DESCRIPTOR, "replit.goval.api.features.features_pb2", _globals 23 | ) 24 | if _descriptor._USE_C_DESCRIPTORS == False: 25 | DESCRIPTOR._options = None 26 | DESCRIPTOR._serialized_options = b"Z$github.com/replit/goval/api/features" 27 | _globals["_GPU"]._serialized_start = 71 28 | _globals["_GPU"]._serialized_end = 76 29 | _globals["_BOOSTED"]._serialized_start = 78 30 | _globals["_BOOSTED"]._serialized_end = 87 31 | _globals["_FEATURE"]._serialized_start = 90 32 | _globals["_FEATURE"]._serialized_end = 230 33 | # @@protoc_insertion_point(module_scope) 34 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/structs/chat.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from pydantic import BaseModel 4 | from typing_extensions import Required, TypedDict 5 | 6 | from .google import GoogleMetadata 7 | from .shared import Usage 8 | 9 | ## 10 | # Request params 11 | ## 12 | 13 | 14 | class ChatCompletionMessageRequestParam(TypedDict, total=False): 15 | role: Required[str] 16 | content: Optional[str] 17 | tool_calls: Optional[List] 18 | tool_call_id: Optional[str] 19 | 20 | 21 | ## 22 | # Response models 23 | ## 24 | 25 | 26 | class FunctionCall(BaseModel): 27 | name: str 28 | arguments: str 29 | 30 | 31 | class ToolCall(BaseModel): 32 | id: str 33 | type: str 34 | function: FunctionCall 35 | 36 | 37 | class ChoiceMessage(BaseModel): 38 | content: Optional[str] = None 39 | role: Optional[str] = None 40 | tool_calls: Optional[List[ToolCall]] = None 41 | 42 | 43 | class BaseChoice(BaseModel): 44 | index: int 45 | finish_reason: Optional[str] = None 46 | metadata: Optional[Dict[str, Any]] = None 47 | 48 | 49 | class Choice(BaseChoice): 50 | message: ChoiceMessage 51 | 52 | 53 | class ChoiceStream(BaseChoice): 54 | delta: ChoiceMessage 55 | 56 | 57 | class BaseChatCompletionResponse(BaseModel): 58 | id: str 59 | choices: List[BaseChoice] 60 | model: str 61 | created: Optional[int] 62 | object: Optional[str] = None 63 | usage: Optional[Usage] = None 64 | metadata: Optional[GoogleMetadata] = None 65 | 66 | 67 | class ChatCompletionResponse(BaseChatCompletionResponse): 68 | choices: List[Choice] 69 | object: Optional[str] = "chat.completion" 70 | 71 | 72 | class ChatCompletionStreamChunkResponse(BaseChatCompletionResponse): 73 | choices: List[ChoiceStream] 74 | object: Optional[str] = "chat.completion.chunk" 75 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/google/language_models/text_embedding_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | from replit.ai.modelfarm import AsyncModelfarm, Modelfarm 5 | from replit.ai.modelfarm.structs.embeddings import Embedding, EmbeddingModelResponse 6 | 7 | 8 | @dataclass 9 | class TextEmbeddingStatistics: 10 | token_count: float 11 | truncated: bool 12 | 13 | 14 | @dataclass 15 | class TextEmbedding: 16 | statistics: TextEmbeddingStatistics 17 | values: List[float] 18 | 19 | 20 | class TextEmbeddingModel: 21 | 22 | def __init__(self, model_id: str): 23 | self.underlying_model = model_id 24 | self._client = Modelfarm() 25 | self._async_client = AsyncModelfarm() 26 | 27 | @staticmethod 28 | def from_pretrained(model_id: str) -> "TextEmbeddingModel": 29 | return TextEmbeddingModel(model_id) 30 | 31 | # this model only takes in the content parameter and nothing else 32 | def get_embeddings(self, content: List[str]) -> List[TextEmbedding]: 33 | # since this model only takes the content param, we don't pass kwargs 34 | response = self._client.embeddings.create(input=content, 35 | model=self.underlying_model) 36 | return self.__ready_response(response) 37 | 38 | async def async_get_embeddings(self, 39 | content: List[str]) -> List[TextEmbedding]: 40 | # since this model only takes the content param, we don't pass kwargs 41 | response = await self._async_client.embeddings.create( 42 | input=content, model=self.underlying_model) 43 | return self.__ready_response(response) 44 | 45 | def __ready_response( 46 | self, response: EmbeddingModelResponse) -> List[TextEmbedding]: 47 | 48 | def transform_response(x: Embedding) -> TextEmbedding: 49 | metadata = x.metadata or {} 50 | token_metadata = metadata.get("tokenCountMetadata", {}) 51 | tokenCount: int = token_metadata.get( 52 | "unbilledTokens", 0) + token_metadata.get("billableTokens", 0) 53 | stats = TextEmbeddingStatistics(tokenCount, metadata["truncated"]) 54 | return TextEmbedding(stats, x.embedding) 55 | 56 | return [transform_response(x) for x in response.data] 57 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/sign.py: -------------------------------------------------------------------------------- 1 | """This library allows signing identity tokens from Replit.""" 2 | 3 | import base64 4 | 5 | import pyseto 6 | from replit.ai.modelfarm.identity import verify 7 | from replit.ai.modelfarm.identity.goval.api import signing_pb2 8 | 9 | 10 | class SigningAuthority: 11 | """A class to generate tokens that prove identity. 12 | 13 | This class proves the identity of one repl (your own) against another repl 14 | (the audience). Use this to prevent the target repl from spoofing your own 15 | identity by forwarding the token. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | marshaled_private_key: str, 21 | marshaled_identity: str, 22 | replid: str, 23 | pubkey_source: verify.PubKeySource = verify.read_public_key_from_env, 24 | ) -> None: 25 | """Creates a new SigningAuthority. 26 | 27 | Args: 28 | marshaled_private_key: The private key, in PASERK format. 29 | marshaled_identity: The PASETO of the Repl identity. 30 | replid: The ID of the source Repl. 31 | pubkey_source: The PubKeySource to get the public key. 32 | """ 33 | self.identity = verify.verify_identity_token( 34 | marshaled_identity, replid, pubkey_source 35 | ) 36 | self.signing_authority = verify.get_signing_authority(marshaled_identity) 37 | self.private_key = pyseto.Key.from_paserk(marshaled_private_key) 38 | 39 | def sign(self, audience: str) -> str: 40 | """Generates a new token that can be given to the provided audience. 41 | 42 | This is resistant against forwarding, so that the recipient cannot 43 | forward this token to another repl and claim it came directly from you. 44 | 45 | Args: 46 | audience: The audience that the token will be signed for. 47 | 48 | Returns: 49 | The encoded token in PASETO format. 50 | """ 51 | identity = signing_pb2.GovalReplIdentity() 52 | identity.CopyFrom(self.identity) 53 | identity.aud = audience 54 | 55 | encoded_identity = identity.SerializeToString() 56 | encoded_cert = self.signing_authority.SerializeToString() 57 | 58 | return pyseto.encode( 59 | self.private_key, 60 | base64.b64encode(encoded_identity), 61 | base64.b64encode(encoded_cert), 62 | ).decode("utf-8") 63 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/google/language_models/test_text_generation_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from replit.ai.modelfarm.google.language_models import ( 3 | TextGenerationModel, 4 | TextGenerationResponse, 5 | ) 6 | 7 | TEST_PARAMETERS = { 8 | # Temperature controls the degree of randomness in token selection. 9 | "temperature": 0.5, 10 | # Token limit determines the maximum amount of text output. 11 | "max_output_tokens": 256, 12 | # Tokens are selected from most probable to least until the sum of 13 | # their probabilities equals the top_p value. 14 | "top_p": 0.8, 15 | # A top_k of 1 means the selected token is the most probable among all tokens. 16 | "top_k": 40, 17 | } 18 | 19 | 20 | def test_text_generation_model_predict(): 21 | model = TextGenerationModel.from_pretrained("text-bison@001") 22 | response = model.predict( 23 | "Give me ten interview questions for the role of program manager.", 24 | **TEST_PARAMETERS, 25 | ) 26 | validate_response(response) 27 | 28 | 29 | @pytest.mark.asyncio 30 | async def test_text_generation_model_async_predict(): 31 | model = TextGenerationModel.from_pretrained("text-bison@001") 32 | response = await model.async_predict( 33 | "Give me ten interview questions for the role of program manager.", 34 | **TEST_PARAMETERS, 35 | ) 36 | validate_response(response) 37 | 38 | 39 | def test_text_generation_model_predict_streaming(): 40 | model = TextGenerationModel.from_pretrained("text-bison@001") 41 | response = model.predict_streaming( 42 | "Give me 100 interview questions for the role of program manager.", 43 | **TEST_PARAMETERS, 44 | ) 45 | result = list(response) 46 | assert len(result) > 1 47 | for x in result: 48 | validate_response(x) 49 | 50 | 51 | @pytest.mark.asyncio 52 | async def test_text_generation_model_async_predict_streaming(): 53 | model = TextGenerationModel.from_pretrained("text-bison@001") 54 | response = model.async_predict_streaming( 55 | "Give me 100 interview questions for the role of program manager.", 56 | **TEST_PARAMETERS, 57 | ) 58 | result = [] 59 | async for x in response: 60 | result.append(x) 61 | assert len(result) > 1 62 | 63 | for x in result: 64 | validate_response(x) 65 | 66 | 67 | def validate_response(response: TextGenerationResponse): 68 | assert len(response.text) > 1 69 | assert response.is_blocked is False 70 | assert response.safety_attributes is not None 71 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/replit_identity_token_manager.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from typing import Optional 5 | 6 | import requests 7 | from replit.ai.modelfarm.config import get_config 8 | from replit.ai.modelfarm.identity.sign import SigningAuthority 9 | 10 | 11 | class MissingEnvironmentVariable(Exception): 12 | pass 13 | 14 | 15 | class ReplitIdentityTokenManager: 16 | 17 | def __init__(self, token_timeout: int = 300): 18 | """Initializes a new instance of ReplitIdentityTokenManager 19 | 20 | Args: 21 | token_timeout (int): The timeout in seconds for the token. 22 | Default is 300 seconds. 23 | """ 24 | self.token_timeout = token_timeout 25 | self.last_update: Optional[float] = None 26 | self.token: Optional[str] = None 27 | self.__update_token() 28 | 29 | def get_token(self) -> Optional[str]: 30 | """Returns the token, updates if the current token has expired. 31 | 32 | Returns: 33 | str: The token. 34 | """ 35 | if (self.last_update is None 36 | or self.last_update + self.token_timeout < time.time()): 37 | self.__update_token() 38 | return self.token 39 | 40 | def __update_token(self): 41 | """Updates the token and the last_updated time.""" 42 | self.token = self.get_new_token() 43 | self.last_update = time.time() 44 | 45 | def get_new_token(self) -> str: 46 | """Gets the most recent token. 47 | 48 | Returns: 49 | str: The most recent token. 50 | """ 51 | if self.__in_deployment(): 52 | return self.get_deployment_token() 53 | return self.get_interactive_token() 54 | 55 | def get_deployment_token(self) -> str: 56 | """Fetches deployment token from hostindpid1. 57 | 58 | Returns: 59 | str: Deployment token. 60 | """ 61 | response = requests.post( 62 | "http://localhost:1105/getIdentityToken", 63 | json={"audience": get_config().audience}, 64 | ) 65 | return json.loads(response.content)["identityToken"] 66 | 67 | @classmethod 68 | def get_env_var(cls, var: str) -> str: 69 | if var in os.environ: 70 | return os.environ[var] 71 | raise MissingEnvironmentVariable( 72 | f"Did not find the environment variable: {var}") 73 | 74 | def get_interactive_token(self) -> str: 75 | """Generates and returns an identity token" 76 | 77 | Returns: 78 | str: Interactive token. 79 | """ 80 | gsa = SigningAuthority( 81 | marshaled_private_key=self.get_env_var("REPL_IDENTITY_KEY"), 82 | marshaled_identity=self.get_env_var("REPL_IDENTITY"), 83 | replid=self.get_env_var("REPL_ID"), 84 | ) 85 | signed_token = gsa.sign(audience=get_config().audience) 86 | return signed_token 87 | 88 | def __in_deployment(self) -> bool: 89 | """Determines if in deployment environement. 90 | 91 | Returns: 92 | bool: True if in the deployment environment, False otherwise. 93 | """ 94 | return "REPLIT_DEPLOYMENT" in os.environ 95 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/test_identity.py: -------------------------------------------------------------------------------- 1 | """Tests for replit.identity.""" 2 | 3 | import json 4 | import os 5 | from unittest.mock import patch 6 | 7 | import pyseto 8 | from replit.ai.modelfarm.identity import verify 9 | from replit.ai.modelfarm.identity.sign import SigningAuthority 10 | from replit.ai.modelfarm.replit_identity_token_manager import ReplitIdentityTokenManager 11 | 12 | PUBLIC_KEY = "on0FkSmEC+ce40V9Vc4QABXSx6TXo+lhp99b6Ka0gro=" 13 | 14 | # This token should be valid for 100y. 15 | # Generated with: 16 | # ``` 17 | # go run ./cmd/goval_keypairgen/ -eternal -sample-token \ 18 | # -identity -gen-prefix dev -gen-id identity -issuer conman \ 19 | # -replid=test -shortlived=false 20 | # ``` 21 | # in goval. 22 | IDENTITY_PRIVATE_KEY = "k2.secret.6sHU27WoRIaspIOVaShpuZM33ozpfFyI2THfO8fmSX6xiA_Duh4ac5g76Y5bParclsalaOCTaCs6gZowhYivVQ" # noqa: E501,B950,S106 # line too long 23 | IDENTITY_TOKEN = "v2.public.Q2dSMFpYTjBJZ1IwWlhOMHDnO17Eg43zucAMSAHnCS4C1wn4QUCCOcr-Pggw5SV1KnbOXq8RcQE5if6pMcbJ6lmRWcdoHq5CV9jqyRrUlwo.R0FFaUJtTnZibTFoYmhLckFuWXlMbkIxWW14cFl5NVJNbVF6VTFkd2VHRnRXbmRrTVd4U1lXMDViRm96YkVKU1ZrNUZVVmRzYTA1VmQzSlRSVlp2VWtaa2RGbFZVa3BSVmtwMlVUQmtRbFpYUmtOYU1qbEdXa1ZrVjJWdFVrUlRWRVpvWld0c01Wa3dhRmRoVjBwSVlrZHdUV0pyTldGWGFrWkRUVEEwZVU5WGVGTk5hbFpSVmpGVk5HUkhTbFpQVm1oc1lXdHdORlJVUW5kaFZrbDZVV3hvYUdKWFVubFVWekZyWlZaUmVVOVZhRnBXVkVaTFZtcENjMlZWTVZkV1ZEQkRUbVpoYkhkMk5EUm9SRkZQTFVKWlJDMURWSEUxYVdJeFQzVlVlamxIWW5WTlFVVnFURFExUVVwclZXNW9kR2hxVFZOVVRtOVZSRVphWDBsaVUyTjFjekoxWW05aVowNU1MV2RRVlRGRmVVOTFUVzlHTGxJd1JrWmhWVXAwVkc1YWFXSlVSbTlaYldSMlVteHdTRlpxU2xCaGExVTU" # noqa: E501,B950,S106 # line too long 24 | 25 | 26 | def setup_pub_key(): 27 | return json.dumps({"dev:1": PUBLIC_KEY}) 28 | 29 | 30 | def test_read_public_key_from_env() -> None: 31 | """Test read_public_key_from_env.""" 32 | with patch.dict(os.environ, {"REPL_PUBKEYS": setup_pub_key()}): 33 | pubkey = verify.read_public_key_from_env("dev:1", "goval") 34 | assert isinstance(pubkey, pyseto.versions.v2.V2Public) 35 | 36 | 37 | def test_signing_authority() -> None: 38 | """Test SigningAuthority.""" 39 | with patch.dict(os.environ, {"REPL_PUBKEYS": setup_pub_key()}): 40 | gsa = SigningAuthority( 41 | marshaled_private_key=IDENTITY_PRIVATE_KEY, 42 | marshaled_identity=IDENTITY_TOKEN, 43 | replid="test", 44 | ) 45 | signed_token = gsa.sign("audience") 46 | 47 | verify.verify_identity_token( 48 | identity_token=signed_token, 49 | audience="audience", 50 | ) 51 | 52 | 53 | def test_verify_identity_token() -> None: 54 | """Test verify_identity_token.""" 55 | with patch.dict(os.environ, {"REPL_PUBKEYS": setup_pub_key()}): 56 | verify.verify_identity_token( 57 | identity_token=IDENTITY_TOKEN, 58 | audience="test", 59 | ) 60 | 61 | 62 | def test_get_interactive_token() -> None: 63 | with patch.dict( 64 | os.environ, 65 | { 66 | "REPL_PUBKEYS": setup_pub_key(), 67 | "REPL_IDENTITY": IDENTITY_TOKEN, 68 | "REPL_IDENTITY_KEY": IDENTITY_PRIVATE_KEY, 69 | "REPL_ID": "test", 70 | }, 71 | ): 72 | replit_identity_token_manager = ReplitIdentityTokenManager() 73 | signed_token = replit_identity_token_manager.get_interactive_token() 74 | 75 | verify.verify_identity_token( 76 | identity_token=signed_token, 77 | audience="modelfarm@replit.com", 78 | ) 79 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/embeddings.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, Optional 2 | 3 | from replit.ai.modelfarm.structs.embeddings import ( 4 | EmbeddingModelResponse, 5 | InputParameter, 6 | ) 7 | 8 | if TYPE_CHECKING: 9 | from replit.ai.modelfarm import AsyncModelfarm, Modelfarm 10 | 11 | 12 | class Embeddings: 13 | _client: "Modelfarm" 14 | 15 | def __init__(self, client: "Modelfarm") -> None: 16 | self._client = client 17 | 18 | def create( 19 | self, 20 | *, 21 | input: InputParameter, 22 | model: str, 23 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 24 | **kwargs: Any, 25 | ) -> EmbeddingModelResponse: 26 | """ 27 | Makes a prediction based on the input and parameters. 28 | 29 | Args: 30 | input (InputParameter): The input(s) to embed. 31 | model (str): The name of the model. 32 | 33 | Returns: 34 | EmbeddingModelResponse: The response from the model. 35 | """ 36 | response = self._client._post( 37 | "/v1beta2/embeddings", 38 | payload=_build_request_payload( 39 | input, 40 | model, 41 | provider_extra_parameters, 42 | **kwargs, 43 | ), 44 | ) 45 | self._client._check_response(response) 46 | return EmbeddingModelResponse(**response.json()) 47 | 48 | 49 | class AsyncEmbeddings: 50 | _client: "AsyncModelfarm" 51 | 52 | def __init__(self, client: "AsyncModelfarm") -> None: 53 | self._client = client 54 | 55 | async def create( 56 | self, 57 | *, 58 | input: InputParameter, 59 | model: str, 60 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 61 | **kwargs: Any, 62 | ) -> EmbeddingModelResponse: 63 | """ 64 | Makes an asynchronous embedding generation based on the input 65 | and parameters. 66 | 67 | Args: 68 | input (EmbeddingInput): The input(s) to embed. 69 | model (str): The name of the model. 70 | 71 | Returns: 72 | EmbeddingModelResponse: The response from the model. 73 | """ 74 | async with self._client._post( 75 | "/v1beta2/embeddings", 76 | payload=_build_request_payload( 77 | input, 78 | model, 79 | provider_extra_parameters, 80 | **kwargs, 81 | ), 82 | ) as response: 83 | await self._client._check_response(response) 84 | return EmbeddingModelResponse(**await response.json()) 85 | 86 | 87 | def _build_request_payload( 88 | input: InputParameter, 89 | model: str, 90 | provider_extra_parameters: Optional[Dict[str, Any]], 91 | **kwargs: Any, 92 | ) -> Dict[str, Any]: 93 | """ 94 | Builds the request payload. 95 | 96 | Args: 97 | input (InputParameter): The input(s) to embed. 98 | model (str): The name of the model to use. 99 | 100 | Returns: 101 | Dict[str, Any]: The request payload. 102 | """ 103 | 104 | params = { 105 | "model": model, 106 | "input": input, 107 | "provider_extra_parameters": provider_extra_parameters, 108 | **kwargs, 109 | } 110 | 111 | return {k: v for k, v in params.items() if v is not None} 112 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/google/preview/language_model/test_chat_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from replit.ai.modelfarm.google.preview.language_models import ( 3 | ChatModel, 4 | InputOutputTextPair, 5 | ) 6 | from replit.ai.modelfarm.google.structs import TextGenerationResponse 7 | 8 | parameters = { 9 | # Temperature controls the degree of randomness in token selection. 10 | "temperature": 0.5, 11 | # Token limit determines the maximum amount of text output. 12 | "max_output_tokens": 256, 13 | # Tokens are selected from most probable to least until the sum of 14 | # their probabilities equals the top_p value. 15 | "top_p": 0.95, 16 | # A top_k of 1 means the selected token is the most probable among all tokens. 17 | "top_k": 40, 18 | } 19 | 20 | 21 | def test_chat_model_send_message(): 22 | chat_model = ChatModel.from_pretrained("chat-bison@001") 23 | 24 | chat = chat_model.start_chat( 25 | context=("My name is Miles. You are an astronomer, " 26 | "knowledgeable about the solar system."), 27 | examples=[ 28 | InputOutputTextPair( 29 | input_text="How many moons does Mars have?", 30 | output_text="The planet Mars has two moons, Phobos and Deimos.", 31 | ), 32 | ], 33 | ) 34 | 35 | response = chat.send_message( 36 | "How many planets are there in the solar system?", **parameters) 37 | validate_response(response) 38 | 39 | 40 | @pytest.mark.asyncio 41 | async def test_chat_model_async_send_message(): 42 | chat_model = ChatModel.from_pretrained("chat-bison@001") 43 | 44 | chat = chat_model.start_chat( 45 | context=("My name is Miles. You are an astronomer, " 46 | "knowledgeable about the solar system."), 47 | examples=[ 48 | InputOutputTextPair( 49 | input_text="How many moons does Mars have?", 50 | output_text="The planet Mars has two moons, Phobos and Deimos.", 51 | ), 52 | ], 53 | ) 54 | 55 | response = await chat.async_send_message( 56 | "How many planets are there in the solar system?", **parameters) 57 | validate_response(response) 58 | 59 | 60 | def test_chat_model_send_message_stream(): 61 | chat_model = ChatModel.from_pretrained("chat-bison@001") 62 | 63 | chat = chat_model.start_chat( 64 | context=("My name is Miles. You are an astronomer, " 65 | "knowledgeable about the solar system."), 66 | examples=[ 67 | InputOutputTextPair( 68 | input_text="How many moons does Mars have?", 69 | output_text="The planet Mars has two moons, Phobos and Deimos.", 70 | ), 71 | ], 72 | ) 73 | 74 | responses = list( 75 | chat.send_message_stream("Name as many different stars as you can.", 76 | **parameters)) 77 | assert len(responses) > 1 78 | 79 | assert any(len(res.text) > 1 for res in responses) 80 | assert all(response.is_blocked is False for response in responses) 81 | assert all(response.safety_attributes is not None 82 | for response in responses) 83 | 84 | 85 | @pytest.mark.asyncio 86 | async def test_chat_model_async_send_message_stream(): 87 | chat_model = ChatModel.from_pretrained("chat-bison@001") 88 | 89 | chat = chat_model.start_chat( 90 | context=("My name is Miles. You are an astronomer, " 91 | "knowledgeable about the solar system."), 92 | examples=[ 93 | InputOutputTextPair( 94 | input_text="How many moons does Mars have?", 95 | output_text="The planet Mars has two moons, Phobos and Deimos.", 96 | ), 97 | ], 98 | ) 99 | responses = [ 100 | res async for res in chat.async_send_message_stream( 101 | "Name as many different stars as you can.", **parameters) 102 | ] 103 | 104 | assert len(responses) > 1 105 | 106 | assert any(len(res.text) > 1 for res in responses) 107 | assert all(response.is_blocked is False for response in responses) 108 | assert all(response.safety_attributes is not None 109 | for response in responses) 110 | 111 | 112 | def validate_response(response: TextGenerationResponse): 113 | assert len(response.text) > 1 114 | assert response.is_blocked is False 115 | assert response.safety_attributes is not None 116 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/test_completions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import pytest 4 | from replit.ai.modelfarm import AsyncModelfarm, Modelfarm 5 | from replit.ai.modelfarm.exceptions import BadRequestException 6 | 7 | # module level constants 8 | PROMPT = ["1 + 1 = "] 9 | LONG_PROMPT = [ 10 | "A very long answer to the question of what is the meaning of life is " 11 | ] 12 | VALID_KWARGS = { 13 | "top_p": 0.1, 14 | "stop": ["\n"], 15 | "n": 5, 16 | "provider_extra_parameters": { 17 | "top_k": 20, 18 | }, 19 | } 20 | # stream_complete endpoint does not support the candidateCount arg 21 | VALID_GEN_STREAM_KWARGS = { 22 | "max_tokens": 128, 23 | "temperature": 0, 24 | "top_p": 0.1, 25 | "provider_extra_parameters": { 26 | "top_k": 20, 27 | }, 28 | } 29 | INVALID_KWARGS: Dict[str, Any] = { 30 | "invalid_parameter": 0.5, 31 | } 32 | 33 | MODEL = "text-bison" 34 | 35 | 36 | def test_completion_model_complete(client: Modelfarm) -> None: 37 | response = client.completions.create( 38 | prompt=PROMPT, 39 | model=MODEL, 40 | **VALID_KWARGS, 41 | ) 42 | 43 | assert len(response.choices) >= 1 44 | 45 | choice = response.choices[0] 46 | 47 | assert "2" in choice.text 48 | 49 | choice_metadata = choice.metadata 50 | assert choice_metadata is not None 51 | assert choice_metadata["safetyAttributes"]["blocked"] is False 52 | 53 | 54 | def test_completion_model_complete_invalid_parameter( 55 | client: Modelfarm) -> None: 56 | with pytest.raises(BadRequestException): 57 | client.completions.create(prompt=PROMPT, model=MODEL, **INVALID_KWARGS) 58 | 59 | 60 | @pytest.mark.asyncio 61 | async def test_completion_model_async_complete( 62 | async_client: AsyncModelfarm) -> None: 63 | response = await async_client.completions.create(prompt=PROMPT, 64 | model=MODEL, 65 | **VALID_KWARGS) 66 | 67 | assert len(response.choices) >= 1 68 | 69 | choice = response.choices[0] 70 | 71 | assert "2" in choice.text 72 | 73 | choice_metadata = choice.metadata 74 | 75 | assert choice_metadata is not None 76 | assert choice_metadata["safetyAttributes"]["blocked"] is False 77 | 78 | 79 | @pytest.mark.asyncio 80 | async def test_completion_model_async_complete_invalid_parameter( 81 | async_client: AsyncModelfarm) -> None: 82 | with pytest.raises(BadRequestException): 83 | await async_client.completions.create(prompt=PROMPT, 84 | model=MODEL, 85 | **INVALID_KWARGS) 86 | 87 | 88 | def test_completion_model_stream_complete(client: Modelfarm) -> None: 89 | responses = list( 90 | client.completions.create(prompt=LONG_PROMPT, 91 | model=MODEL, 92 | stream=True, 93 | **VALID_GEN_STREAM_KWARGS)) 94 | 95 | assert len(responses) > 1 96 | for response in responses: 97 | assert len(response.choices) == 1 98 | choice = response.choices[0] 99 | assert len(choice.text) > 0 100 | 101 | 102 | def test_completion_model_stream_complete_invalid_parameter( 103 | client: Modelfarm) -> None: 104 | with pytest.raises(BadRequestException): 105 | list( 106 | client.completions.create(prompt=PROMPT, 107 | model=MODEL, 108 | stream=True, 109 | **INVALID_KWARGS)) 110 | 111 | 112 | @pytest.mark.asyncio 113 | async def test_completion_model_async_stream_complete( 114 | async_client: AsyncModelfarm) -> None: 115 | responses = [ 116 | res async for res in await async_client.completions.create( 117 | prompt=LONG_PROMPT, 118 | model=MODEL, 119 | stream=True, 120 | **VALID_GEN_STREAM_KWARGS) 121 | ] 122 | 123 | assert len(responses) > 1 124 | for response in responses: 125 | assert len(response.choices) == 1 126 | choice = response.choices[0] 127 | assert len(choice.text) > 0 128 | 129 | 130 | @pytest.mark.asyncio 131 | async def test_completion_model_async_stream_complete_invalid_parameter( 132 | async_client: AsyncModelfarm) -> None: 133 | with pytest.raises(BadRequestException): 134 | async for _ in await async_client.completions.create( 135 | prompt=LONG_PROMPT, model=MODEL, stream=True, 136 | **INVALID_KWARGS): 137 | pass 138 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/goval/api/repl/repl_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: replit/goval/api/repl/repl.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import symbol_database as _symbol_database 8 | from google.protobuf.internal import builder as _builder 9 | 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from replit.ai.modelfarm.identity.goval.api.features import ( 16 | features_pb2 as replit_dot_goval_dot_api_dot_features_dot_features__pb2, 17 | ) 18 | 19 | 20 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 21 | b'\n replit/goval/api/repl/repl.proto\x12\x15replit.goval.api.repl\x1a(replit/goval/api/features/features.proto"C\n\x07\x42uckets\x12\x11\n\tsnapshots\x18\x01 \x01(\t\x12\x10\n\x08metadata\x18\x02 \x01(\t\x12\x13\n\x0b\x64isk_blocks\x18\x03 \x01(\t"\x8b\x02\n\x04Repl\x12\n\n\x02id\x18\x01 \x01(\t\x12\x10\n\x08language\x18\x02 \x01(\t\x12\x0e\n\x06\x62ucket\x18\x03 \x01(\t\x12\x0c\n\x04slug\x18\x04 \x01(\t\x12\x0c\n\x04user\x18\x05 \x01(\t\x12\x12\n\nsourceRepl\x18\x06 \x01(\t\x12\x10\n\x08\x64\x61tabase\x18\x07 \x01(\t\x12/\n\x07\x62uckets\x18\x08 \x01(\x0b\x32\x1e.replit.goval.api.repl.Buckets\x12.\n\x07user_id\x18\t \x01(\x0b\x32\x1d.replit.goval.api.repl.UserId\x12\x0f\n\x07is_team\x18\n \x01(\x08\x12\r\n\x05roles\x18\x0b \x03(\t\x12\x12\n\nlog_fields\x18\r \x01(\t"M\n\x06UserId\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x37\n\x0b\x65nvironment\x18\x02 \x01(\x0e\x32".replit.goval.api.repl.Environment"\xa7\x02\n\x0eResourceLimits\x12\x0b\n\x03net\x18\x01 \x01(\x08\x12\x0e\n\x06memory\x18\x02 \x01(\x03\x12\x0f\n\x07threads\x18\x03 \x01(\x01\x12\x0e\n\x06shares\x18\x04 \x01(\x01\x12\x0c\n\x04\x64isk\x18\x05 \x01(\x03\x12\x14\n\x0cminimum_disk\x18\n \x01(\x03\x12\x14\n\x0cscratch_disk\x18\t \x01(\x03\x12@\n\x05\x63\x61\x63he\x18\x06 \x01(\x0e\x32\x31.replit.goval.api.repl.ResourceLimits.Cachability\x12\x17\n\x0frestrictNetwork\x18\x07 \x01(\x08\x12\x15\n\rpreventWakeup\x18\x08 \x01(\x08"+\n\x0b\x43\x61\x63hability\x12\x08\n\x04NONE\x10\x00\x12\x08\n\x04USER\x10\x01\x12\x08\n\x04REPL\x10\x02"%\n\x0bPermissions\x12\x16\n\x0etoggleAlwaysOn\x18\x01 \x01(\x08"\xab\x03\n\x08Metadata\x12)\n\x04repl\x18\x07 \x01(\x0b\x32\x1b.replit.goval.api.repl.Repl\x12=\n\x0eresourceLimits\x18\n \x01(\x0b\x32%.replit.goval.api.repl.ResourceLimits\x12H\n\x19interactiveResourceLimits\x18\x11 \x01(\x0b\x32%.replit.goval.api.repl.ResourceLimits\x12\x37\n\x0bpersistence\x18\x06 \x01(\x0e\x32".replit.goval.api.repl.Persistence\x12\r\n\x05\x66lags\x18\x0e \x03(\t\x12\x37\n\x0bpermissions\x18\x0f \x01(\x0b\x32".replit.goval.api.repl.Permissions\x12\x34\n\x08\x66\x65\x61tures\x18\x10 \x03(\x0b\x32".replit.goval.api.features.Feature\x12\x34\n\nbuild_info\x18\x12 \x01(\x0b\x32 .replit.goval.api.repl.BuildInfo"W\n\tBuildInfo\x12\x15\n\rdeployment_id\x18\x01 \x01(\t\x12\x0b\n\x03url\x18\x02 \x01(\t\x12\x10\n\x08\x62uild_id\x18\x03 \x01(\t\x12\x14\n\x0cmachine_tier\x18\x04 \x01(\t*.\n\x0b\x45nvironment\x12\x0f\n\x0b\x44\x45VELOPMENT\x10\x00\x12\x0e\n\nPRODUCTION\x10\x01*E\n\x0bPersistence\x12\x0e\n\nPERSISTENT\x10\x00\x12\r\n\tEPHEMERAL\x10\x01\x12\x08\n\x04NONE\x10\x02\x12\r\n\tREAD_ONLY\x10\x03\x42"Z github.com/replit/goval/api/replb\x06proto3' 22 | ) 23 | 24 | _globals = globals() 25 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 26 | _builder.BuildTopDescriptorsAndMessages( 27 | DESCRIPTOR, "replit.goval.api.repl.repl_pb2", _globals 28 | ) 29 | if _descriptor._USE_C_DESCRIPTORS == False: 30 | DESCRIPTOR._options = None 31 | DESCRIPTOR._serialized_options = b"Z github.com/replit/goval/api/repl" 32 | _globals["_ENVIRONMENT"]._serialized_start = 1375 33 | _globals["_ENVIRONMENT"]._serialized_end = 1421 34 | _globals["_PERSISTENCE"]._serialized_start = 1423 35 | _globals["_PERSISTENCE"]._serialized_end = 1492 36 | _globals["_BUCKETS"]._serialized_start = 101 37 | _globals["_BUCKETS"]._serialized_end = 168 38 | _globals["_REPL"]._serialized_start = 171 39 | _globals["_REPL"]._serialized_end = 438 40 | _globals["_USERID"]._serialized_start = 440 41 | _globals["_USERID"]._serialized_end = 517 42 | _globals["_RESOURCELIMITS"]._serialized_start = 520 43 | _globals["_RESOURCELIMITS"]._serialized_end = 815 44 | _globals["_RESOURCELIMITS_CACHABILITY"]._serialized_start = 772 45 | _globals["_RESOURCELIMITS_CACHABILITY"]._serialized_end = 815 46 | _globals["_PERMISSIONS"]._serialized_start = 817 47 | _globals["_PERMISSIONS"]._serialized_end = 854 48 | _globals["_METADATA"]._serialized_start = 857 49 | _globals["_METADATA"]._serialized_end = 1284 50 | _globals["_BUILDINFO"]._serialized_start = 1286 51 | _globals["_BUILDINFO"]._serialized_end = 1373 52 | # @@protoc_insertion_point(module_scope) 53 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/goval/api/signing_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: replit/goval/api/signing.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import symbol_database as _symbol_database 8 | from google.protobuf.internal import builder as _builder 9 | 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 16 | from replit.ai.modelfarm.identity.goval.api import ( 17 | client_pb2 as replit_dot_goval_dot_api_dot_client__pb2, 18 | ) 19 | from replit.ai.modelfarm.identity.goval.api.repl import ( 20 | repl_pb2 as replit_dot_goval_dot_api_dot_repl_dot_repl__pb2, 21 | ) 22 | 23 | 24 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 25 | b'\n\x1ereplit/goval/api/signing.proto\x12\x10replit.goval.api\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dreplit/goval/api/client.proto\x1a replit/goval/api/repl/repl.proto"\x89\x01\n\x15GovalSigningAuthority\x12\x10\n\x06key_id\x18\x01 \x01(\tH\x00\x12\x15\n\x0bsigned_cert\x18\x02 \x01(\tH\x00\x12/\n\x07version\x18\x03 \x01(\x0e\x32\x1e.replit.goval.api.TokenVersion\x12\x0e\n\x06issuer\x18\x04 \x01(\tB\x06\n\x04\x63\x65rt"\xbc\x01\n\x10\x43\x65rtificateClaim\x12\x10\n\x06replid\x18\x01 \x01(\tH\x00\x12\x0e\n\x04user\x18\x02 \x01(\tH\x00\x12\x11\n\x07user_id\x18\x07 \x01(\x03H\x00\x12\x11\n\x07\x63luster\x18\x04 \x01(\tH\x00\x12\x14\n\nsubcluster\x18\x05 \x01(\tH\x00\x12\x14\n\ndeployment\x18\x06 \x01(\x08H\x00\x12+\n\x04\x66lag\x18\x03 \x01(\x0e\x32\x1b.replit.goval.api.FlagClaimH\x00\x42\x07\n\x05\x63laim"\xa4\x01\n\tGovalCert\x12\'\n\x03iat\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\'\n\x03\x65xp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x32\n\x06\x63laims\x18\x03 \x03(\x0b\x32".replit.goval.api.CertificateClaim\x12\x11\n\tpublicKey\x18\x04 \x01(\t"\xe8\x01\n\nGovalToken\x12\'\n\x03iat\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\'\n\x03\x65xp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0e\n\x06replid\x18\x03 \x01(\t\x12\x31\n\nrepl_token\x18\x04 \x01(\x0b\x32\x1b.replit.goval.api.ReplTokenH\x00\x12<\n\rrepl_identity\x18\x05 \x01(\x0b\x32#.replit.goval.api.GovalReplIdentityH\x00\x42\x07\n\x05Token"\xe7\x02\n\x11GovalReplIdentity\x12\x0e\n\x06replid\x18\x01 \x01(\t\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x0c\n\x04slug\x18\x03 \x01(\t\x12\x0b\n\x03\x61ud\x18\x04 \x01(\t\x12\x11\n\tephemeral\x18\x05 \x01(\x08\x12\x14\n\x0coriginReplid\x18\x06 \x01(\t\x12\x0f\n\x07user_id\x18\x07 \x01(\x03\x12\x34\n\nbuild_info\x18\x08 \x01(\x0b\x32 .replit.goval.api.repl.BuildInfo\x12\x0f\n\x07is_team\x18\t \x01(\x08\x12\r\n\x05roles\x18\n \x03(\t\x12?\n\x0binteractive\x18\x0b \x01(\x0b\x32(.replit.goval.api.ReplRuntimeInteractiveH\x00\x12=\n\ndeployment\x18\x0c \x01(\x0b\x32\'.replit.goval.api.ReplRuntimeDeploymentH\x00\x42\t\n\x07runtime"=\n\x16ReplRuntimeInteractive\x12\x0f\n\x07\x63luster\x18\x01 \x01(\t\x12\x12\n\nsubcluster\x18\x02 \x01(\t"\x17\n\x15ReplRuntimeDeployment*9\n\x0cTokenVersion\x12\x13\n\x0f\x42\x41RE_REPL_TOKEN\x10\x00\x12\x14\n\x10TYPE_AWARE_TOKEN\x10\x01*\xe3\x01\n\tFlagClaim\x12\x14\n\x10MINT_GOVAL_TOKEN\x10\x00\x12\x1a\n\x16SIGN_INTERMEDIATE_CERT\x10\x01\x12\x0c\n\x08IDENTITY\x10\x05\x12\x0f\n\x0bGHOSTWRITER\x10\x06\x12\x12\n\x0eRENEW_IDENTITY\x10\x07\x12\x0c\n\x08RENEW_KV\x10\x08\x12\x0f\n\x0b\x44\x45PLOYMENTS\x10\n\x12\x0e\n\nANY_REPLID\x10\x02\x12\x0c\n\x08\x41NY_USER\x10\x03\x12\x0f\n\x0b\x41NY_USER_ID\x10\x0b\x12\x0f\n\x0b\x41NY_CLUSTER\x10\x04\x12\x12\n\x0e\x41NY_SUBCLUSTER\x10\tB\x1dZ\x1bgithub.com/replit/goval/apib\x06proto3' 26 | ) 27 | 28 | _globals = globals() 29 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 30 | _builder.BuildTopDescriptorsAndMessages( 31 | DESCRIPTOR, "replit.goval.api.signing_pb2", _globals 32 | ) 33 | if _descriptor._USE_C_DESCRIPTORS == False: 34 | DESCRIPTOR._options = None 35 | DESCRIPTOR._serialized_options = b"Z\033github.com/replit/goval/api" 36 | _globals["_TOKENVERSION"]._serialized_start = 1333 37 | _globals["_TOKENVERSION"]._serialized_end = 1390 38 | _globals["_FLAGCLAIM"]._serialized_start = 1393 39 | _globals["_FLAGCLAIM"]._serialized_end = 1620 40 | _globals["_GOVALSIGNINGAUTHORITY"]._serialized_start = 151 41 | _globals["_GOVALSIGNINGAUTHORITY"]._serialized_end = 288 42 | _globals["_CERTIFICATECLAIM"]._serialized_start = 291 43 | _globals["_CERTIFICATECLAIM"]._serialized_end = 479 44 | _globals["_GOVALCERT"]._serialized_start = 482 45 | _globals["_GOVALCERT"]._serialized_end = 646 46 | _globals["_GOVALTOKEN"]._serialized_start = 649 47 | _globals["_GOVALTOKEN"]._serialized_end = 881 48 | _globals["_GOVALREPLIDENTITY"]._serialized_start = 884 49 | _globals["_GOVALREPLIDENTITY"]._serialized_end = 1243 50 | _globals["_REPLRUNTIMEINTERACTIVE"]._serialized_start = 1245 51 | _globals["_REPLRUNTIMEINTERACTIVE"]._serialized_end = 1306 52 | _globals["_REPLRUNTIMEDEPLOYMENT"]._serialized_start = 1308 53 | _globals["_REPLRUNTIMEDEPLOYMENT"]._serialized_end = 1331 54 | # @@protoc_insertion_point(module_scope) 55 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/google/language_models/text_generation_model.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncIterator, Iterator 2 | 3 | from replit.ai.modelfarm import AsyncModelfarm, Modelfarm 4 | from replit.ai.modelfarm.google.structs import TextGenerationResponse 5 | from replit.ai.modelfarm.google.utils import ready_parameters 6 | from replit.ai.modelfarm.structs.completions import CompletionModelResponse 7 | 8 | 9 | class TextGenerationModel: 10 | """ 11 | Class representing a Google completion model. 12 | 13 | Methods: 14 | from_pretrained - Loads a pretrained model using its identifier 15 | predict - completes a human-like text given an initial prompt. 16 | async_predict - Async version of the predict method. 17 | """ 18 | 19 | def __init__(self, model_id: str): 20 | """Constructor method to initialize a text generation model.""" 21 | self.underlying_model = model_id 22 | self._client = Modelfarm() 23 | self._async_client = AsyncModelfarm() 24 | 25 | @staticmethod 26 | def from_pretrained(model_id: str) -> "TextGenerationModel": 27 | """ 28 | Creates a Tokenizer from a pretrained model. 29 | 30 | Args: 31 | model_id (str): The identifier of the pretrained model. 32 | 33 | Returns: 34 | The TextGenerationModel class instance. 35 | """ 36 | return TextGenerationModel(model_id) 37 | 38 | def predict(self, prompt: str, **kwargs) -> TextGenerationResponse: 39 | """ 40 | completes a human-like text given an initial prompt. 41 | 42 | Args: 43 | prompt (str): The initial text to start the generation. 44 | 45 | Returns: 46 | TextGenerationResponse: The model's response containing the completed text. 47 | """ 48 | parameters = ready_parameters(kwargs) 49 | response = self._client.completions.create(prompt=prompt, 50 | model=self.underlying_model, 51 | stream=False, 52 | **parameters) 53 | return self.__ready_response(response) 54 | 55 | def predict_streaming(self, prompt: str, 56 | **kwargs) -> Iterator[TextGenerationResponse]: 57 | """ 58 | completes a human-like text given an initial prompt. 59 | 60 | Args: 61 | prompt (str): The initial text to start the generation. 62 | 63 | Returns: 64 | TextGenerationResponse: The model's response containing the completed text. 65 | """ 66 | parameters = ready_parameters(kwargs) 67 | response = self._client.completions.create(prompt=prompt, 68 | model=self.underlying_model, 69 | stream=True, 70 | **parameters) 71 | for x in response: 72 | yield self.__ready_response(x) 73 | 74 | async def async_predict(self, prompt: str, 75 | **kwargs) -> TextGenerationResponse: 76 | """ 77 | Async version of the predict method. Equivalent to the predict method, 78 | but suited for asynchronous programming. 79 | 80 | Args: 81 | prompt (str): The initial text to start the generation. 82 | 83 | Returns: 84 | TextGenerationResponse: The model's response containing the completed text. 85 | """ 86 | parameters = ready_parameters(kwargs) 87 | response = await self._async_client.completions.create( 88 | prompt=prompt, 89 | model=self.underlying_model, 90 | stream=False, 91 | **parameters) 92 | return self.__ready_response(response) 93 | 94 | async def async_predict_streaming( 95 | self, prompt: str, 96 | **kwargs) -> AsyncIterator[TextGenerationResponse]: 97 | """ 98 | Async version of the predict method. Equivalent to the predict method, 99 | but suited for asynchronous programming. 100 | 101 | Args: 102 | prompt (str): The initial text to start the generation. 103 | 104 | Returns: 105 | TextGenerationResponse: The model's response containing the completed text. 106 | """ 107 | parameters = ready_parameters(kwargs) 108 | response = await self._async_client.completions.create( 109 | prompt=prompt, 110 | model=self.underlying_model, 111 | stream=True, 112 | **parameters) 113 | async for x in response: 114 | yield self.__ready_response(x) 115 | 116 | def __ready_response( 117 | self, response: CompletionModelResponse) -> TextGenerationResponse: 118 | """ 119 | Transforms Completion Model's response into a readily usable format. 120 | 121 | Args: 122 | response (CompletionModelResponse): The original response from 123 | the underlying model. 124 | 125 | Returns: 126 | TextGenerationResponse: The transformed response. 127 | """ 128 | choice = response.choices[0] 129 | safetyAttributes = choice.metadata[ 130 | "safetyAttributes"] if choice.metadata else {} 131 | safetyCategories = dict( 132 | zip(safetyAttributes["categories"], 133 | safetyAttributes["scores"], 134 | strict=True)) 135 | 136 | return TextGenerationResponse( 137 | is_blocked=safetyAttributes["blocked"], 138 | raw_prediction_response=choice.model_dump(), 139 | safety_attributes=safetyCategories, 140 | text=choice.text, 141 | ) 142 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/goval/api/client_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: replit/goval/api/client.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import symbol_database as _symbol_database 8 | from google.protobuf.internal import builder as _builder 9 | 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 16 | from replit.ai.modelfarm.identity.goval.api.repl import ( 17 | repl_pb2 as replit_dot_goval_dot_api_dot_repl_dot_repl__pb2, 18 | ) 19 | from replit.ai.modelfarm.identity.goval.api.features import ( 20 | features_pb2 as replit_dot_goval_dot_api_dot_features_dot_features__pb2, 21 | ) 22 | 23 | 24 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( 25 | b'\n\x1dreplit/goval/api/client.proto\x12\x10replit.goval.api\x1a\x1fgoogle/protobuf/timestamp.proto\x1a replit/goval/api/repl/repl.proto\x1a(replit/goval/api/features/features.proto"\xdb\x07\n\tReplToken\x12\'\n\x03iat\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\'\n\x03\x65xp\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04salt\x18\x03 \x01(\t\x12\x0f\n\x07\x63luster\x18\x04 \x01(\t\x12\x37\n\x0bpersistence\x18\x06 \x01(\x0e\x32".replit.goval.api.repl.Persistence\x12+\n\x04repl\x18\x07 \x01(\x0b\x32\x1b.replit.goval.api.repl.ReplH\x00\x12\x30\n\x02id\x18\x08 \x01(\x0b\x32".replit.goval.api.ReplToken.ReplIDH\x00\x12\x46\n\tclassroom\x18\t \x01(\x0b\x32-.replit.goval.api.ReplToken.ClassroomMetadataB\x02\x18\x01H\x00\x12=\n\x0eresourceLimits\x18\n \x01(\x0b\x32%.replit.goval.api.repl.ResourceLimits\x12H\n\x19interactiveResourceLimits\x18\x11 \x01(\x0b\x32%.replit.goval.api.repl.ResourceLimits\x12\x36\n\x06\x66ormat\x18\x0c \x01(\x0e\x32&.replit.goval.api.ReplToken.WireFormat\x12\x38\n\tpresenced\x18\r \x01(\x0b\x32%.replit.goval.api.ReplToken.Presenced\x12\r\n\x05\x66lags\x18\x0e \x03(\t\x12\x37\n\x0bpermissions\x18\x0f \x01(\x0b\x32".replit.goval.api.repl.Permissions\x12\x34\n\x08\x66\x65\x61tures\x18\x10 \x03(\x0b\x32".replit.goval.api.features.Feature\x12\x34\n\nbuild_info\x18\x12 \x01(\x0b\x32 .replit.goval.api.repl.BuildInfo\x1a\x31\n\x11\x43lassroomMetadata\x12\n\n\x02id\x18\x01 \x01(\t\x12\x10\n\x08language\x18\x02 \x01(\t\x1a(\n\x06ReplID\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\nsourceRepl\x18\x02 \x01(\t\x1a\x31\n\tPresenced\x12\x10\n\x08\x62\x65\x61rerID\x18\x01 \x01(\r\x12\x12\n\nbearerName\x18\x02 \x01(\t"2\n\nWireFormat\x12\x0c\n\x08PROTOBUF\x10\x00\x12\x0c\n\x04JSON\x10\x01\x1a\x02\x08\x01\x12\x08\n\x04PID2\x10\x02\x42\n\n\x08metadata".\n\x0eTLSCertificate\x12\x0e\n\x06\x64omain\x18\x01 \x01(\t\x12\x0c\n\x04\x63\x65rt\x18\x02 \x01(\x0c"\xc0\x02\n\x0cReplTransfer\x12)\n\x04repl\x18\x01 \x01(\x0b\x32\x1b.replit.goval.api.repl.Repl\x12\x39\n\nreplLimits\x18\x02 \x01(\x0b\x32%.replit.goval.api.repl.ResourceLimits\x12\x39\n\nuserLimits\x18\x03 \x01(\x0b\x32%.replit.goval.api.repl.ResourceLimits\x12\x15\n\rcustomDomains\x18\x04 \x03(\t\x12\x36\n\x0c\x63\x65rtificates\x18\x05 \x03(\x0b\x32 .replit.goval.api.TLSCertificate\x12\r\n\x05\x66lags\x18\x06 \x03(\t\x12\x31\n\x08metadata\x18\x07 \x01(\x0b\x32\x1f.replit.goval.api.repl.Metadata"H\n\x10\x41llowReplRequest\x12\x34\n\x0creplTransfer\x18\x01 \x01(\x0b\x32\x1e.replit.goval.api.ReplTransfer"l\n\x0f\x43lusterMetadata\x12\n\n\x02id\x18\x01 \x01(\t\x12\x11\n\tconmanURL\x18\x02 \x01(\t\x12\x0c\n\x04gurl\x18\x03 \x01(\t\x12\r\n\x05proxy\x18\x05 \x01(\t\x12\x11\n\tcontinent\x18\x07 \x01(\tJ\x04\x08\x04\x10\x05J\x04\x08\x06\x10\x07"y\n\x10\x45victReplRequest\x12:\n\x0f\x63lusterMetadata\x18\x01 \x01(\x0b\x32!.replit.goval.api.ClusterMetadata\x12\r\n\x05token\x18\x02 \x01(\t\x12\x0c\n\x04user\x18\x03 \x01(\t\x12\x0c\n\x04slug\x18\x04 \x01(\t"I\n\x11\x45victReplResponse\x12\x34\n\x0creplTransfer\x18\x01 \x01(\x0b\x32\x1e.replit.goval.api.ReplTransferB\x1dZ\x1bgithub.com/replit/goval/apib\x06proto3' 26 | ) 27 | 28 | _globals = globals() 29 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 30 | _builder.BuildTopDescriptorsAndMessages( 31 | DESCRIPTOR, "replit.goval.api.client_pb2", _globals 32 | ) 33 | if _descriptor._USE_C_DESCRIPTORS == False: 34 | DESCRIPTOR._options = None 35 | DESCRIPTOR._serialized_options = b"Z\033github.com/replit/goval/api" 36 | _REPLTOKEN_WIREFORMAT.values_by_name["JSON"]._options = None 37 | _REPLTOKEN_WIREFORMAT.values_by_name["JSON"]._serialized_options = b"\010\001" 38 | _REPLTOKEN.fields_by_name["classroom"]._options = None 39 | _REPLTOKEN.fields_by_name["classroom"]._serialized_options = b"\030\001" 40 | _globals["_REPLTOKEN"]._serialized_start = 161 41 | _globals["_REPLTOKEN"]._serialized_end = 1148 42 | _globals["_REPLTOKEN_CLASSROOMMETADATA"]._serialized_start = 942 43 | _globals["_REPLTOKEN_CLASSROOMMETADATA"]._serialized_end = 991 44 | _globals["_REPLTOKEN_REPLID"]._serialized_start = 993 45 | _globals["_REPLTOKEN_REPLID"]._serialized_end = 1033 46 | _globals["_REPLTOKEN_PRESENCED"]._serialized_start = 1035 47 | _globals["_REPLTOKEN_PRESENCED"]._serialized_end = 1084 48 | _globals["_REPLTOKEN_WIREFORMAT"]._serialized_start = 1086 49 | _globals["_REPLTOKEN_WIREFORMAT"]._serialized_end = 1136 50 | _globals["_TLSCERTIFICATE"]._serialized_start = 1150 51 | _globals["_TLSCERTIFICATE"]._serialized_end = 1196 52 | _globals["_REPLTRANSFER"]._serialized_start = 1199 53 | _globals["_REPLTRANSFER"]._serialized_end = 1519 54 | _globals["_ALLOWREPLREQUEST"]._serialized_start = 1521 55 | _globals["_ALLOWREPLREQUEST"]._serialized_end = 1593 56 | _globals["_CLUSTERMETADATA"]._serialized_start = 1595 57 | _globals["_CLUSTERMETADATA"]._serialized_end = 1703 58 | _globals["_EVICTREPLREQUEST"]._serialized_start = 1705 59 | _globals["_EVICTREPLREQUEST"]._serialized_end = 1826 60 | _globals["_EVICTREPLRESPONSE"]._serialized_start = 1828 61 | _globals["_EVICTREPLRESPONSE"]._serialized_end = 1901 62 | # @@protoc_insertion_point(module_scope) 63 | -------------------------------------------------------------------------------- /src/replit/tests/ai/modelfarm/test_chat_completions.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from typing import Any, Dict, List 3 | 4 | import pytest 5 | from replit.ai.modelfarm import AsyncModelfarm, Modelfarm 6 | from replit.ai.modelfarm.exceptions import BadRequestException 7 | from replit.ai.modelfarm.structs.chat import ChatCompletionMessageRequestParam 8 | 9 | # module level constants 10 | 11 | MODEL = "chat-bison" 12 | 13 | MESSAGES: List[ChatCompletionMessageRequestParam] = [ 14 | { 15 | "role": "USER", 16 | "content": "What is the meaning of life?", 17 | }, 18 | ] 19 | 20 | # kwargs for different endpoints and cases 21 | 22 | VALID_KWARGS = { 23 | "top_p": 0.1, 24 | "stop": ["\n"], 25 | "n": 3, 26 | "provider_extra_parameters": { 27 | "top_k": 20, 28 | } 29 | } 30 | 31 | INVALID_KWARGS: Dict[str, Any] = { 32 | "invalid_parameter": 0.5, 33 | } 34 | 35 | # stream_chat endpoint does not support the candidateCount arg 36 | VALID_GEN_STREAM_KWARGS = { 37 | "max_tokens": 128, 38 | "temperature": 0, 39 | "top_p": 0.1, 40 | "provider_extra_parameters": { 41 | "top_k": 20, 42 | }, 43 | } 44 | 45 | 46 | def test_chat_model_chat(client: Modelfarm) -> None: 47 | response = client.chat.completions.create( 48 | messages=MESSAGES, 49 | model=MODEL, 50 | **VALID_KWARGS, 51 | ) 52 | 53 | assert len(response.choices) >= 1 54 | 55 | choice = response.choices[0] 56 | 57 | assert len(choice.message.content) > 10 58 | 59 | choice_metadata = choice.metadata 60 | assert choice_metadata["safetyAttributes"]["blocked"] is False 61 | 62 | 63 | def test_chat_model_chat_no_kwargs(client: Modelfarm) -> None: 64 | response = client.chat.completions.create(messages=MESSAGES, model=MODEL) 65 | 66 | assert len(response.choices) == 1 67 | 68 | choice = response.choices[0] 69 | 70 | assert choice.message.content is not None 71 | assert len(choice.message.content) > 10 72 | 73 | choice_metadata = choice.metadata 74 | assert choice_metadata is not None 75 | assert choice_metadata["safetyAttributes"]["blocked"] is False 76 | 77 | 78 | def test_chat_model_chat_invalid_parameter(client: Modelfarm) -> None: 79 | with pytest.raises(BadRequestException): 80 | client.chat.completions.create( 81 | messages=MESSAGES, 82 | model=MODEL, 83 | **INVALID_KWARGS, 84 | ) 85 | 86 | 87 | @pytest.mark.asyncio 88 | async def test_chat_model_async_chat(async_client: AsyncModelfarm) -> None: 89 | response = await async_client.chat.completions.create( 90 | messages=MESSAGES, 91 | model=MODEL, 92 | **VALID_KWARGS, 93 | ) 94 | 95 | assert len(response.choices) >= 1 96 | 97 | choice = response.choices[0] 98 | 99 | assert len(choice.message.content) > 10 100 | 101 | choice_metadata = choice.metadata 102 | assert choice_metadata["safetyAttributes"]["blocked"] is False 103 | 104 | 105 | @pytest.mark.asyncio 106 | async def test_chat_model_async_chat_invalid_parameter( 107 | async_client: AsyncModelfarm) -> None: 108 | with pytest.raises(BadRequestException): 109 | await async_client.chat.completions.create(messages=MESSAGES, 110 | model=MODEL, 111 | **INVALID_KWARGS) 112 | 113 | 114 | def test_chat_model_stream_chat(client: Modelfarm) -> None: 115 | responses = list( 116 | client.chat.completions.create(messages=MESSAGES, 117 | model=MODEL, 118 | stream=True, 119 | **VALID_GEN_STREAM_KWARGS)) 120 | 121 | assert len(responses) > 1 122 | for response in responses: 123 | assert len(response.choices) == 1 124 | choice = response.choices[0] 125 | assert choice.delta.content is not None 126 | assert len(choice.delta.content) >= 1 127 | 128 | 129 | def test_chat_model_stream_chat_invalid_parameter(client: Modelfarm) -> None: 130 | with pytest.raises(BadRequestException): 131 | list( 132 | client.chat.completions.create(messages=MESSAGES, 133 | model=MODEL, 134 | stream=True, 135 | **INVALID_KWARGS)) 136 | 137 | 138 | def test_chat_model_stream_chat_raises_with_choice_count_param( 139 | client: Modelfarm) -> None: 140 | """ 141 | Test that stream_chat raises an exception if choice_count is specified. 142 | """ 143 | with pytest.raises(BadRequestException): 144 | list( 145 | client.chat.completions.create(messages=MESSAGES, 146 | model=MODEL, 147 | stream=True, 148 | n=5)) 149 | 150 | 151 | @pytest.mark.asyncio 152 | async def test_chat_model_async_stream_chat( 153 | async_client: AsyncModelfarm) -> None: 154 | responses = [ 155 | res async for res in await async_client.chat.completions.create( 156 | messages=MESSAGES, 157 | model=MODEL, 158 | stream=True, 159 | **VALID_GEN_STREAM_KWARGS) 160 | ] 161 | 162 | assert len(responses) > 1 163 | for response in responses: 164 | assert len(response.choices) == 1 165 | 166 | choice = response.choices[0] 167 | assert choice.delta.content is not None 168 | assert len(choice.delta.content) >= 1 169 | 170 | 171 | @pytest.mark.asyncio 172 | async def test_chat_model_async_stream_chat_invalid_parameter( 173 | async_client: AsyncModelfarm) -> None: 174 | with pytest.raises(BadRequestException): 175 | async for _ in await async_client.chat.completions.create( 176 | messages=MESSAGES, 177 | model=MODEL, 178 | stream=True, 179 | **INVALID_KWARGS, 180 | ): 181 | pass 182 | 183 | 184 | def test_chat_model_stream_chat_no_duplicates(client: Modelfarm) -> None: 185 | # synchronous streaming call 186 | responses = client.chat.completions.create( 187 | messages=MESSAGES, 188 | model=MODEL, 189 | stream=True, 190 | ) 191 | counter = Counter() 192 | for response in responses: 193 | counter[response.choices[0].delta.content] += 1 194 | 195 | for content, count in counter.items(): 196 | if count > 1: 197 | print(counter, content) 198 | assert count == 1 199 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/client.py: -------------------------------------------------------------------------------- 1 | import json 2 | from contextlib import asynccontextmanager 3 | from typing import Any, AsyncGenerator, AsyncIterator, Dict, Iterator, Optional 4 | 5 | import aiohttp 6 | import requests 7 | from aiohttp import ClientResponse 8 | from requests import JSONDecodeError, Response 9 | 10 | from .chat_completions import AsyncChat, Chat 11 | from .completions import AsyncCompletions, Completions 12 | from .config import get_config 13 | from .embeddings import AsyncEmbeddings, Embeddings 14 | from .exceptions import BadRequestException, InvalidResponseException 15 | from .replit_identity_token_manager import ReplitIdentityTokenManager 16 | 17 | 18 | class BaseModelfarm: 19 | 20 | def __init__(self, base_url: Optional[str] = None) -> None: 21 | """ 22 | Initializes a new instance of the BaseModelfarm class. 23 | """ 24 | self.base_url = base_url or get_config().rootUrl 25 | self.auth = ReplitIdentityTokenManager() 26 | 27 | def _get_auth_headers(self) -> Dict[str, str]: 28 | """ 29 | Gets authentication headers required for API requests. 30 | 31 | Returns: 32 | dict: A dictionary containing the Authorization header. 33 | """ 34 | token = self.auth.get_token() 35 | return {"Authorization": f"Bearer {token}"} 36 | 37 | 38 | class Modelfarm(BaseModelfarm): 39 | chat: Chat 40 | embeddings: Embeddings 41 | completions: Completions 42 | 43 | def __init__( 44 | self, 45 | base_url: Optional[str] = None, 46 | ) -> None: 47 | """ 48 | Initializes a new instance of the Modelfarm class. 49 | """ 50 | super().__init__(base_url) 51 | 52 | self.chat = Chat(self) 53 | self.embeddings = Embeddings(self) 54 | self.completions = Completions(self) 55 | 56 | def _post( 57 | self, 58 | path: str, 59 | payload: Optional[Dict[str, Any]] = None, 60 | stream: Optional[bool] = None, 61 | **kwargs, 62 | ) -> Response: 63 | return requests.post( 64 | url=self.base_url + path, 65 | headers=self._get_auth_headers(), 66 | json=payload, 67 | stream=stream, 68 | **kwargs, 69 | ) 70 | 71 | def _check_response(self, response: Response) -> None: 72 | """ 73 | Validates a response from the server. 74 | 75 | Parameters: 76 | response: The server response to check. 77 | 78 | Raises: 79 | InvalidResponseException: If the response is not valid JSON. 80 | BadRequestException: If the response contains a 400 status code. 81 | """ 82 | try: 83 | rjson = response.json() 84 | except JSONDecodeError as e: 85 | raise InvalidResponseException( 86 | f"Invalid response: {response.text}") from e 87 | 88 | if response.status_code == 400: 89 | raise BadRequestException(rjson["detail"]) 90 | if response.status_code != 200: 91 | if "detail" in rjson: 92 | raise InvalidResponseException(rjson["detail"]) 93 | raise InvalidResponseException(rjson) 94 | 95 | def _check_streaming_response(self, response: Response) -> None: 96 | """ 97 | Validates a streaming response from the server. 98 | 99 | Parameters: 100 | response: The server's streaming response to check. 101 | """ 102 | if response.status_code == 200: 103 | return 104 | self._check_response(response) 105 | 106 | def _parse_streaming_response(self, response) -> Iterator[Any]: 107 | """ 108 | Parses a streaming response from the server. 109 | 110 | Parameters: 111 | response: The server's streaming response to parse. 112 | 113 | Yields: 114 | JSON objects extracted from the streaming response. 115 | """ 116 | buffer = b"" 117 | decoder = json.JSONDecoder() 118 | for chunk in response.iter_content(chunk_size=128): 119 | buffer += chunk 120 | buffer_str = buffer.decode("utf-8") 121 | 122 | start_idx = 0 123 | # Iteratively parse JSON objects 124 | while start_idx < len(buffer_str): 125 | try: 126 | # Load JSON object 127 | result = decoder.raw_decode(buffer_str, start_idx) 128 | json_obj, end_idx = result 129 | 130 | yield json_obj 131 | # Update start index for next iteration 132 | start_idx = end_idx 133 | except json.JSONDecodeError: 134 | break 135 | buffer = buffer[start_idx:] 136 | 137 | 138 | class AsyncModelfarm(BaseModelfarm): 139 | chat: AsyncChat 140 | embeddings: AsyncEmbeddings 141 | completions: AsyncCompletions 142 | 143 | def __init__( 144 | self, 145 | base_url: Optional[str] = None, 146 | ) -> None: 147 | """ 148 | Initializes a new instance of the AsyncModelfarm class. 149 | """ 150 | super().__init__(base_url) 151 | 152 | self.chat = AsyncChat(self) 153 | self.embeddings = AsyncEmbeddings(self) 154 | self.completions = AsyncCompletions(self) 155 | 156 | @asynccontextmanager 157 | async def _post( 158 | self, 159 | path: str, 160 | payload: Optional[Dict[str, Any]] = None, 161 | timeout: float = 15, 162 | **kwargs, 163 | ) -> AsyncGenerator[ClientResponse, None]: 164 | async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( 165 | total=timeout)) as session, session.post( 166 | url=self.base_url + path, 167 | headers=self._get_auth_headers(), 168 | json=payload, 169 | **kwargs) as response: 170 | yield response 171 | 172 | async def _check_response(self, response: ClientResponse) -> None: 173 | """ 174 | Validates an asynchronous response from the server. 175 | 176 | Parameters: 177 | response: The asynchronous server response to check. 178 | 179 | Raises: 180 | InvalidResponseException: If the response is not valid JSON. 181 | BadRequestException: If the response contains a 400 status code. 182 | """ 183 | try: 184 | rjson = await response.json() 185 | except JSONDecodeError as e: 186 | raise InvalidResponseException( 187 | f"Invalid response: {response.text}") from e 188 | 189 | if response.status == 400: 190 | raise BadRequestException(rjson["detail"]) 191 | if response.status != 200: 192 | if "detail" in rjson: 193 | raise InvalidResponseException(rjson["detail"]) 194 | raise InvalidResponseException(rjson) 195 | 196 | async def _check_streaming_response(self, 197 | response: ClientResponse) -> None: 198 | """ 199 | Validates an asynchronous streaming response from the server. 200 | 201 | Parameters: 202 | response: The server's asynchronous streaming response to check. 203 | """ 204 | if response.status == 200: 205 | return 206 | await self._check_response(response) 207 | 208 | async def _parse_streaming_response( 209 | self, response: ClientResponse) -> AsyncIterator[Any]: 210 | """ 211 | Asynchronously parses a streaming response from the server. 212 | 213 | Parameters: 214 | response: The server's asynchronous streaming response to parse. 215 | 216 | Yields: 217 | JSON objects extracted from the streaming response. 218 | """ 219 | buffer = b"" 220 | decoder = json.JSONDecoder() 221 | while True: 222 | chunk = await response.content.read(128) 223 | if not chunk: 224 | break 225 | buffer += chunk 226 | buffer_str = buffer.decode("utf-8") 227 | 228 | start_idx = 0 229 | # Iteratively parse JSON objects 230 | while start_idx < len(buffer_str): 231 | try: 232 | # Load JSON object 233 | result = decoder.raw_decode(buffer_str, start_idx) 234 | json_obj, end_idx = result 235 | 236 | yield json_obj 237 | 238 | # Update start index for next iteration 239 | start_idx = end_idx 240 | except json.JSONDecodeError: 241 | break 242 | buffer = buffer[start_idx:] 243 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/google/preview/language_models/chat_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | from replit.ai.modelfarm import AsyncModelfarm, Modelfarm 5 | from replit.ai.modelfarm.google.structs import TextGenerationResponse 6 | from replit.ai.modelfarm.google.utils import ready_parameters 7 | from replit.ai.modelfarm.structs.chat import ( 8 | ChatCompletionMessageRequestParam, 9 | ChatCompletionResponse, 10 | ChatCompletionStreamChunkResponse, 11 | ) 12 | 13 | USER_AUTHOR = "user" 14 | MODEL_AUTHOR = "bot" 15 | 16 | 17 | @dataclass 18 | class InputOutputTextPair: 19 | input_text: str 20 | output_text: str 21 | 22 | 23 | @dataclass 24 | class ChatMessage: 25 | content: str 26 | author: str 27 | 28 | 29 | class ChatSession: 30 | context: Optional[str] 31 | examples: List[InputOutputTextPair] 32 | message_history: List[ChatMessage] 33 | underlying_model: str 34 | parameters: Dict[str, Any] 35 | 36 | _client: Modelfarm 37 | _async_client: AsyncModelfarm 38 | 39 | def __init__( 40 | self, 41 | underlying_model, 42 | context=None, 43 | examples: Optional[List[InputOutputTextPair]] = None, 44 | message_history: Optional[List[ChatMessage]] = None, 45 | parameters: Optional[Dict[str, Any]] = None, 46 | ) -> None: 47 | self.context = context 48 | self.examples = examples or [] 49 | self.message_history = message_history or [] 50 | self.underlying_model = underlying_model 51 | self.parameters = parameters or {} 52 | 53 | self._client = Modelfarm() 54 | self._async_client = AsyncModelfarm() 55 | 56 | def send_message(self, message: str, **kwargs): 57 | self.add_user_message(message) 58 | predictParams = dict( 59 | **self.parameters, **kwargs, **{ 60 | "context": self.context, 61 | "examples": self.__build_chat_examples_from_io(), 62 | }) 63 | response = self._client.chat.completions.create( 64 | model=self.underlying_model, 65 | messages=self.__build_replit_messages_from_history(), 66 | stream=False, 67 | **ready_parameters(predictParams), 68 | ) 69 | self.add_model_message(self.__get_response_content(response)) 70 | return self.__ready_response(response) 71 | 72 | async def async_send_message(self, message: str, **kwargs): 73 | self.add_user_message(message) 74 | predictParams = dict( 75 | **self.parameters, **kwargs, **{ 76 | "context": self.context, 77 | "examples": self.__build_chat_examples_from_io(), 78 | }) 79 | response = await self._async_client.chat.completions.create( 80 | model=self.underlying_model, 81 | messages=self.__build_replit_messages_from_history(), 82 | stream=False, 83 | **ready_parameters(predictParams), 84 | ) 85 | self.add_model_message(self.__get_response_content(response)) 86 | return self.__ready_response(response) 87 | 88 | def send_message_stream(self, message: str, **kwargs): 89 | self.add_user_message(message) 90 | predictParams = dict( 91 | **self.parameters, **kwargs, **{ 92 | "context": self.context, 93 | "examples": self.__build_chat_examples_from_io(), 94 | }) 95 | response = self._client.chat.completions.create( 96 | model=self.underlying_model, 97 | messages=self.__build_replit_messages_from_history(), 98 | stream=True, 99 | **ready_parameters(predictParams), 100 | ) 101 | message = "" 102 | for chunk in response: 103 | transformedResponse = self.__ready_response(chunk) 104 | message += transformedResponse.text 105 | yield transformedResponse 106 | self.add_model_message(message) 107 | 108 | async def async_send_message_stream(self, message: str, **kwargs): 109 | self.add_user_message(message) 110 | predictParams = dict( 111 | **self.parameters, **kwargs, **{ 112 | "context": self.context, 113 | "examples": self.__build_chat_examples_from_io(), 114 | }) 115 | response = await self._async_client.chat.completions.create( 116 | model=self.underlying_model, 117 | messages=self.__build_replit_messages_from_history(), 118 | stream=True, 119 | **ready_parameters(predictParams), 120 | ) 121 | message = "" 122 | async for chunk in response: 123 | transformedResponse = self.__ready_response(chunk) 124 | message += transformedResponse.text 125 | yield transformedResponse 126 | self.add_model_message(message) 127 | 128 | def add_user_message(self, message: str): 129 | chatMessage = ChatMessage(content=message, author=USER_AUTHOR) 130 | self.message_history.append(chatMessage) 131 | 132 | def add_model_message(self, message: str): 133 | chatMessage = ChatMessage(content=message, author=MODEL_AUTHOR) 134 | self.message_history.append(chatMessage) 135 | 136 | def __build_chat_examples_from_io(self) -> List[Dict[str, Dict]]: 137 | return [{ 138 | "input": { 139 | "content": io.input_text, 140 | "author": "" 141 | }, 142 | "output": { 143 | "content": io.output_text, 144 | "author": "" 145 | }, 146 | } for io in self.examples] 147 | 148 | def __build_replit_messages_from_history( 149 | self) -> List[ChatCompletionMessageRequestParam]: 150 | return [ 151 | self.__build_replit_message_from_google_chat_message(x) 152 | for x in self.message_history 153 | ] 154 | 155 | @staticmethod 156 | def __build_replit_message_from_google_chat_message( 157 | msg: ChatMessage, ) -> ChatCompletionMessageRequestParam: 158 | return {"content": msg.content, "role": msg.author} 159 | 160 | def __get_response_content( 161 | self, response: Union[ChatCompletionResponse, 162 | ChatCompletionStreamChunkResponse] 163 | ) -> str: 164 | if isinstance(response, ChatCompletionResponse): 165 | return response.choices[0].message.content or "" 166 | return response.choices[0].delta.content or "" 167 | 168 | def __ready_response( 169 | self, response: Union[ChatCompletionResponse, 170 | ChatCompletionStreamChunkResponse] 171 | ) -> TextGenerationResponse: 172 | """ 173 | Transforms Completion Model's response into a readily usable format. 174 | 175 | Args: 176 | response (CompletionModelResponse): The original response from 177 | the underlying model. 178 | 179 | Returns: 180 | TextGenerationResponse: The transformed response. 181 | """ 182 | choice = response.choices[0] 183 | text = self.__get_response_content(response) 184 | safetyAttributes = choice.metadata[ 185 | "safetyAttributes"] if choice.metadata else {} 186 | safetyCategories = dict( 187 | zip(safetyAttributes["categories"], 188 | safetyAttributes["scores"], 189 | strict=True)) if safetyAttributes else {} 190 | return TextGenerationResponse( 191 | is_blocked=safetyAttributes["blocked"], 192 | raw_prediction_response=choice.model_dump(), 193 | safety_attributes=safetyCategories, 194 | text=text, 195 | ) 196 | 197 | 198 | class ChatModel: 199 | 200 | def __init__(self, model_id: str): 201 | self.underlying_model = model_id 202 | 203 | @staticmethod 204 | def from_pretrained(model_id: str) -> "ChatModel": 205 | return ChatModel(model_id) 206 | 207 | def start_chat( 208 | self, 209 | context: Optional[str] = "", 210 | examples: Optional[List[InputOutputTextPair]] = None, 211 | message_history: Optional[List[ChatMessage]] = None, 212 | ) -> ChatSession: 213 | chat_session = ChatSession(self.underlying_model, context, examples 214 | or [], message_history or []) 215 | return chat_session 216 | 217 | 218 | # def get_embeddings(self, content: List[str]) -> List[TextEmbedding]: 219 | # request = self.__ready_input(content) 220 | # response = self.underlying_model.embed(request, {}) 221 | # return self.__ready_response(response) 222 | 223 | # async def async_get_embeddings(self, 224 | # content: List[str]) -> List[TextEmbedding]: 225 | # request = self.__ready_input(content) 226 | # response = await self.underlying_model.aembed(request, {}) 227 | # return self.__ready_response(response) 228 | 229 | # def __ready_input(self, content: List[str]) -> List[Dict[str, Any]]: 230 | # return [{'content': x} for x in content] 231 | 232 | # def __ready_response( 233 | # self, response: EmbeddingModelResponse) -> List[TextEmbedding]: 234 | 235 | # def transform_response(x): 236 | # metadata = x.tokenCountMetadata 237 | # tokenCount = metadata.unbilledTokens + metadata.billableTokens 238 | # stats = TextEmbeddingStatistics(tokenCount, x.truncated) 239 | # return TextEmbedding(stats, x.values) 240 | 241 | # return [transform_response(x) for x in response.embeddings] 242 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/completions.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | AsyncIterator, 5 | Dict, 6 | Iterator, 7 | Literal, 8 | Optional, 9 | Union, 10 | overload, 11 | ) 12 | 13 | from replit.ai.modelfarm.structs.completions import ( 14 | CompletionModelResponse, 15 | PromptParameter, 16 | ) 17 | 18 | if TYPE_CHECKING: 19 | from replit.ai.modelfarm import AsyncModelfarm, Modelfarm 20 | 21 | 22 | class Completions: 23 | _client: "Modelfarm" 24 | 25 | def __init__(self, client: "Modelfarm") -> None: 26 | self._client = client 27 | 28 | @overload 29 | def create( 30 | self, 31 | *, 32 | model: str, 33 | prompt: PromptParameter, 34 | stream: Literal[True], 35 | max_tokens: Optional[int] = 1024, 36 | temperature: float = 0.2, 37 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 38 | **kwargs: Any, 39 | ) -> Iterator[CompletionModelResponse]: 40 | ... 41 | 42 | @overload 43 | def create( 44 | self, 45 | *, 46 | model: str, 47 | prompt: PromptParameter, 48 | stream: Literal[False] = False, 49 | max_tokens: Optional[int] = 1024, 50 | temperature: float = 0.2, 51 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 52 | **kwargs: Any, 53 | ) -> CompletionModelResponse: 54 | ... 55 | 56 | def create( 57 | self, 58 | *, 59 | model: str, 60 | prompt: PromptParameter, 61 | stream: bool = False, 62 | max_tokens: Optional[int] = 1024, 63 | temperature: float = 0.2, 64 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 65 | **kwargs: Any, 66 | ) -> Union[CompletionModelResponse, Iterator[CompletionModelResponse]]: 67 | """ 68 | Makes a generation based on the messages and parameters. 69 | 70 | Args: 71 | model (str): The name of the model to use. 72 | prompt (PrompParameter): The prompt(s) to generate completion for. 73 | stream (bool): Whether to stream the responses. Defaults to False. 74 | max_tokens (int): The maximum number of tokens to generate. 75 | Defaults to 1024. 76 | temperature (float): The temperature of the generation. Defaults to 0.2. 77 | provider_extra_parameters (Optional[Dict[str, Any]]): Extra parameters 78 | of the speficic provider. Defaults to None. 79 | 80 | Returns: 81 | If stream is True, returns an iterator of CompletionModelResponse. 82 | Otherwise, returns a CompletionModelResponse. 83 | 84 | """ 85 | if stream: 86 | return self.__completion_stream( 87 | model=model, 88 | prompt=prompt, 89 | max_tokens=max_tokens, 90 | temperature=temperature, 91 | provider_extra_parameters=provider_extra_parameters, 92 | **kwargs, 93 | ) 94 | return self.__completion( 95 | model=model, 96 | prompt=prompt, 97 | max_tokens=max_tokens, 98 | temperature=temperature, 99 | provider_extra_parameters=provider_extra_parameters, 100 | **kwargs, 101 | ) 102 | 103 | def __completion( 104 | self, 105 | model: str, 106 | prompt: PromptParameter, 107 | max_tokens: Optional[int], 108 | temperature: float, 109 | provider_extra_parameters: Optional[Dict[str, Any]], 110 | **kwargs: Any, 111 | ) -> CompletionModelResponse: 112 | """ 113 | Makes a generation based on prompt(s) and parameters. 114 | """ 115 | response = self._client._post( 116 | "/v1beta2/completions", 117 | payload=_build_request_payload( 118 | model=model, 119 | prompt=prompt, 120 | max_tokens=max_tokens, 121 | temperature=temperature, 122 | stream=False, 123 | provider_extra_parameters=provider_extra_parameters, 124 | **kwargs, 125 | ), 126 | ) 127 | self._client._check_response(response) 128 | return CompletionModelResponse(**response.json()) 129 | 130 | def __completion_stream( 131 | self, 132 | model: str, 133 | prompt: PromptParameter, 134 | max_tokens: Optional[int], 135 | temperature: float, 136 | provider_extra_parameters: Optional[Dict[str, Any]], 137 | **kwargs: Any, 138 | ) -> Iterator[CompletionModelResponse]: 139 | """ 140 | Create a stream of CompletionModelResponse 141 | """ 142 | response = self._client._post( 143 | "/v1beta2/completions", 144 | payload=_build_request_payload( 145 | model=model, 146 | prompt=prompt, 147 | max_tokens=max_tokens, 148 | temperature=temperature, 149 | stream=True, 150 | provider_extra_parameters=provider_extra_parameters, 151 | **kwargs, 152 | ), 153 | stream=True, 154 | ) 155 | self._client._check_streaming_response(response) 156 | for chunk in self._client._parse_streaming_response(response): 157 | yield CompletionModelResponse(**chunk) 158 | 159 | 160 | class AsyncCompletions: 161 | _client: "AsyncModelfarm" 162 | 163 | def __init__(self, client: "AsyncModelfarm") -> None: 164 | self._client = client 165 | 166 | @overload 167 | async def create( 168 | self, 169 | *, 170 | model: str, 171 | prompt: PromptParameter, 172 | stream: Literal[True], 173 | max_tokens: Optional[int] = 1024, 174 | temperature: float = 0.2, 175 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 176 | **kwargs: Any, 177 | ) -> AsyncIterator[CompletionModelResponse]: 178 | ... 179 | 180 | @overload 181 | async def create( 182 | self, 183 | *, 184 | model: str, 185 | prompt: PromptParameter, 186 | stream: Literal[False] = False, 187 | max_tokens: Optional[int] = 1024, 188 | temperature: float = 0.2, 189 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 190 | **kwargs: Any, 191 | ) -> CompletionModelResponse: 192 | ... 193 | 194 | async def create( 195 | self, 196 | *, 197 | model: str, 198 | prompt: PromptParameter, 199 | stream: bool = False, 200 | max_tokens: Optional[int] = 1024, 201 | temperature: float = 0.2, 202 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 203 | **kwargs: Any, 204 | ) -> Union[CompletionModelResponse, 205 | AsyncIterator[CompletionModelResponse]]: 206 | """ 207 | Makes a generation based on the messages and parameters. 208 | 209 | Args: 210 | model (str): The name of the model to use. 211 | prompt (PromptParameter): The prompt(s) to generate completion for. 212 | stream (bool): Whether to stream the responses. Defaults to False. 213 | max_tokens (int): The maximum number of tokens to generate. 214 | Defaults to 1024. 215 | temperature (float): The temperature of the generation. Defaults to 0.2. 216 | provider_extra_parameters (Optional[Dict[str, Any]]): Extra parameters 217 | of the speficic provider. Defaults to None. 218 | 219 | Returns: 220 | If stream is True, returns an iterator of CompletionModelResponse. 221 | Otherwise, returns a CompletionModelResponse. 222 | 223 | """ 224 | if stream: 225 | return self.__completion_stream( 226 | model=model, 227 | prompt=prompt, 228 | max_tokens=max_tokens, 229 | temperature=temperature, 230 | provider_extra_parameters=provider_extra_parameters, 231 | **kwargs, 232 | ) 233 | return await self.__completion( 234 | model=model, 235 | prompt=prompt, 236 | max_tokens=max_tokens, 237 | temperature=temperature, 238 | provider_extra_parameters=provider_extra_parameters, 239 | **kwargs, 240 | ) 241 | 242 | async def __completion( 243 | self, 244 | model: str, 245 | prompt: PromptParameter, 246 | max_tokens: Optional[int], 247 | temperature: float, 248 | provider_extra_parameters: Optional[Dict[str, Any]], 249 | **kwargs: Any, 250 | ) -> CompletionModelResponse: 251 | """ 252 | Makes a generation based on the prompt(s) and parameters. 253 | """ 254 | async with self._client._post( 255 | "/v1beta2/completions", 256 | payload=_build_request_payload( 257 | model=model, 258 | prompt=prompt, 259 | max_tokens=max_tokens, 260 | temperature=temperature, 261 | stream=False, 262 | provider_extra_parameters=provider_extra_parameters, 263 | **kwargs, 264 | ), 265 | ) as response: 266 | await self._client._check_response(response) 267 | return CompletionModelResponse(**await response.json()) 268 | 269 | async def __completion_stream( 270 | self, 271 | model: str, 272 | prompt: PromptParameter, 273 | max_tokens: Optional[int], 274 | temperature: float, 275 | provider_extra_parameters: Optional[Dict[str, Any]], 276 | **kwargs: Any, 277 | ) -> AsyncIterator[CompletionModelResponse]: 278 | """ 279 | Create a stream of CompletionModelResponse 280 | """ 281 | async with self._client._post( 282 | "/v1beta2/completions", 283 | payload=_build_request_payload( 284 | model=model, 285 | prompt=prompt, 286 | max_tokens=max_tokens, 287 | temperature=temperature, 288 | stream=True, 289 | provider_extra_parameters=provider_extra_parameters, 290 | **kwargs, 291 | ), 292 | ) as response: 293 | await self._client._check_streaming_response(response) 294 | async for chunk in self._client._parse_streaming_response( 295 | response): 296 | yield CompletionModelResponse(**chunk) 297 | 298 | 299 | def _build_request_payload( 300 | model: str, 301 | prompt: PromptParameter, 302 | max_tokens: Optional[int], 303 | temperature: float, 304 | stream: bool, 305 | provider_extra_parameters: Optional[Dict[str, Any]], 306 | **kwargs: Any, 307 | ) -> Dict[str, Any]: 308 | """ 309 | Builds the request payload. 310 | 311 | Args: 312 | model (str): The name of the model to use. 313 | prompt (PromptParameter): The prompt(s) to generate completion for. 314 | max_tokens (int): The maximum number of tokens to generate. 315 | temperature (float): The temperature of the generation. 316 | provider_extra_parameters (Optional[Dict[str, Any]]): Extra parameters 317 | of the speficic provider. 318 | 319 | Returns: 320 | Dict[str, Any]: The request payload. 321 | """ 322 | 323 | payload = { 324 | "model": model, 325 | "prompt": prompt, 326 | "max_tokens": max_tokens, 327 | "temperature": temperature, 328 | "stream": stream, 329 | "provider_extra_parameters": provider_extra_parameters, 330 | **kwargs, 331 | } 332 | 333 | # Drop any keys with a value of None 334 | payload = {k: v for k, v in payload.items() if v is not None} 335 | return payload 336 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/chat_completions.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | AsyncIterator, 5 | Dict, 6 | Iterator, 7 | List, 8 | Literal, 9 | Optional, 10 | Union, 11 | overload, 12 | ) 13 | 14 | from replit.ai.modelfarm.structs.chat import ( 15 | ChatCompletionMessageRequestParam, 16 | ChatCompletionResponse, 17 | ChatCompletionStreamChunkResponse, 18 | ) 19 | 20 | if TYPE_CHECKING: 21 | from replit.ai.modelfarm import AsyncModelfarm, Modelfarm 22 | 23 | 24 | class Completions: 25 | _client: "Modelfarm" 26 | 27 | def __init__(self, client: "Modelfarm") -> None: 28 | """ 29 | Initializes a new instance of the Completions class. 30 | """ 31 | self._client = client 32 | 33 | @overload 34 | def create( 35 | self, 36 | *, 37 | messages: List[ChatCompletionMessageRequestParam], 38 | model: str, 39 | stream: Literal[True], 40 | max_tokens: Optional[int] = 1024, 41 | temperature: float = 0.2, 42 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 43 | **kwargs: Any, 44 | ) -> Iterator[ChatCompletionStreamChunkResponse]: 45 | ... 46 | 47 | @overload 48 | def create( 49 | self, 50 | *, 51 | messages: List[ChatCompletionMessageRequestParam], 52 | model: str, 53 | stream: Literal[False] = False, 54 | max_tokens: Optional[int] = 1024, 55 | temperature: float = 0.2, 56 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 57 | **kwargs: Any, 58 | ) -> ChatCompletionResponse: 59 | ... 60 | 61 | def create( 62 | self, 63 | *, 64 | messages: List[ChatCompletionMessageRequestParam], 65 | model: str, 66 | stream: bool = False, 67 | max_tokens: Optional[int] = 1024, 68 | temperature: float = 0.2, 69 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 70 | **kwargs: Any, 71 | ) -> Union[ChatCompletionResponse, 72 | Iterator[ChatCompletionStreamChunkResponse]]: 73 | """ 74 | Makes a generation based on the messages and parameters. 75 | 76 | Args: 77 | messages (List[ChatCompletionMessageRequestParam]): The list of messages 78 | in the conversation so far. 79 | model (str): The name of the model to use. 80 | stream (bool): Whether to stream the responses. Defaults to False. 81 | max_tokens (int): The maximum number of tokens to generate. 82 | Defaults to 1024. 83 | temperature (float): The temperature of the generation. Defaults to 0.2. 84 | provider_extra_parameters (Optional[Dict[str, Any]]): Extra parameters 85 | of the speficic provider. Defaults to None. 86 | 87 | Returns: 88 | If stream is True, returns an iterator of ChatCompletionStreamChunkResponse. 89 | Otherwise, returns a ChatCompletionResponse. 90 | 91 | """ 92 | if stream: 93 | return self.__chat_stream( 94 | messages=messages, 95 | model=model, 96 | max_tokens=max_tokens, 97 | temperature=temperature, 98 | provider_extra_parameters=provider_extra_parameters, 99 | **kwargs, 100 | ) 101 | return self.__chat( 102 | messages=messages, 103 | model=model, 104 | max_tokens=max_tokens, 105 | temperature=temperature, 106 | provider_extra_parameters=provider_extra_parameters, 107 | **kwargs, 108 | ) 109 | 110 | def __chat( 111 | self, 112 | messages: List[ChatCompletionMessageRequestParam], 113 | model: str, 114 | max_tokens: Optional[int], 115 | temperature: float, 116 | provider_extra_parameters: Optional[Dict[str, Any]], 117 | **kwargs: Any, 118 | ) -> ChatCompletionResponse: 119 | response = self._client._post( 120 | "/v1beta2/chat/completions", 121 | payload=_build_request_payload( 122 | messages=messages, 123 | model=model, 124 | max_tokens=max_tokens, 125 | temperature=temperature, 126 | stream=False, 127 | provider_extra_parameters=provider_extra_parameters, 128 | **kwargs, 129 | ), 130 | ) 131 | self._client._check_response(response) 132 | return ChatCompletionResponse(**response.json()) 133 | 134 | def __chat_stream( 135 | self, 136 | messages: List[ChatCompletionMessageRequestParam], 137 | model: str, 138 | max_tokens: Optional[int], 139 | temperature: float, 140 | provider_extra_parameters: Optional[Dict[str, Any]], 141 | **kwargs: Any, 142 | ) -> Iterator[ChatCompletionStreamChunkResponse]: 143 | """ 144 | Create a stream of ChatCompletionStreamChunkResponse 145 | """ 146 | response = self._client._post( 147 | "/v1beta2/chat/completions", 148 | payload=_build_request_payload( 149 | messages=messages, 150 | model=model, 151 | max_tokens=max_tokens, 152 | temperature=temperature, 153 | stream=True, 154 | provider_extra_parameters=provider_extra_parameters, 155 | **kwargs, 156 | ), 157 | stream=True, 158 | ) 159 | self._client._check_streaming_response(response) 160 | for chunk in self._client._parse_streaming_response(response): 161 | yield ChatCompletionStreamChunkResponse(**chunk) 162 | 163 | 164 | class AsyncCompletions: 165 | _client: "AsyncModelfarm" 166 | 167 | def __init__(self, client: "AsyncModelfarm") -> None: 168 | """ 169 | Initializes a new instance of the Completions class. 170 | """ 171 | self._client = client 172 | 173 | @overload 174 | async def create( 175 | self, 176 | messages: List[ChatCompletionMessageRequestParam], 177 | model: str, 178 | stream: Literal[True], 179 | max_tokens: Optional[int] = 1024, 180 | temperature: float = 0.2, 181 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 182 | **kwargs: Any, 183 | ) -> AsyncIterator[ChatCompletionStreamChunkResponse]: 184 | ... 185 | 186 | @overload 187 | async def create( 188 | self, 189 | messages: List[ChatCompletionMessageRequestParam], 190 | model: str, 191 | stream: Literal[False] = False, 192 | max_tokens: Optional[int] = 1024, 193 | temperature: float = 0.2, 194 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 195 | **kwargs: Any, 196 | ) -> ChatCompletionResponse: 197 | ... 198 | 199 | async def create( 200 | self, 201 | messages: List[ChatCompletionMessageRequestParam], 202 | model: str, 203 | stream: bool = False, 204 | max_tokens: Optional[int] = 1024, 205 | temperature: float = 0.2, 206 | provider_extra_parameters: Optional[Dict[str, Any]] = None, 207 | **kwargs: Any, 208 | ) -> Union[ChatCompletionResponse, 209 | AsyncIterator[ChatCompletionStreamChunkResponse]]: 210 | """ 211 | Makes a generation based on the messages and parameters. 212 | 213 | Args: 214 | messages (List[ChatCompletionMessageRequestParam]): The list of messages 215 | in the conversation so far. 216 | model (str): The name of the model to use. 217 | stream (bool): Whether to stream the responses. Defaults to False. 218 | max_tokens (int): The maximum number of tokens to generate. 219 | Defaults to 1024. 220 | temperature (float): The temperature of the generation. Defaults to 0.2. 221 | provider_extra_parameters (Optional[Dict[str, Any]]): Extra parameters 222 | of the speficic provider. Defaults to None. 223 | 224 | Returns: 225 | If stream is True, returns an iterator of ChatCompletionStreamChunkResponse. 226 | Otherwise, returns a ChatCompletionResponse. 227 | 228 | """ 229 | if stream: 230 | return self.__chat_stream( 231 | messages=messages, 232 | model=model, 233 | max_tokens=max_tokens, 234 | temperature=temperature, 235 | provider_extra_parameters=provider_extra_parameters, 236 | **kwargs, 237 | ) 238 | return await self.__chat( 239 | messages=messages, 240 | model=model, 241 | max_tokens=max_tokens, 242 | temperature=temperature, 243 | provider_extra_parameters=provider_extra_parameters, 244 | **kwargs, 245 | ) 246 | 247 | async def __chat( 248 | self, 249 | messages: List[ChatCompletionMessageRequestParam], 250 | model: str, 251 | max_tokens: Optional[int], 252 | temperature: float, 253 | provider_extra_parameters: Optional[Dict[str, Any]], 254 | **kwargs: Any, 255 | ) -> ChatCompletionResponse: 256 | async with self._client._post( 257 | "/v1beta2/chat/completions", 258 | payload=_build_request_payload( 259 | messages=messages, 260 | model=model, 261 | max_tokens=max_tokens, 262 | temperature=temperature, 263 | stream=False, 264 | provider_extra_parameters=provider_extra_parameters, 265 | **kwargs, 266 | ), 267 | ) as response: 268 | await self._client._check_response(response) 269 | return ChatCompletionResponse(**await response.json()) 270 | 271 | async def __chat_stream( 272 | self, 273 | messages: List[ChatCompletionMessageRequestParam], 274 | model: str, 275 | max_tokens: Optional[int], 276 | temperature: float, 277 | provider_extra_parameters: Optional[Dict[str, Any]], 278 | **kwargs: Any, 279 | ) -> AsyncIterator[ChatCompletionStreamChunkResponse]: 280 | """ 281 | Create a stream of ChatCompletionStreamChunkResponse 282 | """ 283 | async with self._client._post( 284 | "/v1beta2/chat/completions", 285 | payload=_build_request_payload( 286 | messages=messages, 287 | model=model, 288 | max_tokens=max_tokens, 289 | temperature=temperature, 290 | stream=True, 291 | provider_extra_parameters=provider_extra_parameters, 292 | **kwargs, 293 | ), 294 | ) as response: 295 | await self._client._check_streaming_response(response) 296 | async for chunk in self._client._parse_streaming_response( 297 | response): 298 | yield ChatCompletionStreamChunkResponse(**chunk) 299 | 300 | 301 | class Chat: 302 | completions: Completions 303 | 304 | def __init__(self, client: "Modelfarm") -> None: 305 | """ 306 | Initializes a new instance of the Chat class. 307 | """ 308 | self.completions = Completions(client) 309 | 310 | 311 | class AsyncChat: 312 | completions: AsyncCompletions 313 | 314 | def __init__(self, client: "AsyncModelfarm") -> None: 315 | """ 316 | Initializes a new instance of the AsyncChat class. 317 | """ 318 | self.completions = AsyncCompletions(client) 319 | 320 | 321 | def _build_request_payload( 322 | messages: List[ChatCompletionMessageRequestParam], 323 | model: str, 324 | max_tokens: Optional[int], 325 | temperature: float, 326 | stream: bool, 327 | provider_extra_parameters: Optional[Dict[str, Any]], 328 | **kwargs: Any, 329 | ) -> Dict[str, Any]: 330 | """ 331 | Builds the request payload. 332 | 333 | Args: 334 | messages (List[ChatCompletionMessageRequestParam]): The list of messages 335 | in the conversation so far. 336 | model (str): The name of the model to use. 337 | max_tokens (int): The maximum number of tokens to generate. 338 | temperature (float): The temperature of the generation. 339 | provider_extra_parameters (Optional[Dict[str, Any]]): Extra parameters 340 | of the speficic provider. 341 | 342 | Returns: 343 | Dict[str, Any]: The request payload. 344 | """ 345 | 346 | payload = { 347 | "model": model, 348 | "messages": messages, 349 | "max_tokens": max_tokens, 350 | "temperature": temperature, 351 | "stream": stream, 352 | "provider_extra_parameters": provider_extra_parameters, 353 | **kwargs, 354 | } 355 | 356 | # Drop any keys with a value of None 357 | payload = {k: v for k, v in payload.items() if v is not None} 358 | return payload 359 | -------------------------------------------------------------------------------- /src/replit/ai/modelfarm/identity/verify.py: -------------------------------------------------------------------------------- 1 | """Identity verification.""" 2 | 3 | import base64 4 | import dataclasses 5 | import datetime 6 | import json 7 | import os 8 | from typing import Callable, Dict, Optional, Set, Tuple, cast 9 | 10 | import pyseto 11 | from replit.ai.modelfarm.identity.exceptions import VerifyError 12 | from replit.ai.modelfarm.identity.goval.api import signing_pb2 13 | 14 | PubKeySource = Callable[[str, str], pyseto.KeyInterface] 15 | 16 | 17 | @dataclasses.dataclass 18 | class _MessageClaims: 19 | """Claims from a signing_pb2.GovalCert.""" 20 | 21 | repls: Set[str] = dataclasses.field(default_factory=set) 22 | users: Set[str] = dataclasses.field(default_factory=set) 23 | user_ids: Set[int] = dataclasses.field(default_factory=set) 24 | clusters: Set[str] = dataclasses.field(default_factory=set) 25 | subclusters: Set[str] = dataclasses.field(default_factory=set) 26 | flags: Set[int] = dataclasses.field(default_factory=set) 27 | 28 | 29 | def _parse_claims(cert: signing_pb2.GovalCert) -> _MessageClaims: 30 | """Parses claims from a signing_pb2.GovalCert. 31 | 32 | Args: 33 | cert: The certificate 34 | 35 | Returns: 36 | The parsed _MessageClaims. 37 | """ 38 | claims = _MessageClaims() 39 | 40 | for claim in cert.claims: 41 | if claim.WhichOneof("claim") == "replid": 42 | claims.repls.add(claim.replid) 43 | elif claim.WhichOneof("claim") == "user": 44 | claims.users.add(claim.user) 45 | elif claim.WhichOneof("claim") == "user_id": 46 | claims.user_ids.add(claim.user_id) 47 | elif claim.WhichOneof("claim") == "cluster": 48 | claims.clusters.add(claim.cluster) 49 | elif claim.WhichOneof("claim") == "subcluster": 50 | claims.subclusters.add(claim.subcluster) 51 | elif claim.WhichOneof("claim") == "flag": 52 | claims.flags.add(claim.flag) 53 | 54 | return claims 55 | 56 | 57 | def get_signing_authority(token: str) -> signing_pb2.GovalSigningAuthority: 58 | """Gets the signing authority from a token. 59 | 60 | Args: 61 | token: The token in a PASETO format. 62 | 63 | Returns: 64 | The parsed GovalSigningAuthority. 65 | 66 | Raises: 67 | VerifyError: If there's any problem verifying the token. 68 | """ 69 | # The library does not allow just grabbing the footer to know what key to 70 | # use, so we need to manually extract that. 71 | token_parts = token.split(".") 72 | if len(token_parts) != 4: 73 | raise VerifyError("token is not correctly PASETO-encoded") 74 | version, purpose, raw_payload, raw_footer = token_parts 75 | if version != "v2": 76 | raise VerifyError(f"only v2 is supported: {version}") 77 | if purpose != "public": 78 | raise VerifyError(f'only "public" purpose is supported: {purpose}') 79 | 80 | return signing_pb2.GovalSigningAuthority.FromString( 81 | base64.b64decode(base64.urlsafe_b64decode(raw_footer + "=="))) 82 | 83 | 84 | def _verify_raw_claims( 85 | replid: Optional[str] = None, 86 | user: Optional[str] = None, 87 | user_id: Optional[int] = None, 88 | cluster: Optional[str] = None, 89 | subcluster: Optional[str] = None, 90 | claims: Optional[_MessageClaims] = None, 91 | deployment: bool = False, 92 | ) -> None: 93 | if claims is None: 94 | return 95 | 96 | any_replid = signing_pb2.FlagClaim.ANY_REPLID in claims.flags 97 | any_user = signing_pb2.FlagClaim.ANY_USER in claims.flags 98 | any_user_id = signing_pb2.FlagClaim.ANY_USER_ID in claims.flags 99 | any_cluster = signing_pb2.FlagClaim.ANY_CLUSTER in claims.flags 100 | any_subcluster = signing_pb2.FlagClaim.ANY_SUBCLUSTER in claims.flags 101 | deployments = signing_pb2.FlagClaim.DEPLOYMENTS in claims.flags 102 | 103 | if not any_replid and replid is not None and replid not in claims.repls: 104 | raise VerifyError( 105 | f"not authorized (replid), got {replid!r}, want {claims.repls!r}") 106 | if not any_user and user is not None and user not in claims.users: 107 | raise VerifyError( 108 | f"not authorized (user), got {user!r}, want {claims.users!r}") 109 | if not any_user_id and user_id is not None and user_id not in claims.user_ids: 110 | raise VerifyError( 111 | f"not authorized (user_id), got {user_id!r}, want {claims.user_ids!r}" 112 | ) 113 | if not any_cluster and cluster is not None and cluster not in claims.clusters: 114 | raise VerifyError( 115 | f"not authorized (cluster), got {cluster!r}, want {claims.clusters!r}" 116 | ) 117 | if (not any_subcluster and subcluster is not None 118 | and subcluster not in claims.subclusters): 119 | raise VerifyError(f"not authorized (subcluster), " 120 | f"got {subcluster!r}, want {claims.subclusters!r}") 121 | if not deployments and deployment: 122 | raise VerifyError("not authorized (deployment)") 123 | 124 | 125 | def _verify_claims( 126 | iat: datetime.datetime, 127 | exp: datetime.datetime, 128 | replid: Optional[str] = None, 129 | user: Optional[str] = None, 130 | user_id: Optional[int] = None, 131 | cluster: Optional[str] = None, 132 | subcluster: Optional[str] = None, 133 | deployment: bool = False, 134 | claims: Optional[_MessageClaims] = None, 135 | ) -> None: 136 | now = datetime.datetime.utcnow() 137 | if iat > now: 138 | raise VerifyError(f"not valid for {iat - now}") 139 | if exp < now: 140 | raise VerifyError(f"expired {now - exp} ago") 141 | 142 | _verify_raw_claims( 143 | replid=replid, 144 | user=user, 145 | user_id=user_id, 146 | cluster=cluster, 147 | subcluster=subcluster, 148 | deployment=deployment, 149 | claims=claims, 150 | ) 151 | 152 | 153 | class Verifier: 154 | """Provides verification of tokens.""" 155 | 156 | def __init__(self) -> None: 157 | pass 158 | 159 | def verify_chain( 160 | self, 161 | token: str, 162 | pubkey_source: PubKeySource, 163 | ) -> Tuple[bytes, Optional[signing_pb2.GovalCert]]: 164 | """Verifies that the token and its signing chain are valid.""" 165 | gsa = get_signing_authority(token) 166 | 167 | if gsa.key_id != "": 168 | # If it's signed directly with a root key, grab the pubkey and 169 | # verify it. 170 | return ( 171 | self.verify_token_with_keyid(token, gsa.key_id, gsa.issuer, 172 | pubkey_source), 173 | None, 174 | ) 175 | 176 | if gsa.signed_cert != "": 177 | # If it's signed by another token, verify the other token first. 178 | signing_bytes, skip_level_cert = self.verify_chain( 179 | gsa.signed_cert, pubkey_source) 180 | 181 | # Make sure the two parent certs agree. 182 | signing_cert = self.verify_cert(signing_bytes, skip_level_cert) 183 | 184 | # Now verify this token using the parent cert. 185 | return self.verify_token_with_cert(token, 186 | signing_cert), signing_cert 187 | 188 | raise VerifyError(f"Invalid signing authority: {gsa}") 189 | 190 | def verify_token_with_keyid( 191 | self, 192 | token: str, 193 | key_id: str, 194 | issuer: str, 195 | pubkey_source: PubKeySource, 196 | ) -> bytes: 197 | """Verifies that the token is valid and signed by the keyid.""" 198 | pubkey = pubkey_source(key_id, issuer) 199 | return self.verify_token(token, pubkey) 200 | 201 | def verify_token_with_cert( 202 | self, 203 | token: str, 204 | cert: signing_pb2.GovalCert, 205 | ) -> bytes: 206 | """Verifies that the token is valid and signed by the cert.""" 207 | pubkey = pyseto.Key.from_paserk(cert.publicKey) 208 | return self.verify_token(token, pubkey) 209 | 210 | def verify_cert( 211 | self, encoded_cert: bytes, 212 | signing_cert: Optional[signing_pb2.GovalCert] 213 | ) -> signing_pb2.GovalCert: 214 | """Verifies that the certificate is valid.""" 215 | cert = signing_pb2.GovalCert.FromString(encoded_cert) 216 | 217 | # Verify that the cert is valid. 218 | _verify_claims( 219 | iat=cert.iat.ToDatetime(), 220 | exp=cert.exp.ToDatetime(), 221 | claims=_parse_claims(cert), 222 | ) 223 | 224 | # If the parent is the root cert, there's nothing else to verify. 225 | if signing_cert: 226 | claims = _parse_claims(signing_cert) 227 | if signing_pb2.FlagClaim.SIGN_INTERMEDIATE_CERT not in claims.flags: 228 | raise VerifyError( 229 | "signing cert does not have authority to sign intermediate certs" 230 | ) 231 | 232 | # Verify the cert claims agrees with its signer. 233 | authorized_claims: Set[str] = set() 234 | any_replid = False 235 | any_user = False 236 | any_user_id = False 237 | any_cluster = False 238 | any_subcluster = False 239 | for claim in signing_cert.claims: 240 | authorized_claims.add(str(claim)) 241 | if claim.WhichOneof("claim") == "flag": 242 | if claim.flag == signing_pb2.FlagClaim.ANY_REPLID: 243 | any_replid = True 244 | elif claim.flag == signing_pb2.FlagClaim.ANY_USER: 245 | any_user = True 246 | elif claim.flag == signing_pb2.FlagClaim.ANY_USER_ID: 247 | any_user_id = True 248 | elif claim.flag == signing_pb2.FlagClaim.ANY_CLUSTER: 249 | any_cluster = True 250 | elif claim.flag == signing_pb2.FlagClaim.ANY_SUBCLUSTER: 251 | any_subcluster = True 252 | 253 | for claim in signing_cert.claims: 254 | if claim.WhichOneof("claim") == "replid" and any_replid: 255 | continue 256 | if claim.WhichOneof("claim") == "user" and any_user: 257 | continue 258 | if claim.WhichOneof("claim") == "user_id" and any_user_id: 259 | continue 260 | if claim.WhichOneof("claim") == "cluster" and any_cluster: 261 | continue 262 | if claim.WhichOneof( 263 | "claim") == "subcluster" and any_subcluster: 264 | continue 265 | if str(claim) not in authorized_claims: 266 | raise VerifyError( 267 | f"signing cert does not authorize claim {claim}") 268 | 269 | return cert 270 | 271 | def verify_token( 272 | self, 273 | token: str, 274 | pubkey: pyseto.KeyInterface, 275 | ) -> bytes: 276 | """Verifies that the token is valid.""" 277 | decoded = pyseto.decode(pubkey, token) 278 | return base64.b64decode(decoded.payload) 279 | 280 | 281 | def read_public_key_from_env(keyid: str, _issuer: str) -> pyseto.KeyInterface: 282 | """Provides a [PubKeySource] that reads public keys from the environment. 283 | 284 | Args: 285 | keyid: The ID of the public key used to sign a token. 286 | issuer: The name of the issuer of the certificate. 287 | 288 | Returns: 289 | The public key corresponding to the key id. 290 | """ 291 | pubkeys = cast(Dict[str, str], json.loads(os.getenv("REPL_PUBKEYS"))) 292 | key = base64.b64decode(pubkeys[keyid]) 293 | return pyseto.Key.from_asymmetric_key_params(version=2, x=key) 294 | 295 | 296 | def verify_identity_token( 297 | identity_token: str, 298 | audience: str, 299 | pubkey_source: PubKeySource = read_public_key_from_env, 300 | ) -> signing_pb2.GovalReplIdentity: 301 | """Verifies a Repl Identity token. 302 | 303 | Args: 304 | identity_token: The Identity token. 305 | audience: The audience that the token was signed for. 306 | pubkey_source: The PubKeySource to get the public key. 307 | 308 | Returns: 309 | The parsed and verified signing_pb2.GovalReplIdentity. 310 | 311 | Raises: 312 | VerifyError: If there's any problem verifying the token. 313 | """ 314 | v = Verifier() 315 | raw_goval_token, goval_cert = v.verify_chain(identity_token, pubkey_source) 316 | repl_identity = signing_pb2.GovalReplIdentity.FromString(raw_goval_token) 317 | 318 | # Verify that the cert is valid. 319 | if repl_identity.aud != audience: 320 | raise VerifyError( 321 | f"not authorized (audience), got {repl_identity.aud!r}, want {audience!r}" 322 | ) 323 | deployment: bool = False 324 | cluster: Optional[str] = None 325 | subcluster: Optional[str] = None 326 | if repl_identity.WhichOneof("runtime") == "deployment": 327 | deployment = True 328 | elif repl_identity.WhichOneof("runtime") == "interactive": 329 | cluster = repl_identity.interactive.cluster 330 | subcluster = repl_identity.interactive.subcluster 331 | _verify_claims( 332 | iat=goval_cert.iat.ToDatetime(), 333 | exp=goval_cert.exp.ToDatetime(), 334 | cluster=cluster, 335 | subcluster=subcluster, 336 | deployment=deployment, 337 | claims=_parse_claims(goval_cert) if goval_cert else None, 338 | ) 339 | return repl_identity 340 | --------------------------------------------------------------------------------