├── app
├── api
│ ├── __init__.py
│ ├── schemas
│ │ ├── __init__.py
│ │ ├── admin.py
│ │ ├── cached_user.py
│ │ ├── stripe.py
│ │ ├── user.py
│ │ ├── forge_api_key.py
│ │ ├── statistic.py
│ │ ├── provider_key.py
│ │ └── anthropic.py
│ └── routes
│ │ ├── admin.py
│ │ ├── __init__.py
│ │ ├── stripe.py
│ │ ├── auth.py
│ │ ├── stats.py
│ │ └── health.py
├── core
│ ├── __init__.py
│ ├── logger.py
│ ├── security.py
│ └── database.py
├── models
│ ├── __init__.py
│ ├── admin_users.py
│ ├── base.py
│ ├── stripe.py
│ ├── wallet.py
│ ├── provider_key.py
│ ├── usage_tracker.py
│ ├── user.py
│ ├── api_request_log.py
│ ├── forge_api_key.py
│ └── pricing.py
├── exceptions
│ ├── __init__.py
│ └── exceptions.py
├── services
│ ├── __init__.py
│ └── providers
│ │ ├── __init__.py
│ │ ├── perplexity_adapter.py
│ │ ├── zai_adapter.py
│ │ ├── zhipu_adapter.py
│ │ ├── tensorblock_adapter.py
│ │ ├── alibaba_adapter.py
│ │ ├── gemini_openai_adapter.py
│ │ └── fireworks_adapter.py
├── __init__.py
└── utils
│ ├── serialization.py
│ └── translator.py
├── assets
└── service.jpg
├── tests
├── cache
│ └── __init__.py
├── __init__.py
├── mock_testing
│ ├── examples
│ │ ├── __init__.py
│ │ └── test_with_mocks.py
│ ├── __init__.py
│ ├── run_mock_tests.py
│ └── mock_openai.py
├── performance
│ ├── __init__.py
│ └── README.md
├── unit_tests
│ ├── assets
│ │ ├── openai
│ │ │ ├── embeddings_response.json
│ │ │ ├── chat_completion_response_1.json
│ │ │ ├── list_models.json
│ │ │ └── responses_response_1.json
│ │ ├── anthropic
│ │ │ ├── chat_completion_response_1.json
│ │ │ ├── list_models.json
│ │ │ └── chat_completion_streaming_response_1.json
│ │ └── google
│ │ │ ├── chat_completion_response_1.json
│ │ │ ├── chat_completion_streaming_response_1.json
│ │ │ └── list_models.json
│ ├── test_vertex_adapter.py
│ └── test_security.py
├── run_tests.py
├── README.md
├── test_token_counting_stream.py
└── frontend_simulation.py
├── run_unit_tests.sh
├── k8s
├── service.yaml
└── deployment.yaml
├── alembic
├── versions
│ ├── 6a92c2663fa5_base.py
│ ├── f5522808aba9_.py
│ ├── 683fc811a969_add_billable_column_to_usage_tracker.py
│ ├── f45e46b231f3_add_admin_users.py
│ ├── 831fc2cf16ee_enable_soft_deletion_for_provider_keys_.py
│ ├── b206e9a941e3_add_cost_tracking_to_usage_tracker_table.py
│ ├── 40e4b59f754d_add_stripe_payment_table.py
│ ├── 9daf34d338f7_update_model_mapping_type_for_.py
│ ├── b5d4363a9f62_add_clerk_user_id_to_users.py
│ ├── a58395ea1b22_add_balance_system.py
│ ├── 4a685a55c5cd_create_usage_tracker_table.py
│ ├── 39bcedfae4fe_add_model_default_pricing.py
│ ├── 08cc005a4bc8_create_forge_api_key_provider_scope_.py
│ ├── 4a82fb8af123_create_forge_api_keys_table.py
│ ├── ca1ac51334ec_create_usage_stats_table.py
│ ├── 0ce4eeae965f_drop_usage_stats_table.py
│ ├── initial_migration.py
│ ├── c50fd7be794c_add_endpoint_column_to_usage_stats.py
│ ├── b38aad374524_create_api_request_log_table.py
│ └── c9f3e548adef_add_lambda_model_pricing.py
├── script.py.mako
└── env.py
├── docker-entrypoint.sh
├── .pre-commit-config.yaml
├── .gitignore
├── tools
├── data
│ └── lambda_model_pricing_init.csv
└── diagnostics
│ ├── run_test_clean.py
│ ├── fix_model_mapping.py
│ ├── enable_request_logging.py
│ ├── clear_cache.py
│ ├── check_dotenv.py
│ └── check_db_keys.py
├── docker-compose.yml
├── reset_alembic.py
├── cli_tools
└── config.ini
├── run.py
├── .github
├── PULL_REQUEST_TEMPLATE
│ └── pull_request_template.md
└── workflows
│ └── docker-build.yml
├── Dockerfile
├── .env.example
├── alembic.ini
├── pyproject.toml
├── docs
└── PERFORMANCE_OPTIMIZATIONS.md
└── LICENSE
/app/api/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/api/schemas/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/exceptions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/services/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/app/__init__.py:
--------------------------------------------------------------------------------
1 | # Forge middleware service
2 |
--------------------------------------------------------------------------------
/assets/service.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TensorBlock/forge/HEAD/assets/service.jpg
--------------------------------------------------------------------------------
/tests/cache/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for the caching functionality in Forge.
3 | """
4 |
--------------------------------------------------------------------------------
/app/services/providers/__init__.py:
--------------------------------------------------------------------------------
1 | # This file should be created to make the directory a package
2 |
--------------------------------------------------------------------------------
/run_unit_tests.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash -xe
2 | uv run pytest -vv tests/unit_tests --cov=app --cov-report=xml
3 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # This file is intentionally left empty to make the tests directory a proper Python package
2 |
--------------------------------------------------------------------------------
/tests/mock_testing/examples/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Mock testing examples.
3 |
4 | This package contains example tests that demonstrate using the mock client
5 | in different testing scenarios.
6 | """
7 |
--------------------------------------------------------------------------------
/tests/performance/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Performance test module for Forge API.
3 |
4 | This package contains tests to measure and analyze the performance
5 | of the Forge middleware service.
6 | """
7 |
8 | __version__ = "0.1.0"
9 |
--------------------------------------------------------------------------------
/tests/mock_testing/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Forge Mock Testing Package.
3 |
4 | This package provides tools for testing Forge with mock providers
5 | that simulate API responses without making actual API calls.
6 | """
7 |
8 | # Export key components for easy importing
9 |
--------------------------------------------------------------------------------
/k8s/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: forge-service
5 | labels:
6 | app: forge
7 | spec:
8 | type: LoadBalancer
9 | ports:
10 | - port: 80
11 | targetPort: 8000
12 | protocol: TCP
13 | name: http
14 | selector:
15 | app: forge
16 |
--------------------------------------------------------------------------------
/app/models/admin_users.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import Column, Integer, ForeignKey
2 | from sqlalchemy.orm import relationship
3 |
4 | from app.models.base import Base
5 |
6 | class AdminUsers(Base):
7 | __tablename__ = "admin_users"
8 |
9 | user_id = Column(Integer, ForeignKey("users.id"), primary_key=True)
10 | user = relationship("User", back_populates="admin_users")
11 |
--------------------------------------------------------------------------------
/app/api/schemas/admin.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, field_validator
2 |
3 | class AddBalanceRequest(BaseModel):
4 | user_id: int | None = None
5 | email: str | None = None
6 | amount: int # in cents
7 |
8 | @field_validator("amount")
9 | def validate_amount(cls, value: float):
10 | if value < 100:
11 | raise ValueError("Amount must be greater than 100 cents")
12 | return value
13 |
--------------------------------------------------------------------------------
/alembic/versions/6a92c2663fa5_base.py:
--------------------------------------------------------------------------------
1 | """base
2 |
3 | Revision ID: 6a92c2663fa5
4 | Revises: f5522808aba9
5 | Create Date: 2025-04-29 16:02:26.813233
6 |
7 | """
8 |
9 |
10 | # revision identifiers, used by Alembic.
11 | revision = "6a92c2663fa5"
12 | down_revision = "f5522808aba9"
13 | branch_labels = None
14 | depends_on = None
15 |
16 |
17 | def upgrade() -> None:
18 | pass
19 |
20 |
21 | def downgrade() -> None:
22 | pass
23 |
--------------------------------------------------------------------------------
/alembic/versions/f5522808aba9_.py:
--------------------------------------------------------------------------------
1 | """empty message
2 |
3 | Revision ID: f5522808aba9
4 | Revises: 4a82fb8af123
5 | Create Date: 2025-04-29 16:02:04.550374
6 |
7 | """
8 |
9 |
10 | # revision identifiers, used by Alembic.
11 | revision = "f5522808aba9"
12 | down_revision = "4a82fb8af123"
13 | branch_labels = None
14 | depends_on = None
15 |
16 |
17 | def upgrade() -> None:
18 | pass
19 |
20 |
21 | def downgrade() -> None:
22 | pass
23 |
--------------------------------------------------------------------------------
/app/services/providers/perplexity_adapter.py:
--------------------------------------------------------------------------------
1 | from .openai_adapter import OpenAIAdapter
2 |
3 | PERPLEXITY_MODELS = [
4 | "sonar",
5 | "sonar-reasoning-pro",
6 | "sonar-reasoning",
7 | "sonar-pro",
8 | "sonar",
9 | "sonar-deep-research",
10 | ]
11 |
12 | class PerplexityAdapter(OpenAIAdapter):
13 | """Adapter for Perplexity API"""
14 |
15 | async def list_models(self, api_key: str) -> list[str]:
16 | return PERPLEXITY_MODELS
17 |
--------------------------------------------------------------------------------
/tests/unit_tests/assets/openai/embeddings_response.json:
--------------------------------------------------------------------------------
1 | {
2 | "object": "list",
3 | "data": [
4 | {
5 | "object": "embedding",
6 | "index": 0,
7 | "embedding": [
8 | -0.006929283495992422,
9 | -0.005336422007530928,
10 | -4.547132266452536e-05,
11 | -0.024047503247857094
12 | ]
13 | }
14 | ],
15 | "model": "text-embedding-ada-002",
16 | "usage": {
17 | "prompt_tokens": 8,
18 | "total_tokens": 8
19 | }
20 | }
--------------------------------------------------------------------------------
/app/services/providers/zai_adapter.py:
--------------------------------------------------------------------------------
1 | from .openai_adapter import OpenAIAdapter
2 |
3 | # https://docs.z.ai/api-reference/llm/chat-completion#body-model
4 | ZAI_MODELS = [
5 | "glm-4.5",
6 | "glm-4.5-air",
7 | "glm-4.5-x",
8 | "glm-4.5-airx",
9 | "glm-4.5-flash",
10 | "glm-4-32b-0414-128k",
11 | ]
12 |
13 |
14 | class ZAIAdapter(OpenAIAdapter):
15 | """Adapter for Zai API"""
16 |
17 | async def list_models(self, api_key: str) -> list[str]:
18 | return ZAI_MODELS
--------------------------------------------------------------------------------
/app/models/base.py:
--------------------------------------------------------------------------------
1 | from datetime import UTC
2 | from datetime import datetime
3 |
4 | from sqlalchemy import Column, DateTime, Integer
5 |
6 | from app.core.database import Base
7 |
8 |
9 | class BaseModel(Base):
10 | __abstract__ = True
11 |
12 | id = Column(Integer, primary_key=True, index=True)
13 | created_at = Column(DateTime, default=datetime.utcnow)
14 | updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
15 | deleted_at = Column(DateTime(timezone=True), nullable=True)
16 |
--------------------------------------------------------------------------------
/app/services/providers/zhipu_adapter.py:
--------------------------------------------------------------------------------
1 | from .openai_adapter import OpenAIAdapter
2 | ZHIPU_MODELS = [
3 | "glm-4-plus",
4 | "glm-4-0520",
5 | "glm-4",
6 | "glm-4-air",
7 | "glm-4-airx",
8 | "glm-4-long",
9 | "glm-4-flash",
10 | "glm-4v-plus-0111",
11 | "glm-4v-flash"
12 | "glm-z1-air",
13 | "glm-z1-airx",
14 | "glm-z1-flash",
15 | ]
16 |
17 | class ZhipuAdapter(OpenAIAdapter):
18 | """Adapter for Zhipu API"""
19 |
20 | async def list_models(self, api_key: str) -> list[str]:
21 | return ZHIPU_MODELS
22 |
--------------------------------------------------------------------------------
/docker-entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Exit if any command fails
4 | set -e
5 |
6 | # Set correct permissions for the logs directory at runtime.
7 | # This ensures the 'nobody' user can write to the volume.
8 | chown -R nobody:nogroup /app/logs
9 |
10 | # Run Alembic migrations
11 | echo "Running database migrations..."
12 | if ! alembic upgrade head; then
13 | echo "⚠️ Warning: Alembic migration failed. Continuing without shutdown."
14 | fi
15 |
16 | # Use gosu to drop from root to the 'nobody' user and run the main command
17 | exec gosu nobody "$@"
18 |
19 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.4.0
4 | hooks:
5 | - id: trailing-whitespace
6 | - id: end-of-file-fixer
7 | - id: check-yaml
8 | - id: check-added-large-files
9 |
10 | - repo: https://github.com/astral-sh/ruff-pre-commit
11 | rev: v0.2.0
12 | hooks:
13 | # Run the linter
14 | - id: ruff
15 | args: [--fix]
16 | types_or: [python, pyi, jupyter]
17 | # Run the formatter
18 | - id: ruff-format
19 | types_or: [python, pyi, jupyter]
20 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_vertex_adapter.py:
--------------------------------------------------------------------------------
1 | from app.services.providers.vertex_adapter import VertexAdapter
2 |
3 |
4 | def test_vertex_adapter_base_url_global():
5 | config = {"publisher": "anthropic", "location": "global"}
6 | adapter = VertexAdapter("vertex", None, config)
7 | assert adapter._base_url == "https://aiplatform.googleapis.com"
8 |
9 |
10 | def test_vertex_adapter_base_url_region():
11 | config = {"publisher": "anthropic", "location": "us-east1"}
12 | adapter = VertexAdapter("vertex", None, config)
13 | assert adapter._base_url == "https://us-east1-aiplatform.googleapis.com"
--------------------------------------------------------------------------------
/alembic/script.py.mako:
--------------------------------------------------------------------------------
1 | """${message}
2 |
3 | Revision ID: ${up_revision}
4 | Revises: ${down_revision | comma,n}
5 | Create Date: ${create_date}
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 | ${imports if imports else ""}
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = ${repr(up_revision)}
14 | down_revision = ${repr(down_revision)}
15 | branch_labels = ${repr(branch_labels)}
16 | depends_on = ${repr(depends_on)}
17 |
18 |
19 | def upgrade() -> None:
20 | ${upgrades if upgrades else "pass"}
21 |
22 |
23 | def downgrade() -> None:
24 | ${downgrades if downgrades else "pass"}
25 |
--------------------------------------------------------------------------------
/alembic/versions/683fc811a969_add_billable_column_to_usage_tracker.py:
--------------------------------------------------------------------------------
1 | """add billable column to usage_tracker
2 |
3 | Revision ID: 683fc811a969
4 | Revises: 40e4b59f754d
5 | Create Date: 2025-09-05 10:48:09.623668
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 |
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = '683fc811a969'
14 | down_revision = '40e4b59f754d'
15 | branch_labels = None
16 | depends_on = None
17 |
18 |
19 | def upgrade() -> None:
20 | op.add_column('usage_tracker', sa.Column('billable', sa.Boolean(), nullable=False, server_default='FALSE'))
21 |
22 |
23 | def downgrade() -> None:
24 | op.drop_column('usage_tracker', 'billable')
25 |
--------------------------------------------------------------------------------
/app/api/schemas/cached_user.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | from pydantic import BaseModel, ConfigDict, EmailStr
4 |
5 |
6 | class CachedUser(BaseModel):
7 | """
8 | A Pydantic model representing the data of a User object to be stored in the cache.
9 | This avoids caching the full SQLAlchemy ORM object, which prevents DetachedInstanceError
10 | and improves performance by removing the need for `db.merge()`.
11 | """
12 |
13 | id: int
14 | email: EmailStr
15 | username: str
16 | is_active: bool
17 | clerk_user_id: str | None = None
18 | created_at: datetime
19 | updated_at: datetime
20 | model_config = ConfigDict(from_attributes=True)
21 |
--------------------------------------------------------------------------------
/app/models/stripe.py:
--------------------------------------------------------------------------------
1 | from app.models.base import Base
2 | from sqlalchemy import Column, String, Integer, ForeignKey, JSON, DateTime
3 | from datetime import datetime, UTC
4 |
5 | class StripePayment(Base):
6 | __tablename__ = "stripe_payment"
7 |
8 | id = Column(String, primary_key=True)
9 | user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
10 | amount = Column(Integer, nullable=False)
11 | currency = Column(String(3), nullable=False)
12 | status = Column(String, nullable=False)
13 | raw_data = Column(JSON, nullable=True)
14 | created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
15 | updated_at = Column(DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC))
--------------------------------------------------------------------------------
/tests/unit_tests/assets/anthropic/chat_completion_response_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "id": "msg_01Kpzcjm6jUdJ5RY5Dxt3rV8",
3 | "type": "message",
4 | "role": "assistant",
5 | "model": "claude-sonnet-4-20250514",
6 | "content": [
7 | {
8 | "type": "text",
9 | "text": "Hello! I'm doing well, thank you for asking. I'm here and ready to help with whatever you'd like to discuss or work on. How are you doing today?"
10 | }
11 | ],
12 | "stop_reason": "end_turn",
13 | "stop_sequence": null,
14 | "usage": {
15 | "input_tokens": 13,
16 | "cache_creation_input_tokens": 0,
17 | "cache_read_input_tokens": 0,
18 | "output_tokens": 39,
19 | "service_tier": "standard"
20 | }
21 | }
--------------------------------------------------------------------------------
/alembic/versions/f45e46b231f3_add_admin_users.py:
--------------------------------------------------------------------------------
1 | """add admin users
2 |
3 | Revision ID: f45e46b231f3
4 | Revises: 683fc811a969
5 | Create Date: 2025-09-11 13:14:17.066592
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 |
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = 'f45e46b231f3'
14 | down_revision = '683fc811a969'
15 | branch_labels = None
16 | depends_on = None
17 |
18 |
19 | def upgrade() -> None:
20 | op.create_table(
21 | "admin_users",
22 | sa.Column("user_id", sa.Integer(), nullable=False),
23 | sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
24 | sa.PrimaryKeyConstraint("user_id"),
25 | )
26 |
27 |
28 | def downgrade() -> None:
29 | op.drop_table("admin_users")
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # Environment
28 | .env
29 | .venv
30 | env/
31 | ENV/
32 | env.bak/
33 | venv.bak/
34 |
35 | # Database
36 | *.db
37 | *.sqlite3
38 |
39 | # IDE
40 | .idea/
41 | .vscode/
42 | *.swp
43 | *.swo
44 |
45 | # Logs
46 | *.log
47 |
48 | # VScode settings
49 | .vscode-upload.json
50 |
51 | .ruff_cache/
52 |
53 | # macOS
54 | .DS_Store
55 |
56 | # Performance test results
57 | tests/performance/results/
58 | .coverage*
59 | coverage*
60 |
61 | # Local test script
62 | tools/tmp/*
63 |
--------------------------------------------------------------------------------
/app/models/wallet.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, UTC
2 | from sqlalchemy import Column, BigInteger, CHAR, DECIMAL, Boolean, DateTime, ForeignKey
3 | from sqlalchemy.orm import relationship
4 | from .base import Base
5 |
6 | class Wallet(Base):
7 | __tablename__ = "wallets"
8 |
9 | account_id = Column(BigInteger, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True)
10 | currency = Column(CHAR(3), nullable=False, default='USD')
11 | balance = Column(DECIMAL(20, 6), nullable=False, default=0)
12 | blocked = Column(Boolean, nullable=False, default=False)
13 | version = Column(BigInteger, nullable=False, default=0)
14 | created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.now(UTC))
15 | updated_at = Column(DateTime(timezone=True), nullable=False, default=datetime.now(UTC))
16 |
17 | user = relationship("User", back_populates="wallet")
18 |
--------------------------------------------------------------------------------
/app/services/providers/tensorblock_adapter.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from .azure_adapter import AzureAdapter
4 |
5 | TENSORBLOCK_MODELS = [
6 | "gpt-4.1-mini",
7 | "gpt-4.1-nano",
8 | "gpt-4o-mini",
9 | "o3-mini",
10 | "text-embedding-3-large",
11 | "text-embedding-3-small",
12 | "text-embedding-ada-002",
13 | ]
14 |
15 |
16 | class TensorblockAdapter(AzureAdapter):
17 | """Adapter for Tensorblock API"""
18 |
19 | def __init__(self, provider_name: str, base_url: str, config: dict[str, Any]):
20 | super().__init__(provider_name, base_url, config)
21 |
22 | def get_mapped_model(self, model: str) -> str:
23 | """Get the Azure-specific model name"""
24 | # For TensorBlock, we use the model name as-is since it's already in the correct format
25 | return model
26 |
27 | async def list_models(self, api_key: str) -> list[str]:
28 | return TENSORBLOCK_MODELS
29 |
--------------------------------------------------------------------------------
/tools/data/lambda_model_pricing_init.csv:
--------------------------------------------------------------------------------
1 | provider_name,model_name,input_token_price,output_token_price
2 | lambda,deepseek-r1-0528,0.0005,0.00218
3 | lambda,deepseek-v3-0324,0.00034,0.00088
4 | lambda,qwen-3-32b,0.0001,0.0003
5 | lambda,llama-4-maverick-17b-128e-instruct-fp8,0.00018,0.0006
6 | lambda,llama-4-scout-17b-16e-instruct,0.00008,0.0003
7 | lambda,llama-3.1-8b-instruct,0.000025,0.00004
8 | lambda,llama-3.1-70b-instruct,0.00012,0.0003
9 | lambda,llama-3.1-405b-instruct,0.0008,0.0008
10 | lambda,deepseek-llama3.3-70b,0.0002,0.0006
11 | lambda,llama-3.3-70b-instruct,0.00012,0.0003
12 | lambda,llama-3.2-3b-instruct,0.000015,0.000025
13 | lambda,hermes-3-llama-3.1-8b,0.000025,0.00004
14 | lambda,hermes-3-llama-3.1-70b (fp8),0.00012,0.0003
15 | lambda,hermes-3-llama-3.1-405b (fp8),0.0008,0.0008
16 | lambda,lfm-40b,0.00015,0.00015
17 | lambda,llama3.1-nemotron-70b-instruct,0.00012,0.0003
18 | lambda,qwen2.5-coder-32b,0.00007,0.00016
19 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | app:
3 | build: .
4 | ports:
5 | - "8000:8000"
6 | environment:
7 | - DATABASE_URL=postgresql://forge:forge@db:5432/forge
8 | - PORT=8000
9 | - DEBUG_CACHE=false # Control cache debugging
10 | - FORGE_DEBUG_LOGGING=false # Control application logging level
11 | depends_on:
12 | - db
13 | volumes:
14 | - .:/app
15 | - forge_logs:/app/logs # Persist logs
16 | networks:
17 | - forge-network
18 |
19 | db:
20 | image: postgres:14
21 | environment:
22 | - POSTGRES_USER=forge
23 | - POSTGRES_PASSWORD=forge
24 | - POSTGRES_DB=forge
25 | volumes:
26 | - postgres_data:/var/lib/postgresql/data
27 | ports:
28 | - "5432:5432"
29 | networks:
30 | - forge-network
31 |
32 | volumes:
33 | postgres_data:
34 | forge_logs: # Add volume for logs
35 |
36 | networks:
37 | forge-network:
38 | driver: bridge
39 |
--------------------------------------------------------------------------------
/alembic/versions/831fc2cf16ee_enable_soft_deletion_for_provider_keys_.py:
--------------------------------------------------------------------------------
1 | """enable soft deletion for provider keys and api keys
2 |
3 | Revision ID: 831fc2cf16ee
4 | Revises: 4a685a55c5cd
5 | Create Date: 2025-08-02 17:50:12.224293
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 |
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = '831fc2cf16ee'
14 | down_revision = '4a685a55c5cd'
15 | branch_labels = None
16 | depends_on = None
17 |
18 |
19 | def upgrade() -> None:
20 | op.add_column('provider_keys', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
21 | op.add_column('forge_api_keys', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
22 | op.alter_column('forge_api_keys', 'key', nullable=True)
23 |
24 |
25 | def downgrade() -> None:
26 | op.drop_column('provider_keys', 'deleted_at')
27 | op.drop_column('forge_api_keys', 'deleted_at')
28 | op.alter_column('forge_api_keys', 'key', nullable=False)
29 |
--------------------------------------------------------------------------------
/alembic/versions/b206e9a941e3_add_cost_tracking_to_usage_tracker_table.py:
--------------------------------------------------------------------------------
1 | """add cost tracking to usage_tracker table
2 |
3 | Revision ID: b206e9a941e3
4 | Revises: 1876c1c4bc96
5 | Create Date: 2025-08-11 18:19:08.581296
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 |
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = 'b206e9a941e3'
14 | down_revision = '1876c1c4bc96'
15 | branch_labels = None
16 | depends_on = None
17 |
18 |
19 | def upgrade() -> None:
20 | op.add_column('usage_tracker', sa.Column('cost', sa.DECIMAL(precision=12, scale=8), nullable=True))
21 | op.add_column('usage_tracker', sa.Column('currency', sa.String(length=3), nullable=True))
22 | op.add_column('usage_tracker', sa.Column('pricing_source', sa.String(length=255), nullable=True))
23 |
24 |
25 | def downgrade() -> None:
26 | op.drop_column('usage_tracker', 'cost')
27 | op.drop_column('usage_tracker', 'currency')
28 | op.drop_column('usage_tracker', 'pricing_source')
29 |
--------------------------------------------------------------------------------
/tools/diagnostics/run_test_clean.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Script to run the test with a clean environment.
4 | """
5 |
6 | import os
7 | import subprocess
8 | import sys
9 |
10 |
11 | def main():
12 | """Run the test in a clean environment."""
13 | # Find the project root directory (where this script is located)
14 | script_dir = os.path.dirname(
15 | os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
16 | )
17 | os.chdir(script_dir) # Change to project root directory
18 |
19 | # Create a clean environment without the problematic variable
20 | clean_env = os.environ.copy()
21 | if "FORGE_API_KEY" in clean_env:
22 | del clean_env["FORGE_API_KEY"]
23 |
24 | # Run the test with the clean environment
25 | result = subprocess.run(
26 | [sys.executable, "tests/frontend_simulation.py"], env=clean_env, check=False
27 | )
28 |
29 | print(f"\nTest completed with exit code: {result.returncode}")
30 |
31 |
32 | if __name__ == "__main__":
33 | main()
34 |
--------------------------------------------------------------------------------
/tests/unit_tests/assets/anthropic/list_models.json:
--------------------------------------------------------------------------------
1 | {
2 | "data": [
3 | {
4 | "type": "model",
5 | "id": "claude-opus-4-20250514",
6 | "display_name": "Claude Opus 4",
7 | "created_at": "2025-05-22T00:00:00Z"
8 | },
9 | {
10 | "type": "model",
11 | "id": "claude-sonnet-4-20250514",
12 | "display_name": "Claude Sonnet 4",
13 | "created_at": "2025-05-22T00:00:00Z"
14 | },
15 | {
16 | "type": "model",
17 | "id": "claude-3-7-sonnet-20250219",
18 | "display_name": "Claude Sonnet 3.7",
19 | "created_at": "2025-02-24T00:00:00Z"
20 | },
21 | {
22 | "type": "model",
23 | "id": "claude-3-5-sonnet-20241022",
24 | "display_name": "Claude Sonnet 3.5 (New)",
25 | "created_at": "2024-10-22T00:00:00Z"
26 | }
27 | ],
28 | "has_more": false,
29 | "first_id": "claude-opus-4-20250514",
30 | "last_id": "claude-3-opus-20240229"
31 | }
--------------------------------------------------------------------------------
/tests/unit_tests/assets/google/chat_completion_response_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "candidates": [
3 | {
4 | "content": {
5 | "parts": [
6 | {
7 | "text": "I am doing well, thank you for asking. How are you today?\n"
8 | }
9 | ],
10 | "role": "model"
11 | },
12 | "finishReason": "STOP",
13 | "avgLogprobs": -0.2357906699180603
14 | }
15 | ],
16 | "usageMetadata": {
17 | "promptTokenCount": 6,
18 | "candidatesTokenCount": 16,
19 | "totalTokenCount": 22,
20 | "promptTokensDetails": [
21 | {
22 | "modality": "TEXT",
23 | "tokenCount": 6
24 | }
25 | ],
26 | "candidatesTokensDetails": [
27 | {
28 | "modality": "TEXT",
29 | "tokenCount": 16
30 | }
31 | ]
32 | },
33 | "modelVersion": "gemini-1.5-pro-002",
34 | "responseId": "zCFeaOSZBvmSmNAP-sXGiQU"
35 | }
--------------------------------------------------------------------------------
/app/utils/serialization.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Any, Dict
3 |
4 | def serialize_dict(data: Dict[str, Any]) -> str:
5 | """
6 | Serialize a dictionary to a string using JSON.
7 |
8 | Args:
9 | data: The dictionary to serialize
10 |
11 | Returns:
12 | A JSON string representation of the dictionary
13 |
14 | Raises:
15 | TypeError: If the input is not a dictionary
16 | ValueError: If the dictionary contains non-serializable values
17 | """
18 | if not isinstance(data, dict):
19 | raise TypeError("Input must be a dictionary")
20 |
21 | return json.dumps(data)
22 |
23 | def deserialize_dict(serialized_data: str) -> Dict[str, Any]:
24 | """
25 | Deserialize a string back into a dictionary using JSON.
26 |
27 | Args:
28 | serialized_data: The JSON string to deserialize
29 |
30 | Returns:
31 | The deserialized dictionary
32 |
33 | Raises:
34 | ValueError: If the input string is not valid JSON
35 | """
36 | return json.loads(serialized_data)
--------------------------------------------------------------------------------
/reset_alembic.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import sys
4 |
5 | from dotenv import load_dotenv
6 | from sqlalchemy import create_engine
7 | from sqlalchemy.sql import text
8 |
9 | load_dotenv()
10 |
11 | # Get connection string from environment variable
12 | conn_string = os.environ.get("DATABASE_URL")
13 | if not conn_string:
14 | print("Please set the DATABASE_URL environment variable")
15 | sys.exit(1)
16 |
17 | engine = create_engine(conn_string)
18 |
19 | with engine.connect() as connection:
20 | try:
21 | connection.execute(text("DROP TABLE IF EXISTS alembic_version"))
22 | connection.execute(
23 | text("CREATE TABLE alembic_version (version_num VARCHAR(32) PRIMARY KEY)")
24 | )
25 | connection.execute(text("INSERT INTO alembic_version VALUES ('6a92c2663fa5')"))
26 | connection.commit()
27 | print("Successfully reset alembic version to 6a92c2663fa5")
28 | except Exception as e:
29 | print(f"Error resetting alembic version: {e}")
30 | print("Error: No version found in alembic_version table")
31 | sys.exit(1)
32 |
--------------------------------------------------------------------------------
/app/core/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger configuration using loguru.
3 | """
4 |
5 | import sys
6 | from pathlib import Path
7 |
8 | from loguru import logger
9 |
10 | # Remove default handler
11 | logger.remove()
12 |
13 | # Add console handler with custom format
14 | logger.add(
15 | sys.stderr,
16 | format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}",
17 | level="DEBUG",
18 | enqueue=True, # Thread-safe logging
19 | backtrace=True, # Detailed traceback
20 | diagnose=True, # Enable exception diagnosis
21 | )
22 |
23 | # Add file handler for debugging
24 | log_path = Path("logs")
25 | log_path.mkdir(exist_ok=True)
26 |
27 | logger.add(
28 | "logs/forge_{time}.log",
29 | rotation="1 day", # Create new file daily
30 | retention="1 week", # Keep logs for 1 week
31 | compression="zip", # Compress rotated logs
32 | level="DEBUG",
33 | enqueue=True,
34 | backtrace=True,
35 | diagnose=True,
36 | )
37 |
38 | # Export logger instance
39 | get_logger = logger.bind
40 |
--------------------------------------------------------------------------------
/cli_tools/config.ini:
--------------------------------------------------------------------------------
1 | [api]
2 | base_url = http://localhost:8000/v1
3 | timeout = 30
4 | retry_attempts = 3
5 |
6 | [ui]
7 | default_theme = dark
8 | show_clock = true
9 | show_notifications = true
10 | auto_refresh_interval = 30
11 |
12 | [display]
13 | table_page_size = 15
14 | max_log_lines = 100
15 | truncate_long_text = true
16 | max_text_length = 50
17 |
18 | [keybindings]
19 | # Custom keybindings (override defaults)
20 | quit = "q"
21 | toggle_theme = "f1"
22 | refresh_all = "f2"
23 | new_api_key = "ctrl+n"
24 | new_provider = "ctrl+p"
25 | help = "f12"
26 |
27 | [notifications]
28 | # Notification settings
29 | show_success = true
30 | show_errors = true
31 | show_warnings = true
32 | auto_dismiss_time = 5 #
33 |
34 | [development]
35 | # Development and debugging options
36 | debug_mode = false
37 | log_api_requests = false
38 | mock_api_responses = false
39 |
40 | # Color scheme customization (optional)
41 | [colors]
42 | # Uncomment and modify to customize colors
43 | # primary = "#007acc"
44 | # secondary = "#6c757d"
45 | # success = "#28a745"
46 | # warning = "#ffc107"
47 | # error = "#dc3545"
48 | # accent = "#17a2b8"
49 |
--------------------------------------------------------------------------------
/k8s/deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: forge-app
5 | labels:
6 | app: forge
7 | spec:
8 | replicas: 2
9 | selector:
10 | matchLabels:
11 | app: forge
12 | template:
13 | metadata:
14 | labels:
15 | app: forge
16 | spec:
17 | containers:
18 | - name: forge
19 | image: tensorblockai/forge:latest
20 | ports:
21 | - containerPort: 8000
22 | envFrom:
23 | - secretRef:
24 | name: forge-secrets
25 | env:
26 | - name: PORT
27 | value: "8000"
28 | resources:
29 | requests:
30 | cpu: "100m"
31 | memory: "256Mi"
32 | limits:
33 | cpu: "500m"
34 | memory: "512Mi"
35 | livenessProbe:
36 | httpGet:
37 | path: /health
38 | port: 8000
39 | initialDelaySeconds: 30
40 | periodSeconds: 10
41 | readinessProbe:
42 | httpGet:
43 | path: /health
44 | port: 8000
45 | initialDelaySeconds: 5
46 | periodSeconds: 5
47 |
--------------------------------------------------------------------------------
/app/services/providers/alibaba_adapter.py:
--------------------------------------------------------------------------------
1 | from .openai_adapter import OpenAIAdapter
2 |
3 | ALIBABA_MODELS = [
4 | "qwen-max",
5 | "qwen-max-latest",
6 | "qwen-max-2025-01-25",
7 | "qwen-plus",
8 | "qwen-plus-latest",
9 | "qwen-plus-2025-04-28",
10 | "qwen-plus-2025-01-25",
11 | "qwen-turbo",
12 | "qwen-turbo-latest",
13 | "qwen-turbo-2025-04-28",
14 | "qwen-turbo-2024-11-01",
15 | "qwq-32b",
16 | "qwen3-235b-a22b",
17 | "qwen3-32b",
18 | "qwen3-30b-a3b",
19 | "qwen3-14b",
20 | "qwen3-8b",
21 | "qwen3-4b",
22 | "qwen3-1.7b",
23 | "qwen3-0.6b",
24 | "qwen2.5-14b-instruct-1m",
25 | "qwen2.5-7b-instruct-1m",
26 | "qwen2.5-72b-instruct",
27 | "qwen2.5-32b-instruct",
28 | "qwen2.5-14b-instruct",
29 | "qwen2.5-7b-instruct",
30 | "qwen2-72b-instruct",
31 | "qwen2-7b-instruct",
32 | "qwen1.5-110b-chat",
33 | "qwen1.5-72b-chat",
34 | "qwen1.5-32b-chat",
35 | "qwen1.5-14b-chat",
36 | "qwen1.5-7b-chat",
37 | ]
38 |
39 | class AlibabaAdapter(OpenAIAdapter):
40 | """Adapter for Alibaba API"""
41 |
42 | async def list_models(self, api_key: str) -> list[str]:
43 | return ALIBABA_MODELS
--------------------------------------------------------------------------------
/alembic/versions/40e4b59f754d_add_stripe_payment_table.py:
--------------------------------------------------------------------------------
1 | """add stripe payment table
2 |
3 | Revision ID: 40e4b59f754d
4 | Revises: a58395ea1b22
5 | Create Date: 2025-09-02 20:52:29.183031
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 | from datetime import datetime, UTC
11 |
12 |
13 | # revision identifiers, used by Alembic.
14 | revision = '40e4b59f754d'
15 | down_revision = 'a58395ea1b22'
16 | branch_labels = None
17 | depends_on = None
18 |
19 |
20 | def upgrade() -> None:
21 | op.create_table('stripe_payment',
22 | sa.Column('id', sa.String, primary_key=True),
23 | sa.Column('user_id', sa.Integer, sa.ForeignKey('users.id'), nullable=False),
24 | sa.Column('amount', sa.Integer, nullable=False),
25 | sa.Column('currency', sa.String(3), nullable=False),
26 | sa.Column('status', sa.String, nullable=False),
27 | sa.Column('raw_data', sa.JSON, nullable=True),
28 | sa.Column('created_at', sa.DateTime(timezone=True), default=datetime.now(UTC)),
29 | sa.Column('updated_at', sa.DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)),
30 | )
31 |
32 |
33 | def downgrade() -> None:
34 | op.drop_table('stripe_payment')
35 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def main() -> None:
5 | """Launch Gunicorn the same way we do in the Dockerfile for local dev.
6 |
7 | This avoids importing the FastAPI app in the parent process, so no DB
8 | connections are opened before Gunicorn forks its workers (prevents SSL
9 | errors when using Railway/PostgreSQL).
10 | """
11 |
12 | host = os.getenv("HOST", "0.0.0.0")
13 | port = os.getenv("PORT", "8000")
14 | reload = os.getenv("RELOAD", "false").lower() == "true"
15 |
16 | # Optional: let caller override the number of Gunicorn workers
17 | workers_env = os.getenv("WORKERS") # e.g. WORKERS=10
18 |
19 | cmd = [
20 | "gunicorn",
21 | "app.main:app",
22 | "-k",
23 | "uvicorn.workers.UvicornWorker",
24 | "--bind",
25 | f"{host}:{port}",
26 | "--log-level",
27 | "info",
28 | ]
29 |
30 | if reload:
31 | cmd.append("--reload")
32 |
33 | # Inject --workers flag if WORKERS env var is set
34 | if workers_env and workers_env.isdigit():
35 | cmd.extend(["--workers", workers_env])
36 |
37 | # Replace the current process with Gunicorn.
38 | os.execvp(cmd[0], cmd)
39 |
40 |
41 | if __name__ == "__main__":
42 | main()
43 |
--------------------------------------------------------------------------------
/alembic/versions/9daf34d338f7_update_model_mapping_type_for_.py:
--------------------------------------------------------------------------------
1 | """update model_mapping type for ProviderKey table
2 |
3 | Revision ID: 9daf34d338f7
4 | Revises: 08cc005a4bc8
5 | Create Date: 2025-07-18 21:32:48.791253
6 |
7 | """
8 |
9 | from alembic import op
10 | import sqlalchemy as sa
11 | from sqlalchemy.dialects import postgresql
12 |
13 |
14 | # revision identifiers, used by Alembic.
15 | revision = "9daf34d338f7"
16 | down_revision = "08cc005a4bc8"
17 | branch_labels = None
18 | depends_on = None
19 |
20 |
21 | def upgrade() -> None:
22 | # Change model_mapping column from String to JSON
23 | op.alter_column(
24 | "provider_keys",
25 | "model_mapping",
26 | existing_type=sa.String(),
27 | type_=postgresql.JSON(astext_type=sa.Text()),
28 | existing_nullable=True,
29 | postgresql_using="model_mapping::json",
30 | )
31 |
32 |
33 | def downgrade() -> None:
34 | # Change model_mapping column from JSON back to String
35 | op.alter_column(
36 | "provider_keys",
37 | "model_mapping",
38 | existing_type=postgresql.JSON(astext_type=sa.Text()),
39 | type_=sa.String(),
40 | existing_nullable=True,
41 | postgresql_using="model_mapping::text",
42 | )
43 |
--------------------------------------------------------------------------------
/tests/unit_tests/assets/openai/chat_completion_response_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "id": "chatcmpl-BmuFSt7sCoTGzmCCEEHt6vS1YIEdO",
3 | "object": "chat.completion",
4 | "created": 1750995662,
5 | "model": "gpt-4o-mini-2024-07-18",
6 | "choices": [
7 | {
8 | "index": 0,
9 | "message": {
10 | "role": "assistant",
11 | "content": "Hello! I'm just a program, so I don't have feelings, but I'm here and ready to help you. How can I assist you today?",
12 | "refusal": null,
13 | "annotations": []
14 | },
15 | "logprobs": null,
16 | "finish_reason": "stop"
17 | }
18 | ],
19 | "usage": {
20 | "prompt_tokens": 13,
21 | "completion_tokens": 29,
22 | "total_tokens": 42,
23 | "prompt_tokens_details": {
24 | "cached_tokens": 0,
25 | "audio_tokens": 0
26 | },
27 | "completion_tokens_details": {
28 | "reasoning_tokens": 0,
29 | "audio_tokens": 0,
30 | "accepted_prediction_tokens": 0,
31 | "rejected_prediction_tokens": 0
32 | }
33 | },
34 | "service_tier": "default",
35 | "system_fingerprint": "fp_34a54ae93c"
36 | }
--------------------------------------------------------------------------------
/app/models/provider_key.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import Column, ForeignKey, Integer, String, JSON, Boolean
2 | from sqlalchemy.orm import relationship
3 |
4 | from app.models.forge_api_key import forge_api_key_provider_scope_association
5 |
6 | from .base import BaseModel
7 |
8 |
9 | class ProviderKey(BaseModel):
10 | __tablename__ = "provider_keys"
11 |
12 | provider_name = Column(String, index=True) # e.g., "openai", "anthropic", etc.
13 | encrypted_api_key = Column(String)
14 | user_id = Column(Integer, ForeignKey("users.id"))
15 | user = relationship("User", back_populates="provider_keys")
16 |
17 | # Additional metadata specific to the provider
18 | base_url = Column(
19 | String, nullable=True
20 | ) # Allow custom base URLs for some providers
21 | model_mapping = Column(JSON, nullable=True) # JSON dict for model name mappings
22 | billable = Column(Boolean, nullable=False, default=False)
23 |
24 | # Relationship to ForgeApiKeys that have this provider key in their scope
25 | scoped_forge_api_keys = relationship(
26 | "ForgeApiKey",
27 | secondary=forge_api_key_provider_scope_association,
28 | back_populates="allowed_provider_keys",
29 | lazy="selectin",
30 | )
31 | usage_tracker = relationship("UsageTracker", back_populates="provider_key")
32 |
--------------------------------------------------------------------------------
/alembic/versions/b5d4363a9f62_add_clerk_user_id_to_users.py:
--------------------------------------------------------------------------------
1 | """add_clerk_user_id_to_users
2 |
3 | Revision ID: b5d4363a9f62
4 | Revises: 0ce4eeae965f
5 | Create Date: 2025-04-19 10:43:37.264983
6 |
7 | """
8 |
9 | import sqlalchemy as sa
10 |
11 | from alembic import op
12 |
13 | # revision identifiers, used by Alembic.
14 | revision = "b5d4363a9f62"
15 | down_revision = "0ce4eeae965f"
16 | branch_labels = None
17 | depends_on = None
18 |
19 |
20 | def upgrade() -> None:
21 | # Add clerk_user_id column to users table
22 | op.add_column("users", sa.Column("clerk_user_id", sa.String(), nullable=True))
23 | op.create_index(
24 | op.f("ix_users_clerk_user_id"), "users", ["clerk_user_id"], unique=True
25 | )
26 |
27 | # SQLite doesn't support ALTER COLUMN directly, use batch operations instead
28 | with op.batch_alter_table("users") as batch_op:
29 | batch_op.alter_column(
30 | "hashed_password", existing_type=sa.String(), nullable=True
31 | )
32 |
33 |
34 | def downgrade() -> None:
35 | # Remove clerk_user_id column
36 | op.drop_index(op.f("ix_users_clerk_user_id"), table_name="users")
37 | op.drop_column("users", "clerk_user_id")
38 |
39 | # Make hashed_password required again
40 | with op.batch_alter_table("users") as batch_op:
41 | batch_op.alter_column(
42 | "hashed_password", existing_type=sa.String(), nullable=False
43 | )
44 |
--------------------------------------------------------------------------------
/app/utils/translator.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import aiohttp
3 | from http import HTTPStatus
4 |
5 | async def download_image_url(logger, image_url: str) -> str:
6 | """
7 | Download an image from a URL and return the base64 encoded string
8 | """
9 |
10 | # if the image url is a data url, return it as is
11 | if image_url.startswith("data:"):
12 | return image_url
13 |
14 | async with aiohttp.ClientSession() as session:
15 | async with session.head(image_url) as response:
16 | if response.status != HTTPStatus.OK:
17 | error_text = await response.text()
18 | log_error_msg = f"Failed to fetch file metadata from URL: {error_text}"
19 | logger.error(log_error_msg)
20 | raise RuntimeError(log_error_msg)
21 |
22 | mime_type = response.headers.get("Content-Type", "")
23 | file_size = int(response.headers.get("Content-Length", 0))
24 | if file_size > 10 * 1024 * 1024:
25 | log_error_msg = f"Image file size is too large: {file_size} bytes"
26 | logger.error(log_error_msg)
27 | raise RuntimeError(log_error_msg)
28 |
29 | async with session.get(image_url) as response:
30 | # return format is data:mime_type;base64,base64_data
31 | return f"data:{mime_type};base64,{base64.b64encode(await response.read()).decode('utf-8')}"
32 |
--------------------------------------------------------------------------------
/app/services/providers/gemini_openai_adapter.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from app.core.logger import get_logger
4 | from .openai_adapter import OpenAIAdapter
5 |
6 | # Configure logging
7 | logger = get_logger(name="gemini_openai_adapter")
8 |
9 |
10 | class GeminiOpenAIAdapter(OpenAIAdapter):
11 | """Adapter for Google Gemini via the OpenAI-compatible endpoint
12 |
13 | Google now exposes Gemini models behind an OpenAI-compatible REST surface at
14 | https://generativelanguage.googleapis.com/v1beta/openai/…
15 |
16 | We can therefore reuse all logic in OpenAIAdapter. The only wrinkle is that
17 | callers might supply the base_url without the trailing `/openai` segment,
18 | so we normalise it here.
19 | """
20 |
21 | def __init__(
22 | self,
23 | provider_name: str,
24 | base_url: str | None,
25 | config: dict[str, Any] | None = None,
26 | ):
27 | # Default base URL if none supplied
28 | if not base_url:
29 | base_url = "https://generativelanguage.googleapis.com/v1beta"
30 |
31 | # Ensure the URL ends with the OpenAI compatibility suffix
32 | base_url = base_url.rstrip("/")
33 | if not base_url.endswith("/openai"):
34 | base_url = f"{base_url}/openai"
35 |
36 | logger.debug(f"Initialised GeminiOpenAIAdapter with base_url={base_url}")
37 |
38 | super().__init__(provider_name, base_url, config=config or {})
--------------------------------------------------------------------------------
/app/api/schemas/stripe.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, field_validator
2 | from typing import List, Literal
3 |
4 | # https://docs.stripe.com/api/checkout/sessions/create
5 | class StripeCheckoutSessionLineItemPriceDataProductData(BaseModel):
6 | name: str
7 | description: str | None = None
8 | images: List[str] | None = None
9 |
10 | class StripeCheckoutSessionLineItemPriceData(BaseModel):
11 | currency: str
12 | product_data: StripeCheckoutSessionLineItemPriceDataProductData
13 | tax_behavior: str = "inclusive"
14 | unit_amount: int
15 |
16 | class StripeCheckoutSessionLineItem(BaseModel):
17 | price_data: StripeCheckoutSessionLineItemPriceData
18 | quantity: int
19 |
20 | class CreateCheckoutSessionRequest(BaseModel):
21 | line_items: List[StripeCheckoutSessionLineItem]
22 | # Only allow payment mode for now
23 | mode: Literal["payment"] = "payment"
24 | # Attach the session_id to the success_url
25 | # https://docs.stripe.com/payments/checkout/custom-success-page?payment-ui=stripe-hosted&utm_source=chatgpt.com#success-url
26 | success_url: str | None = None
27 | return_url: str | None = None
28 | cancel_url: str | None = None
29 | ui_mode: str = "hosted"
30 |
31 | @field_validator("success_url")
32 | @classmethod
33 | def append_session_id_to_success_url(cls, value: str):
34 | if value is None:
35 | return None
36 | return value.rstrip("/") + "?session_id={CHECKOUT_SESSION_ID}"
--------------------------------------------------------------------------------
/app/api/schemas/user.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | from pydantic import BaseModel, ConfigDict, EmailStr, Field
4 |
5 | from app.api.schemas.forge_api_key import ForgeApiKeyMasked
6 |
7 | # Constants
8 | VISIBLE_API_KEY_CHARS = 4
9 |
10 |
11 | class UserBase(BaseModel):
12 | email: EmailStr
13 | username: str
14 |
15 |
16 | class UserCreate(UserBase):
17 | password: str
18 |
19 |
20 | class UserUpdate(BaseModel):
21 | email: EmailStr | None = None
22 | username: str | None = None
23 | password: str | None = None
24 |
25 |
26 | class UserInDB(UserBase):
27 | id: int
28 | is_active: bool
29 | created_at: datetime
30 | updated_at: datetime
31 | model_config = ConfigDict(from_attributes=True)
32 |
33 |
34 | class User(UserInDB):
35 | forge_api_keys: list[str] | None = None
36 |
37 |
38 | class MaskedUser(UserInDB):
39 | is_admin: bool = False
40 | forge_api_keys: list[str] | None = Field(
41 | description="List of all API keys with all but last 4 digits masked",
42 | default=None,
43 | )
44 |
45 | @classmethod
46 | def mask_api_key(cls, api_key: str | None) -> str | None:
47 | if not api_key:
48 | return None
49 | return ForgeApiKeyMasked.mask_api_key(api_key)
50 |
51 | model_config = ConfigDict(from_attributes=True)
52 |
53 |
54 | class Token(BaseModel):
55 | access_token: str
56 | token_type: str
57 |
58 |
59 | class TokenData(BaseModel):
60 | username: str | None = None
61 |
--------------------------------------------------------------------------------
/alembic/versions/a58395ea1b22_add_balance_system.py:
--------------------------------------------------------------------------------
1 | """add balance system
2 |
3 | Revision ID: a58395ea1b22
4 | Revises: c9f3e548adef
5 | Create Date: 2025-08-20 22:00:45.743308
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 |
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = 'a58395ea1b22'
14 | down_revision = 'c9f3e548adef'
15 | branch_labels = None
16 | depends_on = None
17 |
18 | def upgrade() -> None:
19 | op.create_table(
20 | 'wallets',
21 | sa.Column('account_id', sa.BigInteger(), nullable=False),
22 | sa.Column('currency', sa.CHAR(length=3), nullable=False, server_default='USD'),
23 | sa.Column('balance', sa.DECIMAL(precision=20, scale=6), nullable=False, server_default='0'),
24 | sa.Column('blocked', sa.Boolean(), nullable=False, server_default='FALSE'),
25 | sa.Column('version', sa.BigInteger(), nullable=False, server_default='0'),
26 | sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
27 | sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
28 | sa.PrimaryKeyConstraint('account_id'),
29 | sa.ForeignKeyConstraint(['account_id'], ['users.id'], ondelete='CASCADE')
30 | )
31 | op.add_column('provider_keys', sa.Column('billable', sa.Boolean(), nullable=False, server_default='FALSE'))
32 |
33 | def downgrade() -> None:
34 | op.drop_table('wallets')
35 | op.drop_column('provider_keys', 'billable')
36 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## Description
2 |
3 |
4 | ## Type of Change
5 |
6 | - [ ] Bug fix (non-breaking change which fixes an issue)
7 | - [ ] New feature (non-breaking change which adds functionality)
8 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
9 | - [ ] Documentation update
10 | - [ ] Refactoring (no functional changes)
11 | - [ ] Performance improvement
12 | - [ ] Testing improvement
13 |
14 | ## How Has This Been Tested?
15 |
16 | - [ ] Unit tests
17 | - [ ] Integration tests
18 | - [ ] Manual tests (please describe)
19 |
20 | ## Checklist:
21 |
22 | - [ ] My code follows the code style of this project
23 | - [ ] My changes generate no new warnings
24 | - [ ] I have added tests that prove my fix is effective or that my feature works
25 | - [ ] New and existing unit tests pass locally with my changes
26 | - [ ] Any dependent changes have been merged and published in downstream modules
27 | - [ ] I have updated the documentation accordingly
28 | - [ ] I have checked that my changes don't break backward compatibility
29 | - [ ] I have ensured all automated tests pass (checks are green)
30 |
31 | ## Additional Context
32 |
33 |
--------------------------------------------------------------------------------
/app/models/usage_tracker.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from datetime import UTC
3 | import uuid
4 |
5 | from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, DECIMAL, Boolean
6 | from sqlalchemy.orm import relationship
7 | from sqlalchemy.dialects.postgresql import UUID
8 | from .base import Base
9 |
10 | class UsageTracker(Base):
11 | __tablename__ = "usage_tracker"
12 |
13 | id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
14 | user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
15 | provider_key_id = Column(Integer, ForeignKey("provider_keys.id", ondelete="CASCADE"), nullable=False)
16 | forge_key_id = Column(Integer, ForeignKey("forge_api_keys.id", ondelete="CASCADE"), nullable=False)
17 | model = Column(String, nullable=True)
18 | endpoint = Column(String, nullable=True)
19 | created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC))
20 | updated_at = Column(DateTime(timezone=True), nullable=True)
21 | input_tokens = Column(Integer, nullable=True)
22 | output_tokens = Column(Integer, nullable=True)
23 | cached_tokens = Column(Integer, nullable=True)
24 | reasoning_tokens = Column(Integer, nullable=True)
25 | cost = Column(DECIMAL(12, 8), nullable=True)
26 | currency = Column(String(3), nullable=True)
27 | pricing_source = Column(String(255), nullable=True)
28 | billable = Column(Boolean, nullable=False, default=False)
29 |
30 | provider_key = relationship("ProviderKey", back_populates="usage_tracker")
31 |
--------------------------------------------------------------------------------
/app/models/user.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | from sqlalchemy import Boolean, Column, DateTime, Integer, String
4 | from sqlalchemy.orm import relationship
5 |
6 | from app.models.base import Base
7 |
8 | # UsageStats model is removed, so no related imports needed
9 |
10 |
11 | class User(Base):
12 | """User model for storing user information"""
13 |
14 | __tablename__ = "users"
15 |
16 | id = Column(Integer, primary_key=True, index=True)
17 | email = Column(String, unique=True, index=True, nullable=False)
18 | username = Column(String, unique=True, index=True, nullable=False)
19 | hashed_password = Column(String, nullable=False)
20 | is_active = Column(Boolean, default=True)
21 | # is_admin = Column(Boolean, default=False) # Add if implementing admin role
22 | clerk_user_id = Column(String, unique=True, nullable=True) # Add Clerk user ID
23 | created_at = Column(DateTime, default=datetime.datetime.utcnow)
24 | updated_at = Column(
25 | DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow
26 | )
27 |
28 | # Relationships
29 | api_keys = relationship(
30 | "ForgeApiKey", back_populates="user", cascade="all, delete-orphan"
31 | )
32 | provider_keys = relationship(
33 | "ProviderKey", back_populates="user", cascade="all, delete-orphan"
34 | )
35 | wallet = relationship("Wallet", back_populates="user", uselist=False)
36 | # Optional: Add relationship to ApiRequestLog if needed
37 | # api_logs = relationship("ApiRequestLog")
38 | admin_users = relationship("AdminUsers", back_populates="user", uselist=False)
--------------------------------------------------------------------------------
/app/models/api_request_log.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | from sqlalchemy import Column, DateTime, Float, ForeignKey, Index, Integer, String
4 |
5 | from .base import Base
6 |
7 |
8 | class ApiRequestLog(Base):
9 | __tablename__ = "api_request_log"
10 |
11 | id = Column(Integer, primary_key=True)
12 | # Link to user, allow null if user deleted or request is unauthenticated? Let's make it nullable.
13 | user_id = Column(
14 | Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True, index=True
15 | )
16 | provider_name = Column(String, nullable=False, index=True)
17 | model = Column(String, nullable=False, index=True)
18 | endpoint = Column(String, nullable=False, index=True) # e.g., 'chat/completions'
19 | request_timestamp = Column(
20 | DateTime, default=datetime.datetime.utcnow, nullable=False, index=True
21 | ) # Time of the request
22 | input_tokens = Column(Integer, default=0)
23 | output_tokens = Column(Integer, default=0)
24 | total_tokens = Column(Integer, default=0) # Calculated: input + output
25 | cost = Column(Float, default=0.0) # Estimated cost if available
26 | # Optional: Add status_code and duration_ms later if needed
27 |
28 | # Relationship (optional, useful if querying logs through User object)
29 | # user = relationship("User")
30 |
31 | # Define indices for common query patterns
32 | __table_args__ = (
33 | Index("ix_api_request_log_user_time", "user_id", "request_timestamp"),
34 | # Add other indices as needed, e.g., Index('ix_api_request_log_provider_model', 'provider_name', 'model')
35 | )
36 |
--------------------------------------------------------------------------------
/tools/diagnostics/fix_model_mapping.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Utility script to fix model mappings in the database.
4 | Specifically for fixing the gpt-4o to mock-gpt-4o mapping issue.
5 | """
6 |
7 | import asyncio
8 | import os
9 | import sys
10 | from pathlib import Path
11 |
12 | from app.core.database import get_async_db
13 |
14 | # Add the project root to the Python path
15 | script_dir = Path(__file__).resolve().parent.parent.parent
16 | sys.path.insert(0, str(script_dir))
17 |
18 | # Change to the project root directory
19 | os.chdir(script_dir)
20 |
21 |
22 | async def fix_model_mappings():
23 | """Fix model mappings by clearing caches"""
24 | print("\n🔧 FIXING MODEL MAPPINGS")
25 | print("======================")
26 |
27 | # Get DB session
28 | async with get_async_db() as db:
29 | pass
30 |
31 | # Clear all caches to ensure changes take effect
32 | print("🔄 Invalidating provider service cache for all users")
33 | from app.core.cache import provider_service_cache, user_cache
34 |
35 | provider_service_cache.clear()
36 | user_cache.clear()
37 | print("✅ All caches cleared")
38 |
39 | print("\n✅ Model mapping fix complete.")
40 | return True
41 |
42 |
43 | async def main():
44 | """Main entry point"""
45 | if await fix_model_mappings():
46 | print(
47 | "\n✅ Model mappings have been fixed. Use check_model_mappings.py to verify."
48 | )
49 | sys.exit(0)
50 | else:
51 | print("\n❌ Failed to fix model mappings.")
52 | sys.exit(1)
53 |
54 |
55 | if __name__ == "__main__":
56 | asyncio.run(main())
57 |
--------------------------------------------------------------------------------
/.github/workflows/docker-build.yml:
--------------------------------------------------------------------------------
1 | name: Build and Publish Docker Image
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | release_tag:
7 | description: 'Release Tag (leave empty for auto-generated timestamp)'
8 | required: false
9 | default: ''
10 | debug_logging:
11 | description: 'Enable debug logging in the image'
12 | type: boolean
13 | required: false
14 | default: false
15 |
16 | jobs:
17 | build-and-publish:
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - name: Checkout code
22 | uses: actions/checkout@v3
23 |
24 | - name: Set up Docker Buildx
25 | uses: docker/setup-buildx-action@v2
26 |
27 | - name: Login to Docker Hub
28 | uses: docker/login-action@v2
29 | with:
30 | username: ${{ secrets.DOCKER_USERNAME }}
31 | password: ${{ secrets.DOCKER_PASSWORD }}
32 |
33 | - name: Get timestamp
34 | id: timestamp
35 | run: echo "timestamp=$(date +'%Y%m%d-%H%M%S')" >> $GITHUB_OUTPUT
36 |
37 | - name: Determine tag
38 | id: tag
39 | run: |
40 | if [ -z "${{ github.event.inputs.release_tag }}" ]; then
41 | echo "tag=${{ steps.timestamp.outputs.timestamp }}" >> $GITHUB_OUTPUT
42 | else
43 | echo "tag=${{ github.event.inputs.release_tag }}" >> $GITHUB_OUTPUT
44 | fi
45 |
46 | - name: Build and push
47 | uses: docker/build-push-action@v4
48 | with:
49 | context: .
50 | push: true
51 | tags: |
52 | tensorblockai/forge:latest
53 | tensorblockai/forge:${{ steps.tag.outputs.tag }}
54 | cache-from: type=gha
55 | cache-to: type=gha,mode=max
56 |
--------------------------------------------------------------------------------
/tests/unit_tests/assets/openai/list_models.json:
--------------------------------------------------------------------------------
1 | {
2 | "object": "list",
3 | "data": [
4 | {
5 | "id": "gpt-4-0613",
6 | "object": "model",
7 | "created": 1686588896,
8 | "owned_by": "openai"
9 | },
10 | {
11 | "id": "gpt-4",
12 | "object": "model",
13 | "created": 1687882411,
14 | "owned_by": "openai"
15 | },
16 | {
17 | "id": "gpt-3.5-turbo",
18 | "object": "model",
19 | "created": 1677610602,
20 | "owned_by": "openai"
21 | },
22 | {
23 | "id": "o4-mini-deep-research",
24 | "object": "model",
25 | "created": 1749685485,
26 | "owned_by": "system"
27 | },
28 | {
29 | "id": "o3-deep-research",
30 | "object": "model",
31 | "created": 1749840121,
32 | "owned_by": "system"
33 | },
34 | {
35 | "id": "davinci-002",
36 | "object": "model",
37 | "created": 1692634301,
38 | "owned_by": "system"
39 | },
40 | {
41 | "id": "dall-e-3",
42 | "object": "model",
43 | "created": 1698785189,
44 | "owned_by": "system"
45 | },
46 | {
47 | "id": "dall-e-2",
48 | "object": "model",
49 | "created": 1698798177,
50 | "owned_by": "system"
51 | },
52 | {
53 | "id": "gpt-3.5-turbo-1106",
54 | "object": "model",
55 | "created": 1698959748,
56 | "owned_by": "system"
57 | }
58 | ]
59 | }
--------------------------------------------------------------------------------
/alembic/versions/4a685a55c5cd_create_usage_tracker_table.py:
--------------------------------------------------------------------------------
1 | """create usage tracker table
2 |
3 | Revision ID: 4a685a55c5cd
4 | Revises: 9daf34d338f7
5 | Create Date: 2025-08-02 12:29:07.955645
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 | from sqlalchemy.dialects.postgresql import UUID
11 | import uuid
12 |
13 |
14 | # revision identifiers, used by Alembic.
15 | revision = '4a685a55c5cd'
16 | down_revision = '9daf34d338f7'
17 | branch_labels = None
18 | depends_on = None
19 |
20 |
21 | def upgrade() -> None:
22 | op.create_table(
23 | "usage_tracker",
24 | sa.Column("id", UUID(as_uuid=True), nullable=False, primary_key=True, default=uuid.uuid4),
25 | sa.Column("user_id", sa.Integer(), nullable=False),
26 | sa.Column("provider_key_id", sa.Integer(), nullable=False),
27 | sa.Column("forge_key_id", sa.Integer(), nullable=False),
28 | sa.Column("model", sa.String(), nullable=True),
29 | sa.Column("endpoint", sa.String(), nullable=True),
30 | sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
31 | sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
32 | sa.Column("input_tokens", sa.Integer(), nullable=True),
33 | sa.Column("output_tokens", sa.Integer(), nullable=True),
34 | sa.Column("cached_tokens", sa.Integer(), nullable=True),
35 | sa.Column("reasoning_tokens", sa.Integer(), nullable=True),
36 |
37 | sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
38 | sa.ForeignKeyConstraint(["provider_key_id"], ["provider_keys.id"], ondelete="CASCADE"),
39 | sa.ForeignKeyConstraint(["forge_key_id"], ["forge_api_keys.id"], ondelete="CASCADE"),
40 | )
41 |
42 |
43 | def downgrade() -> None:
44 | op.drop_table("usage_tracker")
45 |
--------------------------------------------------------------------------------
/alembic/versions/39bcedfae4fe_add_model_default_pricing.py:
--------------------------------------------------------------------------------
1 | """add model default pricing
2 |
3 | Revision ID: 39bcedfae4fe
4 | Revises: b206e9a941e3
5 | Create Date: 2025-08-14 18:31:20.897283
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 |
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = '39bcedfae4fe'
14 | down_revision = 'b206e9a941e3'
15 | branch_labels = None
16 | depends_on = None
17 |
18 |
19 | def upgrade() -> None:
20 | op.add_column('fallback_pricing', sa.Column('model_name', sa.String(), nullable=True))
21 | connection = op.get_bind()
22 | connection.execute(sa.text("""
23 | insert into fallback_pricing (provider_name, model_name, input_token_price, output_token_price, cached_token_price, currency, created_at, updated_at, effective_date, end_date, fallback_type, description)
24 | select distinct on (model_name)
25 | provider_name as provider_name,
26 | model_name as model_name,
27 | input_token_price,
28 | output_token_price,
29 | cached_token_price,
30 | currency,
31 | created_at,
32 | updated_at,
33 | effective_date,
34 | end_date,
35 | 'model_default' as fallback_type,
36 | null as description
37 | from model_pricing
38 | order by model_name,
39 | case when provider_name = 'openai' then 1
40 | when provider_name = 'anthropic' then 2
41 | else 3 end
42 | """))
43 |
44 |
45 | def downgrade() -> None:
46 | connection = op.get_bind()
47 | connection.execute(sa.text("""
48 | delete from fallback_pricing where fallback_type = 'model_default'
49 | """))
50 | op.drop_column('fallback_pricing', 'model_name')
51 |
--------------------------------------------------------------------------------
/tests/unit_tests/assets/openai/responses_response_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "id": "resp_0a379700213743ce0068db3cded47481a29d3552eea69e6939",
3 | "object": "response",
4 | "created_at": 1759198430,
5 | "status": "completed",
6 | "background": false,
7 | "billing": {"payer": "developer"},
8 | "error": null,
9 | "incomplete_details": null,
10 | "instructions": null,
11 | "max_output_tokens": null,
12 | "max_tool_calls": null,
13 | "model": "gpt-4o-mini-2024-07-18",
14 | "output": [
15 | {
16 | "id": "msg_0a379700213743ce0068db3cdf654481a29d5a27842a0e095f",
17 | "type": "message",
18 | "status": "completed",
19 | "content": [
20 | {
21 | "type": "output_text",
22 | "annotations": [],
23 | "logprobs": [],
24 | "text": "Hello! I'm doing well, thank you. How about you?"
25 | }
26 | ],
27 | "role": "assistant"
28 | }
29 | ],
30 | "parallel_tool_calls": true,
31 | "previous_response_id": null,
32 | "prompt_cache_key": null,
33 | "reasoning": {"effort": null, "summary": null},
34 | "safety_identifier": null,
35 | "service_tier": "default",
36 | "store": true,
37 | "temperature": 1.0,
38 | "text": {"format": {"type": "text"}, "verbosity": "medium"},
39 | "tool_choice": "auto",
40 | "tools": [],
41 | "top_logprobs": 0,
42 | "top_p": 1.0,
43 | "truncation": "disabled",
44 | "usage": {
45 | "input_tokens": 13,
46 | "input_tokens_details": {"cached_tokens": 0},
47 | "output_tokens": 14,
48 | "output_tokens_details": {"reasoning_tokens": 0},
49 | "total_tokens": 27
50 | },
51 | "user": null,
52 | "metadata": {}
53 | }
54 |
--------------------------------------------------------------------------------
/app/api/schemas/forge_api_key.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | from pydantic import BaseModel, ConfigDict, Field
4 |
5 |
6 | class ForgeApiKeyBase(BaseModel):
7 | """Base schema for ForgeApiKey"""
8 |
9 | name: str | None = None
10 |
11 |
12 | class ForgeApiKeyCreate(ForgeApiKeyBase):
13 | """Schema for creating a ForgeApiKey"""
14 |
15 | allowed_provider_key_ids: list[int] | None = Field(default=None)
16 |
17 |
18 | class ForgeApiKeyResponse(ForgeApiKeyBase):
19 | """Schema for ForgeApiKey response"""
20 |
21 | id: int
22 | key: str
23 | is_active: bool
24 | created_at: datetime
25 | last_used_at: datetime | None = None
26 | allowed_provider_key_ids: list[int] = Field(default_factory=list)
27 | model_config = ConfigDict(from_attributes=True)
28 |
29 |
30 | class ForgeApiKeyMasked(ForgeApiKeyBase):
31 | """Schema for ForgeApiKey with masked key for display"""
32 |
33 | id: int
34 | key: str
35 | is_active: bool
36 | created_at: datetime
37 | last_used_at: datetime | None = None
38 | allowed_provider_key_ids: list[int] = Field(default_factory=list)
39 |
40 | @staticmethod
41 | def mask_api_key(key: str) -> str:
42 | """Mask the API key for display"""
43 | if not key:
44 | return ""
45 | if key.startswith("forge-"):
46 | prefix = "forge-"
47 | key_part = key[len(prefix) :]
48 | return f"{prefix}{'*' * (len(key_part) - 4)}{key_part[-4:]}"
49 | return f"{'*' * (len(key) - 4)}{key[-4:]}"
50 |
51 | model_config = ConfigDict(from_attributes=True)
52 |
53 |
54 | class ForgeApiKeyUpdate(BaseModel):
55 | """Schema for updating a ForgeApiKey"""
56 |
57 | name: str | None = None
58 | is_active: bool | None = None
59 | allowed_provider_key_ids: list[int] | None = None
60 |
--------------------------------------------------------------------------------
/alembic/versions/08cc005a4bc8_create_forge_api_key_provider_scope_.py:
--------------------------------------------------------------------------------
1 | """create_forge_api_key_provider_scope_association_table
2 |
3 | Revision ID: 08cc005a4bc8
4 | Revises: 6a92c2663fa5
5 | Create Date: 2025-05-16 13:53:06.169215
6 |
7 | """
8 | import sqlalchemy as sa
9 |
10 | from alembic import op
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = "08cc005a4bc8"
14 | down_revision = "6a92c2663fa5"
15 | branch_labels = None
16 | depends_on = None
17 |
18 |
19 | def upgrade() -> None:
20 | # ### commands auto generated by Alembic - please adjust! ###
21 | op.create_table(
22 | "forge_api_key_provider_scope_association",
23 | sa.Column("forge_api_key_id", sa.Integer(), nullable=False),
24 | sa.Column("provider_key_id", sa.Integer(), nullable=False),
25 | sa.ForeignKeyConstraint(
26 | ["forge_api_key_id"],
27 | ["forge_api_keys.id"],
28 | name=op.f(
29 | "fk_forge_api_key_provider_scope_association_forge_api_key_id_forge_api_keys"
30 | ),
31 | ondelete="CASCADE",
32 | ),
33 | sa.ForeignKeyConstraint(
34 | ["provider_key_id"],
35 | ["provider_keys.id"],
36 | name=op.f(
37 | "fk_forge_api_key_provider_scope_association_provider_key_id_provider_keys"
38 | ),
39 | ondelete="CASCADE",
40 | ),
41 | sa.PrimaryKeyConstraint(
42 | "forge_api_key_id",
43 | "provider_key_id",
44 | name=op.f("pk_forge_api_key_provider_scope_association"),
45 | ),
46 | )
47 | # ### end Alembic commands ###
48 |
49 |
50 | def downgrade() -> None:
51 | # ### commands auto generated by Alembic - please adjust! ###
52 | op.drop_table("forge_api_key_provider_scope_association")
53 | # ### end Alembic commands ###
54 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.12-slim
2 |
3 | WORKDIR /app
4 |
5 | # Define build-time arguments
6 | ARG ARG_CLERK_JWT_PUBLIC_KEY
7 | ARG ARG_CLERK_API_KEY
8 | ARG ARG_CLERK_API_URL
9 | ARG ARG_DEBUG_LOGGING=false
10 |
11 | # Set runtime environment variables from build-time arguments
12 | # WARNING: These values will be baked into the image. For sensitive data, prefer runtime injection.
13 | ENV CLERK_JWT_PUBLIC_KEY=${ARG_CLERK_JWT_PUBLIC_KEY}
14 | ENV CLERK_API_KEY=${ARG_CLERK_API_KEY}
15 | # Example with a default value if ARG is not set for CLERK_API_URL
16 | ENV CLERK_API_URL=${ARG_CLERK_API_URL:-https://api.clerk.dev/v1}
17 | ENV FORGE_DEBUG_LOGGING=${ARG_DEBUG_LOGGING}
18 |
19 | # Database connection optimization environment variables
20 | # These settings optimize for PostgreSQL connection limits
21 | ENV DB_POOL_SIZE=3
22 | ENV DB_MAX_OVERFLOW=2
23 | ENV DB_POOL_TIMEOUT=30
24 | ENV DB_POOL_RECYCLE=1800
25 | ENV DB_POOL_PRE_PING=true
26 |
27 | # Reduced worker count to manage database connections
28 | # With 5 workers: max 60 connections (5 × 3 × 2 engines + 5 × 2 × 2 overflow = 50 connections)
29 | ENV WORKERS=5
30 |
31 | # Install system dependencies including PostgreSQL client and gosu for user privilege management
32 | RUN apt-get update && apt-get install -y \
33 | postgresql-client \
34 | && rm -rf /var/lib/apt/lists/*
35 |
36 | RUN mkdir -p /app/logs && \
37 | chown -R nobody:nogroup /app/logs && \
38 | chmod -R 777 /app/logs
39 |
40 | # Copy project files
41 | COPY . .
42 |
43 | # Install dependencies using pip
44 | RUN pip install -e .
45 |
46 | # Switch to non-root user for security
47 | USER nobody
48 |
49 | # Expose port
50 | EXPOSE 8000
51 |
52 | # Use environment variable for workers count and optimize for database connections
53 | CMD ["sh", "-c", "gunicorn app.main:app -k uvicorn.workers.UvicornWorker --workers ${WORKERS:-5} --bind 0.0.0.0:8000 --timeout 120 --max-requests 1000 --max-requests-jitter 100"]
54 |
--------------------------------------------------------------------------------
/tests/unit_tests/assets/google/chat_completion_streaming_response_1.json:
--------------------------------------------------------------------------------
1 | [
2 | "[{\n",
3 | " \"candidates\": [\n",
4 | " {\n",
5 | " \"content\": {\n",
6 | " \"parts\": [\n",
7 | " {\n",
8 | " \"text\": \"I\"\n",
9 | " }\n",
10 | " ],\n",
11 | " \"role\": \"model\"\n",
12 | " }\n",
13 | " }\n",
14 | " ],\n",
15 | " \"usageMetadata\": {\n",
16 | " \"promptTokenCount\": 6,\n",
17 | " \"totalTokenCount\": 6,\n",
18 | " \"promptTokensDetails\": [\n",
19 | " {\n",
20 | " \"modality\": \"TEXT\",\n",
21 | " \"tokenCount\": 6\n",
22 | " }\n",
23 | " ]\n",
24 | " },\n",
25 | " \"modelVersion\": \"gemini-1.5-pro-002\",\n",
26 | " \"responseId\": \"ASVeaJ69O57ImNAPh7Sh4A0\"\n",
27 | "}\n",
28 | ",\r\n",
29 | "{\n",
30 | " \"candidates\": [\n",
31 | " {\n",
32 | " \"content\": {\n",
33 | " \"parts\": [\n",
34 | " {\n",
35 | " \"text\": \" am doing well, thank you for asking. How are you today?\\n\"\n",
36 | " }\n",
37 | " ],\n",
38 | " \"role\": \"model\"\n",
39 | " },\n",
40 | " \"finishReason\": \"STOP\"\n",
41 | " }\n",
42 | " ],\n",
43 | " \"usageMetadata\": {\n",
44 | " \"promptTokenCount\": 6,\n",
45 | " \"candidatesTokenCount\": 16,\n",
46 | " \"totalTokenCount\": 22,\n",
47 | " \"promptTokensDetails\": [\n",
48 | " {\n",
49 | " \"modality\": \"TEXT\",\n",
50 | " \"tokenCount\": 6\n",
51 | " }\n",
52 | " ],\n",
53 | " \"candidatesTokensDetails\": [\n",
54 | " {\n",
55 | " \"modality\": \"TEXT\",\n",
56 | " \"tokenCount\": 16\n",
57 | " }\n",
58 | " ]\n",
59 | " },\n",
60 | " \"modelVersion\": \"gemini-1.5-pro-002\",\n",
61 | " \"responseId\": \"ASVeaJ69O57ImNAPh7Sh4A0\"\n",
62 | "}\n",
63 | "]"
64 | ]
--------------------------------------------------------------------------------
/app/core/security.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime, timedelta
3 |
4 | from cryptography.fernet import Fernet
5 | from dotenv import load_dotenv
6 | from jose import jwt
7 | from passlib.context import CryptContext
8 |
9 | load_dotenv()
10 |
11 | # Password hashing
12 | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
13 |
14 | # JWT settings
15 | SECRET_KEY = os.getenv("SECRET_KEY", "your_secret_key_here")
16 | ALGORITHM = os.getenv("ALGORITHM", "HS256")
17 | ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "60"))
18 |
19 | # Encryption for API keys
20 | ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY", Fernet.generate_key().decode())
21 | # Initialize with a direct approach to avoid indentation issues
22 | if isinstance(ENCRYPTION_KEY, str):
23 | fernet = Fernet(ENCRYPTION_KEY.encode())
24 | else:
25 | fernet = Fernet(ENCRYPTION_KEY)
26 |
27 |
28 | def verify_password(plain_password, hashed_password):
29 | return pwd_context.verify(plain_password, hashed_password)
30 |
31 |
32 | def get_password_hash(password):
33 | return pwd_context.hash(password)
34 |
35 |
36 | def create_access_token(data: dict, expires_delta: timedelta | None = None):
37 | to_encode = data.copy()
38 | if expires_delta:
39 | expire = datetime.utcnow() + expires_delta
40 | else:
41 | expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
42 | to_encode.update({"exp": expire})
43 | encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
44 | return encoded_jwt
45 |
46 |
47 | def encrypt_api_key(api_key: str) -> str:
48 | """Encrypt an API key"""
49 | return fernet.encrypt(api_key.encode()).decode()
50 |
51 |
52 | def decrypt_api_key(encrypted_api_key: str) -> str:
53 | """Decrypt an API key"""
54 | return fernet.decrypt(encrypted_api_key.encode()).decode()
55 |
56 |
57 | def generate_forge_api_key() -> str:
58 | """
59 | Generate a unique Forge API key.
60 | """
61 | import secrets
62 |
63 | return f"forge-{secrets.token_hex(18)}"
64 |
--------------------------------------------------------------------------------
/app/models/forge_api_key.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Table
4 | from sqlalchemy.orm import relationship
5 |
6 | from app.models.base import Base
7 |
8 | # Association Table for ForgeApiKey and ProviderKey
9 | forge_api_key_provider_scope_association = Table(
10 | "forge_api_key_provider_scope_association",
11 | Base.metadata,
12 | Column(
13 | "forge_api_key_id",
14 | Integer,
15 | ForeignKey("forge_api_keys.id", ondelete="CASCADE"),
16 | primary_key=True,
17 | ),
18 | Column(
19 | "provider_key_id",
20 | Integer,
21 | ForeignKey("provider_keys.id", ondelete="CASCADE"),
22 | primary_key=True,
23 | ),
24 | )
25 |
26 |
27 | class ForgeApiKey(Base):
28 | """Model for storing multiple Forge API keys per user"""
29 |
30 | __tablename__ = "forge_api_keys"
31 |
32 | id = Column(Integer, primary_key=True, index=True)
33 | key = Column(String, unique=True, index=True, nullable=False)
34 | name = Column(String, nullable=True) # Optional name/description for the key
35 | user_id = Column(
36 | Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False
37 | )
38 | is_active = Column(Boolean, default=True)
39 | created_at = Column(DateTime, default=datetime.datetime.utcnow)
40 | last_used_at = Column(DateTime, nullable=True)
41 | deleted_at = Column(DateTime(timezone=True), nullable=True)
42 |
43 | # Relationship to user
44 | user = relationship("User", back_populates="api_keys")
45 |
46 | # Relationship to allowed ProviderKeys (scope)
47 | allowed_provider_keys = relationship(
48 | "ProviderKey",
49 | secondary=forge_api_key_provider_scope_association,
50 | back_populates="scoped_forge_api_keys",
51 | lazy="selectin", # Use selectin loading for efficiency when accessing this relationship
52 | )
53 |
54 | # Optionally, we could have a relationship to enabled provider keys
55 | # enabled_provider_keys = relationship("EnabledProviderKey", back_populates="api_key")
56 |
--------------------------------------------------------------------------------
/tools/diagnostics/enable_request_logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Enable request logging by setting LOG_LEVEL environment variable.
4 | This script modifies the .env file to adjust the logging level for the Forge server.
5 | """
6 |
7 | import os
8 | import sys
9 | from pathlib import Path
10 |
11 | # Add the project root to the Python path
12 | script_dir = Path(__file__).resolve().parent.parent.parent
13 | sys.path.insert(0, str(script_dir))
14 |
15 | # Change to the project root directory
16 | os.chdir(script_dir)
17 |
18 |
19 | def enable_request_logging():
20 | """Enable request logging by setting LOG_LEVEL=debug in .env file"""
21 | env_file = Path(".env")
22 |
23 | # Read existing .env file
24 | if env_file.exists():
25 | with open(env_file) as f:
26 | lines = f.readlines()
27 | else:
28 | lines = []
29 |
30 | # Check if LOG_LEVEL already exists
31 | log_level_exists = False
32 | for i, line in enumerate(lines):
33 | if line.strip().startswith("LOG_LEVEL="):
34 | lines[i] = "LOG_LEVEL=debug\n"
35 | log_level_exists = True
36 | break
37 |
38 | # Add LOG_LEVEL if it doesn't exist
39 | if not log_level_exists:
40 | lines.append("\n# Set logging level for server\nLOG_LEVEL=debug\n")
41 |
42 | # Write back to .env file
43 | with open(env_file, "w") as f:
44 | f.writelines(lines)
45 |
46 | print("✅ Request logging enabled in .env file")
47 | print("🔍 Log level set to 'debug' to show detailed request/response information")
48 | print("ℹ️ Restart your server for changes to take effect")
49 | print("\nTo see all requests in the server logs, restart with:")
50 | print(" python run.py")
51 |
52 | return True
53 |
54 |
55 | def main():
56 | """Main entry point"""
57 | if enable_request_logging():
58 | print("\n✨ Your server will now show detailed request logs")
59 | sys.exit(0)
60 | else:
61 | print("\n❌ Failed to enable request logging")
62 | sys.exit(1)
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/alembic/versions/4a82fb8af123_create_forge_api_keys_table.py:
--------------------------------------------------------------------------------
1 | """create_forge_api_keys_table
2 |
3 | Revision ID: 4a82fb8af123
4 | Revises: b5d4363a9f62
5 | Create Date: 2023-05-15 12:00:00.000000
6 |
7 | """
8 | import sqlalchemy as sa
9 |
10 | from alembic import op
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = "4a82fb8af123"
14 | down_revision = "b5d4363a9f62" # Update this to point to your latest migration
15 | branch_labels = None
16 | depends_on = None
17 |
18 |
19 | def upgrade():
20 | # Create forge_api_keys table
21 | op.create_table(
22 | "forge_api_keys",
23 | sa.Column("id", sa.Integer(), nullable=False),
24 | sa.Column("key", sa.String(), nullable=False),
25 | sa.Column("name", sa.String(), nullable=True),
26 | sa.Column("user_id", sa.Integer(), nullable=False),
27 | sa.Column("is_active", sa.Boolean(), default=True),
28 | sa.Column("created_at", sa.DateTime(), default=sa.func.now()),
29 | sa.Column("last_used_at", sa.DateTime(), nullable=True),
30 | sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
31 | sa.PrimaryKeyConstraint("id"),
32 | )
33 |
34 | # Create indexes
35 | op.create_index(
36 | op.f("ix_forge_api_keys_id"), "forge_api_keys", ["id"], unique=False
37 | )
38 | op.create_index(
39 | op.f("ix_forge_api_keys_key"), "forge_api_keys", ["key"], unique=True
40 | )
41 |
42 | # Migrate existing keys
43 | # This will create a new entry in forge_api_keys for each user's existing forge_api_key
44 | op.execute(
45 | """
46 | INSERT INTO forge_api_keys (key, name, user_id, is_active, created_at)
47 | SELECT forge_api_key, 'Legacy API Key', id, is_active, created_at
48 | FROM users
49 | WHERE forge_api_key IS NOT NULL
50 | """
51 | )
52 |
53 |
54 | def downgrade():
55 | # Drop the table
56 | op.drop_index(op.f("ix_forge_api_keys_key"), table_name="forge_api_keys")
57 | op.drop_index(op.f("ix_forge_api_keys_id"), table_name="forge_api_keys")
58 | op.drop_table("forge_api_keys")
59 |
--------------------------------------------------------------------------------
/tools/diagnostics/clear_cache.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Utility script to clear all cache entries.
4 | This is useful for testing and diagnosing cache-related issues.
5 | """
6 |
7 | import os
8 | import sys
9 | from pathlib import Path
10 |
11 | from app.core.cache import provider_service_cache, user_cache
12 |
13 | # Add the project root to the Python path
14 | script_dir = Path(__file__).resolve().parent.parent.parent
15 | sys.path.insert(0, str(script_dir))
16 |
17 | # Change to the project root directory
18 | os.chdir(script_dir)
19 |
20 |
21 | def clear_caches():
22 | """Clear all cache entries and print statistics"""
23 | print("🔄 Clearing caches...")
24 |
25 | # Get cache stats before clearing
26 | user_stats_before = user_cache.stats()
27 | provider_stats_before = provider_service_cache.stats()
28 |
29 | # Clear caches
30 | user_cache.clear()
31 | provider_service_cache.clear()
32 |
33 | # Get cache stats after clearing
34 | user_stats_after = user_cache.stats()
35 | provider_stats_after = provider_service_cache.stats()
36 |
37 | # Print results
38 | print("\n✅ Cache clearing complete!")
39 | print("\n📊 User Cache Statistics:")
40 | print(
41 | f" Before: {user_stats_before['entries']} entries, {user_stats_before['hits']} hits, {user_stats_before['misses']} misses"
42 | )
43 | print(f" After: {user_stats_after['entries']} entries")
44 |
45 | print("\n📊 Provider Service Cache Statistics:")
46 | print(
47 | f" Before: {provider_stats_before['entries']} entries, {provider_stats_before['hits']} hits, {provider_stats_before['misses']} misses"
48 | )
49 | print(f" After: {provider_stats_after['entries']} entries")
50 |
51 | print("\n🔍 Cache hit rates:")
52 | print(f" User Cache: {user_stats_before['hit_rate']:.2%}")
53 | print(f" Provider Cache: {provider_stats_before['hit_rate']:.2%}")
54 |
55 | return True
56 |
57 |
58 | def main():
59 | """Main entry point"""
60 | if clear_caches():
61 | print("\n✨ All caches have been cleared successfully.")
62 | sys.exit(0)
63 | else:
64 | print("\n❌ Failed to clear caches.")
65 | sys.exit(1)
66 |
67 |
68 | if __name__ == "__main__":
69 | main()
70 |
--------------------------------------------------------------------------------
/tests/unit_tests/assets/anthropic/chat_completion_streaming_response_1.json:
--------------------------------------------------------------------------------
1 | [
2 | "event: message_start\n",
3 | "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_011R37jSCigty4xP95pkskoU\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-20250514\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":13,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":0,\"output_tokens\":1,\"service_tier\":\"standard\"}} }\n",
4 | "\n",
5 | "event: content_block_start\n",
6 | "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"} }\n",
7 | "\n",
8 | "event: ping\n",
9 | "data: {\"type\": \"ping\"}\n",
10 | "\n",
11 | "event: content_block_delta\n",
12 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"} }\n",
13 | "\n",
14 | "event: content_block_delta\n",
15 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"! I'm doing well, thank you\"} }\n",
16 | "\n",
17 | "event: content_block_delta\n",
18 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" for asking. I'm here and ready to help with\"} }\n",
19 | "\n",
20 | "event: content_block_delta\n",
21 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" whatever you'd like to discuss\"} }\n",
22 | "\n",
23 | "event: content_block_delta\n",
24 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" or work on. How\"} }\n",
25 | "\n",
26 | "event: content_block_delta\n",
27 | "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\" are you doing today?\"} }\n",
28 | "\n",
29 | "event: content_block_stop\n",
30 | "data: {\"type\":\"content_block_stop\",\"index\":0 }\n",
31 | "\n",
32 | "event: message_delta\n",
33 | "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":39} }\n",
34 | "\n",
35 | "event: message_stop\n",
36 | "data: {\"type\":\"message_stop\" }\n",
37 | "\n"
38 | ]
--------------------------------------------------------------------------------
/tools/diagnostics/check_dotenv.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Script to check how dotenv loads values from the .env file."""
3 |
4 | import os
5 | import sys
6 | from pathlib import Path
7 |
8 | from dotenv import find_dotenv, load_dotenv
9 |
10 |
11 | def check_env_file(file_path):
12 | """Check if a file exists and print its contents."""
13 | path = Path(file_path)
14 | if path.exists():
15 | print(f"File exists: {path.absolute()}")
16 | print("File contents:")
17 | with open(path) as f:
18 | for line in f:
19 | if line.strip() and "=" in line:
20 | key, value = line.strip().split("=", 1)
21 | if key == "FORGE_API_KEY":
22 | print(f"{key}={value[:8]}...")
23 | else:
24 | print(line.strip())
25 | else:
26 | print(f"File does not exist: {path.absolute()}")
27 |
28 |
29 | def main():
30 | """Main function to check environment loading."""
31 | # Find the project root directory
32 | script_dir = os.path.dirname(
33 | os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
34 | )
35 | os.chdir(script_dir) # Change to project root directory
36 |
37 | print("Python version:", sys.version)
38 | print("Current working directory:", os.getcwd())
39 |
40 | # Check local .env file
41 | print("\nChecking .env file in current directory:")
42 | check_env_file(".env")
43 |
44 | # Check find_dotenv
45 | print("\nUsing find_dotenv to locate .env file:")
46 | found_dotenv = find_dotenv()
47 | print(f"Found .env at: {found_dotenv}")
48 | if found_dotenv:
49 | check_env_file(found_dotenv)
50 |
51 | # Try loading with load_dotenv
52 | print("\nLoading environment variables with load_dotenv:")
53 | load_dotenv(verbose=True)
54 |
55 | # Check if variable was loaded
56 | api_key = os.getenv("FORGE_API_KEY", "")
57 | if api_key:
58 | print(f"FORGE_API_KEY loaded: {api_key[:8]}...")
59 | else:
60 | print("FORGE_API_KEY not loaded")
61 |
62 | # Check all environment variables (not just from .env)
63 | print("\nAll environment variables containing 'FORGE':")
64 | for key, value in os.environ.items():
65 | if "FORGE" in key:
66 | print(
67 | f"{key}={value[:8]}..." if key == "FORGE_API_KEY" else f"{key}={value}"
68 | )
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
--------------------------------------------------------------------------------
/app/api/routes/admin.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, Depends, HTTPException
2 | from sqlalchemy import select
3 | from decimal import Decimal
4 | from pydantic import BaseModel
5 | import uuid
6 |
7 | from app.api.dependencies import get_current_active_admin_user_from_clerk
8 | from app.core.database import get_async_db
9 | from sqlalchemy.ext.asyncio import AsyncSession
10 | from app.models.user import User
11 | from app.models.stripe import StripePayment
12 | from app.core.logger import get_logger
13 | from app.api.schemas.admin import AddBalanceRequest
14 | from app.services.wallet_service import WalletService
15 |
16 | logger = get_logger(name="admin")
17 | router = APIRouter()
18 |
19 | class AddBalanceResponse(BaseModel):
20 | balance: Decimal
21 | blocked: bool
22 |
23 |
24 | @router.post("/add-balance")
25 | async def add_balance(
26 | add_balance_request: AddBalanceRequest,
27 | current_user: User = Depends(get_current_active_admin_user_from_clerk),
28 | db: AsyncSession = Depends(get_async_db),
29 | ):
30 | """Add balance to a user"""
31 | user_id = add_balance_request.user_id
32 | email = add_balance_request.email
33 | amount = add_balance_request.amount
34 |
35 | result = await db.execute(
36 | select(User)
37 | .where(
38 | user_id is None or User.id == user_id,
39 | email is None or User.email == email,
40 | )
41 | )
42 | user = result.scalar_one_or_none()
43 | if not user:
44 | raise HTTPException(status_code=404, detail="User not found")
45 |
46 | amount_decimal = Decimal(amount / 100.0)
47 | result = await WalletService.adjust(db, user.id, amount_decimal, f"Admin {current_user.id} added balance for user {user.id}")
48 | if not result.get("success"):
49 | raise HTTPException(status_code=400, detail=f"Failed to add balance for user {user.id}: {result.get('reason')}")
50 |
51 | # add the amount to the user's stripe payment
52 | stripe_payment = StripePayment(
53 | id=f"tb_admin_{uuid.uuid4().hex}",
54 | user_id=user.id,
55 | amount=amount,
56 | currency="USD",
57 | status="completed",
58 | raw_data={"reason": f"Admin {current_user.id} added balance for user {user.id}"},
59 | )
60 | db.add(stripe_payment)
61 | await db.commit()
62 | logger.info(f"Added balance {amount_decimal} for user {user.id} by admin {current_user.id}")
63 |
64 | return AddBalanceResponse(balance=result.get("balance"), blocked=result.get("blocked"))
65 |
--------------------------------------------------------------------------------
/alembic/versions/ca1ac51334ec_create_usage_stats_table.py:
--------------------------------------------------------------------------------
1 | """Create usage_stats table
2 |
3 | Revision ID: ca1ac51334ec
4 | Revises: initial_migration
5 | Create Date: 2024-08-05 21:50:30.028279
6 |
7 | """
8 |
9 | import sqlalchemy as sa
10 |
11 | from alembic import op
12 |
13 | # revision identifiers, used by Alembic.
14 | revision = "ca1ac51334ec"
15 | down_revision = "initial_migration" # Set the correct previous migration
16 | branch_labels = None
17 | depends_on = None
18 |
19 |
20 | def upgrade() -> None:
21 | # ### commands auto generated by Alembic - please adjust! ###
22 | op.create_table(
23 | "usage_stats",
24 | sa.Column("id", sa.Integer(), nullable=False),
25 | sa.Column("user_id", sa.Integer(), nullable=True),
26 | sa.Column("provider_name", sa.String(), nullable=True),
27 | sa.Column("model", sa.String(), nullable=True),
28 | sa.Column("request_count", sa.Integer(), nullable=True),
29 | sa.Column("success_count", sa.Integer(), nullable=True),
30 | sa.Column("prompt_tokens", sa.Integer(), nullable=True),
31 | sa.Column("completion_tokens", sa.Integer(), nullable=True),
32 | sa.Column("error_count", sa.Integer(), nullable=True),
33 | sa.Column("timestamp", sa.DateTime(), nullable=True),
34 | sa.ForeignKeyConstraint(
35 | ["user_id"], ["users.id"], name=op.f("fk_usage_stats_user_id_users")
36 | ),
37 | sa.PrimaryKeyConstraint("id", name=op.f("pk_usage_stats")),
38 | )
39 | op.create_index(op.f("ix_usage_stats_id"), "usage_stats", ["id"], unique=False)
40 | op.create_index(
41 | op.f("ix_usage_stats_user_id"), "usage_stats", ["user_id"], unique=False
42 | )
43 | op.create_index(
44 | op.f("ix_usage_stats_provider_name"),
45 | "usage_stats",
46 | ["provider_name"],
47 | unique=False,
48 | )
49 | op.create_index(
50 | op.f("ix_usage_stats_model"), "usage_stats", ["model"], unique=False
51 | )
52 |
53 | # ### end Alembic commands ###
54 |
55 |
56 | def downgrade() -> None:
57 | # ### commands auto generated by Alembic - please adjust! ###
58 | op.drop_index(op.f("ix_usage_stats_model"), table_name="usage_stats")
59 | op.drop_index(op.f("ix_usage_stats_provider_name"), table_name="usage_stats")
60 | op.drop_index(op.f("ix_usage_stats_user_id"), table_name="usage_stats")
61 | op.drop_index(op.f("ix_usage_stats_id"), table_name="usage_stats")
62 | op.drop_table("usage_stats")
63 | # ### end Alembic commands ###
64 |
--------------------------------------------------------------------------------
/tests/mock_testing/run_mock_tests.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Run all mock tests from a single script.
4 | """
5 |
6 | import asyncio
7 | import os
8 | import subprocess
9 | import sys
10 |
11 | # Add the parent directory to the path
12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
13 |
14 |
15 | def print_header(text):
16 | """Print a header with decoration"""
17 | print("\n" + "=" * 80)
18 | print(f" {text} ".center(80, "="))
19 | print("=" * 80)
20 |
21 |
22 | async def run_mock_client_tests():
23 | """Run the mock client tests"""
24 | print_header("Running Mock Client Tests")
25 |
26 | # Get the path to the test_mock_client.py script
27 | script_path = os.path.join(os.path.dirname(__file__), "test_mock_client.py")
28 |
29 | # Run the script as a subprocess
30 | result = subprocess.run(
31 | ["python", script_path], capture_output=True, text=True, check=False
32 | )
33 |
34 | # Print the output
35 | if result.stdout:
36 | print(result.stdout)
37 | if result.stderr:
38 | print(result.stderr, file=sys.stderr)
39 |
40 | return result.returncode == 0
41 |
42 |
43 | async def run_example_tests():
44 | """Run the example tests"""
45 | print_header("Running Example Tests")
46 |
47 | # Get the path to the test_with_mocks.py script
48 | script_path = os.path.join(
49 | os.path.dirname(__file__), "examples", "test_with_mocks.py"
50 | )
51 |
52 | # Run the script as a subprocess with unittest discover
53 | result = subprocess.run(
54 | ["python", "-m", "unittest", script_path],
55 | capture_output=True,
56 | text=True,
57 | check=False,
58 | )
59 |
60 | # Print the output
61 | if result.stdout:
62 | print(result.stdout)
63 | if result.stderr:
64 | print(result.stderr, file=sys.stderr)
65 |
66 | return result.returncode == 0
67 |
68 |
69 | async def main():
70 | """Run all tests"""
71 | client_tests_ok = await run_mock_client_tests()
72 | example_tests_ok = await run_example_tests()
73 |
74 | print_header("Summary")
75 | print(f"Mock Client Tests: {'PASSED' if client_tests_ok else 'FAILED'}")
76 | print(f"Example Tests: {'PASSED' if example_tests_ok else 'FAILED'}")
77 |
78 | if client_tests_ok and example_tests_ok:
79 | print("\n✅ All mock tests passed!")
80 | return 0
81 | else:
82 | print("\n❌ Some tests failed.")
83 | return 1
84 |
85 |
86 | if __name__ == "__main__":
87 | exit_code = asyncio.run(main())
88 | sys.exit(exit_code)
89 |
--------------------------------------------------------------------------------
/alembic/versions/0ce4eeae965f_drop_usage_stats_table.py:
--------------------------------------------------------------------------------
1 | """Drop usage_stats table
2 |
3 | Revision ID: 0ce4eeae965f
4 | Revises: b38aad374524
5 | Create Date: 2025-04-08 21:34:42.905629
6 |
7 | """
8 |
9 | import sqlalchemy as sa
10 |
11 | from alembic import op
12 |
13 | # revision identifiers, used by Alembic.
14 | revision = "0ce4eeae965f"
15 | down_revision = "b38aad374524"
16 | branch_labels = None
17 | depends_on = None
18 |
19 |
20 | def upgrade() -> None:
21 | # ### commands auto generated by Alembic - please adjust! ###
22 | with op.batch_alter_table("usage_stats", schema=None) as batch_op:
23 | batch_op.drop_index("ix_usage_stats_id")
24 | batch_op.drop_index("ix_usage_stats_model")
25 | batch_op.drop_index("ix_usage_stats_provider_name")
26 |
27 | op.drop_table("usage_stats")
28 | # ### end Alembic commands ###
29 |
30 |
31 | def downgrade() -> None:
32 | # ### commands auto generated by Alembic - please adjust! ###
33 | op.create_table(
34 | "usage_stats",
35 | sa.Column("id", sa.INTEGER(), nullable=False),
36 | sa.Column("user_id", sa.INTEGER(), nullable=False),
37 | sa.Column("provider_name", sa.VARCHAR(), nullable=False),
38 | sa.Column("model", sa.VARCHAR(), nullable=False),
39 | sa.Column("prompt_tokens", sa.INTEGER(), nullable=True),
40 | sa.Column("completion_tokens", sa.INTEGER(), nullable=True),
41 | sa.Column("total_tokens", sa.INTEGER(), nullable=True),
42 | sa.Column("request_count", sa.INTEGER(), nullable=True),
43 | sa.Column("success_count", sa.INTEGER(), nullable=True),
44 | sa.Column("error_count", sa.INTEGER(), nullable=True),
45 | sa.Column("timestamp", sa.DATETIME(), nullable=True),
46 | sa.Column("endpoint", sa.VARCHAR(), nullable=False),
47 | sa.Column("input_tokens", sa.INTEGER(), nullable=True),
48 | sa.Column("output_tokens", sa.INTEGER(), nullable=True),
49 | sa.Column("requests_count", sa.INTEGER(), nullable=True),
50 | sa.Column("cost", sa.FLOAT(), nullable=True),
51 | sa.Column("created_at", sa.DATETIME(), nullable=True),
52 | sa.ForeignKeyConstraint(
53 | ["user_id"],
54 | ["users.id"],
55 | ),
56 | sa.PrimaryKeyConstraint("id"),
57 | )
58 | with op.batch_alter_table("usage_stats", schema=None) as batch_op:
59 | batch_op.create_index(
60 | "ix_usage_stats_provider_name", ["provider_name"], unique=False
61 | )
62 | batch_op.create_index("ix_usage_stats_model", ["model"], unique=False)
63 | batch_op.create_index("ix_usage_stats_id", ["id"], unique=False)
64 |
65 | # ### end Alembic commands ###
66 |
--------------------------------------------------------------------------------
/app/api/schemas/statistic.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, field_validator
2 | from datetime import datetime
3 | import re
4 | import decimal
5 |
6 | from app.api.schemas.forge_api_key import ForgeApiKeyMasked
7 |
8 | def mask_forge_name_or_key(v: str) -> str:
9 | # If the forge key is a valid forge key, mask it
10 | if re.match(r"forge-\w{18}", v):
11 | return ForgeApiKeyMasked.mask_api_key(v)
12 | # Otherwise, return the original value (user customized name)
13 | return v
14 |
15 | class UsageRealtimeItem(BaseModel):
16 | timestamp: datetime | str
17 | forge_key: str
18 | provider_name: str
19 | model_name: str
20 | tokens: int
21 | input_tokens: int
22 | output_tokens: int
23 | cached_tokens: int
24 | duration: float
25 | cost: decimal.Decimal
26 | billable: bool
27 |
28 | @field_validator('forge_key')
29 | @classmethod
30 | def mask_forge_key(cls, v: str) -> str:
31 | return mask_forge_name_or_key(v)
32 |
33 | @field_validator('timestamp')
34 | @classmethod
35 | def convert_timestamp_to_iso(cls, v: datetime | str) -> str:
36 | if isinstance(v, str):
37 | return v
38 | return v.isoformat()
39 |
40 | class UsageRealtimeResponse(BaseModel):
41 | total: int
42 | items: list[UsageRealtimeItem]
43 | page_size: int
44 | page_index: int
45 |
46 | class UsageSummaryBreakdown(BaseModel):
47 | forge_key: str
48 | tokens: int
49 | cost: decimal.Decimal
50 | charged_cost: decimal.Decimal
51 | input_tokens: int
52 | output_tokens: int
53 | cached_tokens: int
54 |
55 | @field_validator('forge_key')
56 | @classmethod
57 | def mask_forge_key(cls, v: str) -> str:
58 | return mask_forge_name_or_key(v)
59 |
60 |
61 | class UsageSummaryResponse(BaseModel):
62 | time_point: datetime | str
63 | breakdown: list[UsageSummaryBreakdown]
64 | total_tokens: int
65 | total_cost: decimal.Decimal
66 | total_charged_cost: decimal.Decimal
67 | total_input_tokens: int
68 | total_output_tokens: int
69 | total_cached_tokens: int
70 |
71 | @field_validator('time_point')
72 | @classmethod
73 | def convert_timestamp_to_iso(cls, v: datetime | str) -> str:
74 | if isinstance(v, str):
75 | return v
76 | return v.isoformat()
77 |
78 |
79 | class ForgeKeysUsageSummaryResponse(BaseModel):
80 | forge_key: str
81 | tokens: int
82 | cost: decimal.Decimal
83 | charged_cost: decimal.Decimal
84 | input_tokens: int
85 | output_tokens: int
86 | cached_tokens: int
87 |
88 | @field_validator('forge_key')
89 | @classmethod
90 | def mask_forge_key(cls, v: str) -> str:
91 | return mask_forge_name_or_key(v)
--------------------------------------------------------------------------------
/alembic/versions/initial_migration.py:
--------------------------------------------------------------------------------
1 | """initial migration
2 |
3 | Revision ID: initial_migration
4 | Revises:
5 | Create Date: 2023-05-01 00:00:00.000000
6 |
7 | """
8 |
9 | import sqlalchemy as sa
10 |
11 | from alembic import op
12 |
13 | # revision identifiers, used by Alembic.
14 | revision = "initial_migration"
15 | down_revision = None
16 | branch_labels = None
17 | depends_on = None
18 |
19 |
20 | def upgrade() -> None:
21 | # Create users table
22 | op.create_table(
23 | "users",
24 | sa.Column("id", sa.Integer(), nullable=False),
25 | sa.Column("email", sa.String(), nullable=True),
26 | sa.Column("username", sa.String(), nullable=True),
27 | sa.Column("hashed_password", sa.String(), nullable=True),
28 | sa.Column("is_active", sa.Boolean(), nullable=True),
29 | sa.Column("forge_api_key", sa.String(), nullable=True),
30 | sa.Column("created_at", sa.DateTime(), nullable=True),
31 | sa.Column("updated_at", sa.DateTime(), nullable=True),
32 | sa.PrimaryKeyConstraint("id"),
33 | )
34 | op.create_index(op.f("ix_users_email"), "users", ["email"], unique=True)
35 | op.create_index(
36 | op.f("ix_users_forge_api_key"), "users", ["forge_api_key"], unique=True
37 | )
38 | op.create_index(op.f("ix_users_id"), "users", ["id"], unique=False)
39 | op.create_index(op.f("ix_users_username"), "users", ["username"], unique=True)
40 |
41 | # Create provider_keys table
42 | op.create_table(
43 | "provider_keys",
44 | sa.Column("id", sa.Integer(), nullable=False),
45 | sa.Column("provider_name", sa.String(), nullable=True),
46 | sa.Column("encrypted_api_key", sa.String(), nullable=True),
47 | sa.Column("user_id", sa.Integer(), nullable=True),
48 | sa.Column("base_url", sa.String(), nullable=True),
49 | sa.Column("model_mapping", sa.String(), nullable=True),
50 | sa.Column("created_at", sa.DateTime(), nullable=True),
51 | sa.Column("updated_at", sa.DateTime(), nullable=True),
52 | sa.ForeignKeyConstraint(
53 | ["user_id"],
54 | ["users.id"],
55 | ),
56 | sa.PrimaryKeyConstraint("id"),
57 | )
58 | op.create_index(op.f("ix_provider_keys_id"), "provider_keys", ["id"], unique=False)
59 | op.create_index(
60 | op.f("ix_provider_keys_provider_name"),
61 | "provider_keys",
62 | ["provider_name"],
63 | unique=False,
64 | )
65 |
66 |
67 | def downgrade() -> None:
68 | op.drop_index(op.f("ix_provider_keys_provider_name"), table_name="provider_keys")
69 | op.drop_index(op.f("ix_provider_keys_id"), table_name="provider_keys")
70 | op.drop_table("provider_keys")
71 | op.drop_index(op.f("ix_users_username"), table_name="users")
72 | op.drop_index(op.f("ix_users_id"), table_name="users")
73 | op.drop_index(op.f("ix_users_forge_api_key"), table_name="users")
74 | op.drop_index(op.f("ix_users_email"), table_name="users")
75 | op.drop_table("users")
76 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_security.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from unittest import IsolatedAsyncioTestCase as TestCase
4 |
5 | # Add the parent directory to the path so Python can find the app module
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 |
8 | # Patch bcrypt version detection to avoid warnings
9 | import bcrypt
10 |
11 | if not hasattr(bcrypt, "__about__"):
12 | import types
13 |
14 | bcrypt.__about__ = types.ModuleType("__about__")
15 | bcrypt.__about__.__version__ = (
16 | bcrypt.__version__ if hasattr(bcrypt, "__version__") else "3.2.0"
17 | )
18 |
19 | from jose import jwt
20 |
21 | from app.core.security import (
22 | ALGORITHM,
23 | SECRET_KEY,
24 | create_access_token,
25 | decrypt_api_key,
26 | encrypt_api_key,
27 | generate_forge_api_key,
28 | get_password_hash,
29 | verify_password,
30 | )
31 |
32 |
33 | class TestSecurity(TestCase):
34 | """Test the security utilities"""
35 |
36 | async def test_password_hashing(self):
37 | """Test password hashing and verification"""
38 | password = "test_password123"
39 | hashed = get_password_hash(password)
40 |
41 | # Verify the hash is different from the original password
42 | self.assertNotEqual(password, hashed)
43 |
44 | # Verify the password against the hash
45 | self.assertTrue(verify_password(password, hashed))
46 |
47 | # Verify wrong password fails
48 | self.assertFalse(verify_password("wrong_password", hashed))
49 |
50 | async def test_jwt_token(self):
51 | """Test JWT token creation and verification"""
52 | data = {"sub": "testuser"}
53 | token = create_access_token(data)
54 |
55 | # Decode and verify the token
56 | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
57 |
58 | # Check that original data is preserved
59 | self.assertEqual(payload["sub"], "testuser")
60 |
61 | # Check that expiration is added
62 | self.assertIn("exp", payload)
63 |
64 | async def test_api_key_encryption(self):
65 | """Test API key encryption and decryption"""
66 | original_key = "sk-123456789abcdef"
67 |
68 | # Encrypt the key
69 | encrypted = encrypt_api_key(original_key)
70 |
71 | # Verify encrypted key is different
72 | self.assertNotEqual(original_key, encrypted)
73 |
74 | # Decrypt and verify
75 | decrypted = decrypt_api_key(encrypted)
76 | self.assertEqual(original_key, decrypted)
77 |
78 | async def test_forge_api_key_generation(self):
79 | """Test Forge API key generation"""
80 | key1 = generate_forge_api_key()
81 | key2 = generate_forge_api_key()
82 |
83 | # Check the format
84 | self.assertTrue(key1.startswith("forge-"))
85 |
86 | # Check uniqueness
87 | self.assertNotEqual(key1, key2)
88 |
89 | # Check length
90 | self.assertEqual(
91 | len(key1), 42
92 | ) # "forge-" (6) + checksum (4) + base_key (32 hex chars)
93 |
--------------------------------------------------------------------------------
/alembic/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | from logging.config import fileConfig
3 |
4 | from dotenv import load_dotenv
5 | from sqlalchemy import engine_from_config, pool
6 |
7 | from alembic import context
8 | from app.models.api_request_log import ApiRequestLog # noqa
9 | from app.models.base import Base
10 | from app.models.provider_key import ProviderKey # noqa
11 |
12 | # Explicitly import models for autogenerate to find them
13 | from app.models.user import User # noqa
14 |
15 | # Load environment variables from .env file
16 | load_dotenv()
17 |
18 | # this is the Alembic Config object, which provides
19 | # access to the values within the .ini file in use.
20 | config = context.config
21 |
22 | # Override the SQLAlchemy URL with environment variable
23 | database_url = os.getenv("DATABASE_URL")
24 | if not database_url:
25 | raise ValueError("DATABASE_URL environment variable is not set")
26 | config.set_main_option("sqlalchemy.url", database_url)
27 |
28 | # Interpret the config file for Python logging.
29 | # This line sets up loggers basically.
30 | if config.config_file_name is not None:
31 | fileConfig(config.config_file_name)
32 |
33 | # add your model's MetaData object here
34 | # for 'autogenerate' support
35 | # from myapp import mymodel
36 | # target_metadata = mymodel.Base.metadata
37 | target_metadata = Base.metadata
38 |
39 | # other values from the config, defined by the needs of env.py,
40 | # can be acquired:
41 | # my_important_option = config.get_main_option("my_important_option")
42 | # ... etc.
43 |
44 |
45 | def run_migrations_offline() -> None:
46 | """Run migrations in 'offline' mode.
47 |
48 | This configures the context with just a URL
49 | and not an Engine, though an Engine is acceptable
50 | here as well. By skipping the Engine creation
51 | we don't even need a DBAPI to be available.
52 |
53 | Calls to context.execute() here emit the given string to the
54 | script output.
55 |
56 | """
57 | url = config.get_main_option("sqlalchemy.url")
58 | context.configure(
59 | url=url,
60 | target_metadata=target_metadata,
61 | literal_binds=True,
62 | dialect_opts={"paramstyle": "named"},
63 | render_as_batch=True,
64 | )
65 |
66 | with context.begin_transaction():
67 | context.run_migrations()
68 |
69 |
70 | def run_migrations_online() -> None:
71 | """Run migrations in 'online' mode.
72 |
73 | In this scenario we need to create an Engine
74 | and associate a connection with the context.
75 |
76 | """
77 | connectable = engine_from_config(
78 | config.get_section(config.config_ini_section, {}),
79 | prefix="sqlalchemy.",
80 | poolclass=pool.NullPool,
81 | )
82 |
83 | with connectable.connect() as connection:
84 | context.configure(
85 | connection=connection,
86 | target_metadata=target_metadata,
87 | render_as_batch=True,
88 | )
89 |
90 | with context.begin_transaction():
91 | context.run_migrations()
92 |
93 |
94 | if context.is_offline_mode():
95 | run_migrations_offline()
96 | else:
97 | run_migrations_online()
98 |
--------------------------------------------------------------------------------
/app/models/pricing.py:
--------------------------------------------------------------------------------
1 | # app/models/pricing.py
2 | import datetime
3 | from datetime import UTC
4 | from sqlalchemy import Column, DateTime, String, DECIMAL, Index
5 |
6 | from .base import BaseModel
7 |
8 |
9 | class ModelPricing(BaseModel):
10 | """
11 | Store pricing information for specific models with temporal support
12 | """
13 | __tablename__ = "model_pricing"
14 |
15 | provider_name = Column(String, nullable=False, index=True)
16 | model_name = Column(String, nullable=False, index=True)
17 |
18 | # Temporal fields for price changes over time
19 | effective_date = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC))
20 | end_date = Column(DateTime(timezone=True), nullable=True) # NULL means currently active
21 |
22 | # Pricing per 1K tokens (using DECIMAL for precision)
23 | input_token_price = Column(DECIMAL(12, 8), nullable=False) # Price per 1K input tokens
24 | output_token_price = Column(DECIMAL(12, 8), nullable=False) # Price per 1K output tokens
25 | cached_token_price = Column(DECIMAL(12, 8), nullable=False, default=0) # Price per 1K cached tokens
26 |
27 | # Metadata
28 | currency = Column(String(3), nullable=False, default='USD')
29 |
30 | # Indexes for efficient querying
31 | __table_args__ = (
32 | # Index for finding active pricing for a model
33 | Index('ix_model_pricing_active', 'provider_name', 'model_name', 'effective_date', 'end_date'),
34 | # Index for temporal queries
35 | Index('ix_model_pricing_temporal', 'effective_date', 'end_date'),
36 | # Unique constraint for overlapping periods (business rule enforcement)
37 | Index('ix_model_pricing_unique_period', 'provider_name', 'model_name', 'effective_date', unique=True),
38 | )
39 |
40 |
41 | class FallbackPricing(BaseModel):
42 | """
43 | Store fallback pricing for providers and global defaults
44 | """
45 | __tablename__ = "fallback_pricing"
46 |
47 | provider_name = Column(String, nullable=True, index=True) # NULL for global fallback
48 | model_name = Column(String, nullable=True, index=True) # NULL for global fallback
49 | fallback_type = Column(String(20), nullable=False, index=True) # 'model_default', 'provider_default', 'global_default'
50 |
51 | # Temporal fields
52 | effective_date = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC))
53 | end_date = Column(DateTime(timezone=True), nullable=True)
54 |
55 | # Pricing per 1K tokens
56 | input_token_price = Column(DECIMAL(12, 8), nullable=False)
57 | output_token_price = Column(DECIMAL(12, 8), nullable=False)
58 | cached_token_price = Column(DECIMAL(12, 8), nullable=False, default=0)
59 |
60 | # Metadata
61 | currency = Column(String(3), nullable=False, default='USD')
62 | description = Column(String(255), nullable=True) # Optional description
63 |
64 | # Indexes
65 | __table_args__ = (
66 | Index('ix_fallback_pricing_active', 'provider_name', 'fallback_type', 'effective_date', 'end_date'),
67 | Index('ix_fallback_pricing_type', 'fallback_type', 'effective_date'),
68 | )
69 |
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
1 | # Forge Environment Variables
2 | #
3 | # This file provides a template for the environment variables needed to run the Forge application.
4 | # Copy this file to .env and fill in the appropriate values for your development environment.
5 | # Do not commit the .env file to your version control system.
6 |
7 | # -----------------------------------------------------------------------------
8 | # DATABASE SETTINGS
9 | # -----------------------------------------------------------------------------
10 | # PostgreSQL connection string.
11 | # Example for local development with Docker: postgresql://forge:forge@localhost:5432/forge
12 | DATABASE_URL=postgresql://user:password@host:port/dbname
13 |
14 | # -----------------------------------------------------------------------------
15 | # REDIS CACHE SETTINGS (Optional)
16 | # -----------------------------------------------------------------------------
17 | # URL for your Redis instance. If not provided, an in-memory cache will be used.
18 | # REDIS_URL=redis://localhost:6379/0
19 | # Prefix for all Redis keys to avoid collisions.
20 | REDIS_PREFIX=forge
21 |
22 | # -----------------------------------------------------------------------------
23 | # SECURITY & AUTHENTICATION
24 | # -----------------------------------------------------------------------------
25 | # Secret key for signing JWTs. Generate a strong, random key.
26 | # Use `openssl rand -hex 32` to generate a key.
27 | SECRET_KEY=
28 | # Secret key for encrypting sensitive data like provider API keys.
29 | # Generate with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
30 | ENCRYPTION_KEY=
31 | # Algorithm for JWT signing.
32 | ALGORITHM=HS256
33 | # Access token expiration time in minutes.
34 | ACCESS_TOKEN_EXPIRE_MINUTES=60
35 |
36 | # -----------------------------------------------------------------------------
37 | # CLERK AUTHENTICATION INTEGRATION (OPTIONAL)
38 | # -----------------------------------------------------------------------------
39 | # Your Clerk JWT public key (remove -----BEGIN PUBLIC KEY----- and -----END PUBLIC KEY----- headers/footers and newlines).
40 | CLERK_JWT_PUBLIC_KEY=
41 | # Your Clerk API key.
42 | CLERK_API_KEY=
43 | # Your Clerk API URL.
44 | CLERK_API_URL=https://api.clerk.com/v1
45 | # Your Clerk webhook signing secret.
46 | CLERK_WEBHOOK_SECRET=
47 |
48 | # -----------------------------------------------------------------------------
49 | # APPLICATION & DEVELOPMENT SETTINGS
50 | # -----------------------------------------------------------------------------
51 | # Set to "production" or "development".
52 | ENVIRONMENT=development
53 | # Set to "true" to enable debug-level logging.
54 | FORGE_DEBUG_LOGGING=false
55 | # Host for the application.
56 | HOST=0.0.0.0
57 | # Port for the.
58 | PORT=8000
59 | # Set to "true" to enable auto-reloading for development.
60 | RELOAD=false
61 | # Number of Gunicorn worker processes.
62 | WORKERS=4
63 | # Set to "true" to force the use of in-memory cache, even if REDIS_URL is set.
64 | # Useful for testing environments.
65 | FORCE_MEMORY_CACHE=false
66 | # Set to "true" to enable verbose cache debugging output.
67 | DEBUG_CACHE=false
68 | # Your application's domain, used for JWT audience validation.
69 | APP_DOMAIN=localhost
70 |
--------------------------------------------------------------------------------
/app/api/routes/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import AsyncGenerator
2 | import json
3 | from fastapi import HTTPException
4 | from starlette.responses import StreamingResponse
5 | from app.exceptions.exceptions import ProviderAPIException
6 |
7 | async def wrap_streaming_response_with_error_handling(
8 | logger, async_gen: AsyncGenerator[bytes, None]
9 | ) -> StreamingResponse:
10 | """
11 | Wraps an async generator to catch and properly handle errors in streaming responses.
12 | Returns a StreamingResponse that will:
13 | - Return proper HTTP error status if error occurs before first chunk
14 | - Send error as SSE event if error occurs mid-stream
15 |
16 | Args:
17 | logger: Logger instance for error logging
18 | async_gen: The async generator producing the stream chunks
19 |
20 | Returns:
21 | StreamingResponse with proper error handling
22 |
23 | Raises:
24 | HTTPException: If error occurs before streaming starts
25 | """
26 |
27 | # Try to get the first chunk BEFORE creating StreamingResponse
28 | # This allows us to catch immediate errors and return proper HTTP status
29 | try:
30 | first_chunk = await async_gen.__anext__()
31 | except StopAsyncIteration:
32 | # Empty stream
33 | logger.error("Empty stream response")
34 | raise HTTPException(status_code=500, detail="Empty stream response")
35 | except ProviderAPIException as e:
36 | logger.error(f"Provider API error: {str(e)}")
37 | raise HTTPException(status_code=e.error_code, detail=e.error_message) from e
38 | except Exception as e:
39 | # Convert other exceptions to HTTPException
40 | logger.error(f"Error before streaming started: {str(e)}")
41 | raise HTTPException(status_code=500, detail=str(e)) from e
42 |
43 | # Success! Now create generator that replays first chunk + rest
44 | async def response_generator():
45 | # Yield the first chunk we already got
46 | yield first_chunk
47 |
48 | try:
49 | # Continue with the rest of the stream
50 | async for chunk in async_gen:
51 | yield chunk
52 | except Exception as e:
53 | # Error occurred mid-stream - HTTP status already sent
54 | # Send error as SSE event to inform the client
55 | logger.error(f"Error during streaming: {str(e)}")
56 |
57 | error_message = str(e)
58 | error_event = {
59 | "error": {
60 | "message": error_message,
61 | "type": "stream_error",
62 | "code": "provider_error"
63 | }
64 | }
65 | yield f"data: {json.dumps(error_event)}\n\n".encode()
66 |
67 | # Send [DONE] to properly close the stream
68 | yield b"data: [DONE]\n\n"
69 |
70 | # Set appropriate headers for streaming
71 | headers = {
72 | "Content-Type": "text/event-stream",
73 | "Cache-Control": "no-cache",
74 | "Connection": "keep-alive",
75 | "X-Accel-Buffering": "no", # Prevent Nginx buffering
76 | }
77 |
78 | return StreamingResponse(
79 | response_generator(),
80 | media_type="text/event-stream",
81 | headers=headers
82 | )
--------------------------------------------------------------------------------
/alembic.ini:
--------------------------------------------------------------------------------
1 | # A generic, single database configuration.
2 |
3 | [alembic]
4 | # path to migration scripts
5 | script_location = alembic
6 |
7 | # template used to generate migration files
8 | # file_template = %%(rev)s_%%(slug)s
9 |
10 | # sys.path path, will be prepended to sys.path if present.
11 | # defaults to the current working directory.
12 | prepend_sys_path = .
13 |
14 | # timezone to use when rendering the date within the migration file
15 | # as well as the filename.
16 | # If specified, requires the python-dateutil library that can be
17 | # installed by adding `alembic[tz]` to the pip requirements
18 | # string value is passed to dateutil.tz.gettz()
19 | # leave blank for localtime
20 | # timezone =
21 |
22 | # max length of characters to apply to the
23 | # "slug" field
24 | # truncate_slug_length = 40
25 |
26 | # set to 'true' to run the environment during
27 | # the 'revision' command, regardless of autogenerate
28 | # revision_environment = false
29 |
30 | # set to 'true' to allow .pyc and .pyo files without
31 | # a source .py file to be detected as revisions in the
32 | # versions/ directory
33 | # sourceless = false
34 |
35 | # version location specification; This defaults
36 | # to alembic/versions. When using multiple version
37 | # directories, initial revisions must be specified with --version-path.
38 | # The path separator used here should be the separator specified by "version_path_separator" below.
39 | # version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
40 |
41 | # version path separator; As mentioned above, this is the character used to split
42 | # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
43 | # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
44 | # Valid values for version_path_separator are:
45 | #
46 | # version_path_separator = :
47 | # version_path_separator = ;
48 | # version_path_separator = space
49 | version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
50 |
51 | # the output encoding used when revision files
52 | # are written from script.py.mako
53 | # output_encoding = utf-8
54 |
55 |
56 | [post_write_hooks]
57 | # post_write_hooks defines scripts or Python functions that are run
58 | # on newly generated revision scripts. See the documentation for further
59 | # detail and examples
60 |
61 | # format using "black" - use the console_scripts runner, against the "black" entrypoint
62 | # hooks = black
63 | # black.type = console_scripts
64 | # black.entrypoint = black
65 | # black.options = -l 79 REVISION_SCRIPT_FILENAME
66 |
67 | # Logging configuration
68 | [loggers]
69 | keys = root,sqlalchemy,alembic
70 |
71 | [handlers]
72 | keys = console
73 |
74 | [formatters]
75 | keys = generic
76 |
77 | [logger_root]
78 | level = WARN
79 | handlers = console
80 | qualname =
81 |
82 | [logger_sqlalchemy]
83 | level = WARN
84 | handlers =
85 | qualname = sqlalchemy.engine
86 |
87 | [logger_alembic]
88 | level = INFO
89 | handlers =
90 | qualname = alembic
91 |
92 | [handler_console]
93 | class = StreamHandler
94 | args = (sys.stderr,)
95 | level = NOTSET
96 | formatter = generic
97 |
98 | [formatter_generic]
99 | format = %(levelname)-5.5s [%(name)s] %(message)s
100 | datefmt = %H:%M:%S
101 |
--------------------------------------------------------------------------------
/alembic/versions/c50fd7be794c_add_endpoint_column_to_usage_stats.py:
--------------------------------------------------------------------------------
1 | """Add endpoint column to usage_stats
2 |
3 | Revision ID: c50fd7be794c
4 | Revises: ca1ac51334ec
5 | Create Date: 2025-04-08 20:46:46.713120
6 |
7 | """
8 |
9 | import sqlalchemy as sa
10 |
11 | from alembic import op
12 |
13 | # revision identifiers, used by Alembic.
14 | revision = "c50fd7be794c"
15 | down_revision = "ca1ac51334ec"
16 | branch_labels = None
17 | depends_on = None
18 |
19 |
20 | def upgrade() -> None:
21 | # ### commands auto generated by Alembic - please adjust! ###
22 | with op.batch_alter_table("usage_stats", schema=None) as batch_op:
23 | batch_op.add_column(sa.Column("endpoint", sa.String(), nullable=False))
24 | batch_op.add_column(sa.Column("input_tokens", sa.Integer(), nullable=True))
25 | batch_op.add_column(sa.Column("output_tokens", sa.Integer(), nullable=True))
26 | batch_op.add_column(sa.Column("requests_count", sa.Integer(), nullable=True))
27 | batch_op.add_column(sa.Column("cost", sa.Float(), nullable=True))
28 | batch_op.add_column(sa.Column("created_at", sa.DateTime(), nullable=True))
29 |
30 | # Drop the old FK using its specific name
31 | batch_op.drop_constraint(
32 | op.f("fk_usage_stats_user_id_users"), type_="foreignkey"
33 | )
34 | # Create the new FK with ON DELETE CASCADE
35 | batch_op.create_foreign_key(
36 | op.f("fk_usage_stats_user_id_users"), # Keep the same conventional name
37 | "users",
38 | ["user_id"],
39 | ["id"],
40 | ondelete="CASCADE",
41 | )
42 |
43 | batch_op.drop_column("completion_tokens")
44 | batch_op.drop_column("timestamp")
45 | batch_op.drop_column("error_count")
46 | batch_op.drop_column("prompt_tokens")
47 | batch_op.drop_column("success_count")
48 | batch_op.drop_column("request_count")
49 |
50 | # ### end Alembic commands ###
51 |
52 |
53 | def downgrade() -> None:
54 | # ### commands auto generated by Alembic - please adjust! ###
55 | with op.batch_alter_table("usage_stats", schema=None) as batch_op:
56 | batch_op.add_column(sa.Column("request_count", sa.INTEGER(), nullable=True))
57 | batch_op.add_column(sa.Column("success_count", sa.INTEGER(), nullable=True))
58 | batch_op.add_column(sa.Column("prompt_tokens", sa.INTEGER(), nullable=True))
59 | batch_op.add_column(sa.Column("error_count", sa.INTEGER(), nullable=True))
60 | batch_op.add_column(sa.Column("timestamp", sa.DATETIME(), nullable=True))
61 | batch_op.add_column(sa.Column("completion_tokens", sa.INTEGER(), nullable=True))
62 |
63 | # Drop the ON DELETE CASCADE FK
64 | batch_op.drop_constraint(
65 | op.f("fk_usage_stats_user_id_users"), type_="foreignkey"
66 | )
67 | # Recreate the original FK without ON DELETE CASCADE
68 | batch_op.create_foreign_key(
69 | op.f("fk_usage_stats_user_id_users"), "users", ["user_id"], ["id"]
70 | )
71 |
72 | batch_op.drop_column("created_at")
73 | batch_op.drop_column("cost")
74 | batch_op.drop_column("requests_count")
75 | batch_op.drop_column("output_tokens")
76 | batch_op.drop_column("input_tokens")
77 | batch_op.drop_column("endpoint")
78 |
79 | # ### end Alembic commands ###
80 |
--------------------------------------------------------------------------------
/app/api/routes/stripe.py:
--------------------------------------------------------------------------------
1 | import os
2 | from fastapi import APIRouter, Depends, Request
3 | from app.api.dependencies import get_current_active_user_from_clerk, get_current_active_user
4 | from app.api.schemas.stripe import CreateCheckoutSessionRequest
5 | from app.models.user import User
6 | from app.models.stripe import StripePayment
7 | from app.core.database import get_async_db
8 | from sqlalchemy.ext.asyncio import AsyncSession
9 | import stripe
10 | from app.core.logger import get_logger
11 | from sqlalchemy import select
12 | from fastapi import HTTPException
13 |
14 | logger = get_logger(name="stripe")
15 |
16 | STRIPE_API_KEY = os.getenv("STRIPE_API_KEY")
17 | stripe.api_key = STRIPE_API_KEY
18 |
19 | router = APIRouter()
20 |
21 | @router.post("/create-checkout-session/clerk")
22 | async def stripe_create_checkout_session_clerk(request: Request, create_checkout_session_request: CreateCheckoutSessionRequest, user: User = Depends(get_current_active_user_from_clerk), db: AsyncSession = Depends(get_async_db)):
23 | return await stripe_create_checkout_session(request, create_checkout_session_request, user, db)
24 |
25 |
26 | @router.post("/create-checkout-session")
27 | async def stripe_create_checkout_session(request: Request, create_checkout_session_request: CreateCheckoutSessionRequest, user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_async_db)):
28 | """
29 | Create a checkout session for a user.
30 | """
31 | logger.info(f"Creating checkout session for user {user.id}")
32 | session = await stripe.checkout.Session.create_async(
33 | metadata={
34 | "user_id": user.id,
35 | },
36 | **create_checkout_session_request.model_dump(exclude_none=True),
37 | )
38 | stripe_payment = StripePayment(
39 | id=session.id,
40 | user_id=user.id,
41 | status=session.status,
42 | currency=session.currency.upper(),
43 | amount=session.amount_total,
44 | # store the whole session as raw_data
45 | raw_data=dict(session),
46 | )
47 | db.add(stripe_payment)
48 | await db.commit()
49 |
50 | return {
51 | 'session_id': session.id,
52 | 'url': session.url,
53 | }
54 |
55 | @router.get("/checkout-session")
56 | async def stripe_get_checkout_session(session_id: str, user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_async_db)):
57 | result = await db.execute(
58 | select(
59 | StripePayment
60 | )
61 | .where(StripePayment.id == session_id, StripePayment.user_id == user.id)
62 | )
63 | stripe_payment = result.scalar_one_or_none()
64 | if not stripe_payment:
65 | raise HTTPException(status_code=404, detail="Stripe payment not found")
66 |
67 | return {
68 | 'id': stripe_payment.id,
69 | 'status': stripe_payment.status,
70 | 'currency': stripe_payment.currency,
71 | 'amount': stripe_payment.amount / 100.0 if stripe_payment.currency == "USD" else stripe_payment.amount,
72 | 'created_at': stripe_payment.created_at,
73 | 'updated_at': stripe_payment.updated_at,
74 | }
75 |
76 | @router.get("/checkout-session/clerk")
77 | async def stripe_get_checkout_session_clerk(session_id: str, user: User = Depends(get_current_active_user_from_clerk), db: AsyncSession = Depends(get_async_db)):
78 | return await stripe_get_checkout_session(session_id, user, db)
79 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "forge"
7 | version = "0.1.0"
8 | description = "Forge is an open-source middleware service that simplifies AI model provider management. It allows you to use multiple AI providers (OpenAI, Anthropic, etc.) through a single, unified API. By storing your provider API keys securely, Forge generates a unified key that works across all your AI applications."
9 | requires-python = ">=3.12"
10 | authors = [
11 | {name = "TensorBlock", email = "contact@tensorblock.co"},
12 | ]
13 | readme = "README.md"
14 | license = {text = "MIT"}
15 | dependencies = [
16 | "fastapi>=0.95.0",
17 | "uvicorn>=0.22.0",
18 | "pydantic>=2.0.0",
19 | "python-jose>=3.3.0",
20 | "stripe",
21 | "passlib>=1.7.4",
22 | "python-multipart>=0.0.5",
23 | "sqlalchemy[asyncio]>=2.0.0",
24 | "alembic>=1.10.4",
25 | "aiohttp>=3.8.4",
26 | "cryptography>=40.0.0",
27 | "bcrypt==3.2.2",
28 | "python-dotenv>=1.0.0",
29 | "email-validator>=2.0.0",
30 | "requests>=2.28.0",
31 | "svix>=1.13.0",
32 | "psycopg2-binary>=2.9.9",
33 | "asyncpg>=0.29.0",
34 | "boto3>=1.0.0",
35 | "gunicorn>=20.0.0",
36 | "redis>=4.6.0", # sync & async clients used by shared cache
37 | "loguru>=0.7.0",
38 | "aiobotocore~=2.0",
39 | "tiktoken>=0.5.0", # for token counting in Claude Code support
40 | "google-generativeai>=0.3.0",
41 | "google-genai>=0.3.0",
42 | ]
43 |
44 | [project.optional-dependencies]
45 | dev = [
46 | "pytest>=7.0.0",
47 | "pytest-cov>=4.0.0",
48 | "pytest-asyncio>=0.26.0",
49 | "ruff>=0.2.0",
50 | "pre-commit>=3.0.0",
51 | "flake8",
52 | "black==22.6.0",
53 | "isort",
54 | "mypy",
55 | ]
56 |
57 | [tool.pytest.ini_options]
58 | filterwarnings = [
59 | "ignore:.*'crypt' is deprecated.*:DeprecationWarning:passlib.*"
60 | ]
61 | asyncio_default_fixture_loop_scope = "function"
62 | asyncio_mode = "auto"
63 |
64 | [tool.hatch.build.targets.wheel]
65 | packages = ["app"]
66 |
67 | [tool.hatch.build.targets.sdist]
68 | include = [
69 | "app",
70 | "alembic",
71 | "alembic.ini",
72 | "README.md",
73 | "LICENSE",
74 | ]
75 |
76 | [tool.ruff]
77 | # Same as Black
78 | indent-width = 4
79 | line-length = 88
80 | target-version = "py312"
81 |
82 | [tool.ruff.lint]
83 | # Enable Pyflakes ('F'), pycodestyle ('E'), and import sorting ('I')
84 | select = ["E", "F", "I", "N", "W", "B", "C4", "PL", "SIM", "UP"]
85 | ignore = [
86 | "B008", # Function call in argument defaults
87 | "E501", # Line too long
88 | "PLR0912", # Too many branches
89 | "PLR0915", # Too many statements
90 | "PLR0913", # Too many arguments
91 | "B904",
92 | "PLR0911",
93 | "SIM118",
94 | "PLW2901",
95 | "SIM117",
96 | "PLR2004",
97 | ]
98 | # Allow autofix for all enabled rules (when `--fix` is passed)
99 | fixable = ["ALL"]
100 | unfixable = []
101 |
102 | [tool.ruff.format]
103 | quote-style = "double"
104 | indent-style = "space"
105 | line-ending = "auto"
106 | skip-magic-trailing-comma = false
107 |
108 | # Exclude a variety of commonly ignored directories
109 | [tool.ruff.lint.isort]
110 | known-first-party = ["app"]
111 |
112 | [tool.pytest]
113 | testpaths = ["tests"]
114 | python_files = "test_*.py"
115 |
--------------------------------------------------------------------------------
/tests/unit_tests/assets/google/list_models.json:
--------------------------------------------------------------------------------
1 | {
2 | "models": [
3 | {
4 | "name": "models/embedding-gecko-001",
5 | "version": "001",
6 | "displayName": "Embedding Gecko",
7 | "description": "Obtain a distributed representation of a text.",
8 | "inputTokenLimit": 1024,
9 | "outputTokenLimit": 1,
10 | "supportedGenerationMethods": [
11 | "embedText",
12 | "countTextTokens"
13 | ]
14 | },
15 | {
16 | "name": "models/gemini-1.0-pro-vision-latest",
17 | "version": "001",
18 | "displayName": "Gemini 1.0 Pro Vision",
19 | "description": "The original Gemini 1.0 Pro Vision model version which was optimized for image understanding. Gemini 1.0 Pro Vision was deprecated on July 12, 2024. Move to a newer Gemini version.",
20 | "inputTokenLimit": 12288,
21 | "outputTokenLimit": 4096,
22 | "supportedGenerationMethods": [
23 | "generateContent",
24 | "countTokens"
25 | ],
26 | "temperature": 0.4,
27 | "topP": 1,
28 | "topK": 32
29 | },
30 | {
31 | "name": "models/gemini-pro-vision",
32 | "version": "001",
33 | "displayName": "Gemini 1.0 Pro Vision",
34 | "description": "The original Gemini 1.0 Pro Vision model version which was optimized for image understanding. Gemini 1.0 Pro Vision was deprecated on July 12, 2024. Move to a newer Gemini version.",
35 | "inputTokenLimit": 12288,
36 | "outputTokenLimit": 4096,
37 | "supportedGenerationMethods": [
38 | "generateContent",
39 | "countTokens"
40 | ],
41 | "temperature": 0.4,
42 | "topP": 1,
43 | "topK": 32
44 | },
45 | {
46 | "name": "models/gemini-1.5-pro-latest",
47 | "version": "001",
48 | "displayName": "Gemini 1.5 Pro Latest",
49 | "description": "Alias that points to the most recent production (non-experimental) release of Gemini 1.5 Pro, our mid-size multimodal model that supports up to 2 million tokens.",
50 | "inputTokenLimit": 2000000,
51 | "outputTokenLimit": 8192,
52 | "supportedGenerationMethods": [
53 | "generateContent",
54 | "countTokens"
55 | ],
56 | "temperature": 1,
57 | "topP": 0.95,
58 | "topK": 40,
59 | "maxTemperature": 2
60 | },
61 | {
62 | "name": "models/gemini-1.5-pro-002",
63 | "version": "002",
64 | "displayName": "Gemini 1.5 Pro 002",
65 | "description": "Stable version of Gemini 1.5 Pro, our mid-size multimodal model that supports up to 2 million tokens, released in September of 2024.",
66 | "inputTokenLimit": 2000000,
67 | "outputTokenLimit": 8192,
68 | "supportedGenerationMethods": [
69 | "generateContent",
70 | "countTokens",
71 | "createCachedContent"
72 | ],
73 | "temperature": 1,
74 | "topP": 0.95,
75 | "topK": 40,
76 | "maxTemperature": 2
77 | }
78 | ],
79 | "nextPageToken": "Chltb2RlbHMvZ2VtaW5pLTEuNS1wcm8tMDAy"
80 | }
--------------------------------------------------------------------------------
/tests/run_tests.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import sys
4 | import unittest
5 |
6 | # Add parent directory to path to make imports work
7 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
8 |
9 | # Patch bcrypt version detection to avoid warnings
10 | import bcrypt
11 |
12 | if not hasattr(bcrypt, "__about__"):
13 | import types
14 |
15 | bcrypt.__about__ = types.ModuleType("__about__")
16 | bcrypt.__about__.__version__ = (
17 | bcrypt.__version__ if hasattr(bcrypt, "__version__") else "3.2.0"
18 | )
19 |
20 | # Import the cache test files
21 | # These are run separately as they need special async setup
22 | from tests.unit_tests.test_provider_service import TestProviderService
23 |
24 | # Import the test modules
25 | from tests.unit_tests.test_security import TestSecurity
26 |
27 | # Import the image related tests
28 | from tests.unit_tests.test_provider_service_images import TestProviderServiceImages
29 |
30 |
31 | # Define test suites
32 | def security_suite():
33 | suite = unittest.TestSuite()
34 | suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSecurity))
35 | return suite
36 |
37 |
38 | def provider_service_suite():
39 | suite = unittest.TestSuite()
40 | suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestProviderService))
41 | return suite
42 |
43 | def provider_service_images_suite():
44 | suite = unittest.TestSuite()
45 | suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestProviderServiceImages))
46 | return suite
47 |
48 |
49 | if __name__ == "__main__":
50 | runner = unittest.TextTestRunner(verbosity=2)
51 |
52 | tests_to_run = []
53 |
54 | print("Running security tests...")
55 | result_security = runner.run(security_suite())
56 | tests_to_run.append(result_security)
57 |
58 | print("\nRunning provider service tests...")
59 | result_provider = runner.run(provider_service_suite())
60 | tests_to_run.append(result_provider)
61 |
62 | print("\nRunning provider service images tests...")
63 | result_provider_images = runner.run(provider_service_images_suite())
64 | tests_to_run.append(result_provider_images)
65 |
66 | # Integration tests require a running server
67 | print("\nFor integration tests, make sure the server is running and then execute:")
68 | print("python tests/integration_test.py")
69 |
70 | # Cache tests
71 | print("\nTo run cache tests:")
72 | print("python tests/cache/test_sync_cache.py # For sync cache tests")
73 | print("python tests/cache/test_async_cache.py # For async cache tests")
74 |
75 | # Frontend simulation tests require a valid Forge API key
76 | print(
77 | "\nTo simulate a frontend application, set your Forge API key in the .env file and run:"
78 | )
79 | print("python tests/frontend_simulation.py")
80 |
81 | # Mock provider tests
82 | print("\nTo run all mock tests at once:")
83 | print("python tests/mock_testing/run_mock_tests.py")
84 |
85 | print("\nOr to run individual mock tests:")
86 | print("python tests/mock_testing/test_mock_client.py")
87 | print("# For interactive testing:")
88 | print("python tests/mock_testing/test_mock_client.py --interactive")
89 |
90 | print("\nFor examples of using mocks in your tests, see:")
91 | print("python tests/mock_testing/examples/test_with_mocks.py")
92 |
93 | print("\nSee tests/mock_testing/README.md for more information on mock testing.")
94 |
95 | # Exit with error code if any tests failed
96 | if any(not test.wasSuccessful() for test in tests_to_run):
97 | sys.exit(1)
--------------------------------------------------------------------------------
/docs/PERFORMANCE_OPTIMIZATIONS.md:
--------------------------------------------------------------------------------
1 | # Performance Optimizations for Forge
2 |
3 | This document outlines the performance optimizations implemented in the Forge application to enhance scalability, reduce database load, and improve response times.
4 |
5 | ## Key Optimizations
6 |
7 | ### 1. Multi-Level Caching
8 |
9 | We've implemented several caching mechanisms to reduce redundant operations:
10 |
11 | #### User Authentication Cache
12 | - Each API request requires user authentication via their API key
13 | - Instead of querying the database for every request, users are now cached for 5 minutes
14 | - Significant reduction in database reads for high-traffic users
15 |
16 | #### Provider Service Caching
17 | - ProviderService instances are now reused across requests for the same user
18 | - Cached for 10 minutes to balance memory usage with performance
19 | - Avoids redundant instantiation and database reads
20 |
21 | #### Provider API Key Caching
22 | - API keys are now lazy-loaded and decrypted only when needed
23 | - Once decrypted, they remain available in the cached service instance
24 | - Eliminates repeated decryption operations which are computationally expensive
25 |
26 | #### Model List Caching
27 | - Models returned by providers are cached for 1 hour
28 | - Avoids repeated API calls to list models, which can be rate-limited by providers
29 | - Each provider/configuration has its own cache entry
30 |
31 | ### 2. Lazy Loading
32 |
33 | - Provider keys are not loaded during service initialization
34 | - Keys are only loaded and decrypted when needed for a request
35 | - If a request doesn't need access to provider keys, no decryption occurs
36 |
37 | ### 3. Cache Invalidation
38 |
39 | We've implemented targeted cache invalidation to ensure data consistency:
40 |
41 | - User cache is invalidated when:
42 | - User details are updated
43 | - User's Forge API key is reset
44 |
45 | - Provider service cache is invalidated when:
46 | - Provider keys are added
47 | - Provider keys are updated
48 | - Provider keys are deleted
49 | - User details change
50 |
51 | ### 4. Error Resilience
52 |
53 | - Model listing now has error handling to prevent failures if one provider API is down
54 | - Batch processing of concurrent API requests to avoid overwhelming provider APIs
55 |
56 | ## Implementation Details
57 |
58 | ### Cache Implementation
59 |
60 | A simple in-memory cache with time-to-live (TTL) expiration:
61 | - `app/core/cache.py` contains the cache implementation
62 | - Each cache entry has its own expiration time
63 | - Cache entries are automatically removed when accessed after expiration
64 |
65 | ### Adapter Caching
66 |
67 | Provider adapters are cached at the class level:
68 | - Adapters are stateless and can be reused across all requests
69 | - Improves memory usage by sharing adapter instances
70 |
71 | ### Service Reuse Pattern
72 |
73 | The ProviderService now uses a factory pattern:
74 | - `get_instance()` method returns cached instances when available
75 | - Creates new instances only when needed
76 |
77 | ## Performance Impact
78 |
79 | These optimizations significantly reduce:
80 | - Database queries per request
81 | - Decryption operations
82 | - Memory usage
83 | - API calls to providers
84 |
85 | This results in:
86 | - Lower response times
87 | - Higher maximum throughput
88 | - Reduced database load
89 | - Better scalability
90 |
91 | ## Future Improvements
92 |
93 | Potential future optimizations:
94 | - Distributed caching (Redis) for multi-server deployments
95 | - More granular cache expiration policies
96 | - Request batching for similar consecutive requests
97 | - Advanced monitoring of cache hit/miss rates
98 |
--------------------------------------------------------------------------------
/tools/diagnostics/check_db_keys.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Script to directly check the database for API keys.
4 | """
5 |
6 | import os
7 |
8 | import psycopg2
9 | from dotenv import load_dotenv
10 |
11 |
12 | def check_database_connection():
13 | """Check if database is accessible."""
14 | try:
15 | # Load environment variables
16 | load_dotenv()
17 |
18 | # Get database URL from environment
19 | db_url = os.getenv(
20 | "DATABASE_URL", "postgresql://user:password@localhost:5432/forge"
21 | )
22 |
23 | # Try to connect to PostgreSQL
24 | conn = psycopg2.connect(db_url)
25 | return conn
26 | except Exception as e:
27 | print(f"Database connection error: {e}")
28 | print("Please ensure:")
29 | print("1. PostgreSQL is running")
30 | print("2. Database 'forge' exists")
31 | print("3. User has proper permissions")
32 | print("4. Connection details in .env are correct")
33 | return None
34 |
35 |
36 | def main():
37 | """Check the database for API keys."""
38 | conn = check_database_connection()
39 | if not conn:
40 | return
41 |
42 | try:
43 | cursor = conn.cursor()
44 |
45 | # Get table names
46 | cursor.execute(
47 | """
48 | SELECT table_name
49 | FROM information_schema.tables
50 | WHERE table_schema = 'public'
51 | """
52 | )
53 | tables = cursor.fetchall()
54 | print(f"Tables in database: {[table[0] for table in tables]}")
55 |
56 | # Check if users table exists
57 | if ("users",) in tables:
58 | # Count users
59 | cursor.execute("SELECT COUNT(*) FROM users")
60 | user_count = cursor.fetchone()[0]
61 | print(f"Number of users in database: {user_count}")
62 |
63 | # Get user details
64 | cursor.execute("SELECT id, username, forge_api_key FROM users")
65 | users = cursor.fetchall()
66 |
67 | print("\nUsers in database:")
68 | for user_id, username, api_key in users:
69 | print(f"User {user_id} ({username}): API Key = {api_key}")
70 |
71 | # Check against the expected key
72 | expected_key = "forge-1ea5812207aa309110b4122f38d7be34"
73 | if api_key == expected_key:
74 | print(" ✓ Matches expected key")
75 | else:
76 | print(" ✗ Does not match expected key")
77 | else:
78 | print("'users' table not found in database")
79 |
80 | # Check if provider_keys table exists
81 | if ("provider_keys",) in tables:
82 | cursor.execute("SELECT COUNT(*) FROM provider_keys")
83 | key_count = cursor.fetchone()[0]
84 | print(f"\nNumber of provider keys in database: {key_count}")
85 |
86 | # Get provider key details
87 | cursor.execute("SELECT id, provider_name, user_id FROM provider_keys")
88 | keys = cursor.fetchall()
89 |
90 | print("\nProvider keys in database:")
91 | for key_id, provider_name, user_id in keys:
92 | print(f"Key {key_id}: Provider = {provider_name}, User ID = {user_id}")
93 |
94 | cursor.close()
95 | conn.close()
96 |
97 | except psycopg2.Error as e:
98 | print(f"Database error: {e}")
99 | if conn:
100 | conn.close()
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------
/tests/performance/README.md:
--------------------------------------------------------------------------------
1 | # Forge Performance Tests
2 |
3 | This directory contains performance tests for evaluating the Forge middleware service's performance characteristics.
4 |
5 | ## Overview
6 |
7 | These tests measure key performance metrics such as:
8 |
9 | - Response latency (p50, p90, p99)
10 | - Throughput (requests per second)
11 | - Scalability under varying loads
12 | - Resource utilization
13 | - Provider-specific performance metrics
14 |
15 | Performance tests use the mock provider by default to ensure consistent and reproducible results without external dependencies.
16 |
17 | ## Prerequisites
18 |
19 | To run these tests, you need:
20 |
21 | 1. A running instance of Forge
22 | 2. Python packages:
23 | - locust
24 | - pytest
25 | - pytest-benchmark
26 | - aiohttp
27 | - pandas (for results analysis)
28 | - matplotlib (for visualizations)
29 |
30 | Install performance testing dependencies:
31 |
32 | ```bash
33 | pip install locust pytest pytest-benchmark aiohttp pandas matplotlib
34 | ```
35 |
36 | ## Usage
37 |
38 | ### Single-run benchmarks:
39 |
40 | ```bash
41 | # Run all performance tests
42 | pytest tests/performance/
43 |
44 | # Run specific performance test category
45 | pytest tests/performance/test_latency.py
46 | pytest tests/performance/test_throughput.py
47 | pytest tests/performance/test_providers.py
48 | ```
49 |
50 | ### Load testing with Locust:
51 |
52 | ```bash
53 | # Start Locust web interface
54 | cd tests/performance
55 | locust -f locustfile.py
56 |
57 | # Run headless with 10 users, spawn rate of 1 user/sec, for 1 minute
58 | locust -f locustfile.py --headless -u 10 -r 1 --run-time 1m
59 | ```
60 |
61 | ## About Mock Providers
62 |
63 | Performance tests use mock providers rather than real API calls to:
64 |
65 | - Eliminate external dependencies for consistent results
66 | - Avoid rate limits and costs associated with real API calls
67 | - Ensure reproducible test results
68 | - Provide consistent response patterns and timing
69 |
70 | This approach focuses on measuring Forge's own performance characteristics rather than the performance of external services.
71 |
72 | ## Baseline Results
73 |
74 | The baseline performance results are stored in `baseline_results/` and represent reference numbers for comparing future changes.
75 |
76 | ## Output Analysis
77 |
78 | Test results are saved to:
79 | - JSON reports in `results/json/`
80 | - CSV files in `results/csv/`
81 | - Graphs in `results/graphs/`
82 |
83 | Use the analysis utilities to compare results:
84 |
85 | ```bash
86 | # Compare specific test runs
87 | python tests/performance/analyze_results.py --baseline baseline_results/latest.json --new results/json/new_run.json
88 |
89 | # Generate charts for all test results
90 | python tests/performance/analyze_results.py
91 |
92 | # Generate an interactive HTML dashboard with all test results
93 | python tests/performance/analyze_results.py --dashboard
94 | ```
95 |
96 | The HTML dashboard provides a comprehensive view of all performance test results in a single page with:
97 | - Interactive tabs for different test categories
98 | - Visual representations of all metrics
99 | - Summary statistics for each test
100 | - Test results grouped by type (latency, throughput, streaming, etc.)
101 |
102 | After generating the dashboard, open it in any web browser:
103 | ```bash
104 | # On macOS
105 | open tests/performance/results/graphs/performance_dashboard.html
106 |
107 | # On Linux
108 | xdg-open tests/performance/results/graphs/performance_dashboard.html
109 |
110 | # On Windows
111 | start tests/performance/results/graphs/performance_dashboard.html
112 | ```
113 |
114 | ## Interpreting Results
115 |
116 | When evaluating performance changes:
117 | - Latency: Lower is better
118 | - Throughput: Higher is better
119 | - Error rate: Should remain at or near 0%
120 |
--------------------------------------------------------------------------------
/app/api/schemas/provider_key.py:
--------------------------------------------------------------------------------
1 | import json
2 | from datetime import datetime
3 |
4 | from pydantic import BaseModel, ConfigDict, computed_field, Field, field_validator
5 |
6 | from app.core.logger import get_logger
7 | from app.core.security import decrypt_api_key
8 | from app.services.providers.adapter_factory import ProviderAdapterFactory
9 |
10 | logger = get_logger(name="provider_key")
11 |
12 |
13 | class ProviderKeyBase(BaseModel):
14 | provider_name: str = Field(..., min_length=1)
15 | api_key: str
16 | base_url: str | None = None
17 | model_mapping: dict[str, str] | None = None
18 | config: dict[str, str] | None = None
19 |
20 |
21 | class ProviderKeyCreate(ProviderKeyBase):
22 | pass
23 |
24 |
25 | class ProviderKeyUpdate(BaseModel):
26 | api_key: str | None = None
27 | config: dict[str, str] | None = None
28 | base_url: str | None = None
29 | model_mapping: dict[str, str] | None = None
30 |
31 |
32 | class ProviderKeyInDBBase(BaseModel):
33 | id: int
34 | provider_name: str
35 | user_id: int
36 | base_url: str | None = None
37 | model_mapping: dict[str, str] | None = None
38 | created_at: datetime
39 | updated_at: datetime
40 | encrypted_api_key: str
41 | model_config = ConfigDict(from_attributes=True)
42 |
43 | @field_validator("model_mapping", mode="before")
44 | @classmethod
45 | def parse_model_mapping(cls, v):
46 | """Parse JSON string to dictionary for model_mapping field."""
47 | if v is None:
48 | return None
49 | if isinstance(v, str):
50 | try:
51 | return json.loads(v)
52 | except json.JSONDecodeError:
53 | logger.warning(f"Failed to parse model_mapping JSON: {v}")
54 | return {}
55 | return v
56 |
57 |
58 | class ProviderKey(ProviderKeyInDBBase):
59 | @computed_field
60 | @property
61 | def api_key(self) -> str | None:
62 | """Masked API key for responses."""
63 | if self.encrypted_api_key:
64 | decrypted_value = decrypt_api_key(self.encrypted_api_key)
65 | provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls(
66 | self.provider_name
67 | )
68 | try:
69 | api_key, _ = provider_adapter_cls.deserialize_api_key_config(
70 | decrypted_value
71 | )
72 | return provider_adapter_cls.mask_api_key(api_key)
73 | except Exception as e:
74 | logger.error(
75 | f"Error deserializing API key for provider {self.provider_name}: {e}"
76 | )
77 | return None
78 | return None
79 |
80 | @computed_field
81 | @property
82 | def config(self) -> dict[str, str] | None:
83 | """Masked config for responses."""
84 | if self.encrypted_api_key:
85 | decrypted_value = decrypt_api_key(self.encrypted_api_key)
86 | provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls(
87 | self.provider_name
88 | )
89 | try:
90 | _, config = provider_adapter_cls.deserialize_api_key_config(
91 | decrypted_value
92 | )
93 | return provider_adapter_cls.mask_config(config)
94 | except Exception as e:
95 | logger.error(
96 | f"Error deserializing config for provider {self.provider_name}: {e}"
97 | )
98 | return None
99 | return None
100 |
101 |
102 | class ProviderKeyUpsertItem(BaseModel):
103 | provider_name: str = Field(..., min_length=1)
104 | api_key: str | None = None
105 | base_url: str | None = None
106 | model_mapping: dict[str, str] | None = None
107 | config: dict[str, str] | None = None
108 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | TensorBlock Forge is licensed under the Apache License, Version 2.0. In addition to the standard terms of the Apache License 2.0, your use of TensorBlock Forge is subject to the following Supplemental Terms:
2 |
3 | -------------------------------------------------------------------------------
4 | 1. Commercial Use Restrictions
5 | -------------------------------------------------------------------------------
6 |
7 | The standard open-source license granted under Apache 2.0 does not apply to the following categories of usage. If you intend to use TensorBlock Forge in any of the following ways, you must obtain prior written commercial authorization from the TensorBlock team:
8 |
9 | 1. Modifications and Derivative Works
10 | Use of modified versions of TensorBlock Forge or derivative works thereof, including but not limited to changes to application branding, codebase, features, user interface, or data structures, for commercial purposes.
11 |
12 | 2. Enterprise Deployment
13 | Use of TensorBlock Forge within a commercial organization or in services provided to third-party enterprise clients.
14 |
15 | 3. Public Cloud-Based Services
16 | Operation of public cloud services built upon TensorBlock Forge (e.g., SaaS platforms, hosted APIs, multi-tenant web services) accessible to general users.
17 |
18 | 4. Bundled Distribution with Hardware
19 | Distribution of TensorBlock Forge bundled, integrated, or pre-installed on hardware products intended for commercial sale.
20 |
21 | -------------------------------------------------------------------------------
22 | 2. Contributor Agreement
23 | -------------------------------------------------------------------------------
24 |
25 | By contributing to TensorBlock Forge in the form of code, documentation, or other assets, you acknowledge and agree to the following terms:
26 |
27 | 1. Right to Re-License
28 | You grant the maintainers of TensorBlock Forge a perpetual, irrevocable, worldwide, non-exclusive license to use, reproduce, modify, sublicense, and distribute your contributions under both open-source and commercial licenses.
29 |
30 | 2. Commercial Use of Contributions
31 | Your contributions may be incorporated into commercial versions of TensorBlock Forge or other related products offered by the TensorBlock team, including but not limited to proprietary deployments, enterprise features, and hosted services.
32 |
33 | 3. License Policy Adjustments
34 | TensorBlock reserves the right to modify the licensing structure of TensorBlock Forge in the future, in accordance with project governance and development objectives. Any such change shall not affect your rights under the version of the license in effect at the time of your contribution, unless explicitly agreed otherwise.
35 |
36 | -------------------------------------------------------------------------------
37 | 3. General Provisions
38 | -------------------------------------------------------------------------------
39 |
40 | 1. Interpretation
41 | The interpretation and enforcement of these Supplemental Terms shall be at the sole discretion of the TensorBlock development team.
42 |
43 | 2. Amendments
44 | These Supplemental Terms may be updated from time to time to reflect changes in project governance, usage scenarios, or business needs. Updates will be communicated via the software interface or official project communication channels. Continued use of the software after such notice constitutes acceptance of the updated terms.
45 |
46 | -------------------------------------------------------------------------------
47 | Except as explicitly modified by these Supplemental Terms, all other rights and obligations regarding your use of TensorBlock Forge are governed by the Apache License, Version 2.0.
48 |
49 | You may access the full Apache 2.0 license text at:
50 | http://www.apache.org/licenses/LICENSE-2.0
51 |
52 | For commercial licensing inquiries, please contact the TensorBlock team at:
53 | contact@tensorblock.co
54 | -------------------------------------------------------------------------------
--------------------------------------------------------------------------------
/app/services/providers/fireworks_adapter.py:
--------------------------------------------------------------------------------
1 | from .openai_adapter import OpenAIAdapter
2 | from typing import Any
3 |
4 |
5 | class FireworksAdapter(OpenAIAdapter):
6 | """Adapter for Fireworks API"""
7 |
8 | FIREWORKS_MODEL_MAPPING = {
9 | "Llama 4 Maverick Instruct (Basic)": "accounts/fireworks/models/llama4-maverick-instruct-basic",
10 | "Llama 4 Scout Instruct (Basic)": "accounts/fireworks/models/llama4-scout-instruct-basic",
11 | "Llama 3.1 405B Instruct": "accounts/fireworks/models/llama-v3p1-405b-instruct",
12 | "Llama 3.1 8B Instruct": "accounts/fireworks/models/llama-v3p1-8b-instruct",
13 | "Llama 3.3 70B Instruct": "accounts/fireworks/models/llama-v3p3-70b-instruct",
14 | "Llama 3.2 90B Vision Instruct": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
15 | "DeepSeek V3": "accounts/fireworks/models/deepseek-v3",
16 | "DeepSeek R1 (Fast)": "accounts/fireworks/models/deepseek-r1",
17 | "DeepSeek R1 (Basic)": "accounts/fireworks/models/deepseek-r1-basic",
18 | "Deepseek V3 03-24": "accounts/fireworks/models/deepseek-v3-0324",
19 | "Qwen Qwq 32b Preview": "accounts/fireworks/models/qwen-qwq-32b-preview",
20 | "Phi 3.5 Vision Instruct": "accounts/fireworks/models/phi-3-vision-128k-instruct",
21 | "Firesearch Ocr V6": "accounts/fireworks/models/firesearch-ocr-v6",
22 | "Yi-Large": "accounts/yi-01-ai/models/yi-large",
23 | "Llama V3p1 405b Instruct Long": "accounts/fireworks/models/llama-v3p1-405b-instruct-long",
24 | "Llama Guard 3 8b": "accounts/fireworks/models/llama-guard-3-8b",
25 | "Dobby-Unhinged-Llama-3.3-70B": "accounts/sentientfoundation/models/dobby-unhinged-llama-3-3-70b-new",
26 | "Mixtral MoE 8x22B Instruct": "accounts/fireworks/models/mixtral-8x22b-instruct",
27 | "Qwen2.5 72B Instruct": "accounts/fireworks/models/qwen2p5-72b-instruct",
28 | "QwQ-32B": "accounts/fireworks/models/qwq-32b",
29 | "Qwen2 VL 72B Instruct": "accounts/fireworks/models/qwen2-vl-72b-instruct",
30 | }
31 |
32 | def __init__(
33 | self,
34 | provider_name: str,
35 | base_url: str,
36 | config: dict[str, Any] | None = None,
37 | ):
38 | self._base_url = base_url
39 | super().__init__(provider_name, base_url, config=config)
40 |
41 | async def list_models(self, api_key: str) -> list[str]:
42 | """List all models (verbosely) supported by the provider"""
43 | # Check cache first
44 | cached_models = self.get_cached_models(api_key, self._base_url)
45 | if cached_models is not None:
46 | return cached_models
47 |
48 | # If not in cache, make API call
49 | # headers = {
50 | # "Authorization": f"Bearer {api_key}",
51 | # "Content-Type": "application/json",
52 | # }
53 | # fireworks models requires a account id, which we use the official account fireworks
54 | # https://docs.fireworks.ai/api-reference/list-models
55 | # And fireworks inference api and list models api doesn't share the same base url
56 | # url = "https://api.fireworks.ai/v1/accounts/fireworks/models"
57 |
58 | # async with (
59 | # aiohttp.ClientSession() as session,
60 | # session.get(url, headers=headers, params={"pageSize": 200}) as response,
61 | # ):
62 | # if response.status != HTTPStatus.OK:
63 | # error_text = await response.text()
64 | # raise ValueError(f"Fireworks API error: {error_text}")
65 | # resp = await response.json()
66 | # self.FIREWORKS_MODEL_MAPPING = {
67 | # d["displayName"]: d["name"] for d in resp["models"]
68 | # }
69 | # return [d["name"] for d in resp["models"]]
70 |
71 | # TODO: currently fireworks api doesn't support list all the serverless models
72 | # We simply hardcode the models here
73 | return list(self.FIREWORKS_MODEL_MAPPING.values())
74 |
--------------------------------------------------------------------------------
/tests/mock_testing/mock_openai.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for patching OpenAI's client with our mock client.
3 | This allows tests to run without real API calls.
4 | """
5 |
6 | import importlib.util
7 | import os
8 | import sys
9 | from unittest.mock import patch
10 |
11 | # Add the parent directory to the path
12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
13 |
14 | # Import the mock client
15 | from app.services.providers.mock_provider import MockClient
16 |
17 |
18 | # Create a mock openai module that can be used when the real one is not installed
19 | class MockOpenAIModule:
20 | def __init__(self):
21 | self.OpenAI = MockClient
22 | self.AsyncOpenAI = MockClient
23 |
24 |
25 | # Check if openai is installed
26 | if importlib.util.find_spec("openai") is None:
27 | # Create a mock openai module
28 | mock_openai = MockOpenAIModule()
29 | # Add it to sys.modules so import statements work
30 | sys.modules["openai"] = mock_openai
31 | # Import the mock module
32 | import openai as _mock_openai # noqa: F401
33 |
34 |
35 | def enable_mock_openai():
36 | """
37 | Enable mocking of OpenAI's client for testing.
38 | This function returns a context manager that can be used with 'with'.
39 |
40 | Example:
41 | with enable_mock_openai():
42 | client = openai.OpenAI(api_key="fake-key")
43 | # All calls to client will use the mock implementation
44 | """
45 | # Create patchers
46 | openai_client_patcher = patch("openai.OpenAI", return_value=MockClient())
47 |
48 | # Start the patchers
49 | mock_openai_client = openai_client_patcher.start()
50 |
51 | try:
52 | yield mock_openai_client
53 | finally:
54 | # Stop the patchers
55 | openai_client_patcher.stop()
56 |
57 |
58 | def patch_with_mock():
59 | """Decorator to patch OpenAI with mock client for a test function"""
60 |
61 | def decorator(func):
62 | async def wrapper(*args, **kwargs):
63 | # Create patchers
64 | openai_client_patcher = patch("openai.OpenAI", return_value=MockClient())
65 | async_openai_client_patcher = patch(
66 | "openai.AsyncOpenAI", return_value=MockClient()
67 | )
68 |
69 | # Start the patchers
70 | openai_client_patcher.start()
71 | async_openai_client_patcher.start()
72 |
73 | try:
74 | # Run the function
75 | result = await func(*args, **kwargs)
76 | return result
77 | finally:
78 | # Stop the patchers
79 | openai_client_patcher.stop()
80 | async_openai_client_patcher.stop()
81 |
82 | return wrapper
83 |
84 | return decorator
85 |
86 |
87 | class MockPatch:
88 | """
89 | Class-based helper for patching OpenAI in tests.
90 | This can be used as a context manager or standalone.
91 | """
92 |
93 | def __init__(self):
94 | self.openai_client_patcher = patch("openai.OpenAI", return_value=MockClient())
95 | self.async_openai_client_patcher = patch(
96 | "openai.AsyncOpenAI", return_value=MockClient()
97 | )
98 | self.patched = False
99 |
100 | def start(self):
101 | """Start patching OpenAI"""
102 | if not self.patched:
103 | self.openai_client_patcher.start()
104 | self.async_openai_client_patcher.start()
105 | self.patched = True
106 | return self
107 |
108 | def stop(self):
109 | """Stop patching OpenAI"""
110 | if self.patched:
111 | self.openai_client_patcher.stop()
112 | self.async_openai_client_patcher.stop()
113 | self.patched = False
114 |
115 | def __enter__(self):
116 | """Enter context manager"""
117 | return self.start()
118 |
119 | def __exit__(self, exc_type, exc_val, exc_tb):
120 | """Exit context manager"""
121 | self.stop()
122 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | # Forge Integration Tests
2 |
3 | This directory contains tests for Forge, including unit tests and integration tests.
4 |
5 | ## Integration Tests
6 |
7 | The main integration test script is:
8 |
9 | - `integration_test.py` - Supports both local development with real API calls and mock mode for CI/CD environments
10 |
11 | ### Prerequisites
12 |
13 | Before running integration tests, make sure:
14 |
15 | 1. The Forge server is running (`python run.py` from the project root)
16 | 2. Required Python packages are installed:
17 | ```
18 | pip install requests python-dotenv
19 | ```
20 |
21 | ### Running Integration Tests
22 |
23 | You can run the integration tests in different modes:
24 |
25 | #### With Real API Calls
26 |
27 | For local development with actual API calls to external providers:
28 |
29 | ```bash
30 | # Run the full integration test (requires API keys)
31 | python tests/integration_test.py
32 | ```
33 |
34 | This test will register a user, add provider keys, and test completions with real API calls. You'll need:
35 |
36 | - OpenAI API key (set as OPENAI_API_KEY in .env)
37 | - (Optional) Anthropic API key (set as ANTHROPIC_API_KEY in .env)
38 |
39 | #### With Mocked Responses
40 |
41 | For CI environments or when you don't want to make real API calls:
42 |
43 | ```bash
44 | # Use CI testing mode
45 | CI_TESTING=true python tests/integration_test.py
46 |
47 | # Or use the SKIP_API_CALLS flag (same effect)
48 | SKIP_API_CALLS=true python tests/integration_test.py
49 | ```
50 |
51 | In mock mode:
52 | - External API calls are replaced with mock responses
53 | - The server connection is still verified
54 | - User registration and management features are tested
55 | - API interactions use the mock provider to simulate responses
56 | - The mock provider returns predefined responses for testing
57 |
58 | ### Mock Provider
59 |
60 | The integration test uses the mock provider from `app/services/providers/mock_provider.py` when running in mock mode. This provider:
61 |
62 | - Simulates API responses without making actual API calls
63 | - Provides mock models similar to those from real providers
64 | - Returns consistent, predictable responses for testing
65 | - Can be used with the `CI_TESTING=true` or `SKIP_API_CALLS=true` flag
66 |
67 | ### GitHub Actions Integration
68 |
69 | The test is automatically run in GitHub Actions workflows defined in `.github/workflows/tests.yml`. The workflow:
70 |
71 | 1. Sets up a test environment
72 | 2. Starts the Forge server
73 | 3. Runs the integration tests in CI mode
74 | 4. Ensures no actual API calls are made to external services
75 |
76 | ## Unit Tests
77 |
78 | Run individual unit tests with:
79 |
80 | ```bash
81 | python -m unittest tests/test_security.py
82 | python -m unittest tests/test_provider_service.py
83 | ```
84 |
85 | Run all unit tests with:
86 |
87 | ```bash
88 | python -m unittest discover tests
89 | ```
90 |
91 | ## Cache Tests
92 |
93 | The cache test directory contains tests for both synchronous and asynchronous caching functionality:
94 |
95 | - `test_sync_cache.py` - Tests the synchronous in-memory caching with ProviderService instances
96 | - `test_async_cache.py` - Tests the async-compatible cache implementation for future distributed caching
97 |
98 | ### Running Cache Tests
99 |
100 | ```bash
101 | # Run synchronous cache tests
102 | python tests/cache/test_sync_cache.py
103 |
104 | # Run asynchronous cache tests
105 | python tests/cache/test_async_cache.py
106 | ```
107 |
108 | These tests verify:
109 | - Cache hit/miss behavior
110 | - Performance improvements from caching
111 | - Proper instance reuse with singleton pattern
112 | - AsyncCache compatibility with asyncio patterns
113 |
114 | The async tests are important for validating Forge's readiness for distributed caching solutions like Redis or AWS ElasticCache, as outlined in `docs/DISTRIBUTED_CACHE_MIGRATION.md`.
115 |
116 | ## Test Coverage Report
117 |
118 | Generate a test coverage report with:
119 |
120 | ```bash
121 | pytest tests/test_*.py --cov=app --cov-report=xml
122 | ```
123 |
--------------------------------------------------------------------------------
/app/api/routes/auth.py:
--------------------------------------------------------------------------------
1 | from datetime import timedelta
2 | from typing import Any
3 |
4 | from fastapi import APIRouter, Depends, HTTPException, status
5 | from fastapi.security import OAuth2PasswordRequestForm
6 | from sqlalchemy import select
7 | from sqlalchemy.ext.asyncio import AsyncSession
8 |
9 | from app.api.routes.users import create_user as create_user_endpoint_logic
10 | from app.api.schemas.user import Token, User, UserCreate
11 | from app.core.database import get_async_db
12 | from app.core.logger import get_logger
13 | from app.core.security import (
14 | ACCESS_TOKEN_EXPIRE_MINUTES,
15 | create_access_token,
16 | verify_password,
17 | )
18 | from app.models.user import User as UserModel
19 |
20 | logger = get_logger(name="auth")
21 |
22 | router = APIRouter()
23 |
24 |
25 | @router.post("/register", response_model=User)
26 | async def register(
27 | user_in: UserCreate,
28 | db: AsyncSession = Depends(get_async_db),
29 | ) -> Any:
30 | """
31 | Register new user. This will create the user but will not automatically create a Forge API key.
32 | Users should use the /api-keys/ endpoint to create their keys after registration.
33 | """
34 | # Call the user creation logic from users.py
35 | # This handles checks for existing email/username and password hashing.
36 | try:
37 | db_user = await create_user_endpoint_logic(user_in=user_in, db=db)
38 | except HTTPException as e: # Propagate HTTPExceptions (like 400 for existing user)
39 | raise e
40 | except Exception as e: # Catch any other unexpected errors during user creation
41 | # Log this error e
42 | logger.error(
43 | f"Unexpected error during create_user_endpoint_logic call: {e}"
44 | ) # Added more logging
45 | raise HTTPException(
46 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
47 | detail="An unexpected error occurred during user registration.",
48 | )
49 |
50 | # Prepare the response.
51 | # create_user_endpoint_logic returns a UserModel instance.
52 | # The User Pydantic model has from_attributes = True.
53 | try:
54 | # For Pydantic v2+ (which is likely if using FastAPI > 0.100.0)
55 | pydantic_user = User.model_validate(db_user)
56 | # If using Pydantic v1, you would use:
57 | # pydantic_user = User.from_orm(db_user)
58 | except Exception as e_pydantic:
59 | # Log this validation error to understand what went wrong if it fails
60 | logger.error(
61 | f"Error during Pydantic model_validate in /auth/register: {e_pydantic}"
62 | )
63 | logger.error(
64 | f"SQLAlchemy User object was: {db_user.__dict__ if db_user else 'None'}"
65 | )
66 | raise HTTPException(
67 | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
68 | detail="Error processing user data after creation.",
69 | )
70 |
71 | pydantic_user.forge_api_keys = [] # Explicitly set to empty list as no key is auto-generated
72 |
73 | return pydantic_user
74 |
75 |
76 | @router.post("/token", response_model=Token)
77 | async def login_for_access_token(
78 | db: AsyncSession = Depends(get_async_db),
79 | form_data: OAuth2PasswordRequestForm = Depends()
80 | ) -> Any:
81 | """
82 | Get an access token for future API requests.
83 | """
84 | result = await db.execute(
85 | select(UserModel).filter(UserModel.username == form_data.username)
86 | )
87 | user = result.scalar_one_or_none()
88 |
89 | if not user or not verify_password(form_data.password, user.hashed_password):
90 | raise HTTPException(
91 | status_code=status.HTTP_401_UNAUTHORIZED,
92 | detail="Incorrect username or password",
93 | headers={"WWW-Authenticate": "Bearer"},
94 | )
95 |
96 | access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
97 | access_token = create_access_token(
98 | data={"sub": user.username}, expires_delta=access_token_expires
99 | )
100 |
101 | return {"access_token": access_token, "token_type": "bearer"}
102 |
--------------------------------------------------------------------------------
/tests/mock_testing/examples/test_with_mocks.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | Example tests that demonstrate using the mock client in different scenarios.
4 | """
5 |
6 | import asyncio
7 | import os
8 | import sys
9 | import unittest
10 |
11 | # Add the parent directory to the path
12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
13 |
14 | # Import our mock utilities
15 | from tests.mock_testing.mock_openai import MockPatch, patch_with_mock
16 |
17 |
18 | # This would be your application code that uses OpenAI
19 | class MyAppService:
20 | """Example application service that uses OpenAI"""
21 |
22 | def __init__(self, api_key=None):
23 | """Initialize with an API key"""
24 | self.api_key = api_key or os.getenv("OPENAI_API_KEY")
25 |
26 | # In a real app, you'd import OpenAI here
27 | # But we'll do it in the methods to allow for patching
28 |
29 | async def generate_response(self, user_message: str) -> str:
30 | """Generate a response to a user message"""
31 | import openai
32 |
33 | client = openai.OpenAI(api_key=self.api_key)
34 |
35 | response = await client.chat_completions_create(
36 | model="mock-only-gpt-3.5-turbo",
37 | messages=[
38 | {"role": "system", "content": "You are a helpful assistant."},
39 | {"role": "user", "content": user_message},
40 | ],
41 | )
42 |
43 | return response.choices[0].message.content
44 |
45 | async def get_available_models(self) -> list:
46 | """Get a list of available models"""
47 | import openai
48 |
49 | client = openai.OpenAI(api_key=self.api_key)
50 |
51 | models = await client.models_list()
52 | # Handle both object-style and dict-style responses
53 | if hasattr(models, "data") and hasattr(models.data[0], "id"):
54 | return [model.id for model in models.data]
55 | else:
56 | return [model["id"] for model in models.data]
57 |
58 |
59 | # Now let's write tests for our application using the mock client
60 | class AsyncTestCase(unittest.TestCase):
61 | """Base class for async test cases"""
62 |
63 | def run_async(self, coro):
64 | """Run a coroutine in an event loop"""
65 | return asyncio.run(coro)
66 |
67 |
68 | class TestMyAppWithMocks(AsyncTestCase):
69 | """Test the MyApp class with mocked OpenAI client"""
70 |
71 | def setUp(self):
72 | """Set up the test"""
73 | self.service = MyAppService(api_key="test-key")
74 |
75 | def test_generate_response(self):
76 | """Test generating a response with the mock client"""
77 |
78 | @patch_with_mock()
79 | async def _async_test():
80 | response = await self.service.generate_response("Hello, how are you?")
81 |
82 | # The mock response will be predictable
83 | self.assertIsNotNone(response)
84 | self.assertIn("You asked", response) # Our mock adds this prefix
85 | return response
86 |
87 | self.run_async(_async_test())
88 |
89 | def test_get_models(self):
90 | """Test getting available models with the mock client"""
91 |
92 | @patch_with_mock()
93 | async def _async_test():
94 | models = await self.service.get_available_models()
95 |
96 | # The mock returns specific models
97 | self.assertIn("mock-gpt-3.5-turbo", models)
98 | self.assertIn("mock-gpt-4", models)
99 | return models
100 |
101 | self.run_async(_async_test())
102 |
103 | def test_with_context_manager(self):
104 | """Test using the mock client with a context manager"""
105 |
106 | async def _async_test():
107 | # Using the context manager approach
108 | mock_patch = MockPatch()
109 | with mock_patch:
110 | response = await self.service.generate_response("Tell me about testing")
111 | self.assertIsNotNone(response)
112 | self.assertIn("You asked", response)
113 | return response
114 |
115 | self.run_async(_async_test())
116 |
117 |
118 | # If we run this directly, run all the tests
119 | if __name__ == "__main__":
120 | unittest.main()
121 |
--------------------------------------------------------------------------------
/app/api/schemas/anthropic.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Literal, Optional, Union
2 | from pydantic import BaseModel, Field, field_validator
3 |
4 | from app.core.logger import get_logger
5 |
6 | logger = get_logger(name="anthropic_schemas")
7 |
8 | # Content Block Models
9 | class ContentBlockText(BaseModel):
10 | type: Literal["text"]
11 | text: str
12 |
13 |
14 | class ContentBlockImageSource(BaseModel):
15 | type: str
16 | media_type: str
17 | data: str
18 |
19 |
20 | class ContentBlockImage(BaseModel):
21 | type: Literal["image"]
22 | source: ContentBlockImageSource
23 |
24 |
25 | class ContentBlockToolUse(BaseModel):
26 | type: Literal["tool_use"]
27 | id: str
28 | name: str
29 | input: Dict[str, Any]
30 |
31 |
32 | class ContentBlockToolResult(BaseModel):
33 | type: Literal["tool_result"]
34 | tool_use_id: str
35 | content: Union[str, List[Dict[str, Any]], List[Any]]
36 | is_error: Optional[bool] = None
37 |
38 |
39 | ContentBlock = Union[
40 | ContentBlockText, ContentBlockImage, ContentBlockToolUse, ContentBlockToolResult
41 | ]
42 |
43 |
44 | # System Content
45 | class SystemContent(BaseModel):
46 | type: Literal["text"]
47 | text: str
48 |
49 |
50 | # Message Model
51 | class AnthropicMessage(BaseModel):
52 | role: Literal["user", "assistant"]
53 | content: Union[str, List[ContentBlock]]
54 |
55 |
56 | # Tool Models
57 | class Tool(BaseModel):
58 | name: str
59 | description: Optional[str] = None
60 | input_schema: Dict[str, Any] = Field(..., alias="input_schema")
61 |
62 |
63 | class ToolChoice(BaseModel):
64 | type: Literal["auto", "any", "tool"]
65 | name: Optional[str] = None
66 |
67 |
68 | # Main Request Model
69 | class AnthropicMessagesRequest(BaseModel):
70 | model: str
71 | max_tokens: int
72 | messages: List[AnthropicMessage]
73 | system: Optional[Union[str, List[SystemContent]]] = None
74 | stop_sequences: Optional[List[str]] = None
75 | stream: Optional[bool] = False
76 | temperature: Optional[float] = None
77 | top_p: Optional[float] = None
78 | top_k: Optional[int] = None
79 | metadata: Optional[Dict[str, Any]] = None
80 | tools: Optional[List[Tool]] = None
81 | tool_choice: Optional[ToolChoice] = None
82 |
83 | @field_validator("top_k")
84 | def check_top_k(cls, v: Optional[int]) -> Optional[int]:
85 | if v is not None:
86 | logger.warning(
87 | f"Parameter 'top_k' provided by client but is not directly supported by the OpenAI Chat Completions API and will be ignored. Value: {v}"
88 | )
89 | return v
90 |
91 |
92 | # Token Count Request/Response
93 | class TokenCountRequest(BaseModel):
94 | model: str
95 | messages: List[AnthropicMessage]
96 | system: Optional[Union[str, List[SystemContent]]] = None
97 | tools: Optional[List[Tool]] = None
98 |
99 |
100 | class TokenCountResponse(BaseModel):
101 | input_tokens: int
102 |
103 |
104 | # Usage Model
105 | class Usage(BaseModel):
106 | input_tokens: int
107 | output_tokens: int
108 |
109 |
110 | # Error Models
111 | class AnthropicErrorType:
112 | INVALID_REQUEST = "invalid_request_error"
113 | AUTHENTICATION = "authentication_error"
114 | PERMISSION = "permission_error"
115 | NOT_FOUND = "not_found_error"
116 | RATE_LIMIT = "rate_limit_error"
117 | API_ERROR = "api_error"
118 | OVERLOADED = "overloaded_error"
119 | REQUEST_TOO_LARGE = "request_too_large_error"
120 |
121 |
122 | class AnthropicErrorDetail(BaseModel):
123 | type: str
124 | message: str
125 | provider: Optional[str] = None
126 | provider_message: Optional[str] = None
127 | provider_code: Optional[Union[str, int]] = None
128 |
129 |
130 | class AnthropicErrorResponse(BaseModel):
131 | type: Literal["error"] = "error"
132 | error: AnthropicErrorDetail
133 |
134 |
135 | # Response Model
136 | class AnthropicMessagesResponse(BaseModel):
137 | id: str
138 | type: Literal["message"] = "message"
139 | role: Literal["assistant"] = "assistant"
140 | model: str
141 | content: List[ContentBlock]
142 | stop_reason: Optional[
143 | Literal["end_turn", "max_tokens", "stop_sequence", "tool_use", "error"]
144 | ] = None
145 | stop_sequence: Optional[str] = None
146 | usage: Usage
--------------------------------------------------------------------------------
/app/api/routes/stats.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 | from typing import Any
3 |
4 | from fastapi import APIRouter, Depends, HTTPException, Query
5 | from sqlalchemy.ext.asyncio import AsyncSession
6 |
7 | from app.api.dependencies import (
8 | get_current_active_user_from_clerk,
9 | get_current_user,
10 | get_async_db,
11 | get_user_by_api_key,
12 | )
13 | from app.models.user import User
14 | from app.services.usage_stats_service import UsageStatsService
15 |
16 | router = APIRouter()
17 |
18 |
19 | # http://localhost:8000/stats/?start_date=2025-04-09&end_date=2025-04-09
20 | # http://localhost:8000/stats/?model=gpt-3.5-turbo
21 | @router.get("/", response_model=list[dict[str, Any]])
22 | async def get_user_stats(
23 | current_user: User = Depends(get_user_by_api_key),
24 | provider: str | None = Query(None, description="Filter stats by provider name"),
25 | model: str | None = Query(None, description="Filter stats by model name"),
26 | start_date: date | None = Query(
27 | None, description="Start date for filtering (YYYY-MM-DD)"
28 | ),
29 | end_date: date | None = Query(
30 | None, description="End date for filtering (YYYY-MM-DD)"
31 | ),
32 | db: AsyncSession = Depends(get_async_db),
33 | ):
34 | """
35 | Get aggregated usage statistics for the current user, queried from request logs.
36 |
37 | Allows filtering by provider, model, and date range (inclusive).
38 | """
39 | # Note: Service layer now handles aggregation and filtering
40 | # We pass the query parameters directly to the service method
41 | stats = await UsageStatsService.get_user_stats(
42 | db=db,
43 | user_id=current_user.id,
44 | provider=provider,
45 | model=model,
46 | start_date=start_date,
47 | end_date=end_date,
48 | )
49 | return stats
50 |
51 |
52 | # http://localhost:8000/stats/clerk/?start_date=2025-04-09&end_date=2025-04-09
53 | # http://localhost:8000/stats/clerk/?model=gpt-3.5-turbo
54 | @router.get("/clerk", response_model=list[dict[str, Any]])
55 | async def get_user_stats_clerk(
56 | current_user: User = Depends(get_current_active_user_from_clerk),
57 | provider: str | None = Query(None, description="Filter stats by provider name"),
58 | model: str | None = Query(None, description="Filter stats by model name"),
59 | start_date: date | None = Query(
60 | None, description="Start date for filtering (YYYY-MM-DD)"
61 | ),
62 | end_date: date | None = Query(
63 | None, description="End date for filtering (YYYY-MM-DD)"
64 | ),
65 | db: AsyncSession = Depends(get_async_db),
66 | ):
67 | """
68 | Get aggregated usage statistics for the current user, queried from request logs.
69 |
70 | Allows filtering by provider, model, and date range (inclusive).
71 | """
72 | # Note: Service layer now handles aggregation and filtering
73 | # We pass the query parameters directly to the service method
74 | stats = await UsageStatsService.get_user_stats(
75 | db=db,
76 | user_id=current_user.id,
77 | provider=provider,
78 | model=model,
79 | start_date=start_date,
80 | end_date=end_date,
81 | )
82 | return stats
83 |
84 |
85 | @router.get("/admin", response_model=list[dict[str, Any]])
86 | async def get_all_stats(
87 | current_user: User = Depends(get_current_user),
88 | provider: str | None = Query(None, description="Filter stats by provider name"),
89 | model: str | None = Query(None, description="Filter stats by model name"),
90 | start_date: date | None = Query(
91 | None, description="Start date for filtering (YYYY-MM-DD)"
92 | ),
93 | end_date: date | None = Query(
94 | None, description="End date for filtering (YYYY-MM-DD)"
95 | ),
96 | db: AsyncSession = Depends(get_async_db),
97 | ):
98 | """
99 | Get aggregated usage statistics for all users, queried from request logs.
100 |
101 | Only accessible to admin users. Allows filtering.
102 | """
103 | # Check if user is an admin
104 | if not getattr(current_user, "is_admin", False):
105 | raise HTTPException(
106 | status_code=403, detail="Not authorized to access admin statistics"
107 | )
108 |
109 | stats = await UsageStatsService.get_all_stats(
110 | db=db, provider=provider, model=model, start_date=start_date, end_date=end_date
111 | )
112 | return stats
113 |
--------------------------------------------------------------------------------
/app/exceptions/exceptions.py:
--------------------------------------------------------------------------------
1 | class BaseForgeException(Exception):
2 | pass
3 |
4 | class InvalidProviderException(BaseForgeException):
5 | """Exception raised when a provider is invalid."""
6 |
7 | def __init__(self, identifier: str):
8 | self.identifier = identifier
9 | super().__init__(f"Provider {identifier} is invalid or failed to extract provider info from model_id {identifier}")
10 |
11 |
12 | class ProviderAuthenticationException(BaseForgeException):
13 | """Exception raised when a provider authentication fails."""
14 |
15 | def __init__(self, provider_name: str, error: Exception):
16 | self.provider_name = provider_name
17 | self.error = error
18 | super().__init__(f"Provider {provider_name} authentication failed: {error}")
19 |
20 |
21 | class BaseInvalidProviderSetupException(BaseForgeException):
22 | """Exception raised when a provider setup is invalid."""
23 |
24 | def __init__(self, provider_name: str, error: Exception):
25 | self.provider_name = provider_name
26 | self.error = error
27 | super().__init__(f"Provider {provider_name} setup is invalid: {error}")
28 |
29 | class InvalidProviderConfigException(BaseInvalidProviderSetupException):
30 | """Exception raised when a provider config is invalid."""
31 |
32 | def __init__(self, provider_name: str, error: Exception):
33 | super().__init__(provider_name, error)
34 |
35 | class InvalidProviderAPIKeyException(BaseInvalidProviderSetupException):
36 | """Exception raised when a provider API key is invalid."""
37 |
38 | def __init__(self, provider_name: str, error: Exception):
39 | super().__init__(provider_name, error)
40 |
41 | class ProviderAPIException(BaseForgeException):
42 | """Exception raised when a provider API error occurs."""
43 |
44 | def __init__(self, provider_name: str, error_code: int, error_message: str):
45 | """Initialize the exception and persist error details for downstream handling.
46 |
47 | Many parts of the codebase (e.g. the Claude Code routes) rely on the
48 | presence of the ``error_code`` and ``error_message`` attributes to
49 | construct a well-formed error response. Without setting these instance
50 | attributes an ``AttributeError`` is raised when the exception is caught
51 | and introspected, masking the original provider failure. Persisting the
52 | values here guarantees that the original error information is available
53 | to any error-handling middleware.
54 | """
55 | self.provider_name = provider_name
56 | self.error_code = error_code
57 | self.error_message = error_message
58 |
59 | # Compose the base exception message for logging / debugging purposes.
60 | super().__init__(
61 | f"Provider {provider_name} API error: {error_code} {error_message}"
62 | )
63 |
64 |
65 | class BaseInvalidRequestException(BaseForgeException):
66 | """Exception raised when a request is invalid."""
67 |
68 | def __init__(self, provider_name: str, error: Exception):
69 | self.provider_name = provider_name
70 | self.error = error
71 | super().__init__(f"Provider {provider_name} request is invalid: {error}")
72 |
73 | class InvalidCompletionRequestException(BaseInvalidRequestException):
74 | """Exception raised when a completion request is invalid."""
75 |
76 | def __init__(self, provider_name: str, error: Exception):
77 | self.provider_name = provider_name
78 | self.error = error
79 | super().__init__(self.provider_name, self.error)
80 |
81 | class InvalidEmbeddingsRequestException(BaseInvalidRequestException):
82 | """Exception raised when a embeddings request is invalid."""
83 |
84 | def __init__(self, provider_name: str, error: Exception):
85 | self.provider_name = provider_name
86 | self.error = error
87 | super().__init__(self.provider_name, self.error)
88 |
89 | class BaseInvalidForgeKeyException(BaseForgeException):
90 | """Exception raised when a Forge key is invalid."""
91 |
92 | def __init__(self, error: Exception):
93 | self.error = error
94 | super().__init__(f"Forge key is invalid: {error}")
95 |
96 |
97 | class InvalidForgeKeyException(BaseInvalidForgeKeyException):
98 | """Exception raised when a Forge key is invalid."""
99 | def __init__(self, error: Exception):
100 | super().__init__(error)
--------------------------------------------------------------------------------
/app/api/routes/health.py:
--------------------------------------------------------------------------------
1 | """
2 | Health check and monitoring endpoints for production deployments.
3 | """
4 |
5 | from fastapi import APIRouter, HTTPException
6 | from sqlalchemy import text
7 |
8 | from app.core.database import get_connection_info, get_db_session
9 | from app.core.logger import get_logger
10 |
11 | logger = get_logger(name="health")
12 | router = APIRouter()
13 |
14 |
15 | @router.get("/health")
16 | async def health_check():
17 | """
18 | Basic health check endpoint.
19 | Returns 200 if the service is running.
20 | """
21 | return {"status": "healthy", "service": "forge"}
22 |
23 |
24 | @router.get("/health/database")
25 | async def database_health_check():
26 | """
27 | Database health check endpoint.
28 | Returns detailed information about database connectivity and pool status.
29 | """
30 | try:
31 | # Test database connection
32 | async with get_db_session() as session:
33 | result = await session.execute(text("SELECT 1"))
34 | result.scalar()
35 |
36 | # Get connection pool information
37 | pool_info = get_connection_info()
38 |
39 | # Calculate connection usage
40 | sync_pool = pool_info['sync_engine']
41 | async_pool = pool_info['async_engine']
42 |
43 | sync_usage = sync_pool['checked_out'] / (pool_info['pool_size'] + pool_info['max_overflow']) * 100
44 | async_usage = async_pool['checked_out'] / (pool_info['pool_size'] + pool_info['max_overflow']) * 100
45 |
46 | return {
47 | "status": "healthy",
48 | "database": "connected",
49 | "connection_pools": {
50 | "sync": {
51 | "checked_out": sync_pool['checked_out'],
52 | "checked_in": sync_pool['checked_in'],
53 | "size": sync_pool['size'],
54 | "usage_percent": round(sync_usage, 1)
55 | },
56 | "async": {
57 | "checked_out": async_pool['checked_out'],
58 | "checked_in": async_pool['checked_in'],
59 | "size": async_pool['size'],
60 | "usage_percent": round(async_usage, 1)
61 | }
62 | },
63 | "configuration": {
64 | "pool_size": pool_info['pool_size'],
65 | "max_overflow": pool_info['max_overflow'],
66 | "pool_timeout": pool_info['pool_timeout'],
67 | "pool_recycle": pool_info['pool_recycle']
68 | }
69 | }
70 |
71 | except Exception as e:
72 | logger.error(f"Database health check failed: {e}")
73 | raise HTTPException(
74 | status_code=503,
75 | detail={
76 | "status": "unhealthy",
77 | "database": "disconnected",
78 | "error": str(e)
79 | }
80 | )
81 |
82 |
83 | @router.get("/health/detailed")
84 | async def detailed_health_check():
85 | """
86 | Detailed health check including all service components.
87 | """
88 | try:
89 | # Test database
90 | async with get_db_session() as session:
91 | db_result = await session.execute(text("SELECT version()"))
92 | db_version = db_result.scalar()
93 |
94 | pool_info = get_connection_info()
95 |
96 | return {
97 | "status": "healthy",
98 | "timestamp": "2025-01-21T19:15:00Z", # This would be dynamic in real implementation
99 | "service": "forge",
100 | "version": "0.1.0",
101 | "database": {
102 | "status": "connected",
103 | "version": db_version,
104 | "pool_status": pool_info
105 | },
106 | "environment": {
107 | "workers": pool_info.get('workers', 'unknown'),
108 | "pool_size": pool_info['pool_size'],
109 | "max_overflow": pool_info['max_overflow']
110 | }
111 | }
112 |
113 | except Exception as e:
114 | logger.error(f"Detailed health check failed: {e}")
115 | raise HTTPException(
116 | status_code=503,
117 | detail={
118 | "status": "unhealthy",
119 | "error": str(e),
120 | "timestamp": "2025-01-21T19:15:00Z"
121 | }
122 | )
--------------------------------------------------------------------------------
/alembic/versions/b38aad374524_create_api_request_log_table.py:
--------------------------------------------------------------------------------
1 | """Create api_request_log table
2 |
3 | Revision ID: b38aad374524
4 | Revises: c50fd7be794c
5 | Create Date: 2025-04-08 21:44:00.759874
6 |
7 | """
8 |
9 | import sqlalchemy as sa
10 |
11 | from alembic import op
12 |
13 | # revision identifiers, used by Alembic.
14 | revision = "b38aad374524"
15 | down_revision = "c50fd7be794c"
16 | branch_labels = None
17 | depends_on = None
18 |
19 |
20 | def upgrade() -> None:
21 | # ### commands auto generated by Alembic - START ###
22 | op.create_table(
23 | "api_request_log",
24 | sa.Column("id", sa.Integer(), nullable=False),
25 | sa.Column("user_id", sa.Integer(), nullable=True),
26 | sa.Column("provider_name", sa.String(), nullable=False),
27 | sa.Column("model", sa.String(), nullable=False),
28 | sa.Column("endpoint", sa.String(), nullable=False),
29 | sa.Column("request_timestamp", sa.DateTime(), nullable=False),
30 | sa.Column("input_tokens", sa.Integer(), nullable=True),
31 | sa.Column("output_tokens", sa.Integer(), nullable=True),
32 | sa.Column("total_tokens", sa.Integer(), nullable=True),
33 | sa.Column("cost", sa.Float(), nullable=True),
34 | sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="SET NULL"),
35 | sa.PrimaryKeyConstraint("id"),
36 | )
37 | op.create_index(
38 | "ix_api_request_log_user_time",
39 | "api_request_log",
40 | ["user_id", "request_timestamp"],
41 | unique=False,
42 | )
43 | op.create_index(
44 | op.f("ix_api_request_log_endpoint"),
45 | "api_request_log",
46 | ["endpoint"],
47 | unique=False,
48 | )
49 | op.create_index(
50 | op.f("ix_api_request_log_model"), "api_request_log", ["model"], unique=False
51 | )
52 | op.create_index(
53 | op.f("ix_api_request_log_provider_name"),
54 | "api_request_log",
55 | ["provider_name"],
56 | unique=False,
57 | )
58 | op.create_index(
59 | op.f("ix_api_request_log_request_timestamp"),
60 | "api_request_log",
61 | ["request_timestamp"],
62 | unique=False,
63 | )
64 | op.create_index(
65 | op.f("ix_api_request_log_user_id"), "api_request_log", ["user_id"], unique=False
66 | )
67 | # op.drop_constraint(None, 'usage_stats', type_='foreignkey') # REMOVED
68 | # op.create_foreign_key(None, 'usage_stats', 'users', ['user_id'], ['id'], ondelete='CASCADE') # REMOVED
69 | # op.drop_column('usage_stats', 'completion_tokens') # REMOVED
70 | # op.drop_column('usage_stats', 'timestamp') # REMOVED
71 | # op.drop_column('usage_stats', 'error_count') # REMOVED
72 | # op.drop_column('usage_stats', 'prompt_tokens') # REMOVED
73 | # op.drop_column('usage_stats', 'success_count') # REMOVED
74 | # op.drop_column('usage_stats', 'request_count') # REMOVED
75 | # ### end Alembic commands ###
76 |
77 |
78 | def downgrade() -> None:
79 | # ### commands auto generated by Alembic - START ###
80 | # op.add_column('usage_stats', sa.Column('request_count', sa.INTEGER(), nullable=True)) # REMOVED
81 | # op.add_column('usage_stats', sa.Column('success_count', sa.INTEGER(), nullable=True)) # REMOVED
82 | # op.add_column('usage_stats', sa.Column('prompt_tokens', sa.INTEGER(), nullable=True)) # REMOVED
83 | # op.add_column('usage_stats', sa.Column('error_count', sa.INTEGER(), nullable=True)) # REMOVED
84 | # op.add_column('usage_stats', sa.Column('timestamp', sa.DATETIME(), nullable=True)) # REMOVED
85 | # op.add_column('usage_stats', sa.Column('completion_tokens', sa.INTEGER(), nullable=True)) # REMOVED
86 | # op.drop_constraint(None, 'usage_stats', type_='foreignkey') # REMOVED
87 | # op.create_foreign_key(None, 'usage_stats', 'users', ['user_id'], ['id']) # REMOVED
88 | op.drop_index(op.f("ix_api_request_log_user_id"), table_name="api_request_log")
89 | op.drop_index(
90 | op.f("ix_api_request_log_request_timestamp"), table_name="api_request_log"
91 | )
92 | op.drop_index(
93 | op.f("ix_api_request_log_provider_name"), table_name="api_request_log"
94 | )
95 | op.drop_index(op.f("ix_api_request_log_model"), table_name="api_request_log")
96 | op.drop_index(op.f("ix_api_request_log_endpoint"), table_name="api_request_log")
97 | op.drop_index("ix_api_request_log_user_time", table_name="api_request_log")
98 | op.drop_table("api_request_log")
99 | # ### end Alembic commands ###
100 |
--------------------------------------------------------------------------------
/app/core/database.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dotenv import load_dotenv
4 | from contextlib import asynccontextmanager
5 | from sqlalchemy import create_engine
6 | from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
7 | from sqlalchemy.orm import declarative_base, sessionmaker
8 |
9 | load_dotenv()
10 |
11 | # Production-optimized connection pool settings
12 | # With 10 Gunicorn workers, this allows max 60 connections total (10 workers × 3 pool_size × 2 engines)
13 | # Plus 40 overflow connections (10 workers × 2 max_overflow × 2 engines) = 100 max connections
14 | POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "3")) # Reduced from 5 to 3
15 | MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "2")) # Reduced from 10 to 2
16 | MAX_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
17 | POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "1800")) # 30 minutes
18 | POOL_PRE_PING = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
19 |
20 | SQLALCHEMY_DATABASE_URL = os.getenv("DATABASE_URL")
21 | if not SQLALCHEMY_DATABASE_URL:
22 | raise ValueError("DATABASE_URL environment variable is not set")
23 |
24 | # Sync engine and session
25 | engine = create_engine(
26 | SQLALCHEMY_DATABASE_URL,
27 | pool_size=POOL_SIZE,
28 | max_overflow=MAX_OVERFLOW,
29 | pool_timeout=MAX_TIMEOUT,
30 | pool_recycle=POOL_RECYCLE,
31 | pool_pre_ping=POOL_PRE_PING, # Enables connection health checks
32 | echo=False,
33 | )
34 | SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
35 |
36 |
37 | # Sync dependency
38 | def get_db():
39 | db = SessionLocal()
40 | try:
41 | yield db
42 | finally:
43 | db.close()
44 |
45 |
46 | # Async engine and session (new)
47 | # Convert the DATABASE_URL to async format if it's using psycopg2
48 | ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL
49 | if SQLALCHEMY_DATABASE_URL.startswith("postgresql://"):
50 | ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://")
51 | elif SQLALCHEMY_DATABASE_URL.startswith("postgresql+psycopg2://"):
52 | ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL.replace("postgresql+psycopg2://", "postgresql+asyncpg://")
53 |
54 | async_engine = create_async_engine(
55 | ASYNC_DATABASE_URL,
56 | pool_size=POOL_SIZE,
57 | max_overflow=MAX_OVERFLOW,
58 | pool_timeout=MAX_TIMEOUT,
59 | pool_recycle=POOL_RECYCLE,
60 | pool_pre_ping=POOL_PRE_PING, # Enables connection health checks
61 | echo=False,
62 | )
63 |
64 | AsyncSessionLocal = async_sessionmaker(
65 | bind=async_engine,
66 | class_=AsyncSession,
67 | expire_on_commit=False,
68 | autoflush=False,
69 | )
70 |
71 | Base = declarative_base()
72 |
73 |
74 | # Async dependency
75 | async def get_async_db():
76 | async with AsyncSessionLocal() as session:
77 | try:
78 | yield session
79 | except Exception:
80 | # Rollback on any exception, but handle potential session state issues
81 | try:
82 | await session.rollback()
83 | except Exception:
84 | # If rollback fails (e.g., session already closed), ignore it
85 | # The context manager will handle session cleanup
86 | pass
87 | raise
88 |
89 |
90 | @asynccontextmanager
91 | async def get_db_session():
92 | """Async context manager for database sessions"""
93 | async with AsyncSessionLocal() as session:
94 | try:
95 | yield session
96 | except Exception:
97 | # Rollback on any exception, but handle potential session state issues
98 | try:
99 | await session.rollback()
100 | except Exception:
101 | # If rollback fails (e.g., session already closed), ignore it
102 | # The context manager will handle session cleanup
103 | pass
104 | raise
105 |
106 |
107 | def get_connection_info():
108 | """Get current connection pool information for monitoring"""
109 | return {
110 | "pool_size": POOL_SIZE,
111 | "max_overflow": MAX_OVERFLOW,
112 | "pool_timeout": MAX_TIMEOUT,
113 | "pool_recycle": POOL_RECYCLE,
114 | "sync_engine": {
115 | "pool": engine.pool,
116 | "checked_out": engine.pool.checkedout(),
117 | "checked_in": engine.pool.checkedin(),
118 | "size": engine.pool.size(),
119 | },
120 | "async_engine": {
121 | "pool": async_engine.pool,
122 | "checked_out": async_engine.pool.checkedout(),
123 | "checked_in": async_engine.pool.checkedin(),
124 | "size": async_engine.pool.size(),
125 | }
126 | }
--------------------------------------------------------------------------------
/tests/test_token_counting_stream.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import asyncio
4 | import json
5 | from http import HTTPStatus
6 |
7 | import aiohttp
8 |
9 |
10 | async def test_token_counting_stream(api_key: str, model: str):
11 | """
12 | Test the token counting in streaming mode.
13 | """
14 | url = "http://localhost:8000/v1/chat/completions"
15 | headers = {"Content-Type": "application/json", "X-API-Key": api_key}
16 |
17 | payload = {
18 | "model": model,
19 | "messages": [
20 | {"role": "system", "content": "You are a helpful assistant."},
21 | {
22 | "role": "user",
23 | "content": "Write a short story about a robot learning to count. Keep it to 3 paragraphs.",
24 | },
25 | ],
26 | "stream": True,
27 | }
28 |
29 | print(f"\n[INFO] Sending streaming request to {url}")
30 | print(f"[INFO] Model: {model}")
31 |
32 | tokens_seen = 0
33 | chunk_count = 0
34 |
35 | # Use single with statement for session and post request
36 | async with (
37 | aiohttp.ClientSession() as session,
38 | session.post(url, headers=headers, json=payload) as response,
39 | ):
40 | print(f"[INFO] Response status: {response.status}")
41 |
42 | if response.status != HTTPStatus.OK:
43 | error_text = await response.text()
44 | print(f"[ERROR] {error_text}")
45 | return
46 |
47 | async for line_bytes in response.content:
48 | line = line_bytes.decode("utf-8").strip()
49 |
50 | if not line:
51 | continue
52 |
53 | if line.startswith("data:"):
54 | chunk_count += 1
55 | data_str = line[5:].strip()
56 |
57 | if data_str == "[DONE]":
58 | print("\n[INFO] Streaming completed")
59 | print(f"[INFO] Total chunks: {chunk_count}")
60 | break
61 |
62 | try:
63 | data = json.loads(data_str)
64 |
65 | # Extract content from the chunk
66 | content = ""
67 | if (
68 | "choices" in data
69 | and data["choices"]
70 | and "delta" in data["choices"][0]
71 | ):
72 | content = data["choices"][0]["delta"].get("content", "")
73 |
74 | # Keep track of content length as a proxy for tokens
75 | tokens_seen += len(content) / 4 # Simple approximation
76 |
77 | # Check if we have usage information
78 | if "usage" in data:
79 | usage = data.get("usage", {})
80 | print(f"\n[TOKEN INFO] Reported usage: {usage}")
81 |
82 | if chunk_count % 10 == 1:
83 | print(f"\n[CHUNK {chunk_count}] Content: '{content}'")
84 | print(
85 | f"[TOKEN ESTIMATE] Approximately {int(tokens_seen)} tokens so far"
86 | )
87 |
88 | except json.JSONDecodeError:
89 | print(f"[WARNING] Could not parse: {data_str}")
90 |
91 | print("\n[SUMMARY]")
92 | print(f"Total chunks received: {chunk_count}")
93 | print(f"Estimated token count: {int(tokens_seen)}")
94 |
95 | # After streaming completes, check the usage statistics
96 | print("\n[INFO] Checking usage statistics...")
97 | url = "http://localhost:8000/stats/"
98 |
99 | # Use single with statement for session and get request
100 | async with (
101 | aiohttp.ClientSession() as session,
102 | session.get(url, headers=headers) as response,
103 | ):
104 | if response.status == HTTPStatus.OK:
105 | usage_data = await response.json()
106 | print("\n[USAGE STATS]")
107 | print(json.dumps(usage_data, indent=2))
108 | else:
109 | print(f"[ERROR] Failed to get usage stats: {response.status}")
110 | error_text = await response.text()
111 | print(error_text)
112 |
113 |
114 | def main():
115 | parser = argparse.ArgumentParser(
116 | description="Test token counting in streaming mode"
117 | )
118 | parser.add_argument("--api-key", required=True, help="API key for authentication")
119 | parser.add_argument(
120 | "--model", default="gpt-3.5-turbo", help="Model to use for the test"
121 | )
122 |
123 | args = parser.parse_args()
124 |
125 | asyncio.run(test_token_counting_stream(args.api_key, args.model))
126 |
127 |
128 | if __name__ == "__main__":
129 | main()
130 |
--------------------------------------------------------------------------------
/tests/frontend_simulation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | from http import HTTPStatus
5 |
6 | import requests
7 | from dotenv import dotenv_values, load_dotenv
8 |
9 | # Load directly from .env file, bypassing system environment variables
10 | env_values = dotenv_values(".env")
11 | FORGE_API_KEY = env_values.get("FORGE_API_KEY", "")
12 |
13 | # Fall back to standard dotenv loading if not found directly
14 | if not FORGE_API_KEY:
15 | load_dotenv()
16 | FORGE_API_KEY = os.getenv("FORGE_API_KEY", "")
17 |
18 | # Configuration
19 | FORGE_API_URL = env_values.get(
20 | "FORGE_API_URL", os.getenv("FORGE_API_URL", "http://localhost:8000")
21 | )
22 |
23 | if not FORGE_API_KEY:
24 | print("Error: FORGE_API_KEY environment variable is not set.")
25 | print("Please set it to your Forge API key and try again.")
26 | sys.exit(1)
27 |
28 |
29 | def chat_completion(messages, model="mock-only-gpt-3.5-turbo"):
30 | """Send a chat completion request to Forge"""
31 | url = f"{FORGE_API_URL}/chat/completions"
32 | headers = {"X-API-KEY": FORGE_API_KEY}
33 | data = {"model": model, "messages": messages, "temperature": 0.7}
34 |
35 | try:
36 | start_time = time.time()
37 | response = requests.post(url, json=data, headers=headers)
38 | end_time = time.time()
39 |
40 | if response.status_code == HTTPStatus.OK:
41 | result = response.json()
42 | print(f"Request completed in {end_time - start_time:.2f} seconds")
43 | return {
44 | "success": True,
45 | "response": result["choices"][0]["message"]["content"],
46 | "model": result["model"],
47 | "response_time": end_time - start_time,
48 | }
49 | else:
50 | return {
51 | "success": False,
52 | "error": f"API error: {response.status_code}",
53 | "details": response.text,
54 | }
55 | except Exception as e:
56 | return {"success": False, "error": "Request failed", "details": str(e)}
57 |
58 |
59 | def simulate_conversation():
60 | """Simulate a conversation using the Forge API"""
61 | messages = []
62 |
63 | # Initial system message
64 | system_message = {
65 | "role": "system",
66 | "content": "You are a helpful and friendly AI assistant.",
67 | }
68 | messages.append(system_message)
69 |
70 | # First user message
71 | first_user_message = {
72 | "role": "user",
73 | "content": "Hello! Can you tell me what Forge is?",
74 | }
75 | messages.append(first_user_message)
76 |
77 | print("\n--- User: Hello! Can you tell me what Forge is?")
78 | result = chat_completion(messages)
79 |
80 | if result["success"]:
81 | print(f"\n--- AI ({result['model']}): {result['response']}")
82 | messages.append({"role": "assistant", "content": result["response"]})
83 | else:
84 | print(f"\nError: {result['error']}")
85 | print(f"Details: {result.get('details', 'No details available')}")
86 | return
87 |
88 | # Second user message
89 | second_user_message = {
90 | "role": "user",
91 | "content": "How can I use Forge with other frontend applications?",
92 | }
93 | messages.append(second_user_message)
94 |
95 | print("\n--- User: How can I use Forge with other frontend applications?")
96 | result = chat_completion(messages)
97 |
98 | if result["success"]:
99 | print(f"\n--- AI ({result['model']}): {result['response']}")
100 | messages.append({"role": "assistant", "content": result["response"]})
101 | else:
102 | print(f"\nError: {result['error']}")
103 | print(f"Details: {result.get('details', 'No details available')}")
104 | return
105 |
106 | # Try a different model if available
107 | third_user_message = {
108 | "role": "user",
109 | "content": "Can you summarize our conversation so far?",
110 | }
111 | messages.append(third_user_message)
112 |
113 | print("\n--- User: Can you summarize our conversation so far?")
114 |
115 | # Try with a different model if available
116 | alternative_model = "mock-only-gpt-4" # Will fallback if not available
117 | result = chat_completion(messages, model=alternative_model)
118 |
119 | if result["success"]:
120 | print(f"\n--- AI ({result['model']}): {result['response']}")
121 | else:
122 | print(f"\nError: {result['error']}")
123 | print(f"Details: {result.get('details', 'No details available')}")
124 |
125 |
126 | if __name__ == "__main__":
127 | print("🔄 Simulating a frontend application using Forge API")
128 | print(f"🔌 Connecting to Forge at {FORGE_API_URL}")
129 | print(f"🔑 Using Forge API Key: {FORGE_API_KEY[:8]}...")
130 |
131 | simulate_conversation()
132 |
--------------------------------------------------------------------------------
/alembic/versions/c9f3e548adef_add_lambda_model_pricing.py:
--------------------------------------------------------------------------------
1 | """add lambda model pricing
2 |
3 | Revision ID: c9f3e548adef
4 | Revises: 39bcedfae4fe
5 | Create Date: 2025-08-25 19:53:57.606298
6 |
7 | """
8 | from csv import DictReader
9 | from decimal import Decimal
10 | import os
11 | from alembic import op
12 | import sqlalchemy as sa
13 | from datetime import datetime, timedelta, UTC
14 |
15 |
16 | # revision identifiers, used by Alembic.
17 | revision = 'c9f3e548adef'
18 | down_revision = '39bcedfae4fe'
19 | branch_labels = None
20 | depends_on = None
21 |
22 |
23 | def upgrade() -> None:
24 | # insert the lambda data
25 | effective_date = datetime.now(UTC) - timedelta(days=1)
26 | csv_path = os.path.join(os.path.dirname(__file__), "..", "..", "tools", "data", "lambda_model_pricing_init.csv")
27 | with open(csv_path, "r") as f:
28 | reader = DictReader(f)
29 | rows_to_insert = []
30 | for row in reader:
31 | rows_to_insert.append({
32 | "provider_name": row["provider_name"],
33 | "model_name": row["model_name"],
34 | "effective_date": effective_date,
35 | "input_token_price": Decimal(str(row["input_token_price"])).normalize(),
36 | "output_token_price": Decimal(str(row["output_token_price"])).normalize(),
37 | "price_source": "manual"
38 | })
39 |
40 | if rows_to_insert:
41 | connection = op.get_bind()
42 | connection.execute(
43 | sa.text("""
44 | INSERT INTO model_pricing (provider_name, model_name, effective_date, input_token_price, output_token_price, cached_token_price, price_source)
45 | VALUES (:provider_name, :model_name, :effective_date, :input_token_price, :output_token_price, :input_token_price, 'manual')
46 | """),
47 | rows_to_insert,
48 | )
49 | connection.execute(
50 | sa.text("""
51 | INSERT INTO fallback_pricing (provider_name, model_name, effective_date, input_token_price, output_token_price, cached_token_price, fallback_type)
52 | VALUES (:provider_name, :model_name, :effective_date, :input_token_price, :output_token_price, :input_token_price, 'model_default')
53 | """),
54 | rows_to_insert,
55 | )
56 |
57 | # Fix the cached_token_price for all the other models
58 | csv_path = os.path.join(os.path.dirname(__file__), "..", "..", "tools", "data", "model_pricing_init.csv")
59 | with open(csv_path, "r") as f:
60 | reader = DictReader(f)
61 | rows_to_update = []
62 | for row in reader:
63 | rows_to_update.append({
64 | "provider_name": row["provider_name"],
65 | "model_name": row["model_name"],
66 | "input_token_price": Decimal(str(row["input_token_price"])).normalize(),
67 | "output_token_price": Decimal(str(row["output_token_price"])).normalize(),
68 | "cached_token_price": Decimal(str(row["cached_token_price"])).normalize(),
69 | })
70 |
71 | if rows_to_update:
72 | connection = op.get_bind()
73 | connection.execute(
74 | sa.text("""
75 | update model_pricing set cached_token_price = :cached_token_price
76 | where provider_name = :provider_name and model_name = :model_name
77 | """),
78 | rows_to_update,
79 | )
80 | connection.execute(
81 | sa.text("""
82 | update fallback_pricing set cached_token_price = :cached_token_price
83 | where provider_name = :provider_name and model_name = :model_name
84 | """),
85 | rows_to_update,
86 | )
87 |
88 | # backfill the cached_token_price for all the other models
89 | connection = op.get_bind()
90 | connection.execute(
91 | sa.text("""
92 | with updated_model_pricing as (
93 | update model_pricing set cached_token_price = input_token_price
94 | where cached_token_price = 0
95 | )
96 | update fallback_pricing set cached_token_price = input_token_price
97 | where cached_token_price = 0
98 | """),
99 | rows_to_insert,
100 | )
101 |
102 | # remove the default value for cached_token_price
103 | op.alter_column('model_pricing', 'cached_token_price', server_default=None)
104 | op.alter_column('fallback_pricing', 'cached_token_price', server_default=None)
105 |
106 |
107 | def downgrade() -> None:
108 | connection = op.get_bind()
109 | connection.execute(
110 | sa.text("""
111 | DELETE FROM model_pricing WHERE provider_name = 'lambda'
112 | """),
113 | )
114 | connection.execute(
115 | sa.text("""
116 | DELETE FROM fallback_pricing WHERE provider_name = 'lambda'
117 | """),
118 | )
119 |
--------------------------------------------------------------------------------