├── app ├── api │ ├── __init__.py │ ├── schemas │ │ ├── __init__.py │ │ ├── admin.py │ │ ├── cached_user.py │ │ ├── stripe.py │ │ ├── user.py │ │ ├── forge_api_key.py │ │ ├── statistic.py │ │ ├── provider_key.py │ │ └── anthropic.py │ └── routes │ │ ├── admin.py │ │ ├── __init__.py │ │ ├── stripe.py │ │ ├── auth.py │ │ ├── stats.py │ │ └── health.py ├── core │ ├── __init__.py │ ├── logger.py │ ├── security.py │ └── database.py ├── models │ ├── __init__.py │ ├── admin_users.py │ ├── base.py │ ├── stripe.py │ ├── wallet.py │ ├── provider_key.py │ ├── usage_tracker.py │ ├── user.py │ ├── api_request_log.py │ ├── forge_api_key.py │ └── pricing.py ├── exceptions │ ├── __init__.py │ └── exceptions.py ├── services │ ├── __init__.py │ └── providers │ │ ├── __init__.py │ │ ├── perplexity_adapter.py │ │ ├── zai_adapter.py │ │ ├── zhipu_adapter.py │ │ ├── tensorblock_adapter.py │ │ ├── alibaba_adapter.py │ │ ├── gemini_openai_adapter.py │ │ └── fireworks_adapter.py ├── __init__.py └── utils │ ├── serialization.py │ └── translator.py ├── assets └── service.jpg ├── tests ├── cache │ └── __init__.py ├── __init__.py ├── mock_testing │ ├── examples │ │ ├── __init__.py │ │ └── test_with_mocks.py │ ├── __init__.py │ ├── run_mock_tests.py │ └── mock_openai.py ├── performance │ ├── __init__.py │ └── README.md ├── unit_tests │ ├── assets │ │ ├── openai │ │ │ ├── embeddings_response.json │ │ │ ├── chat_completion_response_1.json │ │ │ ├── list_models.json │ │ │ └── responses_response_1.json │ │ ├── anthropic │ │ │ ├── chat_completion_response_1.json │ │ │ ├── list_models.json │ │ │ └── chat_completion_streaming_response_1.json │ │ └── google │ │ │ ├── chat_completion_response_1.json │ │ │ ├── chat_completion_streaming_response_1.json │ │ │ └── list_models.json │ ├── test_vertex_adapter.py │ └── test_security.py ├── run_tests.py ├── README.md ├── test_token_counting_stream.py └── frontend_simulation.py ├── run_unit_tests.sh ├── k8s ├── service.yaml └── deployment.yaml ├── alembic ├── versions │ ├── 6a92c2663fa5_base.py │ ├── f5522808aba9_.py │ ├── 683fc811a969_add_billable_column_to_usage_tracker.py │ ├── f45e46b231f3_add_admin_users.py │ ├── 831fc2cf16ee_enable_soft_deletion_for_provider_keys_.py │ ├── b206e9a941e3_add_cost_tracking_to_usage_tracker_table.py │ ├── 40e4b59f754d_add_stripe_payment_table.py │ ├── 9daf34d338f7_update_model_mapping_type_for_.py │ ├── b5d4363a9f62_add_clerk_user_id_to_users.py │ ├── a58395ea1b22_add_balance_system.py │ ├── 4a685a55c5cd_create_usage_tracker_table.py │ ├── 39bcedfae4fe_add_model_default_pricing.py │ ├── 08cc005a4bc8_create_forge_api_key_provider_scope_.py │ ├── 4a82fb8af123_create_forge_api_keys_table.py │ ├── ca1ac51334ec_create_usage_stats_table.py │ ├── 0ce4eeae965f_drop_usage_stats_table.py │ ├── initial_migration.py │ ├── c50fd7be794c_add_endpoint_column_to_usage_stats.py │ ├── b38aad374524_create_api_request_log_table.py │ └── c9f3e548adef_add_lambda_model_pricing.py ├── script.py.mako └── env.py ├── docker-entrypoint.sh ├── .pre-commit-config.yaml ├── .gitignore ├── tools ├── data │ └── lambda_model_pricing_init.csv └── diagnostics │ ├── run_test_clean.py │ ├── fix_model_mapping.py │ ├── enable_request_logging.py │ ├── clear_cache.py │ ├── check_dotenv.py │ └── check_db_keys.py ├── docker-compose.yml ├── reset_alembic.py ├── cli_tools └── config.ini ├── run.py ├── .github ├── PULL_REQUEST_TEMPLATE │ └── pull_request_template.md └── workflows │ └── docker-build.yml ├── Dockerfile ├── .env.example ├── alembic.ini ├── pyproject.toml ├── docs └── PERFORMANCE_OPTIMIZATIONS.md └── LICENSE /app/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/api/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/services/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | # Forge middleware service 2 | -------------------------------------------------------------------------------- /assets/service.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TensorBlock/forge/HEAD/assets/service.jpg -------------------------------------------------------------------------------- /tests/cache/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for the caching functionality in Forge. 3 | """ 4 | -------------------------------------------------------------------------------- /app/services/providers/__init__.py: -------------------------------------------------------------------------------- 1 | # This file should be created to make the directory a package 2 | -------------------------------------------------------------------------------- /run_unit_tests.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash -xe 2 | uv run pytest -vv tests/unit_tests --cov=app --cov-report=xml 3 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # This file is intentionally left empty to make the tests directory a proper Python package 2 | -------------------------------------------------------------------------------- /tests/mock_testing/examples/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mock testing examples. 3 | 4 | This package contains example tests that demonstrate using the mock client 5 | in different testing scenarios. 6 | """ 7 | -------------------------------------------------------------------------------- /tests/performance/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Performance test module for Forge API. 3 | 4 | This package contains tests to measure and analyze the performance 5 | of the Forge middleware service. 6 | """ 7 | 8 | __version__ = "0.1.0" 9 | -------------------------------------------------------------------------------- /tests/mock_testing/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Forge Mock Testing Package. 3 | 4 | This package provides tools for testing Forge with mock providers 5 | that simulate API responses without making actual API calls. 6 | """ 7 | 8 | # Export key components for easy importing 9 | -------------------------------------------------------------------------------- /k8s/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: forge-service 5 | labels: 6 | app: forge 7 | spec: 8 | type: LoadBalancer 9 | ports: 10 | - port: 80 11 | targetPort: 8000 12 | protocol: TCP 13 | name: http 14 | selector: 15 | app: forge 16 | -------------------------------------------------------------------------------- /app/models/admin_users.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, ForeignKey 2 | from sqlalchemy.orm import relationship 3 | 4 | from app.models.base import Base 5 | 6 | class AdminUsers(Base): 7 | __tablename__ = "admin_users" 8 | 9 | user_id = Column(Integer, ForeignKey("users.id"), primary_key=True) 10 | user = relationship("User", back_populates="admin_users") 11 | -------------------------------------------------------------------------------- /app/api/schemas/admin.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, field_validator 2 | 3 | class AddBalanceRequest(BaseModel): 4 | user_id: int | None = None 5 | email: str | None = None 6 | amount: int # in cents 7 | 8 | @field_validator("amount") 9 | def validate_amount(cls, value: float): 10 | if value < 100: 11 | raise ValueError("Amount must be greater than 100 cents") 12 | return value 13 | -------------------------------------------------------------------------------- /alembic/versions/6a92c2663fa5_base.py: -------------------------------------------------------------------------------- 1 | """base 2 | 3 | Revision ID: 6a92c2663fa5 4 | Revises: f5522808aba9 5 | Create Date: 2025-04-29 16:02:26.813233 6 | 7 | """ 8 | 9 | 10 | # revision identifiers, used by Alembic. 11 | revision = "6a92c2663fa5" 12 | down_revision = "f5522808aba9" 13 | branch_labels = None 14 | depends_on = None 15 | 16 | 17 | def upgrade() -> None: 18 | pass 19 | 20 | 21 | def downgrade() -> None: 22 | pass 23 | -------------------------------------------------------------------------------- /alembic/versions/f5522808aba9_.py: -------------------------------------------------------------------------------- 1 | """empty message 2 | 3 | Revision ID: f5522808aba9 4 | Revises: 4a82fb8af123 5 | Create Date: 2025-04-29 16:02:04.550374 6 | 7 | """ 8 | 9 | 10 | # revision identifiers, used by Alembic. 11 | revision = "f5522808aba9" 12 | down_revision = "4a82fb8af123" 13 | branch_labels = None 14 | depends_on = None 15 | 16 | 17 | def upgrade() -> None: 18 | pass 19 | 20 | 21 | def downgrade() -> None: 22 | pass 23 | -------------------------------------------------------------------------------- /app/services/providers/perplexity_adapter.py: -------------------------------------------------------------------------------- 1 | from .openai_adapter import OpenAIAdapter 2 | 3 | PERPLEXITY_MODELS = [ 4 | "sonar", 5 | "sonar-reasoning-pro", 6 | "sonar-reasoning", 7 | "sonar-pro", 8 | "sonar", 9 | "sonar-deep-research", 10 | ] 11 | 12 | class PerplexityAdapter(OpenAIAdapter): 13 | """Adapter for Perplexity API""" 14 | 15 | async def list_models(self, api_key: str) -> list[str]: 16 | return PERPLEXITY_MODELS 17 | -------------------------------------------------------------------------------- /tests/unit_tests/assets/openai/embeddings_response.json: -------------------------------------------------------------------------------- 1 | { 2 | "object": "list", 3 | "data": [ 4 | { 5 | "object": "embedding", 6 | "index": 0, 7 | "embedding": [ 8 | -0.006929283495992422, 9 | -0.005336422007530928, 10 | -4.547132266452536e-05, 11 | -0.024047503247857094 12 | ] 13 | } 14 | ], 15 | "model": "text-embedding-ada-002", 16 | "usage": { 17 | "prompt_tokens": 8, 18 | "total_tokens": 8 19 | } 20 | } -------------------------------------------------------------------------------- /app/services/providers/zai_adapter.py: -------------------------------------------------------------------------------- 1 | from .openai_adapter import OpenAIAdapter 2 | 3 | # https://docs.z.ai/api-reference/llm/chat-completion#body-model 4 | ZAI_MODELS = [ 5 | "glm-4.5", 6 | "glm-4.5-air", 7 | "glm-4.5-x", 8 | "glm-4.5-airx", 9 | "glm-4.5-flash", 10 | "glm-4-32b-0414-128k", 11 | ] 12 | 13 | 14 | class ZAIAdapter(OpenAIAdapter): 15 | """Adapter for Zai API""" 16 | 17 | async def list_models(self, api_key: str) -> list[str]: 18 | return ZAI_MODELS -------------------------------------------------------------------------------- /app/models/base.py: -------------------------------------------------------------------------------- 1 | from datetime import UTC 2 | from datetime import datetime 3 | 4 | from sqlalchemy import Column, DateTime, Integer 5 | 6 | from app.core.database import Base 7 | 8 | 9 | class BaseModel(Base): 10 | __abstract__ = True 11 | 12 | id = Column(Integer, primary_key=True, index=True) 13 | created_at = Column(DateTime, default=datetime.utcnow) 14 | updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) 15 | deleted_at = Column(DateTime(timezone=True), nullable=True) 16 | -------------------------------------------------------------------------------- /app/services/providers/zhipu_adapter.py: -------------------------------------------------------------------------------- 1 | from .openai_adapter import OpenAIAdapter 2 | ZHIPU_MODELS = [ 3 | "glm-4-plus", 4 | "glm-4-0520", 5 | "glm-4", 6 | "glm-4-air", 7 | "glm-4-airx", 8 | "glm-4-long", 9 | "glm-4-flash", 10 | "glm-4v-plus-0111", 11 | "glm-4v-flash" 12 | "glm-z1-air", 13 | "glm-z1-airx", 14 | "glm-z1-flash", 15 | ] 16 | 17 | class ZhipuAdapter(OpenAIAdapter): 18 | """Adapter for Zhipu API""" 19 | 20 | async def list_models(self, api_key: str) -> list[str]: 21 | return ZHIPU_MODELS 22 | -------------------------------------------------------------------------------- /docker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Exit if any command fails 4 | set -e 5 | 6 | # Set correct permissions for the logs directory at runtime. 7 | # This ensures the 'nobody' user can write to the volume. 8 | chown -R nobody:nogroup /app/logs 9 | 10 | # Run Alembic migrations 11 | echo "Running database migrations..." 12 | if ! alembic upgrade head; then 13 | echo "⚠️ Warning: Alembic migration failed. Continuing without shutdown." 14 | fi 15 | 16 | # Use gosu to drop from root to the 'nobody' user and run the main command 17 | exec gosu nobody "$@" 18 | 19 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.2.0 12 | hooks: 13 | # Run the linter 14 | - id: ruff 15 | args: [--fix] 16 | types_or: [python, pyi, jupyter] 17 | # Run the formatter 18 | - id: ruff-format 19 | types_or: [python, pyi, jupyter] 20 | -------------------------------------------------------------------------------- /tests/unit_tests/test_vertex_adapter.py: -------------------------------------------------------------------------------- 1 | from app.services.providers.vertex_adapter import VertexAdapter 2 | 3 | 4 | def test_vertex_adapter_base_url_global(): 5 | config = {"publisher": "anthropic", "location": "global"} 6 | adapter = VertexAdapter("vertex", None, config) 7 | assert adapter._base_url == "https://aiplatform.googleapis.com" 8 | 9 | 10 | def test_vertex_adapter_base_url_region(): 11 | config = {"publisher": "anthropic", "location": "us-east1"} 12 | adapter = VertexAdapter("vertex", None, config) 13 | assert adapter._base_url == "https://us-east1-aiplatform.googleapis.com" -------------------------------------------------------------------------------- /alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade() -> None: 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade() -> None: 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /alembic/versions/683fc811a969_add_billable_column_to_usage_tracker.py: -------------------------------------------------------------------------------- 1 | """add billable column to usage_tracker 2 | 3 | Revision ID: 683fc811a969 4 | Revises: 40e4b59f754d 5 | Create Date: 2025-09-05 10:48:09.623668 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '683fc811a969' 14 | down_revision = '40e4b59f754d' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | op.add_column('usage_tracker', sa.Column('billable', sa.Boolean(), nullable=False, server_default='FALSE')) 21 | 22 | 23 | def downgrade() -> None: 24 | op.drop_column('usage_tracker', 'billable') 25 | -------------------------------------------------------------------------------- /app/api/schemas/cached_user.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from pydantic import BaseModel, ConfigDict, EmailStr 4 | 5 | 6 | class CachedUser(BaseModel): 7 | """ 8 | A Pydantic model representing the data of a User object to be stored in the cache. 9 | This avoids caching the full SQLAlchemy ORM object, which prevents DetachedInstanceError 10 | and improves performance by removing the need for `db.merge()`. 11 | """ 12 | 13 | id: int 14 | email: EmailStr 15 | username: str 16 | is_active: bool 17 | clerk_user_id: str | None = None 18 | created_at: datetime 19 | updated_at: datetime 20 | model_config = ConfigDict(from_attributes=True) 21 | -------------------------------------------------------------------------------- /app/models/stripe.py: -------------------------------------------------------------------------------- 1 | from app.models.base import Base 2 | from sqlalchemy import Column, String, Integer, ForeignKey, JSON, DateTime 3 | from datetime import datetime, UTC 4 | 5 | class StripePayment(Base): 6 | __tablename__ = "stripe_payment" 7 | 8 | id = Column(String, primary_key=True) 9 | user_id = Column(Integer, ForeignKey('users.id'), nullable=False) 10 | amount = Column(Integer, nullable=False) 11 | currency = Column(String(3), nullable=False) 12 | status = Column(String, nullable=False) 13 | raw_data = Column(JSON, nullable=True) 14 | created_at = Column(DateTime(timezone=True), default=datetime.now(UTC)) 15 | updated_at = Column(DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)) -------------------------------------------------------------------------------- /tests/unit_tests/assets/anthropic/chat_completion_response_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "msg_01Kpzcjm6jUdJ5RY5Dxt3rV8", 3 | "type": "message", 4 | "role": "assistant", 5 | "model": "claude-sonnet-4-20250514", 6 | "content": [ 7 | { 8 | "type": "text", 9 | "text": "Hello! I'm doing well, thank you for asking. I'm here and ready to help with whatever you'd like to discuss or work on. How are you doing today?" 10 | } 11 | ], 12 | "stop_reason": "end_turn", 13 | "stop_sequence": null, 14 | "usage": { 15 | "input_tokens": 13, 16 | "cache_creation_input_tokens": 0, 17 | "cache_read_input_tokens": 0, 18 | "output_tokens": 39, 19 | "service_tier": "standard" 20 | } 21 | } -------------------------------------------------------------------------------- /alembic/versions/f45e46b231f3_add_admin_users.py: -------------------------------------------------------------------------------- 1 | """add admin users 2 | 3 | Revision ID: f45e46b231f3 4 | Revises: 683fc811a969 5 | Create Date: 2025-09-11 13:14:17.066592 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = 'f45e46b231f3' 14 | down_revision = '683fc811a969' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | op.create_table( 21 | "admin_users", 22 | sa.Column("user_id", sa.Integer(), nullable=False), 23 | sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), 24 | sa.PrimaryKeyConstraint("user_id"), 25 | ) 26 | 27 | 28 | def downgrade() -> None: 29 | op.drop_table("admin_users") 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # Environment 28 | .env 29 | .venv 30 | env/ 31 | ENV/ 32 | env.bak/ 33 | venv.bak/ 34 | 35 | # Database 36 | *.db 37 | *.sqlite3 38 | 39 | # IDE 40 | .idea/ 41 | .vscode/ 42 | *.swp 43 | *.swo 44 | 45 | # Logs 46 | *.log 47 | 48 | # VScode settings 49 | .vscode-upload.json 50 | 51 | .ruff_cache/ 52 | 53 | # macOS 54 | .DS_Store 55 | 56 | # Performance test results 57 | tests/performance/results/ 58 | .coverage* 59 | coverage* 60 | 61 | # Local test script 62 | tools/tmp/* 63 | -------------------------------------------------------------------------------- /app/models/wallet.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, UTC 2 | from sqlalchemy import Column, BigInteger, CHAR, DECIMAL, Boolean, DateTime, ForeignKey 3 | from sqlalchemy.orm import relationship 4 | from .base import Base 5 | 6 | class Wallet(Base): 7 | __tablename__ = "wallets" 8 | 9 | account_id = Column(BigInteger, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True) 10 | currency = Column(CHAR(3), nullable=False, default='USD') 11 | balance = Column(DECIMAL(20, 6), nullable=False, default=0) 12 | blocked = Column(Boolean, nullable=False, default=False) 13 | version = Column(BigInteger, nullable=False, default=0) 14 | created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.now(UTC)) 15 | updated_at = Column(DateTime(timezone=True), nullable=False, default=datetime.now(UTC)) 16 | 17 | user = relationship("User", back_populates="wallet") 18 | -------------------------------------------------------------------------------- /app/services/providers/tensorblock_adapter.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from .azure_adapter import AzureAdapter 4 | 5 | TENSORBLOCK_MODELS = [ 6 | "gpt-4.1-mini", 7 | "gpt-4.1-nano", 8 | "gpt-4o-mini", 9 | "o3-mini", 10 | "text-embedding-3-large", 11 | "text-embedding-3-small", 12 | "text-embedding-ada-002", 13 | ] 14 | 15 | 16 | class TensorblockAdapter(AzureAdapter): 17 | """Adapter for Tensorblock API""" 18 | 19 | def __init__(self, provider_name: str, base_url: str, config: dict[str, Any]): 20 | super().__init__(provider_name, base_url, config) 21 | 22 | def get_mapped_model(self, model: str) -> str: 23 | """Get the Azure-specific model name""" 24 | # For TensorBlock, we use the model name as-is since it's already in the correct format 25 | return model 26 | 27 | async def list_models(self, api_key: str) -> list[str]: 28 | return TENSORBLOCK_MODELS 29 | -------------------------------------------------------------------------------- /tools/data/lambda_model_pricing_init.csv: -------------------------------------------------------------------------------- 1 | provider_name,model_name,input_token_price,output_token_price 2 | lambda,deepseek-r1-0528,0.0005,0.00218 3 | lambda,deepseek-v3-0324,0.00034,0.00088 4 | lambda,qwen-3-32b,0.0001,0.0003 5 | lambda,llama-4-maverick-17b-128e-instruct-fp8,0.00018,0.0006 6 | lambda,llama-4-scout-17b-16e-instruct,0.00008,0.0003 7 | lambda,llama-3.1-8b-instruct,0.000025,0.00004 8 | lambda,llama-3.1-70b-instruct,0.00012,0.0003 9 | lambda,llama-3.1-405b-instruct,0.0008,0.0008 10 | lambda,deepseek-llama3.3-70b,0.0002,0.0006 11 | lambda,llama-3.3-70b-instruct,0.00012,0.0003 12 | lambda,llama-3.2-3b-instruct,0.000015,0.000025 13 | lambda,hermes-3-llama-3.1-8b,0.000025,0.00004 14 | lambda,hermes-3-llama-3.1-70b (fp8),0.00012,0.0003 15 | lambda,hermes-3-llama-3.1-405b (fp8),0.0008,0.0008 16 | lambda,lfm-40b,0.00015,0.00015 17 | lambda,llama3.1-nemotron-70b-instruct,0.00012,0.0003 18 | lambda,qwen2.5-coder-32b,0.00007,0.00016 19 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | app: 3 | build: . 4 | ports: 5 | - "8000:8000" 6 | environment: 7 | - DATABASE_URL=postgresql://forge:forge@db:5432/forge 8 | - PORT=8000 9 | - DEBUG_CACHE=false # Control cache debugging 10 | - FORGE_DEBUG_LOGGING=false # Control application logging level 11 | depends_on: 12 | - db 13 | volumes: 14 | - .:/app 15 | - forge_logs:/app/logs # Persist logs 16 | networks: 17 | - forge-network 18 | 19 | db: 20 | image: postgres:14 21 | environment: 22 | - POSTGRES_USER=forge 23 | - POSTGRES_PASSWORD=forge 24 | - POSTGRES_DB=forge 25 | volumes: 26 | - postgres_data:/var/lib/postgresql/data 27 | ports: 28 | - "5432:5432" 29 | networks: 30 | - forge-network 31 | 32 | volumes: 33 | postgres_data: 34 | forge_logs: # Add volume for logs 35 | 36 | networks: 37 | forge-network: 38 | driver: bridge 39 | -------------------------------------------------------------------------------- /alembic/versions/831fc2cf16ee_enable_soft_deletion_for_provider_keys_.py: -------------------------------------------------------------------------------- 1 | """enable soft deletion for provider keys and api keys 2 | 3 | Revision ID: 831fc2cf16ee 4 | Revises: 4a685a55c5cd 5 | Create Date: 2025-08-02 17:50:12.224293 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '831fc2cf16ee' 14 | down_revision = '4a685a55c5cd' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | op.add_column('provider_keys', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True)) 21 | op.add_column('forge_api_keys', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True)) 22 | op.alter_column('forge_api_keys', 'key', nullable=True) 23 | 24 | 25 | def downgrade() -> None: 26 | op.drop_column('provider_keys', 'deleted_at') 27 | op.drop_column('forge_api_keys', 'deleted_at') 28 | op.alter_column('forge_api_keys', 'key', nullable=False) 29 | -------------------------------------------------------------------------------- /alembic/versions/b206e9a941e3_add_cost_tracking_to_usage_tracker_table.py: -------------------------------------------------------------------------------- 1 | """add cost tracking to usage_tracker table 2 | 3 | Revision ID: b206e9a941e3 4 | Revises: 1876c1c4bc96 5 | Create Date: 2025-08-11 18:19:08.581296 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = 'b206e9a941e3' 14 | down_revision = '1876c1c4bc96' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | op.add_column('usage_tracker', sa.Column('cost', sa.DECIMAL(precision=12, scale=8), nullable=True)) 21 | op.add_column('usage_tracker', sa.Column('currency', sa.String(length=3), nullable=True)) 22 | op.add_column('usage_tracker', sa.Column('pricing_source', sa.String(length=255), nullable=True)) 23 | 24 | 25 | def downgrade() -> None: 26 | op.drop_column('usage_tracker', 'cost') 27 | op.drop_column('usage_tracker', 'currency') 28 | op.drop_column('usage_tracker', 'pricing_source') 29 | -------------------------------------------------------------------------------- /tools/diagnostics/run_test_clean.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to run the test with a clean environment. 4 | """ 5 | 6 | import os 7 | import subprocess 8 | import sys 9 | 10 | 11 | def main(): 12 | """Run the test in a clean environment.""" 13 | # Find the project root directory (where this script is located) 14 | script_dir = os.path.dirname( 15 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 16 | ) 17 | os.chdir(script_dir) # Change to project root directory 18 | 19 | # Create a clean environment without the problematic variable 20 | clean_env = os.environ.copy() 21 | if "FORGE_API_KEY" in clean_env: 22 | del clean_env["FORGE_API_KEY"] 23 | 24 | # Run the test with the clean environment 25 | result = subprocess.run( 26 | [sys.executable, "tests/frontend_simulation.py"], env=clean_env, check=False 27 | ) 28 | 29 | print(f"\nTest completed with exit code: {result.returncode}") 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /tests/unit_tests/assets/anthropic/list_models.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "type": "model", 5 | "id": "claude-opus-4-20250514", 6 | "display_name": "Claude Opus 4", 7 | "created_at": "2025-05-22T00:00:00Z" 8 | }, 9 | { 10 | "type": "model", 11 | "id": "claude-sonnet-4-20250514", 12 | "display_name": "Claude Sonnet 4", 13 | "created_at": "2025-05-22T00:00:00Z" 14 | }, 15 | { 16 | "type": "model", 17 | "id": "claude-3-7-sonnet-20250219", 18 | "display_name": "Claude Sonnet 3.7", 19 | "created_at": "2025-02-24T00:00:00Z" 20 | }, 21 | { 22 | "type": "model", 23 | "id": "claude-3-5-sonnet-20241022", 24 | "display_name": "Claude Sonnet 3.5 (New)", 25 | "created_at": "2024-10-22T00:00:00Z" 26 | } 27 | ], 28 | "has_more": false, 29 | "first_id": "claude-opus-4-20250514", 30 | "last_id": "claude-3-opus-20240229" 31 | } -------------------------------------------------------------------------------- /tests/unit_tests/assets/google/chat_completion_response_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "candidates": [ 3 | { 4 | "content": { 5 | "parts": [ 6 | { 7 | "text": "I am doing well, thank you for asking. How are you today?\n" 8 | } 9 | ], 10 | "role": "model" 11 | }, 12 | "finishReason": "STOP", 13 | "avgLogprobs": -0.2357906699180603 14 | } 15 | ], 16 | "usageMetadata": { 17 | "promptTokenCount": 6, 18 | "candidatesTokenCount": 16, 19 | "totalTokenCount": 22, 20 | "promptTokensDetails": [ 21 | { 22 | "modality": "TEXT", 23 | "tokenCount": 6 24 | } 25 | ], 26 | "candidatesTokensDetails": [ 27 | { 28 | "modality": "TEXT", 29 | "tokenCount": 16 30 | } 31 | ] 32 | }, 33 | "modelVersion": "gemini-1.5-pro-002", 34 | "responseId": "zCFeaOSZBvmSmNAP-sXGiQU" 35 | } -------------------------------------------------------------------------------- /app/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict 3 | 4 | def serialize_dict(data: Dict[str, Any]) -> str: 5 | """ 6 | Serialize a dictionary to a string using JSON. 7 | 8 | Args: 9 | data: The dictionary to serialize 10 | 11 | Returns: 12 | A JSON string representation of the dictionary 13 | 14 | Raises: 15 | TypeError: If the input is not a dictionary 16 | ValueError: If the dictionary contains non-serializable values 17 | """ 18 | if not isinstance(data, dict): 19 | raise TypeError("Input must be a dictionary") 20 | 21 | return json.dumps(data) 22 | 23 | def deserialize_dict(serialized_data: str) -> Dict[str, Any]: 24 | """ 25 | Deserialize a string back into a dictionary using JSON. 26 | 27 | Args: 28 | serialized_data: The JSON string to deserialize 29 | 30 | Returns: 31 | The deserialized dictionary 32 | 33 | Raises: 34 | ValueError: If the input string is not valid JSON 35 | """ 36 | return json.loads(serialized_data) -------------------------------------------------------------------------------- /reset_alembic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | from dotenv import load_dotenv 6 | from sqlalchemy import create_engine 7 | from sqlalchemy.sql import text 8 | 9 | load_dotenv() 10 | 11 | # Get connection string from environment variable 12 | conn_string = os.environ.get("DATABASE_URL") 13 | if not conn_string: 14 | print("Please set the DATABASE_URL environment variable") 15 | sys.exit(1) 16 | 17 | engine = create_engine(conn_string) 18 | 19 | with engine.connect() as connection: 20 | try: 21 | connection.execute(text("DROP TABLE IF EXISTS alembic_version")) 22 | connection.execute( 23 | text("CREATE TABLE alembic_version (version_num VARCHAR(32) PRIMARY KEY)") 24 | ) 25 | connection.execute(text("INSERT INTO alembic_version VALUES ('6a92c2663fa5')")) 26 | connection.commit() 27 | print("Successfully reset alembic version to 6a92c2663fa5") 28 | except Exception as e: 29 | print(f"Error resetting alembic version: {e}") 30 | print("Error: No version found in alembic_version table") 31 | sys.exit(1) 32 | -------------------------------------------------------------------------------- /app/core/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger configuration using loguru. 3 | """ 4 | 5 | import sys 6 | from pathlib import Path 7 | 8 | from loguru import logger 9 | 10 | # Remove default handler 11 | logger.remove() 12 | 13 | # Add console handler with custom format 14 | logger.add( 15 | sys.stderr, 16 | format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", 17 | level="DEBUG", 18 | enqueue=True, # Thread-safe logging 19 | backtrace=True, # Detailed traceback 20 | diagnose=True, # Enable exception diagnosis 21 | ) 22 | 23 | # Add file handler for debugging 24 | log_path = Path("logs") 25 | log_path.mkdir(exist_ok=True) 26 | 27 | logger.add( 28 | "logs/forge_{time}.log", 29 | rotation="1 day", # Create new file daily 30 | retention="1 week", # Keep logs for 1 week 31 | compression="zip", # Compress rotated logs 32 | level="DEBUG", 33 | enqueue=True, 34 | backtrace=True, 35 | diagnose=True, 36 | ) 37 | 38 | # Export logger instance 39 | get_logger = logger.bind 40 | -------------------------------------------------------------------------------- /cli_tools/config.ini: -------------------------------------------------------------------------------- 1 | [api] 2 | base_url = http://localhost:8000/v1 3 | timeout = 30 4 | retry_attempts = 3 5 | 6 | [ui] 7 | default_theme = dark 8 | show_clock = true 9 | show_notifications = true 10 | auto_refresh_interval = 30 11 | 12 | [display] 13 | table_page_size = 15 14 | max_log_lines = 100 15 | truncate_long_text = true 16 | max_text_length = 50 17 | 18 | [keybindings] 19 | # Custom keybindings (override defaults) 20 | quit = "q" 21 | toggle_theme = "f1" 22 | refresh_all = "f2" 23 | new_api_key = "ctrl+n" 24 | new_provider = "ctrl+p" 25 | help = "f12" 26 | 27 | [notifications] 28 | # Notification settings 29 | show_success = true 30 | show_errors = true 31 | show_warnings = true 32 | auto_dismiss_time = 5 # 33 | 34 | [development] 35 | # Development and debugging options 36 | debug_mode = false 37 | log_api_requests = false 38 | mock_api_responses = false 39 | 40 | # Color scheme customization (optional) 41 | [colors] 42 | # Uncomment and modify to customize colors 43 | # primary = "#007acc" 44 | # secondary = "#6c757d" 45 | # success = "#28a745" 46 | # warning = "#ffc107" 47 | # error = "#dc3545" 48 | # accent = "#17a2b8" 49 | -------------------------------------------------------------------------------- /k8s/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: forge-app 5 | labels: 6 | app: forge 7 | spec: 8 | replicas: 2 9 | selector: 10 | matchLabels: 11 | app: forge 12 | template: 13 | metadata: 14 | labels: 15 | app: forge 16 | spec: 17 | containers: 18 | - name: forge 19 | image: tensorblockai/forge:latest 20 | ports: 21 | - containerPort: 8000 22 | envFrom: 23 | - secretRef: 24 | name: forge-secrets 25 | env: 26 | - name: PORT 27 | value: "8000" 28 | resources: 29 | requests: 30 | cpu: "100m" 31 | memory: "256Mi" 32 | limits: 33 | cpu: "500m" 34 | memory: "512Mi" 35 | livenessProbe: 36 | httpGet: 37 | path: /health 38 | port: 8000 39 | initialDelaySeconds: 30 40 | periodSeconds: 10 41 | readinessProbe: 42 | httpGet: 43 | path: /health 44 | port: 8000 45 | initialDelaySeconds: 5 46 | periodSeconds: 5 47 | -------------------------------------------------------------------------------- /app/services/providers/alibaba_adapter.py: -------------------------------------------------------------------------------- 1 | from .openai_adapter import OpenAIAdapter 2 | 3 | ALIBABA_MODELS = [ 4 | "qwen-max", 5 | "qwen-max-latest", 6 | "qwen-max-2025-01-25", 7 | "qwen-plus", 8 | "qwen-plus-latest", 9 | "qwen-plus-2025-04-28", 10 | "qwen-plus-2025-01-25", 11 | "qwen-turbo", 12 | "qwen-turbo-latest", 13 | "qwen-turbo-2025-04-28", 14 | "qwen-turbo-2024-11-01", 15 | "qwq-32b", 16 | "qwen3-235b-a22b", 17 | "qwen3-32b", 18 | "qwen3-30b-a3b", 19 | "qwen3-14b", 20 | "qwen3-8b", 21 | "qwen3-4b", 22 | "qwen3-1.7b", 23 | "qwen3-0.6b", 24 | "qwen2.5-14b-instruct-1m", 25 | "qwen2.5-7b-instruct-1m", 26 | "qwen2.5-72b-instruct", 27 | "qwen2.5-32b-instruct", 28 | "qwen2.5-14b-instruct", 29 | "qwen2.5-7b-instruct", 30 | "qwen2-72b-instruct", 31 | "qwen2-7b-instruct", 32 | "qwen1.5-110b-chat", 33 | "qwen1.5-72b-chat", 34 | "qwen1.5-32b-chat", 35 | "qwen1.5-14b-chat", 36 | "qwen1.5-7b-chat", 37 | ] 38 | 39 | class AlibabaAdapter(OpenAIAdapter): 40 | """Adapter for Alibaba API""" 41 | 42 | async def list_models(self, api_key: str) -> list[str]: 43 | return ALIBABA_MODELS -------------------------------------------------------------------------------- /alembic/versions/40e4b59f754d_add_stripe_payment_table.py: -------------------------------------------------------------------------------- 1 | """add stripe payment table 2 | 3 | Revision ID: 40e4b59f754d 4 | Revises: a58395ea1b22 5 | Create Date: 2025-09-02 20:52:29.183031 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | from datetime import datetime, UTC 11 | 12 | 13 | # revision identifiers, used by Alembic. 14 | revision = '40e4b59f754d' 15 | down_revision = 'a58395ea1b22' 16 | branch_labels = None 17 | depends_on = None 18 | 19 | 20 | def upgrade() -> None: 21 | op.create_table('stripe_payment', 22 | sa.Column('id', sa.String, primary_key=True), 23 | sa.Column('user_id', sa.Integer, sa.ForeignKey('users.id'), nullable=False), 24 | sa.Column('amount', sa.Integer, nullable=False), 25 | sa.Column('currency', sa.String(3), nullable=False), 26 | sa.Column('status', sa.String, nullable=False), 27 | sa.Column('raw_data', sa.JSON, nullable=True), 28 | sa.Column('created_at', sa.DateTime(timezone=True), default=datetime.now(UTC)), 29 | sa.Column('updated_at', sa.DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)), 30 | ) 31 | 32 | 33 | def downgrade() -> None: 34 | op.drop_table('stripe_payment') 35 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def main() -> None: 5 | """Launch Gunicorn the same way we do in the Dockerfile for local dev. 6 | 7 | This avoids importing the FastAPI app in the parent process, so no DB 8 | connections are opened before Gunicorn forks its workers (prevents SSL 9 | errors when using Railway/PostgreSQL). 10 | """ 11 | 12 | host = os.getenv("HOST", "0.0.0.0") 13 | port = os.getenv("PORT", "8000") 14 | reload = os.getenv("RELOAD", "false").lower() == "true" 15 | 16 | # Optional: let caller override the number of Gunicorn workers 17 | workers_env = os.getenv("WORKERS") # e.g. WORKERS=10 18 | 19 | cmd = [ 20 | "gunicorn", 21 | "app.main:app", 22 | "-k", 23 | "uvicorn.workers.UvicornWorker", 24 | "--bind", 25 | f"{host}:{port}", 26 | "--log-level", 27 | "info", 28 | ] 29 | 30 | if reload: 31 | cmd.append("--reload") 32 | 33 | # Inject --workers flag if WORKERS env var is set 34 | if workers_env and workers_env.isdigit(): 35 | cmd.extend(["--workers", workers_env]) 36 | 37 | # Replace the current process with Gunicorn. 38 | os.execvp(cmd[0], cmd) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /alembic/versions/9daf34d338f7_update_model_mapping_type_for_.py: -------------------------------------------------------------------------------- 1 | """update model_mapping type for ProviderKey table 2 | 3 | Revision ID: 9daf34d338f7 4 | Revises: 08cc005a4bc8 5 | Create Date: 2025-07-18 21:32:48.791253 6 | 7 | """ 8 | 9 | from alembic import op 10 | import sqlalchemy as sa 11 | from sqlalchemy.dialects import postgresql 12 | 13 | 14 | # revision identifiers, used by Alembic. 15 | revision = "9daf34d338f7" 16 | down_revision = "08cc005a4bc8" 17 | branch_labels = None 18 | depends_on = None 19 | 20 | 21 | def upgrade() -> None: 22 | # Change model_mapping column from String to JSON 23 | op.alter_column( 24 | "provider_keys", 25 | "model_mapping", 26 | existing_type=sa.String(), 27 | type_=postgresql.JSON(astext_type=sa.Text()), 28 | existing_nullable=True, 29 | postgresql_using="model_mapping::json", 30 | ) 31 | 32 | 33 | def downgrade() -> None: 34 | # Change model_mapping column from JSON back to String 35 | op.alter_column( 36 | "provider_keys", 37 | "model_mapping", 38 | existing_type=postgresql.JSON(astext_type=sa.Text()), 39 | type_=sa.String(), 40 | existing_nullable=True, 41 | postgresql_using="model_mapping::text", 42 | ) 43 | -------------------------------------------------------------------------------- /tests/unit_tests/assets/openai/chat_completion_response_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "chatcmpl-BmuFSt7sCoTGzmCCEEHt6vS1YIEdO", 3 | "object": "chat.completion", 4 | "created": 1750995662, 5 | "model": "gpt-4o-mini-2024-07-18", 6 | "choices": [ 7 | { 8 | "index": 0, 9 | "message": { 10 | "role": "assistant", 11 | "content": "Hello! I'm just a program, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?", 12 | "refusal": null, 13 | "annotations": [] 14 | }, 15 | "logprobs": null, 16 | "finish_reason": "stop" 17 | } 18 | ], 19 | "usage": { 20 | "prompt_tokens": 13, 21 | "completion_tokens": 29, 22 | "total_tokens": 42, 23 | "prompt_tokens_details": { 24 | "cached_tokens": 0, 25 | "audio_tokens": 0 26 | }, 27 | "completion_tokens_details": { 28 | "reasoning_tokens": 0, 29 | "audio_tokens": 0, 30 | "accepted_prediction_tokens": 0, 31 | "rejected_prediction_tokens": 0 32 | } 33 | }, 34 | "service_tier": "default", 35 | "system_fingerprint": "fp_34a54ae93c" 36 | } -------------------------------------------------------------------------------- /app/models/provider_key.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, ForeignKey, Integer, String, JSON, Boolean 2 | from sqlalchemy.orm import relationship 3 | 4 | from app.models.forge_api_key import forge_api_key_provider_scope_association 5 | 6 | from .base import BaseModel 7 | 8 | 9 | class ProviderKey(BaseModel): 10 | __tablename__ = "provider_keys" 11 | 12 | provider_name = Column(String, index=True) # e.g., "openai", "anthropic", etc. 13 | encrypted_api_key = Column(String) 14 | user_id = Column(Integer, ForeignKey("users.id")) 15 | user = relationship("User", back_populates="provider_keys") 16 | 17 | # Additional metadata specific to the provider 18 | base_url = Column( 19 | String, nullable=True 20 | ) # Allow custom base URLs for some providers 21 | model_mapping = Column(JSON, nullable=True) # JSON dict for model name mappings 22 | billable = Column(Boolean, nullable=False, default=False) 23 | 24 | # Relationship to ForgeApiKeys that have this provider key in their scope 25 | scoped_forge_api_keys = relationship( 26 | "ForgeApiKey", 27 | secondary=forge_api_key_provider_scope_association, 28 | back_populates="allowed_provider_keys", 29 | lazy="selectin", 30 | ) 31 | usage_tracker = relationship("UsageTracker", back_populates="provider_key") 32 | -------------------------------------------------------------------------------- /alembic/versions/b5d4363a9f62_add_clerk_user_id_to_users.py: -------------------------------------------------------------------------------- 1 | """add_clerk_user_id_to_users 2 | 3 | Revision ID: b5d4363a9f62 4 | Revises: 0ce4eeae965f 5 | Create Date: 2025-04-19 10:43:37.264983 6 | 7 | """ 8 | 9 | import sqlalchemy as sa 10 | 11 | from alembic import op 12 | 13 | # revision identifiers, used by Alembic. 14 | revision = "b5d4363a9f62" 15 | down_revision = "0ce4eeae965f" 16 | branch_labels = None 17 | depends_on = None 18 | 19 | 20 | def upgrade() -> None: 21 | # Add clerk_user_id column to users table 22 | op.add_column("users", sa.Column("clerk_user_id", sa.String(), nullable=True)) 23 | op.create_index( 24 | op.f("ix_users_clerk_user_id"), "users", ["clerk_user_id"], unique=True 25 | ) 26 | 27 | # SQLite doesn't support ALTER COLUMN directly, use batch operations instead 28 | with op.batch_alter_table("users") as batch_op: 29 | batch_op.alter_column( 30 | "hashed_password", existing_type=sa.String(), nullable=True 31 | ) 32 | 33 | 34 | def downgrade() -> None: 35 | # Remove clerk_user_id column 36 | op.drop_index(op.f("ix_users_clerk_user_id"), table_name="users") 37 | op.drop_column("users", "clerk_user_id") 38 | 39 | # Make hashed_password required again 40 | with op.batch_alter_table("users") as batch_op: 41 | batch_op.alter_column( 42 | "hashed_password", existing_type=sa.String(), nullable=False 43 | ) 44 | -------------------------------------------------------------------------------- /app/utils/translator.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import aiohttp 3 | from http import HTTPStatus 4 | 5 | async def download_image_url(logger, image_url: str) -> str: 6 | """ 7 | Download an image from a URL and return the base64 encoded string 8 | """ 9 | 10 | # if the image url is a data url, return it as is 11 | if image_url.startswith("data:"): 12 | return image_url 13 | 14 | async with aiohttp.ClientSession() as session: 15 | async with session.head(image_url) as response: 16 | if response.status != HTTPStatus.OK: 17 | error_text = await response.text() 18 | log_error_msg = f"Failed to fetch file metadata from URL: {error_text}" 19 | logger.error(log_error_msg) 20 | raise RuntimeError(log_error_msg) 21 | 22 | mime_type = response.headers.get("Content-Type", "") 23 | file_size = int(response.headers.get("Content-Length", 0)) 24 | if file_size > 10 * 1024 * 1024: 25 | log_error_msg = f"Image file size is too large: {file_size} bytes" 26 | logger.error(log_error_msg) 27 | raise RuntimeError(log_error_msg) 28 | 29 | async with session.get(image_url) as response: 30 | # return format is data:mime_type;base64,base64_data 31 | return f"data:{mime_type};base64,{base64.b64encode(await response.read()).decode('utf-8')}" 32 | -------------------------------------------------------------------------------- /app/services/providers/gemini_openai_adapter.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from app.core.logger import get_logger 4 | from .openai_adapter import OpenAIAdapter 5 | 6 | # Configure logging 7 | logger = get_logger(name="gemini_openai_adapter") 8 | 9 | 10 | class GeminiOpenAIAdapter(OpenAIAdapter): 11 | """Adapter for Google Gemini via the OpenAI-compatible endpoint 12 | 13 | Google now exposes Gemini models behind an OpenAI-compatible REST surface at 14 | https://generativelanguage.googleapis.com/v1beta/openai/… 15 | 16 | We can therefore reuse all logic in OpenAIAdapter. The only wrinkle is that 17 | callers might supply the base_url without the trailing `/openai` segment, 18 | so we normalise it here. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | provider_name: str, 24 | base_url: str | None, 25 | config: dict[str, Any] | None = None, 26 | ): 27 | # Default base URL if none supplied 28 | if not base_url: 29 | base_url = "https://generativelanguage.googleapis.com/v1beta" 30 | 31 | # Ensure the URL ends with the OpenAI compatibility suffix 32 | base_url = base_url.rstrip("/") 33 | if not base_url.endswith("/openai"): 34 | base_url = f"{base_url}/openai" 35 | 36 | logger.debug(f"Initialised GeminiOpenAIAdapter with base_url={base_url}") 37 | 38 | super().__init__(provider_name, base_url, config=config or {}) -------------------------------------------------------------------------------- /app/api/schemas/stripe.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, field_validator 2 | from typing import List, Literal 3 | 4 | # https://docs.stripe.com/api/checkout/sessions/create 5 | class StripeCheckoutSessionLineItemPriceDataProductData(BaseModel): 6 | name: str 7 | description: str | None = None 8 | images: List[str] | None = None 9 | 10 | class StripeCheckoutSessionLineItemPriceData(BaseModel): 11 | currency: str 12 | product_data: StripeCheckoutSessionLineItemPriceDataProductData 13 | tax_behavior: str = "inclusive" 14 | unit_amount: int 15 | 16 | class StripeCheckoutSessionLineItem(BaseModel): 17 | price_data: StripeCheckoutSessionLineItemPriceData 18 | quantity: int 19 | 20 | class CreateCheckoutSessionRequest(BaseModel): 21 | line_items: List[StripeCheckoutSessionLineItem] 22 | # Only allow payment mode for now 23 | mode: Literal["payment"] = "payment" 24 | # Attach the session_id to the success_url 25 | # https://docs.stripe.com/payments/checkout/custom-success-page?payment-ui=stripe-hosted&utm_source=chatgpt.com#success-url 26 | success_url: str | None = None 27 | return_url: str | None = None 28 | cancel_url: str | None = None 29 | ui_mode: str = "hosted" 30 | 31 | @field_validator("success_url") 32 | @classmethod 33 | def append_session_id_to_success_url(cls, value: str): 34 | if value is None: 35 | return None 36 | return value.rstrip("/") + "?session_id={CHECKOUT_SESSION_ID}" -------------------------------------------------------------------------------- /app/api/schemas/user.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from pydantic import BaseModel, ConfigDict, EmailStr, Field 4 | 5 | from app.api.schemas.forge_api_key import ForgeApiKeyMasked 6 | 7 | # Constants 8 | VISIBLE_API_KEY_CHARS = 4 9 | 10 | 11 | class UserBase(BaseModel): 12 | email: EmailStr 13 | username: str 14 | 15 | 16 | class UserCreate(UserBase): 17 | password: str 18 | 19 | 20 | class UserUpdate(BaseModel): 21 | email: EmailStr | None = None 22 | username: str | None = None 23 | password: str | None = None 24 | 25 | 26 | class UserInDB(UserBase): 27 | id: int 28 | is_active: bool 29 | created_at: datetime 30 | updated_at: datetime 31 | model_config = ConfigDict(from_attributes=True) 32 | 33 | 34 | class User(UserInDB): 35 | forge_api_keys: list[str] | None = None 36 | 37 | 38 | class MaskedUser(UserInDB): 39 | is_admin: bool = False 40 | forge_api_keys: list[str] | None = Field( 41 | description="List of all API keys with all but last 4 digits masked", 42 | default=None, 43 | ) 44 | 45 | @classmethod 46 | def mask_api_key(cls, api_key: str | None) -> str | None: 47 | if not api_key: 48 | return None 49 | return ForgeApiKeyMasked.mask_api_key(api_key) 50 | 51 | model_config = ConfigDict(from_attributes=True) 52 | 53 | 54 | class Token(BaseModel): 55 | access_token: str 56 | token_type: str 57 | 58 | 59 | class TokenData(BaseModel): 60 | username: str | None = None 61 | -------------------------------------------------------------------------------- /alembic/versions/a58395ea1b22_add_balance_system.py: -------------------------------------------------------------------------------- 1 | """add balance system 2 | 3 | Revision ID: a58395ea1b22 4 | Revises: c9f3e548adef 5 | Create Date: 2025-08-20 22:00:45.743308 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = 'a58395ea1b22' 14 | down_revision = 'c9f3e548adef' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | def upgrade() -> None: 19 | op.create_table( 20 | 'wallets', 21 | sa.Column('account_id', sa.BigInteger(), nullable=False), 22 | sa.Column('currency', sa.CHAR(length=3), nullable=False, server_default='USD'), 23 | sa.Column('balance', sa.DECIMAL(precision=20, scale=6), nullable=False, server_default='0'), 24 | sa.Column('blocked', sa.Boolean(), nullable=False, server_default='FALSE'), 25 | sa.Column('version', sa.BigInteger(), nullable=False, server_default='0'), 26 | sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), 27 | sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), 28 | sa.PrimaryKeyConstraint('account_id'), 29 | sa.ForeignKeyConstraint(['account_id'], ['users.id'], ondelete='CASCADE') 30 | ) 31 | op.add_column('provider_keys', sa.Column('billable', sa.Boolean(), nullable=False, server_default='FALSE')) 32 | 33 | def downgrade() -> None: 34 | op.drop_table('wallets') 35 | op.drop_column('provider_keys', 'billable') 36 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | ## Type of Change 5 | 6 | - [ ] Bug fix (non-breaking change which fixes an issue) 7 | - [ ] New feature (non-breaking change which adds functionality) 8 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 9 | - [ ] Documentation update 10 | - [ ] Refactoring (no functional changes) 11 | - [ ] Performance improvement 12 | - [ ] Testing improvement 13 | 14 | ## How Has This Been Tested? 15 | 16 | - [ ] Unit tests 17 | - [ ] Integration tests 18 | - [ ] Manual tests (please describe) 19 | 20 | ## Checklist: 21 | 22 | - [ ] My code follows the code style of this project 23 | - [ ] My changes generate no new warnings 24 | - [ ] I have added tests that prove my fix is effective or that my feature works 25 | - [ ] New and existing unit tests pass locally with my changes 26 | - [ ] Any dependent changes have been merged and published in downstream modules 27 | - [ ] I have updated the documentation accordingly 28 | - [ ] I have checked that my changes don't break backward compatibility 29 | - [ ] I have ensured all automated tests pass (checks are green) 30 | 31 | ## Additional Context 32 | 33 | -------------------------------------------------------------------------------- /app/models/usage_tracker.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from datetime import UTC 3 | import uuid 4 | 5 | from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, DECIMAL, Boolean 6 | from sqlalchemy.orm import relationship 7 | from sqlalchemy.dialects.postgresql import UUID 8 | from .base import Base 9 | 10 | class UsageTracker(Base): 11 | __tablename__ = "usage_tracker" 12 | 13 | id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) 14 | user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) 15 | provider_key_id = Column(Integer, ForeignKey("provider_keys.id", ondelete="CASCADE"), nullable=False) 16 | forge_key_id = Column(Integer, ForeignKey("forge_api_keys.id", ondelete="CASCADE"), nullable=False) 17 | model = Column(String, nullable=True) 18 | endpoint = Column(String, nullable=True) 19 | created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC)) 20 | updated_at = Column(DateTime(timezone=True), nullable=True) 21 | input_tokens = Column(Integer, nullable=True) 22 | output_tokens = Column(Integer, nullable=True) 23 | cached_tokens = Column(Integer, nullable=True) 24 | reasoning_tokens = Column(Integer, nullable=True) 25 | cost = Column(DECIMAL(12, 8), nullable=True) 26 | currency = Column(String(3), nullable=True) 27 | pricing_source = Column(String(255), nullable=True) 28 | billable = Column(Boolean, nullable=False, default=False) 29 | 30 | provider_key = relationship("ProviderKey", back_populates="usage_tracker") 31 | -------------------------------------------------------------------------------- /app/models/user.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from sqlalchemy import Boolean, Column, DateTime, Integer, String 4 | from sqlalchemy.orm import relationship 5 | 6 | from app.models.base import Base 7 | 8 | # UsageStats model is removed, so no related imports needed 9 | 10 | 11 | class User(Base): 12 | """User model for storing user information""" 13 | 14 | __tablename__ = "users" 15 | 16 | id = Column(Integer, primary_key=True, index=True) 17 | email = Column(String, unique=True, index=True, nullable=False) 18 | username = Column(String, unique=True, index=True, nullable=False) 19 | hashed_password = Column(String, nullable=False) 20 | is_active = Column(Boolean, default=True) 21 | # is_admin = Column(Boolean, default=False) # Add if implementing admin role 22 | clerk_user_id = Column(String, unique=True, nullable=True) # Add Clerk user ID 23 | created_at = Column(DateTime, default=datetime.datetime.utcnow) 24 | updated_at = Column( 25 | DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow 26 | ) 27 | 28 | # Relationships 29 | api_keys = relationship( 30 | "ForgeApiKey", back_populates="user", cascade="all, delete-orphan" 31 | ) 32 | provider_keys = relationship( 33 | "ProviderKey", back_populates="user", cascade="all, delete-orphan" 34 | ) 35 | wallet = relationship("Wallet", back_populates="user", uselist=False) 36 | # Optional: Add relationship to ApiRequestLog if needed 37 | # api_logs = relationship("ApiRequestLog") 38 | admin_users = relationship("AdminUsers", back_populates="user", uselist=False) -------------------------------------------------------------------------------- /app/models/api_request_log.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from sqlalchemy import Column, DateTime, Float, ForeignKey, Index, Integer, String 4 | 5 | from .base import Base 6 | 7 | 8 | class ApiRequestLog(Base): 9 | __tablename__ = "api_request_log" 10 | 11 | id = Column(Integer, primary_key=True) 12 | # Link to user, allow null if user deleted or request is unauthenticated? Let's make it nullable. 13 | user_id = Column( 14 | Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True 15 | ) 16 | provider_name = Column(String, nullable=False, index=True) 17 | model = Column(String, nullable=False, index=True) 18 | endpoint = Column(String, nullable=False, index=True) # e.g., 'chat/completions' 19 | request_timestamp = Column( 20 | DateTime, default=datetime.datetime.utcnow, nullable=False, index=True 21 | ) # Time of the request 22 | input_tokens = Column(Integer, default=0) 23 | output_tokens = Column(Integer, default=0) 24 | total_tokens = Column(Integer, default=0) # Calculated: input + output 25 | cost = Column(Float, default=0.0) # Estimated cost if available 26 | # Optional: Add status_code and duration_ms later if needed 27 | 28 | # Relationship (optional, useful if querying logs through User object) 29 | # user = relationship("User") 30 | 31 | # Define indices for common query patterns 32 | __table_args__ = ( 33 | Index("ix_api_request_log_user_time", "user_id", "request_timestamp"), 34 | # Add other indices as needed, e.g., Index('ix_api_request_log_provider_model', 'provider_name', 'model') 35 | ) 36 | -------------------------------------------------------------------------------- /tools/diagnostics/fix_model_mapping.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Utility script to fix model mappings in the database. 4 | Specifically for fixing the gpt-4o to mock-gpt-4o mapping issue. 5 | """ 6 | 7 | import asyncio 8 | import os 9 | import sys 10 | from pathlib import Path 11 | 12 | from app.core.database import get_async_db 13 | 14 | # Add the project root to the Python path 15 | script_dir = Path(__file__).resolve().parent.parent.parent 16 | sys.path.insert(0, str(script_dir)) 17 | 18 | # Change to the project root directory 19 | os.chdir(script_dir) 20 | 21 | 22 | async def fix_model_mappings(): 23 | """Fix model mappings by clearing caches""" 24 | print("\n🔧 FIXING MODEL MAPPINGS") 25 | print("======================") 26 | 27 | # Get DB session 28 | async with get_async_db() as db: 29 | pass 30 | 31 | # Clear all caches to ensure changes take effect 32 | print("🔄 Invalidating provider service cache for all users") 33 | from app.core.cache import provider_service_cache, user_cache 34 | 35 | provider_service_cache.clear() 36 | user_cache.clear() 37 | print("✅ All caches cleared") 38 | 39 | print("\n✅ Model mapping fix complete.") 40 | return True 41 | 42 | 43 | async def main(): 44 | """Main entry point""" 45 | if await fix_model_mappings(): 46 | print( 47 | "\n✅ Model mappings have been fixed. Use check_model_mappings.py to verify." 48 | ) 49 | sys.exit(0) 50 | else: 51 | print("\n❌ Failed to fix model mappings.") 52 | sys.exit(1) 53 | 54 | 55 | if __name__ == "__main__": 56 | asyncio.run(main()) 57 | -------------------------------------------------------------------------------- /.github/workflows/docker-build.yml: -------------------------------------------------------------------------------- 1 | name: Build and Publish Docker Image 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | release_tag: 7 | description: 'Release Tag (leave empty for auto-generated timestamp)' 8 | required: false 9 | default: '' 10 | debug_logging: 11 | description: 'Enable debug logging in the image' 12 | type: boolean 13 | required: false 14 | default: false 15 | 16 | jobs: 17 | build-and-publish: 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - name: Checkout code 22 | uses: actions/checkout@v3 23 | 24 | - name: Set up Docker Buildx 25 | uses: docker/setup-buildx-action@v2 26 | 27 | - name: Login to Docker Hub 28 | uses: docker/login-action@v2 29 | with: 30 | username: ${{ secrets.DOCKER_USERNAME }} 31 | password: ${{ secrets.DOCKER_PASSWORD }} 32 | 33 | - name: Get timestamp 34 | id: timestamp 35 | run: echo "timestamp=$(date +'%Y%m%d-%H%M%S')" >> $GITHUB_OUTPUT 36 | 37 | - name: Determine tag 38 | id: tag 39 | run: | 40 | if [ -z "${{ github.event.inputs.release_tag }}" ]; then 41 | echo "tag=${{ steps.timestamp.outputs.timestamp }}" >> $GITHUB_OUTPUT 42 | else 43 | echo "tag=${{ github.event.inputs.release_tag }}" >> $GITHUB_OUTPUT 44 | fi 45 | 46 | - name: Build and push 47 | uses: docker/build-push-action@v4 48 | with: 49 | context: . 50 | push: true 51 | tags: | 52 | tensorblockai/forge:latest 53 | tensorblockai/forge:${{ steps.tag.outputs.tag }} 54 | cache-from: type=gha 55 | cache-to: type=gha,mode=max 56 | -------------------------------------------------------------------------------- /tests/unit_tests/assets/openai/list_models.json: -------------------------------------------------------------------------------- 1 | { 2 | "object": "list", 3 | "data": [ 4 | { 5 | "id": "gpt-4-0613", 6 | "object": "model", 7 | "created": 1686588896, 8 | "owned_by": "openai" 9 | }, 10 | { 11 | "id": "gpt-4", 12 | "object": "model", 13 | "created": 1687882411, 14 | "owned_by": "openai" 15 | }, 16 | { 17 | "id": "gpt-3.5-turbo", 18 | "object": "model", 19 | "created": 1677610602, 20 | "owned_by": "openai" 21 | }, 22 | { 23 | "id": "o4-mini-deep-research", 24 | "object": "model", 25 | "created": 1749685485, 26 | "owned_by": "system" 27 | }, 28 | { 29 | "id": "o3-deep-research", 30 | "object": "model", 31 | "created": 1749840121, 32 | "owned_by": "system" 33 | }, 34 | { 35 | "id": "davinci-002", 36 | "object": "model", 37 | "created": 1692634301, 38 | "owned_by": "system" 39 | }, 40 | { 41 | "id": "dall-e-3", 42 | "object": "model", 43 | "created": 1698785189, 44 | "owned_by": "system" 45 | }, 46 | { 47 | "id": "dall-e-2", 48 | "object": "model", 49 | "created": 1698798177, 50 | "owned_by": "system" 51 | }, 52 | { 53 | "id": "gpt-3.5-turbo-1106", 54 | "object": "model", 55 | "created": 1698959748, 56 | "owned_by": "system" 57 | } 58 | ] 59 | } -------------------------------------------------------------------------------- /alembic/versions/4a685a55c5cd_create_usage_tracker_table.py: -------------------------------------------------------------------------------- 1 | """create usage tracker table 2 | 3 | Revision ID: 4a685a55c5cd 4 | Revises: 9daf34d338f7 5 | Create Date: 2025-08-02 12:29:07.955645 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | from sqlalchemy.dialects.postgresql import UUID 11 | import uuid 12 | 13 | 14 | # revision identifiers, used by Alembic. 15 | revision = '4a685a55c5cd' 16 | down_revision = '9daf34d338f7' 17 | branch_labels = None 18 | depends_on = None 19 | 20 | 21 | def upgrade() -> None: 22 | op.create_table( 23 | "usage_tracker", 24 | sa.Column("id", UUID(as_uuid=True), nullable=False, primary_key=True, default=uuid.uuid4), 25 | sa.Column("user_id", sa.Integer(), nullable=False), 26 | sa.Column("provider_key_id", sa.Integer(), nullable=False), 27 | sa.Column("forge_key_id", sa.Integer(), nullable=False), 28 | sa.Column("model", sa.String(), nullable=True), 29 | sa.Column("endpoint", sa.String(), nullable=True), 30 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), 31 | sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), 32 | sa.Column("input_tokens", sa.Integer(), nullable=True), 33 | sa.Column("output_tokens", sa.Integer(), nullable=True), 34 | sa.Column("cached_tokens", sa.Integer(), nullable=True), 35 | sa.Column("reasoning_tokens", sa.Integer(), nullable=True), 36 | 37 | sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), 38 | sa.ForeignKeyConstraint(["provider_key_id"], ["provider_keys.id"], ondelete="CASCADE"), 39 | sa.ForeignKeyConstraint(["forge_key_id"], ["forge_api_keys.id"], ondelete="CASCADE"), 40 | ) 41 | 42 | 43 | def downgrade() -> None: 44 | op.drop_table("usage_tracker") 45 | -------------------------------------------------------------------------------- /alembic/versions/39bcedfae4fe_add_model_default_pricing.py: -------------------------------------------------------------------------------- 1 | """add model default pricing 2 | 3 | Revision ID: 39bcedfae4fe 4 | Revises: b206e9a941e3 5 | Create Date: 2025-08-14 18:31:20.897283 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = '39bcedfae4fe' 14 | down_revision = 'b206e9a941e3' 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | op.add_column('fallback_pricing', sa.Column('model_name', sa.String(), nullable=True)) 21 | connection = op.get_bind() 22 | connection.execute(sa.text(""" 23 | insert into fallback_pricing (provider_name, model_name, input_token_price, output_token_price, cached_token_price, currency, created_at, updated_at, effective_date, end_date, fallback_type, description) 24 | select distinct on (model_name) 25 | provider_name as provider_name, 26 | model_name as model_name, 27 | input_token_price, 28 | output_token_price, 29 | cached_token_price, 30 | currency, 31 | created_at, 32 | updated_at, 33 | effective_date, 34 | end_date, 35 | 'model_default' as fallback_type, 36 | null as description 37 | from model_pricing 38 | order by model_name, 39 | case when provider_name = 'openai' then 1 40 | when provider_name = 'anthropic' then 2 41 | else 3 end 42 | """)) 43 | 44 | 45 | def downgrade() -> None: 46 | connection = op.get_bind() 47 | connection.execute(sa.text(""" 48 | delete from fallback_pricing where fallback_type = 'model_default' 49 | """)) 50 | op.drop_column('fallback_pricing', 'model_name') 51 | -------------------------------------------------------------------------------- /tests/unit_tests/assets/openai/responses_response_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "resp_0a379700213743ce0068db3cded47481a29d3552eea69e6939", 3 | "object": "response", 4 | "created_at": 1759198430, 5 | "status": "completed", 6 | "background": false, 7 | "billing": {"payer": "developer"}, 8 | "error": null, 9 | "incomplete_details": null, 10 | "instructions": null, 11 | "max_output_tokens": null, 12 | "max_tool_calls": null, 13 | "model": "gpt-4o-mini-2024-07-18", 14 | "output": [ 15 | { 16 | "id": "msg_0a379700213743ce0068db3cdf654481a29d5a27842a0e095f", 17 | "type": "message", 18 | "status": "completed", 19 | "content": [ 20 | { 21 | "type": "output_text", 22 | "annotations": [], 23 | "logprobs": [], 24 | "text": "Hello! I'm doing well, thank you. How about you?" 25 | } 26 | ], 27 | "role": "assistant" 28 | } 29 | ], 30 | "parallel_tool_calls": true, 31 | "previous_response_id": null, 32 | "prompt_cache_key": null, 33 | "reasoning": {"effort": null, "summary": null}, 34 | "safety_identifier": null, 35 | "service_tier": "default", 36 | "store": true, 37 | "temperature": 1.0, 38 | "text": {"format": {"type": "text"}, "verbosity": "medium"}, 39 | "tool_choice": "auto", 40 | "tools": [], 41 | "top_logprobs": 0, 42 | "top_p": 1.0, 43 | "truncation": "disabled", 44 | "usage": { 45 | "input_tokens": 13, 46 | "input_tokens_details": {"cached_tokens": 0}, 47 | "output_tokens": 14, 48 | "output_tokens_details": {"reasoning_tokens": 0}, 49 | "total_tokens": 27 50 | }, 51 | "user": null, 52 | "metadata": {} 53 | } 54 | -------------------------------------------------------------------------------- /app/api/schemas/forge_api_key.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from pydantic import BaseModel, ConfigDict, Field 4 | 5 | 6 | class ForgeApiKeyBase(BaseModel): 7 | """Base schema for ForgeApiKey""" 8 | 9 | name: str | None = None 10 | 11 | 12 | class ForgeApiKeyCreate(ForgeApiKeyBase): 13 | """Schema for creating a ForgeApiKey""" 14 | 15 | allowed_provider_key_ids: list[int] | None = Field(default=None) 16 | 17 | 18 | class ForgeApiKeyResponse(ForgeApiKeyBase): 19 | """Schema for ForgeApiKey response""" 20 | 21 | id: int 22 | key: str 23 | is_active: bool 24 | created_at: datetime 25 | last_used_at: datetime | None = None 26 | allowed_provider_key_ids: list[int] = Field(default_factory=list) 27 | model_config = ConfigDict(from_attributes=True) 28 | 29 | 30 | class ForgeApiKeyMasked(ForgeApiKeyBase): 31 | """Schema for ForgeApiKey with masked key for display""" 32 | 33 | id: int 34 | key: str 35 | is_active: bool 36 | created_at: datetime 37 | last_used_at: datetime | None = None 38 | allowed_provider_key_ids: list[int] = Field(default_factory=list) 39 | 40 | @staticmethod 41 | def mask_api_key(key: str) -> str: 42 | """Mask the API key for display""" 43 | if not key: 44 | return "" 45 | if key.startswith("forge-"): 46 | prefix = "forge-" 47 | key_part = key[len(prefix) :] 48 | return f"{prefix}{'*' * (len(key_part) - 4)}{key_part[-4:]}" 49 | return f"{'*' * (len(key) - 4)}{key[-4:]}" 50 | 51 | model_config = ConfigDict(from_attributes=True) 52 | 53 | 54 | class ForgeApiKeyUpdate(BaseModel): 55 | """Schema for updating a ForgeApiKey""" 56 | 57 | name: str | None = None 58 | is_active: bool | None = None 59 | allowed_provider_key_ids: list[int] | None = None 60 | -------------------------------------------------------------------------------- /alembic/versions/08cc005a4bc8_create_forge_api_key_provider_scope_.py: -------------------------------------------------------------------------------- 1 | """create_forge_api_key_provider_scope_association_table 2 | 3 | Revision ID: 08cc005a4bc8 4 | Revises: 6a92c2663fa5 5 | Create Date: 2025-05-16 13:53:06.169215 6 | 7 | """ 8 | import sqlalchemy as sa 9 | 10 | from alembic import op 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = "08cc005a4bc8" 14 | down_revision = "6a92c2663fa5" 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade() -> None: 20 | # ### commands auto generated by Alembic - please adjust! ### 21 | op.create_table( 22 | "forge_api_key_provider_scope_association", 23 | sa.Column("forge_api_key_id", sa.Integer(), nullable=False), 24 | sa.Column("provider_key_id", sa.Integer(), nullable=False), 25 | sa.ForeignKeyConstraint( 26 | ["forge_api_key_id"], 27 | ["forge_api_keys.id"], 28 | name=op.f( 29 | "fk_forge_api_key_provider_scope_association_forge_api_key_id_forge_api_keys" 30 | ), 31 | ondelete="CASCADE", 32 | ), 33 | sa.ForeignKeyConstraint( 34 | ["provider_key_id"], 35 | ["provider_keys.id"], 36 | name=op.f( 37 | "fk_forge_api_key_provider_scope_association_provider_key_id_provider_keys" 38 | ), 39 | ondelete="CASCADE", 40 | ), 41 | sa.PrimaryKeyConstraint( 42 | "forge_api_key_id", 43 | "provider_key_id", 44 | name=op.f("pk_forge_api_key_provider_scope_association"), 45 | ), 46 | ) 47 | # ### end Alembic commands ### 48 | 49 | 50 | def downgrade() -> None: 51 | # ### commands auto generated by Alembic - please adjust! ### 52 | op.drop_table("forge_api_key_provider_scope_association") 53 | # ### end Alembic commands ### 54 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.12-slim 2 | 3 | WORKDIR /app 4 | 5 | # Define build-time arguments 6 | ARG ARG_CLERK_JWT_PUBLIC_KEY 7 | ARG ARG_CLERK_API_KEY 8 | ARG ARG_CLERK_API_URL 9 | ARG ARG_DEBUG_LOGGING=false 10 | 11 | # Set runtime environment variables from build-time arguments 12 | # WARNING: These values will be baked into the image. For sensitive data, prefer runtime injection. 13 | ENV CLERK_JWT_PUBLIC_KEY=${ARG_CLERK_JWT_PUBLIC_KEY} 14 | ENV CLERK_API_KEY=${ARG_CLERK_API_KEY} 15 | # Example with a default value if ARG is not set for CLERK_API_URL 16 | ENV CLERK_API_URL=${ARG_CLERK_API_URL:-https://api.clerk.dev/v1} 17 | ENV FORGE_DEBUG_LOGGING=${ARG_DEBUG_LOGGING} 18 | 19 | # Database connection optimization environment variables 20 | # These settings optimize for PostgreSQL connection limits 21 | ENV DB_POOL_SIZE=3 22 | ENV DB_MAX_OVERFLOW=2 23 | ENV DB_POOL_TIMEOUT=30 24 | ENV DB_POOL_RECYCLE=1800 25 | ENV DB_POOL_PRE_PING=true 26 | 27 | # Reduced worker count to manage database connections 28 | # With 5 workers: max 60 connections (5 × 3 × 2 engines + 5 × 2 × 2 overflow = 50 connections) 29 | ENV WORKERS=5 30 | 31 | # Install system dependencies including PostgreSQL client and gosu for user privilege management 32 | RUN apt-get update && apt-get install -y \ 33 | postgresql-client \ 34 | && rm -rf /var/lib/apt/lists/* 35 | 36 | RUN mkdir -p /app/logs && \ 37 | chown -R nobody:nogroup /app/logs && \ 38 | chmod -R 777 /app/logs 39 | 40 | # Copy project files 41 | COPY . . 42 | 43 | # Install dependencies using pip 44 | RUN pip install -e . 45 | 46 | # Switch to non-root user for security 47 | USER nobody 48 | 49 | # Expose port 50 | EXPOSE 8000 51 | 52 | # Use environment variable for workers count and optimize for database connections 53 | CMD ["sh", "-c", "gunicorn app.main:app -k uvicorn.workers.UvicornWorker --workers ${WORKERS:-5} --bind 0.0.0.0:8000 --timeout 120 --max-requests 1000 --max-requests-jitter 100"] 54 | -------------------------------------------------------------------------------- /tests/unit_tests/assets/google/chat_completion_streaming_response_1.json: -------------------------------------------------------------------------------- 1 | [ 2 | "[{\n", 3 | " \"candidates\": [\n", 4 | " {\n", 5 | " \"content\": {\n", 6 | " \"parts\": [\n", 7 | " {\n", 8 | " \"text\": \"I\"\n", 9 | " }\n", 10 | " ],\n", 11 | " \"role\": \"model\"\n", 12 | " }\n", 13 | " }\n", 14 | " ],\n", 15 | " \"usageMetadata\": {\n", 16 | " \"promptTokenCount\": 6,\n", 17 | " \"totalTokenCount\": 6,\n", 18 | " \"promptTokensDetails\": [\n", 19 | " {\n", 20 | " \"modality\": \"TEXT\",\n", 21 | " \"tokenCount\": 6\n", 22 | " }\n", 23 | " ]\n", 24 | " },\n", 25 | " \"modelVersion\": \"gemini-1.5-pro-002\",\n", 26 | " \"responseId\": \"ASVeaJ69O57ImNAPh7Sh4A0\"\n", 27 | "}\n", 28 | ",\r\n", 29 | "{\n", 30 | " \"candidates\": [\n", 31 | " {\n", 32 | " \"content\": {\n", 33 | " \"parts\": [\n", 34 | " {\n", 35 | " \"text\": \" am doing well, thank you for asking. How are you today?\\n\"\n", 36 | " }\n", 37 | " ],\n", 38 | " \"role\": \"model\"\n", 39 | " },\n", 40 | " \"finishReason\": \"STOP\"\n", 41 | " }\n", 42 | " ],\n", 43 | " \"usageMetadata\": {\n", 44 | " \"promptTokenCount\": 6,\n", 45 | " \"candidatesTokenCount\": 16,\n", 46 | " \"totalTokenCount\": 22,\n", 47 | " \"promptTokensDetails\": [\n", 48 | " {\n", 49 | " \"modality\": \"TEXT\",\n", 50 | " \"tokenCount\": 6\n", 51 | " }\n", 52 | " ],\n", 53 | " \"candidatesTokensDetails\": [\n", 54 | " {\n", 55 | " \"modality\": \"TEXT\",\n", 56 | " \"tokenCount\": 16\n", 57 | " }\n", 58 | " ]\n", 59 | " },\n", 60 | " \"modelVersion\": \"gemini-1.5-pro-002\",\n", 61 | " \"responseId\": \"ASVeaJ69O57ImNAPh7Sh4A0\"\n", 62 | "}\n", 63 | "]" 64 | ] -------------------------------------------------------------------------------- /app/core/security.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime, timedelta 3 | 4 | from cryptography.fernet import Fernet 5 | from dotenv import load_dotenv 6 | from jose import jwt 7 | from passlib.context import CryptContext 8 | 9 | load_dotenv() 10 | 11 | # Password hashing 12 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") 13 | 14 | # JWT settings 15 | SECRET_KEY = os.getenv("SECRET_KEY", "your_secret_key_here") 16 | ALGORITHM = os.getenv("ALGORITHM", "HS256") 17 | ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "60")) 18 | 19 | # Encryption for API keys 20 | ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", Fernet.generate_key().decode()) 21 | # Initialize with a direct approach to avoid indentation issues 22 | if isinstance(ENCRYPTION_KEY, str): 23 | fernet = Fernet(ENCRYPTION_KEY.encode()) 24 | else: 25 | fernet = Fernet(ENCRYPTION_KEY) 26 | 27 | 28 | def verify_password(plain_password, hashed_password): 29 | return pwd_context.verify(plain_password, hashed_password) 30 | 31 | 32 | def get_password_hash(password): 33 | return pwd_context.hash(password) 34 | 35 | 36 | def create_access_token(data: dict, expires_delta: timedelta | None = None): 37 | to_encode = data.copy() 38 | if expires_delta: 39 | expire = datetime.utcnow() + expires_delta 40 | else: 41 | expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) 42 | to_encode.update({"exp": expire}) 43 | encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) 44 | return encoded_jwt 45 | 46 | 47 | def encrypt_api_key(api_key: str) -> str: 48 | """Encrypt an API key""" 49 | return fernet.encrypt(api_key.encode()).decode() 50 | 51 | 52 | def decrypt_api_key(encrypted_api_key: str) -> str: 53 | """Decrypt an API key""" 54 | return fernet.decrypt(encrypted_api_key.encode()).decode() 55 | 56 | 57 | def generate_forge_api_key() -> str: 58 | """ 59 | Generate a unique Forge API key. 60 | """ 61 | import secrets 62 | 63 | return f"forge-{secrets.token_hex(18)}" 64 | -------------------------------------------------------------------------------- /app/models/forge_api_key.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Table 4 | from sqlalchemy.orm import relationship 5 | 6 | from app.models.base import Base 7 | 8 | # Association Table for ForgeApiKey and ProviderKey 9 | forge_api_key_provider_scope_association = Table( 10 | "forge_api_key_provider_scope_association", 11 | Base.metadata, 12 | Column( 13 | "forge_api_key_id", 14 | Integer, 15 | ForeignKey("forge_api_keys.id", ondelete="CASCADE"), 16 | primary_key=True, 17 | ), 18 | Column( 19 | "provider_key_id", 20 | Integer, 21 | ForeignKey("provider_keys.id", ondelete="CASCADE"), 22 | primary_key=True, 23 | ), 24 | ) 25 | 26 | 27 | class ForgeApiKey(Base): 28 | """Model for storing multiple Forge API keys per user""" 29 | 30 | __tablename__ = "forge_api_keys" 31 | 32 | id = Column(Integer, primary_key=True, index=True) 33 | key = Column(String, unique=True, index=True, nullable=False) 34 | name = Column(String, nullable=True) # Optional name/description for the key 35 | user_id = Column( 36 | Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False 37 | ) 38 | is_active = Column(Boolean, default=True) 39 | created_at = Column(DateTime, default=datetime.datetime.utcnow) 40 | last_used_at = Column(DateTime, nullable=True) 41 | deleted_at = Column(DateTime(timezone=True), nullable=True) 42 | 43 | # Relationship to user 44 | user = relationship("User", back_populates="api_keys") 45 | 46 | # Relationship to allowed ProviderKeys (scope) 47 | allowed_provider_keys = relationship( 48 | "ProviderKey", 49 | secondary=forge_api_key_provider_scope_association, 50 | back_populates="scoped_forge_api_keys", 51 | lazy="selectin", # Use selectin loading for efficiency when accessing this relationship 52 | ) 53 | 54 | # Optionally, we could have a relationship to enabled provider keys 55 | # enabled_provider_keys = relationship("EnabledProviderKey", back_populates="api_key") 56 | -------------------------------------------------------------------------------- /tools/diagnostics/enable_request_logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Enable request logging by setting LOG_LEVEL environment variable. 4 | This script modifies the .env file to adjust the logging level for the Forge server. 5 | """ 6 | 7 | import os 8 | import sys 9 | from pathlib import Path 10 | 11 | # Add the project root to the Python path 12 | script_dir = Path(__file__).resolve().parent.parent.parent 13 | sys.path.insert(0, str(script_dir)) 14 | 15 | # Change to the project root directory 16 | os.chdir(script_dir) 17 | 18 | 19 | def enable_request_logging(): 20 | """Enable request logging by setting LOG_LEVEL=debug in .env file""" 21 | env_file = Path(".env") 22 | 23 | # Read existing .env file 24 | if env_file.exists(): 25 | with open(env_file) as f: 26 | lines = f.readlines() 27 | else: 28 | lines = [] 29 | 30 | # Check if LOG_LEVEL already exists 31 | log_level_exists = False 32 | for i, line in enumerate(lines): 33 | if line.strip().startswith("LOG_LEVEL="): 34 | lines[i] = "LOG_LEVEL=debug\n" 35 | log_level_exists = True 36 | break 37 | 38 | # Add LOG_LEVEL if it doesn't exist 39 | if not log_level_exists: 40 | lines.append("\n# Set logging level for server\nLOG_LEVEL=debug\n") 41 | 42 | # Write back to .env file 43 | with open(env_file, "w") as f: 44 | f.writelines(lines) 45 | 46 | print("✅ Request logging enabled in .env file") 47 | print("🔍 Log level set to 'debug' to show detailed request/response information") 48 | print("ℹ️ Restart your server for changes to take effect") 49 | print("\nTo see all requests in the server logs, restart with:") 50 | print(" python run.py") 51 | 52 | return True 53 | 54 | 55 | def main(): 56 | """Main entry point""" 57 | if enable_request_logging(): 58 | print("\n✨ Your server will now show detailed request logs") 59 | sys.exit(0) 60 | else: 61 | print("\n❌ Failed to enable request logging") 62 | sys.exit(1) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /alembic/versions/4a82fb8af123_create_forge_api_keys_table.py: -------------------------------------------------------------------------------- 1 | """create_forge_api_keys_table 2 | 3 | Revision ID: 4a82fb8af123 4 | Revises: b5d4363a9f62 5 | Create Date: 2023-05-15 12:00:00.000000 6 | 7 | """ 8 | import sqlalchemy as sa 9 | 10 | from alembic import op 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = "4a82fb8af123" 14 | down_revision = "b5d4363a9f62" # Update this to point to your latest migration 15 | branch_labels = None 16 | depends_on = None 17 | 18 | 19 | def upgrade(): 20 | # Create forge_api_keys table 21 | op.create_table( 22 | "forge_api_keys", 23 | sa.Column("id", sa.Integer(), nullable=False), 24 | sa.Column("key", sa.String(), nullable=False), 25 | sa.Column("name", sa.String(), nullable=True), 26 | sa.Column("user_id", sa.Integer(), nullable=False), 27 | sa.Column("is_active", sa.Boolean(), default=True), 28 | sa.Column("created_at", sa.DateTime(), default=sa.func.now()), 29 | sa.Column("last_used_at", sa.DateTime(), nullable=True), 30 | sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), 31 | sa.PrimaryKeyConstraint("id"), 32 | ) 33 | 34 | # Create indexes 35 | op.create_index( 36 | op.f("ix_forge_api_keys_id"), "forge_api_keys", ["id"], unique=False 37 | ) 38 | op.create_index( 39 | op.f("ix_forge_api_keys_key"), "forge_api_keys", ["key"], unique=True 40 | ) 41 | 42 | # Migrate existing keys 43 | # This will create a new entry in forge_api_keys for each user's existing forge_api_key 44 | op.execute( 45 | """ 46 | INSERT INTO forge_api_keys (key, name, user_id, is_active, created_at) 47 | SELECT forge_api_key, 'Legacy API Key', id, is_active, created_at 48 | FROM users 49 | WHERE forge_api_key IS NOT NULL 50 | """ 51 | ) 52 | 53 | 54 | def downgrade(): 55 | # Drop the table 56 | op.drop_index(op.f("ix_forge_api_keys_key"), table_name="forge_api_keys") 57 | op.drop_index(op.f("ix_forge_api_keys_id"), table_name="forge_api_keys") 58 | op.drop_table("forge_api_keys") 59 | -------------------------------------------------------------------------------- /tools/diagnostics/clear_cache.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Utility script to clear all cache entries. 4 | This is useful for testing and diagnosing cache-related issues. 5 | """ 6 | 7 | import os 8 | import sys 9 | from pathlib import Path 10 | 11 | from app.core.cache import provider_service_cache, user_cache 12 | 13 | # Add the project root to the Python path 14 | script_dir = Path(__file__).resolve().parent.parent.parent 15 | sys.path.insert(0, str(script_dir)) 16 | 17 | # Change to the project root directory 18 | os.chdir(script_dir) 19 | 20 | 21 | def clear_caches(): 22 | """Clear all cache entries and print statistics""" 23 | print("🔄 Clearing caches...") 24 | 25 | # Get cache stats before clearing 26 | user_stats_before = user_cache.stats() 27 | provider_stats_before = provider_service_cache.stats() 28 | 29 | # Clear caches 30 | user_cache.clear() 31 | provider_service_cache.clear() 32 | 33 | # Get cache stats after clearing 34 | user_stats_after = user_cache.stats() 35 | provider_stats_after = provider_service_cache.stats() 36 | 37 | # Print results 38 | print("\n✅ Cache clearing complete!") 39 | print("\n📊 User Cache Statistics:") 40 | print( 41 | f" Before: {user_stats_before['entries']} entries, {user_stats_before['hits']} hits, {user_stats_before['misses']} misses" 42 | ) 43 | print(f" After: {user_stats_after['entries']} entries") 44 | 45 | print("\n📊 Provider Service Cache Statistics:") 46 | print( 47 | f" Before: {provider_stats_before['entries']} entries, {provider_stats_before['hits']} hits, {provider_stats_before['misses']} misses" 48 | ) 49 | print(f" After: {provider_stats_after['entries']} entries") 50 | 51 | print("\n🔍 Cache hit rates:") 52 | print(f" User Cache: {user_stats_before['hit_rate']:.2%}") 53 | print(f" Provider Cache: {provider_stats_before['hit_rate']:.2%}") 54 | 55 | return True 56 | 57 | 58 | def main(): 59 | """Main entry point""" 60 | if clear_caches(): 61 | print("\n✨ All caches have been cleared successfully.") 62 | sys.exit(0) 63 | else: 64 | print("\n❌ Failed to clear caches.") 65 | sys.exit(1) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /tests/unit_tests/assets/anthropic/chat_completion_streaming_response_1.json: -------------------------------------------------------------------------------- 1 | [ 2 | "event: message_start\n", 3 | "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_011R37jSCigty4xP95pkskoU\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-20250514\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":13,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":0,\"output_tokens\":1,\"service_tier\":\"standard\"}} }\n", 4 | "\n", 5 | "event: content_block_start\n", 6 | "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"} }\n", 7 | "\n", 8 | "event: ping\n", 9 | "data: {\"type\": \"ping\"}\n", 10 | "\n", 11 | "event: content_block_delta\n", 12 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"} }\n", 13 | "\n", 14 | "event: content_block_delta\n", 15 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"! I'm doing well, thank you\"} }\n", 16 | "\n", 17 | "event: content_block_delta\n", 18 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" for asking. I'm here and ready to help with\"} }\n", 19 | "\n", 20 | "event: content_block_delta\n", 21 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" whatever you'd like to discuss\"} }\n", 22 | "\n", 23 | "event: content_block_delta\n", 24 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" or work on. How\"} }\n", 25 | "\n", 26 | "event: content_block_delta\n", 27 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" are you doing today?\"} }\n", 28 | "\n", 29 | "event: content_block_stop\n", 30 | "data: {\"type\":\"content_block_stop\",\"index\":0 }\n", 31 | "\n", 32 | "event: message_delta\n", 33 | "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":39} }\n", 34 | "\n", 35 | "event: message_stop\n", 36 | "data: {\"type\":\"message_stop\" }\n", 37 | "\n" 38 | ] -------------------------------------------------------------------------------- /tools/diagnostics/check_dotenv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Script to check how dotenv loads values from the .env file.""" 3 | 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | from dotenv import find_dotenv, load_dotenv 9 | 10 | 11 | def check_env_file(file_path): 12 | """Check if a file exists and print its contents.""" 13 | path = Path(file_path) 14 | if path.exists(): 15 | print(f"File exists: {path.absolute()}") 16 | print("File contents:") 17 | with open(path) as f: 18 | for line in f: 19 | if line.strip() and "=" in line: 20 | key, value = line.strip().split("=", 1) 21 | if key == "FORGE_API_KEY": 22 | print(f"{key}={value[:8]}...") 23 | else: 24 | print(line.strip()) 25 | else: 26 | print(f"File does not exist: {path.absolute()}") 27 | 28 | 29 | def main(): 30 | """Main function to check environment loading.""" 31 | # Find the project root directory 32 | script_dir = os.path.dirname( 33 | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 34 | ) 35 | os.chdir(script_dir) # Change to project root directory 36 | 37 | print("Python version:", sys.version) 38 | print("Current working directory:", os.getcwd()) 39 | 40 | # Check local .env file 41 | print("\nChecking .env file in current directory:") 42 | check_env_file(".env") 43 | 44 | # Check find_dotenv 45 | print("\nUsing find_dotenv to locate .env file:") 46 | found_dotenv = find_dotenv() 47 | print(f"Found .env at: {found_dotenv}") 48 | if found_dotenv: 49 | check_env_file(found_dotenv) 50 | 51 | # Try loading with load_dotenv 52 | print("\nLoading environment variables with load_dotenv:") 53 | load_dotenv(verbose=True) 54 | 55 | # Check if variable was loaded 56 | api_key = os.getenv("FORGE_API_KEY", "") 57 | if api_key: 58 | print(f"FORGE_API_KEY loaded: {api_key[:8]}...") 59 | else: 60 | print("FORGE_API_KEY not loaded") 61 | 62 | # Check all environment variables (not just from .env) 63 | print("\nAll environment variables containing 'FORGE':") 64 | for key, value in os.environ.items(): 65 | if "FORGE" in key: 66 | print( 67 | f"{key}={value[:8]}..." if key == "FORGE_API_KEY" else f"{key}={value}" 68 | ) 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /app/api/routes/admin.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, Depends, HTTPException 2 | from sqlalchemy import select 3 | from decimal import Decimal 4 | from pydantic import BaseModel 5 | import uuid 6 | 7 | from app.api.dependencies import get_current_active_admin_user_from_clerk 8 | from app.core.database import get_async_db 9 | from sqlalchemy.ext.asyncio import AsyncSession 10 | from app.models.user import User 11 | from app.models.stripe import StripePayment 12 | from app.core.logger import get_logger 13 | from app.api.schemas.admin import AddBalanceRequest 14 | from app.services.wallet_service import WalletService 15 | 16 | logger = get_logger(name="admin") 17 | router = APIRouter() 18 | 19 | class AddBalanceResponse(BaseModel): 20 | balance: Decimal 21 | blocked: bool 22 | 23 | 24 | @router.post("/add-balance") 25 | async def add_balance( 26 | add_balance_request: AddBalanceRequest, 27 | current_user: User = Depends(get_current_active_admin_user_from_clerk), 28 | db: AsyncSession = Depends(get_async_db), 29 | ): 30 | """Add balance to a user""" 31 | user_id = add_balance_request.user_id 32 | email = add_balance_request.email 33 | amount = add_balance_request.amount 34 | 35 | result = await db.execute( 36 | select(User) 37 | .where( 38 | user_id is None or User.id == user_id, 39 | email is None or User.email == email, 40 | ) 41 | ) 42 | user = result.scalar_one_or_none() 43 | if not user: 44 | raise HTTPException(status_code=404, detail="User not found") 45 | 46 | amount_decimal = Decimal(amount / 100.0) 47 | result = await WalletService.adjust(db, user.id, amount_decimal, f"Admin {current_user.id} added balance for user {user.id}") 48 | if not result.get("success"): 49 | raise HTTPException(status_code=400, detail=f"Failed to add balance for user {user.id}: {result.get('reason')}") 50 | 51 | # add the amount to the user's stripe payment 52 | stripe_payment = StripePayment( 53 | id=f"tb_admin_{uuid.uuid4().hex}", 54 | user_id=user.id, 55 | amount=amount, 56 | currency="USD", 57 | status="completed", 58 | raw_data={"reason": f"Admin {current_user.id} added balance for user {user.id}"}, 59 | ) 60 | db.add(stripe_payment) 61 | await db.commit() 62 | logger.info(f"Added balance {amount_decimal} for user {user.id} by admin {current_user.id}") 63 | 64 | return AddBalanceResponse(balance=result.get("balance"), blocked=result.get("blocked")) 65 | -------------------------------------------------------------------------------- /alembic/versions/ca1ac51334ec_create_usage_stats_table.py: -------------------------------------------------------------------------------- 1 | """Create usage_stats table 2 | 3 | Revision ID: ca1ac51334ec 4 | Revises: initial_migration 5 | Create Date: 2024-08-05 21:50:30.028279 6 | 7 | """ 8 | 9 | import sqlalchemy as sa 10 | 11 | from alembic import op 12 | 13 | # revision identifiers, used by Alembic. 14 | revision = "ca1ac51334ec" 15 | down_revision = "initial_migration" # Set the correct previous migration 16 | branch_labels = None 17 | depends_on = None 18 | 19 | 20 | def upgrade() -> None: 21 | # ### commands auto generated by Alembic - please adjust! ### 22 | op.create_table( 23 | "usage_stats", 24 | sa.Column("id", sa.Integer(), nullable=False), 25 | sa.Column("user_id", sa.Integer(), nullable=True), 26 | sa.Column("provider_name", sa.String(), nullable=True), 27 | sa.Column("model", sa.String(), nullable=True), 28 | sa.Column("request_count", sa.Integer(), nullable=True), 29 | sa.Column("success_count", sa.Integer(), nullable=True), 30 | sa.Column("prompt_tokens", sa.Integer(), nullable=True), 31 | sa.Column("completion_tokens", sa.Integer(), nullable=True), 32 | sa.Column("error_count", sa.Integer(), nullable=True), 33 | sa.Column("timestamp", sa.DateTime(), nullable=True), 34 | sa.ForeignKeyConstraint( 35 | ["user_id"], ["users.id"], name=op.f("fk_usage_stats_user_id_users") 36 | ), 37 | sa.PrimaryKeyConstraint("id", name=op.f("pk_usage_stats")), 38 | ) 39 | op.create_index(op.f("ix_usage_stats_id"), "usage_stats", ["id"], unique=False) 40 | op.create_index( 41 | op.f("ix_usage_stats_user_id"), "usage_stats", ["user_id"], unique=False 42 | ) 43 | op.create_index( 44 | op.f("ix_usage_stats_provider_name"), 45 | "usage_stats", 46 | ["provider_name"], 47 | unique=False, 48 | ) 49 | op.create_index( 50 | op.f("ix_usage_stats_model"), "usage_stats", ["model"], unique=False 51 | ) 52 | 53 | # ### end Alembic commands ### 54 | 55 | 56 | def downgrade() -> None: 57 | # ### commands auto generated by Alembic - please adjust! ### 58 | op.drop_index(op.f("ix_usage_stats_model"), table_name="usage_stats") 59 | op.drop_index(op.f("ix_usage_stats_provider_name"), table_name="usage_stats") 60 | op.drop_index(op.f("ix_usage_stats_user_id"), table_name="usage_stats") 61 | op.drop_index(op.f("ix_usage_stats_id"), table_name="usage_stats") 62 | op.drop_table("usage_stats") 63 | # ### end Alembic commands ### 64 | -------------------------------------------------------------------------------- /tests/mock_testing/run_mock_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Run all mock tests from a single script. 4 | """ 5 | 6 | import asyncio 7 | import os 8 | import subprocess 9 | import sys 10 | 11 | # Add the parent directory to the path 12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 13 | 14 | 15 | def print_header(text): 16 | """Print a header with decoration""" 17 | print("\n" + "=" * 80) 18 | print(f" {text} ".center(80, "=")) 19 | print("=" * 80) 20 | 21 | 22 | async def run_mock_client_tests(): 23 | """Run the mock client tests""" 24 | print_header("Running Mock Client Tests") 25 | 26 | # Get the path to the test_mock_client.py script 27 | script_path = os.path.join(os.path.dirname(__file__), "test_mock_client.py") 28 | 29 | # Run the script as a subprocess 30 | result = subprocess.run( 31 | ["python", script_path], capture_output=True, text=True, check=False 32 | ) 33 | 34 | # Print the output 35 | if result.stdout: 36 | print(result.stdout) 37 | if result.stderr: 38 | print(result.stderr, file=sys.stderr) 39 | 40 | return result.returncode == 0 41 | 42 | 43 | async def run_example_tests(): 44 | """Run the example tests""" 45 | print_header("Running Example Tests") 46 | 47 | # Get the path to the test_with_mocks.py script 48 | script_path = os.path.join( 49 | os.path.dirname(__file__), "examples", "test_with_mocks.py" 50 | ) 51 | 52 | # Run the script as a subprocess with unittest discover 53 | result = subprocess.run( 54 | ["python", "-m", "unittest", script_path], 55 | capture_output=True, 56 | text=True, 57 | check=False, 58 | ) 59 | 60 | # Print the output 61 | if result.stdout: 62 | print(result.stdout) 63 | if result.stderr: 64 | print(result.stderr, file=sys.stderr) 65 | 66 | return result.returncode == 0 67 | 68 | 69 | async def main(): 70 | """Run all tests""" 71 | client_tests_ok = await run_mock_client_tests() 72 | example_tests_ok = await run_example_tests() 73 | 74 | print_header("Summary") 75 | print(f"Mock Client Tests: {'PASSED' if client_tests_ok else 'FAILED'}") 76 | print(f"Example Tests: {'PASSED' if example_tests_ok else 'FAILED'}") 77 | 78 | if client_tests_ok and example_tests_ok: 79 | print("\n✅ All mock tests passed!") 80 | return 0 81 | else: 82 | print("\n❌ Some tests failed.") 83 | return 1 84 | 85 | 86 | if __name__ == "__main__": 87 | exit_code = asyncio.run(main()) 88 | sys.exit(exit_code) 89 | -------------------------------------------------------------------------------- /alembic/versions/0ce4eeae965f_drop_usage_stats_table.py: -------------------------------------------------------------------------------- 1 | """Drop usage_stats table 2 | 3 | Revision ID: 0ce4eeae965f 4 | Revises: b38aad374524 5 | Create Date: 2025-04-08 21:34:42.905629 6 | 7 | """ 8 | 9 | import sqlalchemy as sa 10 | 11 | from alembic import op 12 | 13 | # revision identifiers, used by Alembic. 14 | revision = "0ce4eeae965f" 15 | down_revision = "b38aad374524" 16 | branch_labels = None 17 | depends_on = None 18 | 19 | 20 | def upgrade() -> None: 21 | # ### commands auto generated by Alembic - please adjust! ### 22 | with op.batch_alter_table("usage_stats", schema=None) as batch_op: 23 | batch_op.drop_index("ix_usage_stats_id") 24 | batch_op.drop_index("ix_usage_stats_model") 25 | batch_op.drop_index("ix_usage_stats_provider_name") 26 | 27 | op.drop_table("usage_stats") 28 | # ### end Alembic commands ### 29 | 30 | 31 | def downgrade() -> None: 32 | # ### commands auto generated by Alembic - please adjust! ### 33 | op.create_table( 34 | "usage_stats", 35 | sa.Column("id", sa.INTEGER(), nullable=False), 36 | sa.Column("user_id", sa.INTEGER(), nullable=False), 37 | sa.Column("provider_name", sa.VARCHAR(), nullable=False), 38 | sa.Column("model", sa.VARCHAR(), nullable=False), 39 | sa.Column("prompt_tokens", sa.INTEGER(), nullable=True), 40 | sa.Column("completion_tokens", sa.INTEGER(), nullable=True), 41 | sa.Column("total_tokens", sa.INTEGER(), nullable=True), 42 | sa.Column("request_count", sa.INTEGER(), nullable=True), 43 | sa.Column("success_count", sa.INTEGER(), nullable=True), 44 | sa.Column("error_count", sa.INTEGER(), nullable=True), 45 | sa.Column("timestamp", sa.DATETIME(), nullable=True), 46 | sa.Column("endpoint", sa.VARCHAR(), nullable=False), 47 | sa.Column("input_tokens", sa.INTEGER(), nullable=True), 48 | sa.Column("output_tokens", sa.INTEGER(), nullable=True), 49 | sa.Column("requests_count", sa.INTEGER(), nullable=True), 50 | sa.Column("cost", sa.FLOAT(), nullable=True), 51 | sa.Column("created_at", sa.DATETIME(), nullable=True), 52 | sa.ForeignKeyConstraint( 53 | ["user_id"], 54 | ["users.id"], 55 | ), 56 | sa.PrimaryKeyConstraint("id"), 57 | ) 58 | with op.batch_alter_table("usage_stats", schema=None) as batch_op: 59 | batch_op.create_index( 60 | "ix_usage_stats_provider_name", ["provider_name"], unique=False 61 | ) 62 | batch_op.create_index("ix_usage_stats_model", ["model"], unique=False) 63 | batch_op.create_index("ix_usage_stats_id", ["id"], unique=False) 64 | 65 | # ### end Alembic commands ### 66 | -------------------------------------------------------------------------------- /app/api/schemas/statistic.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, field_validator 2 | from datetime import datetime 3 | import re 4 | import decimal 5 | 6 | from app.api.schemas.forge_api_key import ForgeApiKeyMasked 7 | 8 | def mask_forge_name_or_key(v: str) -> str: 9 | # If the forge key is a valid forge key, mask it 10 | if re.match(r"forge-\w{18}", v): 11 | return ForgeApiKeyMasked.mask_api_key(v) 12 | # Otherwise, return the original value (user customized name) 13 | return v 14 | 15 | class UsageRealtimeItem(BaseModel): 16 | timestamp: datetime | str 17 | forge_key: str 18 | provider_name: str 19 | model_name: str 20 | tokens: int 21 | input_tokens: int 22 | output_tokens: int 23 | cached_tokens: int 24 | duration: float 25 | cost: decimal.Decimal 26 | billable: bool 27 | 28 | @field_validator('forge_key') 29 | @classmethod 30 | def mask_forge_key(cls, v: str) -> str: 31 | return mask_forge_name_or_key(v) 32 | 33 | @field_validator('timestamp') 34 | @classmethod 35 | def convert_timestamp_to_iso(cls, v: datetime | str) -> str: 36 | if isinstance(v, str): 37 | return v 38 | return v.isoformat() 39 | 40 | class UsageRealtimeResponse(BaseModel): 41 | total: int 42 | items: list[UsageRealtimeItem] 43 | page_size: int 44 | page_index: int 45 | 46 | class UsageSummaryBreakdown(BaseModel): 47 | forge_key: str 48 | tokens: int 49 | cost: decimal.Decimal 50 | charged_cost: decimal.Decimal 51 | input_tokens: int 52 | output_tokens: int 53 | cached_tokens: int 54 | 55 | @field_validator('forge_key') 56 | @classmethod 57 | def mask_forge_key(cls, v: str) -> str: 58 | return mask_forge_name_or_key(v) 59 | 60 | 61 | class UsageSummaryResponse(BaseModel): 62 | time_point: datetime | str 63 | breakdown: list[UsageSummaryBreakdown] 64 | total_tokens: int 65 | total_cost: decimal.Decimal 66 | total_charged_cost: decimal.Decimal 67 | total_input_tokens: int 68 | total_output_tokens: int 69 | total_cached_tokens: int 70 | 71 | @field_validator('time_point') 72 | @classmethod 73 | def convert_timestamp_to_iso(cls, v: datetime | str) -> str: 74 | if isinstance(v, str): 75 | return v 76 | return v.isoformat() 77 | 78 | 79 | class ForgeKeysUsageSummaryResponse(BaseModel): 80 | forge_key: str 81 | tokens: int 82 | cost: decimal.Decimal 83 | charged_cost: decimal.Decimal 84 | input_tokens: int 85 | output_tokens: int 86 | cached_tokens: int 87 | 88 | @field_validator('forge_key') 89 | @classmethod 90 | def mask_forge_key(cls, v: str) -> str: 91 | return mask_forge_name_or_key(v) -------------------------------------------------------------------------------- /alembic/versions/initial_migration.py: -------------------------------------------------------------------------------- 1 | """initial migration 2 | 3 | Revision ID: initial_migration 4 | Revises: 5 | Create Date: 2023-05-01 00:00:00.000000 6 | 7 | """ 8 | 9 | import sqlalchemy as sa 10 | 11 | from alembic import op 12 | 13 | # revision identifiers, used by Alembic. 14 | revision = "initial_migration" 15 | down_revision = None 16 | branch_labels = None 17 | depends_on = None 18 | 19 | 20 | def upgrade() -> None: 21 | # Create users table 22 | op.create_table( 23 | "users", 24 | sa.Column("id", sa.Integer(), nullable=False), 25 | sa.Column("email", sa.String(), nullable=True), 26 | sa.Column("username", sa.String(), nullable=True), 27 | sa.Column("hashed_password", sa.String(), nullable=True), 28 | sa.Column("is_active", sa.Boolean(), nullable=True), 29 | sa.Column("forge_api_key", sa.String(), nullable=True), 30 | sa.Column("created_at", sa.DateTime(), nullable=True), 31 | sa.Column("updated_at", sa.DateTime(), nullable=True), 32 | sa.PrimaryKeyConstraint("id"), 33 | ) 34 | op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True) 35 | op.create_index( 36 | op.f("ix_users_forge_api_key"), "users", ["forge_api_key"], unique=True 37 | ) 38 | op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False) 39 | op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True) 40 | 41 | # Create provider_keys table 42 | op.create_table( 43 | "provider_keys", 44 | sa.Column("id", sa.Integer(), nullable=False), 45 | sa.Column("provider_name", sa.String(), nullable=True), 46 | sa.Column("encrypted_api_key", sa.String(), nullable=True), 47 | sa.Column("user_id", sa.Integer(), nullable=True), 48 | sa.Column("base_url", sa.String(), nullable=True), 49 | sa.Column("model_mapping", sa.String(), nullable=True), 50 | sa.Column("created_at", sa.DateTime(), nullable=True), 51 | sa.Column("updated_at", sa.DateTime(), nullable=True), 52 | sa.ForeignKeyConstraint( 53 | ["user_id"], 54 | ["users.id"], 55 | ), 56 | sa.PrimaryKeyConstraint("id"), 57 | ) 58 | op.create_index(op.f("ix_provider_keys_id"), "provider_keys", ["id"], unique=False) 59 | op.create_index( 60 | op.f("ix_provider_keys_provider_name"), 61 | "provider_keys", 62 | ["provider_name"], 63 | unique=False, 64 | ) 65 | 66 | 67 | def downgrade() -> None: 68 | op.drop_index(op.f("ix_provider_keys_provider_name"), table_name="provider_keys") 69 | op.drop_index(op.f("ix_provider_keys_id"), table_name="provider_keys") 70 | op.drop_table("provider_keys") 71 | op.drop_index(op.f("ix_users_username"), table_name="users") 72 | op.drop_index(op.f("ix_users_id"), table_name="users") 73 | op.drop_index(op.f("ix_users_forge_api_key"), table_name="users") 74 | op.drop_index(op.f("ix_users_email"), table_name="users") 75 | op.drop_table("users") 76 | -------------------------------------------------------------------------------- /tests/unit_tests/test_security.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from unittest import IsolatedAsyncioTestCase as TestCase 4 | 5 | # Add the parent directory to the path so Python can find the app module 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | 8 | # Patch bcrypt version detection to avoid warnings 9 | import bcrypt 10 | 11 | if not hasattr(bcrypt, "__about__"): 12 | import types 13 | 14 | bcrypt.__about__ = types.ModuleType("__about__") 15 | bcrypt.__about__.__version__ = ( 16 | bcrypt.__version__ if hasattr(bcrypt, "__version__") else "3.2.0" 17 | ) 18 | 19 | from jose import jwt 20 | 21 | from app.core.security import ( 22 | ALGORITHM, 23 | SECRET_KEY, 24 | create_access_token, 25 | decrypt_api_key, 26 | encrypt_api_key, 27 | generate_forge_api_key, 28 | get_password_hash, 29 | verify_password, 30 | ) 31 | 32 | 33 | class TestSecurity(TestCase): 34 | """Test the security utilities""" 35 | 36 | async def test_password_hashing(self): 37 | """Test password hashing and verification""" 38 | password = "test_password123" 39 | hashed = get_password_hash(password) 40 | 41 | # Verify the hash is different from the original password 42 | self.assertNotEqual(password, hashed) 43 | 44 | # Verify the password against the hash 45 | self.assertTrue(verify_password(password, hashed)) 46 | 47 | # Verify wrong password fails 48 | self.assertFalse(verify_password("wrong_password", hashed)) 49 | 50 | async def test_jwt_token(self): 51 | """Test JWT token creation and verification""" 52 | data = {"sub": "testuser"} 53 | token = create_access_token(data) 54 | 55 | # Decode and verify the token 56 | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) 57 | 58 | # Check that original data is preserved 59 | self.assertEqual(payload["sub"], "testuser") 60 | 61 | # Check that expiration is added 62 | self.assertIn("exp", payload) 63 | 64 | async def test_api_key_encryption(self): 65 | """Test API key encryption and decryption""" 66 | original_key = "sk-123456789abcdef" 67 | 68 | # Encrypt the key 69 | encrypted = encrypt_api_key(original_key) 70 | 71 | # Verify encrypted key is different 72 | self.assertNotEqual(original_key, encrypted) 73 | 74 | # Decrypt and verify 75 | decrypted = decrypt_api_key(encrypted) 76 | self.assertEqual(original_key, decrypted) 77 | 78 | async def test_forge_api_key_generation(self): 79 | """Test Forge API key generation""" 80 | key1 = generate_forge_api_key() 81 | key2 = generate_forge_api_key() 82 | 83 | # Check the format 84 | self.assertTrue(key1.startswith("forge-")) 85 | 86 | # Check uniqueness 87 | self.assertNotEqual(key1, key2) 88 | 89 | # Check length 90 | self.assertEqual( 91 | len(key1), 42 92 | ) # "forge-" (6) + checksum (4) + base_key (32 hex chars) 93 | -------------------------------------------------------------------------------- /alembic/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging.config import fileConfig 3 | 4 | from dotenv import load_dotenv 5 | from sqlalchemy import engine_from_config, pool 6 | 7 | from alembic import context 8 | from app.models.api_request_log import ApiRequestLog # noqa 9 | from app.models.base import Base 10 | from app.models.provider_key import ProviderKey # noqa 11 | 12 | # Explicitly import models for autogenerate to find them 13 | from app.models.user import User # noqa 14 | 15 | # Load environment variables from .env file 16 | load_dotenv() 17 | 18 | # this is the Alembic Config object, which provides 19 | # access to the values within the .ini file in use. 20 | config = context.config 21 | 22 | # Override the SQLAlchemy URL with environment variable 23 | database_url = os.getenv("DATABASE_URL") 24 | if not database_url: 25 | raise ValueError("DATABASE_URL environment variable is not set") 26 | config.set_main_option("sqlalchemy.url", database_url) 27 | 28 | # Interpret the config file for Python logging. 29 | # This line sets up loggers basically. 30 | if config.config_file_name is not None: 31 | fileConfig(config.config_file_name) 32 | 33 | # add your model's MetaData object here 34 | # for 'autogenerate' support 35 | # from myapp import mymodel 36 | # target_metadata = mymodel.Base.metadata 37 | target_metadata = Base.metadata 38 | 39 | # other values from the config, defined by the needs of env.py, 40 | # can be acquired: 41 | # my_important_option = config.get_main_option("my_important_option") 42 | # ... etc. 43 | 44 | 45 | def run_migrations_offline() -> None: 46 | """Run migrations in 'offline' mode. 47 | 48 | This configures the context with just a URL 49 | and not an Engine, though an Engine is acceptable 50 | here as well. By skipping the Engine creation 51 | we don't even need a DBAPI to be available. 52 | 53 | Calls to context.execute() here emit the given string to the 54 | script output. 55 | 56 | """ 57 | url = config.get_main_option("sqlalchemy.url") 58 | context.configure( 59 | url=url, 60 | target_metadata=target_metadata, 61 | literal_binds=True, 62 | dialect_opts={"paramstyle": "named"}, 63 | render_as_batch=True, 64 | ) 65 | 66 | with context.begin_transaction(): 67 | context.run_migrations() 68 | 69 | 70 | def run_migrations_online() -> None: 71 | """Run migrations in 'online' mode. 72 | 73 | In this scenario we need to create an Engine 74 | and associate a connection with the context. 75 | 76 | """ 77 | connectable = engine_from_config( 78 | config.get_section(config.config_ini_section, {}), 79 | prefix="sqlalchemy.", 80 | poolclass=pool.NullPool, 81 | ) 82 | 83 | with connectable.connect() as connection: 84 | context.configure( 85 | connection=connection, 86 | target_metadata=target_metadata, 87 | render_as_batch=True, 88 | ) 89 | 90 | with context.begin_transaction(): 91 | context.run_migrations() 92 | 93 | 94 | if context.is_offline_mode(): 95 | run_migrations_offline() 96 | else: 97 | run_migrations_online() 98 | -------------------------------------------------------------------------------- /app/models/pricing.py: -------------------------------------------------------------------------------- 1 | # app/models/pricing.py 2 | import datetime 3 | from datetime import UTC 4 | from sqlalchemy import Column, DateTime, String, DECIMAL, Index 5 | 6 | from .base import BaseModel 7 | 8 | 9 | class ModelPricing(BaseModel): 10 | """ 11 | Store pricing information for specific models with temporal support 12 | """ 13 | __tablename__ = "model_pricing" 14 | 15 | provider_name = Column(String, nullable=False, index=True) 16 | model_name = Column(String, nullable=False, index=True) 17 | 18 | # Temporal fields for price changes over time 19 | effective_date = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC)) 20 | end_date = Column(DateTime(timezone=True), nullable=True) # NULL means currently active 21 | 22 | # Pricing per 1K tokens (using DECIMAL for precision) 23 | input_token_price = Column(DECIMAL(12, 8), nullable=False) # Price per 1K input tokens 24 | output_token_price = Column(DECIMAL(12, 8), nullable=False) # Price per 1K output tokens 25 | cached_token_price = Column(DECIMAL(12, 8), nullable=False, default=0) # Price per 1K cached tokens 26 | 27 | # Metadata 28 | currency = Column(String(3), nullable=False, default='USD') 29 | 30 | # Indexes for efficient querying 31 | __table_args__ = ( 32 | # Index for finding active pricing for a model 33 | Index('ix_model_pricing_active', 'provider_name', 'model_name', 'effective_date', 'end_date'), 34 | # Index for temporal queries 35 | Index('ix_model_pricing_temporal', 'effective_date', 'end_date'), 36 | # Unique constraint for overlapping periods (business rule enforcement) 37 | Index('ix_model_pricing_unique_period', 'provider_name', 'model_name', 'effective_date', unique=True), 38 | ) 39 | 40 | 41 | class FallbackPricing(BaseModel): 42 | """ 43 | Store fallback pricing for providers and global defaults 44 | """ 45 | __tablename__ = "fallback_pricing" 46 | 47 | provider_name = Column(String, nullable=True, index=True) # NULL for global fallback 48 | model_name = Column(String, nullable=True, index=True) # NULL for global fallback 49 | fallback_type = Column(String(20), nullable=False, index=True) # 'model_default', 'provider_default', 'global_default' 50 | 51 | # Temporal fields 52 | effective_date = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC)) 53 | end_date = Column(DateTime(timezone=True), nullable=True) 54 | 55 | # Pricing per 1K tokens 56 | input_token_price = Column(DECIMAL(12, 8), nullable=False) 57 | output_token_price = Column(DECIMAL(12, 8), nullable=False) 58 | cached_token_price = Column(DECIMAL(12, 8), nullable=False, default=0) 59 | 60 | # Metadata 61 | currency = Column(String(3), nullable=False, default='USD') 62 | description = Column(String(255), nullable=True) # Optional description 63 | 64 | # Indexes 65 | __table_args__ = ( 66 | Index('ix_fallback_pricing_active', 'provider_name', 'fallback_type', 'effective_date', 'end_date'), 67 | Index('ix_fallback_pricing_type', 'fallback_type', 'effective_date'), 68 | ) 69 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # Forge Environment Variables 2 | # 3 | # This file provides a template for the environment variables needed to run the Forge application. 4 | # Copy this file to .env and fill in the appropriate values for your development environment. 5 | # Do not commit the .env file to your version control system. 6 | 7 | # ----------------------------------------------------------------------------- 8 | # DATABASE SETTINGS 9 | # ----------------------------------------------------------------------------- 10 | # PostgreSQL connection string. 11 | # Example for local development with Docker: postgresql://forge:forge@localhost:5432/forge 12 | DATABASE_URL=postgresql://user:password@host:port/dbname 13 | 14 | # ----------------------------------------------------------------------------- 15 | # REDIS CACHE SETTINGS (Optional) 16 | # ----------------------------------------------------------------------------- 17 | # URL for your Redis instance. If not provided, an in-memory cache will be used. 18 | # REDIS_URL=redis://localhost:6379/0 19 | # Prefix for all Redis keys to avoid collisions. 20 | REDIS_PREFIX=forge 21 | 22 | # ----------------------------------------------------------------------------- 23 | # SECURITY & AUTHENTICATION 24 | # ----------------------------------------------------------------------------- 25 | # Secret key for signing JWTs. Generate a strong, random key. 26 | # Use `openssl rand -hex 32` to generate a key. 27 | SECRET_KEY= 28 | # Secret key for encrypting sensitive data like provider API keys. 29 | # Generate with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())" 30 | ENCRYPTION_KEY= 31 | # Algorithm for JWT signing. 32 | ALGORITHM=HS256 33 | # Access token expiration time in minutes. 34 | ACCESS_TOKEN_EXPIRE_MINUTES=60 35 | 36 | # ----------------------------------------------------------------------------- 37 | # CLERK AUTHENTICATION INTEGRATION (OPTIONAL) 38 | # ----------------------------------------------------------------------------- 39 | # Your Clerk JWT public key (remove -----BEGIN PUBLIC KEY----- and -----END PUBLIC KEY----- headers/footers and newlines). 40 | CLERK_JWT_PUBLIC_KEY= 41 | # Your Clerk API key. 42 | CLERK_API_KEY= 43 | # Your Clerk API URL. 44 | CLERK_API_URL=https://api.clerk.com/v1 45 | # Your Clerk webhook signing secret. 46 | CLERK_WEBHOOK_SECRET= 47 | 48 | # ----------------------------------------------------------------------------- 49 | # APPLICATION & DEVELOPMENT SETTINGS 50 | # ----------------------------------------------------------------------------- 51 | # Set to "production" or "development". 52 | ENVIRONMENT=development 53 | # Set to "true" to enable debug-level logging. 54 | FORGE_DEBUG_LOGGING=false 55 | # Host for the application. 56 | HOST=0.0.0.0 57 | # Port for the. 58 | PORT=8000 59 | # Set to "true" to enable auto-reloading for development. 60 | RELOAD=false 61 | # Number of Gunicorn worker processes. 62 | WORKERS=4 63 | # Set to "true" to force the use of in-memory cache, even if REDIS_URL is set. 64 | # Useful for testing environments. 65 | FORCE_MEMORY_CACHE=false 66 | # Set to "true" to enable verbose cache debugging output. 67 | DEBUG_CACHE=false 68 | # Your application's domain, used for JWT audience validation. 69 | APP_DOMAIN=localhost 70 | -------------------------------------------------------------------------------- /app/api/routes/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import AsyncGenerator 2 | import json 3 | from fastapi import HTTPException 4 | from starlette.responses import StreamingResponse 5 | from app.exceptions.exceptions import ProviderAPIException 6 | 7 | async def wrap_streaming_response_with_error_handling( 8 | logger, async_gen: AsyncGenerator[bytes, None] 9 | ) -> StreamingResponse: 10 | """ 11 | Wraps an async generator to catch and properly handle errors in streaming responses. 12 | Returns a StreamingResponse that will: 13 | - Return proper HTTP error status if error occurs before first chunk 14 | - Send error as SSE event if error occurs mid-stream 15 | 16 | Args: 17 | logger: Logger instance for error logging 18 | async_gen: The async generator producing the stream chunks 19 | 20 | Returns: 21 | StreamingResponse with proper error handling 22 | 23 | Raises: 24 | HTTPException: If error occurs before streaming starts 25 | """ 26 | 27 | # Try to get the first chunk BEFORE creating StreamingResponse 28 | # This allows us to catch immediate errors and return proper HTTP status 29 | try: 30 | first_chunk = await async_gen.__anext__() 31 | except StopAsyncIteration: 32 | # Empty stream 33 | logger.error("Empty stream response") 34 | raise HTTPException(status_code=500, detail="Empty stream response") 35 | except ProviderAPIException as e: 36 | logger.error(f"Provider API error: {str(e)}") 37 | raise HTTPException(status_code=e.error_code, detail=e.error_message) from e 38 | except Exception as e: 39 | # Convert other exceptions to HTTPException 40 | logger.error(f"Error before streaming started: {str(e)}") 41 | raise HTTPException(status_code=500, detail=str(e)) from e 42 | 43 | # Success! Now create generator that replays first chunk + rest 44 | async def response_generator(): 45 | # Yield the first chunk we already got 46 | yield first_chunk 47 | 48 | try: 49 | # Continue with the rest of the stream 50 | async for chunk in async_gen: 51 | yield chunk 52 | except Exception as e: 53 | # Error occurred mid-stream - HTTP status already sent 54 | # Send error as SSE event to inform the client 55 | logger.error(f"Error during streaming: {str(e)}") 56 | 57 | error_message = str(e) 58 | error_event = { 59 | "error": { 60 | "message": error_message, 61 | "type": "stream_error", 62 | "code": "provider_error" 63 | } 64 | } 65 | yield f"data: {json.dumps(error_event)}\n\n".encode() 66 | 67 | # Send [DONE] to properly close the stream 68 | yield b"data: [DONE]\n\n" 69 | 70 | # Set appropriate headers for streaming 71 | headers = { 72 | "Content-Type": "text/event-stream", 73 | "Cache-Control": "no-cache", 74 | "Connection": "keep-alive", 75 | "X-Accel-Buffering": "no", # Prevent Nginx buffering 76 | } 77 | 78 | return StreamingResponse( 79 | response_generator(), 80 | media_type="text/event-stream", 81 | headers=headers 82 | ) -------------------------------------------------------------------------------- /alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | script_location = alembic 6 | 7 | # template used to generate migration files 8 | # file_template = %%(rev)s_%%(slug)s 9 | 10 | # sys.path path, will be prepended to sys.path if present. 11 | # defaults to the current working directory. 12 | prepend_sys_path = . 13 | 14 | # timezone to use when rendering the date within the migration file 15 | # as well as the filename. 16 | # If specified, requires the python-dateutil library that can be 17 | # installed by adding `alembic[tz]` to the pip requirements 18 | # string value is passed to dateutil.tz.gettz() 19 | # leave blank for localtime 20 | # timezone = 21 | 22 | # max length of characters to apply to the 23 | # "slug" field 24 | # truncate_slug_length = 40 25 | 26 | # set to 'true' to run the environment during 27 | # the 'revision' command, regardless of autogenerate 28 | # revision_environment = false 29 | 30 | # set to 'true' to allow .pyc and .pyo files without 31 | # a source .py file to be detected as revisions in the 32 | # versions/ directory 33 | # sourceless = false 34 | 35 | # version location specification; This defaults 36 | # to alembic/versions. When using multiple version 37 | # directories, initial revisions must be specified with --version-path. 38 | # The path separator used here should be the separator specified by "version_path_separator" below. 39 | # version_locations = %(here)s/bar:%(here)s/bat:alembic/versions 40 | 41 | # version path separator; As mentioned above, this is the character used to split 42 | # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. 43 | # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. 44 | # Valid values for version_path_separator are: 45 | # 46 | # version_path_separator = : 47 | # version_path_separator = ; 48 | # version_path_separator = space 49 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects. 50 | 51 | # the output encoding used when revision files 52 | # are written from script.py.mako 53 | # output_encoding = utf-8 54 | 55 | 56 | [post_write_hooks] 57 | # post_write_hooks defines scripts or Python functions that are run 58 | # on newly generated revision scripts. See the documentation for further 59 | # detail and examples 60 | 61 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 62 | # hooks = black 63 | # black.type = console_scripts 64 | # black.entrypoint = black 65 | # black.options = -l 79 REVISION_SCRIPT_FILENAME 66 | 67 | # Logging configuration 68 | [loggers] 69 | keys = root,sqlalchemy,alembic 70 | 71 | [handlers] 72 | keys = console 73 | 74 | [formatters] 75 | keys = generic 76 | 77 | [logger_root] 78 | level = WARN 79 | handlers = console 80 | qualname = 81 | 82 | [logger_sqlalchemy] 83 | level = WARN 84 | handlers = 85 | qualname = sqlalchemy.engine 86 | 87 | [logger_alembic] 88 | level = INFO 89 | handlers = 90 | qualname = alembic 91 | 92 | [handler_console] 93 | class = StreamHandler 94 | args = (sys.stderr,) 95 | level = NOTSET 96 | formatter = generic 97 | 98 | [formatter_generic] 99 | format = %(levelname)-5.5s [%(name)s] %(message)s 100 | datefmt = %H:%M:%S 101 | -------------------------------------------------------------------------------- /alembic/versions/c50fd7be794c_add_endpoint_column_to_usage_stats.py: -------------------------------------------------------------------------------- 1 | """Add endpoint column to usage_stats 2 | 3 | Revision ID: c50fd7be794c 4 | Revises: ca1ac51334ec 5 | Create Date: 2025-04-08 20:46:46.713120 6 | 7 | """ 8 | 9 | import sqlalchemy as sa 10 | 11 | from alembic import op 12 | 13 | # revision identifiers, used by Alembic. 14 | revision = "c50fd7be794c" 15 | down_revision = "ca1ac51334ec" 16 | branch_labels = None 17 | depends_on = None 18 | 19 | 20 | def upgrade() -> None: 21 | # ### commands auto generated by Alembic - please adjust! ### 22 | with op.batch_alter_table("usage_stats", schema=None) as batch_op: 23 | batch_op.add_column(sa.Column("endpoint", sa.String(), nullable=False)) 24 | batch_op.add_column(sa.Column("input_tokens", sa.Integer(), nullable=True)) 25 | batch_op.add_column(sa.Column("output_tokens", sa.Integer(), nullable=True)) 26 | batch_op.add_column(sa.Column("requests_count", sa.Integer(), nullable=True)) 27 | batch_op.add_column(sa.Column("cost", sa.Float(), nullable=True)) 28 | batch_op.add_column(sa.Column("created_at", sa.DateTime(), nullable=True)) 29 | 30 | # Drop the old FK using its specific name 31 | batch_op.drop_constraint( 32 | op.f("fk_usage_stats_user_id_users"), type_="foreignkey" 33 | ) 34 | # Create the new FK with ON DELETE CASCADE 35 | batch_op.create_foreign_key( 36 | op.f("fk_usage_stats_user_id_users"), # Keep the same conventional name 37 | "users", 38 | ["user_id"], 39 | ["id"], 40 | ondelete="CASCADE", 41 | ) 42 | 43 | batch_op.drop_column("completion_tokens") 44 | batch_op.drop_column("timestamp") 45 | batch_op.drop_column("error_count") 46 | batch_op.drop_column("prompt_tokens") 47 | batch_op.drop_column("success_count") 48 | batch_op.drop_column("request_count") 49 | 50 | # ### end Alembic commands ### 51 | 52 | 53 | def downgrade() -> None: 54 | # ### commands auto generated by Alembic - please adjust! ### 55 | with op.batch_alter_table("usage_stats", schema=None) as batch_op: 56 | batch_op.add_column(sa.Column("request_count", sa.INTEGER(), nullable=True)) 57 | batch_op.add_column(sa.Column("success_count", sa.INTEGER(), nullable=True)) 58 | batch_op.add_column(sa.Column("prompt_tokens", sa.INTEGER(), nullable=True)) 59 | batch_op.add_column(sa.Column("error_count", sa.INTEGER(), nullable=True)) 60 | batch_op.add_column(sa.Column("timestamp", sa.DATETIME(), nullable=True)) 61 | batch_op.add_column(sa.Column("completion_tokens", sa.INTEGER(), nullable=True)) 62 | 63 | # Drop the ON DELETE CASCADE FK 64 | batch_op.drop_constraint( 65 | op.f("fk_usage_stats_user_id_users"), type_="foreignkey" 66 | ) 67 | # Recreate the original FK without ON DELETE CASCADE 68 | batch_op.create_foreign_key( 69 | op.f("fk_usage_stats_user_id_users"), "users", ["user_id"], ["id"] 70 | ) 71 | 72 | batch_op.drop_column("created_at") 73 | batch_op.drop_column("cost") 74 | batch_op.drop_column("requests_count") 75 | batch_op.drop_column("output_tokens") 76 | batch_op.drop_column("input_tokens") 77 | batch_op.drop_column("endpoint") 78 | 79 | # ### end Alembic commands ### 80 | -------------------------------------------------------------------------------- /app/api/routes/stripe.py: -------------------------------------------------------------------------------- 1 | import os 2 | from fastapi import APIRouter, Depends, Request 3 | from app.api.dependencies import get_current_active_user_from_clerk, get_current_active_user 4 | from app.api.schemas.stripe import CreateCheckoutSessionRequest 5 | from app.models.user import User 6 | from app.models.stripe import StripePayment 7 | from app.core.database import get_async_db 8 | from sqlalchemy.ext.asyncio import AsyncSession 9 | import stripe 10 | from app.core.logger import get_logger 11 | from sqlalchemy import select 12 | from fastapi import HTTPException 13 | 14 | logger = get_logger(name="stripe") 15 | 16 | STRIPE_API_KEY = os.getenv("STRIPE_API_KEY") 17 | stripe.api_key = STRIPE_API_KEY 18 | 19 | router = APIRouter() 20 | 21 | @router.post("/create-checkout-session/clerk") 22 | async def stripe_create_checkout_session_clerk(request: Request, create_checkout_session_request: CreateCheckoutSessionRequest, user: User = Depends(get_current_active_user_from_clerk), db: AsyncSession = Depends(get_async_db)): 23 | return await stripe_create_checkout_session(request, create_checkout_session_request, user, db) 24 | 25 | 26 | @router.post("/create-checkout-session") 27 | async def stripe_create_checkout_session(request: Request, create_checkout_session_request: CreateCheckoutSessionRequest, user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_async_db)): 28 | """ 29 | Create a checkout session for a user. 30 | """ 31 | logger.info(f"Creating checkout session for user {user.id}") 32 | session = await stripe.checkout.Session.create_async( 33 | metadata={ 34 | "user_id": user.id, 35 | }, 36 | **create_checkout_session_request.model_dump(exclude_none=True), 37 | ) 38 | stripe_payment = StripePayment( 39 | id=session.id, 40 | user_id=user.id, 41 | status=session.status, 42 | currency=session.currency.upper(), 43 | amount=session.amount_total, 44 | # store the whole session as raw_data 45 | raw_data=dict(session), 46 | ) 47 | db.add(stripe_payment) 48 | await db.commit() 49 | 50 | return { 51 | 'session_id': session.id, 52 | 'url': session.url, 53 | } 54 | 55 | @router.get("/checkout-session") 56 | async def stripe_get_checkout_session(session_id: str, user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_async_db)): 57 | result = await db.execute( 58 | select( 59 | StripePayment 60 | ) 61 | .where(StripePayment.id == session_id, StripePayment.user_id == user.id) 62 | ) 63 | stripe_payment = result.scalar_one_or_none() 64 | if not stripe_payment: 65 | raise HTTPException(status_code=404, detail="Stripe payment not found") 66 | 67 | return { 68 | 'id': stripe_payment.id, 69 | 'status': stripe_payment.status, 70 | 'currency': stripe_payment.currency, 71 | 'amount': stripe_payment.amount / 100.0 if stripe_payment.currency == "USD" else stripe_payment.amount, 72 | 'created_at': stripe_payment.created_at, 73 | 'updated_at': stripe_payment.updated_at, 74 | } 75 | 76 | @router.get("/checkout-session/clerk") 77 | async def stripe_get_checkout_session_clerk(session_id: str, user: User = Depends(get_current_active_user_from_clerk), db: AsyncSession = Depends(get_async_db)): 78 | return await stripe_get_checkout_session(session_id, user, db) 79 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "forge" 7 | version = "0.1.0" 8 | description = "Forge is an open-source middleware service that simplifies AI model provider management. It allows you to use multiple AI providers (OpenAI, Anthropic, etc.) through a single, unified API. By storing your provider API keys securely, Forge generates a unified key that works across all your AI applications." 9 | requires-python = ">=3.12" 10 | authors = [ 11 | {name = "TensorBlock", email = "contact@tensorblock.co"}, 12 | ] 13 | readme = "README.md" 14 | license = {text = "MIT"} 15 | dependencies = [ 16 | "fastapi>=0.95.0", 17 | "uvicorn>=0.22.0", 18 | "pydantic>=2.0.0", 19 | "python-jose>=3.3.0", 20 | "stripe", 21 | "passlib>=1.7.4", 22 | "python-multipart>=0.0.5", 23 | "sqlalchemy[asyncio]>=2.0.0", 24 | "alembic>=1.10.4", 25 | "aiohttp>=3.8.4", 26 | "cryptography>=40.0.0", 27 | "bcrypt==3.2.2", 28 | "python-dotenv>=1.0.0", 29 | "email-validator>=2.0.0", 30 | "requests>=2.28.0", 31 | "svix>=1.13.0", 32 | "psycopg2-binary>=2.9.9", 33 | "asyncpg>=0.29.0", 34 | "boto3>=1.0.0", 35 | "gunicorn>=20.0.0", 36 | "redis>=4.6.0", # sync & async clients used by shared cache 37 | "loguru>=0.7.0", 38 | "aiobotocore~=2.0", 39 | "tiktoken>=0.5.0", # for token counting in Claude Code support 40 | "google-generativeai>=0.3.0", 41 | "google-genai>=0.3.0", 42 | ] 43 | 44 | [project.optional-dependencies] 45 | dev = [ 46 | "pytest>=7.0.0", 47 | "pytest-cov>=4.0.0", 48 | "pytest-asyncio>=0.26.0", 49 | "ruff>=0.2.0", 50 | "pre-commit>=3.0.0", 51 | "flake8", 52 | "black==22.6.0", 53 | "isort", 54 | "mypy", 55 | ] 56 | 57 | [tool.pytest.ini_options] 58 | filterwarnings = [ 59 | "ignore:.*'crypt' is deprecated.*:DeprecationWarning:passlib.*" 60 | ] 61 | asyncio_default_fixture_loop_scope = "function" 62 | asyncio_mode = "auto" 63 | 64 | [tool.hatch.build.targets.wheel] 65 | packages = ["app"] 66 | 67 | [tool.hatch.build.targets.sdist] 68 | include = [ 69 | "app", 70 | "alembic", 71 | "alembic.ini", 72 | "README.md", 73 | "LICENSE", 74 | ] 75 | 76 | [tool.ruff] 77 | # Same as Black 78 | indent-width = 4 79 | line-length = 88 80 | target-version = "py312" 81 | 82 | [tool.ruff.lint] 83 | # Enable Pyflakes ('F'), pycodestyle ('E'), and import sorting ('I') 84 | select = ["E", "F", "I", "N", "W", "B", "C4", "PL", "SIM", "UP"] 85 | ignore = [ 86 | "B008", # Function call in argument defaults 87 | "E501", # Line too long 88 | "PLR0912", # Too many branches 89 | "PLR0915", # Too many statements 90 | "PLR0913", # Too many arguments 91 | "B904", 92 | "PLR0911", 93 | "SIM118", 94 | "PLW2901", 95 | "SIM117", 96 | "PLR2004", 97 | ] 98 | # Allow autofix for all enabled rules (when `--fix` is passed) 99 | fixable = ["ALL"] 100 | unfixable = [] 101 | 102 | [tool.ruff.format] 103 | quote-style = "double" 104 | indent-style = "space" 105 | line-ending = "auto" 106 | skip-magic-trailing-comma = false 107 | 108 | # Exclude a variety of commonly ignored directories 109 | [tool.ruff.lint.isort] 110 | known-first-party = ["app"] 111 | 112 | [tool.pytest] 113 | testpaths = ["tests"] 114 | python_files = "test_*.py" 115 | -------------------------------------------------------------------------------- /tests/unit_tests/assets/google/list_models.json: -------------------------------------------------------------------------------- 1 | { 2 | "models": [ 3 | { 4 | "name": "models/embedding-gecko-001", 5 | "version": "001", 6 | "displayName": "Embedding Gecko", 7 | "description": "Obtain a distributed representation of a text.", 8 | "inputTokenLimit": 1024, 9 | "outputTokenLimit": 1, 10 | "supportedGenerationMethods": [ 11 | "embedText", 12 | "countTextTokens" 13 | ] 14 | }, 15 | { 16 | "name": "models/gemini-1.0-pro-vision-latest", 17 | "version": "001", 18 | "displayName": "Gemini 1.0 Pro Vision", 19 | "description": "The original Gemini 1.0 Pro Vision model version which was optimized for image understanding. Gemini 1.0 Pro Vision was deprecated on July 12, 2024. Move to a newer Gemini version.", 20 | "inputTokenLimit": 12288, 21 | "outputTokenLimit": 4096, 22 | "supportedGenerationMethods": [ 23 | "generateContent", 24 | "countTokens" 25 | ], 26 | "temperature": 0.4, 27 | "topP": 1, 28 | "topK": 32 29 | }, 30 | { 31 | "name": "models/gemini-pro-vision", 32 | "version": "001", 33 | "displayName": "Gemini 1.0 Pro Vision", 34 | "description": "The original Gemini 1.0 Pro Vision model version which was optimized for image understanding. Gemini 1.0 Pro Vision was deprecated on July 12, 2024. Move to a newer Gemini version.", 35 | "inputTokenLimit": 12288, 36 | "outputTokenLimit": 4096, 37 | "supportedGenerationMethods": [ 38 | "generateContent", 39 | "countTokens" 40 | ], 41 | "temperature": 0.4, 42 | "topP": 1, 43 | "topK": 32 44 | }, 45 | { 46 | "name": "models/gemini-1.5-pro-latest", 47 | "version": "001", 48 | "displayName": "Gemini 1.5 Pro Latest", 49 | "description": "Alias that points to the most recent production (non-experimental) release of Gemini 1.5 Pro, our mid-size multimodal model that supports up to 2 million tokens.", 50 | "inputTokenLimit": 2000000, 51 | "outputTokenLimit": 8192, 52 | "supportedGenerationMethods": [ 53 | "generateContent", 54 | "countTokens" 55 | ], 56 | "temperature": 1, 57 | "topP": 0.95, 58 | "topK": 40, 59 | "maxTemperature": 2 60 | }, 61 | { 62 | "name": "models/gemini-1.5-pro-002", 63 | "version": "002", 64 | "displayName": "Gemini 1.5 Pro 002", 65 | "description": "Stable version of Gemini 1.5 Pro, our mid-size multimodal model that supports up to 2 million tokens, released in September of 2024.", 66 | "inputTokenLimit": 2000000, 67 | "outputTokenLimit": 8192, 68 | "supportedGenerationMethods": [ 69 | "generateContent", 70 | "countTokens", 71 | "createCachedContent" 72 | ], 73 | "temperature": 1, 74 | "topP": 0.95, 75 | "topK": 40, 76 | "maxTemperature": 2 77 | } 78 | ], 79 | "nextPageToken": "Chltb2RlbHMvZ2VtaW5pLTEuNS1wcm8tMDAy" 80 | } -------------------------------------------------------------------------------- /tests/run_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import sys 4 | import unittest 5 | 6 | # Add parent directory to path to make imports work 7 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 8 | 9 | # Patch bcrypt version detection to avoid warnings 10 | import bcrypt 11 | 12 | if not hasattr(bcrypt, "__about__"): 13 | import types 14 | 15 | bcrypt.__about__ = types.ModuleType("__about__") 16 | bcrypt.__about__.__version__ = ( 17 | bcrypt.__version__ if hasattr(bcrypt, "__version__") else "3.2.0" 18 | ) 19 | 20 | # Import the cache test files 21 | # These are run separately as they need special async setup 22 | from tests.unit_tests.test_provider_service import TestProviderService 23 | 24 | # Import the test modules 25 | from tests.unit_tests.test_security import TestSecurity 26 | 27 | # Import the image related tests 28 | from tests.unit_tests.test_provider_service_images import TestProviderServiceImages 29 | 30 | 31 | # Define test suites 32 | def security_suite(): 33 | suite = unittest.TestSuite() 34 | suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSecurity)) 35 | return suite 36 | 37 | 38 | def provider_service_suite(): 39 | suite = unittest.TestSuite() 40 | suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestProviderService)) 41 | return suite 42 | 43 | def provider_service_images_suite(): 44 | suite = unittest.TestSuite() 45 | suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestProviderServiceImages)) 46 | return suite 47 | 48 | 49 | if __name__ == "__main__": 50 | runner = unittest.TextTestRunner(verbosity=2) 51 | 52 | tests_to_run = [] 53 | 54 | print("Running security tests...") 55 | result_security = runner.run(security_suite()) 56 | tests_to_run.append(result_security) 57 | 58 | print("\nRunning provider service tests...") 59 | result_provider = runner.run(provider_service_suite()) 60 | tests_to_run.append(result_provider) 61 | 62 | print("\nRunning provider service images tests...") 63 | result_provider_images = runner.run(provider_service_images_suite()) 64 | tests_to_run.append(result_provider_images) 65 | 66 | # Integration tests require a running server 67 | print("\nFor integration tests, make sure the server is running and then execute:") 68 | print("python tests/integration_test.py") 69 | 70 | # Cache tests 71 | print("\nTo run cache tests:") 72 | print("python tests/cache/test_sync_cache.py # For sync cache tests") 73 | print("python tests/cache/test_async_cache.py # For async cache tests") 74 | 75 | # Frontend simulation tests require a valid Forge API key 76 | print( 77 | "\nTo simulate a frontend application, set your Forge API key in the .env file and run:" 78 | ) 79 | print("python tests/frontend_simulation.py") 80 | 81 | # Mock provider tests 82 | print("\nTo run all mock tests at once:") 83 | print("python tests/mock_testing/run_mock_tests.py") 84 | 85 | print("\nOr to run individual mock tests:") 86 | print("python tests/mock_testing/test_mock_client.py") 87 | print("# For interactive testing:") 88 | print("python tests/mock_testing/test_mock_client.py --interactive") 89 | 90 | print("\nFor examples of using mocks in your tests, see:") 91 | print("python tests/mock_testing/examples/test_with_mocks.py") 92 | 93 | print("\nSee tests/mock_testing/README.md for more information on mock testing.") 94 | 95 | # Exit with error code if any tests failed 96 | if any(not test.wasSuccessful() for test in tests_to_run): 97 | sys.exit(1) -------------------------------------------------------------------------------- /docs/PERFORMANCE_OPTIMIZATIONS.md: -------------------------------------------------------------------------------- 1 | # Performance Optimizations for Forge 2 | 3 | This document outlines the performance optimizations implemented in the Forge application to enhance scalability, reduce database load, and improve response times. 4 | 5 | ## Key Optimizations 6 | 7 | ### 1. Multi-Level Caching 8 | 9 | We've implemented several caching mechanisms to reduce redundant operations: 10 | 11 | #### User Authentication Cache 12 | - Each API request requires user authentication via their API key 13 | - Instead of querying the database for every request, users are now cached for 5 minutes 14 | - Significant reduction in database reads for high-traffic users 15 | 16 | #### Provider Service Caching 17 | - ProviderService instances are now reused across requests for the same user 18 | - Cached for 10 minutes to balance memory usage with performance 19 | - Avoids redundant instantiation and database reads 20 | 21 | #### Provider API Key Caching 22 | - API keys are now lazy-loaded and decrypted only when needed 23 | - Once decrypted, they remain available in the cached service instance 24 | - Eliminates repeated decryption operations which are computationally expensive 25 | 26 | #### Model List Caching 27 | - Models returned by providers are cached for 1 hour 28 | - Avoids repeated API calls to list models, which can be rate-limited by providers 29 | - Each provider/configuration has its own cache entry 30 | 31 | ### 2. Lazy Loading 32 | 33 | - Provider keys are not loaded during service initialization 34 | - Keys are only loaded and decrypted when needed for a request 35 | - If a request doesn't need access to provider keys, no decryption occurs 36 | 37 | ### 3. Cache Invalidation 38 | 39 | We've implemented targeted cache invalidation to ensure data consistency: 40 | 41 | - User cache is invalidated when: 42 | - User details are updated 43 | - User's Forge API key is reset 44 | 45 | - Provider service cache is invalidated when: 46 | - Provider keys are added 47 | - Provider keys are updated 48 | - Provider keys are deleted 49 | - User details change 50 | 51 | ### 4. Error Resilience 52 | 53 | - Model listing now has error handling to prevent failures if one provider API is down 54 | - Batch processing of concurrent API requests to avoid overwhelming provider APIs 55 | 56 | ## Implementation Details 57 | 58 | ### Cache Implementation 59 | 60 | A simple in-memory cache with time-to-live (TTL) expiration: 61 | - `app/core/cache.py` contains the cache implementation 62 | - Each cache entry has its own expiration time 63 | - Cache entries are automatically removed when accessed after expiration 64 | 65 | ### Adapter Caching 66 | 67 | Provider adapters are cached at the class level: 68 | - Adapters are stateless and can be reused across all requests 69 | - Improves memory usage by sharing adapter instances 70 | 71 | ### Service Reuse Pattern 72 | 73 | The ProviderService now uses a factory pattern: 74 | - `get_instance()` method returns cached instances when available 75 | - Creates new instances only when needed 76 | 77 | ## Performance Impact 78 | 79 | These optimizations significantly reduce: 80 | - Database queries per request 81 | - Decryption operations 82 | - Memory usage 83 | - API calls to providers 84 | 85 | This results in: 86 | - Lower response times 87 | - Higher maximum throughput 88 | - Reduced database load 89 | - Better scalability 90 | 91 | ## Future Improvements 92 | 93 | Potential future optimizations: 94 | - Distributed caching (Redis) for multi-server deployments 95 | - More granular cache expiration policies 96 | - Request batching for similar consecutive requests 97 | - Advanced monitoring of cache hit/miss rates 98 | -------------------------------------------------------------------------------- /tools/diagnostics/check_db_keys.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to directly check the database for API keys. 4 | """ 5 | 6 | import os 7 | 8 | import psycopg2 9 | from dotenv import load_dotenv 10 | 11 | 12 | def check_database_connection(): 13 | """Check if database is accessible.""" 14 | try: 15 | # Load environment variables 16 | load_dotenv() 17 | 18 | # Get database URL from environment 19 | db_url = os.getenv( 20 | "DATABASE_URL", "postgresql://user:password@localhost:5432/forge" 21 | ) 22 | 23 | # Try to connect to PostgreSQL 24 | conn = psycopg2.connect(db_url) 25 | return conn 26 | except Exception as e: 27 | print(f"Database connection error: {e}") 28 | print("Please ensure:") 29 | print("1. PostgreSQL is running") 30 | print("2. Database 'forge' exists") 31 | print("3. User has proper permissions") 32 | print("4. Connection details in .env are correct") 33 | return None 34 | 35 | 36 | def main(): 37 | """Check the database for API keys.""" 38 | conn = check_database_connection() 39 | if not conn: 40 | return 41 | 42 | try: 43 | cursor = conn.cursor() 44 | 45 | # Get table names 46 | cursor.execute( 47 | """ 48 | SELECT table_name 49 | FROM information_schema.tables 50 | WHERE table_schema = 'public' 51 | """ 52 | ) 53 | tables = cursor.fetchall() 54 | print(f"Tables in database: {[table[0] for table in tables]}") 55 | 56 | # Check if users table exists 57 | if ("users",) in tables: 58 | # Count users 59 | cursor.execute("SELECT COUNT(*) FROM users") 60 | user_count = cursor.fetchone()[0] 61 | print(f"Number of users in database: {user_count}") 62 | 63 | # Get user details 64 | cursor.execute("SELECT id, username, forge_api_key FROM users") 65 | users = cursor.fetchall() 66 | 67 | print("\nUsers in database:") 68 | for user_id, username, api_key in users: 69 | print(f"User {user_id} ({username}): API Key = {api_key}") 70 | 71 | # Check against the expected key 72 | expected_key = "forge-1ea5812207aa309110b4122f38d7be34" 73 | if api_key == expected_key: 74 | print(" ✓ Matches expected key") 75 | else: 76 | print(" ✗ Does not match expected key") 77 | else: 78 | print("'users' table not found in database") 79 | 80 | # Check if provider_keys table exists 81 | if ("provider_keys",) in tables: 82 | cursor.execute("SELECT COUNT(*) FROM provider_keys") 83 | key_count = cursor.fetchone()[0] 84 | print(f"\nNumber of provider keys in database: {key_count}") 85 | 86 | # Get provider key details 87 | cursor.execute("SELECT id, provider_name, user_id FROM provider_keys") 88 | keys = cursor.fetchall() 89 | 90 | print("\nProvider keys in database:") 91 | for key_id, provider_name, user_id in keys: 92 | print(f"Key {key_id}: Provider = {provider_name}, User ID = {user_id}") 93 | 94 | cursor.close() 95 | conn.close() 96 | 97 | except psycopg2.Error as e: 98 | print(f"Database error: {e}") 99 | if conn: 100 | conn.close() 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /tests/performance/README.md: -------------------------------------------------------------------------------- 1 | # Forge Performance Tests 2 | 3 | This directory contains performance tests for evaluating the Forge middleware service's performance characteristics. 4 | 5 | ## Overview 6 | 7 | These tests measure key performance metrics such as: 8 | 9 | - Response latency (p50, p90, p99) 10 | - Throughput (requests per second) 11 | - Scalability under varying loads 12 | - Resource utilization 13 | - Provider-specific performance metrics 14 | 15 | Performance tests use the mock provider by default to ensure consistent and reproducible results without external dependencies. 16 | 17 | ## Prerequisites 18 | 19 | To run these tests, you need: 20 | 21 | 1. A running instance of Forge 22 | 2. Python packages: 23 | - locust 24 | - pytest 25 | - pytest-benchmark 26 | - aiohttp 27 | - pandas (for results analysis) 28 | - matplotlib (for visualizations) 29 | 30 | Install performance testing dependencies: 31 | 32 | ```bash 33 | pip install locust pytest pytest-benchmark aiohttp pandas matplotlib 34 | ``` 35 | 36 | ## Usage 37 | 38 | ### Single-run benchmarks: 39 | 40 | ```bash 41 | # Run all performance tests 42 | pytest tests/performance/ 43 | 44 | # Run specific performance test category 45 | pytest tests/performance/test_latency.py 46 | pytest tests/performance/test_throughput.py 47 | pytest tests/performance/test_providers.py 48 | ``` 49 | 50 | ### Load testing with Locust: 51 | 52 | ```bash 53 | # Start Locust web interface 54 | cd tests/performance 55 | locust -f locustfile.py 56 | 57 | # Run headless with 10 users, spawn rate of 1 user/sec, for 1 minute 58 | locust -f locustfile.py --headless -u 10 -r 1 --run-time 1m 59 | ``` 60 | 61 | ## About Mock Providers 62 | 63 | Performance tests use mock providers rather than real API calls to: 64 | 65 | - Eliminate external dependencies for consistent results 66 | - Avoid rate limits and costs associated with real API calls 67 | - Ensure reproducible test results 68 | - Provide consistent response patterns and timing 69 | 70 | This approach focuses on measuring Forge's own performance characteristics rather than the performance of external services. 71 | 72 | ## Baseline Results 73 | 74 | The baseline performance results are stored in `baseline_results/` and represent reference numbers for comparing future changes. 75 | 76 | ## Output Analysis 77 | 78 | Test results are saved to: 79 | - JSON reports in `results/json/` 80 | - CSV files in `results/csv/` 81 | - Graphs in `results/graphs/` 82 | 83 | Use the analysis utilities to compare results: 84 | 85 | ```bash 86 | # Compare specific test runs 87 | python tests/performance/analyze_results.py --baseline baseline_results/latest.json --new results/json/new_run.json 88 | 89 | # Generate charts for all test results 90 | python tests/performance/analyze_results.py 91 | 92 | # Generate an interactive HTML dashboard with all test results 93 | python tests/performance/analyze_results.py --dashboard 94 | ``` 95 | 96 | The HTML dashboard provides a comprehensive view of all performance test results in a single page with: 97 | - Interactive tabs for different test categories 98 | - Visual representations of all metrics 99 | - Summary statistics for each test 100 | - Test results grouped by type (latency, throughput, streaming, etc.) 101 | 102 | After generating the dashboard, open it in any web browser: 103 | ```bash 104 | # On macOS 105 | open tests/performance/results/graphs/performance_dashboard.html 106 | 107 | # On Linux 108 | xdg-open tests/performance/results/graphs/performance_dashboard.html 109 | 110 | # On Windows 111 | start tests/performance/results/graphs/performance_dashboard.html 112 | ``` 113 | 114 | ## Interpreting Results 115 | 116 | When evaluating performance changes: 117 | - Latency: Lower is better 118 | - Throughput: Higher is better 119 | - Error rate: Should remain at or near 0% 120 | -------------------------------------------------------------------------------- /app/api/schemas/provider_key.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datetime import datetime 3 | 4 | from pydantic import BaseModel, ConfigDict, computed_field, Field, field_validator 5 | 6 | from app.core.logger import get_logger 7 | from app.core.security import decrypt_api_key 8 | from app.services.providers.adapter_factory import ProviderAdapterFactory 9 | 10 | logger = get_logger(name="provider_key") 11 | 12 | 13 | class ProviderKeyBase(BaseModel): 14 | provider_name: str = Field(..., min_length=1) 15 | api_key: str 16 | base_url: str | None = None 17 | model_mapping: dict[str, str] | None = None 18 | config: dict[str, str] | None = None 19 | 20 | 21 | class ProviderKeyCreate(ProviderKeyBase): 22 | pass 23 | 24 | 25 | class ProviderKeyUpdate(BaseModel): 26 | api_key: str | None = None 27 | config: dict[str, str] | None = None 28 | base_url: str | None = None 29 | model_mapping: dict[str, str] | None = None 30 | 31 | 32 | class ProviderKeyInDBBase(BaseModel): 33 | id: int 34 | provider_name: str 35 | user_id: int 36 | base_url: str | None = None 37 | model_mapping: dict[str, str] | None = None 38 | created_at: datetime 39 | updated_at: datetime 40 | encrypted_api_key: str 41 | model_config = ConfigDict(from_attributes=True) 42 | 43 | @field_validator("model_mapping", mode="before") 44 | @classmethod 45 | def parse_model_mapping(cls, v): 46 | """Parse JSON string to dictionary for model_mapping field.""" 47 | if v is None: 48 | return None 49 | if isinstance(v, str): 50 | try: 51 | return json.loads(v) 52 | except json.JSONDecodeError: 53 | logger.warning(f"Failed to parse model_mapping JSON: {v}") 54 | return {} 55 | return v 56 | 57 | 58 | class ProviderKey(ProviderKeyInDBBase): 59 | @computed_field 60 | @property 61 | def api_key(self) -> str | None: 62 | """Masked API key for responses.""" 63 | if self.encrypted_api_key: 64 | decrypted_value = decrypt_api_key(self.encrypted_api_key) 65 | provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls( 66 | self.provider_name 67 | ) 68 | try: 69 | api_key, _ = provider_adapter_cls.deserialize_api_key_config( 70 | decrypted_value 71 | ) 72 | return provider_adapter_cls.mask_api_key(api_key) 73 | except Exception as e: 74 | logger.error( 75 | f"Error deserializing API key for provider {self.provider_name}: {e}" 76 | ) 77 | return None 78 | return None 79 | 80 | @computed_field 81 | @property 82 | def config(self) -> dict[str, str] | None: 83 | """Masked config for responses.""" 84 | if self.encrypted_api_key: 85 | decrypted_value = decrypt_api_key(self.encrypted_api_key) 86 | provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls( 87 | self.provider_name 88 | ) 89 | try: 90 | _, config = provider_adapter_cls.deserialize_api_key_config( 91 | decrypted_value 92 | ) 93 | return provider_adapter_cls.mask_config(config) 94 | except Exception as e: 95 | logger.error( 96 | f"Error deserializing config for provider {self.provider_name}: {e}" 97 | ) 98 | return None 99 | return None 100 | 101 | 102 | class ProviderKeyUpsertItem(BaseModel): 103 | provider_name: str = Field(..., min_length=1) 104 | api_key: str | None = None 105 | base_url: str | None = None 106 | model_mapping: dict[str, str] | None = None 107 | config: dict[str, str] | None = None 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | TensorBlock Forge is licensed under the Apache License, Version 2.0. In addition to the standard terms of the Apache License 2.0, your use of TensorBlock Forge is subject to the following Supplemental Terms: 2 | 3 | ------------------------------------------------------------------------------- 4 | 1. Commercial Use Restrictions 5 | ------------------------------------------------------------------------------- 6 | 7 | The standard open-source license granted under Apache 2.0 does not apply to the following categories of usage. If you intend to use TensorBlock Forge in any of the following ways, you must obtain prior written commercial authorization from the TensorBlock team: 8 | 9 | 1. Modifications and Derivative Works 10 | Use of modified versions of TensorBlock Forge or derivative works thereof, including but not limited to changes to application branding, codebase, features, user interface, or data structures, for commercial purposes. 11 | 12 | 2. Enterprise Deployment 13 | Use of TensorBlock Forge within a commercial organization or in services provided to third-party enterprise clients. 14 | 15 | 3. Public Cloud-Based Services 16 | Operation of public cloud services built upon TensorBlock Forge (e.g., SaaS platforms, hosted APIs, multi-tenant web services) accessible to general users. 17 | 18 | 4. Bundled Distribution with Hardware 19 | Distribution of TensorBlock Forge bundled, integrated, or pre-installed on hardware products intended for commercial sale. 20 | 21 | ------------------------------------------------------------------------------- 22 | 2. Contributor Agreement 23 | ------------------------------------------------------------------------------- 24 | 25 | By contributing to TensorBlock Forge in the form of code, documentation, or other assets, you acknowledge and agree to the following terms: 26 | 27 | 1. Right to Re-License 28 | You grant the maintainers of TensorBlock Forge a perpetual, irrevocable, worldwide, non-exclusive license to use, reproduce, modify, sublicense, and distribute your contributions under both open-source and commercial licenses. 29 | 30 | 2. Commercial Use of Contributions 31 | Your contributions may be incorporated into commercial versions of TensorBlock Forge or other related products offered by the TensorBlock team, including but not limited to proprietary deployments, enterprise features, and hosted services. 32 | 33 | 3. License Policy Adjustments 34 | TensorBlock reserves the right to modify the licensing structure of TensorBlock Forge in the future, in accordance with project governance and development objectives. Any such change shall not affect your rights under the version of the license in effect at the time of your contribution, unless explicitly agreed otherwise. 35 | 36 | ------------------------------------------------------------------------------- 37 | 3. General Provisions 38 | ------------------------------------------------------------------------------- 39 | 40 | 1. Interpretation 41 | The interpretation and enforcement of these Supplemental Terms shall be at the sole discretion of the TensorBlock development team. 42 | 43 | 2. Amendments 44 | These Supplemental Terms may be updated from time to time to reflect changes in project governance, usage scenarios, or business needs. Updates will be communicated via the software interface or official project communication channels. Continued use of the software after such notice constitutes acceptance of the updated terms. 45 | 46 | ------------------------------------------------------------------------------- 47 | Except as explicitly modified by these Supplemental Terms, all other rights and obligations regarding your use of TensorBlock Forge are governed by the Apache License, Version 2.0. 48 | 49 | You may access the full Apache 2.0 license text at: 50 | http://www.apache.org/licenses/LICENSE-2.0 51 | 52 | For commercial licensing inquiries, please contact the TensorBlock team at: 53 | contact@tensorblock.co 54 | ------------------------------------------------------------------------------- -------------------------------------------------------------------------------- /app/services/providers/fireworks_adapter.py: -------------------------------------------------------------------------------- 1 | from .openai_adapter import OpenAIAdapter 2 | from typing import Any 3 | 4 | 5 | class FireworksAdapter(OpenAIAdapter): 6 | """Adapter for Fireworks API""" 7 | 8 | FIREWORKS_MODEL_MAPPING = { 9 | "Llama 4 Maverick Instruct (Basic)": "accounts/fireworks/models/llama4-maverick-instruct-basic", 10 | "Llama 4 Scout Instruct (Basic)": "accounts/fireworks/models/llama4-scout-instruct-basic", 11 | "Llama 3.1 405B Instruct": "accounts/fireworks/models/llama-v3p1-405b-instruct", 12 | "Llama 3.1 8B Instruct": "accounts/fireworks/models/llama-v3p1-8b-instruct", 13 | "Llama 3.3 70B Instruct": "accounts/fireworks/models/llama-v3p3-70b-instruct", 14 | "Llama 3.2 90B Vision Instruct": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct", 15 | "DeepSeek V3": "accounts/fireworks/models/deepseek-v3", 16 | "DeepSeek R1 (Fast)": "accounts/fireworks/models/deepseek-r1", 17 | "DeepSeek R1 (Basic)": "accounts/fireworks/models/deepseek-r1-basic", 18 | "Deepseek V3 03-24": "accounts/fireworks/models/deepseek-v3-0324", 19 | "Qwen Qwq 32b Preview": "accounts/fireworks/models/qwen-qwq-32b-preview", 20 | "Phi 3.5 Vision Instruct": "accounts/fireworks/models/phi-3-vision-128k-instruct", 21 | "Firesearch Ocr V6": "accounts/fireworks/models/firesearch-ocr-v6", 22 | "Yi-Large": "accounts/yi-01-ai/models/yi-large", 23 | "Llama V3p1 405b Instruct Long": "accounts/fireworks/models/llama-v3p1-405b-instruct-long", 24 | "Llama Guard 3 8b": "accounts/fireworks/models/llama-guard-3-8b", 25 | "Dobby-Unhinged-Llama-3.3-70B": "accounts/sentientfoundation/models/dobby-unhinged-llama-3-3-70b-new", 26 | "Mixtral MoE 8x22B Instruct": "accounts/fireworks/models/mixtral-8x22b-instruct", 27 | "Qwen2.5 72B Instruct": "accounts/fireworks/models/qwen2p5-72b-instruct", 28 | "QwQ-32B": "accounts/fireworks/models/qwq-32b", 29 | "Qwen2 VL 72B Instruct": "accounts/fireworks/models/qwen2-vl-72b-instruct", 30 | } 31 | 32 | def __init__( 33 | self, 34 | provider_name: str, 35 | base_url: str, 36 | config: dict[str, Any] | None = None, 37 | ): 38 | self._base_url = base_url 39 | super().__init__(provider_name, base_url, config=config) 40 | 41 | async def list_models(self, api_key: str) -> list[str]: 42 | """List all models (verbosely) supported by the provider""" 43 | # Check cache first 44 | cached_models = self.get_cached_models(api_key, self._base_url) 45 | if cached_models is not None: 46 | return cached_models 47 | 48 | # If not in cache, make API call 49 | # headers = { 50 | # "Authorization": f"Bearer {api_key}", 51 | # "Content-Type": "application/json", 52 | # } 53 | # fireworks models requires a account id, which we use the official account fireworks 54 | # https://docs.fireworks.ai/api-reference/list-models 55 | # And fireworks inference api and list models api doesn't share the same base url 56 | # url = "https://api.fireworks.ai/v1/accounts/fireworks/models" 57 | 58 | # async with ( 59 | # aiohttp.ClientSession() as session, 60 | # session.get(url, headers=headers, params={"pageSize": 200}) as response, 61 | # ): 62 | # if response.status != HTTPStatus.OK: 63 | # error_text = await response.text() 64 | # raise ValueError(f"Fireworks API error: {error_text}") 65 | # resp = await response.json() 66 | # self.FIREWORKS_MODEL_MAPPING = { 67 | # d["displayName"]: d["name"] for d in resp["models"] 68 | # } 69 | # return [d["name"] for d in resp["models"]] 70 | 71 | # TODO: currently fireworks api doesn't support list all the serverless models 72 | # We simply hardcode the models here 73 | return list(self.FIREWORKS_MODEL_MAPPING.values()) 74 | -------------------------------------------------------------------------------- /tests/mock_testing/mock_openai.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for patching OpenAI's client with our mock client. 3 | This allows tests to run without real API calls. 4 | """ 5 | 6 | import importlib.util 7 | import os 8 | import sys 9 | from unittest.mock import patch 10 | 11 | # Add the parent directory to the path 12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 13 | 14 | # Import the mock client 15 | from app.services.providers.mock_provider import MockClient 16 | 17 | 18 | # Create a mock openai module that can be used when the real one is not installed 19 | class MockOpenAIModule: 20 | def __init__(self): 21 | self.OpenAI = MockClient 22 | self.AsyncOpenAI = MockClient 23 | 24 | 25 | # Check if openai is installed 26 | if importlib.util.find_spec("openai") is None: 27 | # Create a mock openai module 28 | mock_openai = MockOpenAIModule() 29 | # Add it to sys.modules so import statements work 30 | sys.modules["openai"] = mock_openai 31 | # Import the mock module 32 | import openai as _mock_openai # noqa: F401 33 | 34 | 35 | def enable_mock_openai(): 36 | """ 37 | Enable mocking of OpenAI's client for testing. 38 | This function returns a context manager that can be used with 'with'. 39 | 40 | Example: 41 | with enable_mock_openai(): 42 | client = openai.OpenAI(api_key="fake-key") 43 | # All calls to client will use the mock implementation 44 | """ 45 | # Create patchers 46 | openai_client_patcher = patch("openai.OpenAI", return_value=MockClient()) 47 | 48 | # Start the patchers 49 | mock_openai_client = openai_client_patcher.start() 50 | 51 | try: 52 | yield mock_openai_client 53 | finally: 54 | # Stop the patchers 55 | openai_client_patcher.stop() 56 | 57 | 58 | def patch_with_mock(): 59 | """Decorator to patch OpenAI with mock client for a test function""" 60 | 61 | def decorator(func): 62 | async def wrapper(*args, **kwargs): 63 | # Create patchers 64 | openai_client_patcher = patch("openai.OpenAI", return_value=MockClient()) 65 | async_openai_client_patcher = patch( 66 | "openai.AsyncOpenAI", return_value=MockClient() 67 | ) 68 | 69 | # Start the patchers 70 | openai_client_patcher.start() 71 | async_openai_client_patcher.start() 72 | 73 | try: 74 | # Run the function 75 | result = await func(*args, **kwargs) 76 | return result 77 | finally: 78 | # Stop the patchers 79 | openai_client_patcher.stop() 80 | async_openai_client_patcher.stop() 81 | 82 | return wrapper 83 | 84 | return decorator 85 | 86 | 87 | class MockPatch: 88 | """ 89 | Class-based helper for patching OpenAI in tests. 90 | This can be used as a context manager or standalone. 91 | """ 92 | 93 | def __init__(self): 94 | self.openai_client_patcher = patch("openai.OpenAI", return_value=MockClient()) 95 | self.async_openai_client_patcher = patch( 96 | "openai.AsyncOpenAI", return_value=MockClient() 97 | ) 98 | self.patched = False 99 | 100 | def start(self): 101 | """Start patching OpenAI""" 102 | if not self.patched: 103 | self.openai_client_patcher.start() 104 | self.async_openai_client_patcher.start() 105 | self.patched = True 106 | return self 107 | 108 | def stop(self): 109 | """Stop patching OpenAI""" 110 | if self.patched: 111 | self.openai_client_patcher.stop() 112 | self.async_openai_client_patcher.stop() 113 | self.patched = False 114 | 115 | def __enter__(self): 116 | """Enter context manager""" 117 | return self.start() 118 | 119 | def __exit__(self, exc_type, exc_val, exc_tb): 120 | """Exit context manager""" 121 | self.stop() 122 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Forge Integration Tests 2 | 3 | This directory contains tests for Forge, including unit tests and integration tests. 4 | 5 | ## Integration Tests 6 | 7 | The main integration test script is: 8 | 9 | - `integration_test.py` - Supports both local development with real API calls and mock mode for CI/CD environments 10 | 11 | ### Prerequisites 12 | 13 | Before running integration tests, make sure: 14 | 15 | 1. The Forge server is running (`python run.py` from the project root) 16 | 2. Required Python packages are installed: 17 | ``` 18 | pip install requests python-dotenv 19 | ``` 20 | 21 | ### Running Integration Tests 22 | 23 | You can run the integration tests in different modes: 24 | 25 | #### With Real API Calls 26 | 27 | For local development with actual API calls to external providers: 28 | 29 | ```bash 30 | # Run the full integration test (requires API keys) 31 | python tests/integration_test.py 32 | ``` 33 | 34 | This test will register a user, add provider keys, and test completions with real API calls. You'll need: 35 | 36 | - OpenAI API key (set as OPENAI_API_KEY in .env) 37 | - (Optional) Anthropic API key (set as ANTHROPIC_API_KEY in .env) 38 | 39 | #### With Mocked Responses 40 | 41 | For CI environments or when you don't want to make real API calls: 42 | 43 | ```bash 44 | # Use CI testing mode 45 | CI_TESTING=true python tests/integration_test.py 46 | 47 | # Or use the SKIP_API_CALLS flag (same effect) 48 | SKIP_API_CALLS=true python tests/integration_test.py 49 | ``` 50 | 51 | In mock mode: 52 | - External API calls are replaced with mock responses 53 | - The server connection is still verified 54 | - User registration and management features are tested 55 | - API interactions use the mock provider to simulate responses 56 | - The mock provider returns predefined responses for testing 57 | 58 | ### Mock Provider 59 | 60 | The integration test uses the mock provider from `app/services/providers/mock_provider.py` when running in mock mode. This provider: 61 | 62 | - Simulates API responses without making actual API calls 63 | - Provides mock models similar to those from real providers 64 | - Returns consistent, predictable responses for testing 65 | - Can be used with the `CI_TESTING=true` or `SKIP_API_CALLS=true` flag 66 | 67 | ### GitHub Actions Integration 68 | 69 | The test is automatically run in GitHub Actions workflows defined in `.github/workflows/tests.yml`. The workflow: 70 | 71 | 1. Sets up a test environment 72 | 2. Starts the Forge server 73 | 3. Runs the integration tests in CI mode 74 | 4. Ensures no actual API calls are made to external services 75 | 76 | ## Unit Tests 77 | 78 | Run individual unit tests with: 79 | 80 | ```bash 81 | python -m unittest tests/test_security.py 82 | python -m unittest tests/test_provider_service.py 83 | ``` 84 | 85 | Run all unit tests with: 86 | 87 | ```bash 88 | python -m unittest discover tests 89 | ``` 90 | 91 | ## Cache Tests 92 | 93 | The cache test directory contains tests for both synchronous and asynchronous caching functionality: 94 | 95 | - `test_sync_cache.py` - Tests the synchronous in-memory caching with ProviderService instances 96 | - `test_async_cache.py` - Tests the async-compatible cache implementation for future distributed caching 97 | 98 | ### Running Cache Tests 99 | 100 | ```bash 101 | # Run synchronous cache tests 102 | python tests/cache/test_sync_cache.py 103 | 104 | # Run asynchronous cache tests 105 | python tests/cache/test_async_cache.py 106 | ``` 107 | 108 | These tests verify: 109 | - Cache hit/miss behavior 110 | - Performance improvements from caching 111 | - Proper instance reuse with singleton pattern 112 | - AsyncCache compatibility with asyncio patterns 113 | 114 | The async tests are important for validating Forge's readiness for distributed caching solutions like Redis or AWS ElasticCache, as outlined in `docs/DISTRIBUTED_CACHE_MIGRATION.md`. 115 | 116 | ## Test Coverage Report 117 | 118 | Generate a test coverage report with: 119 | 120 | ```bash 121 | pytest tests/test_*.py --cov=app --cov-report=xml 122 | ``` 123 | -------------------------------------------------------------------------------- /app/api/routes/auth.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Any 3 | 4 | from fastapi import APIRouter, Depends, HTTPException, status 5 | from fastapi.security import OAuth2PasswordRequestForm 6 | from sqlalchemy import select 7 | from sqlalchemy.ext.asyncio import AsyncSession 8 | 9 | from app.api.routes.users import create_user as create_user_endpoint_logic 10 | from app.api.schemas.user import Token, User, UserCreate 11 | from app.core.database import get_async_db 12 | from app.core.logger import get_logger 13 | from app.core.security import ( 14 | ACCESS_TOKEN_EXPIRE_MINUTES, 15 | create_access_token, 16 | verify_password, 17 | ) 18 | from app.models.user import User as UserModel 19 | 20 | logger = get_logger(name="auth") 21 | 22 | router = APIRouter() 23 | 24 | 25 | @router.post("/register", response_model=User) 26 | async def register( 27 | user_in: UserCreate, 28 | db: AsyncSession = Depends(get_async_db), 29 | ) -> Any: 30 | """ 31 | Register new user. This will create the user but will not automatically create a Forge API key. 32 | Users should use the /api-keys/ endpoint to create their keys after registration. 33 | """ 34 | # Call the user creation logic from users.py 35 | # This handles checks for existing email/username and password hashing. 36 | try: 37 | db_user = await create_user_endpoint_logic(user_in=user_in, db=db) 38 | except HTTPException as e: # Propagate HTTPExceptions (like 400 for existing user) 39 | raise e 40 | except Exception as e: # Catch any other unexpected errors during user creation 41 | # Log this error e 42 | logger.error( 43 | f"Unexpected error during create_user_endpoint_logic call: {e}" 44 | ) # Added more logging 45 | raise HTTPException( 46 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 47 | detail="An unexpected error occurred during user registration.", 48 | ) 49 | 50 | # Prepare the response. 51 | # create_user_endpoint_logic returns a UserModel instance. 52 | # The User Pydantic model has from_attributes = True. 53 | try: 54 | # For Pydantic v2+ (which is likely if using FastAPI > 0.100.0) 55 | pydantic_user = User.model_validate(db_user) 56 | # If using Pydantic v1, you would use: 57 | # pydantic_user = User.from_orm(db_user) 58 | except Exception as e_pydantic: 59 | # Log this validation error to understand what went wrong if it fails 60 | logger.error( 61 | f"Error during Pydantic model_validate in /auth/register: {e_pydantic}" 62 | ) 63 | logger.error( 64 | f"SQLAlchemy User object was: {db_user.__dict__ if db_user else 'None'}" 65 | ) 66 | raise HTTPException( 67 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 68 | detail="Error processing user data after creation.", 69 | ) 70 | 71 | pydantic_user.forge_api_keys = [] # Explicitly set to empty list as no key is auto-generated 72 | 73 | return pydantic_user 74 | 75 | 76 | @router.post("/token", response_model=Token) 77 | async def login_for_access_token( 78 | db: AsyncSession = Depends(get_async_db), 79 | form_data: OAuth2PasswordRequestForm = Depends() 80 | ) -> Any: 81 | """ 82 | Get an access token for future API requests. 83 | """ 84 | result = await db.execute( 85 | select(UserModel).filter(UserModel.username == form_data.username) 86 | ) 87 | user = result.scalar_one_or_none() 88 | 89 | if not user or not verify_password(form_data.password, user.hashed_password): 90 | raise HTTPException( 91 | status_code=status.HTTP_401_UNAUTHORIZED, 92 | detail="Incorrect username or password", 93 | headers={"WWW-Authenticate": "Bearer"}, 94 | ) 95 | 96 | access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) 97 | access_token = create_access_token( 98 | data={"sub": user.username}, expires_delta=access_token_expires 99 | ) 100 | 101 | return {"access_token": access_token, "token_type": "bearer"} 102 | -------------------------------------------------------------------------------- /tests/mock_testing/examples/test_with_mocks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Example tests that demonstrate using the mock client in different scenarios. 4 | """ 5 | 6 | import asyncio 7 | import os 8 | import sys 9 | import unittest 10 | 11 | # Add the parent directory to the path 12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) 13 | 14 | # Import our mock utilities 15 | from tests.mock_testing.mock_openai import MockPatch, patch_with_mock 16 | 17 | 18 | # This would be your application code that uses OpenAI 19 | class MyAppService: 20 | """Example application service that uses OpenAI""" 21 | 22 | def __init__(self, api_key=None): 23 | """Initialize with an API key""" 24 | self.api_key = api_key or os.getenv("OPENAI_API_KEY") 25 | 26 | # In a real app, you'd import OpenAI here 27 | # But we'll do it in the methods to allow for patching 28 | 29 | async def generate_response(self, user_message: str) -> str: 30 | """Generate a response to a user message""" 31 | import openai 32 | 33 | client = openai.OpenAI(api_key=self.api_key) 34 | 35 | response = await client.chat_completions_create( 36 | model="mock-only-gpt-3.5-turbo", 37 | messages=[ 38 | {"role": "system", "content": "You are a helpful assistant."}, 39 | {"role": "user", "content": user_message}, 40 | ], 41 | ) 42 | 43 | return response.choices[0].message.content 44 | 45 | async def get_available_models(self) -> list: 46 | """Get a list of available models""" 47 | import openai 48 | 49 | client = openai.OpenAI(api_key=self.api_key) 50 | 51 | models = await client.models_list() 52 | # Handle both object-style and dict-style responses 53 | if hasattr(models, "data") and hasattr(models.data[0], "id"): 54 | return [model.id for model in models.data] 55 | else: 56 | return [model["id"] for model in models.data] 57 | 58 | 59 | # Now let's write tests for our application using the mock client 60 | class AsyncTestCase(unittest.TestCase): 61 | """Base class for async test cases""" 62 | 63 | def run_async(self, coro): 64 | """Run a coroutine in an event loop""" 65 | return asyncio.run(coro) 66 | 67 | 68 | class TestMyAppWithMocks(AsyncTestCase): 69 | """Test the MyApp class with mocked OpenAI client""" 70 | 71 | def setUp(self): 72 | """Set up the test""" 73 | self.service = MyAppService(api_key="test-key") 74 | 75 | def test_generate_response(self): 76 | """Test generating a response with the mock client""" 77 | 78 | @patch_with_mock() 79 | async def _async_test(): 80 | response = await self.service.generate_response("Hello, how are you?") 81 | 82 | # The mock response will be predictable 83 | self.assertIsNotNone(response) 84 | self.assertIn("You asked", response) # Our mock adds this prefix 85 | return response 86 | 87 | self.run_async(_async_test()) 88 | 89 | def test_get_models(self): 90 | """Test getting available models with the mock client""" 91 | 92 | @patch_with_mock() 93 | async def _async_test(): 94 | models = await self.service.get_available_models() 95 | 96 | # The mock returns specific models 97 | self.assertIn("mock-gpt-3.5-turbo", models) 98 | self.assertIn("mock-gpt-4", models) 99 | return models 100 | 101 | self.run_async(_async_test()) 102 | 103 | def test_with_context_manager(self): 104 | """Test using the mock client with a context manager""" 105 | 106 | async def _async_test(): 107 | # Using the context manager approach 108 | mock_patch = MockPatch() 109 | with mock_patch: 110 | response = await self.service.generate_response("Tell me about testing") 111 | self.assertIsNotNone(response) 112 | self.assertIn("You asked", response) 113 | return response 114 | 115 | self.run_async(_async_test()) 116 | 117 | 118 | # If we run this directly, run all the tests 119 | if __name__ == "__main__": 120 | unittest.main() 121 | -------------------------------------------------------------------------------- /app/api/schemas/anthropic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Literal, Optional, Union 2 | from pydantic import BaseModel, Field, field_validator 3 | 4 | from app.core.logger import get_logger 5 | 6 | logger = get_logger(name="anthropic_schemas") 7 | 8 | # Content Block Models 9 | class ContentBlockText(BaseModel): 10 | type: Literal["text"] 11 | text: str 12 | 13 | 14 | class ContentBlockImageSource(BaseModel): 15 | type: str 16 | media_type: str 17 | data: str 18 | 19 | 20 | class ContentBlockImage(BaseModel): 21 | type: Literal["image"] 22 | source: ContentBlockImageSource 23 | 24 | 25 | class ContentBlockToolUse(BaseModel): 26 | type: Literal["tool_use"] 27 | id: str 28 | name: str 29 | input: Dict[str, Any] 30 | 31 | 32 | class ContentBlockToolResult(BaseModel): 33 | type: Literal["tool_result"] 34 | tool_use_id: str 35 | content: Union[str, List[Dict[str, Any]], List[Any]] 36 | is_error: Optional[bool] = None 37 | 38 | 39 | ContentBlock = Union[ 40 | ContentBlockText, ContentBlockImage, ContentBlockToolUse, ContentBlockToolResult 41 | ] 42 | 43 | 44 | # System Content 45 | class SystemContent(BaseModel): 46 | type: Literal["text"] 47 | text: str 48 | 49 | 50 | # Message Model 51 | class AnthropicMessage(BaseModel): 52 | role: Literal["user", "assistant"] 53 | content: Union[str, List[ContentBlock]] 54 | 55 | 56 | # Tool Models 57 | class Tool(BaseModel): 58 | name: str 59 | description: Optional[str] = None 60 | input_schema: Dict[str, Any] = Field(..., alias="input_schema") 61 | 62 | 63 | class ToolChoice(BaseModel): 64 | type: Literal["auto", "any", "tool"] 65 | name: Optional[str] = None 66 | 67 | 68 | # Main Request Model 69 | class AnthropicMessagesRequest(BaseModel): 70 | model: str 71 | max_tokens: int 72 | messages: List[AnthropicMessage] 73 | system: Optional[Union[str, List[SystemContent]]] = None 74 | stop_sequences: Optional[List[str]] = None 75 | stream: Optional[bool] = False 76 | temperature: Optional[float] = None 77 | top_p: Optional[float] = None 78 | top_k: Optional[int] = None 79 | metadata: Optional[Dict[str, Any]] = None 80 | tools: Optional[List[Tool]] = None 81 | tool_choice: Optional[ToolChoice] = None 82 | 83 | @field_validator("top_k") 84 | def check_top_k(cls, v: Optional[int]) -> Optional[int]: 85 | if v is not None: 86 | logger.warning( 87 | f"Parameter 'top_k' provided by client but is not directly supported by the OpenAI Chat Completions API and will be ignored. Value: {v}" 88 | ) 89 | return v 90 | 91 | 92 | # Token Count Request/Response 93 | class TokenCountRequest(BaseModel): 94 | model: str 95 | messages: List[AnthropicMessage] 96 | system: Optional[Union[str, List[SystemContent]]] = None 97 | tools: Optional[List[Tool]] = None 98 | 99 | 100 | class TokenCountResponse(BaseModel): 101 | input_tokens: int 102 | 103 | 104 | # Usage Model 105 | class Usage(BaseModel): 106 | input_tokens: int 107 | output_tokens: int 108 | 109 | 110 | # Error Models 111 | class AnthropicErrorType: 112 | INVALID_REQUEST = "invalid_request_error" 113 | AUTHENTICATION = "authentication_error" 114 | PERMISSION = "permission_error" 115 | NOT_FOUND = "not_found_error" 116 | RATE_LIMIT = "rate_limit_error" 117 | API_ERROR = "api_error" 118 | OVERLOADED = "overloaded_error" 119 | REQUEST_TOO_LARGE = "request_too_large_error" 120 | 121 | 122 | class AnthropicErrorDetail(BaseModel): 123 | type: str 124 | message: str 125 | provider: Optional[str] = None 126 | provider_message: Optional[str] = None 127 | provider_code: Optional[Union[str, int]] = None 128 | 129 | 130 | class AnthropicErrorResponse(BaseModel): 131 | type: Literal["error"] = "error" 132 | error: AnthropicErrorDetail 133 | 134 | 135 | # Response Model 136 | class AnthropicMessagesResponse(BaseModel): 137 | id: str 138 | type: Literal["message"] = "message" 139 | role: Literal["assistant"] = "assistant" 140 | model: str 141 | content: List[ContentBlock] 142 | stop_reason: Optional[ 143 | Literal["end_turn", "max_tokens", "stop_sequence", "tool_use", "error"] 144 | ] = None 145 | stop_sequence: Optional[str] = None 146 | usage: Usage -------------------------------------------------------------------------------- /app/api/routes/stats.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from typing import Any 3 | 4 | from fastapi import APIRouter, Depends, HTTPException, Query 5 | from sqlalchemy.ext.asyncio import AsyncSession 6 | 7 | from app.api.dependencies import ( 8 | get_current_active_user_from_clerk, 9 | get_current_user, 10 | get_async_db, 11 | get_user_by_api_key, 12 | ) 13 | from app.models.user import User 14 | from app.services.usage_stats_service import UsageStatsService 15 | 16 | router = APIRouter() 17 | 18 | 19 | # http://localhost:8000/stats/?start_date=2025-04-09&end_date=2025-04-09 20 | # http://localhost:8000/stats/?model=gpt-3.5-turbo 21 | @router.get("/", response_model=list[dict[str, Any]]) 22 | async def get_user_stats( 23 | current_user: User = Depends(get_user_by_api_key), 24 | provider: str | None = Query(None, description="Filter stats by provider name"), 25 | model: str | None = Query(None, description="Filter stats by model name"), 26 | start_date: date | None = Query( 27 | None, description="Start date for filtering (YYYY-MM-DD)" 28 | ), 29 | end_date: date | None = Query( 30 | None, description="End date for filtering (YYYY-MM-DD)" 31 | ), 32 | db: AsyncSession = Depends(get_async_db), 33 | ): 34 | """ 35 | Get aggregated usage statistics for the current user, queried from request logs. 36 | 37 | Allows filtering by provider, model, and date range (inclusive). 38 | """ 39 | # Note: Service layer now handles aggregation and filtering 40 | # We pass the query parameters directly to the service method 41 | stats = await UsageStatsService.get_user_stats( 42 | db=db, 43 | user_id=current_user.id, 44 | provider=provider, 45 | model=model, 46 | start_date=start_date, 47 | end_date=end_date, 48 | ) 49 | return stats 50 | 51 | 52 | # http://localhost:8000/stats/clerk/?start_date=2025-04-09&end_date=2025-04-09 53 | # http://localhost:8000/stats/clerk/?model=gpt-3.5-turbo 54 | @router.get("/clerk", response_model=list[dict[str, Any]]) 55 | async def get_user_stats_clerk( 56 | current_user: User = Depends(get_current_active_user_from_clerk), 57 | provider: str | None = Query(None, description="Filter stats by provider name"), 58 | model: str | None = Query(None, description="Filter stats by model name"), 59 | start_date: date | None = Query( 60 | None, description="Start date for filtering (YYYY-MM-DD)" 61 | ), 62 | end_date: date | None = Query( 63 | None, description="End date for filtering (YYYY-MM-DD)" 64 | ), 65 | db: AsyncSession = Depends(get_async_db), 66 | ): 67 | """ 68 | Get aggregated usage statistics for the current user, queried from request logs. 69 | 70 | Allows filtering by provider, model, and date range (inclusive). 71 | """ 72 | # Note: Service layer now handles aggregation and filtering 73 | # We pass the query parameters directly to the service method 74 | stats = await UsageStatsService.get_user_stats( 75 | db=db, 76 | user_id=current_user.id, 77 | provider=provider, 78 | model=model, 79 | start_date=start_date, 80 | end_date=end_date, 81 | ) 82 | return stats 83 | 84 | 85 | @router.get("/admin", response_model=list[dict[str, Any]]) 86 | async def get_all_stats( 87 | current_user: User = Depends(get_current_user), 88 | provider: str | None = Query(None, description="Filter stats by provider name"), 89 | model: str | None = Query(None, description="Filter stats by model name"), 90 | start_date: date | None = Query( 91 | None, description="Start date for filtering (YYYY-MM-DD)" 92 | ), 93 | end_date: date | None = Query( 94 | None, description="End date for filtering (YYYY-MM-DD)" 95 | ), 96 | db: AsyncSession = Depends(get_async_db), 97 | ): 98 | """ 99 | Get aggregated usage statistics for all users, queried from request logs. 100 | 101 | Only accessible to admin users. Allows filtering. 102 | """ 103 | # Check if user is an admin 104 | if not getattr(current_user, "is_admin", False): 105 | raise HTTPException( 106 | status_code=403, detail="Not authorized to access admin statistics" 107 | ) 108 | 109 | stats = await UsageStatsService.get_all_stats( 110 | db=db, provider=provider, model=model, start_date=start_date, end_date=end_date 111 | ) 112 | return stats 113 | -------------------------------------------------------------------------------- /app/exceptions/exceptions.py: -------------------------------------------------------------------------------- 1 | class BaseForgeException(Exception): 2 | pass 3 | 4 | class InvalidProviderException(BaseForgeException): 5 | """Exception raised when a provider is invalid.""" 6 | 7 | def __init__(self, identifier: str): 8 | self.identifier = identifier 9 | super().__init__(f"Provider {identifier} is invalid or failed to extract provider info from model_id {identifier}") 10 | 11 | 12 | class ProviderAuthenticationException(BaseForgeException): 13 | """Exception raised when a provider authentication fails.""" 14 | 15 | def __init__(self, provider_name: str, error: Exception): 16 | self.provider_name = provider_name 17 | self.error = error 18 | super().__init__(f"Provider {provider_name} authentication failed: {error}") 19 | 20 | 21 | class BaseInvalidProviderSetupException(BaseForgeException): 22 | """Exception raised when a provider setup is invalid.""" 23 | 24 | def __init__(self, provider_name: str, error: Exception): 25 | self.provider_name = provider_name 26 | self.error = error 27 | super().__init__(f"Provider {provider_name} setup is invalid: {error}") 28 | 29 | class InvalidProviderConfigException(BaseInvalidProviderSetupException): 30 | """Exception raised when a provider config is invalid.""" 31 | 32 | def __init__(self, provider_name: str, error: Exception): 33 | super().__init__(provider_name, error) 34 | 35 | class InvalidProviderAPIKeyException(BaseInvalidProviderSetupException): 36 | """Exception raised when a provider API key is invalid.""" 37 | 38 | def __init__(self, provider_name: str, error: Exception): 39 | super().__init__(provider_name, error) 40 | 41 | class ProviderAPIException(BaseForgeException): 42 | """Exception raised when a provider API error occurs.""" 43 | 44 | def __init__(self, provider_name: str, error_code: int, error_message: str): 45 | """Initialize the exception and persist error details for downstream handling. 46 | 47 | Many parts of the codebase (e.g. the Claude Code routes) rely on the 48 | presence of the ``error_code`` and ``error_message`` attributes to 49 | construct a well-formed error response. Without setting these instance 50 | attributes an ``AttributeError`` is raised when the exception is caught 51 | and introspected, masking the original provider failure. Persisting the 52 | values here guarantees that the original error information is available 53 | to any error-handling middleware. 54 | """ 55 | self.provider_name = provider_name 56 | self.error_code = error_code 57 | self.error_message = error_message 58 | 59 | # Compose the base exception message for logging / debugging purposes. 60 | super().__init__( 61 | f"Provider {provider_name} API error: {error_code} {error_message}" 62 | ) 63 | 64 | 65 | class BaseInvalidRequestException(BaseForgeException): 66 | """Exception raised when a request is invalid.""" 67 | 68 | def __init__(self, provider_name: str, error: Exception): 69 | self.provider_name = provider_name 70 | self.error = error 71 | super().__init__(f"Provider {provider_name} request is invalid: {error}") 72 | 73 | class InvalidCompletionRequestException(BaseInvalidRequestException): 74 | """Exception raised when a completion request is invalid.""" 75 | 76 | def __init__(self, provider_name: str, error: Exception): 77 | self.provider_name = provider_name 78 | self.error = error 79 | super().__init__(self.provider_name, self.error) 80 | 81 | class InvalidEmbeddingsRequestException(BaseInvalidRequestException): 82 | """Exception raised when a embeddings request is invalid.""" 83 | 84 | def __init__(self, provider_name: str, error: Exception): 85 | self.provider_name = provider_name 86 | self.error = error 87 | super().__init__(self.provider_name, self.error) 88 | 89 | class BaseInvalidForgeKeyException(BaseForgeException): 90 | """Exception raised when a Forge key is invalid.""" 91 | 92 | def __init__(self, error: Exception): 93 | self.error = error 94 | super().__init__(f"Forge key is invalid: {error}") 95 | 96 | 97 | class InvalidForgeKeyException(BaseInvalidForgeKeyException): 98 | """Exception raised when a Forge key is invalid.""" 99 | def __init__(self, error: Exception): 100 | super().__init__(error) -------------------------------------------------------------------------------- /app/api/routes/health.py: -------------------------------------------------------------------------------- 1 | """ 2 | Health check and monitoring endpoints for production deployments. 3 | """ 4 | 5 | from fastapi import APIRouter, HTTPException 6 | from sqlalchemy import text 7 | 8 | from app.core.database import get_connection_info, get_db_session 9 | from app.core.logger import get_logger 10 | 11 | logger = get_logger(name="health") 12 | router = APIRouter() 13 | 14 | 15 | @router.get("/health") 16 | async def health_check(): 17 | """ 18 | Basic health check endpoint. 19 | Returns 200 if the service is running. 20 | """ 21 | return {"status": "healthy", "service": "forge"} 22 | 23 | 24 | @router.get("/health/database") 25 | async def database_health_check(): 26 | """ 27 | Database health check endpoint. 28 | Returns detailed information about database connectivity and pool status. 29 | """ 30 | try: 31 | # Test database connection 32 | async with get_db_session() as session: 33 | result = await session.execute(text("SELECT 1")) 34 | result.scalar() 35 | 36 | # Get connection pool information 37 | pool_info = get_connection_info() 38 | 39 | # Calculate connection usage 40 | sync_pool = pool_info['sync_engine'] 41 | async_pool = pool_info['async_engine'] 42 | 43 | sync_usage = sync_pool['checked_out'] / (pool_info['pool_size'] + pool_info['max_overflow']) * 100 44 | async_usage = async_pool['checked_out'] / (pool_info['pool_size'] + pool_info['max_overflow']) * 100 45 | 46 | return { 47 | "status": "healthy", 48 | "database": "connected", 49 | "connection_pools": { 50 | "sync": { 51 | "checked_out": sync_pool['checked_out'], 52 | "checked_in": sync_pool['checked_in'], 53 | "size": sync_pool['size'], 54 | "usage_percent": round(sync_usage, 1) 55 | }, 56 | "async": { 57 | "checked_out": async_pool['checked_out'], 58 | "checked_in": async_pool['checked_in'], 59 | "size": async_pool['size'], 60 | "usage_percent": round(async_usage, 1) 61 | } 62 | }, 63 | "configuration": { 64 | "pool_size": pool_info['pool_size'], 65 | "max_overflow": pool_info['max_overflow'], 66 | "pool_timeout": pool_info['pool_timeout'], 67 | "pool_recycle": pool_info['pool_recycle'] 68 | } 69 | } 70 | 71 | except Exception as e: 72 | logger.error(f"Database health check failed: {e}") 73 | raise HTTPException( 74 | status_code=503, 75 | detail={ 76 | "status": "unhealthy", 77 | "database": "disconnected", 78 | "error": str(e) 79 | } 80 | ) 81 | 82 | 83 | @router.get("/health/detailed") 84 | async def detailed_health_check(): 85 | """ 86 | Detailed health check including all service components. 87 | """ 88 | try: 89 | # Test database 90 | async with get_db_session() as session: 91 | db_result = await session.execute(text("SELECT version()")) 92 | db_version = db_result.scalar() 93 | 94 | pool_info = get_connection_info() 95 | 96 | return { 97 | "status": "healthy", 98 | "timestamp": "2025-01-21T19:15:00Z", # This would be dynamic in real implementation 99 | "service": "forge", 100 | "version": "0.1.0", 101 | "database": { 102 | "status": "connected", 103 | "version": db_version, 104 | "pool_status": pool_info 105 | }, 106 | "environment": { 107 | "workers": pool_info.get('workers', 'unknown'), 108 | "pool_size": pool_info['pool_size'], 109 | "max_overflow": pool_info['max_overflow'] 110 | } 111 | } 112 | 113 | except Exception as e: 114 | logger.error(f"Detailed health check failed: {e}") 115 | raise HTTPException( 116 | status_code=503, 117 | detail={ 118 | "status": "unhealthy", 119 | "error": str(e), 120 | "timestamp": "2025-01-21T19:15:00Z" 121 | } 122 | ) -------------------------------------------------------------------------------- /alembic/versions/b38aad374524_create_api_request_log_table.py: -------------------------------------------------------------------------------- 1 | """Create api_request_log table 2 | 3 | Revision ID: b38aad374524 4 | Revises: c50fd7be794c 5 | Create Date: 2025-04-08 21:44:00.759874 6 | 7 | """ 8 | 9 | import sqlalchemy as sa 10 | 11 | from alembic import op 12 | 13 | # revision identifiers, used by Alembic. 14 | revision = "b38aad374524" 15 | down_revision = "c50fd7be794c" 16 | branch_labels = None 17 | depends_on = None 18 | 19 | 20 | def upgrade() -> None: 21 | # ### commands auto generated by Alembic - START ### 22 | op.create_table( 23 | "api_request_log", 24 | sa.Column("id", sa.Integer(), nullable=False), 25 | sa.Column("user_id", sa.Integer(), nullable=True), 26 | sa.Column("provider_name", sa.String(), nullable=False), 27 | sa.Column("model", sa.String(), nullable=False), 28 | sa.Column("endpoint", sa.String(), nullable=False), 29 | sa.Column("request_timestamp", sa.DateTime(), nullable=False), 30 | sa.Column("input_tokens", sa.Integer(), nullable=True), 31 | sa.Column("output_tokens", sa.Integer(), nullable=True), 32 | sa.Column("total_tokens", sa.Integer(), nullable=True), 33 | sa.Column("cost", sa.Float(), nullable=True), 34 | sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="SET NULL"), 35 | sa.PrimaryKeyConstraint("id"), 36 | ) 37 | op.create_index( 38 | "ix_api_request_log_user_time", 39 | "api_request_log", 40 | ["user_id", "request_timestamp"], 41 | unique=False, 42 | ) 43 | op.create_index( 44 | op.f("ix_api_request_log_endpoint"), 45 | "api_request_log", 46 | ["endpoint"], 47 | unique=False, 48 | ) 49 | op.create_index( 50 | op.f("ix_api_request_log_model"), "api_request_log", ["model"], unique=False 51 | ) 52 | op.create_index( 53 | op.f("ix_api_request_log_provider_name"), 54 | "api_request_log", 55 | ["provider_name"], 56 | unique=False, 57 | ) 58 | op.create_index( 59 | op.f("ix_api_request_log_request_timestamp"), 60 | "api_request_log", 61 | ["request_timestamp"], 62 | unique=False, 63 | ) 64 | op.create_index( 65 | op.f("ix_api_request_log_user_id"), "api_request_log", ["user_id"], unique=False 66 | ) 67 | # op.drop_constraint(None, 'usage_stats', type_='foreignkey') # REMOVED 68 | # op.create_foreign_key(None, 'usage_stats', 'users', ['user_id'], ['id'], ondelete='CASCADE') # REMOVED 69 | # op.drop_column('usage_stats', 'completion_tokens') # REMOVED 70 | # op.drop_column('usage_stats', 'timestamp') # REMOVED 71 | # op.drop_column('usage_stats', 'error_count') # REMOVED 72 | # op.drop_column('usage_stats', 'prompt_tokens') # REMOVED 73 | # op.drop_column('usage_stats', 'success_count') # REMOVED 74 | # op.drop_column('usage_stats', 'request_count') # REMOVED 75 | # ### end Alembic commands ### 76 | 77 | 78 | def downgrade() -> None: 79 | # ### commands auto generated by Alembic - START ### 80 | # op.add_column('usage_stats', sa.Column('request_count', sa.INTEGER(), nullable=True)) # REMOVED 81 | # op.add_column('usage_stats', sa.Column('success_count', sa.INTEGER(), nullable=True)) # REMOVED 82 | # op.add_column('usage_stats', sa.Column('prompt_tokens', sa.INTEGER(), nullable=True)) # REMOVED 83 | # op.add_column('usage_stats', sa.Column('error_count', sa.INTEGER(), nullable=True)) # REMOVED 84 | # op.add_column('usage_stats', sa.Column('timestamp', sa.DATETIME(), nullable=True)) # REMOVED 85 | # op.add_column('usage_stats', sa.Column('completion_tokens', sa.INTEGER(), nullable=True)) # REMOVED 86 | # op.drop_constraint(None, 'usage_stats', type_='foreignkey') # REMOVED 87 | # op.create_foreign_key(None, 'usage_stats', 'users', ['user_id'], ['id']) # REMOVED 88 | op.drop_index(op.f("ix_api_request_log_user_id"), table_name="api_request_log") 89 | op.drop_index( 90 | op.f("ix_api_request_log_request_timestamp"), table_name="api_request_log" 91 | ) 92 | op.drop_index( 93 | op.f("ix_api_request_log_provider_name"), table_name="api_request_log" 94 | ) 95 | op.drop_index(op.f("ix_api_request_log_model"), table_name="api_request_log") 96 | op.drop_index(op.f("ix_api_request_log_endpoint"), table_name="api_request_log") 97 | op.drop_index("ix_api_request_log_user_time", table_name="api_request_log") 98 | op.drop_table("api_request_log") 99 | # ### end Alembic commands ### 100 | -------------------------------------------------------------------------------- /app/core/database.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dotenv import load_dotenv 4 | from contextlib import asynccontextmanager 5 | from sqlalchemy import create_engine 6 | from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession 7 | from sqlalchemy.orm import declarative_base, sessionmaker 8 | 9 | load_dotenv() 10 | 11 | # Production-optimized connection pool settings 12 | # With 10 Gunicorn workers, this allows max 60 connections total (10 workers × 3 pool_size × 2 engines) 13 | # Plus 40 overflow connections (10 workers × 2 max_overflow × 2 engines) = 100 max connections 14 | POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "3")) # Reduced from 5 to 3 15 | MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "2")) # Reduced from 10 to 2 16 | MAX_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30")) 17 | POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "1800")) # 30 minutes 18 | POOL_PRE_PING = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true" 19 | 20 | SQLALCHEMY_DATABASE_URL = os.getenv("DATABASE_URL") 21 | if not SQLALCHEMY_DATABASE_URL: 22 | raise ValueError("DATABASE_URL environment variable is not set") 23 | 24 | # Sync engine and session 25 | engine = create_engine( 26 | SQLALCHEMY_DATABASE_URL, 27 | pool_size=POOL_SIZE, 28 | max_overflow=MAX_OVERFLOW, 29 | pool_timeout=MAX_TIMEOUT, 30 | pool_recycle=POOL_RECYCLE, 31 | pool_pre_ping=POOL_PRE_PING, # Enables connection health checks 32 | echo=False, 33 | ) 34 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) 35 | 36 | 37 | # Sync dependency 38 | def get_db(): 39 | db = SessionLocal() 40 | try: 41 | yield db 42 | finally: 43 | db.close() 44 | 45 | 46 | # Async engine and session (new) 47 | # Convert the DATABASE_URL to async format if it's using psycopg2 48 | ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL 49 | if SQLALCHEMY_DATABASE_URL.startswith("postgresql://"): 50 | ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://") 51 | elif SQLALCHEMY_DATABASE_URL.startswith("postgresql+psycopg2://"): 52 | ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL.replace("postgresql+psycopg2://", "postgresql+asyncpg://") 53 | 54 | async_engine = create_async_engine( 55 | ASYNC_DATABASE_URL, 56 | pool_size=POOL_SIZE, 57 | max_overflow=MAX_OVERFLOW, 58 | pool_timeout=MAX_TIMEOUT, 59 | pool_recycle=POOL_RECYCLE, 60 | pool_pre_ping=POOL_PRE_PING, # Enables connection health checks 61 | echo=False, 62 | ) 63 | 64 | AsyncSessionLocal = async_sessionmaker( 65 | bind=async_engine, 66 | class_=AsyncSession, 67 | expire_on_commit=False, 68 | autoflush=False, 69 | ) 70 | 71 | Base = declarative_base() 72 | 73 | 74 | # Async dependency 75 | async def get_async_db(): 76 | async with AsyncSessionLocal() as session: 77 | try: 78 | yield session 79 | except Exception: 80 | # Rollback on any exception, but handle potential session state issues 81 | try: 82 | await session.rollback() 83 | except Exception: 84 | # If rollback fails (e.g., session already closed), ignore it 85 | # The context manager will handle session cleanup 86 | pass 87 | raise 88 | 89 | 90 | @asynccontextmanager 91 | async def get_db_session(): 92 | """Async context manager for database sessions""" 93 | async with AsyncSessionLocal() as session: 94 | try: 95 | yield session 96 | except Exception: 97 | # Rollback on any exception, but handle potential session state issues 98 | try: 99 | await session.rollback() 100 | except Exception: 101 | # If rollback fails (e.g., session already closed), ignore it 102 | # The context manager will handle session cleanup 103 | pass 104 | raise 105 | 106 | 107 | def get_connection_info(): 108 | """Get current connection pool information for monitoring""" 109 | return { 110 | "pool_size": POOL_SIZE, 111 | "max_overflow": MAX_OVERFLOW, 112 | "pool_timeout": MAX_TIMEOUT, 113 | "pool_recycle": POOL_RECYCLE, 114 | "sync_engine": { 115 | "pool": engine.pool, 116 | "checked_out": engine.pool.checkedout(), 117 | "checked_in": engine.pool.checkedin(), 118 | "size": engine.pool.size(), 119 | }, 120 | "async_engine": { 121 | "pool": async_engine.pool, 122 | "checked_out": async_engine.pool.checkedout(), 123 | "checked_in": async_engine.pool.checkedin(), 124 | "size": async_engine.pool.size(), 125 | } 126 | } -------------------------------------------------------------------------------- /tests/test_token_counting_stream.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import asyncio 4 | import json 5 | from http import HTTPStatus 6 | 7 | import aiohttp 8 | 9 | 10 | async def test_token_counting_stream(api_key: str, model: str): 11 | """ 12 | Test the token counting in streaming mode. 13 | """ 14 | url = "http://localhost:8000/v1/chat/completions" 15 | headers = {"Content-Type": "application/json", "X-API-Key": api_key} 16 | 17 | payload = { 18 | "model": model, 19 | "messages": [ 20 | {"role": "system", "content": "You are a helpful assistant."}, 21 | { 22 | "role": "user", 23 | "content": "Write a short story about a robot learning to count. Keep it to 3 paragraphs.", 24 | }, 25 | ], 26 | "stream": True, 27 | } 28 | 29 | print(f"\n[INFO] Sending streaming request to {url}") 30 | print(f"[INFO] Model: {model}") 31 | 32 | tokens_seen = 0 33 | chunk_count = 0 34 | 35 | # Use single with statement for session and post request 36 | async with ( 37 | aiohttp.ClientSession() as session, 38 | session.post(url, headers=headers, json=payload) as response, 39 | ): 40 | print(f"[INFO] Response status: {response.status}") 41 | 42 | if response.status != HTTPStatus.OK: 43 | error_text = await response.text() 44 | print(f"[ERROR] {error_text}") 45 | return 46 | 47 | async for line_bytes in response.content: 48 | line = line_bytes.decode("utf-8").strip() 49 | 50 | if not line: 51 | continue 52 | 53 | if line.startswith("data:"): 54 | chunk_count += 1 55 | data_str = line[5:].strip() 56 | 57 | if data_str == "[DONE]": 58 | print("\n[INFO] Streaming completed") 59 | print(f"[INFO] Total chunks: {chunk_count}") 60 | break 61 | 62 | try: 63 | data = json.loads(data_str) 64 | 65 | # Extract content from the chunk 66 | content = "" 67 | if ( 68 | "choices" in data 69 | and data["choices"] 70 | and "delta" in data["choices"][0] 71 | ): 72 | content = data["choices"][0]["delta"].get("content", "") 73 | 74 | # Keep track of content length as a proxy for tokens 75 | tokens_seen += len(content) / 4 # Simple approximation 76 | 77 | # Check if we have usage information 78 | if "usage" in data: 79 | usage = data.get("usage", {}) 80 | print(f"\n[TOKEN INFO] Reported usage: {usage}") 81 | 82 | if chunk_count % 10 == 1: 83 | print(f"\n[CHUNK {chunk_count}] Content: '{content}'") 84 | print( 85 | f"[TOKEN ESTIMATE] Approximately {int(tokens_seen)} tokens so far" 86 | ) 87 | 88 | except json.JSONDecodeError: 89 | print(f"[WARNING] Could not parse: {data_str}") 90 | 91 | print("\n[SUMMARY]") 92 | print(f"Total chunks received: {chunk_count}") 93 | print(f"Estimated token count: {int(tokens_seen)}") 94 | 95 | # After streaming completes, check the usage statistics 96 | print("\n[INFO] Checking usage statistics...") 97 | url = "http://localhost:8000/stats/" 98 | 99 | # Use single with statement for session and get request 100 | async with ( 101 | aiohttp.ClientSession() as session, 102 | session.get(url, headers=headers) as response, 103 | ): 104 | if response.status == HTTPStatus.OK: 105 | usage_data = await response.json() 106 | print("\n[USAGE STATS]") 107 | print(json.dumps(usage_data, indent=2)) 108 | else: 109 | print(f"[ERROR] Failed to get usage stats: {response.status}") 110 | error_text = await response.text() 111 | print(error_text) 112 | 113 | 114 | def main(): 115 | parser = argparse.ArgumentParser( 116 | description="Test token counting in streaming mode" 117 | ) 118 | parser.add_argument("--api-key", required=True, help="API key for authentication") 119 | parser.add_argument( 120 | "--model", default="gpt-3.5-turbo", help="Model to use for the test" 121 | ) 122 | 123 | args = parser.parse_args() 124 | 125 | asyncio.run(test_token_counting_stream(args.api_key, args.model)) 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /tests/frontend_simulation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from http import HTTPStatus 5 | 6 | import requests 7 | from dotenv import dotenv_values, load_dotenv 8 | 9 | # Load directly from .env file, bypassing system environment variables 10 | env_values = dotenv_values(".env") 11 | FORGE_API_KEY = env_values.get("FORGE_API_KEY", "") 12 | 13 | # Fall back to standard dotenv loading if not found directly 14 | if not FORGE_API_KEY: 15 | load_dotenv() 16 | FORGE_API_KEY = os.getenv("FORGE_API_KEY", "") 17 | 18 | # Configuration 19 | FORGE_API_URL = env_values.get( 20 | "FORGE_API_URL", os.getenv("FORGE_API_URL", "http://localhost:8000") 21 | ) 22 | 23 | if not FORGE_API_KEY: 24 | print("Error: FORGE_API_KEY environment variable is not set.") 25 | print("Please set it to your Forge API key and try again.") 26 | sys.exit(1) 27 | 28 | 29 | def chat_completion(messages, model="mock-only-gpt-3.5-turbo"): 30 | """Send a chat completion request to Forge""" 31 | url = f"{FORGE_API_URL}/chat/completions" 32 | headers = {"X-API-KEY": FORGE_API_KEY} 33 | data = {"model": model, "messages": messages, "temperature": 0.7} 34 | 35 | try: 36 | start_time = time.time() 37 | response = requests.post(url, json=data, headers=headers) 38 | end_time = time.time() 39 | 40 | if response.status_code == HTTPStatus.OK: 41 | result = response.json() 42 | print(f"Request completed in {end_time - start_time:.2f} seconds") 43 | return { 44 | "success": True, 45 | "response": result["choices"][0]["message"]["content"], 46 | "model": result["model"], 47 | "response_time": end_time - start_time, 48 | } 49 | else: 50 | return { 51 | "success": False, 52 | "error": f"API error: {response.status_code}", 53 | "details": response.text, 54 | } 55 | except Exception as e: 56 | return {"success": False, "error": "Request failed", "details": str(e)} 57 | 58 | 59 | def simulate_conversation(): 60 | """Simulate a conversation using the Forge API""" 61 | messages = [] 62 | 63 | # Initial system message 64 | system_message = { 65 | "role": "system", 66 | "content": "You are a helpful and friendly AI assistant.", 67 | } 68 | messages.append(system_message) 69 | 70 | # First user message 71 | first_user_message = { 72 | "role": "user", 73 | "content": "Hello! Can you tell me what Forge is?", 74 | } 75 | messages.append(first_user_message) 76 | 77 | print("\n--- User: Hello! Can you tell me what Forge is?") 78 | result = chat_completion(messages) 79 | 80 | if result["success"]: 81 | print(f"\n--- AI ({result['model']}): {result['response']}") 82 | messages.append({"role": "assistant", "content": result["response"]}) 83 | else: 84 | print(f"\nError: {result['error']}") 85 | print(f"Details: {result.get('details', 'No details available')}") 86 | return 87 | 88 | # Second user message 89 | second_user_message = { 90 | "role": "user", 91 | "content": "How can I use Forge with other frontend applications?", 92 | } 93 | messages.append(second_user_message) 94 | 95 | print("\n--- User: How can I use Forge with other frontend applications?") 96 | result = chat_completion(messages) 97 | 98 | if result["success"]: 99 | print(f"\n--- AI ({result['model']}): {result['response']}") 100 | messages.append({"role": "assistant", "content": result["response"]}) 101 | else: 102 | print(f"\nError: {result['error']}") 103 | print(f"Details: {result.get('details', 'No details available')}") 104 | return 105 | 106 | # Try a different model if available 107 | third_user_message = { 108 | "role": "user", 109 | "content": "Can you summarize our conversation so far?", 110 | } 111 | messages.append(third_user_message) 112 | 113 | print("\n--- User: Can you summarize our conversation so far?") 114 | 115 | # Try with a different model if available 116 | alternative_model = "mock-only-gpt-4" # Will fallback if not available 117 | result = chat_completion(messages, model=alternative_model) 118 | 119 | if result["success"]: 120 | print(f"\n--- AI ({result['model']}): {result['response']}") 121 | else: 122 | print(f"\nError: {result['error']}") 123 | print(f"Details: {result.get('details', 'No details available')}") 124 | 125 | 126 | if __name__ == "__main__": 127 | print("🔄 Simulating a frontend application using Forge API") 128 | print(f"🔌 Connecting to Forge at {FORGE_API_URL}") 129 | print(f"🔑 Using Forge API Key: {FORGE_API_KEY[:8]}...") 130 | 131 | simulate_conversation() 132 | -------------------------------------------------------------------------------- /alembic/versions/c9f3e548adef_add_lambda_model_pricing.py: -------------------------------------------------------------------------------- 1 | """add lambda model pricing 2 | 3 | Revision ID: c9f3e548adef 4 | Revises: 39bcedfae4fe 5 | Create Date: 2025-08-25 19:53:57.606298 6 | 7 | """ 8 | from csv import DictReader 9 | from decimal import Decimal 10 | import os 11 | from alembic import op 12 | import sqlalchemy as sa 13 | from datetime import datetime, timedelta, UTC 14 | 15 | 16 | # revision identifiers, used by Alembic. 17 | revision = 'c9f3e548adef' 18 | down_revision = '39bcedfae4fe' 19 | branch_labels = None 20 | depends_on = None 21 | 22 | 23 | def upgrade() -> None: 24 | # insert the lambda data 25 | effective_date = datetime.now(UTC) - timedelta(days=1) 26 | csv_path = os.path.join(os.path.dirname(__file__), "..", "..", "tools", "data", "lambda_model_pricing_init.csv") 27 | with open(csv_path, "r") as f: 28 | reader = DictReader(f) 29 | rows_to_insert = [] 30 | for row in reader: 31 | rows_to_insert.append({ 32 | "provider_name": row["provider_name"], 33 | "model_name": row["model_name"], 34 | "effective_date": effective_date, 35 | "input_token_price": Decimal(str(row["input_token_price"])).normalize(), 36 | "output_token_price": Decimal(str(row["output_token_price"])).normalize(), 37 | "price_source": "manual" 38 | }) 39 | 40 | if rows_to_insert: 41 | connection = op.get_bind() 42 | connection.execute( 43 | sa.text(""" 44 | INSERT INTO model_pricing (provider_name, model_name, effective_date, input_token_price, output_token_price, cached_token_price, price_source) 45 | VALUES (:provider_name, :model_name, :effective_date, :input_token_price, :output_token_price, :input_token_price, 'manual') 46 | """), 47 | rows_to_insert, 48 | ) 49 | connection.execute( 50 | sa.text(""" 51 | INSERT INTO fallback_pricing (provider_name, model_name, effective_date, input_token_price, output_token_price, cached_token_price, fallback_type) 52 | VALUES (:provider_name, :model_name, :effective_date, :input_token_price, :output_token_price, :input_token_price, 'model_default') 53 | """), 54 | rows_to_insert, 55 | ) 56 | 57 | # Fix the cached_token_price for all the other models 58 | csv_path = os.path.join(os.path.dirname(__file__), "..", "..", "tools", "data", "model_pricing_init.csv") 59 | with open(csv_path, "r") as f: 60 | reader = DictReader(f) 61 | rows_to_update = [] 62 | for row in reader: 63 | rows_to_update.append({ 64 | "provider_name": row["provider_name"], 65 | "model_name": row["model_name"], 66 | "input_token_price": Decimal(str(row["input_token_price"])).normalize(), 67 | "output_token_price": Decimal(str(row["output_token_price"])).normalize(), 68 | "cached_token_price": Decimal(str(row["cached_token_price"])).normalize(), 69 | }) 70 | 71 | if rows_to_update: 72 | connection = op.get_bind() 73 | connection.execute( 74 | sa.text(""" 75 | update model_pricing set cached_token_price = :cached_token_price 76 | where provider_name = :provider_name and model_name = :model_name 77 | """), 78 | rows_to_update, 79 | ) 80 | connection.execute( 81 | sa.text(""" 82 | update fallback_pricing set cached_token_price = :cached_token_price 83 | where provider_name = :provider_name and model_name = :model_name 84 | """), 85 | rows_to_update, 86 | ) 87 | 88 | # backfill the cached_token_price for all the other models 89 | connection = op.get_bind() 90 | connection.execute( 91 | sa.text(""" 92 | with updated_model_pricing as ( 93 | update model_pricing set cached_token_price = input_token_price 94 | where cached_token_price = 0 95 | ) 96 | update fallback_pricing set cached_token_price = input_token_price 97 | where cached_token_price = 0 98 | """), 99 | rows_to_insert, 100 | ) 101 | 102 | # remove the default value for cached_token_price 103 | op.alter_column('model_pricing', 'cached_token_price', server_default=None) 104 | op.alter_column('fallback_pricing', 'cached_token_price', server_default=None) 105 | 106 | 107 | def downgrade() -> None: 108 | connection = op.get_bind() 109 | connection.execute( 110 | sa.text(""" 111 | DELETE FROM model_pricing WHERE provider_name = 'lambda' 112 | """), 113 | ) 114 | connection.execute( 115 | sa.text(""" 116 | DELETE FROM fallback_pricing WHERE provider_name = 'lambda' 117 | """), 118 | ) 119 | --------------------------------------------------------------------------------