├── .github └── workflows │ ├── catalogs.yml │ ├── release.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── pyproject.toml └── src ├── gpuhunt ├── __init__.py ├── __main__.py ├── _internal │ ├── __init__.py │ ├── catalog.py │ ├── constraints.py │ ├── default.py │ ├── models.py │ ├── storage.py │ └── utils.py ├── providers │ ├── __init__.py │ ├── aws.py │ ├── azure.py │ ├── cloudrift.py │ ├── cudo.py │ ├── datacrunch.py │ ├── gcp.py │ ├── lambdalabs.py │ ├── nebius.py │ ├── oci.py │ ├── runpod.py │ ├── tensordock.py │ ├── vastai.py │ └── vultr.py ├── resources │ ├── __init__.py │ └── tpu_pricing.json ├── scripts │ ├── __init__.py │ └── catalog_v1 │ │ ├── __init__.py │ │ └── __main__.py └── version.py ├── integrity_tests ├── __init__.py ├── conftest.py ├── test_all.py ├── test_aws.py ├── test_azure.py ├── test_cloudrift.py ├── test_datacrunch.py ├── test_gcp.py ├── test_lambdalabs.py ├── test_nebius.py ├── test_oci.py ├── test_runpod.py └── test_vastai.py └── tests ├── __init__.py ├── _internal ├── __init__.py ├── test_catalog.py ├── test_constraints.py ├── test_models.py └── test_utils.py ├── providers ├── __init__.py ├── test_cudo.py ├── test_datacrunch.py ├── test_oci.py ├── test_providers.py ├── test_tensordock.py └── test_vultr.py └── scripts ├── __init__.py └── test_catalog_v1.py /.github/workflows/catalogs.yml: -------------------------------------------------------------------------------- 1 | name: Collect and publish catalogs 2 | run-name: Collect and publish catalogs${{ inputs.staging && ' (staging)' || '' }} 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | staging: 7 | description: Staging 8 | type: boolean 9 | default: true 10 | schedule: 11 | - cron: '5 * * * *' # Run every hour at HH:05 12 | 13 | env: 14 | PIP_DISABLE_PIP_VERSION_CHECK: on 15 | PIP_DEFAULT_TIMEOUT: 10 16 | PIP_PROGRESS_BAR: off 17 | PYTHON_VERSION: 3.13 18 | 19 | jobs: 20 | catalog-aws: 21 | name: Collect AWS catalog 22 | runs-on: ubuntu-latest 23 | steps: 24 | - uses: actions/checkout@v4 25 | - uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ env.PYTHON_VERSION }} 28 | - name: Install dependencies 29 | run: | 30 | pip install pip -U 31 | pip install -e '.[aws]' 32 | - name: Collect catalog 33 | working-directory: src 34 | env: 35 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} 36 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 37 | run: python -m gpuhunt aws --output ../aws.csv 38 | - uses: actions/upload-artifact@v4 39 | with: 40 | name: catalogs-aws 41 | path: aws.csv 42 | retention-days: 1 43 | 44 | catalog-azure: 45 | name: Collect Azure catalog 46 | runs-on: ubuntu-latest 47 | permissions: 48 | id-token: write 49 | contents: read 50 | steps: 51 | - uses: actions/checkout@v4 52 | - uses: azure/login@v2 53 | with: 54 | creds: '{"clientId":"${{ secrets.AZURE_CLIENT_ID }}","clientSecret":"${{ secrets.AZURE_CLIENT_SECRET }}","subscriptionId":"${{ secrets.AZURE_SUBSCRIPTION_ID }}","tenantId":"${{ secrets.AZURE_TENANT_ID }}"}' 55 | - uses: actions/setup-python@v5 56 | with: 57 | python-version: ${{ env.PYTHON_VERSION }} 58 | - name: Install dependencies 59 | run: | 60 | pip install pip -U 61 | pip install -e '.[azure]' 62 | - name: Collect catalog 63 | working-directory: src 64 | env: 65 | AZURE_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} 66 | run: python -m gpuhunt azure --output ../azure.csv 67 | - uses: actions/upload-artifact@v4 68 | with: 69 | name: catalogs-azure 70 | path: azure.csv 71 | retention-days: 1 72 | 73 | catalog-datacrunch: 74 | name: Collect DataCrunch catalog 75 | runs-on: ubuntu-latest 76 | steps: 77 | - uses: actions/checkout@v4 78 | - uses: actions/setup-python@v5 79 | with: 80 | python-version: ${{ env.PYTHON_VERSION }} 81 | - name: Install dependencies 82 | run: | 83 | pip install pip -U 84 | pip install -e '.[datacrunch]' 85 | - name: Collect catalog 86 | working-directory: src 87 | env: 88 | DATACRUNCH_CLIENT_ID: ${{ secrets.DATACRUNCH_CLIENT_ID }} 89 | DATACRUNCH_CLIENT_SECRET: ${{ secrets.DATACRUNCH_CLIENT_SECRET }} 90 | run: python -m gpuhunt datacrunch --output ../datacrunch.csv 91 | - uses: actions/upload-artifact@v4 92 | with: 93 | name: catalogs-datacrunch 94 | path: datacrunch.csv 95 | retention-days: 1 96 | 97 | catalog-gcp: 98 | name: Collect GCP catalog 99 | runs-on: ubuntu-latest 100 | permissions: 101 | contents: read 102 | id-token: write 103 | steps: 104 | - uses: actions/checkout@v4 105 | - uses: google-github-actions/auth@v2 106 | with: 107 | workload_identity_provider: 'projects/531508670106/locations/global/workloadIdentityPools/github-identity-pool/providers/github-id-provider' 108 | service_account: 'dstack-gpu-pricing-ci@dstack.iam.gserviceaccount.com' 109 | create_credentials_file: true 110 | - uses: actions/setup-python@v5 111 | with: 112 | python-version: ${{ env.PYTHON_VERSION }} 113 | - name: Install dependencies 114 | run: | 115 | pip install pip -U 116 | pip install -e '.[gcp]' 117 | - name: Collect catalog 118 | working-directory: src 119 | env: 120 | GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }} 121 | run: python -m gpuhunt gcp --output ../gcp.csv 122 | - uses: actions/upload-artifact@v4 123 | with: 124 | name: catalogs-gcp 125 | path: gcp.csv 126 | retention-days: 1 127 | 128 | catalog-lambdalabs: 129 | name: Collect LambdaLabs catalog 130 | runs-on: ubuntu-latest 131 | permissions: 132 | contents: read 133 | id-token: write 134 | steps: 135 | - uses: actions/checkout@v4 136 | - uses: actions/setup-python@v5 137 | with: 138 | python-version: ${{ env.PYTHON_VERSION }} 139 | - name: Install dependencies 140 | run: | 141 | pip install pip -U 142 | pip install -e '.' 143 | - name: Collect catalog 144 | working-directory: src 145 | env: 146 | LAMBDALABS_TOKEN: ${{ secrets.LAMBDALABS_TOKEN }} 147 | run: python -m gpuhunt lambdalabs --output ../lambdalabs.csv 148 | - uses: actions/upload-artifact@v4 149 | with: 150 | name: catalogs-lambdalabs 151 | path: lambdalabs.csv 152 | retention-days: 1 153 | 154 | catalog-nebius: 155 | name: Collect Nebius catalog 156 | runs-on: ubuntu-latest 157 | steps: 158 | - uses: actions/checkout@v4 159 | - uses: actions/setup-python@v5 160 | with: 161 | python-version: ${{ env.PYTHON_VERSION }} 162 | - name: Install dependencies 163 | run: | 164 | pip install pip -U 165 | pip install -e '.[nebius]' 166 | - name: Collect catalog 167 | working-directory: src 168 | run: | 169 | echo "${{ secrets.NEBIUS_PRIVATE_KEY }}" > "${NEBIUS_PRIVATE_KEY_FILE}" 170 | python -m gpuhunt nebius --output ../nebius.csv 171 | env: 172 | NEBIUS_SERVICE_ACCOUNT_ID: ${{ secrets.NEBIUS_SERVICE_ACCOUNT_ID }} 173 | NEBIUS_PUBLIC_KEY_ID: ${{ secrets.NEBIUS_PUBLIC_KEY_ID }} 174 | NEBIUS_PRIVATE_KEY_FILE: /tmp/nebius.key 175 | - uses: actions/upload-artifact@v4 176 | with: 177 | name: catalogs-nebius 178 | path: nebius.csv 179 | retention-days: 1 180 | 181 | catalog-oci: 182 | name: Collect OCI catalog 183 | runs-on: ubuntu-latest 184 | steps: 185 | - uses: actions/checkout@v4 186 | - uses: actions/setup-python@v5 187 | with: 188 | python-version: ${{ env.PYTHON_VERSION }} 189 | - name: Install dependencies 190 | run: | 191 | pip install pip -U 192 | pip install -e '.[oci]' 193 | - name: Collect catalog 194 | working-directory: src 195 | run: python -m gpuhunt oci --output ../oci.csv 196 | env: 197 | OCI_CLI_USER: ${{ secrets.OCI_CLI_USER }} 198 | OCI_CLI_KEY_CONTENT: ${{ secrets.OCI_CLI_KEY_CONTENT }} 199 | OCI_CLI_FINGERPRINT: ${{ secrets.OCI_CLI_FINGERPRINT }} 200 | OCI_CLI_TENANCY: ${{ secrets.OCI_CLI_TENANCY }} 201 | OCI_CLI_REGION: ${{ secrets.OCI_CLI_REGION }} 202 | - uses: actions/upload-artifact@v4 203 | with: 204 | name: catalogs-oci 205 | path: oci.csv 206 | retention-days: 1 207 | 208 | catalog-runpod: 209 | name: Collect Runpod catalog 210 | runs-on: ubuntu-latest 211 | steps: 212 | - uses: actions/checkout@v4 213 | - uses: actions/setup-python@v5 214 | with: 215 | python-version: ${{ env.PYTHON_VERSION }} 216 | - name: Install dependencies 217 | run: | 218 | pip install pip -U 219 | pip install -e '.' 220 | - name: Collect catalog 221 | working-directory: src 222 | run: python -m gpuhunt runpod --output ../runpod.csv 223 | - uses: actions/upload-artifact@v4 224 | with: 225 | name: catalogs-runpod 226 | path: runpod.csv 227 | retention-days: 1 228 | 229 | catalog-cloudrift: 230 | name: Collect CloudRift catalog 231 | runs-on: ubuntu-latest 232 | steps: 233 | - uses: actions/checkout@v4 234 | - uses: actions/setup-python@v5 235 | with: 236 | python-version: ${{ env.PYTHON_VERSION }} 237 | - name: Install dependencies 238 | run: | 239 | pip install pip -U 240 | pip install -e '.' 241 | - name: Collect catalog 242 | working-directory: src 243 | run: python -m gpuhunt cloudrift --output ../cloudrift.csv 244 | - uses: actions/upload-artifact@v4 245 | with: 246 | name: catalogs-cloudrift 247 | path: cloudrift.csv 248 | retention-days: 1 249 | 250 | test-catalog: 251 | name: Test catalogs integrity 252 | needs: 253 | - catalog-aws 254 | - catalog-azure 255 | - catalog-datacrunch 256 | - catalog-gcp 257 | - catalog-lambdalabs 258 | - catalog-nebius 259 | - catalog-oci 260 | - catalog-runpod 261 | - catalog-cloudrift 262 | runs-on: ubuntu-latest 263 | steps: 264 | - uses: actions/checkout@v4 265 | - uses: actions/setup-python@v5 266 | with: 267 | python-version: ${{ env.PYTHON_VERSION }} 268 | - name: Install dependencies 269 | run: | 270 | pip install pip -U 271 | pip install '.[all]' 272 | pip install pytest 273 | - uses: actions/download-artifact@v4 274 | with: 275 | pattern: catalogs-* 276 | merge-multiple: true 277 | - name: Run integrity tests 278 | env: 279 | CATALOG_DIR: . 280 | run: pytest src/integrity_tests 281 | 282 | publish-catalog: 283 | name: Publish catalogs 284 | needs: [ test-catalog ] 285 | runs-on: ubuntu-latest 286 | steps: 287 | - uses: actions/checkout@v4 288 | - uses: actions/setup-python@v5 289 | with: 290 | python-version: ${{ env.PYTHON_VERSION }} 291 | - name: Install gpuhunt 292 | run: pip install . 293 | - name: Install AWS CLI 294 | run: pip install awscli 295 | - uses: actions/download-artifact@v4 296 | with: 297 | pattern: catalogs-* 298 | merge-multiple: true 299 | path: v2/ 300 | - name: Build legacy v1 catalogs 301 | run: | 302 | mkdir v1 303 | for catalog_path in $(find v2/*.csv); do 304 | file=$(basename "$catalog_path") 305 | python -m gpuhunt.scripts.catalog_v1 --input "v2/$file" --output "v1/$file" 306 | done 307 | - name: Write version 308 | run: echo "$(date +%Y%m%d)-${{ github.run_number }}" | tee v2/version | tee v1/version 309 | - name: Package catalogs 310 | run: | 311 | zip -j v2/catalog.zip v2/*.csv v2/version 312 | zip -j v1/catalog.zip v1/*.csv v1/version 313 | - name: Upload to S3 314 | env: 315 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} 316 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 317 | BUCKET: s3://dstack-gpu-pricing${{ github.event_name == 'workflow_dispatch' && inputs.staging && '/stgn' || '' }} 318 | run: | 319 | VERSION=$(cat v2/version) 320 | aws s3 cp v2/catalog.zip "$BUCKET/v2/$VERSION/catalog.zip" --acl public-read 321 | aws s3 cp v1/catalog.zip "$BUCKET/v1/$VERSION/catalog.zip" --acl public-read 322 | echo $VERSION | aws s3 cp - "$BUCKET/v2/version" --acl public-read 323 | echo $VERSION | aws s3 cp - "$BUCKET/v1/version" --acl public-read 324 | aws s3 cp "$BUCKET/v2/$VERSION/catalog.zip" "$BUCKET/v2/latest/catalog.zip" --acl public-read 325 | aws s3 cp "$BUCKET/v1/$VERSION/catalog.zip" "$BUCKET/v1/latest/catalog.zip" --acl public-read 326 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | env: 9 | PIP_DISABLE_PIP_VERSION_CHECK: on 10 | PIP_DEFAULT_TIMEOUT: 10 11 | PIP_PROGRESS_BAR: off 12 | 13 | jobs: 14 | quality-control: 15 | uses: ./.github/workflows/test.yml 16 | 17 | pypi-upload: 18 | needs: [ quality-control ] 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: actions/setup-python@v5 23 | with: 24 | python-version: 3.13 25 | - name: Install dependencies 26 | run: | 27 | pip install pip -U 28 | pip install build wheel twine 29 | - name: Upload pypi package 30 | run: | 31 | VERSION=${GITHUB_REF#refs/tags/} 32 | echo "__version__ = \"$VERSION\"" > src/gpuhunt/version.py 33 | python -m build 34 | python -m twine upload --repository pypi --username ${{ secrets.PYPI_USERNAME }} --password ${{ secrets.PYPI_PASSWORD }} dist/* 35 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test package 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" 7 | pull_request: 8 | branches: 9 | - "main" 10 | workflow_dispatch: 11 | workflow_call: 12 | 13 | env: 14 | PIP_DISABLE_PIP_VERSION_CHECK: on 15 | PIP_DEFAULT_TIMEOUT: 10 16 | PIP_PROGRESS_BAR: off 17 | 18 | jobs: 19 | python-lint: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@v4 23 | - uses: actions/setup-python@v5 24 | with: 25 | python-version: 3.13 26 | - run: python -m pip install pre-commit 27 | - run: pre-commit run -a --show-diff-on-failure 28 | 29 | python-test: 30 | needs: [python-lint] 31 | runs-on: ${{ matrix.os }} 32 | strategy: 33 | matrix: 34 | os: [ubuntu-latest] 35 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 36 | steps: 37 | - uses: actions/checkout@v4 38 | - name: Set up Python ${{ matrix.python-version }} 39 | uses: actions/setup-python@v5 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | - name: Install dependencies 43 | run: | 44 | python -m pip install --upgrade pip 45 | pip install '.[all,dev]' 46 | - name: Run doctest 47 | run: | 48 | IGNORE= 49 | if [[ "${{ matrix.python-version }}" == "3.9" ]]; then 50 | IGNORE="--ignore src/gpuhunt/providers/nebius.py" 51 | fi 52 | pytest --doctest-modules src/gpuhunt $IGNORE 53 | - name: Run pytest 54 | run: | 55 | pytest src/tests 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | *.pyc 3 | .DS_Store 4 | /.idea/ 5 | /venv/ 6 | /.venv/ 7 | /build/ 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.5.3 # Should match pyproject.toml 4 | hooks: 5 | - id: ruff 6 | name: ruff common 7 | args: ['--fix'] 8 | - id: ruff-format 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![](https://img.shields.io/pypi/v/gpuhunt)](https://pypi.org/project/gpuhunt/) 2 | 3 | Easy access to GPU pricing data for major cloud providers: AWS, Azure, GCP, etc. 4 | The catalog includes details about prices, locations, CPUs, RAM, GPUs, and spots (interruptible instances). 5 | 6 | ## Usage 7 | 8 | ```python 9 | import gpuhunt 10 | 11 | items = gpuhunt.query( 12 | min_memory=16, 13 | min_cpu=8, 14 | min_gpu_count=1, 15 | max_price=1.0, 16 | ) 17 | 18 | print(*items, sep="\n") 19 | ``` 20 | 21 | List of all available filters: 22 | 23 | * `provider`: name of the provider to filter by. If not specified, all providers will be used. One or many 24 | * `cpu_arch`: CPU architecture, one of: `x86`, `arm` 25 | * `min_cpu`: minimum number of CPUs 26 | * `max_cpu`: maximum number of CPUs 27 | * `min_memory`: minimum amount of RAM in GB 28 | * `max_memory`: maximum amount of RAM in GB 29 | * `min_gpu_count`: minimum number of GPUs 30 | * `max_gpu_count`: maximum number of GPUs 31 | * `gpu_vendor`: GPU/accelerator vendor, one of: `nvidia`, `amd`, `google`, `intel` 32 | * `gpu_name`: name of the GPU to filter by. If not specified, all GPUs will be used. One or many 33 | * `min_gpu_memory`: minimum amount of GPU VRAM in GB for each GPU 34 | * `max_gpu_memory`: maximum amount of GPU VRAM in GB for each GPU 35 | * `min_total_gpu_memory`: minimum amount of GPU VRAM in GB for all GPUs combined 36 | * `max_total_gpu_memory`: maximum amount of GPU VRAM in GB for all GPUs combined 37 | * `min_disk_size`: minimum disk size in GB (not fully supported) 38 | * `max_disk_size`: maximum disk size in GB (not fully supported) 39 | * `min_price`: minimum price per hour in USD 40 | * `max_price`: maximum price per hour in USD 41 | * `min_compute_capability`: minimum compute capability of the GPU 42 | * `max_compute_capability`: maximum compute capability of the GPU 43 | * `spot`: if `False`, only ondemand offers will be returned. If `True`, only spot offers will be returned 44 | 45 | ## Advanced usage 46 | 47 | ```python 48 | from gpuhunt import Catalog 49 | 50 | catalog = Catalog() 51 | catalog.load(version="20240508") 52 | items = catalog.query() 53 | 54 | print(*items, sep="\n") 55 | ``` 56 | 57 | ## Supported providers 58 | 59 | * AWS 60 | * Azure 61 | * Cudo Compute 62 | * DataCrunch 63 | * GCP 64 | * LambdaLabs 65 | * Nebius 66 | * OCI 67 | * RunPod 68 | * TensorDock 69 | * Vast AI 70 | * Vultr 71 | 72 | ## See also 73 | 74 | * [dstack](https://github.com/dstackai/dstack) 75 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "gpuhunt" 7 | authors = [ 8 | { name = "dstack GmbH" }, 9 | ] 10 | description = "A catalog of GPU pricing for different cloud providers" 11 | readme = "README.md" 12 | requires-python = ">=3.9" 13 | classifiers = [ 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)", 16 | "Operating System :: OS Independent", 17 | ] 18 | keywords = ["gpu", "cloud", "pricing"] 19 | dynamic = ["version"] 20 | dependencies = [ 21 | "requests", 22 | "typing-extensions" 23 | ] 24 | 25 | [project.urls] 26 | "GitHub" = "https://github.com/dstackai/gpuhunt" 27 | "Issues" = "https://github.com/dstackai/gpuhunt/issues" 28 | 29 | [project.optional-dependencies] 30 | aws = [ 31 | "boto3" 32 | ] 33 | azure = [ 34 | "azure-mgmt-compute", 35 | "azure-identity" 36 | ] 37 | gcp = [ 38 | "google-cloud-billing", 39 | "google-cloud-compute", 40 | "google-cloud-tpu" 41 | ] 42 | # Nebius requires Python 3.10. On 3.9: 43 | # `pip install gpuhunt[nebius]` is expected to fail 44 | # `pip install gpuhunt[all]` is expected to ignore Nebius 45 | nebius = [ 46 | "nebius>=0.2.18,<0.3", 47 | ] 48 | maybe_nebius = [ 49 | 'nebius>=0.2.18,<0.3; python_version>="3.10"', 50 | ] 51 | oci = [ 52 | "oci", 53 | "pydantic>=1.10.10,<2.0.0", 54 | ] 55 | datacrunch = [ 56 | "datacrunch" 57 | ] 58 | all = ["gpuhunt[aws,azure,datacrunch,gcp,maybe_nebius,oci]"] 59 | dev = [ 60 | "pre-commit", 61 | "pytest~=7.0", 62 | "pytest-mock", 63 | "ruff==0.5.3", # Should match .pre-commit-config.yaml 64 | "requests-mock", 65 | ] 66 | 67 | [tool.setuptools.dynamic] 68 | version = {attr = "gpuhunt.version.__version__"} 69 | 70 | [tool.ruff] 71 | line-length = 99 72 | 73 | [tool.ruff.lint] 74 | select = ['E', 'F', 'I' ,'Q', 'W', 'UP', 'PGH', 'FLY', 'S113'] 75 | ignore = [ 76 | 'E501', 77 | 'E712', 78 | ] 79 | 80 | [tool.ruff.lint.isort] 81 | known-first-party = ["gpuhunt"] 82 | -------------------------------------------------------------------------------- /src/gpuhunt/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | import warnings 3 | 4 | from gpuhunt._internal.catalog import Catalog 5 | from gpuhunt._internal.constraints import ( 6 | KNOWN_ACCELERATORS, 7 | KNOWN_AMD_GPUS, 8 | KNOWN_INTEL_ACCELERATORS, 9 | KNOWN_NVIDIA_GPUS, 10 | KNOWN_TENSTORRENT_ACCELERATORS, 11 | KNOWN_TPUS, 12 | correct_gpu_memory_gib, 13 | is_nvidia_superchip, 14 | matches, 15 | ) 16 | from gpuhunt._internal.default import default_catalog, query 17 | from gpuhunt._internal.models import ( 18 | AcceleratorInfo, 19 | AcceleratorVendor, 20 | AMDGPUInfo, 21 | CatalogItem, 22 | CPUArchitecture, 23 | IntelAcceleratorInfo, 24 | NvidiaGPUInfo, 25 | QueryFilter, 26 | RawCatalogItem, 27 | TenstorrentAcceleratorInfo, 28 | TPUInfo, 29 | ) 30 | 31 | # Deprecated aliases 32 | GPUInfo: type[NvidiaGPUInfo] 33 | KNOWN_GPUS: list[NvidiaGPUInfo] 34 | 35 | 36 | def _warn_renamed(old: str, new: str) -> None: 37 | warnings.warn( 38 | f"{old} has been renamed to {new}, the old name is deprecated and will be removed.", 39 | DeprecationWarning, 40 | stacklevel=2, 41 | ) 42 | 43 | 44 | def __getattr__(name): 45 | if name == "GPUInfo": 46 | _warn_renamed("GPUInfo", "NvidiaGPUInfo") 47 | return NvidiaGPUInfo 48 | if name == "KNOWN_GPUS": 49 | _warn_renamed("KNOWN_GPUS", "KNOWN_NVIDIA_GPUS") 50 | return KNOWN_NVIDIA_GPUS 51 | raise AttributeError(f"module {__name__!r} has no attribute {name!r}") 52 | -------------------------------------------------------------------------------- /src/gpuhunt/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import gpuhunt._internal.storage as storage 6 | from gpuhunt._internal.utils import configure_logging 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(prog="python3 -m gpuhunt") 11 | parser.add_argument( 12 | "provider", 13 | choices=[ 14 | "aws", 15 | "azure", 16 | "cloudrift", 17 | "cudo", 18 | "datacrunch", 19 | "gcp", 20 | "lambdalabs", 21 | "nebius", 22 | "oci", 23 | "runpod", 24 | "tensordock", 25 | "vastai", 26 | "vultr", 27 | ], 28 | ) 29 | parser.add_argument("--output", required=True) 30 | parser.add_argument("--no-filter", action="store_true") 31 | args = parser.parse_args() 32 | configure_logging() 33 | 34 | if args.provider == "aws": 35 | from gpuhunt.providers.aws import AWSProvider 36 | 37 | provider = AWSProvider(os.getenv("AWS_CACHE_PATH")) 38 | elif args.provider == "azure": 39 | from gpuhunt.providers.azure import AzureProvider 40 | 41 | provider = AzureProvider(os.getenv("AZURE_SUBSCRIPTION_ID")) 42 | elif args.provider == "cudo": 43 | from gpuhunt.providers.cudo import CudoProvider 44 | 45 | provider = CudoProvider() 46 | elif args.provider == "cloudrift": 47 | from gpuhunt.providers.cloudrift import CloudRiftProvider 48 | 49 | provider = CloudRiftProvider() 50 | elif args.provider == "datacrunch": 51 | from gpuhunt.providers.datacrunch import DataCrunchProvider 52 | 53 | provider = DataCrunchProvider( 54 | os.getenv("DATACRUNCH_CLIENT_ID"), os.getenv("DATACRUNCH_CLIENT_SECRET") 55 | ) 56 | elif args.provider == "gcp": 57 | from gpuhunt.providers.gcp import GCPProvider 58 | 59 | provider = GCPProvider(os.getenv("GCP_PROJECT_ID")) 60 | elif args.provider == "lambdalabs": 61 | from gpuhunt.providers.lambdalabs import LambdaLabsProvider 62 | 63 | provider = LambdaLabsProvider(os.getenv("LAMBDALABS_TOKEN")) 64 | elif args.provider == "nebius": 65 | from nebius.base.service_account.pk_file import Reader as PKReader 66 | 67 | from gpuhunt.providers.nebius import NebiusProvider 68 | 69 | provider = NebiusProvider( 70 | credentials=PKReader( 71 | filename=os.getenv("NEBIUS_PRIVATE_KEY_FILE"), 72 | public_key_id=os.getenv("NEBIUS_PUBLIC_KEY_ID"), 73 | service_account_id=os.getenv("NEBIUS_SERVICE_ACCOUNT_ID"), 74 | ), 75 | ) 76 | elif args.provider == "oci": 77 | from gpuhunt.providers.oci import OCICredentials, OCIProvider 78 | 79 | provider = OCIProvider( 80 | OCICredentials( 81 | user=os.getenv("OCI_CLI_USER"), 82 | key_content=os.getenv("OCI_CLI_KEY_CONTENT"), 83 | fingerprint=os.getenv("OCI_CLI_FINGERPRINT"), 84 | tenancy=os.getenv("OCI_CLI_TENANCY"), 85 | region=os.getenv("OCI_CLI_REGION"), 86 | ) 87 | ) 88 | elif args.provider == "runpod": 89 | from gpuhunt.providers.runpod import RunpodProvider 90 | 91 | provider = RunpodProvider() 92 | elif args.provider == "tensordock": 93 | from gpuhunt.providers.tensordock import TensorDockProvider 94 | 95 | provider = TensorDockProvider() 96 | elif args.provider == "vastai": 97 | from gpuhunt.providers.vastai import VastAIProvider 98 | 99 | provider = VastAIProvider() 100 | elif args.provider == "vultr": 101 | from gpuhunt.providers.vultr import VultrProvider 102 | 103 | provider = VultrProvider() 104 | else: 105 | exit(f"Unknown provider {args.provider}") 106 | 107 | logging.info("Fetching offers for %s", args.provider) 108 | offers = provider.get() 109 | if not args.no_filter: 110 | offers = provider.filter(offers) 111 | storage.dump(offers, args.output) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /src/gpuhunt/_internal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dstackai/gpuhunt/fcc8f33dc037503956d7272c91e730e6584ec492/src/gpuhunt/_internal/__init__.py -------------------------------------------------------------------------------- /src/gpuhunt/_internal/catalog.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import dataclasses 3 | import heapq 4 | import io 5 | import logging 6 | import os 7 | import time 8 | import urllib.request 9 | import zipfile 10 | from collections.abc import Container 11 | from concurrent.futures import ThreadPoolExecutor, wait 12 | from pathlib import Path 13 | from typing import Optional, Union 14 | 15 | import gpuhunt._internal.constraints as constraints 16 | from gpuhunt._internal.models import AcceleratorVendor, CatalogItem, CPUArchitecture, QueryFilter 17 | from gpuhunt._internal.utils import parse_compute_capability 18 | from gpuhunt.providers import AbstractProvider 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | VERSION_URL = "https://dstack-gpu-pricing.s3.eu-west-1.amazonaws.com/v2/version" 23 | CATALOG_URL = "https://dstack-gpu-pricing.s3.eu-west-1.amazonaws.com/v2/{version}/catalog.zip" 24 | OFFLINE_PROVIDERS = [ 25 | "aws", 26 | "azure", 27 | "datacrunch", 28 | "gcp", 29 | "lambdalabs", 30 | "nebius", 31 | "oci", 32 | "runpod", 33 | "cloudrift", 34 | ] 35 | ONLINE_PROVIDERS = ["cudo", "tensordock", "vastai", "vultr"] 36 | RELOAD_INTERVAL = 15 * 60 # 15 minutes 37 | 38 | 39 | class Catalog: 40 | def __init__(self, balance_resources: bool = True, auto_reload: bool = True): 41 | """ 42 | Args: 43 | balance_resources: increase min resources to better match the chosen GPU 44 | auto_reload: if `True`, the catalog will be automatically loaded from the S3 bucket every 4 hours 45 | """ 46 | self.catalog = None 47 | self.loaded_at = None 48 | self.providers: list[AbstractProvider] = [] 49 | self.balance_resources = balance_resources 50 | self.auto_reload = auto_reload 51 | 52 | def query( 53 | self, 54 | *, 55 | provider: Optional[Union[str, list[str]]] = None, 56 | cpu_arch: Optional[Union[CPUArchitecture, str]] = None, 57 | min_cpu: Optional[int] = None, 58 | max_cpu: Optional[int] = None, 59 | min_memory: Optional[float] = None, 60 | max_memory: Optional[float] = None, 61 | min_gpu_count: Optional[int] = None, 62 | max_gpu_count: Optional[int] = None, 63 | gpu_vendor: Optional[Union[AcceleratorVendor, str]] = None, 64 | gpu_name: Optional[Union[str, list[str]]] = None, 65 | min_gpu_memory: Optional[float] = None, 66 | max_gpu_memory: Optional[float] = None, 67 | min_total_gpu_memory: Optional[float] = None, 68 | max_total_gpu_memory: Optional[float] = None, 69 | min_disk_size: Optional[int] = None, 70 | max_disk_size: Optional[int] = None, 71 | min_price: Optional[float] = None, 72 | max_price: Optional[float] = None, 73 | min_compute_capability: Optional[Union[str, tuple[int, int]]] = None, 74 | max_compute_capability: Optional[Union[str, tuple[int, int]]] = None, 75 | spot: Optional[bool] = None, 76 | allowed_flags: Optional[Container[str]] = None, 77 | ) -> list[CatalogItem]: 78 | """ 79 | Query the catalog for matching offers 80 | 81 | Args: 82 | provider: name of the provider to filter by. If not specified, all providers will be used 83 | cpu_arch: CPU architecture to filter by. If not specified, all architectures will be used 84 | min_cpu: minimum number of CPUs 85 | max_cpu: maximum number of CPUs 86 | min_memory: minimum amount of RAM in GB 87 | max_memory: maximum amount of RAM in GB 88 | min_gpu_count: minimum number of GPUs 89 | max_gpu_count: maximum number of GPUs 90 | gpu_vendor: accelerator vendor to filter by. If not specified, all vendors will be used 91 | gpu_name: name of the GPU to filter by. If not specified, all GPUs will be used 92 | min_gpu_memory: minimum amount of GPU VRAM in GB for each GPU 93 | max_gpu_memory: maximum amount of GPU VRAM in GB for each GPU 94 | min_total_gpu_memory: minimum amount of GPU VRAM in GB for all GPUs combined 95 | max_total_gpu_memory: maximum amount of GPU VRAM in GB for all GPUs combined 96 | min_disk_size: minimum disk size in GB 97 | max_disk_size: maximum disk size in GB 98 | min_price: minimum price per hour in USD 99 | max_price: maximum price per hour in USD 100 | min_compute_capability: minimum compute capability of the GPU 101 | max_compute_capability: maximum compute capability of the GPU 102 | spot: if `False`, only ondemand offers will be returned. If `True`, only spot offers will be returned 103 | allowed_flags: only offers with all flags allowed will be returned. `None` allows all flags 104 | 105 | Returns: 106 | list of matching offers 107 | """ 108 | if self.auto_reload and ( 109 | self.loaded_at is None or time.monotonic() - self.loaded_at > RELOAD_INTERVAL 110 | ): 111 | self.load() 112 | 113 | query_filter = QueryFilter( 114 | provider=[provider] if isinstance(provider, str) else provider, 115 | cpu_arch=CPUArchitecture.cast(cpu_arch) if cpu_arch else None, 116 | min_cpu=min_cpu, 117 | max_cpu=max_cpu, 118 | min_memory=min_memory, 119 | max_memory=max_memory, 120 | min_gpu_count=min_gpu_count, 121 | max_gpu_count=max_gpu_count, 122 | gpu_vendor=AcceleratorVendor.cast(gpu_vendor) if gpu_vendor else None, 123 | gpu_name=[gpu_name] if isinstance(gpu_name, str) else gpu_name, 124 | min_gpu_memory=min_gpu_memory, 125 | max_gpu_memory=max_gpu_memory, 126 | min_total_gpu_memory=min_total_gpu_memory, 127 | max_total_gpu_memory=max_total_gpu_memory, 128 | min_disk_size=min_disk_size, 129 | max_disk_size=max_disk_size, 130 | min_price=min_price, 131 | max_price=max_price, 132 | min_compute_capability=parse_compute_capability(min_compute_capability), 133 | max_compute_capability=parse_compute_capability(max_compute_capability), 134 | spot=spot, 135 | allowed_flags=allowed_flags, 136 | ) 137 | 138 | if query_filter.provider is not None: 139 | # validate providers 140 | for p in query_filter.provider: 141 | if p.lower() not in OFFLINE_PROVIDERS + ONLINE_PROVIDERS: 142 | raise ValueError(f"Unknown provider: {p}") 143 | else: 144 | query_filter.provider = OFFLINE_PROVIDERS + list( 145 | set(p.NAME for p in self.providers if p.NAME in ONLINE_PROVIDERS) 146 | ) 147 | 148 | # fetch providers 149 | with ThreadPoolExecutor(max_workers=8) as executor: 150 | futures = [] 151 | 152 | for provider_name in ONLINE_PROVIDERS: 153 | if provider_name in map(str.lower, query_filter.provider): 154 | futures.append( 155 | executor.submit( 156 | self._get_online_provider_items, 157 | provider_name, 158 | query_filter, 159 | ) 160 | ) 161 | 162 | for provider_name in OFFLINE_PROVIDERS: 163 | if provider_name in map(str.lower, query_filter.provider): 164 | futures.append( 165 | executor.submit( 166 | self._get_offline_provider_items, 167 | provider_name, 168 | query_filter, 169 | ) 170 | ) 171 | 172 | completed, _ = wait(futures) 173 | # The merge preserves provider-specific order, picking the cheapest offer at each step. 174 | # The final list is not strictly sorted by the price. 175 | items = list(heapq.merge(*[f.result() for f in completed], key=lambda i: i.price)) 176 | return items 177 | 178 | def load(self, version: Optional[str] = None): 179 | """ 180 | Fetch the catalog from the S3 bucket 181 | 182 | Args: 183 | version: specific version of the catalog to download. If not specified, the latest version will be used 184 | """ 185 | catalog_url = os.getenv("GPUHUNT_CATALOG_URL") 186 | if catalog_url is None: 187 | if version is None: 188 | version = self.get_latest_version() 189 | catalog_url = CATALOG_URL.format(version=version) 190 | logger.debug("Downloading catalog %s...", version) 191 | with urllib.request.urlopen(catalog_url) as f: 192 | self.loaded_at = time.monotonic() 193 | self.catalog = io.BytesIO(f.read()) 194 | 195 | @staticmethod 196 | def get_latest_version() -> str: 197 | """ 198 | Get the latest version of the catalog from the S3 bucket 199 | """ 200 | with urllib.request.urlopen(VERSION_URL) as f: 201 | return f.read().decode("utf-8").strip() 202 | 203 | def add_provider(self, provider: AbstractProvider): 204 | """ 205 | Add provider for querying offers 206 | 207 | Args: 208 | provider: provider to add 209 | """ 210 | self.providers.append(provider) 211 | 212 | def _get_offline_provider_items( 213 | self, provider_name: str, query_filter: QueryFilter 214 | ) -> list[CatalogItem]: 215 | logger.debug("Loading items for offline provider %s", provider_name) 216 | items = [] 217 | # Set this env var to use a local catalog instead of the s3 catalog 218 | catalog_dir = os.getenv("GPUHUNT_CATALOG_DIR") 219 | if catalog_dir is not None: 220 | with open(Path(catalog_dir) / f"{provider_name}.csv", "rb") as csv_file: 221 | reader = csv.DictReader(io.TextIOWrapper(csv_file, "utf-8")) 222 | for row in reader: 223 | item = CatalogItem.from_dict(row, provider=provider_name) 224 | if constraints.matches(item, query_filter): 225 | items.append(item) 226 | else: 227 | if self.catalog is None: 228 | logger.warning("Catalog not loaded") 229 | return items 230 | with zipfile.ZipFile(self.catalog) as zip_file: 231 | with zip_file.open(f"{provider_name}.csv", "r") as csv_file: 232 | reader = csv.DictReader(io.TextIOWrapper(csv_file, "utf-8")) 233 | for row in reader: 234 | item = CatalogItem.from_dict(row, provider=provider_name) 235 | if constraints.matches(item, query_filter): 236 | items.append(item) 237 | return items 238 | 239 | def _get_online_provider_items( 240 | self, provider_name: str, query_filter: QueryFilter 241 | ) -> list[CatalogItem]: 242 | logger.debug("Loading items for online provider %s", provider_name) 243 | items = [] 244 | found = False 245 | for provider in self.providers: 246 | if provider.NAME != provider_name: 247 | continue 248 | found = True 249 | for i in provider.get( 250 | query_filter=query_filter, balance_resources=self.balance_resources 251 | ): 252 | item = CatalogItem(provider=provider_name, **dataclasses.asdict(i)) 253 | if constraints.matches(item, query_filter): 254 | items.append(item) 255 | if not found: 256 | raise ValueError(f"Provider is not loaded: {provider_name}") 257 | return items 258 | -------------------------------------------------------------------------------- /src/gpuhunt/_internal/constraints.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Optional, TypeVar, Union 3 | 4 | from gpuhunt._internal.models import ( 5 | AcceleratorVendor, 6 | AMDArchitecture, 7 | AMDGPUInfo, 8 | CatalogItem, 9 | IntelAcceleratorInfo, 10 | NvidiaGPUInfo, 11 | QueryFilter, 12 | TenstorrentAcceleratorInfo, 13 | TPUInfo, 14 | ) 15 | 16 | # v5litepod = v5e 17 | _TPU_VERSIONS = ["v2", "v3", "v4", "v5p", "v5litepod", "v6e"] 18 | 19 | 20 | Comparable = TypeVar("Comparable", bound=Union[int, float, tuple[int, int]]) 21 | 22 | 23 | def is_between(value: Comparable, left: Optional[Comparable], right: Optional[Comparable]) -> bool: 24 | if is_below(value, left) or is_above(value, right): 25 | return False 26 | return True 27 | 28 | 29 | def is_below(value: Comparable, limit: Optional[Comparable]) -> bool: 30 | if limit is not None and value < limit: 31 | return True 32 | return False 33 | 34 | 35 | def is_above(value: Comparable, limit: Optional[Comparable]) -> bool: 36 | if limit is not None and value > limit: 37 | return True 38 | return False 39 | 40 | 41 | def matches(i: CatalogItem, q: QueryFilter) -> bool: 42 | """ 43 | Check if the catalog item matches the filters 44 | 45 | Args: 46 | i: catalog item 47 | q: filters 48 | 49 | Returns: 50 | whether the catalog item matches the filters 51 | """ 52 | if q.provider is not None and i.provider.lower() not in map(str.lower, q.provider): 53 | return False 54 | if not is_between(i.price, q.min_price, q.max_price): 55 | return False 56 | if q.spot is not None and i.spot != q.spot: 57 | return False 58 | if q.cpu_arch and q.cpu_arch != i.cpu_arch: 59 | return False 60 | if not is_between(i.cpu, q.min_cpu, q.max_cpu): 61 | return False 62 | if not is_between(i.memory, q.min_memory, q.max_memory): 63 | return False 64 | if q.gpu_vendor and q.gpu_vendor != i.gpu_vendor: 65 | return False 66 | if not is_between(i.gpu_count, q.min_gpu_count, q.max_gpu_count): 67 | return False 68 | if q.gpu_name is not None: 69 | if i.gpu_name is None: 70 | return False 71 | if i.gpu_name.lower() not in map(str.lower, q.gpu_name): 72 | return False 73 | if q.min_compute_capability is not None or q.max_compute_capability is not None: 74 | if i.gpu_vendor != AcceleratorVendor.NVIDIA: 75 | return False 76 | if not i.gpu_name: 77 | return False 78 | cc = get_compute_capability(i.gpu_name) 79 | if not cc or not is_between(cc, q.min_compute_capability, q.max_compute_capability): 80 | return False 81 | if not is_between(i.gpu_memory if i.gpu_count > 0 else 0, q.min_gpu_memory, q.max_gpu_memory): 82 | return False 83 | if not is_between( 84 | (i.gpu_count * i.gpu_memory) if i.gpu_count > 0 else 0, 85 | q.min_total_gpu_memory, 86 | q.max_total_gpu_memory, 87 | ): 88 | return False 89 | if i.disk_size is not None: 90 | if not is_between(i.disk_size, q.min_disk_size, q.max_disk_size): 91 | return False 92 | if q.allowed_flags is not None: 93 | if any(flag not in q.allowed_flags for flag in i.flags): 94 | return False 95 | return True 96 | 97 | 98 | def get_compute_capability(gpu_name: str) -> Optional[tuple[int, int]]: 99 | for gpu in KNOWN_NVIDIA_GPUS: 100 | if gpu.name.lower() == gpu_name.lower(): 101 | return gpu.compute_capability 102 | return None 103 | 104 | 105 | def correct_gpu_memory_gib(gpu_name: str, memory_mib: float) -> int: 106 | """ 107 | Round to whole number of gibibytes and attempt correcting the reported GPU 108 | memory size if the actual memory size for that GPU is known and the 109 | difference between the reported and the known memory is within a heuristic 110 | threshold. 111 | 112 | This is useful for cases when nvidia-smi or cloud providers report the GPU 113 | memory imprecisely. 114 | """ 115 | 116 | memory_gib = memory_mib / 1024 117 | known_memories_gib = {gpu.memory for gpu in KNOWN_ACCELERATORS if gpu.name == gpu_name} 118 | if known_memories_gib: 119 | closest_known_memory_gib = min(known_memories_gib, key=lambda x: abs(x - memory_gib)) 120 | difference_gib = abs(closest_known_memory_gib - memory_gib) 121 | if difference_gib / closest_known_memory_gib < 0.07: 122 | return closest_known_memory_gib 123 | return round(memory_gib) 124 | 125 | 126 | def is_nvidia_superchip(gpu_name: str) -> bool: 127 | """ 128 | Check if the given NVIDIA GPU is actually a so-called "superchip" combining GPU with ARM CPU, 129 | such as: 130 | * GH200 (Grace + Hopper) 131 | * GB10, GB200 (Grace + Blackwell) 132 | """ 133 | return re.match(r"^g[bh]\d+", gpu_name.lower()) is not None 134 | 135 | 136 | KNOWN_NVIDIA_GPUS: list[NvidiaGPUInfo] = [ 137 | NvidiaGPUInfo(name="A10", memory=24, compute_capability=(8, 6)), 138 | NvidiaGPUInfo(name="A16", memory=16, compute_capability=(8, 6)), 139 | NvidiaGPUInfo(name="A40", memory=48, compute_capability=(8, 6)), 140 | NvidiaGPUInfo(name="A100", memory=40, compute_capability=(8, 0)), 141 | NvidiaGPUInfo(name="A100", memory=80, compute_capability=(8, 0)), 142 | NvidiaGPUInfo(name="A10G", memory=24, compute_capability=(8, 6)), 143 | NvidiaGPUInfo(name="A4000", memory=16, compute_capability=(8, 6)), 144 | NvidiaGPUInfo(name="A4500", memory=20, compute_capability=(8, 6)), 145 | NvidiaGPUInfo(name="A5000", memory=24, compute_capability=(8, 6)), 146 | NvidiaGPUInfo(name="A6000", memory=48, compute_capability=(8, 6)), 147 | NvidiaGPUInfo(name="H100", memory=80, compute_capability=(9, 0)), 148 | NvidiaGPUInfo(name="H100NVL", memory=94, compute_capability=(9, 0)), 149 | NvidiaGPUInfo(name="L4", memory=24, compute_capability=(8, 9)), 150 | NvidiaGPUInfo(name="L40", memory=48, compute_capability=(8, 9)), 151 | NvidiaGPUInfo(name="L40S", memory=48, compute_capability=(8, 9)), 152 | NvidiaGPUInfo(name="P100", memory=16, compute_capability=(6, 0)), 153 | NvidiaGPUInfo(name="RTX3060", memory=8, compute_capability=(8, 6)), 154 | NvidiaGPUInfo(name="RTX3060", memory=12, compute_capability=(8, 6)), 155 | NvidiaGPUInfo(name="RTX3060Ti", memory=8, compute_capability=(8, 6)), 156 | NvidiaGPUInfo(name="RTX3070Ti", memory=8, compute_capability=(8, 6)), 157 | NvidiaGPUInfo(name="RTX3080", memory=10, compute_capability=(8, 6)), 158 | NvidiaGPUInfo(name="RTX3080Ti", memory=12, compute_capability=(8, 6)), 159 | NvidiaGPUInfo(name="RTX3090", memory=24, compute_capability=(8, 6)), 160 | NvidiaGPUInfo(name="RTX4090", memory=24, compute_capability=(8, 9)), 161 | NvidiaGPUInfo(name="RTX6000", memory=24, compute_capability=(7, 5)), 162 | NvidiaGPUInfo(name="RTX2000Ada", memory=16, compute_capability=(8, 9)), 163 | NvidiaGPUInfo(name="RTX4000Ada", memory=20, compute_capability=(8, 9)), 164 | NvidiaGPUInfo(name="RTX6000Ada", memory=48, compute_capability=(8, 9)), 165 | NvidiaGPUInfo(name="T4", memory=16, compute_capability=(7, 5)), 166 | NvidiaGPUInfo(name="V100", memory=16, compute_capability=(7, 0)), 167 | NvidiaGPUInfo(name="V100", memory=32, compute_capability=(7, 0)), 168 | NvidiaGPUInfo(name="GH200", memory=96, compute_capability=(9, 0)), 169 | ] 170 | 171 | KNOWN_AMD_GPUS: list[AMDGPUInfo] = [ 172 | AMDGPUInfo(name="MI100", memory=32, architecture=AMDArchitecture.CDNA), 173 | AMDGPUInfo(name="MI210", memory=64, architecture=AMDArchitecture.CDNA2), 174 | AMDGPUInfo(name="MI250", memory=128, architecture=AMDArchitecture.CDNA2), 175 | AMDGPUInfo(name="MI250X", memory=128, architecture=AMDArchitecture.CDNA2), 176 | AMDGPUInfo(name="MI300A", memory=128, architecture=AMDArchitecture.CDNA3), 177 | AMDGPUInfo(name="MI300X", memory=192, architecture=AMDArchitecture.CDNA3), 178 | AMDGPUInfo(name="MI308X", memory=128, architecture=AMDArchitecture.CDNA3), 179 | AMDGPUInfo(name="MI325X", memory=288, architecture=AMDArchitecture.CDNA3), 180 | ] 181 | 182 | KNOWN_TPUS: list[TPUInfo] = [TPUInfo(name=version, memory=0) for version in _TPU_VERSIONS] 183 | 184 | KNOWN_INTEL_ACCELERATORS: list[IntelAcceleratorInfo] = [ 185 | IntelAcceleratorInfo(name="Gaudi", memory=32), # HL-205 186 | IntelAcceleratorInfo(name="Gaudi2", memory=96), # HL-225 187 | IntelAcceleratorInfo(name="Gaudi3", memory=128), 188 | ] 189 | 190 | KNOWN_TENSTORRENT_ACCELERATORS: list[TenstorrentAcceleratorInfo] = [ 191 | TenstorrentAcceleratorInfo(name="n150", memory=12), 192 | TenstorrentAcceleratorInfo(name="n300", memory=24), 193 | ] 194 | 195 | KNOWN_ACCELERATORS: list[ 196 | Union[NvidiaGPUInfo, AMDGPUInfo, TPUInfo, IntelAcceleratorInfo, TenstorrentAcceleratorInfo] 197 | ] = ( 198 | KNOWN_NVIDIA_GPUS 199 | + KNOWN_AMD_GPUS 200 | + KNOWN_TPUS 201 | + KNOWN_INTEL_ACCELERATORS 202 | + KNOWN_TENSTORRENT_ACCELERATORS 203 | ) 204 | -------------------------------------------------------------------------------- /src/gpuhunt/_internal/default.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import importlib 3 | import logging 4 | from typing import Callable, TypeVar 5 | 6 | from typing_extensions import Concatenate, ParamSpec 7 | 8 | from gpuhunt._internal.catalog import Catalog 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @functools.lru_cache 14 | def default_catalog() -> Catalog: 15 | """ 16 | Returns: 17 | the latest catalog with all available providers loaded 18 | """ 19 | catalog = Catalog() 20 | catalog.load() 21 | for module, provider in [ 22 | ("gpuhunt.providers.tensordock", "TensorDockProvider"), 23 | ("gpuhunt.providers.vastai", "VastAIProvider"), 24 | ("gpuhunt.providers.cudo", "CudoProvider"), 25 | ("gpuhunt.providers.vultr", "VultrProvider"), 26 | ]: 27 | try: 28 | module = importlib.import_module(module) 29 | provider = getattr(module, provider)() 30 | catalog.add_provider(provider) 31 | except ImportError: 32 | logger.warning("Failed to import provider %s", provider) 33 | return catalog 34 | 35 | 36 | P = ParamSpec("P") 37 | R = TypeVar("R") 38 | Method = Callable[P, R] 39 | CatalogMethod = Callable[Concatenate[Catalog, P], R] 40 | 41 | 42 | def with_signature(method: CatalogMethod) -> Callable[[Method], Method]: 43 | """ 44 | Returns: 45 | decorator to add the signature of the Catalog method to the decorated method 46 | """ 47 | 48 | def decorator(func: Method) -> Method: 49 | @functools.wraps(func) 50 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 51 | return func(*args, **kwargs) 52 | 53 | return wrapper 54 | 55 | return decorator 56 | 57 | 58 | @with_signature(Catalog.query) 59 | def query(*args: P.args, **kwargs: P.kwargs) -> R: 60 | """ 61 | Query the `default_catalog`. 62 | See `Catalog.query` for more details on parameters 63 | 64 | Returns: 65 | (List[CatalogItem]): the result of the query 66 | """ 67 | return default_catalog().query(*args, **kwargs) 68 | -------------------------------------------------------------------------------- /src/gpuhunt/_internal/models.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from collections.abc import Container 3 | from dataclasses import asdict, dataclass, field, fields 4 | from typing import ( 5 | ClassVar, 6 | Optional, 7 | Union, 8 | ) 9 | 10 | from gpuhunt._internal.utils import empty_as_none 11 | 12 | 13 | def bool_loader(x: Union[bool, str]) -> bool: 14 | if isinstance(x, bool): 15 | return x 16 | return x.lower() == "true" 17 | 18 | 19 | class AMDArchitecture(enum.Enum): 20 | CDNA = "CDNA" 21 | CDNA2 = "CDNA2" 22 | CDNA3 = "CDNA3" 23 | 24 | @classmethod 25 | def cast(cls, value: Union["AMDArchitecture", str]) -> "AMDArchitecture": 26 | if isinstance(value, AMDArchitecture): 27 | return value 28 | return cls(value.upper()) 29 | 30 | 31 | class AcceleratorVendor(str, enum.Enum): 32 | NVIDIA = "nvidia" 33 | AMD = "amd" 34 | GOOGLE = "google" 35 | INTEL = "intel" 36 | TENSTORRENT = "tenstorrent" 37 | 38 | @classmethod 39 | def cast(cls, value: Union["AcceleratorVendor", str]) -> "AcceleratorVendor": 40 | if isinstance(value, AcceleratorVendor): 41 | return value 42 | return cls(value.lower()) 43 | 44 | 45 | class CPUArchitecture(str, enum.Enum): 46 | X86 = "x86" # x86-64 extension support implied 47 | ARM = "arm" # AArch64 (ARM64) execution state support implied 48 | 49 | @classmethod 50 | def cast(cls, value: Union["CPUArchitecture", str]) -> "CPUArchitecture": 51 | if isinstance(value, CPUArchitecture): 52 | return value 53 | return cls(value.lower()) 54 | 55 | 56 | @dataclass 57 | class RawCatalogItem: 58 | """ 59 | An item stored in the catalog. 60 | See `CatalogItem` for field descriptions. 61 | """ 62 | 63 | instance_name: Optional[str] 64 | location: Optional[str] 65 | price: Optional[float] 66 | cpu: Optional[int] 67 | memory: Optional[float] 68 | gpu_count: Optional[int] 69 | gpu_name: Optional[str] 70 | gpu_memory: Optional[float] 71 | spot: Optional[bool] 72 | disk_size: Optional[float] 73 | gpu_vendor: Optional[str] = None 74 | flags: list[str] = field(default_factory=list) 75 | cpu_arch: Optional[str] = None 76 | 77 | def __post_init__(self) -> None: 78 | self._process_gpu_vendor() 79 | self._process_cpu_arch() 80 | 81 | def _process_gpu_vendor(self) -> None: 82 | # This heuristic will be required indefinitely since we support historical catalogs. 83 | is_tpu = False 84 | gpu_name = self.gpu_name 85 | if gpu_name and gpu_name.startswith("tpu-"): 86 | is_tpu = True 87 | self.gpu_name = gpu_name[4:] 88 | gpu_vendor = self.gpu_vendor 89 | if gpu_vendor is None: 90 | if not self.gpu_count: 91 | # None or 0 92 | return 93 | if is_tpu: 94 | self.gpu_vendor = AcceleratorVendor.GOOGLE.value 95 | else: 96 | self.gpu_vendor = AcceleratorVendor.NVIDIA.value 97 | elif isinstance(gpu_vendor, AcceleratorVendor): 98 | self.gpu_vendor = gpu_vendor.value 99 | 100 | def _process_cpu_arch(self) -> None: 101 | # This heuristic will be required indefinitely since we support historical catalogs. 102 | cpu_arch = self.cpu_arch 103 | if cpu_arch is None: 104 | self.cpu_arch = CPUArchitecture.X86.value 105 | elif isinstance(cpu_arch, CPUArchitecture): 106 | self.cpu_arch = cpu_arch.value 107 | 108 | @staticmethod 109 | def from_dict(v: dict) -> "RawCatalogItem": 110 | return RawCatalogItem( 111 | instance_name=empty_as_none(v.get("instance_name")), 112 | location=empty_as_none(v.get("location")), 113 | price=empty_as_none(v.get("price"), loader=float), 114 | cpu_arch=empty_as_none(v.get("cpu_arch")), 115 | cpu=empty_as_none(v.get("cpu"), loader=int), 116 | memory=empty_as_none(v.get("memory"), loader=float), 117 | gpu_vendor=empty_as_none(v.get("gpu_vendor")), 118 | gpu_count=empty_as_none(v.get("gpu_count"), loader=int), 119 | gpu_name=empty_as_none(v.get("gpu_name")), 120 | gpu_memory=empty_as_none(v.get("gpu_memory"), loader=float), 121 | spot=empty_as_none(v.get("spot"), loader=bool_loader), 122 | disk_size=empty_as_none(v.get("disk_size"), loader=float), 123 | flags=v.get("flags", "").split(), 124 | ) 125 | 126 | def dict(self) -> dict[str, Union[str, int, float, bool, None]]: 127 | return { 128 | **asdict(self), 129 | "flags": " ".join(self.flags), 130 | } 131 | 132 | 133 | @dataclass 134 | class CatalogItem: 135 | """ 136 | An item returned by `Catalog.query`. 137 | Attributes: 138 | instance_name: name of the instance 139 | location: region or zone 140 | price: $ per hour 141 | cpu_arch: CPU instruction set architecture 142 | cpu: number of CPUs 143 | memory: amount of RAM in GB 144 | gpu_vendor: GPU/accelerator vendor 145 | gpu_count: number of GPUs 146 | gpu_name: name of the GPU 147 | gpu_memory: amount of GPU VRAM in GB for each GPU 148 | spot: whether the instance is a spot instance 149 | provider: name of the provider 150 | disk_size: size of disk in GB 151 | flags: list of flags. If a catalog item breaks existing dstack versions, 152 | add a flag to hide the item from those versions. Newer dstack versions 153 | will have to request this flag explicitly to get the catalog item. 154 | If you are adding a new provider, leave the flags empty. 155 | Flag names should be in kebab-case. 156 | """ 157 | 158 | instance_name: str 159 | location: str 160 | price: float 161 | cpu: int 162 | memory: float 163 | gpu_count: int 164 | gpu_name: Optional[str] 165 | gpu_memory: Optional[float] 166 | spot: bool 167 | disk_size: Optional[float] 168 | provider: str 169 | gpu_vendor: Optional[AcceleratorVendor] = None 170 | flags: list[str] = field(default_factory=list) 171 | cpu_arch: Optional[CPUArchitecture] = None 172 | 173 | def __post_init__(self) -> None: 174 | self._process_gpu_vendor() 175 | self._process_cpu_arch() 176 | 177 | def _process_gpu_vendor(self) -> None: 178 | # This heuristic is only required until we update all providers to always set the vendor. 179 | gpu_vendor = self.gpu_vendor 180 | if gpu_vendor is None: 181 | if not self.gpu_count: 182 | # None or 0 183 | return 184 | # GCPProvider already sets gpu_vendor, and all other providers only support Nvidia 185 | self.gpu_vendor = AcceleratorVendor.NVIDIA 186 | else: 187 | # This cast to the enum is always required since RawCatalogItem.gpu_vendor 188 | # is a string field (for (de)serialization purposes). 189 | self.gpu_vendor = AcceleratorVendor.cast(gpu_vendor) 190 | 191 | def _process_cpu_arch(self) -> None: 192 | # This heuristic is only required until we update all providers to always set the arch. 193 | cpu_arch = self.cpu_arch 194 | if cpu_arch is None: 195 | self.cpu_arch = CPUArchitecture.X86 196 | else: 197 | self.cpu_arch = CPUArchitecture.cast(cpu_arch) 198 | 199 | @staticmethod 200 | def from_dict(v: dict, *, provider: Optional[str] = None) -> "CatalogItem": 201 | return CatalogItem(provider=provider, **asdict(RawCatalogItem.from_dict(v))) 202 | 203 | 204 | @dataclass 205 | class QueryFilter: 206 | """ 207 | Attributes: 208 | provider: name of the provider to filter by. If not specified, all providers will be used 209 | cpu_arch: CPU architecture. If not specified, all architectures will be used 210 | min_cpu: minimum number of CPUs 211 | max_cpu: maximum number of CPUs 212 | min_memory: minimum amount of RAM in GB 213 | max_memory: maximum amount of RAM in GB 214 | min_gpu_count: minimum number of GPUs 215 | max_gpu_count: maximum number of GPUs 216 | gpu_vendor: accelerator vendor to filter by. If not specified, all vendors will be used 217 | gpu_name: name of the GPU to filter by. If not specified, all GPUs will be used 218 | min_gpu_memory: minimum amount of GPU VRAM in GB for each GPU 219 | max_gpu_memory: maximum amount of GPU VRAM in GB for each GPU 220 | min_total_gpu_memory: minimum amount of GPU VRAM in GB for all GPUs combined 221 | max_total_gpu_memory: maximum amount of GPU VRAM in GB for all GPUs combined 222 | min_disk_size: minimum disk size in GB 223 | max_disk_size: maximum disk size in GB 224 | min_price: minimum price per hour in USD 225 | max_price: maximum price per hour in USD 226 | min_compute_capability: minimum compute capability of the GPU 227 | max_compute_capability: maximum compute capability of the GPU 228 | spot: if `False`, only ondemand offers will be returned. If `True`, only spot offers will be returned 229 | allowed_flags: only offers with all flags allowed will be returned. `None` allows all flags 230 | """ 231 | 232 | provider: Optional[list[str]] = None # strings can have mixed case 233 | cpu_arch: Optional[CPUArchitecture] = None 234 | min_cpu: Optional[int] = None 235 | max_cpu: Optional[int] = None 236 | min_memory: Optional[float] = None 237 | max_memory: Optional[float] = None 238 | min_gpu_count: Optional[int] = None 239 | max_gpu_count: Optional[int] = None 240 | gpu_vendor: Optional[AcceleratorVendor] = None 241 | gpu_name: Optional[list[str]] = None # strings can have mixed case 242 | min_gpu_memory: Optional[float] = None 243 | max_gpu_memory: Optional[float] = None 244 | min_total_gpu_memory: Optional[float] = None 245 | max_total_gpu_memory: Optional[float] = None 246 | min_disk_size: Optional[int] = None 247 | max_disk_size: Optional[int] = None 248 | min_price: Optional[float] = None 249 | max_price: Optional[float] = None 250 | min_compute_capability: Optional[tuple[int, int]] = None 251 | max_compute_capability: Optional[tuple[int, int]] = None 252 | spot: Optional[bool] = None 253 | allowed_flags: Optional[Container[str]] = None 254 | 255 | def __repr__(self) -> str: 256 | """ 257 | >>> QueryFilter() 258 | QueryFilter() 259 | >>> QueryFilter(min_cpu=4) 260 | QueryFilter(min_cpu=4) 261 | >>> QueryFilter(max_price=1.2, min_cpu=4) 262 | QueryFilter(min_cpu=4, max_price=1.2) 263 | """ 264 | kv = ", ".join( 265 | f"{f.name}={value}" 266 | for f in fields(self) 267 | if (value := getattr(self, f.name)) is not None 268 | ) 269 | return f"QueryFilter({kv})" 270 | 271 | 272 | @dataclass 273 | class AcceleratorInfo: 274 | vendor: ClassVar[AcceleratorVendor] 275 | name: str 276 | memory: int 277 | 278 | 279 | @dataclass 280 | class NvidiaGPUInfo(AcceleratorInfo): 281 | vendor = AcceleratorVendor.NVIDIA 282 | compute_capability: tuple[int, int] 283 | 284 | 285 | @dataclass 286 | class AMDGPUInfo(AcceleratorInfo): 287 | vendor = AcceleratorVendor.AMD 288 | architecture: AMDArchitecture 289 | 290 | 291 | @dataclass 292 | class TPUInfo(AcceleratorInfo): 293 | vendor = AcceleratorVendor.GOOGLE 294 | 295 | 296 | @dataclass 297 | class IntelAcceleratorInfo(AcceleratorInfo): 298 | vendor = AcceleratorVendor.INTEL 299 | 300 | 301 | @dataclass 302 | class TenstorrentAcceleratorInfo(AcceleratorInfo): 303 | vendor = AcceleratorVendor.TENSTORRENT 304 | -------------------------------------------------------------------------------- /src/gpuhunt/_internal/storage.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import dataclasses 3 | from typing import TypeVar 4 | 5 | from gpuhunt._internal.models import RawCatalogItem 6 | 7 | CATALOG_V1_FIELDS = [ 8 | "instance_name", 9 | "location", 10 | "price", 11 | "cpu", 12 | "memory", 13 | "gpu_count", 14 | "gpu_name", 15 | "gpu_memory", 16 | "spot", 17 | "disk_size", 18 | "gpu_vendor", 19 | ] 20 | T = TypeVar("T", bound=RawCatalogItem) 21 | 22 | 23 | def dump(items: list[T], path: str, *, cls: type[T] = RawCatalogItem): 24 | with open(path, "w", newline="") as f: 25 | writer = csv.DictWriter(f, fieldnames=[field.name for field in dataclasses.fields(cls)]) 26 | writer.writeheader() 27 | for item in items: 28 | writer.writerow(item.dict()) 29 | 30 | 31 | def convert_catalog_v2_to_v1(path_v2: str, path_v1: str) -> None: 32 | with open(path_v2) as f_v2, open(path_v1, "w") as f_v1: 33 | reader = csv.DictReader(f_v2) 34 | writer = csv.DictWriter(f_v1, fieldnames=CATALOG_V1_FIELDS, extrasaction="ignore") 35 | writer.writeheader() 36 | for row in reader: 37 | if not row.get("flags"): 38 | writer.writerow(row) 39 | -------------------------------------------------------------------------------- /src/gpuhunt/_internal/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from typing import Callable, Optional, Union 4 | 5 | 6 | def configure_logging() -> None: 7 | logging.basicConfig( 8 | level=logging.INFO, 9 | stream=sys.stdout, 10 | format="%(asctime)s %(levelname)s %(message)s", 11 | ) 12 | 13 | 14 | def empty_as_none(value: Optional[str], loader: Optional[Callable] = None): 15 | if value is None or value == "": 16 | return None 17 | if loader is not None: 18 | return loader(value) 19 | return value 20 | 21 | 22 | def parse_compute_capability( 23 | value: Optional[Union[str, tuple[int, int]]], 24 | ) -> Optional[tuple[int, int]]: 25 | if isinstance(value, str): 26 | major, minor = value.split(".") 27 | return int(major), int(minor) 28 | return value 29 | 30 | 31 | def to_camel_case(snake_case: str) -> str: 32 | words = snake_case.split("_") 33 | words = list(filter(None, words)) 34 | words[1:] = [word[:1].upper() + word[1:] for word in words[1:]] 35 | return "".join(words) 36 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | from gpuhunt._internal.models import QueryFilter, RawCatalogItem 5 | 6 | 7 | class AbstractProvider(ABC): 8 | """ 9 | Abstract class for cloud provider implementations. 10 | 11 | Attributes: 12 | NAME: (class variable) The name of the provider. 13 | """ 14 | 15 | NAME: str = "abstract" # Override in subclasses 16 | 17 | @abstractmethod 18 | def get( 19 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 20 | ) -> list[RawCatalogItem]: 21 | """ 22 | Return a list of available instance offers. Offers should be ordered by priority. In most 23 | cases - by price, ascending. 24 | 25 | Args: 26 | query_filter: Set of filters requested by the user. Only used with online providers. 27 | Filters are safe to ignore, as they are also enforced by `gpuhunt` after calling 28 | `get`. However, they can be used to reduce the number or size of API requests if 29 | the provider's API supports filtering by GPU, RAM, region, etc. 30 | balance_resources: Whether the instance resources (CPU, RAM, disk) should be 31 | adjusted to better match the GPU. Only used with online providers. Only relevant 32 | to cloud providers that allow configuring instance CPU, RAM, and disk. 33 | """ 34 | 35 | pass 36 | 37 | @classmethod 38 | def filter(cls, offers: list[RawCatalogItem]) -> list[RawCatalogItem]: 39 | """ 40 | Return a subset of offers that should be stored in the catalog. 41 | 42 | Only used with offline providers. Only implement this method if there are reasons to omit 43 | some offers from the catalog. 44 | """ 45 | 46 | return offers 47 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/aws.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import csv 3 | import datetime 4 | import logging 5 | import os 6 | import re 7 | import tempfile 8 | from collections import defaultdict 9 | from collections.abc import Iterable 10 | from concurrent.futures import ThreadPoolExecutor, as_completed 11 | from typing import Optional 12 | 13 | import boto3 14 | import requests 15 | from botocore.exceptions import ClientError, EndpointConnectionError 16 | 17 | from gpuhunt._internal.models import QueryFilter, RawCatalogItem 18 | from gpuhunt.providers import AbstractProvider 19 | 20 | logger = logging.getLogger(__name__) 21 | ec2_pricing_url = ( 22 | "https://pricing.us-east-1.amazonaws.com/offers/v1.0/aws/AmazonEC2/current/index.csv" 23 | ) 24 | disclaimer_rows_skip = 5 25 | # https://aws.amazon.com/ec2/previous-generation/ 26 | previous_generation_families = [ 27 | "a1.", 28 | "c1.", 29 | "c3.", 30 | "c4.", 31 | "g2.", 32 | "g3.", 33 | "g3s.", 34 | "i2.", 35 | "m1.", 36 | "m2.", 37 | "m3.", 38 | "p2.", 39 | "r3.", 40 | "r4.", 41 | "t1.", 42 | "cr1.", 43 | "hs1.", 44 | ] 45 | pricing_filters = { 46 | "TermType": ["OnDemand"], 47 | "Tenancy": ["Shared"], 48 | "Operating System": ["Linux"], 49 | "CapacityStatus": ["Used"], 50 | "Unit": ["Hrs"], 51 | "Currency": ["USD"], 52 | "Pre Installed S/W": ["", "NA"], 53 | "MarketOption": ["OnDemand"], 54 | } 55 | describe_instances_limit = 100 56 | 57 | 58 | class AWSProvider(AbstractProvider): 59 | """ 60 | AWSProvider parses Bulk API index file for AmazonEC2 in all regions and fills missing GPU details 61 | 62 | Required IAM permissions: 63 | * `ec2:DescribeInstanceTypes` 64 | """ 65 | 66 | NAME = "aws" 67 | 68 | def __init__(self, cache_path: Optional[str] = None): 69 | if cache_path: 70 | self.cache_path = cache_path 71 | else: 72 | self.temp_dir = tempfile.TemporaryDirectory() 73 | self.cache_path = self.temp_dir.name + "/index.csv" 74 | # todo aws creds 75 | self.preview_gpus = { 76 | "p4de.24xlarge": ("A100", 80.0), 77 | } 78 | 79 | def get( 80 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 81 | ) -> list[RawCatalogItem]: 82 | if not os.path.exists(self.cache_path): 83 | logger.info("Downloading EC2 prices to %s", self.cache_path) 84 | with requests.get(ec2_pricing_url, stream=True, timeout=20) as r: 85 | r.raise_for_status() 86 | with open(self.cache_path, "wb") as f: 87 | for chunk in r.iter_content(chunk_size=8192): 88 | f.write(chunk) 89 | 90 | offers = [] 91 | with open(self.cache_path, newline="") as f: 92 | for _ in range(disclaimer_rows_skip): 93 | f.readline() 94 | reader: Iterable[dict[str, str]] = csv.DictReader(f) 95 | for row in reader: 96 | if self.skip(row): 97 | continue 98 | offer = RawCatalogItem( 99 | instance_name=row["Instance Type"], 100 | location=row["Region Code"], 101 | price=float(row["PricePerUnit"]), 102 | cpu=int(row["vCPU"]), 103 | memory=parse_memory(row["Memory"]), 104 | gpu_vendor=None, 105 | gpu_count=parse_optional_count(row["GPU"]), 106 | spot=False, 107 | gpu_name=None, 108 | gpu_memory=None, 109 | disk_size=None, 110 | ) 111 | offers.append(offer) 112 | self.fill_gpu_details(offers) 113 | offers = self.add_spots(offers) 114 | return sorted(offers, key=lambda i: i.price) 115 | 116 | def skip(self, row: dict[str, str]) -> bool: 117 | if any(row["Instance Type"].startswith(family) for family in previous_generation_families): 118 | return True 119 | for key, values in pricing_filters.items(): 120 | if row[key] not in values: 121 | return True 122 | return False 123 | 124 | def fill_gpu_details(self, offers: list[RawCatalogItem]): 125 | regions = defaultdict(list) 126 | for offer in offers: 127 | if offer.gpu_count > 0 and offer.instance_name not in self.preview_gpus: 128 | regions[offer.location].append(offer.instance_name) 129 | 130 | gpus = copy.deepcopy(self.preview_gpus) 131 | while regions: 132 | region = max(regions, key=lambda r: len(regions[r])) 133 | instance_types = regions.pop(region) 134 | 135 | client = boto3.client("ec2", region_name=region) 136 | paginator = client.get_paginator("describe_instance_types") 137 | for offset in range(0, len(instance_types), describe_instances_limit): 138 | logger.info("Fetching GPU details for %s (offset=%s)", region, offset) 139 | pages = paginator.paginate( 140 | InstanceTypes=instance_types[offset : offset + describe_instances_limit] 141 | ) 142 | for page in pages: 143 | for i in page["InstanceTypes"]: 144 | gpu = i["GpuInfo"]["Gpus"][0] 145 | gpus[i["InstanceType"]] = ( 146 | gpu["Name"], 147 | _get_gpu_memory_gib(gpu["Name"], gpu["MemoryInfo"]["SizeInMiB"]), 148 | ) 149 | 150 | regions = { 151 | region: left 152 | for region, names in regions.items() 153 | if (left := [i for i in names if i not in instance_types]) 154 | } 155 | 156 | for offer in offers: 157 | if offer.gpu_count > 0: 158 | offer.gpu_name, offer.gpu_memory = gpus[offer.instance_name] 159 | 160 | def _add_spots_worker( 161 | self, region: str, instance_types: set[str] 162 | ) -> dict[tuple[str, str], float]: 163 | spot_prices = dict() 164 | logger.info("Fetching spot prices for %s", region) 165 | try: 166 | client = boto3.client("ec2", region_name=region) # todo creds 167 | pages = client.get_paginator("describe_spot_price_history").paginate( 168 | Filters=[ 169 | { 170 | "Name": "product-description", 171 | "Values": ["Linux/UNIX"], 172 | } 173 | ], 174 | InstanceTypes=list(instance_types), 175 | StartTime=datetime.datetime.utcnow(), 176 | ) 177 | 178 | instance_prices = defaultdict(list) 179 | for page in pages: 180 | for item in page["SpotPriceHistory"]: 181 | instance_prices[item["InstanceType"]].append(float(item["SpotPrice"])) 182 | for ( 183 | instance_type, 184 | zone_prices, 185 | ) in instance_prices.items(): # reduce zone prices to a single value 186 | spot_prices[(instance_type, region)] = min(zone_prices) 187 | except (ClientError, EndpointConnectionError): 188 | return {} 189 | return spot_prices 190 | 191 | def add_spots(self, offers: list[RawCatalogItem]) -> list[RawCatalogItem]: 192 | region_instances = defaultdict(set) 193 | for offer in offers: 194 | region_instances[offer.location].add(offer.instance_name) 195 | 196 | spot_prices = dict() 197 | with ThreadPoolExecutor(max_workers=8) as executor: 198 | future_to_region = {} 199 | for region, instance_types in region_instances.items(): 200 | future = executor.submit(self._add_spots_worker, region, instance_types) 201 | future_to_region[future] = region 202 | for future in as_completed(future_to_region): 203 | spot_prices.update(future.result()) 204 | 205 | spot_offers = [] 206 | for offer in offers: 207 | if (price := spot_prices.get((offer.instance_name, offer.location))) is None: 208 | continue 209 | spot_offer = copy.deepcopy(offer) 210 | spot_offer.spot = True 211 | spot_offer.price = price 212 | spot_offers.append(spot_offer) 213 | return offers + spot_offers 214 | 215 | @classmethod 216 | def filter(cls, offers: list[RawCatalogItem]) -> list[RawCatalogItem]: 217 | return [ 218 | i 219 | for i in offers 220 | if any( 221 | i.instance_name.startswith(family) 222 | for family in [ 223 | "m7i.", 224 | "c7i.", 225 | "r7i.", 226 | "t3.", 227 | "t2.small", 228 | "c5.", 229 | "m5.", 230 | "p5.", 231 | "p5e.", 232 | "p4d.", 233 | "p4de.", 234 | "p3.", 235 | "g6.", 236 | "g6e.", 237 | "gr6.", 238 | "g5.", 239 | "g4dn.", 240 | ] 241 | ) 242 | ] 243 | 244 | 245 | def _get_gpu_memory_gib(gpu_name: str, reported_memory_mib: int) -> float: 246 | """ 247 | Fixes L4 memory size misreported by AWS API 248 | """ 249 | 250 | if gpu_name != "L4": 251 | return reported_memory_mib / 1024 252 | 253 | if reported_memory_mib not in (22888, 91553, 183105): 254 | logger.warning( 255 | "The L4 memory size reported by AWS changed. " 256 | "Please check that it is now correct and remove the hardcoded size if it is." 257 | ) 258 | return 24 259 | 260 | 261 | def parse_memory(s: str) -> float: 262 | r = re.match(r"^([0-9.]+) GiB$", s) 263 | return float(r.group(1)) 264 | 265 | 266 | def parse_optional_count(s: str) -> int: 267 | if not s: 268 | return 0 269 | return int(s) 270 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/azure.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import os 5 | import re 6 | import time 7 | from collections import namedtuple 8 | from collections.abc import Iterable 9 | from queue import Queue 10 | from threading import Thread 11 | from typing import Optional 12 | 13 | import requests 14 | import requests.adapters 15 | from azure.core.credentials import TokenCredential 16 | from azure.identity import DefaultAzureCredential 17 | from azure.mgmt.compute import ComputeManagementClient 18 | 19 | from gpuhunt._internal.models import QueryFilter, RawCatalogItem 20 | from gpuhunt.providers import AbstractProvider 21 | 22 | logger = logging.getLogger(__name__) 23 | prices_url = "https://prices.azure.com/api/retail/prices" 24 | retail_prices_page_size = 1000 25 | prices_version = "2023-01-01-preview" 26 | prices_filters = [ 27 | "serviceName eq 'Virtual Machines'", 28 | "priceType eq 'Consumption'", 29 | "contains(productName, 'Windows') eq false", 30 | "contains(productName, 'Dedicated') eq false", 31 | "contains(meterName, 'Low Priority') eq false", # retires in 2025 32 | ] 33 | VMSeries = namedtuple("VMSeries", ["pattern", "gpu_name", "gpu_memory"]) 34 | gpu_vm_series = [ 35 | VMSeries(r"NC(\d+)ads_A100_v4", "A100", 80.0), # NC A100 v4-series [A100 80GB] 36 | VMSeries(r"NC(\d+)ads_A10_v4", "A10", 24.0), # NC A10 v4-series [A10] 37 | VMSeries(r"NC(\d+)as_T4_v3", "T4", 16.0), # NCasT4_v3-series [T4] 38 | VMSeries(r"NC(\d+)r?s_v3", "V100", 16.0), # NCv3-series [V100 16GB] 39 | VMSeries(r"ND(\d+)amsr_A100_v4", "A100", 80.0), # NDm A100 v4-series [8xA100 80GB] 40 | VMSeries(r"ND(\d+)asr_v4", "A100", 40.0), # ND A100 v4-series [8xA100 40GB] 41 | VMSeries(r"ND(\d+)rs_v2", "V100", 32.0), # NDv2-series [8xV100 32GB] 42 | VMSeries(r"NG(\d+)adm?s_V620_v1", "V620", None), # NGads V620-series [V620] # todo 43 | VMSeries(r"NV(\d+)adm?s_A10_v5", "A10", 24.0), # NVadsA10 v5-series [A10] 44 | VMSeries(r"NV(\d+)as_v4", "MI25", None), # NVv4-series [MI25] # todo 45 | VMSeries(r"NV(\d+)s_v3", "M60", None), # NVv3-series [M60] # todo 46 | ] 47 | # https://learn.microsoft.com/en-us/azure/virtual-machines/sizes-previous-gen 48 | retired_vm_series = [ 49 | r"Basic_A(\d+)", 50 | r"Standard_A(\d+)", 51 | r"Standard_D(\d+)", 52 | r"Standard_DC(\d+)s", 53 | r"Standard_DS(\d+)", 54 | r"Standard_F(\d+)", 55 | r"Standard_F(\d+)s", 56 | r"Standard_G(\d+)", 57 | r"Standard_GS(\d+)", 58 | r"Standard_L(\d+)s", 59 | r"Standard_NC(\d+)r?", 60 | r"Standard_NC(\d+)r?s_v2", 61 | r"Standard_ND(\d+)r?s", 62 | r"Standard_NV(\d+)", 63 | r"Standard_NV(\d+)s_v2", 64 | ] 65 | 66 | 67 | class AzureProvider(AbstractProvider): 68 | NAME = "azure" 69 | 70 | def __init__( 71 | self, 72 | subscription_id: str, 73 | credential: Optional[TokenCredential] = None, 74 | cache_dir: Optional[str] = None, 75 | ): 76 | self.cache_dir = cache_dir 77 | self.client = ComputeManagementClient( 78 | credential=credential or DefaultAzureCredential(), 79 | subscription_id=subscription_id, 80 | ) 81 | 82 | def get_pages(self, threads: int = 8) -> Iterable[list[dict]]: 83 | q = Queue() 84 | workers = [ 85 | Thread(target=self._get_pages_worker, args=(q, threads, i), daemon=True) 86 | for i in range(threads) 87 | ] 88 | for worker in workers: 89 | worker.start() 90 | 91 | exited = 0 92 | while exited < threads: 93 | page = q.get() 94 | if page is None: 95 | exited += 1 96 | else: 97 | yield page 98 | q.task_done() 99 | 100 | def _get_pages_worker(self, q: Queue, stride: int, worker_id: int): 101 | page_id = worker_id 102 | session = requests.Session() 103 | session.mount("https://", requests.adapters.HTTPAdapter(max_retries=3)) 104 | try: 105 | while True: 106 | cached_page = None 107 | if self.cache_dir is not None: 108 | cached_page = os.path.join(self.cache_dir, f"{page_id:04}.json") 109 | if cached_page is not None and os.path.exists(cached_page): 110 | with open(cached_page) as f: 111 | data = json.load(f) 112 | else: 113 | logger.info("Worker %s fetches pricing page %s", worker_id, page_id) 114 | res = session.get( 115 | prices_url, 116 | params={ 117 | "$filter": " and ".join(prices_filters), 118 | "$skip": page_id * retail_prices_page_size, 119 | }, 120 | ) 121 | if res.status_code == 429: 122 | logger.warning("Worker %s got 429: sleep 3 & retry", worker_id) 123 | time.sleep(3) 124 | continue 125 | res.raise_for_status() 126 | if cached_page is not None: 127 | with open(cached_page, "w") as f: 128 | f.write(res.text) 129 | data = res.json() 130 | if not data["Items"]: 131 | logger.info("Worker %s exited", worker_id) 132 | return 133 | q.put(data["Items"]) 134 | page_id += stride 135 | except Exception as e: 136 | logger.exception("Worker %s failed: %s", worker_id, e) 137 | finally: 138 | q.put(None) 139 | 140 | def get( 141 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 142 | ) -> list[RawCatalogItem]: 143 | offers = [] 144 | for page in self.get_pages(): 145 | for item in page: 146 | if is_retired(item["armSkuName"]): 147 | continue 148 | if not item["armSkuName"]: 149 | continue 150 | price = float(item["retailPrice"]) 151 | if math.isclose(price, 0): 152 | continue 153 | offer = RawCatalogItem( 154 | instance_name=item["armSkuName"], 155 | location=item["armRegionName"], 156 | price=price, 157 | spot="Spot" in item["meterName"], 158 | cpu=None, 159 | memory=None, 160 | gpu_vendor=None, 161 | gpu_count=None, 162 | gpu_name=None, 163 | gpu_memory=None, 164 | disk_size=None, 165 | ) 166 | offers.append(offer) 167 | offers = self.fill_details(offers) 168 | return sorted(offers, key=lambda i: i.price) 169 | 170 | def fill_details(self, offers: list[RawCatalogItem]) -> list[RawCatalogItem]: 171 | logger.info("Fetching instance details") 172 | instances = {} 173 | resources = self.client.resource_skus.list() 174 | for resource in resources: 175 | if resource.resource_type != "virtualMachines": 176 | continue 177 | if is_retired(resource.name): 178 | continue 179 | capabilities = {pair.name: pair.value for pair in resource.capabilities} 180 | cpu = capabilities.get("vCPUs") 181 | memory = capabilities.get("MemoryGB") 182 | if not cpu: 183 | logger.warning("Instance CPU is missing: %s", resource.name) 184 | continue 185 | if not memory: 186 | logger.warning("Instance memory is missing: %s", resource.name) 187 | continue 188 | gpu_count, gpu_name, gpu_memory = 0, None, None 189 | if "GPUs" in capabilities: 190 | gpu_count = int(capabilities["GPUs"]) 191 | gpu_name, gpu_memory = get_gpu_name_memory(resource.name) 192 | if gpu_name is None and gpu_count: 193 | logger.warning("Can't parse VM name: %s", resource.name) 194 | continue 195 | instances[resource.name] = RawCatalogItem( 196 | instance_name=resource.name, 197 | cpu=cpu, 198 | memory=float(memory), 199 | gpu_vendor=None, 200 | gpu_count=gpu_count, 201 | gpu_name=gpu_name, 202 | gpu_memory=gpu_memory, 203 | location=None, 204 | price=None, 205 | spot=None, 206 | disk_size=None, 207 | ) 208 | with_details = [] 209 | for offer in offers: 210 | if (resources := instances.get(offer.instance_name)) is None: 211 | continue 212 | offer.cpu = resources.cpu 213 | offer.memory = resources.memory 214 | offer.gpu_count = resources.gpu_count 215 | offer.gpu_name = resources.gpu_name 216 | offer.gpu_memory = resources.gpu_memory 217 | with_details.append(offer) 218 | return with_details 219 | 220 | @classmethod 221 | def filter(cls, offers: list[RawCatalogItem]) -> list[RawCatalogItem]: 222 | vm_series = [ 223 | VMSeries(r"D(\d+)s_v6", None, None), # Dsv6-series 224 | VMSeries( 225 | r"E(2|4|8|16|20|32|48|64|96)s_v6", None, None 226 | ), # Esv6-series (E128 and E192i are not yet GA) 227 | VMSeries(r"F(\d+)s_v2", None, None), # Fsv2-series 228 | VMSeries(r"NC(\d+)s_v3", "V100", 16 * 1024), # NCv3-series [V100 16GB] 229 | VMSeries(r"NC(\d+)as_T4_v3", "T4", 16 * 1024), # NCasT4_v3-series [T4] 230 | VMSeries(r"ND(\d+)rs_v2", "V100", 32 * 1024), # NDv2-series [8xV100 32GB] 231 | VMSeries(r"NV(\d+)adm?s_A10_v5", "A10", 24 * 1024), # NVadsA10 v5-series [A10] 232 | VMSeries(r"NC(\d+)ads_A100_v4", "A100", 80 * 1024), # NC A100 v4-series [A100 80GB] 233 | VMSeries(r"ND(\d+)asr_v4", "A100", 40 * 1024), # ND A100 v4-series [8xA100 40GB] 234 | VMSeries( 235 | r"ND(\d+)amsr_A100_v4", "A100", 80 * 1024 236 | ), # NDm A100 v4-series [8xA100 80GB] 237 | # The deprecated series are collected for older dstack versions 238 | VMSeries( 239 | r"D(\d+)s_v3", None, None 240 | ), # Dsv3-series (deprecated in favor of Dsv6-series) 241 | VMSeries( 242 | r"E(\d+)i?s_v4", None, None 243 | ), # Esv4-series (deprecated in favor of Esv6-series) 244 | VMSeries( 245 | r"E(\d+)-(\d+)s_v4", None, None 246 | ), # Esv4-series (constrained vCPU, deprecated in favor of Esv6-series) 247 | ] 248 | vm_series_pattern = re.compile( 249 | f"^Standard_({'|'.join(series.pattern for series in vm_series)})$" 250 | ) 251 | return [i for i in offers if vm_series_pattern.match(i.instance_name)] 252 | 253 | 254 | def get_gpu_name_memory(vm_name: str) -> tuple[Optional[str], Optional[float]]: 255 | for pattern, gpu_name, gpu_memory in gpu_vm_series: 256 | m = re.match(f"^Standard_{pattern}$", vm_name) 257 | if m is None: 258 | continue 259 | if gpu_name == "A10" and vm_name.endswith("_v4"): 260 | gpu_memory = gpu_memory * min(1.0, int(m.group(1)) / 16) 261 | elif gpu_name == "A10" and vm_name.endswith("_v5"): 262 | gpu_memory = gpu_memory * min(1.0, int(m.group(1)) / 36) 263 | 264 | return gpu_name, gpu_memory 265 | return None, None 266 | 267 | 268 | def is_retired(name: str) -> bool: 269 | if re.match(f"^({'|'.join(retired_vm_series)})$", name): 270 | return True 271 | return False 272 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/cloudrift.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Union 3 | 4 | import requests 5 | 6 | from gpuhunt import QueryFilter, RawCatalogItem 7 | from gpuhunt.providers import AbstractProvider 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | CLOUDRIFT_SERVER_ADDRESS = "https://api.cloudrift.ai" 12 | CLOUDRIFT_API_VERSION = "2025-03-21" 13 | 14 | 15 | class CloudRiftProvider(AbstractProvider): 16 | NAME = "cloudrift" 17 | 18 | def get( 19 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 20 | ) -> list[RawCatalogItem]: 21 | instance_types = self._get_instance_types() 22 | instance_types = [ 23 | inst for instance in instance_types for inst in generate_instances(instance) 24 | ] 25 | return sorted(instance_types, key=lambda x: x.price) 26 | 27 | def _get_instance_types(self): 28 | request_data = {"selector": {"ByServiceAndLocation": {"services": ["vm"]}}} 29 | response_data = _make_request("instance-types/list", request_data) 30 | return response_data["instance_types"] 31 | 32 | 33 | def generate_instances(instance) -> list[RawCatalogItem]: 34 | instance_gpu_brand = instance["brand_short"] 35 | dstack_gpu_name = GPU_MAP.get(instance_gpu_brand) 36 | if dstack_gpu_name is None: 37 | logger.warning(f"Failed to find GPU name matching '{instance_gpu_brand}'") 38 | return [] 39 | 40 | instance_types = [] 41 | for variant in instance["variants"]: 42 | for location, _count in variant["available_nodes_per_dc"].items(): 43 | raw = RawCatalogItem( 44 | instance_name=variant["name"], 45 | location=location, 46 | spot=False, 47 | price=variant["cost_per_hour"] / 100, 48 | cpu=variant["cpu_count"], 49 | memory=variant["dram"] / 1024**3, 50 | disk_size=variant["disk"] / 1024**3, 51 | gpu_count=variant["gpu_count"], 52 | gpu_name=dstack_gpu_name, 53 | gpu_memory=round(variant["vram"] / 1024**3), 54 | ) 55 | instance_types.append(raw) 56 | 57 | return instance_types 58 | 59 | 60 | GPU_MAP = { 61 | r"RTX 4090": "RTX4090", 62 | r"RTX 5090": "RTX5090", 63 | r"RTX 6000 Pro": "RTX6000PRO", 64 | } 65 | 66 | 67 | def _make_request(endpoint: str, request_data: dict) -> Union[dict, str, None]: 68 | response = requests.request( 69 | "POST", 70 | f"{CLOUDRIFT_SERVER_ADDRESS}/api/v1/{endpoint}", 71 | json={"version": CLOUDRIFT_API_VERSION, "data": request_data}, 72 | timeout=5.0, 73 | ) 74 | if not response.ok: 75 | response.raise_for_status() 76 | try: 77 | response_json = response.json() 78 | if isinstance(response_json, str): 79 | return response_json 80 | return response_json["data"] 81 | except requests.exceptions.JSONDecodeError: 82 | return None 83 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/datacrunch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | import logging 4 | import re 5 | from collections.abc import Iterable 6 | from typing import Optional 7 | 8 | from datacrunch import DataCrunchClient 9 | from datacrunch.instance_types.instance_types import InstanceType 10 | 11 | from gpuhunt import QueryFilter, RawCatalogItem 12 | from gpuhunt.providers import AbstractProvider 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | AMD_RX7900XTX = "RX7900XTX" 17 | ALL_AMD_GPUS = [ 18 | AMD_RX7900XTX, 19 | ] 20 | 21 | 22 | class DataCrunchProvider(AbstractProvider): 23 | NAME = "datacrunch" 24 | 25 | def __init__(self, client_id: str, client_secret: str) -> None: 26 | self.datacrunch_client = DataCrunchClient(client_id, client_secret) 27 | 28 | def get( 29 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 30 | ) -> list[RawCatalogItem]: 31 | instance_types = self._get_instance_types() 32 | locations = self._get_locations() 33 | 34 | spots = (True, False) 35 | location_codes = [loc["code"] for loc in locations] 36 | instances = generate_instances(spots, location_codes, instance_types) 37 | 38 | return sorted(instances, key=lambda x: x.price) 39 | 40 | def _get_instance_types(self) -> list[InstanceType]: 41 | return self.datacrunch_client.instance_types.get() 42 | 43 | def _get_locations(self) -> list[dict]: 44 | return self.datacrunch_client.locations.get() 45 | 46 | @classmethod 47 | def filter(cls, offers: list[RawCatalogItem]) -> list[RawCatalogItem]: 48 | return [o for o in offers if o.gpu_name not in ALL_AMD_GPUS] # skip AMD GPU 49 | 50 | 51 | def generate_instances( 52 | spots: Iterable[bool], location_codes: Iterable[str], instance_types: Iterable[InstanceType] 53 | ) -> list[RawCatalogItem]: 54 | instances = [] 55 | for spot, location, instance in itertools.product(spots, location_codes, instance_types): 56 | item = transform_instance(copy.copy(instance), spot, location) 57 | if item is None: 58 | continue 59 | instances.append(RawCatalogItem.from_dict(item)) 60 | return instances 61 | 62 | 63 | def transform_instance(instance: InstanceType, spot: bool, location: str) -> Optional[dict]: 64 | gpu_memory = 0 65 | gpu_count = instance.gpu["number_of_gpus"] 66 | gpu_name = None 67 | 68 | if instance.gpu["number_of_gpus"]: 69 | gpu_memory = instance.gpu_memory["size_in_gigabytes"] / instance.gpu["number_of_gpus"] 70 | gpu_name = get_gpu_name(instance.gpu["description"]) 71 | 72 | if gpu_count and gpu_name is None: 73 | logger.warning( 74 | "Failed to get GPU name from description: '%s'", instance.gpu["description"] 75 | ) 76 | return None 77 | 78 | raw = dict( 79 | instance_name=instance.instance_type, 80 | location=location, 81 | spot=spot, 82 | price=instance.spot_price_per_hour if spot else instance.price_per_hour, 83 | cpu=instance.cpu["number_of_cores"], 84 | memory=instance.memory["size_in_gigabytes"], 85 | gpu_count=gpu_count, 86 | gpu_name=gpu_name, 87 | gpu_memory=gpu_memory, 88 | ) 89 | return raw 90 | 91 | 92 | GPU_MAP = { 93 | r"\d+x B200 SXM6 180GB": "B200", 94 | r"\d+x H200 SXM5 141GB": "H200", 95 | r"\d+x H100 SXM5 80GB": "H100", 96 | r"\d+x A100 SXM4 80GB": "A100", 97 | r"\d+x A100 SXM4 40GB": "A100", 98 | r"\d+x RTX6000 Ada 48GB": "RTX6000Ada", 99 | r"\d+x RTX A6000 48GB": "A6000", 100 | r"\d+x Tesla V100 16GB": "V100", 101 | r"\d+x L40S 48GB": "L40S", 102 | r"\d+x AMD 7900XTX": AMD_RX7900XTX, 103 | } 104 | 105 | 106 | def get_gpu_name(name: str) -> Optional[str]: 107 | for regex, gpu_name in GPU_MAP.items(): 108 | if re.fullmatch(regex, name): 109 | return gpu_name 110 | return None 111 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/lambdalabs.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import re 4 | from typing import Optional 5 | 6 | from requests import Session 7 | 8 | from gpuhunt._internal.constraints import is_nvidia_superchip 9 | from gpuhunt._internal.models import ( 10 | CPUArchitecture, 11 | QueryFilter, 12 | RawCatalogItem, 13 | ) 14 | from gpuhunt.providers import AbstractProvider 15 | 16 | logger = logging.getLogger(__name__) 17 | INSTANCE_TYPES_URL = "https://cloud.lambdalabs.com/api/v1/instance-types" 18 | IMAGES_URL = "https://cloud.lambdalabs.com/api/v1/images" 19 | TIMEOUT = 10 20 | 21 | FLAG_ARM = "lambda-arm" 22 | 23 | 24 | class LambdaLabsProvider(AbstractProvider): 25 | NAME = "lambdalabs" 26 | 27 | def __init__(self, token: str): 28 | self.session = Session() 29 | self.session.headers.update({"Authorization": f"Bearer {token}"}) 30 | 31 | def get( 32 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 33 | ) -> list[RawCatalogItem]: 34 | offers = [] 35 | resp = self.session.get(INSTANCE_TYPES_URL, timeout=TIMEOUT) 36 | resp.raise_for_status() 37 | data = resp.json()["data"] 38 | for instance in data.values(): 39 | instance = instance["instance_type"] 40 | logger.info(instance["name"]) 41 | description = instance["description"] 42 | result = parse_description(description) 43 | if result is None: 44 | logger.warning("Can't parse GPU info from description: %s", description) 45 | continue 46 | gpu_count, gpu_name, gpu_memory = result 47 | flags: list[str] = [] 48 | cpu_arch = CPUArchitecture.X86 49 | if is_nvidia_superchip(gpu_name): 50 | cpu_arch = CPUArchitecture.ARM 51 | flags.append(FLAG_ARM) 52 | offer = RawCatalogItem( 53 | instance_name=instance["name"], 54 | price=instance["price_cents_per_hour"] / 100, 55 | cpu_arch=cpu_arch.value, 56 | cpu=instance["specs"]["vcpus"], 57 | memory=float(instance["specs"]["memory_gib"]) * 1.074, 58 | gpu_vendor=None, 59 | gpu_count=gpu_count, 60 | gpu_name=gpu_name, 61 | gpu_memory=gpu_memory, 62 | spot=False, 63 | location=None, 64 | disk_size=float(instance["specs"]["storage_gib"]) * 1.074, 65 | flags=flags, 66 | ) 67 | offers.append(offer) 68 | offers = self.add_regions(offers) 69 | return sorted(offers, key=lambda i: i.price) 70 | 71 | def add_regions(self, offers: list[RawCatalogItem]) -> list[RawCatalogItem]: 72 | # TODO: we don't know which regions are actually available for each instance type 73 | region_offers = [] 74 | for region in self.list_regions(): 75 | for offer in offers: 76 | offer = copy.deepcopy(offer) 77 | offer.location = region 78 | region_offers.append(offer) 79 | return region_offers 80 | 81 | def list_regions(self) -> list[str]: 82 | resp = self.session.get(IMAGES_URL, timeout=TIMEOUT) 83 | resp.raise_for_status() 84 | regions = set() 85 | for image in resp.json()["data"]: 86 | regions.add(image["region"]["name"]) 87 | return sorted(regions) 88 | 89 | 90 | def parse_description(v: str) -> Optional[tuple[int, str, float]]: 91 | """Returns gpus count, gpu name, and GPU memory""" 92 | r = re.match(r"^(\d)x (?:Tesla )?(.+) \((\d+) GB", v) 93 | if r is None: 94 | return None 95 | count, gpu_name, gpu_memory = r.groups() 96 | return int(count), gpu_name.replace(" ", ""), float(gpu_memory) 97 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/nebius.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | from nebius.aio.channel import Credentials 7 | from nebius.api.nebius.compute.v1 import ( 8 | ListPlatformsRequest, 9 | ListPlatformsResponse, 10 | PlatformServiceClient, 11 | Preset, 12 | ) 13 | from nebius.api.nebius.iam.v1 import ( 14 | ListProjectsRequest, 15 | ListTenantsRequest, 16 | ProjectServiceClient, 17 | TenantServiceClient, 18 | ) 19 | from nebius.sdk import SDK 20 | 21 | from gpuhunt._internal.models import AcceleratorVendor, QueryFilter, RawCatalogItem 22 | from gpuhunt.providers import AbstractProvider 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | @dataclass 28 | class PlatformGPU: 29 | name: str 30 | memory_gib: int 31 | price_hour: float 32 | vendor: AcceleratorVendor = AcceleratorVendor.NVIDIA 33 | 34 | 35 | @dataclass 36 | class Platform: 37 | name: str 38 | gpu: Optional[PlatformGPU] 39 | cpu_price_hour: float 40 | memory_gib_price_hour: float 41 | 42 | 43 | # Until Nebius provides a pricing API, prices are taken from 44 | # https://docs.nebius.com/compute/resources/pricing 45 | PLATFORMS = [ 46 | Platform( 47 | # NVIDIA® H100 NVLink with Intel Sapphire Rapids 48 | name="gpu-h100-sxm", 49 | gpu=PlatformGPU( 50 | name="H100", 51 | memory_gib=80, 52 | price_hour=2.118, 53 | ), 54 | cpu_price_hour=0.012, 55 | memory_gib_price_hour=0.0032, 56 | ), 57 | Platform( 58 | # NVIDIA® H200 NVLink with Intel Sapphire Rapids 59 | name="gpu-h200-sxm", 60 | gpu=PlatformGPU( 61 | name="H200", 62 | memory_gib=141, 63 | price_hour=2.668, 64 | ), 65 | cpu_price_hour=0.012, 66 | memory_gib_price_hour=0.0032, 67 | ), 68 | Platform( 69 | # NVIDIA® L40S PCIe with Intel Ice Lake 70 | name="gpu-l40s-a", 71 | gpu=PlatformGPU( 72 | name="L40S", 73 | memory_gib=48, 74 | price_hour=1.35, 75 | ), 76 | cpu_price_hour=0.012, 77 | memory_gib_price_hour=0.0032, 78 | ), 79 | Platform( 80 | # NVIDIA® L40S PCIe with AMD Epyc Genoa 81 | name="gpu-l40s-d", 82 | gpu=PlatformGPU( 83 | name="L40S", 84 | memory_gib=48, 85 | price_hour=1.35, 86 | ), 87 | cpu_price_hour=0.01, 88 | memory_gib_price_hour=0.0032, 89 | ), 90 | Platform( 91 | # Non-GPU AMD EPYC Genoa 92 | name="cpu-d3", 93 | gpu=None, 94 | cpu_price_hour=0.012, 95 | memory_gib_price_hour=0.0032, 96 | ), 97 | Platform( 98 | # Non-GPU Intel Ice Lake 99 | name="cpu-e2", 100 | gpu=None, 101 | cpu_price_hour=0.012, 102 | memory_gib_price_hour=0.0032, 103 | ), 104 | ] 105 | PLATFORMS_MAP = {p.name: p for p in PLATFORMS} 106 | TIMEOUT = 7 107 | 108 | 109 | class NebiusProvider(AbstractProvider): 110 | NAME = "nebius" 111 | 112 | def __init__(self, credentials: Credentials) -> None: 113 | self.credentials = credentials 114 | 115 | def get( 116 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 117 | ) -> list[RawCatalogItem]: 118 | items: list[RawCatalogItem] = [] 119 | sdk = SDK(credentials=self.credentials) 120 | try: 121 | for region_code, region_name in get_regions_map(sdk).items(): 122 | for platform in list_platforms(sdk, region_code).items: 123 | known_platform_details = PLATFORMS_MAP.get(platform.metadata.name) 124 | if known_platform_details is None: 125 | logger.warning(f"Unknown platform: {platform.metadata.name}") 126 | continue 127 | for preset in platform.spec.presets: 128 | items.append(make_item(known_platform_details, preset, region_name)) 129 | finally: 130 | sdk.sync_close(timeout=TIMEOUT) 131 | items.sort(key=lambda i: i.price) 132 | return items 133 | 134 | 135 | def get_regions_map(sdk: SDK) -> dict[str, str]: 136 | """ 137 | Returns: 138 | `{"e00": "eu-north1", "e01": "eu-west1", ...}` 139 | """ 140 | tenants = TenantServiceClient(sdk).list(ListTenantsRequest(), timeout=TIMEOUT).wait() 141 | if len(tenants.items) != 1: 142 | raise ValueError(f"Expected to find 1 tenant, found {(len(tenants.items))}") 143 | projects = ( 144 | ProjectServiceClient(sdk) 145 | .list(ListProjectsRequest(parent_id=tenants.items[0].metadata.id), timeout=TIMEOUT) 146 | .wait() 147 | ) 148 | result = {} 149 | for project in projects.items: 150 | match = re.match(r"^project-([a-z]\d\d)", project.metadata.id) 151 | if match is None: 152 | logger.error(f"Could not parse project id {project.metadata.id!r}") 153 | continue 154 | result[match.group(1)] = project.status.region 155 | return result 156 | 157 | 158 | def list_platforms(sdk: SDK, region_code: str) -> ListPlatformsResponse: 159 | req = ListPlatformsRequest( 160 | page_size=999, 161 | parent_id=f"project-{region_code}public-images", 162 | ) 163 | return PlatformServiceClient(sdk).list(req, timeout=TIMEOUT).wait() 164 | 165 | 166 | def make_item(platform: Platform, preset: Preset, region: str) -> RawCatalogItem: 167 | item = RawCatalogItem( 168 | instance_name=f"{platform.name} {preset.name}", 169 | location=region, 170 | price=( 171 | preset.resources.vcpu_count * platform.cpu_price_hour 172 | + preset.resources.memory_gibibytes * platform.memory_gib_price_hour 173 | ), 174 | cpu=preset.resources.vcpu_count, 175 | memory=preset.resources.memory_gibibytes, 176 | gpu_count=0, 177 | gpu_name=None, 178 | gpu_memory=None, 179 | gpu_vendor=None, 180 | spot=False, 181 | disk_size=None, 182 | ) 183 | if platform.gpu is not None: 184 | item.gpu_count = preset.resources.gpu_count 185 | item.gpu_name = platform.gpu.name 186 | item.gpu_memory = platform.gpu.memory_gib 187 | item.gpu_vendor = platform.gpu.vendor.value 188 | item.price += item.gpu_count * platform.gpu.price_hour 189 | item.price = round(item.price, 8) # fix floating point precision errors 190 | return item 191 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/oci.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import re 4 | from collections.abc import Iterable 5 | from dataclasses import asdict, dataclass 6 | from typing import Annotated, Optional 7 | 8 | import oci 9 | from oci.identity.models import Region 10 | from pydantic import BaseModel, Field 11 | from requests import Session 12 | from typing_extensions import TypedDict 13 | 14 | from gpuhunt._internal.constraints import KNOWN_NVIDIA_GPUS 15 | from gpuhunt._internal.models import QueryFilter, RawCatalogItem 16 | from gpuhunt._internal.utils import to_camel_case 17 | from gpuhunt.providers import AbstractProvider 18 | 19 | logger = logging.getLogger(__name__) 20 | COST_ESTIMATOR_URL_TEMPLATE = "https://www.oracle.com/a/ocom/docs/cloudestimator2/data/{resource}" 21 | COST_ESTIMATOR_REQUEST_TIMEOUT = 10 22 | 23 | 24 | class OCICredentials(TypedDict): 25 | user: Optional[str] 26 | key_content: Optional[str] 27 | fingerprint: Optional[str] 28 | tenancy: Optional[str] 29 | region: Optional[str] 30 | 31 | 32 | class OCIProvider(AbstractProvider): 33 | NAME = "oci" 34 | 35 | def __init__(self, credentials: OCICredentials): 36 | self.api_client = oci.identity.IdentityClient( 37 | credentials if all(credentials.values()) else oci.config.from_file() 38 | ) 39 | self.cost_estimator = CostEstimator() 40 | 41 | def get( 42 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 43 | ) -> list[RawCatalogItem]: 44 | shapes = self.cost_estimator.get_shapes() 45 | products = self.cost_estimator.get_products() 46 | regions: list[Region] = self.api_client.list_regions().data 47 | 48 | result = [] 49 | 50 | for shape in shapes.items: 51 | if ( 52 | shape.hidden 53 | or shape.status != "ACTIVE" 54 | or shape.shape_type.value not in ("vm", "bm") 55 | or shape.sub_type.value not in ("standard", "gpu", "optimized") 56 | or ".A1." in shape.name 57 | ): 58 | continue 59 | 60 | try: 61 | resources = shape_to_resources(shape, products) 62 | except CostEstimatorDataError as e: 63 | logger.warning( 64 | "Skipping shape %s due to unexpected Cost Estimator data: %s", shape.name, e 65 | ) 66 | continue 67 | 68 | on_demand_item = RawCatalogItem( 69 | instance_name=shape.name, 70 | location=None, 71 | price=resources.total_price(), 72 | cpu=resources.cpu.vcpus, 73 | memory=resources.memory.gbs, 74 | gpu_vendor=None, 75 | gpu_count=resources.gpu.units_count, 76 | gpu_name=resources.gpu.name, 77 | gpu_memory=resources.gpu.unit_memory_gb, 78 | spot=False, 79 | disk_size=None, 80 | ) 81 | item_variations = [on_demand_item] 82 | if shape.allow_preemptible: 83 | item_variations.append(self._make_spot_item(on_demand_item)) 84 | for item in item_variations: 85 | result.extend(self._duplicate_item_in_regions(item, regions)) 86 | 87 | return sorted(result, key=lambda i: i.price) 88 | 89 | @staticmethod 90 | def _make_spot_item(item: RawCatalogItem) -> RawCatalogItem: 91 | item = copy.deepcopy(item) 92 | item.spot = True 93 | # > Preemptible capacity costs 50% less than on-demand capacity 94 | # https://docs.oracle.com/en-us/iaas/Content/Compute/Concepts/preemptible.htm#howitworks__billing 95 | item.price *= 0.5 96 | item.flags.append("oci-spot") 97 | return item 98 | 99 | @staticmethod 100 | def _duplicate_item_in_regions( 101 | item: RawCatalogItem, regions: Iterable[Region] 102 | ) -> list[RawCatalogItem]: 103 | result = [] 104 | for region in regions: 105 | regional_item = copy.deepcopy(item) 106 | regional_item.location = region.name 107 | result.append(regional_item) 108 | return result 109 | 110 | 111 | class CostEstimatorTypeField(BaseModel): 112 | value: str 113 | 114 | 115 | class CostEstimatorShapeProduct(BaseModel): 116 | type: CostEstimatorTypeField 117 | part_number: str 118 | qty: Optional[int] 119 | 120 | class Config: 121 | alias_generator = to_camel_case 122 | 123 | 124 | class CostEstimatorShape(BaseModel): 125 | name: str 126 | hidden: bool 127 | status: str 128 | allow_preemptible: bool 129 | bundle_memory_qty: Optional[int] 130 | gpu_qty: Optional[int] 131 | gpu_memory_qty: Optional[int] 132 | processor_type: CostEstimatorTypeField 133 | shape_type: CostEstimatorTypeField 134 | sub_type: CostEstimatorTypeField 135 | products: list[CostEstimatorShapeProduct] 136 | 137 | class Config: 138 | alias_generator = to_camel_case 139 | 140 | def is_arm_cpu(self): 141 | is_ampere_gpu = self.sub_type.value == "gpu" and ( 142 | "GPU4" in self.name or "GPU.A10" in self.name 143 | ) 144 | # the data says A10 and A100 GPU instances are ARM, but they are not 145 | return self.processor_type.value == "arm" and not is_ampere_gpu 146 | 147 | def get_gpu_unit_memory_gb(self) -> Optional[float]: 148 | if self.gpu_memory_qty and self.gpu_qty: 149 | return self.gpu_memory_qty / self.gpu_qty 150 | return None 151 | 152 | 153 | class CostEstimatorShapeList(BaseModel): 154 | items: list[CostEstimatorShape] 155 | 156 | 157 | class CostEstimatorPrice(BaseModel): 158 | model: str 159 | value: float 160 | 161 | 162 | class CostEstimatorPriceLocalization(BaseModel): 163 | currency_code: str 164 | prices: list[CostEstimatorPrice] 165 | 166 | class Config: 167 | alias_generator = to_camel_case 168 | 169 | 170 | class CostEstimatorProduct(BaseModel): 171 | part_number: str 172 | billing_model: str 173 | price_type: Annotated[str, Field(alias="pricetype")] 174 | currency_code_localizations: list[CostEstimatorPriceLocalization] 175 | 176 | class Config: 177 | alias_generator = to_camel_case 178 | 179 | def find_price_l10n(self, currency_code: str) -> Optional[CostEstimatorPriceLocalization]: 180 | return next( 181 | filter( 182 | lambda price: price.currency_code == currency_code, 183 | self.currency_code_localizations, 184 | ), 185 | None, 186 | ) 187 | 188 | 189 | class CostEstimatorProductList(BaseModel): 190 | items: list[CostEstimatorProduct] 191 | 192 | def find(self, part_number: str) -> Optional[CostEstimatorProduct]: 193 | return next(filter(lambda product: product.part_number == part_number, self.items), None) 194 | 195 | 196 | class CostEstimator: 197 | def __init__(self): 198 | self.session = Session() 199 | 200 | def get_shapes(self) -> CostEstimatorShapeList: 201 | return self._get("shapes.json", CostEstimatorShapeList) 202 | 203 | def get_products(self) -> CostEstimatorProductList: 204 | return self._get("products.json", CostEstimatorProductList) 205 | 206 | def _get(self, resource: str, ResponseModel: type[BaseModel]): 207 | url = COST_ESTIMATOR_URL_TEMPLATE.format(resource=resource) 208 | resp = self.session.get(url, timeout=COST_ESTIMATOR_REQUEST_TIMEOUT) 209 | resp.raise_for_status() 210 | return ResponseModel.parse_raw(resp.content) 211 | 212 | 213 | class CostEstimatorDataError(Exception): 214 | pass 215 | 216 | 217 | @dataclass 218 | class CPUConfiguration: 219 | vcpus: int 220 | price: float 221 | 222 | 223 | @dataclass 224 | class MemoryConfiguration: 225 | gbs: int 226 | price: float 227 | 228 | 229 | @dataclass 230 | class GPUConfiguration: 231 | units_count: int 232 | unit_memory_gb: Optional[float] 233 | name: Optional[str] 234 | price: float 235 | 236 | def __post_init__(self): 237 | d = asdict(self) 238 | if any(d.values()) and not all(d.values()): 239 | raise CostEstimatorDataError(f"Incomplete GPU parameters: {self}") 240 | 241 | 242 | @dataclass 243 | class ResourcesConfiguration: 244 | cpu: CPUConfiguration 245 | memory: MemoryConfiguration 246 | gpu: GPUConfiguration 247 | 248 | def total_price(self) -> float: 249 | return self.cpu.price + self.memory.price + self.gpu.price 250 | 251 | 252 | def shape_to_resources( 253 | shape: CostEstimatorShape, products: CostEstimatorProductList 254 | ) -> ResourcesConfiguration: 255 | cpu = None 256 | gpu = GPUConfiguration(units_count=0, unit_memory_gb=None, name=None, price=0.0) 257 | memory: Optional[MemoryConfiguration] = None 258 | if shape.bundle_memory_qty is not None: 259 | memory = MemoryConfiguration(gbs=shape.bundle_memory_qty, price=0.0) 260 | 261 | for product in shape.products: 262 | if product.qty is None: 263 | raise CostEstimatorDataError("Product quantity not found") 264 | product_details = products.find(product.part_number) 265 | if product_details is None: 266 | raise CostEstimatorDataError(f"Could not find product {product.part_number!r}") 267 | product_price = get_product_price_usd_per_hour(product_details) 268 | 269 | if product.type.value == "ocpu": 270 | vcpus = product.qty if shape.is_arm_cpu() else product.qty * 2 271 | if shape.gpu_qty: 272 | gpu = GPUConfiguration( 273 | units_count=shape.gpu_qty, 274 | unit_memory_gb=shape.get_gpu_unit_memory_gb(), 275 | name=get_gpu_name(shape.name), 276 | price=product_price * shape.gpu_qty, 277 | ) 278 | cpu = CPUConfiguration(vcpus=vcpus, price=0.0) 279 | else: 280 | cpu = CPUConfiguration(vcpus=vcpus, price=product_price * product.qty) 281 | 282 | elif product.type.value == "memory": 283 | memory = MemoryConfiguration(gbs=product.qty, price=product_price * product.qty) 284 | 285 | else: 286 | raise CostEstimatorDataError(f"Unknown product type {product.type.value!r}") 287 | 288 | if cpu is None: 289 | raise CostEstimatorDataError("No ocpu product") 290 | if memory is None: 291 | raise CostEstimatorDataError("No memory product") 292 | 293 | return ResourcesConfiguration(cpu, memory, gpu) 294 | 295 | 296 | def get_product_price_usd_per_hour(product: CostEstimatorProduct) -> float: 297 | if product.billing_model != "UCM": 298 | raise CostEstimatorDataError( 299 | f"Billing model for product {product.part_number!r} is {product.billing_model!r}" 300 | ) 301 | if product.price_type != "HOUR": 302 | raise CostEstimatorDataError( 303 | f"Price type for product {product.part_number!r} is {product.price_type!r}" 304 | ) 305 | price_l10n = product.find_price_l10n("USD") 306 | if price_l10n is None: 307 | raise CostEstimatorDataError(f"No USD price for product {product.part_number!r}") 308 | if len(price_l10n.prices) != 1: 309 | raise CostEstimatorDataError( 310 | f"Product {product.part_number!r} has {len(price_l10n.prices)} USD prices" 311 | ) 312 | price = price_l10n.prices[0] 313 | if price.model != "PAY_AS_YOU_GO": 314 | raise CostEstimatorDataError( 315 | f"Pricing model for product {product.part_number!r} is {price.model!r}" 316 | ) 317 | return price.value 318 | 319 | 320 | def get_gpu_name(shape_name: str) -> Optional[str]: 321 | parts = re.split(r"[\.-]", shape_name.upper()) 322 | 323 | if "GPU4" in parts: 324 | return "A100" 325 | if "GPU3" in parts: 326 | return "V100" 327 | if "GPU2" in parts: 328 | return "P100" 329 | 330 | if "GPU" in parts: 331 | gpu_name_index = parts.index("GPU") + 1 332 | if gpu_name_index < len(parts): 333 | gpu_name = parts[gpu_name_index] 334 | 335 | for gpu in KNOWN_NVIDIA_GPUS: 336 | if gpu.name.upper() == gpu_name: 337 | return gpu.name 338 | return None 339 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/runpod.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | from concurrent.futures import ThreadPoolExecutor 4 | from typing import Optional 5 | 6 | import requests 7 | from requests import RequestException 8 | 9 | from gpuhunt._internal.constraints import KNOWN_AMD_GPUS 10 | from gpuhunt._internal.models import AcceleratorVendor, QueryFilter, RawCatalogItem 11 | from gpuhunt.providers import AbstractProvider 12 | 13 | logger = logging.getLogger(__name__) 14 | API_URL = "https://api.runpod.io/graphql" 15 | 16 | 17 | class RunpodProvider(AbstractProvider): 18 | NAME = "runpod" 19 | 20 | def get( 21 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 22 | ) -> list[RawCatalogItem]: 23 | offers = self.fetch_offers() 24 | return sorted(offers, key=lambda i: i.price) 25 | 26 | @staticmethod 27 | def fetch_offers() -> list[RawCatalogItem]: 28 | query_variables = build_query_variables() 29 | 30 | with ThreadPoolExecutor(max_workers=10) as executor: 31 | futures = [ 32 | executor.submit(get_pods, query_variable) for query_variable in query_variables 33 | ] 34 | pods_by_query = [] 35 | for future in futures: 36 | try: 37 | pods_by_query.append(future.result()) 38 | except RequestException as e: 39 | logger.exception("Failed to get pods data: %s", e) 40 | 41 | catalog_items = [] 42 | for query_variable, pods in zip(query_variables, pods_by_query): 43 | for pod in pods: 44 | catalog_items.extend(make_catalog_items(query_variable, pod)) 45 | return catalog_items 46 | 47 | @classmethod 48 | def filter(cls, offers: list[RawCatalogItem]) -> list[RawCatalogItem]: 49 | return [ 50 | o 51 | for o in offers 52 | if o.location 53 | not in [ 54 | "AR", # network problems, unusable 55 | ] 56 | ] 57 | 58 | 59 | def gpu_vendor_and_name(gpu_id: str) -> Optional[tuple[AcceleratorVendor, str]]: 60 | if not gpu_id: 61 | return None 62 | return GPU_MAP.get(gpu_id) 63 | 64 | 65 | def build_query_variables() -> list[dict]: 66 | """Prepare different combinations of API query filters to cover all available GPUs.""" 67 | 68 | gpu_types = make_request({"query": gpu_types_query, "variables": {}}) 69 | data_centers = [dc["id"] for dc in gpu_types["data"]["dataCenters"] if dc["listed"]] 70 | max_gpu_count = max(gpu["maxGpuCount"] for gpu in gpu_types["data"]["gpuTypes"]) 71 | 72 | variables = [] 73 | for gpu_count in range(1, max_gpu_count + 1): 74 | # Secure cloud is queryable by datacenter ID 75 | for dc_id in data_centers: 76 | variables.append( 77 | { 78 | "secureCloud": True, 79 | "dataCenterId": dc_id, 80 | "gpuCount": gpu_count, 81 | "minDisk": None, 82 | "minMemoryInGb": None, 83 | "minVcpuCount": None, 84 | } 85 | ) 86 | # Community cloud is queryable by country code 87 | for country_code in gpu_types["data"]["countryCodes"]: 88 | variables.append( 89 | { 90 | "secureCloud": False, 91 | "countryCode": country_code, 92 | "gpuCount": gpu_count, 93 | "minDisk": None, 94 | "minMemoryInGb": None, 95 | "minVcpuCount": None, 96 | } 97 | ) 98 | 99 | return variables 100 | 101 | 102 | def get_pods(query_variable: dict) -> list[dict]: 103 | resp = make_request( 104 | { 105 | "query": query_pod_types, 106 | "variables": {"lowestPriceInput": query_variable}, 107 | } 108 | ) 109 | return resp["data"]["gpuTypes"] 110 | 111 | 112 | def make_catalog_items(query_variable: dict, pod: dict) -> list[RawCatalogItem]: 113 | if pod["lowestPrice"]["stockStatus"] is None: 114 | return [] 115 | listed_gpu_vendor_and_name = gpu_vendor_and_name(pod["id"]) 116 | if listed_gpu_vendor_and_name is None: 117 | logger.warning(f"{pod['id']} missing in runpod GPU_MAP") 118 | return [] 119 | if query_variable["secureCloud"]: 120 | location = query_variable["dataCenterId"] 121 | on_demand_gpu_price = pod["securePrice"] 122 | spot_gpu_price = pod["secureSpotPrice"] 123 | else: 124 | location = query_variable["countryCode"] 125 | on_demand_gpu_price = pod["communityPrice"] 126 | spot_gpu_price = pod["communitySpotPrice"] 127 | item_template = RawCatalogItem( 128 | instance_name=pod["id"], 129 | location=location, 130 | price=None, # set below 131 | cpu=pod["lowestPrice"]["minVcpu"], 132 | memory=pod["lowestPrice"]["minMemory"], 133 | gpu_vendor=listed_gpu_vendor_and_name[0], 134 | gpu_count=query_variable["gpuCount"], 135 | gpu_name=listed_gpu_vendor_and_name[1], 136 | gpu_memory=pod["memoryInGb"], 137 | spot=None, # set below 138 | disk_size=None, 139 | ) 140 | items = [] 141 | if on_demand_gpu_price: 142 | item = copy.deepcopy(item_template) 143 | item.spot = False 144 | item.price = item.gpu_count * on_demand_gpu_price 145 | items.append(item) 146 | if spot_gpu_price: 147 | item = copy.deepcopy(item_template) 148 | item.spot = True 149 | item.price = item.gpu_count * spot_gpu_price 150 | items.append(item) 151 | return items 152 | 153 | 154 | def make_request(payload: dict): 155 | resp = requests.post(API_URL, json=payload, timeout=10) 156 | resp.raise_for_status() 157 | return resp.json() 158 | 159 | 160 | def get_gpu_map() -> dict[str, tuple[AcceleratorVendor, str]]: 161 | payload_gpus = { 162 | "query": "query GpuTypes { gpuTypes { id manufacturer displayName memoryInGb } }" 163 | } 164 | response = make_request(payload_gpus) 165 | gpu_map: dict[str, tuple[AcceleratorVendor, str]] = {} 166 | for gpu_type in response["data"]["gpuTypes"]: 167 | try: 168 | vendor = AcceleratorVendor.cast(gpu_type["manufacturer"]) 169 | except ValueError: 170 | continue 171 | gpu_name = get_gpu_name(vendor, gpu_type["displayName"]) 172 | if gpu_name: 173 | gpu_map[gpu_type["id"]] = (vendor, gpu_name) 174 | return gpu_map 175 | 176 | 177 | def get_gpu_name(vendor: AcceleratorVendor, name: str) -> Optional[str]: 178 | if vendor == AcceleratorVendor.NVIDIA: 179 | return get_nvidia_gpu_name(name) 180 | if vendor == AcceleratorVendor.AMD: 181 | return get_amd_gpu_name(name) 182 | return None 183 | 184 | 185 | def get_nvidia_gpu_name(name: str) -> Optional[str]: 186 | if "V100" in name: 187 | return "V100" 188 | if name == "H100 NVL": 189 | return "H100NVL" 190 | if name.startswith(("A", "L", "H")): 191 | gpu_name, _, _ = name.partition(" ") 192 | return gpu_name 193 | if name.startswith("RTX A"): 194 | return name.lstrip("RTX ").replace(" ", "") 195 | if name.startswith("RTX"): 196 | return name.replace(" ", "") 197 | return None 198 | 199 | 200 | def get_amd_gpu_name(name: str) -> Optional[str]: 201 | for gpu in KNOWN_AMD_GPUS: 202 | if gpu.name == name: 203 | return name 204 | return None 205 | 206 | 207 | GPU_MAP = get_gpu_map() 208 | 209 | gpu_types_query = """ 210 | query GpuTypes { 211 | countryCodes 212 | dataCenters { 213 | id 214 | name 215 | listed 216 | __typename 217 | } 218 | gpuTypes { 219 | maxGpuCount 220 | maxGpuCount 221 | maxGpuCountCommunityCloud 222 | maxGpuCountSecureCloud 223 | minPodGpuCount 224 | id 225 | displayName 226 | memoryInGb 227 | secureCloud 228 | communityCloud 229 | __typename 230 | } 231 | } 232 | """ 233 | 234 | query_pod_types = """ 235 | query GpuTypes($lowestPriceInput: GpuLowestPriceInput, $gpuTypesInput: GpuTypeFilter) { 236 | gpuTypes(input: $gpuTypesInput) { 237 | lowestPrice(input: $lowestPriceInput) { 238 | minimumBidPrice 239 | uninterruptablePrice 240 | minVcpu 241 | minMemory 242 | stockStatus 243 | compliance 244 | countryCode 245 | __typename 246 | } 247 | maxGpuCount 248 | id 249 | displayName 250 | memoryInGb 251 | securePrice 252 | secureSpotPrice 253 | communityPrice 254 | communitySpotPrice 255 | oneMonthPrice 256 | threeMonthPrice 257 | sixMonthPrice 258 | secureSpotPrice 259 | __typename 260 | } 261 | } 262 | """ 263 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/tensordock.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from math import ceil 3 | from typing import Optional, TypeVar, Union 4 | 5 | import requests 6 | 7 | from gpuhunt._internal.constraints import get_compute_capability, is_between 8 | from gpuhunt._internal.models import QueryFilter, RawCatalogItem 9 | from gpuhunt.providers import AbstractProvider 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | # https://documenter.getpostman.com/view/20973002/2s8YzMYRDc#2b4a3db0-c162-438c-aae4-6a88afc96fdb 14 | marketplace_hostnodes_url = "https://marketplace.tensordock.com/api/v0/client/deploy/hostnodes" 15 | marketplace_gpus = { 16 | "a100-pcie-80gb": "A100", 17 | "geforcegtx1070-pcie-8gb": "GTX1070", 18 | "geforcertx3060-pcie-12gb": "RTX3060", 19 | "geforcertx3060ti-pcie-8gb": "RTX3060Ti", 20 | "geforcertx3060tilhr-pcie-8gb": "RTX3060TiLHR", 21 | "geforcertx3070-pcie-8gb": "RTX3070", 22 | "geforcertx3070ti-pcie-8gb": "RTX3070Ti", 23 | "geforcertx3080-pcie-10gb": "RTX3080", 24 | "geforcertx3080ti-pcie-12gb": "RTX3080Ti", 25 | "geforcertx3090-pcie-24gb": "RTX3090", 26 | "geforcertx4090-pcie-24gb": "RTX4090", 27 | "l40-pcie-48gb": "L40", 28 | "rtxa4000-pcie-16gb": "A4000", 29 | "rtxa5000-pcie-24gb": "A5000", 30 | "rtxa6000-pcie-48gb": "A6000", 31 | "v100-pcie-16gb": "V100", 32 | } 33 | 34 | RAM_PER_VRAM = 2 35 | RAM_PER_CORE = 6 36 | CPU_DIV = 2 # has to be even 37 | RAM_DIV = 2 # has to be even 38 | 39 | 40 | class TensorDockProvider(AbstractProvider): 41 | NAME = "tensordock" 42 | 43 | def get( 44 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 45 | ) -> list[RawCatalogItem]: 46 | logger.info("Fetching TensorDock offers") 47 | 48 | hostnodes = requests.get(marketplace_hostnodes_url, timeout=10).json()["hostnodes"] 49 | offers = [] 50 | for hostnode, details in hostnodes.items(): 51 | location = details["location"]["country"].lower().replace(" ", "") 52 | if query_filter is not None: 53 | offers += self.optimize_offers( 54 | query_filter, 55 | details["specs"], 56 | hostnode, 57 | location, 58 | balance_resources=balance_resources, 59 | ) 60 | else: # pick maximum possible configuration 61 | for gpu_name, gpu in details["specs"]["gpu"].items(): 62 | if gpu["amount"] == 0: 63 | continue 64 | offers.append( 65 | RawCatalogItem( 66 | instance_name=hostnode, 67 | location=location, 68 | price=round( 69 | sum( 70 | details["specs"][key]["price"] 71 | * details["specs"][key]["amount"] 72 | for key in ("cpu", "ram", "storage") 73 | ) 74 | + gpu["amount"] * gpu["price"], 75 | 5, 76 | ), 77 | cpu=round_down(details["specs"]["cpu"]["amount"], 2), 78 | memory=float(round_down(details["specs"]["ram"]["amount"], 2)), 79 | gpu_vendor=None, 80 | gpu_count=gpu["amount"], 81 | gpu_name=convert_gpu_name(gpu_name), 82 | gpu_memory=float(gpu["vram"]), 83 | spot=False, 84 | disk_size=float(details["specs"]["storage"]["amount"]), 85 | ) 86 | ) 87 | return sorted(offers, key=lambda i: i.price) 88 | 89 | @staticmethod 90 | def optimize_offers( 91 | q: QueryFilter, 92 | specs: dict, 93 | instance_name: str, 94 | location: str, 95 | balance_resources: bool = True, 96 | ) -> list[RawCatalogItem]: 97 | """ 98 | Picks the best offer for the given query filter 99 | Doesn't respect max values, additional filtering is required 100 | 101 | Args: 102 | q: query filter 103 | specs: hostnode specs 104 | instance_name: hostnode `instance_name` 105 | location: hostnode `location` 106 | balance_resources: if True, will override query filter min values 107 | """ 108 | offers = [] 109 | for gpu_model, gpu_info in specs["gpu"].items(): 110 | # filter by single gpu characteristics 111 | if not is_between(gpu_info["vram"], q.min_gpu_memory, q.max_gpu_memory): 112 | continue 113 | gpu_name = convert_gpu_name(gpu_model) 114 | if q.gpu_name is not None and gpu_name.lower() not in map(str.lower, q.gpu_name): 115 | continue 116 | cc = get_compute_capability(gpu_name) 117 | if not cc or not is_between(cc, q.min_compute_capability, q.max_compute_capability): 118 | continue 119 | 120 | for gpu_count in range(1, gpu_info["amount"] + 1): # try all possible gpu counts 121 | if not is_between(gpu_count, q.min_gpu_count, q.max_gpu_count): 122 | continue 123 | total_gpu_memory = gpu_count * gpu_info["vram"] 124 | if not is_between( 125 | total_gpu_memory, q.min_total_gpu_memory, q.max_total_gpu_memory 126 | ): 127 | continue 128 | 129 | # we can't take 100% of CPU/RAM/storage if we don't take all GPUs 130 | multiplier = 0.75 if gpu_count < gpu_info["amount"] else 1 131 | available_memory = round_down(multiplier * specs["ram"]["amount"], RAM_DIV) 132 | available_cpu = round_down(multiplier * specs["cpu"]["amount"], CPU_DIV) 133 | available_disk = int(multiplier * specs["storage"]["amount"]) 134 | 135 | memory = None 136 | if q.min_memory is not None: 137 | if q.min_memory > available_memory: 138 | continue 139 | memory = round_up( 140 | max_none( 141 | q.min_memory, 142 | gpu_count, # 1 GB per GPU at least 143 | q.min_cpu, # 1 GB per CPU at least 144 | ), 145 | RAM_DIV, 146 | ) 147 | if memory is None or balance_resources: 148 | memory = max_none( 149 | memory, 150 | min_none( 151 | available_memory, 152 | round_up(RAM_PER_VRAM * total_gpu_memory, RAM_DIV), 153 | round_down(q.max_memory, RAM_DIV), # can be None 154 | ), 155 | ) 156 | 157 | cpu = None 158 | if q.min_cpu is not None: 159 | if q.min_cpu > available_cpu: 160 | continue 161 | # 1 CPU per GPU at least 162 | cpu = round_up(max(q.min_cpu, gpu_count), CPU_DIV) 163 | if cpu is None or balance_resources: 164 | cpu = max_none( 165 | cpu, 166 | min_none( 167 | available_cpu, 168 | round_up(ceil(memory / RAM_PER_CORE), CPU_DIV), 169 | round_down(q.max_cpu, CPU_DIV), # can be None 170 | ), 171 | ) 172 | 173 | disk_size = None 174 | if q.min_disk_size is not None: 175 | if q.min_disk_size > available_disk: 176 | continue 177 | disk_size = q.min_disk_size 178 | if disk_size is None or balance_resources: 179 | disk_size = max_none( 180 | disk_size, 181 | min_none( 182 | available_disk, 183 | max(memory, total_gpu_memory), 184 | q.max_disk_size, # can be None 185 | ), 186 | ) 187 | 188 | price = round( 189 | memory * specs["ram"]["price"] 190 | + cpu * specs["cpu"]["price"] 191 | + disk_size * specs["storage"]["price"] 192 | + gpu_count * gpu_info["price"], 193 | 5, 194 | ) 195 | 196 | offer = RawCatalogItem( 197 | instance_name=instance_name, 198 | location=location, 199 | price=price, 200 | cpu=cpu, 201 | memory=float(memory), 202 | gpu_vendor=None, 203 | gpu_name=gpu_name, 204 | gpu_count=gpu_count, 205 | gpu_memory=float(gpu_info["vram"]), 206 | spot=False, 207 | disk_size=disk_size, 208 | ) 209 | offers.append(offer) 210 | break # stop increasing gpu count 211 | return offers 212 | 213 | 214 | def round_up(value: Optional[Union[int, float]], step: int) -> Optional[int]: 215 | if value is None: 216 | return None 217 | return round_down(value + step - 1, step) 218 | 219 | 220 | def round_down(value: Optional[Union[int, float]], step: int) -> Optional[int]: 221 | if value is None: 222 | return None 223 | return value // step * step 224 | 225 | 226 | T = TypeVar("T", bound=Union[int, float]) 227 | 228 | 229 | def min_none(*args: Optional[T]) -> T: 230 | return min(v for v in args if v is not None) 231 | 232 | 233 | def max_none(*args: Optional[T]) -> T: 234 | return max(v for v in args if v is not None) 235 | 236 | 237 | def convert_gpu_name(model: str) -> str: 238 | """ 239 | >>> convert_gpu_name("geforcegtx1070-pcie-8gb") 240 | 'GTX1070' 241 | >>> convert_gpu_name("geforcertx1111ti-pcie-13gb") 242 | 'RTX1111Ti' 243 | >>> convert_gpu_name("a100-pcie-40gb") 244 | 'A100' 245 | """ 246 | if model in marketplace_gpus: 247 | return marketplace_gpus[model] 248 | model = model.split("-")[0] 249 | prefix = "geforce" 250 | if model.startswith(prefix): 251 | model = model[len(prefix) :] 252 | return model.upper().replace("TI", "Ti") 253 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/vastai.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import re 4 | from collections import defaultdict 5 | from typing import Any, Literal, Optional, Union 6 | 7 | import requests 8 | 9 | from gpuhunt._internal.constraints import correct_gpu_memory_gib 10 | from gpuhunt._internal.models import QueryFilter, RawCatalogItem 11 | from gpuhunt.providers import AbstractProvider 12 | 13 | logger = logging.getLogger(__name__) 14 | bundles_url = "https://console.vast.ai/api/v0/bundles/" 15 | kilo = 1000 16 | # Maximum number of offers to fetch when GPU name mapping fails. 17 | Operators = Literal["lt", "lte", "eq", "gte", "gt"] 18 | FilterValue = Union[int, float, str, bool] 19 | 20 | 21 | class VastAIProvider(AbstractProvider): 22 | NAME = "vastai" 23 | 24 | def __init__(self, extra_filters: Optional[dict[str, dict[Operators, FilterValue]]] = None): 25 | self.extra_filters = extra_filters 26 | 27 | def get( 28 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 29 | ) -> list[RawCatalogItem]: 30 | filters: dict[str, Any] = self.make_filters(query_filter or QueryFilter()) 31 | if self.extra_filters: 32 | for key, constraints in self.extra_filters.items(): 33 | for op, value in constraints.items(): 34 | filters[key][op] = value 35 | resp = requests.post(bundles_url, json=filters, timeout=10) 36 | resp.raise_for_status() 37 | data = resp.json() 38 | 39 | instance_offers = [] 40 | for offer in data["offers"]: 41 | cpu_cores = offer["cpu_cores"] 42 | # although this is not stated in the docs, the value can be None 43 | if cpu_cores: 44 | memory = float( 45 | int(offer["cpu_ram"] * offer["cpu_cores_effective"] / cpu_cores / kilo) 46 | ) 47 | else: 48 | memory = 0.0 49 | disk_size = query_filter and query_filter.min_disk_size or offer["disk_space"] 50 | if not self.satisfies_filters(offer, filters): 51 | logger.warning("Offer %s does not satisfy filters", offer["id"]) 52 | continue 53 | gpu_name = get_dstack_gpu_name(offer["gpu_name"]) 54 | gpu_memory = correct_gpu_memory_gib(gpu_name, offer["gpu_ram"]) 55 | ondemand_offer = RawCatalogItem( 56 | instance_name=str(offer["id"]), 57 | location=get_location(offer["geolocation"]), 58 | # storage_cost is $/gb/month 59 | price=round( 60 | offer["dph_base"] + disk_size * offer["storage_cost"] / 30 / 24, 61 | 5, 62 | ), 63 | cpu=int(offer["cpu_cores_effective"]), 64 | memory=memory, 65 | gpu_vendor=None, 66 | gpu_count=offer["num_gpus"], 67 | gpu_name=gpu_name, 68 | gpu_memory=float(gpu_memory), 69 | spot=False, 70 | disk_size=disk_size, 71 | ) 72 | instance_offers.append(ondemand_offer) 73 | 74 | if offer.get("min_bid"): 75 | spot_offer = copy.deepcopy(ondemand_offer) 76 | spot_offer.price = round(offer["min_bid"], 5) 77 | spot_offer.spot = True 78 | instance_offers.append(spot_offer) 79 | return instance_offers 80 | 81 | @staticmethod 82 | def make_filters(q: QueryFilter) -> dict[str, dict[Operators, FilterValue]]: 83 | filters = defaultdict(dict) 84 | if q.min_cpu is not None: 85 | filters["cpu_cores"]["gte"] = q.min_cpu 86 | if q.max_cpu is not None: 87 | filters["cpu_cores"]["lte"] = q.max_cpu 88 | if q.min_memory is not None: 89 | filters["cpu_ram"]["gte"] = q.min_memory * kilo 90 | if q.max_memory is not None: 91 | filters["cpu_ram"]["lte"] = q.max_memory * kilo 92 | if q.min_gpu_count is not None: 93 | filters["num_gpus"]["gte"] = q.min_gpu_count 94 | if q.max_gpu_count is not None: 95 | filters["num_gpus"]["lte"] = q.max_gpu_count 96 | if q.gpu_name: 97 | vastai_gpu_names = [] 98 | for g in q.gpu_name: 99 | vastai_gpu_names.extend(get_vastai_gpu_names(g)) 100 | if vastai_gpu_names: 101 | filters["gpu_name"]["in"] = vastai_gpu_names 102 | else: 103 | # If GPU name mapping fails, fetch all offers (to filter locally) 104 | filters["limit"] = 3000 105 | # See correct_gpu_memory_gib in gpuhunt/_internal/constraints.py 106 | if q.min_gpu_memory is not None: 107 | filters["gpu_ram"]["gte"] = q.min_gpu_memory * 1024 * 0.93 108 | if q.max_gpu_memory is not None: 109 | filters["gpu_ram"]["lte"] = q.max_gpu_memory * 1024 * 1.07 110 | if q.min_disk_size is not None: 111 | filters["disk_space"]["gte"] = q.min_disk_size 112 | if q.max_disk_size is not None: 113 | filters["disk_space"]["lte"] = q.max_disk_size 114 | if q.min_price is not None: 115 | filters["dph_total"]["gte"] = q.min_price 116 | if q.max_price is not None: 117 | filters["dph_total"]["lte"] = q.max_price 118 | # TODO(egor-s): add compute capability info for all GPUs 119 | if q.min_compute_capability is not None: 120 | filters["compute_capability"]["gte"] = compute_cap(q.min_compute_capability) 121 | if q.max_compute_capability is not None: 122 | filters["compute_capability"]["lte"] = compute_cap(q.max_compute_capability) 123 | filters["rentable"]["eq"] = True 124 | filters["rented"]["eq"] = False 125 | filters["order"] = [["score", "desc"]] 126 | return filters 127 | 128 | @staticmethod 129 | def satisfies_filters(offer: dict, filters: dict[str, dict[Operators, FilterValue]]) -> bool: 130 | for key in filters: 131 | if key not in offer: 132 | continue 133 | for op, value in filters[key].items(): 134 | if op == "lt" and offer[key] >= value: 135 | return False 136 | if op == "lte" and offer[key] > value: 137 | return False 138 | if op == "eq" and offer[key] != value: 139 | return False 140 | if op == "gte" and offer[key] < value: 141 | return False 142 | if op == "gt" and offer[key] <= value: 143 | return False 144 | return True 145 | 146 | 147 | GPU_MAPPING = { 148 | "L40S": ["L40S"], 149 | "L40": ["L40"], 150 | "A10": ["A10"], 151 | "A40": ["A40"], 152 | "L4": ["L4"], 153 | "A100X": ["A100X"], 154 | "H200": ["H200"], 155 | "H200NVL": ["H200 NVL"], 156 | "P100": ["Tesla P100"], 157 | "T4": ["Tesla T4"], 158 | "P4": ["Tesla P4"], 159 | "P40": ["Tesla P40"], 160 | "V100": ["Tesla V100"], 161 | "A100": ["A100 PCIE", "A100 SXM4"], 162 | "A800PCIE": ["A800 PCIE"], 163 | "H100": ["H100 PCIE", "H100 SXM"], 164 | "H100NVL": ["H100 NVL"], 165 | } 166 | 167 | GPU_MAPPING_RULES = { 168 | r"^RTX(\d{4}\D?)$": r"RTX \1", # RTX4090 -> RTX 4090, RTX4090S -> RTX 4090S 169 | r"^QRTX(\d{4})$": r"Q RTX \1", # QRTX8000 -> Q RTX 8000 170 | r"^RTX(\d{4})Ada$": r"RTX \1Ada", # RTX4090Ada -> RTX 4090Ada 171 | r"^RTX(\d{4}\D?)Ti$": r"RTX \1 Ti", # RTX4090Ti -> RTX 4090 Ti 172 | r"^A(\d{4})": r"RTX A\1", # A5000 -> RTX A5000 173 | } 174 | 175 | 176 | def get_vastai_gpu_names(gpu_name: str) -> list[str]: 177 | if gpu_name in GPU_MAPPING: 178 | return GPU_MAPPING[gpu_name] 179 | for pattern, replacement in GPU_MAPPING_RULES.items(): 180 | if re.match(pattern, gpu_name): 181 | return [re.sub(pattern, replacement, gpu_name)] 182 | return [] 183 | 184 | 185 | def get_dstack_gpu_name(gpu_name: str) -> str: 186 | """ 187 | Convert VastAI GPU names to a standardized format using essential heuristics 188 | """ 189 | gpu_name = gpu_name.replace("RTX A", "A").replace("Tesla ", "") 190 | if gpu_name.startswith("A100 "): 191 | return "A100" 192 | if gpu_name.startswith("H100 ") and "NVL" not in gpu_name: 193 | return "H100" 194 | return gpu_name.replace(" ", "") 195 | 196 | 197 | def get_location(location: Optional[str]) -> str: 198 | if location is None: 199 | return "" 200 | try: 201 | city, country = location.replace(", ", ",").split(",") 202 | location = f"{country}-{city}" 203 | except ValueError: 204 | pass 205 | return location.lower().replace(" ", "") 206 | 207 | 208 | def compute_cap(cc: tuple[int, int]) -> str: 209 | """ 210 | >>> compute_cap((7, 0)) 211 | '700' 212 | >>> compute_cap((7, 5)) 213 | '750' 214 | """ 215 | major, minor = cc 216 | return f"{major}{str(minor).ljust(2, '0')}" 217 | -------------------------------------------------------------------------------- /src/gpuhunt/providers/vultr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Optional, cast 3 | 4 | import requests 5 | from requests import Response 6 | 7 | from gpuhunt import QueryFilter, RawCatalogItem 8 | from gpuhunt._internal.constraints import KNOWN_AMD_GPUS, KNOWN_NVIDIA_GPUS, is_nvidia_superchip 9 | from gpuhunt._internal.models import AcceleratorVendor, CPUArchitecture 10 | from gpuhunt.providers import AbstractProvider 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | API_URL = "https://api.vultr.com/v2" 15 | 16 | 17 | class VultrProvider(AbstractProvider): 18 | NAME = "vultr" 19 | 20 | def get( 21 | self, query_filter: Optional[QueryFilter] = None, balance_resources: bool = True 22 | ) -> list[RawCatalogItem]: 23 | offers = fetch_offers() 24 | return sorted(offers, key=lambda i: i.price) 25 | 26 | 27 | def fetch_offers() -> list[RawCatalogItem]: 28 | """Fetch plans with types: 29 | 1. Cloud GPU (vcg), 30 | 2. Bare Metal (vbm), 31 | 3. and other CPU plans, including: 32 | Cloud Compute (vc2), 33 | High Frequency Compute (vhf), 34 | High Performance (vhp), 35 | All optimized Cloud Types (voc)""" 36 | bare_metal_plans_response = _make_request("GET", "/plans-metal?per_page=500") 37 | other_plans_response = _make_request("GET", "/plans?type=all&per_page=500") 38 | return convert_response_to_raw_catalog_items(bare_metal_plans_response, other_plans_response) 39 | 40 | 41 | def convert_response_to_raw_catalog_items( 42 | bare_metal_plans_response: Response, other_plans_response: Response 43 | ) -> list[RawCatalogItem]: 44 | catalog_items = [] 45 | 46 | bare_metal_plans = bare_metal_plans_response.json()["plans_metal"] 47 | other_plans = other_plans_response.json()["plans"] 48 | 49 | for plan in bare_metal_plans: 50 | for location in plan["locations"]: 51 | catalog_item = get_bare_metal_plans(plan, location) 52 | if catalog_item: 53 | catalog_items.append(catalog_item) 54 | 55 | for plan in other_plans: 56 | for location in plan["locations"]: 57 | catalog_item = get_instance_plans(plan, location) 58 | if catalog_item: 59 | catalog_items.append(catalog_item) 60 | 61 | return catalog_items 62 | 63 | 64 | def get_bare_metal_plans(plan: dict, location: str) -> Optional[RawCatalogItem]: 65 | cpu_arch = CPUArchitecture.X86 66 | gpu_count, gpu_name, gpu_memory, gpu_vendor = 0, None, None, None 67 | if "gpu" in plan["id"]: 68 | if plan["id"] not in BARE_METAL_GPU_DETAILS: 69 | logger.warning("Skipping unknown GPU plan %s", plan["id"]) 70 | return None 71 | gpu_count, gpu_name, gpu_memory = BARE_METAL_GPU_DETAILS[plan["id"]] 72 | if is_nvidia_superchip(gpu_name): 73 | cpu_arch = CPUArchitecture.ARM 74 | gpu_vendor = get_gpu_vendor(gpu_name) 75 | if gpu_vendor is None: 76 | logger.warning("Unknown GPU vendor for plan %s, skipping", plan["id"]) 77 | return None 78 | return RawCatalogItem( 79 | instance_name=plan["id"], 80 | location=location, 81 | price=plan["hourly_cost"], 82 | cpu_arch=cpu_arch.value, 83 | cpu=plan["cpu_threads"], 84 | memory=plan["ram"] / 1024, 85 | gpu_count=gpu_count, 86 | gpu_name=gpu_name, 87 | gpu_memory=gpu_memory, 88 | gpu_vendor=gpu_vendor, 89 | spot=False, 90 | disk_size=plan["disk"], 91 | ) 92 | 93 | 94 | def get_instance_plans(plan: dict, location: str) -> Optional[RawCatalogItem]: 95 | cpu_arch = CPUArchitecture.X86 96 | plan_type = plan["type"] 97 | if plan_type in ["vc2", "vhf", "vhp", "voc"]: 98 | return RawCatalogItem( 99 | instance_name=plan["id"], 100 | location=location, 101 | price=plan["hourly_cost"], 102 | cpu_arch=cpu_arch.value, 103 | cpu=plan["vcpu_count"], 104 | memory=plan["ram"] / 1024, 105 | gpu_count=0, 106 | gpu_name=None, 107 | gpu_memory=None, 108 | gpu_vendor=None, 109 | spot=False, 110 | disk_size=plan["disk"], 111 | ) 112 | elif plan_type == "vcg": 113 | gpu_type = cast(Optional[str], plan.get("gpu_type")) 114 | if not gpu_type: 115 | logger.warning("Missing gpu_type for plan %s, skipping", plan["id"]) 116 | return None 117 | if "_" not in gpu_type: 118 | logger.warning( 119 | "Failed to parse gpu_type %s for plan %s, skipping", gpu_type, plan["id"] 120 | ) 121 | return None 122 | gpu_name = gpu_type.split("_")[1] 123 | gpu_vendor = get_gpu_vendor(gpu_name) 124 | if not gpu_vendor: 125 | logger.warning( 126 | "Failed to detect GPU vendor %s for plan %s, skipping", gpu_type, plan["id"] 127 | ) 128 | return None 129 | gpu_memory = get_gpu_memory(gpu_name) 130 | if not gpu_memory: 131 | logger.warning( 132 | "Failed to detect GPU memory %s for plan %s, skipping", gpu_type, plan["id"] 133 | ) 134 | return None 135 | gpu_memory_total = cast(int, plan["gpu_vram_gb"]) 136 | # For fractional GPU, gpu_count=1 137 | gpu_count = max(1, gpu_memory_total // gpu_memory) 138 | if is_nvidia_superchip(gpu_name): 139 | cpu_arch = CPUArchitecture.ARM 140 | return RawCatalogItem( 141 | instance_name=plan["id"], 142 | location=location, 143 | price=plan["hourly_cost"], 144 | cpu_arch=cpu_arch.value, 145 | cpu=plan["vcpu_count"], 146 | memory=plan["ram"] / 1024, 147 | gpu_count=gpu_count, 148 | gpu_name=gpu_name, 149 | gpu_memory=gpu_memory_total / gpu_count, 150 | gpu_vendor=gpu_vendor, 151 | spot=False, 152 | disk_size=plan["disk"], 153 | ) 154 | return None 155 | 156 | 157 | def get_gpu_memory(gpu_name: str) -> Optional[int]: 158 | if gpu_name.upper() == "A100": 159 | return 80 # VULTR A100 instances have 80GB 160 | for gpu in KNOWN_NVIDIA_GPUS: 161 | if gpu.name.upper() == gpu_name.upper(): 162 | return gpu.memory 163 | 164 | for gpu in KNOWN_AMD_GPUS: 165 | if gpu.name.upper() == gpu_name.upper(): 166 | return gpu.memory 167 | logger.warning(f"Unknown GPU {gpu_name}") 168 | return None 169 | 170 | 171 | def get_gpu_vendor(gpu_name: Optional[str]) -> Optional[str]: 172 | if gpu_name is None: 173 | return None 174 | for gpu in KNOWN_NVIDIA_GPUS: 175 | if gpu.name.upper() == gpu_name.upper(): 176 | return AcceleratorVendor.NVIDIA.value 177 | for gpu in KNOWN_AMD_GPUS: 178 | if gpu.name.upper() == gpu_name.upper(): 179 | return AcceleratorVendor.AMD.value 180 | return None 181 | 182 | 183 | def _make_request(method: str, path: str, data: Any = None) -> Response: 184 | response = requests.request( 185 | method=method, 186 | url=API_URL + path, 187 | json=data, 188 | timeout=30, 189 | ) 190 | response.raise_for_status() 191 | return response 192 | 193 | 194 | BARE_METAL_GPU_DETAILS = { 195 | "vbm-48c-1024gb-4-a100-gpu": (4, "A100", 80), 196 | "vbm-112c-2048gb-8-h100-gpu": (8, "H100", 80), 197 | "vbm-112c-2048gb-8-a100-gpu": (8, "A100", 80), 198 | "vbm-64c-2048gb-8-l40-gpu": (8, "L40S", 48), 199 | "vbm-72c-480gb-gh200-gpu": (1, "GH200", 96), 200 | "vbm-256c-2048gb-8-mi300x-gpu": (8, "MI300X", 192), 201 | } 202 | -------------------------------------------------------------------------------- /src/gpuhunt/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dstackai/gpuhunt/fcc8f33dc037503956d7272c91e730e6584ec492/src/gpuhunt/resources/__init__.py -------------------------------------------------------------------------------- /src/gpuhunt/resources/tpu_pricing.json: -------------------------------------------------------------------------------- 1 | { 2 | "TPU v6e": { 3 | "us-east1": { 4 | "Location": "South Carolina", 5 | "On Demand (USD)": 2.7000, 6 | "Spot (USD)": 1.8900 7 | }, 8 | "us-east5": { 9 | "Location": "Ohio", 10 | "On Demand (USD)": 2.7000, 11 | "Spot (USD)": 1.8900 12 | }, 13 | "europe-west4": { 14 | "Location": "Netherlands", 15 | "On Demand (USD)": 2.9700, 16 | "Spot (USD)": 2.0800 17 | 18 | }, 19 | "asia-northeast1": { 20 | "Location": "Tokio", 21 | "On Demand (USD)": 3.2400, 22 | "Spot (USD)": 2.2700 23 | } 24 | }, 25 | "TPU v5p": { 26 | "us-east5": { 27 | "Location": "Columbus", 28 | "On Demand (USD)": 4.2000, 29 | "Spot (USD)": 2.4150 30 | }, 31 | "us-east1": { 32 | "Location": "South Carolina", 33 | "On Demand (USD)": 4.2000, 34 | "Spot (USD)": 2.1000 35 | } 36 | }, 37 | "TPU v5e": { 38 | "us-central1": { 39 | "Location": "Iowa", 40 | "On Demand (USD)": 1.2000, 41 | "Spot (USD)": 0.6300 42 | }, 43 | "us-east5": { 44 | "Location": "Ohio", 45 | "On Demand (USD)": 1.2000, 46 | "Spot (USD)": 0.6000 47 | }, 48 | "us-south1": { 49 | "Location": "Dallas", 50 | "On Demand (USD)": 1.416, 51 | "Spot (USD)": 0.708 52 | }, 53 | "us-west1": { 54 | "Location": "Oregon", 55 | "On Demand (USD)": 1.2000, 56 | "Spot (USD)": 0.6300 57 | }, 58 | "us-west4": { 59 | "Location": "Nevada", 60 | "On Demand (USD)": 1.2000, 61 | "Spot (USD)": 0.5100 62 | }, 63 | "europe-west1": { 64 | "Location": "Belgium", 65 | "On Demand (USD)": 1.3213, 66 | "Spot (USD)": 0.66066 67 | }, 68 | "europe-west4": { 69 | "Location": "Netherlands", 70 | "On Demand (USD)": 1.5600, 71 | "Spot (USD)": 0.663 72 | }, 73 | "asia-southeast1": { 74 | "Location": "Singapore", 75 | "On Demand (USD)": 1.5600, 76 | "Spot (USD)": 0.7800 77 | } 78 | }, 79 | "TPU v4 pod": { 80 | "us-central2": { 81 | "Location": "Oklahoma", 82 | "On Demand (USD)": 3.2200, 83 | "Spot (USD)": 3.686 84 | } 85 | }, 86 | "TPU v3 pod": { 87 | "europe-west4": { 88 | "Location": "Netherlands", 89 | "On Demand (USD)": 2.0000, 90 | "Spot (USD)": 2.6400 91 | } 92 | }, 93 | "TPU v3 device": { 94 | "europe-west4": { 95 | "Location": "Netherlands", 96 | "On Demand (USD)": 2.2000, 97 | "Spot (USD)": 2.2440 98 | } 99 | }, 100 | "TPU v2 pod": { 101 | "us-central1": { 102 | "Location": "Iowa", 103 | "On Demand (USD)": 1.5000, 104 | "Spot (USD)": 1.5300 105 | }, 106 | "europe-west4": { 107 | "Location": "Netherlands", 108 | "On Demand (USD)": 1.5000, 109 | "Spot (USD)": 1.5300 110 | } 111 | }, 112 | "TPU v2 device": { 113 | "asia-east1": { 114 | "Location": "Taiwan", 115 | "On Demand (USD)": 1.3050, 116 | "Spot (USD)": 0.3915 117 | }, 118 | "europe-west4": { 119 | "Location": "Netherlands", 120 | "On Demand (USD)": 1.2375, 121 | "Spot (USD)": 1.26225 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/gpuhunt/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dstackai/gpuhunt/fcc8f33dc037503956d7272c91e730e6584ec492/src/gpuhunt/scripts/__init__.py -------------------------------------------------------------------------------- /src/gpuhunt/scripts/catalog_v1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dstackai/gpuhunt/fcc8f33dc037503956d7272c91e730e6584ec492/src/gpuhunt/scripts/catalog_v1/__init__.py -------------------------------------------------------------------------------- /src/gpuhunt/scripts/catalog_v1/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from collections.abc import Sequence 4 | from pathlib import Path 5 | from textwrap import dedent 6 | from typing import Optional 7 | 8 | from gpuhunt._internal import storage 9 | from gpuhunt._internal.utils import configure_logging 10 | 11 | 12 | def main(args: Optional[Sequence[str]] = None): 13 | configure_logging() 14 | parser = argparse.ArgumentParser( 15 | description=dedent( 16 | """ 17 | Convert a v2 catalog to a v1 catalog. Legacy v1 catalogs are used by older 18 | gpuhunt versions that do not respect the `flags` field. Any catalog items 19 | with flags are filtered out when converting to v1. 20 | """ 21 | ) 22 | ) 23 | parser.add_argument("--input", type=Path, required=True, help="The v2 catalog file to read") 24 | parser.add_argument("--output", type=Path, required=True, help="The v1 catalog file to write") 25 | args = parser.parse_args(args) 26 | storage.convert_catalog_v2_to_v1(path_v2=args.input, path_v1=args.output) 27 | logging.info("Converted %s -> %s", args.input, args.output) 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /src/gpuhunt/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.0" 2 | -------------------------------------------------------------------------------- /src/integrity_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dstackai/gpuhunt/fcc8f33dc037503956d7272c91e730e6584ec492/src/integrity_tests/__init__.py -------------------------------------------------------------------------------- /src/integrity_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | 7 | @pytest.fixture 8 | def catalog_dir() -> Path: 9 | return Path(os.environ["CATALOG_DIR"]) 10 | -------------------------------------------------------------------------------- /src/integrity_tests/test_all.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | files = sorted(Path(os.environ["CATALOG_DIR"]).glob("*.csv")) 8 | 9 | 10 | def catalog_name(catalog) -> str: 11 | return catalog.name 12 | 13 | 14 | class TestAllCatalogs: 15 | @pytest.fixture(params=files, ids=catalog_name) 16 | def catalog(self, request): 17 | yield request.param 18 | 19 | def test_non_zero_cost(self, catalog): 20 | reader = csv.DictReader(catalog.open()) 21 | for row in reader: 22 | assert float(row["price"]) != pytest.approx(0), str(row) 23 | -------------------------------------------------------------------------------- /src/integrity_tests/test_aws.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture 7 | def data(catalog_dir: Path) -> str: 8 | return (catalog_dir / "aws.csv").read_text() 9 | 10 | 11 | class TestAWSCatalog: 12 | def test_m5_large_regions(self, data: str): 13 | instance = "m5.large" 14 | regions = [ 15 | "af-south-1", 16 | "ap-east-1", 17 | "ap-northeast-1", 18 | "ap-northeast-2", 19 | "ap-northeast-3", 20 | "ap-south-1", 21 | "ap-south-2", 22 | "ap-southeast-1", 23 | "ap-southeast-2", 24 | "ap-southeast-3", 25 | "ap-southeast-4", 26 | "ca-central-1", 27 | "eu-central-1", 28 | "eu-central-2", 29 | "eu-north-1", 30 | "eu-south-1", 31 | "eu-south-2", 32 | "eu-west-1", 33 | "eu-west-2", 34 | "eu-west-3", 35 | "il-central-1", 36 | "me-central-1", 37 | "me-south-1", 38 | "sa-east-1", 39 | "us-east-1", 40 | "us-east-2", 41 | "us-gov-east-1", 42 | "us-gov-west-1", 43 | "us-west-1", 44 | "us-west-2", 45 | "us-west-2-lax-1", 46 | ] 47 | assert all(f"\n{instance},{i}," in data for i in regions) 48 | 49 | def test_spots_presented(self, data: str): 50 | assert ",True," in data 51 | 52 | def test_gpu_presented(self, data: str): 53 | gpus = [ 54 | # AWS pricing csv does not include H200 (p5e.) offers. 55 | # TODO: Add CapacityBlocks offers to support H200. 56 | # "H200", 57 | "H100", 58 | "A100", 59 | "A10G", 60 | "T4", 61 | "V100", 62 | "L40S", 63 | "L4", 64 | ] 65 | assert all(f",{i}," in data for i in gpus) 66 | -------------------------------------------------------------------------------- /src/integrity_tests/test_azure.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | 7 | @pytest.fixture 8 | def data_rows(catalog_dir: Path) -> list[dict]: 9 | with open(catalog_dir / "azure.csv") as f: 10 | return list(csv.DictReader(f)) 11 | 12 | 13 | class TestAzureCatalog: 14 | def test_standard_d2s_v3_locations(self, data_rows: list[dict]): 15 | expected_locations = { 16 | "attatlanta1", 17 | "attdallas1", 18 | "attdetroit1", 19 | "attnewyork1", 20 | "australiacentral", 21 | "australiacentral2", 22 | "australiaeast", 23 | "australiasoutheast", 24 | "australiasoutheast", 25 | "brazilsouth", 26 | "brazilsoutheast", 27 | "canadacentral", 28 | "canadaeast", 29 | "centralindia", 30 | "centralus", 31 | "eastasia", 32 | "eastus", 33 | "eastus2", 34 | "francecentral", 35 | "francesouth", 36 | "germanynorth", 37 | "germanywestcentral", 38 | "indonesiacentral", 39 | "israelcentral", 40 | "italynorth", 41 | "japaneast", 42 | "japanwest", 43 | "jioindiacentral", 44 | "jioindiawest", 45 | "koreacentral", 46 | "koreasouth", 47 | "malaysiawest", 48 | "mexicocentral", 49 | "newzealandnorth", 50 | "northcentralus", 51 | "northeurope", 52 | "norwayeast", 53 | "norwaywest", 54 | "polandcentral", 55 | "qatarcentral", 56 | "sgxsingapore1", 57 | "southafricanorth", 58 | "southafricawest", 59 | "southcentralus", 60 | "southeastasia", 61 | "southindia", 62 | "spaincentral", 63 | "swedencentral", 64 | "swedensouth", 65 | "switzerlandnorth", 66 | "switzerlandwest", 67 | "uaecentral", 68 | "uaenorth", 69 | "uksouth", 70 | "ukwest", 71 | "usgovarizona", 72 | "usgovtexas", 73 | "usgovvirginia", 74 | "westcentralus", 75 | "westeurope", 76 | "westindia", 77 | "westus", 78 | "westus2", 79 | "westus3", 80 | } 81 | locations = set( 82 | row["location"] for row in data_rows if row["instance_name"] == "Standard_D2s_v3" 83 | ) 84 | missing = expected_locations - locations 85 | assert not missing 86 | 87 | def test_spots_presented(self, data_rows: list[dict]): 88 | assert any(row["spot"] == "True" for row in data_rows) 89 | 90 | def test_ondemand_presented(self, data_rows: list[dict]): 91 | assert any(row["spot"] == "False" for row in data_rows) 92 | 93 | def test_gpu_presented(self, data_rows: list[dict]): 94 | expected_gpus = { 95 | "A100", 96 | "A10", 97 | "T4", 98 | "V100", 99 | } 100 | gpus = set(row["gpu_name"] for row in data_rows if row["gpu_name"]) 101 | assert expected_gpus == gpus 102 | 103 | def test_both_a100_presented(self, data_rows: list[dict]): 104 | expected_gpu_memory = {"40.0", "80.0"} 105 | gpu_memory = set(row["gpu_memory"] for row in data_rows if row["gpu_name"] == "A100") 106 | assert expected_gpu_memory == gpu_memory 107 | -------------------------------------------------------------------------------- /src/integrity_tests/test_cloudrift.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from operator import itemgetter 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | 8 | @pytest.fixture 9 | def data_rows(catalog_dir: Path) -> list[dict]: 10 | with open(catalog_dir / "cloudrift.csv") as f: 11 | return list(csv.DictReader(f)) 12 | 13 | 14 | # TODO: Add RTX5090 and RTX6000PRO and others after evaluation 15 | @pytest.mark.parametrize("gpu", ["RTX4090"]) 16 | def test_gpu_present(gpu: str, data_rows: list[dict]): 17 | assert gpu in map(itemgetter("gpu_name"), data_rows) 18 | 19 | 20 | # TODO: Add 3, 4, 5, ... 8 21 | @pytest.mark.parametrize("gpu_count", [1, 2]) 22 | def test_gpu_count_present(gpu_count: int, data_rows: list[dict]): 23 | assert str(gpu_count) in map(itemgetter("gpu_count"), data_rows) 24 | 25 | 26 | @pytest.mark.parametrize("location", ["us-east-nc-nr-1"]) 27 | def test_location_is_present(location: str, data_rows: list[dict]): 28 | assert location in map(itemgetter("location"), data_rows) 29 | 30 | 31 | def test_non_zero_price(data_rows: list[dict]): 32 | assert all(float(p) > 0 for p in map(itemgetter("price"), data_rows)) 33 | -------------------------------------------------------------------------------- /src/integrity_tests/test_datacrunch.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from collections import Counter 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | from gpuhunt.providers.datacrunch import ALL_AMD_GPUS, GPU_MAP 8 | 9 | 10 | @pytest.fixture 11 | def data_rows(catalog_dir: Path) -> list[dict]: 12 | file = catalog_dir / "datacrunch.csv" 13 | reader = csv.DictReader(file.open()) 14 | return list(reader) 15 | 16 | 17 | def select_row(rows, name: str) -> list[str]: 18 | return [r[name] for r in rows if r[name]] 19 | 20 | 21 | def test_locations(data_rows): 22 | expected = { 23 | "FIN-01", 24 | "FIN-02", 25 | "FIN-02", 26 | "ICE-01", 27 | } 28 | locations = select_row(data_rows, "location") 29 | missing = expected - set(locations) 30 | assert not missing 31 | 32 | count = Counter(locations) 33 | for loc in expected: 34 | assert count[loc] > 1 35 | 36 | 37 | def test_spot(data_rows): 38 | spots = select_row(data_rows, "spot") 39 | 40 | expected = set(("True", "False")) 41 | assert set(spots) == expected 42 | 43 | count = Counter(spots) 44 | for spot_key in ("True", "False"): 45 | assert count[spot_key] > 1 46 | 47 | 48 | def test_price(data_rows): 49 | prices = select_row(data_rows, "price") 50 | assert min(float(p) for p in prices) > 0 51 | 52 | 53 | def test_gpu_present(data_rows): 54 | refs = [name for name in GPU_MAP.values() if name not in ALL_AMD_GPUS] 55 | gpus = select_row(data_rows, "gpu_name") 56 | assert set(gpus) == set(refs) 57 | -------------------------------------------------------------------------------- /src/integrity_tests/test_gcp.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture 7 | def data(catalog_dir: Path) -> str: 8 | return (catalog_dir / "gcp.csv").read_text() 9 | 10 | 11 | class TestGCPCatalog: 12 | def test_e2_highcpu_2_zones(self, data: str): 13 | zones = [ 14 | "asia-east1-a", 15 | "asia-east1-b", 16 | "asia-east1-c", 17 | "asia-east2-a", 18 | "asia-east2-b", 19 | "asia-east2-c", 20 | "asia-northeast1-a", 21 | "asia-northeast1-b", 22 | "asia-northeast1-c", 23 | "asia-northeast2-a", 24 | "asia-northeast2-b", 25 | "asia-northeast2-c", 26 | "asia-northeast3-a", 27 | "asia-northeast3-b", 28 | "asia-northeast3-c", 29 | "asia-south1-a", 30 | "asia-south1-b", 31 | "asia-south1-c", 32 | "asia-south2-a", 33 | "asia-south2-b", 34 | "asia-south2-c", 35 | "asia-southeast1-a", 36 | "asia-southeast1-b", 37 | "asia-southeast1-c", 38 | "asia-southeast2-a", 39 | "asia-southeast2-b", 40 | "asia-southeast2-c", 41 | "australia-southeast1-a", 42 | "australia-southeast1-b", 43 | "australia-southeast1-c", 44 | "australia-southeast2-a", 45 | "australia-southeast2-b", 46 | "australia-southeast2-c", 47 | "europe-central2-a", 48 | "europe-central2-b", 49 | "europe-central2-c", 50 | "europe-north1-a", 51 | "europe-north1-b", 52 | "europe-north1-c", 53 | "europe-southwest1-a", 54 | "europe-southwest1-b", 55 | "europe-southwest1-c", 56 | "europe-west1-b", 57 | "europe-west1-c", 58 | "europe-west1-d", 59 | "europe-west10-a", 60 | "europe-west10-b", 61 | "europe-west10-c", 62 | "europe-west12-a", 63 | "europe-west12-b", 64 | "europe-west12-c", 65 | "europe-west2-a", 66 | "europe-west2-b", 67 | "europe-west2-c", 68 | "europe-west3-a", 69 | "europe-west3-b", 70 | "europe-west3-c", 71 | "europe-west4-a", 72 | "europe-west4-b", 73 | "europe-west4-c", 74 | "europe-west6-a", 75 | "europe-west6-b", 76 | "europe-west6-c", 77 | "europe-west8-a", 78 | "europe-west8-b", 79 | "europe-west8-c", 80 | "europe-west9-a", 81 | "europe-west9-b", 82 | "europe-west9-c", 83 | "me-central1-a", 84 | "me-central1-b", 85 | "me-central1-c", 86 | "me-central2-a", 87 | "me-central2-b", 88 | "me-central2-c", 89 | "me-west1-a", 90 | "me-west1-b", 91 | "me-west1-c", 92 | "northamerica-northeast1-a", 93 | "northamerica-northeast1-b", 94 | "northamerica-northeast1-c", 95 | "northamerica-northeast2-a", 96 | "northamerica-northeast2-b", 97 | "northamerica-northeast2-c", 98 | "southamerica-east1-a", 99 | "southamerica-east1-b", 100 | "southamerica-east1-c", 101 | "southamerica-west1-a", 102 | "southamerica-west1-b", 103 | "southamerica-west1-c", 104 | "us-central1-a", 105 | "us-central1-b", 106 | "us-central1-c", 107 | "us-central1-f", 108 | "us-east1-b", 109 | "us-east1-c", 110 | "us-east1-d", 111 | "us-east4-a", 112 | "us-east4-b", 113 | "us-east4-c", 114 | "us-east5-a", 115 | "us-east5-b", 116 | "us-east5-c", 117 | "us-south1-a", 118 | "us-south1-b", 119 | "us-south1-c", 120 | "us-west1-a", 121 | "us-west1-b", 122 | "us-west1-c", 123 | "us-west2-a", 124 | "us-west2-b", 125 | "us-west2-c", 126 | "us-west3-a", 127 | "us-west3-b", 128 | "us-west3-c", 129 | "us-west4-a", 130 | "us-west4-b", 131 | "us-west4-c", 132 | ] 133 | assert all(f"\ne2-highcpu-2,{i}," in data for i in zones) 134 | 135 | def test_spots_presented(self, data: str): 136 | assert ",True," in data 137 | 138 | def test_ondemand_presented(self, data: str): 139 | assert ",False," in data 140 | 141 | def test_gpu_presented(self, data: str): 142 | gpus = [ 143 | "H100", 144 | "A100", 145 | "L4", 146 | "T4", 147 | "V100", 148 | "P100", 149 | ] 150 | assert all(f",{i}," in data for i in gpus) 151 | 152 | def test_tpu_presented(self, data: str): 153 | gpus = [ 154 | "v2", 155 | "v3", 156 | "v5litepod", 157 | "v5p", 158 | ] 159 | assert all(gpu in data for gpu in gpus) 160 | 161 | def test_both_a100_presented(self, data: str): 162 | assert ",A100,40.0," in data 163 | assert ",A100,80.0," in data 164 | -------------------------------------------------------------------------------- /src/integrity_tests/test_lambdalabs.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from operator import itemgetter 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | 8 | @pytest.fixture 9 | def data_rows(catalog_dir: Path) -> list[dict]: 10 | with open(catalog_dir / "lambdalabs.csv") as f: 11 | return list(csv.DictReader(f)) 12 | 13 | 14 | @pytest.mark.parametrize("gpu", ["A10", "A100", "H100"]) 15 | def test_gpu_present(gpu: str, data_rows: list[dict]): 16 | assert gpu in map(itemgetter("gpu_name"), data_rows) 17 | 18 | 19 | def test_on_demand_present(data_rows: list[dict]): 20 | assert "False" in map(itemgetter("spot"), data_rows) 21 | 22 | 23 | def test_spot_not_present(data_rows: list[dict]): 24 | assert "True" not in map(itemgetter("spot"), data_rows) 25 | 26 | 27 | def test_locations(data_rows: list[dict]): 28 | expected_locations = { 29 | "asia-northeast-1", 30 | "asia-northeast-2", 31 | "asia-south-1", 32 | "australia-east-1", 33 | "europe-central-1", 34 | "me-west-1", 35 | "us-east-1", 36 | "us-east-2", 37 | "us-east-3", 38 | "us-midwest-1", 39 | "us-south-1", 40 | "us-south-2", 41 | "us-south-3", 42 | "us-west-1", 43 | "us-west-2", 44 | "us-west-3", 45 | } 46 | locations = set(map(itemgetter("location"), data_rows)) 47 | missing = expected_locations - locations 48 | assert not missing 49 | 50 | 51 | def test_non_zero_price(data_rows: list[dict]): 52 | assert all(float(p) > 0 for p in map(itemgetter("price"), data_rows)) 53 | -------------------------------------------------------------------------------- /src/integrity_tests/test_nebius.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from operator import itemgetter 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | 8 | @pytest.fixture 9 | def data_rows(catalog_dir: Path) -> list[dict]: 10 | with open(catalog_dir / "nebius.csv") as f: 11 | return list(csv.DictReader(f)) 12 | 13 | 14 | @pytest.mark.parametrize("gpu", ["L40S", "H100", "H200", ""]) 15 | def test_gpu_present(gpu: str, data_rows: list[dict]): 16 | assert gpu in map(itemgetter("gpu_name"), data_rows) 17 | 18 | 19 | def test_on_demand_present(data_rows: list[dict]): 20 | assert "False" in map(itemgetter("spot"), data_rows) 21 | 22 | 23 | def test_spot_not_present(data_rows: list[dict]): 24 | assert "True" not in map(itemgetter("spot"), data_rows) 25 | 26 | 27 | @pytest.mark.parametrize("location", ["eu-north1", "eu-west1"]) 28 | def test_location_present(location: str, data_rows: list[dict]): 29 | assert location in map(itemgetter("location"), data_rows) 30 | 31 | 32 | def test_non_zero_price(data_rows: list[dict]): 33 | assert all(float(p) > 0 for p in map(itemgetter("price"), data_rows)) 34 | -------------------------------------------------------------------------------- /src/integrity_tests/test_oci.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from operator import itemgetter 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | 8 | @pytest.fixture 9 | def data_rows(catalog_dir: Path) -> list[dict]: 10 | with open(catalog_dir / "oci.csv") as f: 11 | return list(csv.DictReader(f)) 12 | 13 | 14 | @pytest.mark.parametrize("gpu", ["P100", "V100", "A10", "A100", ""]) 15 | def test_gpu_present(gpu: str, data_rows: list[dict]): 16 | assert gpu in map(itemgetter("gpu_name"), data_rows) 17 | 18 | 19 | def test_on_demand_present(data_rows: list[dict]): 20 | assert "False" in map(itemgetter("spot"), data_rows) 21 | 22 | 23 | def test_spot_present(data_rows: list[dict]): 24 | assert "True" in map(itemgetter("spot"), data_rows) 25 | 26 | 27 | def test_spots_contain_flag(data_rows: list[dict]): 28 | for row in data_rows: 29 | assert (row["spot"] == "True") == ("oci-spot" in row["flags"]), row 30 | 31 | 32 | @pytest.mark.parametrize("prefix", ["VM.Standard", "BM.Standard", "VM.GPU", "BM.GPU"]) 33 | def test_family_present(prefix: str, data_rows: list[dict]): 34 | assert any(name.startswith(prefix) for name in map(itemgetter("instance_name"), data_rows)) 35 | 36 | 37 | def test_quantity_decreases_as_query_complexity_increases(data_rows: list[dict]): 38 | zero_or_one_gpu = list(filter(lambda row: int(row["gpu_count"]) in (0, 1), data_rows)) 39 | zero_gpu = list(filter(lambda row: int(row["gpu_count"]) == 0, data_rows)) 40 | one_gpu = list(filter(lambda row: int(row["gpu_count"]) == 1, data_rows)) 41 | 42 | assert len(data_rows) > len(zero_or_one_gpu) 43 | assert len(zero_or_one_gpu) > len(zero_gpu) 44 | assert len(zero_gpu) > len(one_gpu) 45 | -------------------------------------------------------------------------------- /src/integrity_tests/test_runpod.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from collections import Counter 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | from gpuhunt.providers.runpod import GPU_MAP 8 | 9 | 10 | @pytest.fixture 11 | def data_rows(catalog_dir: Path) -> list[dict]: 12 | file = catalog_dir / "runpod.csv" 13 | reader = csv.DictReader(file.open()) 14 | return list(reader) 15 | 16 | 17 | def select_row(rows, name: str) -> list[str]: 18 | return [r[name] for r in rows if r[name]] 19 | 20 | 21 | def test_locations(data_rows): 22 | expected = { 23 | # Secure cloud 24 | "CA-MTL-1", 25 | "CA-MTL-2", 26 | "CA-MTL-3", 27 | "EU-NL-1", 28 | "EU-RO-1", 29 | "EU-SE-1", 30 | "EUR-IS-1", 31 | "EUR-IS-2", 32 | "US-TX-3", 33 | # Community cloud 34 | "CA", 35 | "CZ", 36 | "FR", 37 | "US", 38 | } 39 | locations = set(select_row(data_rows, "location")) 40 | # Assert most are present. Some may be missing due to low availability 41 | assert len(expected - locations) <= 3 42 | 43 | 44 | def test_spot(data_rows): 45 | spots = select_row(data_rows, "spot") 46 | 47 | expected = set(("True", "False")) 48 | assert set(spots) == expected 49 | 50 | count = Counter(spots) 51 | for spot_key in ("True", "False"): 52 | assert count[spot_key] > 1 53 | 54 | 55 | def test_price(data_rows): 56 | prices = select_row(data_rows, "price") 57 | assert min(float(p) for p in prices) > 0 58 | 59 | 60 | def test_gpu_present(data_rows): 61 | refs = set(name for _, name in GPU_MAP.values()) 62 | gpus = set(select_row(data_rows, "gpu_name")) 63 | assert len(refs & gpus) > 7 64 | -------------------------------------------------------------------------------- /src/integrity_tests/test_vastai.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import requests 4 | 5 | from gpuhunt.providers.vastai import get_dstack_gpu_name, get_vastai_gpu_names 6 | 7 | 8 | def test_real_world_vastai_offers(): 9 | """Test that GPU name conversion works correctly for all real-world offers from Vast.ai.""" 10 | # Get all offers from Vast.ai 11 | response = requests.post( 12 | "https://cloud.vast.ai/api/v0/bundles/", json={"limit": 3000}, timeout=10 13 | ) 14 | response.raise_for_status() 15 | offers = response.json()["offers"] 16 | 17 | # Track unique GPU names and their conversions 18 | unique_gpu_names = set() 19 | conversion_issues = set() 20 | 21 | for offer in offers: 22 | vastai_gpu_name = offer["gpu_name"] 23 | if not vastai_gpu_name: 24 | continue 25 | 26 | unique_gpu_names.add(vastai_gpu_name) 27 | 28 | # Convert to dstack format and back 29 | dstack_name = get_dstack_gpu_name(vastai_gpu_name) 30 | vastai_names = get_vastai_gpu_names(dstack_name) 31 | 32 | # Check if the original name is in the converted back list 33 | if vastai_gpu_name not in vastai_names: 34 | conversion_issues.add(vastai_gpu_name) 35 | 36 | # Print statistics about mismatched GPUs 37 | if conversion_issues: 38 | warning_msg = f"Found {len(conversion_issues)} GPU names without valid mapping:\n" 39 | for gpu_name in sorted(conversion_issues): 40 | warning_msg += f"- {gpu_name}\n" 41 | warnings.warn(warning_msg) 42 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dstackai/gpuhunt/fcc8f33dc037503956d7272c91e730e6584ec492/src/tests/__init__.py -------------------------------------------------------------------------------- /src/tests/_internal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dstackai/gpuhunt/fcc8f33dc037503956d7272c91e730e6584ec492/src/tests/_internal/__init__.py -------------------------------------------------------------------------------- /src/tests/_internal/test_catalog.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from unittest.mock import Mock 3 | 4 | import gpuhunt._internal.catalog as internal_catalog 5 | from gpuhunt import Catalog, CatalogItem, RawCatalogItem 6 | from gpuhunt.providers.tensordock import TensorDockProvider 7 | from gpuhunt.providers.vastai import VastAIProvider 8 | 9 | 10 | class TestQuery: 11 | def test_query_merge(self): 12 | catalog = Catalog(balance_resources=False, auto_reload=False) 13 | 14 | tensordock = TensorDockProvider() 15 | tensordock.get = Mock(return_value=[catalog_item(price=1), catalog_item(price=3)]) 16 | catalog.add_provider(tensordock) 17 | 18 | vastai = VastAIProvider() 19 | vastai.get = Mock(return_value=[catalog_item(price=2), catalog_item(price=1)]) 20 | catalog.add_provider(vastai) 21 | 22 | assert catalog.query(provider=["tensordock", "vastai"]) == [ 23 | catalog_item(provider="tensordock", price=1), 24 | catalog_item(provider="vastai", price=2), 25 | catalog_item(provider="vastai", price=1), 26 | catalog_item(provider="tensordock", price=3), 27 | ] 28 | 29 | def test_no_providers_some_not_loaded(self): 30 | catalog = Catalog(balance_resources=False, auto_reload=False) 31 | 32 | tensordock = TensorDockProvider() 33 | tensordock.get = Mock(return_value=[catalog_item(price=1)]) 34 | catalog.add_provider(tensordock) 35 | 36 | internal_catalog.OFFLINE_PROVIDERS = [] 37 | assert catalog.query() == [ 38 | catalog_item(provider="tensordock", price=1), 39 | ] 40 | 41 | def test_provider_filter(self): 42 | catalog = Catalog(balance_resources=False, auto_reload=False) 43 | catalog.add_provider(tensordock := TensorDockProvider()) 44 | catalog.add_provider(vastai := VastAIProvider()) 45 | 46 | tensordock_offers = [catalog_item(price=1)] 47 | vastai_offers = [catalog_item(price=2), catalog_item(price=3)] 48 | 49 | tensordock.get = Mock(return_value=tensordock_offers) 50 | vastai.get = Mock(return_value=vastai_offers) 51 | 52 | assert len(catalog.query(provider="tensordock")) == 1 53 | assert len(catalog.query(provider="Tensordock")) == 1 54 | assert len(catalog.query(provider="vastai")) == 2 55 | assert len(catalog.query(provider="VastAI")) == 2 56 | assert len(catalog.query(provider=["tensordock", "VastAI"])) == 3 57 | 58 | def test_gpu_name_filter(self): 59 | catalog = Catalog(balance_resources=False, auto_reload=False) 60 | catalog.add_provider(tensordock := TensorDockProvider()) 61 | 62 | tensordock.get = Mock( 63 | return_value=[ 64 | catalog_item(gpu_name="A10"), 65 | catalog_item(gpu_name="A100"), 66 | catalog_item(gpu_name="a100"), 67 | ] 68 | ) 69 | 70 | assert len(catalog.query(gpu_name="V100")) == 0 71 | assert len(catalog.query(gpu_name="A10")) == 1 72 | assert len(catalog.query(gpu_name="a10")) == 1 73 | assert len(catalog.query(gpu_name="A100")) == 2 74 | assert len(catalog.query(gpu_name="a100")) == 2 75 | assert len(catalog.query(gpu_name=["a10", "A100"])) == 3 76 | 77 | 78 | def catalog_item(**kwargs) -> Union[CatalogItem, RawCatalogItem]: 79 | values = dict( 80 | instance_name="instance", 81 | cpu=1, 82 | memory=1, 83 | gpu_vendor="nvidia", 84 | gpu_count=1, 85 | gpu_name="gpu", 86 | gpu_memory=1, 87 | location="location", 88 | price=1, 89 | spot=False, 90 | disk_size=None, 91 | ) 92 | values.update(kwargs) 93 | if "provider" in values: 94 | return CatalogItem(**values) 95 | return RawCatalogItem(**values) 96 | -------------------------------------------------------------------------------- /src/tests/_internal/test_constraints.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from gpuhunt import CatalogItem, QueryFilter 6 | from gpuhunt._internal.constraints import correct_gpu_memory_gib, matches 7 | from gpuhunt._internal.models import AcceleratorVendor 8 | 9 | 10 | @pytest.fixture 11 | def item() -> CatalogItem: 12 | return CatalogItem( 13 | instance_name="large", 14 | location="us-east-1", 15 | price=1.2, 16 | cpu=16, 17 | memory=64.0, 18 | gpu_vendor=AcceleratorVendor.NVIDIA, 19 | gpu_count=1, 20 | gpu_name="A100", 21 | gpu_memory=40.0, 22 | spot=False, 23 | provider="aws", 24 | disk_size=None, 25 | ) 26 | 27 | 28 | @pytest.fixture 29 | def cpu_items() -> list[CatalogItem]: 30 | datacrunch = CatalogItem( 31 | instance_name="CPU.120V.480G", 32 | location="ICE-01", 33 | price=3.0, 34 | cpu=120, 35 | memory=480.0, 36 | gpu_vendor=None, 37 | gpu_count=0, 38 | gpu_name=None, 39 | gpu_memory=0.0, 40 | spot=False, 41 | provider="datacrunch", 42 | disk_size=None, 43 | ) 44 | aws = CatalogItem( 45 | instance_name="c5.12xlarge", 46 | location="us-east-2", 47 | price=2.04, 48 | cpu=48, 49 | memory=96.0, 50 | gpu_vendor=None, 51 | gpu_count=0, 52 | gpu_name=None, 53 | gpu_memory=None, 54 | spot=False, 55 | provider="aws", 56 | disk_size=None, 57 | ) 58 | return [datacrunch, aws] 59 | 60 | 61 | class TestMatches: 62 | def test_empty(self, item: CatalogItem): 63 | assert matches(item, QueryFilter()) 64 | 65 | def test_cpu(self, item: CatalogItem): 66 | assert matches(item, QueryFilter(min_cpu=16)) 67 | assert matches(item, QueryFilter(max_cpu=16)) 68 | assert not matches(item, QueryFilter(min_cpu=32)) 69 | assert not matches(item, QueryFilter(max_cpu=8)) 70 | 71 | def test_memory(self, item: CatalogItem): 72 | assert matches(item, QueryFilter(min_memory=64.0)) 73 | assert matches(item, QueryFilter(max_memory=64.0)) 74 | assert not matches(item, QueryFilter(min_memory=128.0)) 75 | assert not matches(item, QueryFilter(max_memory=32.0)) 76 | 77 | def test_gpu_vendor_nvidia(self, item: CatalogItem): 78 | assert matches(item, QueryFilter(gpu_vendor=AcceleratorVendor.NVIDIA)) 79 | assert not matches(item, QueryFilter(gpu_vendor=AcceleratorVendor.AMD)) 80 | 81 | def test_gpu_vendor_amd(self, item: CatalogItem): 82 | item.gpu_vendor = AcceleratorVendor.AMD 83 | assert matches(item, QueryFilter(gpu_vendor=AcceleratorVendor.AMD)) 84 | assert not matches(item, QueryFilter(gpu_vendor=AcceleratorVendor.NVIDIA)) 85 | 86 | def test_gpu_count(self, item: CatalogItem): 87 | assert matches(item, QueryFilter(min_gpu_count=1)) 88 | assert matches(item, QueryFilter(max_gpu_count=1)) 89 | assert not matches(item, QueryFilter(min_gpu_count=2)) 90 | assert not matches(item, QueryFilter(max_gpu_count=0)) 91 | 92 | def test_gpu_memory(self, item: CatalogItem): 93 | assert matches(item, QueryFilter(min_gpu_memory=40.0)) 94 | assert matches(item, QueryFilter(max_gpu_memory=40.0)) 95 | assert not matches(item, QueryFilter(min_gpu_memory=80.0)) 96 | assert not matches(item, QueryFilter(max_gpu_memory=20.0)) 97 | 98 | def test_gpu_name(self, item: CatalogItem): 99 | assert matches(item, QueryFilter(gpu_name=["a100"])) 100 | assert matches(item, QueryFilter(gpu_name=["A100"])) 101 | assert not matches(item, QueryFilter(gpu_name=["A10"])) 102 | 103 | def test_gpu_name_with_filter_setattr(self, item: CatalogItem): 104 | q = QueryFilter() 105 | q.gpu_name = ["a100"] 106 | assert matches(item, q) 107 | q.gpu_name = ["A100"] 108 | assert matches(item, q) 109 | q.gpu_name = ["A10"] 110 | assert not matches(item, q) 111 | 112 | def test_total_gpu_memory(self, item: CatalogItem): 113 | assert matches(item, QueryFilter(min_total_gpu_memory=40.0)) 114 | assert matches(item, QueryFilter(max_total_gpu_memory=40.0)) 115 | assert not matches(item, QueryFilter(min_total_gpu_memory=80.0)) 116 | assert not matches(item, QueryFilter(max_total_gpu_memory=20.0)) 117 | 118 | def test_price(self, item: CatalogItem): 119 | assert matches(item, QueryFilter(min_price=1.2)) 120 | assert matches(item, QueryFilter(max_price=1.2)) 121 | assert not matches(item, QueryFilter(min_price=1.3)) 122 | assert not matches(item, QueryFilter(max_price=1.1)) 123 | 124 | def test_spot(self, item: CatalogItem): 125 | assert matches(item, QueryFilter(spot=False)) 126 | assert not matches(item, QueryFilter(spot=True)) 127 | 128 | def test_compute_capability(self, item: CatalogItem): 129 | assert matches(item, QueryFilter(min_compute_capability=(8, 0))) 130 | assert matches(item, QueryFilter(max_compute_capability=(8, 0))) 131 | assert not matches(item, QueryFilter(min_compute_capability=(8, 1))) 132 | assert not matches(item, QueryFilter(max_compute_capability=(7, 9))) 133 | 134 | def test_compute_capability_not_nvidia(self, item: CatalogItem): 135 | item.gpu_vendor = AcceleratorVendor.AMD 136 | assert not matches(item, QueryFilter(min_compute_capability=(8, 0))) 137 | assert not matches(item, QueryFilter(max_compute_capability=(8, 0))) 138 | 139 | def test_ti_gpu(self): 140 | item = CatalogItem( 141 | instance_name="large", 142 | location="us-east-1", 143 | price=1.2, 144 | cpu=16, 145 | memory=64.0, 146 | gpu_count=1, 147 | gpu_vendor=AcceleratorVendor.NVIDIA, 148 | gpu_name="RTX3060Ti", # case-sensitive 149 | gpu_memory=8.0, 150 | spot=False, 151 | provider="aws", 152 | disk_size=None, 153 | ) 154 | assert matches(item, QueryFilter(gpu_name=["RTX3060TI"])) 155 | 156 | def test_provider(self, cpu_items): 157 | assert matches(cpu_items[0], QueryFilter(provider=["datacrunch"])) 158 | assert matches(cpu_items[0], QueryFilter(provider=["DataCrunch"])) 159 | assert not matches(cpu_items[0], QueryFilter(provider=["aws"])) 160 | 161 | assert matches(cpu_items[1], QueryFilter(provider=["aws"])) 162 | assert matches(cpu_items[1], QueryFilter(provider=["AWS"])) 163 | assert not matches(cpu_items[1], QueryFilter(provider=["datacrunch"])) 164 | 165 | def test_provider_with_filter_setattr(self, cpu_items): 166 | q = QueryFilter() 167 | q.provider = ["datacrunch"] 168 | assert matches(cpu_items[0], q) 169 | q.provider = ["DataCrunch"] 170 | assert matches(cpu_items[0], q) 171 | q.provider = ["aws"] 172 | assert not matches(cpu_items[0], q) 173 | 174 | @pytest.mark.parametrize( 175 | ("item_flags", "query_allowed_flags", "should_match"), 176 | [ 177 | pytest.param([], ["a"], True, id="matches-if-no-flags"), 178 | pytest.param([], None, True, id="matches-if-no-flags-and-all-flags-allowed"), 179 | pytest.param(["a", "b"], None, True, id="matches-if-has-flags-and-all-flags-allowed"), 180 | pytest.param( 181 | ["a", "b"], ["a", "b", "c"], True, id="matches-if-has-flags-all-of-which-allowed" 182 | ), 183 | pytest.param(["a", "b"], ["a"], False, id="not-matches-if-some-flags-not-allowed"), 184 | ], 185 | ) 186 | def test_flags( 187 | self, 188 | item: CatalogItem, 189 | item_flags: list[str], 190 | query_allowed_flags: Optional[list[str]], 191 | should_match: bool, 192 | ) -> None: 193 | item.flags = item_flags 194 | q = QueryFilter(allowed_flags=query_allowed_flags) 195 | assert matches(item, q) == should_match 196 | 197 | 198 | @pytest.mark.parametrize( 199 | ("gpu_name", "memory_mib", "expected_memory_gib"), 200 | [ 201 | ("H100NVL", 95830.0, 94), 202 | ("L40S", 46068.0, 48), 203 | ("A10G", 23028.0, 24), 204 | ("A10", 4096.0, 4), 205 | ("unknown", 8200.1, 8), 206 | ], 207 | ) 208 | def test_correct_gpu_memory(gpu_name: str, memory_mib: float, expected_memory_gib: int) -> None: 209 | assert correct_gpu_memory_gib(gpu_name, memory_mib) == expected_memory_gib 210 | -------------------------------------------------------------------------------- /src/tests/_internal/test_models.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import pytest 4 | 5 | from gpuhunt._internal.constraints import KNOWN_AMD_GPUS 6 | from gpuhunt._internal.models import ( 7 | AcceleratorVendor, 8 | AMDArchitecture, 9 | CatalogItem, 10 | CPUArchitecture, 11 | Optional, 12 | RawCatalogItem, 13 | ) 14 | 15 | NVIDIA = AcceleratorVendor.NVIDIA 16 | GOOGLE = AcceleratorVendor.GOOGLE 17 | AMD = AcceleratorVendor.AMD 18 | 19 | 20 | @pytest.mark.parametrize( 21 | ["gpu_count", "gpu_vendor", "gpu_name", "expected_gpu_vendor", "expected_gpu_name"], 22 | [ 23 | pytest.param(None, None, None, None, None, id="none-gpu"), 24 | pytest.param(0, None, None, None, None, id="zero-gpu"), 25 | pytest.param(1, None, "A100", "nvidia", "A100", id="one-gpu"), 26 | pytest.param(1, None, "tpu-v3", "google", "v3", id="one-tpu-vendor-not-set"), 27 | pytest.param(1, "google", "tpu-v5p", "google", "v5p", id="one-tpu-vendor-is-set"), 28 | pytest.param(1, AMD, "MI300X", "amd", "MI300X", id="cast-enum-to-string"), 29 | ], 30 | ) 31 | def test_raw_catalog_item_gpu_vendor_heuristic( 32 | gpu_count: Optional[int], 33 | gpu_vendor: Union[AcceleratorVendor, str, None], 34 | gpu_name: Optional[str], 35 | expected_gpu_vendor: Optional[str], 36 | expected_gpu_name: Optional[str], 37 | ): 38 | dct = {} 39 | if gpu_vendor is not None: 40 | dct["gpu_vendor"] = gpu_vendor 41 | if gpu_count is not None: 42 | dct["gpu_count"] = gpu_count 43 | if gpu_name is not None: 44 | dct["gpu_name"] = gpu_name 45 | 46 | item = RawCatalogItem.from_dict(dct) 47 | 48 | assert item.gpu_vendor == expected_gpu_vendor 49 | assert item.gpu_name == expected_gpu_name 50 | 51 | 52 | @pytest.mark.parametrize( 53 | ["gpu_count", "gpu_vendor", "gpu_name", "expected_gpu_vendor"], 54 | [ 55 | pytest.param(None, None, None, None, id="none-gpu"), 56 | pytest.param(0, None, None, None, id="zero-gpu"), 57 | pytest.param(1, None, None, NVIDIA, id="one-gpu-no-name"), 58 | pytest.param(1, None, "v3", NVIDIA, id="one-gpu-with-any-name"), 59 | pytest.param(1, "amd", "MI300X", AMD, id="cast-string-to-enum"), 60 | ], 61 | ) 62 | def test_catalog_item_gpu_vendor_heuristic( 63 | gpu_count: Optional[int], 64 | gpu_vendor: Union[AcceleratorVendor, str, None], 65 | gpu_name: Optional[str], 66 | expected_gpu_vendor: Optional[AcceleratorVendor], 67 | ): 68 | item = CatalogItem( 69 | instance_name="test-instance", 70 | location="eu-west-1", 71 | price=1.0, 72 | cpu=1, 73 | memory=32.0, 74 | gpu_vendor=gpu_vendor, 75 | gpu_count=gpu_count, 76 | gpu_name=gpu_name, 77 | gpu_memory=8.0, 78 | spot=False, 79 | disk_size=100.0, 80 | provider="test", 81 | ) 82 | 83 | assert item.gpu_vendor == expected_gpu_vendor 84 | 85 | 86 | @pytest.mark.parametrize( 87 | ["cpu_arch", "expected_cpu_arch"], 88 | [ 89 | pytest.param(None, CPUArchitecture.X86, id="non-set"), 90 | pytest.param(CPUArchitecture.X86, CPUArchitecture.X86, id="enum"), 91 | pytest.param("ARM", CPUArchitecture.ARM, id="cast-string-to-enum"), 92 | ], 93 | ) 94 | def test_catalog_item_cpu_arch_heuristic( 95 | cpu_arch: Union[CPUArchitecture, str, None], 96 | expected_cpu_arch: CPUArchitecture, 97 | ): 98 | item = CatalogItem( 99 | instance_name="test-instance", 100 | location="eu-west-1", 101 | price=1.0, 102 | cpu_arch=cpu_arch, 103 | cpu=1, 104 | memory=32.0, 105 | gpu_count=0, 106 | gpu_name=None, 107 | gpu_memory=8.0, 108 | spot=False, 109 | disk_size=100.0, 110 | provider="test", 111 | ) 112 | 113 | assert item.cpu_arch == expected_cpu_arch 114 | 115 | 116 | @pytest.mark.parametrize( 117 | ["model", "architecture", "expected_memory"], 118 | [ 119 | pytest.param("MI325X", AMDArchitecture.CDNA3, 288, id="MI325X"), 120 | pytest.param("MI308X", AMDArchitecture.CDNA3, 128, id="MI308X"), 121 | pytest.param("MI300X", AMDArchitecture.CDNA3, 192, id="MI300X"), 122 | pytest.param("MI300A", AMDArchitecture.CDNA3, 128, id="MI300A"), 123 | pytest.param("MI250X", AMDArchitecture.CDNA2, 128, id="MI250X"), 124 | pytest.param("MI250", AMDArchitecture.CDNA2, 128, id="MI250"), 125 | pytest.param("MI210", AMDArchitecture.CDNA2, 64, id="MI210"), 126 | pytest.param("MI100", AMDArchitecture.CDNA, 32, id="MI100"), 127 | ], 128 | ) 129 | def test_amd_gpu_architecture(model: str, architecture: AMDArchitecture, expected_memory: int): 130 | for gpu in KNOWN_AMD_GPUS: 131 | if gpu.name == model: 132 | assert gpu.architecture == architecture 133 | assert gpu.memory == expected_memory 134 | return 135 | # If we get here, the test should fail since we could not find the GPU in our known list. 136 | assert False 137 | 138 | 139 | def test_raw_catalog_item_to_from_dict() -> None: 140 | item = RawCatalogItem( 141 | instance_name="test-instance", 142 | location="eu-west-1", 143 | price=1.0, 144 | cpu_arch=CPUArchitecture.ARM, 145 | cpu=1, 146 | memory=32.0, 147 | gpu_vendor=AcceleratorVendor.NVIDIA, 148 | gpu_count=1, 149 | gpu_name="A10", 150 | gpu_memory=24.0, 151 | spot=False, 152 | disk_size=100.0, 153 | flags=["f1", "f2", "f3"], 154 | ) 155 | item_dict = item.dict() 156 | assert item_dict == { 157 | "instance_name": "test-instance", 158 | "location": "eu-west-1", 159 | "price": 1.0, 160 | "cpu_arch": "arm", 161 | "cpu": 1, 162 | "memory": 32.0, 163 | "gpu_vendor": "nvidia", 164 | "gpu_count": 1, 165 | "gpu_name": "A10", 166 | "gpu_memory": 24.0, 167 | "spot": False, 168 | "disk_size": 100.0, 169 | "flags": "f1 f2 f3", 170 | } 171 | assert RawCatalogItem.from_dict(item_dict) == item 172 | -------------------------------------------------------------------------------- /src/tests/_internal/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from gpuhunt._internal.utils import to_camel_case 4 | 5 | 6 | @pytest.mark.parametrize( 7 | ["before", "after"], 8 | [ 9 | ["spam_ham_eggs", "spamHamEggs"], 10 | ["spam__ham__eggs", "spamHamEggs"], 11 | ["__spam_ham_eggs__", "spamHamEggs"], 12 | ["spamHam_eggs", "spamHamEggs"], 13 | ["spamHamEggs", "spamHamEggs"], 14 | ["SpamHam_eggs", "SpamHamEggs"], 15 | ["spam", "spam"], 16 | ["", ""], 17 | ], 18 | ) 19 | def test_to_camel_case(before, after): 20 | assert to_camel_case(before) == after 21 | -------------------------------------------------------------------------------- /src/tests/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dstackai/gpuhunt/fcc8f33dc037503956d7272c91e730e6584ec492/src/tests/providers/__init__.py -------------------------------------------------------------------------------- /src/tests/providers/test_cudo.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import gpuhunt._internal.catalog as internal_catalog 4 | from gpuhunt import Catalog 5 | from gpuhunt.providers.cudo import ( 6 | CudoProvider, 7 | get_balanced_disk_size, 8 | get_balanced_memory, 9 | get_memory, 10 | gpu_name, 11 | ) 12 | 13 | 14 | @pytest.fixture 15 | def machine_types() -> list[dict]: 16 | return [ 17 | { 18 | "dataCenterId": "br-saopaulo-1", 19 | "machineType": "cascade-lake", 20 | "cpuModel": "Cascadelake-Server-noTSX", 21 | "gpuModel": "RTX 3080", 22 | "gpuModelId": "nvidia-rtx-3080", 23 | "minVcpuPerMemoryGib": 0.25, 24 | "maxVcpuPerMemoryGib": 1, 25 | "minVcpuPerGpu": 1, 26 | "maxVcpuPerGpu": 13, 27 | "vcpuPriceHr": {"value": "0.002500"}, 28 | "memoryGibPriceHr": {"value": "0.003800"}, 29 | "gpuPriceHr": {"value": "0.05"}, 30 | "minStorageGibPriceHr": {"value": "0.00013"}, 31 | "ipv4PriceHr": {"value": "0.005500"}, 32 | "maxVcpuFree": 76, 33 | "totalVcpuFree": 377, 34 | "maxMemoryGibFree": 227, 35 | "totalMemoryGibFree": 1132, 36 | "maxGpuFree": 5, 37 | "totalGpuFree": 24, 38 | "maxStorageGibFree": 42420, 39 | "totalStorageGibFree": 42420, 40 | }, 41 | { 42 | "dataCenterId": "no-luster-1", 43 | "machineType": "epyc-rome-rtx-a5000", 44 | "cpuModel": "EPYC-Rome", 45 | "gpuModel": "RTX A5000", 46 | "gpuModelId": "nvidia-rtx-a5000", 47 | "minVcpuPerMemoryGib": 0.259109, 48 | "maxVcpuPerMemoryGib": 1.036437, 49 | "minVcpuPerGpu": 1, 50 | "maxVcpuPerGpu": 16, 51 | "vcpuPriceHr": {"value": "0.002100"}, 52 | "memoryGibPriceHr": {"value": "0.003400"}, 53 | "gpuPriceHr": {"value": "0.520000"}, 54 | "minStorageGibPriceHr": {"value": "0.000107"}, 55 | "ipv4PriceHr": {"value": "0.003500"}, 56 | "renewableEnergy": False, 57 | "maxVcpuFree": 116, 58 | "totalVcpuFree": 208, 59 | "maxMemoryGibFree": 219, 60 | "totalMemoryGibFree": 390, 61 | "maxGpuFree": 4, 62 | "totalGpuFree": 7, 63 | "maxStorageGibFree": 1170, 64 | "totalStorageGibFree": 1170, 65 | }, 66 | ] 67 | 68 | 69 | def test_get_offers_with_query_filter(mocker, machine_types): 70 | catalog = Catalog(balance_resources=False, auto_reload=False) 71 | cudo = CudoProvider() 72 | cudo.list_vm_machine_types = mocker.Mock(return_value=machine_types) 73 | internal_catalog.ONLINE_PROVIDERS = ["cudo"] 74 | internal_catalog.OFFLINE_PROVIDERS = [] 75 | catalog.add_provider(cudo) 76 | query_result = catalog.query(provider=["cudo"], min_gpu_count=1, max_gpu_count=1) 77 | assert len(query_result) >= 1, "No offers found" 78 | 79 | 80 | def test_get_offers_for_gpu_name(mocker, machine_types): 81 | catalog = Catalog(balance_resources=True, auto_reload=False) 82 | cudo = CudoProvider() 83 | cudo.list_vm_machine_types = mocker.Mock(return_value=machine_types) 84 | internal_catalog.ONLINE_PROVIDERS = ["cudo"] 85 | internal_catalog.OFFLINE_PROVIDERS = [] 86 | catalog.add_provider(cudo) 87 | query_result = catalog.query(provider=["cudo"], min_gpu_count=1, gpu_name=["A5000"]) 88 | assert len(query_result) >= 1, "No offers found" 89 | 90 | 91 | def test_get_offers_for_gpu_memory(mocker, machine_types): 92 | catalog = Catalog(balance_resources=True, auto_reload=False) 93 | cudo = CudoProvider() 94 | cudo.list_vm_machine_types = mocker.Mock(return_value=machine_types) 95 | internal_catalog.ONLINE_PROVIDERS = ["cudo"] 96 | internal_catalog.OFFLINE_PROVIDERS = [] 97 | catalog.add_provider(cudo) 98 | query_result = catalog.query(provider=["cudo"], min_gpu_count=1, min_gpu_memory=16) 99 | assert len(query_result) >= 1, "No offers found" 100 | 101 | 102 | def test_get_offers_for_compute_capability(mocker, machine_types): 103 | catalog = Catalog(balance_resources=True, auto_reload=False) 104 | cudo = CudoProvider() 105 | cudo.list_vm_machine_types = mocker.Mock(return_value=machine_types) 106 | internal_catalog.ONLINE_PROVIDERS = ["cudo"] 107 | internal_catalog.OFFLINE_PROVIDERS = [] 108 | catalog.add_provider(cudo) 109 | query_result = catalog.query(provider=["cudo"], min_gpu_count=1, min_compute_capability=(8, 6)) 110 | assert len(query_result) >= 1, "No offers found" 111 | 112 | 113 | def test_get_offers_no_query_filter(mocker, machine_types): 114 | catalog = Catalog(balance_resources=True, auto_reload=False) 115 | cudo = CudoProvider() 116 | cudo.list_vm_machine_types = mocker.Mock(return_value=machine_types) 117 | internal_catalog.ONLINE_PROVIDERS = ["cudo"] 118 | internal_catalog.OFFLINE_PROVIDERS = [] 119 | catalog.add_provider(cudo) 120 | query_result = catalog.query(provider=["cudo"]) 121 | assert len(query_result) >= 1, "No offers found" 122 | 123 | 124 | def test_optimize_offers_2(mocker, machine_types): 125 | catalog = Catalog(balance_resources=True, auto_reload=False) 126 | cudo = CudoProvider() 127 | cudo.list_vm_machine_types = mocker.Mock(return_value=machine_types[0:1]) 128 | internal_catalog.ONLINE_PROVIDERS = ["cudo"] 129 | internal_catalog.OFFLINE_PROVIDERS = [] 130 | catalog.add_provider(cudo) 131 | query_result = catalog.query( 132 | provider=["cudo"], min_cpu=2, min_gpu_count=1, max_gpu_count=1, min_memory=8 133 | ) 134 | machine_type = machine_types[0] 135 | balance_resource = True 136 | available_disk = machine_type["maxStorageGibFree"] 137 | gpu_memory = get_memory(gpu_name(machine_type["gpuModel"])) 138 | max_memory = None 139 | max_disk_size = None 140 | min_disk_size = None 141 | 142 | assert len(query_result) >= 1 143 | 144 | for config in query_result: 145 | min_cpus_for_memory = machine_type["minVcpuPerMemoryGib"] * config.cpu 146 | max_cpus_for_memory = machine_type["maxVcpuPerMemoryGib"] * config.memory 147 | min_cpus_for_gpu = machine_type["minVcpuPerGpu"] * config.gpu_count 148 | assert config.cpu >= min_cpus_for_memory, ( 149 | f"VM config does not meet the minimum CPU:Memory requirement. Required minimum CPUs: " 150 | f"{min_cpus_for_memory}, Found: {config.cpu}" 151 | ) 152 | assert config.cpu <= max_cpus_for_memory, ( 153 | f"VM config exceeds the maximum CPU:Memory allowance. Allowed maximum CPUs: " 154 | f"{max_cpus_for_memory}, Found: {config.cpu}" 155 | ) 156 | assert config.cpu >= min_cpus_for_gpu, ( 157 | f"VM config does not meet the minimum CPU:GPU requirement. " 158 | f"Required minimum CPUs: {min_cpus_for_gpu}, Found: {config.cpu}" 159 | ) 160 | # Perform the balance resource checks if balance_resource is True 161 | if balance_resource: 162 | expected_memory = get_balanced_memory(config.gpu_count, gpu_memory, max_memory) 163 | expected_disk_size = get_balanced_disk_size( 164 | available_disk, 165 | config.memory, 166 | config.gpu_count * gpu_memory, 167 | max_disk_size, 168 | min_disk_size, 169 | ) 170 | 171 | assert config.memory == expected_memory, ( 172 | f"Memory allocation does not match the expected balanced memory. " 173 | f"Expected: {expected_memory}, Found: {config.memory}" 174 | ) 175 | assert config.disk_size == expected_disk_size, ( 176 | f"Disk size allocation does not match the expected balanced disk size. " 177 | f"Expected: {expected_disk_size}, Found: {config.disk_size}" 178 | ) 179 | -------------------------------------------------------------------------------- /src/tests/providers/test_oci.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from gpuhunt.providers.oci import get_gpu_name 4 | 5 | 6 | @pytest.mark.parametrize( 7 | ("shape_name", "gpu_name"), 8 | [ 9 | ("VM.GPU.A10.2", "A10"), 10 | ("BM.GPU.A100-v2.8", "A100"), 11 | ("BM.GPU4.8", "A100"), 12 | ("VM.GPU3.4", "V100"), 13 | ("VM.GPU2.1", "P100"), 14 | ("BM.GPU.H100.8", "H100"), 15 | ("VM.Standard2.8", None), 16 | ("VM.Notgpu.A10", None), 17 | ], 18 | ) 19 | def test_get_gpu_name(shape_name, gpu_name): 20 | assert get_gpu_name(shape_name) == gpu_name 21 | -------------------------------------------------------------------------------- /src/tests/providers/test_providers.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import inspect 3 | import pkgutil 4 | import sys 5 | 6 | import pytest 7 | 8 | import gpuhunt.providers 9 | from gpuhunt._internal.catalog import OFFLINE_PROVIDERS, ONLINE_PROVIDERS 10 | 11 | 12 | @pytest.fixture() 13 | def providers(): 14 | """List of all provider classes""" 15 | members = [] 16 | for module_info in pkgutil.walk_packages(gpuhunt.providers.__path__): 17 | if sys.version_info < (3, 10) and module_info.name == "nebius": 18 | continue 19 | module = importlib.import_module( 20 | f".{module_info.name}", 21 | package="gpuhunt.providers", 22 | ) 23 | for _, member in inspect.getmembers(module): 24 | if not inspect.isclass(member): 25 | continue 26 | if member.__name__.islower(): 27 | continue # skip builtins to avoid CPython bug #89489 in `issubclass` below 28 | if not issubclass(member, gpuhunt.providers.AbstractProvider): 29 | continue 30 | if member.__name__ == "AbstractProvider": 31 | continue 32 | members.append(member) 33 | assert members 34 | return members 35 | 36 | 37 | def test_catalog_providers_is_unique(): 38 | CATALOG_PROVIDERS = OFFLINE_PROVIDERS + ONLINE_PROVIDERS 39 | assert len(set(CATALOG_PROVIDERS)) == len(CATALOG_PROVIDERS) 40 | 41 | 42 | def test_all_providers_have_a_names(providers): 43 | names = [p.NAME for p in providers] 44 | assert gpuhunt.providers.AbstractProvider.NAME not in names 45 | assert len(set(names)) == len(names) 46 | 47 | 48 | def test_catalog_providers(providers): 49 | CATALOG_PROVIDERS = OFFLINE_PROVIDERS + ONLINE_PROVIDERS 50 | if sys.version_info < (3, 10): 51 | CATALOG_PROVIDERS = [p for p in CATALOG_PROVIDERS if p != "nebius"] 52 | names = [p.NAME for p in providers] 53 | assert set(CATALOG_PROVIDERS) == set(names) 54 | assert len(CATALOG_PROVIDERS) == len(names) 55 | -------------------------------------------------------------------------------- /src/tests/providers/test_tensordock.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from gpuhunt import QueryFilter 4 | from gpuhunt._internal.models import RawCatalogItem 5 | from gpuhunt.providers.tensordock import TensorDockProvider 6 | 7 | 8 | @pytest.fixture 9 | def specs() -> dict: 10 | return { 11 | "cpu": {"amount": 256, "price": 0.003, "type": "Intel Xeon Platinum 8352Y"}, 12 | "gpu": { 13 | "l40-pcie-48gb": { 14 | "amount": 8, 15 | "gtx": False, 16 | "pcie": True, 17 | "price": 1.05, 18 | "rtx": False, 19 | "vram": 48, 20 | } 21 | }, 22 | "ram": {"amount": 1495, "price": 0.002}, 23 | "storage": {"amount": 10252, "price": 5e-05}, 24 | } 25 | 26 | 27 | class TestTensorDockMinimalConfiguration: 28 | def test_no_requirements(self, specs: dict): 29 | offers = TensorDockProvider.optimize_offers(QueryFilter(), specs, "", "") 30 | assert offers == make_offers(specs, cpu=16, memory=96, disk_size=96, gpu_count=1) 31 | 32 | def test_min_cpu_no_balance(self, specs: dict): 33 | offers = TensorDockProvider.optimize_offers( 34 | QueryFilter(min_cpu=4), specs, "", "", balance_resources=False 35 | ) 36 | assert offers == make_offers(specs, cpu=4, memory=96, disk_size=96, gpu_count=1) 37 | 38 | def test_min_cpu(self, specs: dict): 39 | offers = TensorDockProvider.optimize_offers(QueryFilter(min_cpu=4), specs, "", "") 40 | assert offers == make_offers(specs, cpu=16, memory=96, disk_size=96, gpu_count=1) 41 | 42 | def test_too_many_min_cpu(self, specs: dict): 43 | offers = TensorDockProvider.optimize_offers(QueryFilter(min_cpu=1000), specs, "", "") 44 | assert offers == [] 45 | 46 | def test_min_memory_no_balance(self, specs: dict): 47 | offers = TensorDockProvider.optimize_offers( 48 | QueryFilter(min_memory=3), specs, "", "", balance_resources=False 49 | ) 50 | assert offers == make_offers(specs, cpu=2, memory=4, disk_size=48, gpu_count=1) 51 | 52 | def test_min_memory(self, specs: dict): 53 | offers = TensorDockProvider.optimize_offers(QueryFilter(min_memory=3), specs, "", "") 54 | assert offers == make_offers(specs, cpu=16, memory=96, disk_size=96, gpu_count=1) 55 | 56 | def test_too_large_min_memory(self, specs: dict): 57 | offers = TensorDockProvider.optimize_offers(QueryFilter(min_memory=2000), specs, "", "") 58 | assert offers == [] 59 | 60 | def test_min_gpu_count(self, specs: dict): 61 | offers = TensorDockProvider.optimize_offers(QueryFilter(min_gpu_count=2), specs, "", "") 62 | assert offers == make_offers(specs, cpu=32, memory=192, disk_size=192, gpu_count=2) 63 | 64 | def test_min_no_gpu(self, specs: dict): 65 | offers = TensorDockProvider.optimize_offers(QueryFilter(max_gpu_count=0), specs, "", "") 66 | assert offers == [] 67 | 68 | def test_min_total_gpu_memory(self, specs: dict): 69 | offers = TensorDockProvider.optimize_offers( 70 | QueryFilter(min_total_gpu_memory=100), specs, "", "" 71 | ) 72 | assert offers == make_offers(specs, cpu=48, memory=288, disk_size=288, gpu_count=3) 73 | 74 | def test_controversial_gpu(self, specs: dict): 75 | offers = TensorDockProvider.optimize_offers( 76 | QueryFilter(min_total_gpu_memory=100, max_gpu_count=2), specs, "", "" 77 | ) 78 | assert offers == [] 79 | 80 | def test_all_cpu_all_gpu(self, specs: dict): 81 | offers = TensorDockProvider.optimize_offers( 82 | QueryFilter(min_cpu=256, min_gpu_count=1), specs, "", "" 83 | ) 84 | assert offers == make_offers(specs, cpu=256, memory=768, disk_size=768, gpu_count=8) 85 | 86 | 87 | def make_offers( 88 | specs: dict, cpu: int, memory: float, disk_size: float, gpu_count: int 89 | ) -> list[RawCatalogItem]: 90 | gpu = list(specs["gpu"].values())[0] 91 | price = cpu * specs["cpu"]["price"] 92 | price += memory * specs["ram"]["price"] 93 | price += disk_size * specs["storage"]["price"] 94 | price += gpu_count * gpu["price"] 95 | return [ 96 | RawCatalogItem( 97 | instance_name="", 98 | location="", 99 | price=round(price, 5), 100 | cpu=cpu, 101 | memory=memory, 102 | gpu_vendor=None, 103 | gpu_count=gpu_count, 104 | gpu_name="L40", 105 | gpu_memory=gpu["vram"], 106 | spot=False, 107 | disk_size=disk_size, 108 | ) 109 | ] 110 | -------------------------------------------------------------------------------- /src/tests/providers/test_vultr.py: -------------------------------------------------------------------------------- 1 | import gpuhunt._internal.catalog as internal_catalog 2 | from gpuhunt import Catalog 3 | from gpuhunt.providers.vultr import VultrProvider, fetch_offers 4 | 5 | bare_metal = { 6 | "plans_metal": [ 7 | { 8 | "id": "vbm-256c-2048gb-8-mi300x-gpu", 9 | "physical_cpus": 2, 10 | "cpu_count": 128, 11 | "cpu_cores": 128, 12 | "cpu_threads": 256, 13 | "cpu_model": "EPYC 9534", 14 | "cpu_mhz": 2450, 15 | "ram": 2321924, 16 | "disk": 3576, 17 | "disk_count": 8, 18 | "bandwidth": 10240, 19 | "monthly_cost": 11773.44, 20 | "hourly_cost": 17.52, 21 | "monthly_cost_preemptible": 9891.84, 22 | "hourly_cost_preemptible": 14.72, 23 | "type": "NVMe", 24 | "locations": ["ord"], 25 | }, 26 | { 27 | "id": "vbm-112c-2048gb-8-h100-gpu", 28 | "physical_cpus": 2, 29 | "cpu_count": 112, 30 | "cpu_cores": 112, 31 | "cpu_threads": 224, 32 | "cpu_model": "Platinum 8480+", 33 | "cpu_mhz": 2000, 34 | "ram": 2097152, 35 | "disk": 960, 36 | "disk_count": 2, 37 | "bandwidth": 15360, 38 | "monthly_cost": 16074.24, 39 | "hourly_cost": 23.92, 40 | "monthly_cost_preemptible": 12364.8, 41 | "hourly_cost_preemptible": 18.4, 42 | "type": "NVMe", 43 | "locations": ["sea"], 44 | }, 45 | ] 46 | } 47 | 48 | vm_instances = { 49 | "plans": [ 50 | { 51 | "id": "vcg-a100-1c-6g-4vram", 52 | "vcpu_count": 1, 53 | "ram": 6144, 54 | "disk": 70, 55 | "disk_count": 1, 56 | "bandwidth": 1024, 57 | "monthly_cost": 90, 58 | "hourly_cost": 0.123, 59 | "type": "vcg", 60 | "locations": ["ewr"], 61 | "gpu_vram_gb": 4, 62 | "gpu_type": "NVIDIA_A100", 63 | }, 64 | { 65 | "id": "vcg-a100-12c-120g-80vram", 66 | "vcpu_count": 12, 67 | "ram": 122880, 68 | "disk": 1400, 69 | "disk_count": 1, 70 | "bandwidth": 10240, 71 | "monthly_cost": 1750, 72 | "hourly_cost": 2.397, 73 | "type": "vcg", 74 | "locations": ["ewr"], 75 | "gpu_vram_gb": 80, 76 | "gpu_type": "NVIDIA_A100", 77 | }, 78 | { 79 | "id": "vcg-a100-6c-60g-40vram", 80 | "vcpu_count": 12, 81 | "ram": 61440, 82 | "disk": 1400, 83 | "disk_count": 1, 84 | "bandwidth": 10240, 85 | "monthly_cost": 800, 86 | "hourly_cost": 1.397, 87 | "type": "vcg", 88 | "locations": ["ewr"], 89 | "gpu_vram_gb": 40, 90 | "gpu_type": "NVIDIA_A100", 91 | }, 92 | ] 93 | } 94 | 95 | 96 | def test_fetch_offers(requests_mock): 97 | # Mocking the responses for the API endpoints 98 | requests_mock.get("https://api.vultr.com/v2/plans-metal?per_page=500", json=bare_metal) 99 | requests_mock.get("https://api.vultr.com/v2/plans?type=all&per_page=500", json=vm_instances) 100 | 101 | # Fetch offers and verify results 102 | assert len(fetch_offers()) == 5 103 | catalog = Catalog(balance_resources=False, auto_reload=False) 104 | vultr = VultrProvider() 105 | internal_catalog.ONLINE_PROVIDERS = ["vultr"] 106 | internal_catalog.OFFLINE_PROVIDERS = [] 107 | catalog.add_provider(vultr) 108 | assert len(catalog.query(provider=["vultr"], min_gpu_count=1, max_gpu_count=1)) == 3 109 | assert len(catalog.query(provider=["vultr"], min_gpu_memory=80, max_gpu_count=1)) == 1 110 | assert len(catalog.query(provider=["vultr"], gpu_vendor="amd")) == 1 111 | assert len(catalog.query(provider=["vultr"], gpu_name="MI300X")) == 1 112 | -------------------------------------------------------------------------------- /src/tests/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dstackai/gpuhunt/fcc8f33dc037503956d7272c91e730e6584ec492/src/tests/scripts/__init__.py -------------------------------------------------------------------------------- /src/tests/scripts/test_catalog_v1.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from textwrap import dedent 3 | 4 | import pytest 5 | 6 | from gpuhunt.scripts.catalog_v1.__main__ import main 7 | 8 | 9 | @pytest.mark.parametrize( 10 | ("v2_catalog", "v1_catalog"), 11 | [ 12 | pytest.param( 13 | dedent( 14 | """ 15 | instance_name,location,price,cpu,memory,gpu_count,gpu_name,gpu_memory,spot,disk_size,gpu_vendor,flags 16 | i1,us,0.5,30,240,1,A10,24.0,False,,nvidia, 17 | i2,us,1.0,30,240,2,A10,24.0,False,,nvidia,f1 18 | i3,us,2.0,30,240,4,A10,24.0,False,,nvidia,f1 f2 19 | i4,us,4.0,30,240,8,A10,24.0,False,,nvidia, 20 | """ 21 | ).lstrip(), 22 | dedent( 23 | """ 24 | instance_name,location,price,cpu,memory,gpu_count,gpu_name,gpu_memory,spot,disk_size,gpu_vendor 25 | i1,us,0.5,30,240,1,A10,24.0,False,,nvidia 26 | i4,us,4.0,30,240,8,A10,24.0,False,,nvidia 27 | """ 28 | ).lstrip(), 29 | id="filters-out-offers-with-flags", 30 | ), 31 | pytest.param( 32 | dedent( 33 | """ 34 | new_column,instance_name,location,price,cpu,memory,gpu_count,gpu_name,gpu_memory,spot,disk_size,gpu_vendor,flags 35 | ???,i1,us,0.5,30,240,1,A10,24.0,False,,nvidia, 36 | ???,i2,us,1.0,30,240,2,A10,24.0,False,,nvidia, 37 | """ 38 | ).lstrip(), 39 | dedent( 40 | """ 41 | instance_name,location,price,cpu,memory,gpu_count,gpu_name,gpu_memory,spot,disk_size,gpu_vendor 42 | i1,us,0.5,30,240,1,A10,24.0,False,,nvidia 43 | i2,us,1.0,30,240,2,A10,24.0,False,,nvidia 44 | """ 45 | ).lstrip(), 46 | id="removes-extra-columns", 47 | ), 48 | ], 49 | ) 50 | def test_main(tmp_path: Path, v2_catalog: str, v1_catalog: str) -> None: 51 | (tmp_path / "v1").mkdir() 52 | (tmp_path / "v2").mkdir() 53 | v1_file = tmp_path / "v1" / "catalog.csv" 54 | v2_file = tmp_path / "v2" / "catalog.csv" 55 | v2_file.write_text(v2_catalog) 56 | main(["--input", str(v2_file), "--output", str(v1_file)]) 57 | assert v1_file.read_text() == v1_catalog 58 | --------------------------------------------------------------------------------