├── tests ├── tests_v2 │ ├── test_endpoints │ │ ├── fixtures │ │ │ └── create_from │ │ │ │ └── json │ │ │ │ ├── good_document.json │ │ │ │ ├── unexpected_document.json │ │ │ │ └── invalid_document.json │ │ ├── __init__.py │ │ ├── test_systems.py │ │ ├── test_accounts.py │ │ ├── test_resource_budgets.py │ │ ├── test_manifests.py │ │ ├── test_announcements.py │ │ ├── test_api_tokens.py │ │ ├── test_jobs.py │ │ ├── test_services.py │ │ ├── test_node_instructions.py │ │ ├── test_marketplace_demo_tags.py │ │ ├── test_workers.py │ │ ├── test_cloud_init.py │ │ ├── test_breakout.py │ │ ├── test_user_configs.py │ │ └── test_topologies.py │ ├── __init__.py │ ├── test_typing.py │ └── test_air_api.py ├── __init__.py ├── logger.py ├── test_logger.py ├── test_topology_file.py ├── topology_file.py ├── test_login.py ├── test_job.py ├── test_demo.py ├── test_interface.py ├── test_marketplace.py ├── test_resource_budget.py ├── test_exceptions.py ├── test_ssh_key.py ├── test_account.py ├── test_capacity.py ├── test_simulation_interface.py ├── test_token.py ├── test_link.py ├── test_permission.py ├── test_user_preference.py ├── test_fleet.py └── test_node.py ├── mypy.ini ├── unit_test.sh ├── .gitignore ├── air_sdk ├── __init__.py ├── const.py ├── logger.py ├── v2 │ ├── endpoints │ │ ├── systems.py │ │ ├── accounts.py │ │ ├── marketplace_demo_tags.py │ │ ├── api_tokens.py │ │ ├── announcements.py │ │ ├── resource_budgets.py │ │ ├── jobs.py │ │ ├── services.py │ │ ├── node_instructions.py │ │ ├── topologies.py │ │ ├── organizations.py │ │ ├── breakouts.py │ │ ├── user_configs.py │ │ ├── workers.py │ │ ├── __init__.py │ │ ├── cloud_inits.py │ │ ├── manifests.py │ │ └── mixins.py │ ├── exceptions.py │ ├── air_json_encoder.py │ ├── utils.py │ ├── client.py │ └── typing.py ├── exceptions.py ├── capacity.py ├── login.py ├── topology_file.py ├── user_preference.py ├── marketplace.py ├── demo.py ├── interface.py ├── job.py ├── resource_budget.py ├── ssh_key.py ├── account.py ├── simulation_interface.py ├── token.py ├── link.py ├── userconfig.py ├── node.py └── util.py ├── .ruff.toml ├── .vscode └── settings.json ├── .pre-commit-config.yaml └── pyproject.toml /tests/tests_v2/test_endpoints/fixtures/create_from/json/good_document.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/fixtures/create_from/json/unexpected_document.json: -------------------------------------------------------------------------------- 1 | [] -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/fixtures/create_from/json/invalid_document.json: -------------------------------------------------------------------------------- 1 | invalid document -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.8 3 | namespace_packages = True 4 | 5 | strict = True 6 | disallow_untyped_defs = False -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | -------------------------------------------------------------------------------- /tests/tests_v2/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | -------------------------------------------------------------------------------- /unit_test.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: MIT 4 | 5 | python3 -m coverage run -m pytest 6 | coverage report -m 7 | coverage html 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .coverage 2 | __pycache__/ 3 | htmlcov/ 4 | .DS_Store 5 | .pytest_cache 6 | dist/ 7 | air_sdk.egg-info/ 8 | .idea/ 9 | 10 | # The lines below should only be present in the main-github branch 11 | ngci_build/ 12 | .pulse-trufflehog-allowlist.json 13 | -------------------------------------------------------------------------------- /air_sdk/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """Exposes the AIR API client module""" 5 | 6 | from .air_api import * # noqa: F403 7 | from . import v2 # noqa: F401 8 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_systems.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | from air_sdk.v2.endpoints import System 4 | 5 | 6 | class TestSystemEndpointApi: 7 | def test_factory(self, run_api_not_implemented_test, system_factory): 8 | run_api_not_implemented_test(System, system_factory) 9 | -------------------------------------------------------------------------------- /.ruff.toml: -------------------------------------------------------------------------------- 1 | # Maximum allowed length of a line of python code. 2 | line-length = 110 3 | 4 | [lint] 5 | # The full list of rules is here: https://docs.astral.sh/ruff/rules/ 6 | extend-select = [ 7 | # Additional flake8 8 | 'T', 9 | ] 10 | # Ignore the following codes for all files. 11 | ignore = [] 12 | 13 | [format] 14 | # Use single quotes for non-triple-quoted strings. 15 | quote-style = 'single' 16 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_accounts.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | from air_sdk.v2.endpoints import Account 4 | 5 | 6 | class TestAccountEndpointApi: 7 | def test_factory(self, run_api_not_implemented_test, account_factory): 8 | run_api_not_implemented_test(Account, account_factory) 9 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_resource_budgets.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | from air_sdk.v2.endpoints import ResourceBudget 4 | 5 | 6 | class TestResourceBudgetEndpointApi: 7 | def test_factory(self, run_api_not_implemented_test, resource_budget_factory): 8 | run_api_not_implemented_test(ResourceBudget, resource_budget_factory) 9 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.formatOnSave": true, 4 | "editor.codeActionsOnSave": { 5 | "source.fixAll": "explicit" 6 | }, 7 | "editor.defaultFormatter": "charliermarsh.ruff" 8 | }, 9 | "ruff.lint.args": ["--config=.ruff.toml"], 10 | "ruff.format.args": [ 11 | "--config=.ruff.toml" 12 | ], 13 | "python.testing.unittestEnabled": false, 14 | "python.testing.pytestEnabled": true 15 | } -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.5.3 # Ruff version 4 | hooks: 5 | # Run the linter, and fix issues that are safely-fixable. 6 | - id: ruff 7 | args: [ --fix ] 8 | # Run the formatter. 9 | - id: ruff-format 10 | 11 | - repo: local 12 | hooks: 13 | - id: mypy 14 | name: mypy 15 | entry: mypy air_sdk/v2 --config-file mypy.ini 16 | language: system 17 | types: [ python ] 18 | pass_filenames: false 19 | -------------------------------------------------------------------------------- /air_sdk/const.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Constants shared throughout the SDK. 6 | """ 7 | 8 | ALLOWED_HOSTS = [ 9 | 'air.nvidia.com', 10 | 'staging.air.nvidia.com', 11 | 'air.cumulusnetworks.com', 12 | 'staging.air.cumulusnetworks.com', 13 | ] 14 | 15 | DEFAULT_API_URL = 'https://air.nvidia.com/api/' 16 | 17 | 18 | DEFAULT_CONNECT_TIMEOUT = 16 # seconds 19 | DEFAULT_READ_TIMEOUT = 61 # seconds 20 | DEFAULT_PAGINATION_PAGE_SIZE = 200 # Objects per paginated response 21 | -------------------------------------------------------------------------------- /air_sdk/logger.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Logger module 6 | """ 7 | 8 | import logging 9 | import re 10 | 11 | 12 | def _redact(record): 13 | """Redact any strings in the log message that match a sensitive pattern""" 14 | sensitive_patterns = [r'(password[\'\"]:\s?[\'\"]).*([\'\"])'] 15 | for pattern in sensitive_patterns: 16 | record.msg = re.sub(pattern, r'\g<1>***\g<2>', record.msg) 17 | return record 18 | 19 | 20 | air_sdk_logger = logging.getLogger('air_sdk') # pylint: disable=invalid-name 21 | air_sdk_logger.addHandler(logging.NullHandler()) 22 | air_sdk_logger.addFilter(_redact) 23 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/systems.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from dataclasses import dataclass 5 | from typing import Any 6 | 7 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi, PrimaryKey, ApiNotImplementedMixin 8 | 9 | 10 | @dataclass(eq=False) 11 | class System(ApiNotImplementedMixin, AirModel): 12 | id: str 13 | 14 | @classmethod 15 | def get_model_api(cls): 16 | return SystemEndpointApi 17 | 18 | 19 | class SystemEndpointApi(BaseEndpointApi[System]): 20 | API_PATH = 'systems' # A placeholder 21 | model = System 22 | 23 | def get(self, pk: PrimaryKey, **params: Any) -> System: 24 | return self.load_model({'id': str(pk)}) 25 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/accounts.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from dataclasses import dataclass 5 | from typing import Any 6 | 7 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi, PrimaryKey, ApiNotImplementedMixin 8 | 9 | 10 | @dataclass(eq=False) 11 | class Account(ApiNotImplementedMixin, AirModel): 12 | id: str 13 | 14 | @classmethod 15 | def get_model_api(cls): 16 | return AccountEndpointApi 17 | 18 | 19 | class AccountEndpointApi(BaseEndpointApi[Account]): 20 | API_PATH = 'accounts' # A placeholder 21 | model = Account 22 | 23 | def get(self, pk: PrimaryKey, **params: Any) -> Account: 24 | return self.load_model({'id': str(pk)}) 25 | -------------------------------------------------------------------------------- /air_sdk/v2/exceptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | SDK-specific exceptions 6 | """ 7 | 8 | from typing import Optional 9 | 10 | 11 | class AirError(Exception): 12 | def __init__( 13 | self, 14 | message: str = 'An error occurred within the air_sdk.v2.AirApi', 15 | status_code: Optional[int] = None, 16 | ): 17 | self.status_code = status_code 18 | super().__init__(message) 19 | 20 | 21 | class AirModelAttributeError(AirError): 22 | def __init__( 23 | self, 24 | message: str = 'An error occurred while accessing an AirModel attribute.', 25 | status_code: Optional[int] = None, 26 | ): 27 | self.message = message 28 | super().__init__(message=self.message, status_code=status_code) 29 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/marketplace_demo_tags.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from dataclasses import dataclass 5 | 6 | from air_sdk.v2.endpoints import mixins 7 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi 8 | 9 | 10 | @dataclass(eq=False) 11 | class MarketplaceDemoTag(AirModel): 12 | name: str 13 | 14 | @property 15 | def primary_key_field(self) -> str: 16 | return 'name' 17 | 18 | @classmethod 19 | def get_model_api(cls): 20 | """ 21 | Returns the respective `AirModelAPI` type for this model. 22 | """ 23 | return MarketplaceDemoTagsEndpointApi 24 | 25 | 26 | class MarketplaceDemoTagsEndpointApi( 27 | mixins.ListApiMixin[MarketplaceDemoTag], BaseEndpointApi[MarketplaceDemoTag] 28 | ): 29 | API_PATH = 'marketplace-demo-tags' 30 | 31 | model = MarketplaceDemoTag 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "air-sdk" 3 | version = "2.21.1" 4 | description = "Python SDK for interacting with NVIDIA Air" 5 | license = "MIT" 6 | classifiers = ["Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent"] 7 | homepage = "https://github.com/NVIDIA/air_sdk" 8 | authors = ["NVIDIA Air "] 9 | readme = "README.md" 10 | 11 | [tool.poetry.urls] 12 | "Homepage" = "https://github.com/NVIDIA/air_sdk/issues" 13 | "Bug Tracker" = "https://github.com/NVIDIA/air_sdk/issues" 14 | 15 | [tool.poetry.dependencies] 16 | python = "^3.8" 17 | python-dateutil = "^2.9.0" 18 | requests = "^2.32.4" 19 | 20 | [tool.poetry.group.dev.dependencies] 21 | coverage = "^7.5.4" 22 | faker = "^26.0.0" 23 | mypy = "1.10.1" 24 | pre-commit = "^2.21" 25 | pytest = "^7.4.4" 26 | requests-mock = "^1.12.1" 27 | ruff = "0.5.3" 28 | 29 | [tool.coverage.run] 30 | omit = [ 31 | "./tests/*", 32 | "__init__.py" 33 | ] 34 | 35 | [build-system] 36 | requires = ["poetry-core>=1.0.0"] 37 | build-backend = "poetry.core.masonry.api" 38 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_manifests.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | import faker 5 | import pytest 6 | 7 | faker.Faker.seed(0) 8 | fake = faker.Faker() 9 | 10 | 11 | class TestManifestEndpointApi: 12 | def test_list(self, api, run_list_test, manifest_factory): 13 | run_list_test(api.manifests, manifest_factory) 14 | 15 | def test_refresh(self, api, run_refresh_test, manifest_factory): 16 | run_refresh_test(api.manifests, manifest_factory) 17 | 18 | def test_delete(self, api, run_delete_test, manifest_factory): 19 | run_delete_test(api.manifests, manifest_factory) 20 | 21 | 22 | class TestManifestModelRelations: 23 | def test_owner_access(self, api, manifest_factory): 24 | manifest = manifest_factory(api) 25 | owner = manifest.owner 26 | assert owner.__fk_resolved__ is False 27 | assert owner.id is not None 28 | assert owner.__fk_resolved__ is True 29 | with pytest.raises(NotImplementedError): 30 | owner.refresh() 31 | -------------------------------------------------------------------------------- /air_sdk/v2/air_json_encoder.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | 5 | from datetime import datetime 6 | from json import JSONEncoder 7 | from typing import Any 8 | from uuid import UUID 9 | 10 | from air_sdk.v2.utils import datetime_to_iso_string 11 | 12 | 13 | class AirJSONEncoder(JSONEncoder): 14 | """`JSONEncoder` with Air-specific decoding logic.""" 15 | 16 | def default(self, o: Any) -> Any: 17 | """`JSONEncoder.default` with ability to decode `datetime` and UUID objects.""" 18 | from air_sdk.v2.air_model import AirModel 19 | 20 | if isinstance(o, datetime): 21 | return datetime_to_iso_string(o) 22 | if isinstance(o, UUID): 23 | return str(o) 24 | if isinstance(o, AirModel): 25 | pk = getattr(o, '__pk__', None) 26 | if pk is None: 27 | raise ValueError( 28 | f'The `{o.__class__.__name__}` provided is not JSON serializable: __pk__ is None' 29 | ) 30 | return str(pk) 31 | return super().default(o) 32 | -------------------------------------------------------------------------------- /tests/logger.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for logger.py 6 | """ 7 | # pylint: disable=missing-function-docstring,missing-class-docstring,no-self-use,unused-argument 8 | 9 | import logging 10 | from unittest import TestCase 11 | from unittest.mock import MagicMock 12 | 13 | from ..air_sdk.logger import air_sdk_logger, _redact 14 | 15 | 16 | class TestLogger(TestCase): 17 | def test_logger(self): 18 | self.assertEqual(air_sdk_logger.name, 'air_sdk') 19 | self.assertEqual(len(air_sdk_logger.handlers), 1) 20 | self.assertIsInstance(air_sdk_logger.handlers[0], logging.NullHandler) 21 | self.assertEqual(len(air_sdk_logger.filters), 1) 22 | self.assertEqual(air_sdk_logger.filters[0], _redact) 23 | 24 | def test_redact(self): 25 | record = MagicMock() 26 | record.msg = '{"password": "abc123"}' 27 | 28 | self.assertEqual(_redact(record).msg, '{"password": "***"}') 29 | 30 | def test_redact_no_op(self): 31 | msg = 'foo' 32 | record = MagicMock() 33 | record.msg = msg 34 | 35 | self.assertEqual(_redact(record).msg, msg) 36 | -------------------------------------------------------------------------------- /tests/test_logger.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for logger.py 6 | """ 7 | # pylint: disable=missing-function-docstring,missing-class-docstring,no-self-use,unused-argument 8 | 9 | import logging 10 | from unittest import TestCase 11 | from unittest.mock import MagicMock 12 | 13 | from air_sdk.logger import air_sdk_logger, _redact 14 | 15 | 16 | class TestLogger(TestCase): 17 | def test_logger(self): 18 | self.assertEqual(air_sdk_logger.name, 'air_sdk') 19 | self.assertEqual(len(air_sdk_logger.handlers), 1) 20 | self.assertIsInstance(air_sdk_logger.handlers[0], logging.NullHandler) 21 | self.assertEqual(len(air_sdk_logger.filters), 1) 22 | self.assertEqual(air_sdk_logger.filters[0], _redact) 23 | 24 | def test_redact(self): 25 | record = MagicMock() 26 | record.msg = '{"password": "abc123"}' 27 | 28 | self.assertEqual(_redact(record).msg, '{"password": "***"}') 29 | 30 | def test_redact_no_op(self): 31 | msg = 'foo' 32 | record = MagicMock() 33 | record.msg = msg 34 | 35 | self.assertEqual(_redact(record).msg, msg) 36 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/api_tokens.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from dataclasses import dataclass, field 5 | from datetime import datetime 6 | from typing import Optional 7 | 8 | from air_sdk.v2.endpoints import mixins 9 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi 10 | from air_sdk.v2.utils import validate_payload_types 11 | 12 | 13 | @dataclass(eq=False) 14 | class ApiToken(AirModel): 15 | id: str = field(repr=False) 16 | name: str 17 | created: datetime = field(repr=False) 18 | expiry: Optional[datetime] 19 | 20 | @classmethod 21 | def get_model_api(cls): 22 | """ 23 | Returns the respective `AirModelAPI` type for this model. 24 | """ 25 | return ApiTokenEndpointApi 26 | 27 | 28 | class ApiTokenEndpointApi( 29 | mixins.ListApiMixin[ApiToken], 30 | mixins.CreateApiMixin[ApiToken], 31 | mixins.DeleteApiMixin, 32 | BaseEndpointApi[ApiToken], 33 | ): 34 | API_PATH = 'api-tokens' 35 | model = ApiToken 36 | 37 | @validate_payload_types 38 | def create(self, name: str, expiry: Optional[datetime] = None) -> ApiToken: 39 | return super().create(name=name, expiry=expiry) 40 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_announcements.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | import faker 5 | import pytest 6 | 7 | faker.Faker.seed(0) 8 | fake = faker.Faker() 9 | 10 | 11 | class TestAnnouncementEndpointApi: 12 | def test_list(self, api, run_list_test, announcement_factory): 13 | run_list_test(api.announcements, announcement_factory) 14 | 15 | def test_delete(self, api, run_delete_test, announcement_factory): 16 | run_delete_test(api.announcements, announcement_factory) 17 | 18 | @pytest.mark.parametrize( 19 | 'payload,is_valid', 20 | ( 21 | ({}, False), 22 | ({'severity': None}, False), 23 | ({'severity': None, 'message': None}, False), 24 | ({'severity': fake.slug()}, False), 25 | ({'severity': fake.slug(), 'message': None}, False), 26 | ({'severity': fake.slug(), 'message': fake.text()}, True), 27 | ({'severity': None, 'message': fake.text()}, False), 28 | ), 29 | ) 30 | def test_create(self, api, announcement_factory, run_create_test_case, payload, is_valid): 31 | run_create_test_case(api.announcements, announcement_factory, payload, is_valid) 32 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_api_tokens.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | import faker 5 | import pytest 6 | 7 | faker.Faker.seed(0) 8 | fake = faker.Faker() 9 | 10 | 11 | class TestApiTokenEndpointApi: 12 | def test_list(self, api, run_list_test, api_token_factory): 13 | run_list_test(api.api_tokens, api_token_factory) 14 | 15 | def test_delete(self, api, run_delete_test, api_token_factory): 16 | run_delete_test(api.api_tokens, api_token_factory) 17 | 18 | @pytest.mark.parametrize( 19 | 'payload,is_valid', 20 | ( 21 | ({'name': None}, False), 22 | ({}, False), 23 | ({'random_key': fake.slug()}, False), 24 | ({'name': fake.slug()}, True), 25 | ({'name': fake.slug(), 'random_key': None}, False), 26 | ({'name': fake.slug(), 'expiry': None}, True), 27 | ({'name': fake.slug(), 'expiry': fake.date_time()}, True), 28 | ), 29 | ) 30 | def test_create(self, api, api_token_factory, run_create_test_case, payload, is_valid): 31 | """This tests that the data provided is properly validated and used.""" 32 | run_create_test_case(api.api_tokens, api_token_factory, payload, is_valid) 33 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/announcements.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from dataclasses import dataclass, field 5 | from datetime import datetime 6 | from typing import Optional 7 | 8 | from air_sdk.v2.endpoints import mixins 9 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi 10 | from air_sdk.v2.utils import validate_payload_types 11 | 12 | 13 | @dataclass(eq=False) 14 | class Announcement(AirModel): 15 | id: str = field(repr=False) 16 | severity: str 17 | created: datetime 18 | modified: datetime = field(repr=False) 19 | message: Optional[str] = field(repr=False) 20 | 21 | @classmethod 22 | def get_model_api(cls): 23 | """ 24 | Returns the respective `AirModelAPI` type for this model. 25 | """ 26 | return AnnouncementEndpointApi 27 | 28 | 29 | class AnnouncementEndpointApi( 30 | mixins.ListApiMixin[Announcement], 31 | mixins.CreateApiMixin[Announcement], 32 | mixins.DeleteApiMixin, 33 | BaseEndpointApi[Announcement], 34 | ): 35 | API_PATH = 'announcements' 36 | model = Announcement 37 | 38 | @validate_payload_types 39 | def create(self, severity: str, message: str) -> Announcement: 40 | return super().create(severity=severity, message=message) 41 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_jobs.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | import pytest 5 | import faker 6 | 7 | faker.Faker.seed(0) 8 | fake = faker.Faker() 9 | 10 | 11 | class TestJobEndpointApi: 12 | def test_list(self, api, run_list_test, job_factory): 13 | run_list_test(api.jobs, job_factory) 14 | 15 | def test_job_worker_field(self, api, job_factory, worker_factory): 16 | assert job_factory(api).worker.__pk__ is not None 17 | assert job_factory(api, worker=None).worker is None 18 | assert job_factory(api, worker=worker_factory(api)).worker.__pk__ is not None 19 | 20 | def test_refresh(self, api, run_refresh_test, job_factory): 21 | run_refresh_test(api.jobs, job_factory) 22 | 23 | @pytest.mark.parametrize( 24 | 'payload,is_valid', 25 | ( 26 | ({}, False), 27 | ({'fake_param': fake.slug()}, False), 28 | ({'state': fake.pyint()}, False), 29 | ({'state': None}, False), 30 | ({'state': False}, False), 31 | ({'state': 'COMPLETE'}, True), 32 | ({'state': 'FAILED'}, True), 33 | ({'state': 'WORKING'}, True), 34 | ), 35 | ) 36 | def test_update(self, api, run_update_patch_test, job_factory, payload, is_valid): 37 | run_update_patch_test(api.jobs, job_factory, payload, is_valid) 38 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/resource_budgets.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from dataclasses import dataclass, field 5 | from typing import Any, Optional 6 | 7 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi, PrimaryKey, ApiNotImplementedMixin 8 | 9 | 10 | @dataclass(eq=False) 11 | class ResourceBudget(ApiNotImplementedMixin, AirModel): 12 | id: str 13 | cpu: Optional[int] = field(repr=False) 14 | cpu_used: Optional[int] = field(repr=False) 15 | image_uploads: Optional[int] = field(repr=False) 16 | image_uploads_used: Optional[int] = field(repr=False) 17 | memory: Optional[int] = field(repr=False) 18 | memory_used: Optional[int] = field(repr=False) 19 | simulations: Optional[int] = field(repr=False) 20 | simulations_used: Optional[int] = field(repr=False) 21 | storage: Optional[int] = field(repr=False) 22 | storage_used: Optional[int] = field(repr=False) 23 | userconfigs: Optional[int] = field(repr=False) 24 | userconfigs_used: Optional[int] = field(repr=False) 25 | 26 | @classmethod 27 | def get_model_api(cls): 28 | return ResourceBudgetEndpointApi 29 | 30 | 31 | class ResourceBudgetEndpointApi(BaseEndpointApi[ResourceBudget]): 32 | API_PATH = 'resource_budgets' # A placeholder 33 | model = ResourceBudget 34 | 35 | def get(self, pk: PrimaryKey, **params: Any) -> ResourceBudget: 36 | return self.load_model({'id': str(pk)}) 37 | -------------------------------------------------------------------------------- /tests/test_topology_file.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for topology_file.py 6 | """ 7 | 8 | from unittest import TestCase 9 | from unittest.mock import MagicMock, patch 10 | 11 | from air_sdk import topology_file 12 | 13 | 14 | class TestTopologyFile(TestCase): 15 | def setUp(self): 16 | self.model = topology_file.TopologyFile(MagicMock()) 17 | self.model.id = 'abc123' 18 | 19 | def test_repr(self): 20 | self.assertEqual(str(self.model), f'') 21 | 22 | 23 | class TestTopologyFileApi(TestCase): 24 | def setUp(self): 25 | self.client = MagicMock() 26 | self.client.api_url = 'http://testserver/api/v1' 27 | self.api = topology_file.TopologyFileApi(self.client) 28 | 29 | def test_init(self): 30 | self.assertEqual(self.api.client, self.client) 31 | self.assertEqual(self.api.url, 'http://testserver/api/v2/topology-file/') 32 | 33 | @patch('air_sdk.util.raise_if_invalid_response') 34 | def test_get(self, mock_raise): 35 | file_id = 'abc123' 36 | kwargs = {'foo': 'bar'} 37 | self.client.get.return_value.json.return_value = {'test': 'success'} 38 | 39 | res = self.api.get(file_id, **kwargs) 40 | self.client.get.assert_called_with(f'{self.api.url}{file_id}/', params=kwargs) 41 | mock_raise.assert_called_with(self.client.get.return_value) 42 | self.assertIsInstance(res, topology_file.TopologyFile) 43 | -------------------------------------------------------------------------------- /tests/topology_file.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for topology_file.py 6 | """ 7 | 8 | from unittest import TestCase 9 | from unittest.mock import MagicMock, patch 10 | 11 | from ..air_sdk import topology_file 12 | 13 | 14 | class TestTopologyFile(TestCase): 15 | def setUp(self): 16 | self.model = topology_file.TopologyFile(MagicMock()) 17 | self.model.id = 'abc123' 18 | 19 | def test_repr(self): 20 | self.assertEqual(str(self.model), f'') 21 | 22 | 23 | class TestTopologyFileApi(TestCase): 24 | def setUp(self): 25 | self.client = MagicMock() 26 | self.client.api_url = 'http://testserver/api/v1' 27 | self.api = topology_file.TopologyFileApi(self.client) 28 | 29 | def test_init(self): 30 | self.assertEqual(self.api.client, self.client) 31 | self.assertEqual(self.api.url, 'http://testserver/api/v2/topology-file/') 32 | 33 | @patch('air_sdk.air_sdk.util.raise_if_invalid_response') 34 | def test_get(self, mock_raise): 35 | file_id = 'abc123' 36 | kwargs = {'foo': 'bar'} 37 | self.client.get.return_value.json.return_value = {'test': 'success'} 38 | 39 | res = self.api.get(file_id, **kwargs) 40 | self.client.get.assert_called_with(f'{self.api.url}{file_id}/', params=kwargs) 41 | mock_raise.assert_called_with(self.client.get.return_value) 42 | self.assertIsInstance(res, topology_file.TopologyFile) 43 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/jobs.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from dataclasses import dataclass, field 5 | from datetime import datetime 6 | from typing import Optional 7 | 8 | from air_sdk.v2.endpoints import mixins 9 | from air_sdk.v2.endpoints.simulations import Simulation 10 | from air_sdk.v2.endpoints.workers import Worker 11 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi 12 | from air_sdk.v2.utils import validate_payload_types 13 | 14 | 15 | @dataclass(eq=False) 16 | class Job(AirModel): 17 | id: str 18 | category: str 19 | state: str 20 | created: datetime = field(repr=False) 21 | last_updated: datetime = field(repr=False) 22 | notes: Optional[str] = field(repr=False) 23 | data: Optional[str] = field(repr=False) 24 | simulation: Optional[Simulation] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 25 | worker: Optional[Worker] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 26 | 27 | @classmethod 28 | def get_model_api(cls): 29 | """ 30 | Returns the respective `AirModelAPI` type for this model. 31 | """ 32 | return JobEndpointApi 33 | 34 | @validate_payload_types 35 | def update(self, state: str) -> None: 36 | super().update(state=state) 37 | 38 | 39 | class JobEndpointApi( 40 | mixins.ListApiMixin[Job], 41 | mixins.GetApiMixin[Job], 42 | mixins.PatchApiMixin[Job], 43 | BaseEndpointApi[Job], 44 | ): 45 | API_PATH = 'jobs' 46 | model = Job 47 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/services.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | 5 | from dataclasses import field, dataclass 6 | from datetime import datetime 7 | from typing import Optional, Union 8 | 9 | from air_sdk.v2.endpoints import mixins 10 | from air_sdk.v2.endpoints.interfaces import Interface 11 | from air_sdk.v2.endpoints.simulations import Simulation 12 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi, PrimaryKey 13 | from air_sdk.v2.utils import validate_payload_types 14 | 15 | 16 | @dataclass(eq=False) 17 | class Service(AirModel): 18 | id: str 19 | name: str 20 | created: datetime 21 | modified: datetime 22 | dest_port: int 23 | src_port: int 24 | service_type: str 25 | interface: Interface = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 26 | simulation: Simulation = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 27 | host: Optional[str] 28 | 29 | @classmethod 30 | def get_model_api(cls): 31 | """ 32 | Returns the respective `AirModelAPI` type for this model. 33 | """ 34 | return ServiceEndpointApi 35 | 36 | 37 | class ServiceEndpointApi( 38 | mixins.ListApiMixin[Service], 39 | mixins.CreateApiMixin[Service], 40 | mixins.DeleteApiMixin, 41 | BaseEndpointApi[Service], 42 | ): 43 | API_PATH = 'simulations/nodes/interfaces/services/' 44 | model = Service 45 | 46 | @validate_payload_types 47 | def create( 48 | self, name: str, dest_port: int, interface: Union[Interface, PrimaryKey], service_type: str = 'ssh' 49 | ) -> Service: 50 | return super().create( 51 | name=name, 52 | dest_port=dest_port, 53 | interface=interface, 54 | service_type=service_type, 55 | ) 56 | -------------------------------------------------------------------------------- /air_sdk/exceptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Custom exceptions for the AIR SDK 6 | """ 7 | 8 | 9 | class AirError(Exception): 10 | """ 11 | Base exception class. All custom exceptions should inherit from this class. 12 | """ 13 | 14 | def __init__(self, message='', status_code=None): 15 | self.status_code = status_code 16 | super().__init__(message) 17 | 18 | 19 | class AirAuthorizationError(AirError): 20 | """Raised when authorization with the API fails.""" 21 | 22 | def __init__(self, message: str = 'An error occurred when authorizing the Air API', status_code=None): 23 | self.message = message 24 | super().__init__(message=self.message, status_code=status_code) 25 | 26 | 27 | class AirUnexpectedResponse(AirError): 28 | """Raised when the API returns an unexpected response.""" 29 | 30 | def __init__(self, message='', status_code=None): 31 | self.message = 'Received an unexpected response from the Air API' 32 | if status_code: 33 | self.message += f' ({status_code})' 34 | self.message += f': {message}' 35 | super().__init__(message=self.message, status_code=status_code) 36 | 37 | 38 | class AirForbiddenError(AirError): 39 | """Raised when an API call returns a 403 Forbidden error""" 40 | 41 | def __init__(self, message='Received 403 Forbidden. Please call AirApi.authorize().'): 42 | self.message = message 43 | super().__init__(message=self.message, status_code=403) 44 | 45 | 46 | class AirObjectDeleted(AirError): 47 | """Raised when accessing a previously instantiated object that has since been deleted""" 48 | 49 | def __init__(self, cls, message=''): 50 | self.message = message 51 | if not self.message: 52 | self.message = f'{cls} object has been deleted and should no longer be referenced' 53 | super().__init__(message=self.message) 54 | -------------------------------------------------------------------------------- /tests/test_login.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for login.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import login 13 | 14 | 15 | class TestLogin(TestCase): 16 | def setUp(self): 17 | self.model = login.Login(MagicMock()) 18 | self.model.id = 'abc123' 19 | 20 | def test_init_(self): 21 | self.assertFalse(self.model._deletable) 22 | self.assertFalse(self.model._updatable) 23 | 24 | def test_repr(self): 25 | self.assertEqual(str(self.model), f'') 26 | 27 | def test_repr_deleted(self): 28 | self.model._deleted = True 29 | self.assertTrue('Deleted Object' in str(self.model)) 30 | 31 | 32 | class TestLoginApi(TestCase): 33 | def setUp(self): 34 | self.client = MagicMock() 35 | self.client.api_url = 'http://testserver/api' 36 | self.api = login.LoginApi(self.client) 37 | 38 | def test_init_(self): 39 | self.assertEqual(self.api.client, self.client) 40 | self.assertEqual(self.api.url, 'http://testserver/api/login/') 41 | 42 | @patch('air_sdk.login.LoginApi.list') 43 | def test_get(self, mock_list): 44 | res = self.api.get(foo='bar') 45 | mock_list.assert_called_with(foo='bar') 46 | self.assertEqual(res, mock_list.return_value) 47 | 48 | @patch('air_sdk.util.raise_if_invalid_response') 49 | def test_list(self, mock_raise): 50 | self.client.get.return_value.json.return_value = {'id': 'abc'} 51 | res = self.api.list(foo='bar') 52 | self.client.get.assert_called_with(f'{self.client.api_url}/login/', params={'foo': 'bar'}) 53 | mock_raise.assert_called_with(self.client.get.return_value) 54 | self.assertIsInstance(res, login.Login) 55 | self.assertEqual(res.id, 'abc') 56 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_services.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | 5 | import pytest 6 | import faker 7 | 8 | faker.Faker.seed(0) 9 | fake = faker.Faker() 10 | 11 | 12 | class TestServiceEndpointApi: 13 | def test_list(self, api, run_list_test, service_factory): 14 | run_list_test(api.services, service_factory) 15 | 16 | def test_delete(self, api, run_delete_test, service_factory): 17 | run_delete_test(api.services, service_factory) 18 | 19 | @pytest.mark.parametrize( 20 | 'payload,is_valid', 21 | ( 22 | ({}, False), 23 | ({'name': None, 'dest_port': None, 'interface': None}, False), 24 | ({'name': fake.slug(), 'dest_port': fake.slug(), 'interface': fake.slug()}, False), 25 | ( 26 | { 27 | 'name': fake.slug(), 28 | 'dest_port': fake.slug(), 29 | 'interface': fake.slug(), 30 | 'service_type': None, 31 | }, 32 | False, 33 | ), 34 | ( 35 | { 36 | 'name': fake.slug(), 37 | 'dest_port': fake.slug(), 38 | 'interface': fake.slug(), 39 | 'service_type': fake.pyint(), 40 | }, 41 | False, 42 | ), 43 | ({'name': fake.slug(), 'dest_port': fake.pyint(), 'interface': fake.uuid4(cast_to=str)}, True), 44 | ( 45 | { 46 | 'name': fake.slug(), 47 | 'dest_port': fake.pyint(), 48 | 'interface': fake.uuid4(cast_to=str), 49 | 'service_type': fake.slug(), 50 | }, 51 | True, 52 | ), 53 | ), 54 | ) 55 | def test_create(self, api, service_factory, run_create_test_case, payload, is_valid): 56 | """This tests that the data provided is properly validated and used.""" 57 | run_create_test_case(api.services, service_factory, payload, is_valid) 58 | -------------------------------------------------------------------------------- /air_sdk/capacity.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Capacity module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class Capacity(AirModel): 13 | """ 14 | View platform capacity 15 | 16 | ### json 17 | Returns a JSON string representation of the capacity 18 | 19 | ### refresh 20 | Syncs the capacity with all values returned by the API 21 | """ 22 | 23 | _deletable = False 24 | _updatable = False 25 | 26 | def __repr__(self): 27 | if self._deleted or getattr(self, 'copies', None) is None: 28 | return super().__repr__() 29 | return f'' 30 | 31 | 32 | class CapacityApi: 33 | """High-level interface for the Simulation API""" 34 | 35 | def __init__(self, client): 36 | self.client = client 37 | self.url = self.client.api_url + '/capacity/' 38 | 39 | @util.deprecated('CapacityApi.get()') 40 | def get_capacity(self, simulation=None, simulation_id=None): # pylint: disable=missing-function-docstring 41 | if not simulation and not simulation_id: 42 | raise ValueError('Must pass a simulation or simulation_id argument') 43 | sim_id = simulation_id or simulation.id 44 | return self.get(simulation_id=sim_id) 45 | 46 | def get(self, simulation_id, **kwargs): 47 | """ 48 | Get current platform capacity for a [`Simulation`](/docs/simulation) 49 | 50 | Arguments: 51 | simulation_id (str | `Simulation`): Simulation or ID 52 | 53 | Returns: 54 | [`Capacity`](/docs/capacity) 55 | 56 | Raises: 57 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 58 | or valid response JSON 59 | 60 | Example: 61 | ``` 62 | >>> air.capacity.get(simulation) 63 | 64 | ``` 65 | """ 66 | if isinstance(simulation_id, AirModel): 67 | simulation_id = simulation_id.id 68 | url = f'{self.url}{simulation_id}/' 69 | res = self.client.get(url, params=kwargs) 70 | util.raise_if_invalid_response(res) 71 | return Capacity(self, **res.json()) 72 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/node_instructions.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from __future__ import annotations 5 | from dataclasses import dataclass, field 6 | from datetime import datetime 7 | from http import HTTPStatus 8 | from typing import Literal, Optional 9 | 10 | from air_sdk.v2.endpoints import mixins 11 | from air_sdk.v2.air_model import AirModel, PrimaryKey, BaseEndpointApi 12 | from air_sdk.v2.endpoints.nodes import Node 13 | from air_sdk.util import raise_if_invalid_response 14 | from air_sdk.v2.utils import validate_payload_types 15 | 16 | 17 | @dataclass(eq=False) 18 | class NodeInstruction(AirModel): 19 | id: str = field(repr=False) 20 | node: Node = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 21 | instruction: str 22 | state: str 23 | created: datetime = field(repr=False) 24 | modified: datetime = field(repr=False) 25 | 26 | @classmethod 27 | def get_model_api(cls): 28 | """ 29 | Returns the respective `AirModelAPI` type for this model. 30 | """ 31 | return NodeInstructionsEndpointApi 32 | 33 | @property 34 | def primary_key_field(self) -> str: 35 | return 'node' 36 | 37 | @property 38 | def __pk__(self) -> PrimaryKey: 39 | return getattr(self, self.primary_key_field).__pk__ # type: ignore 40 | 41 | 42 | class NodeInstructionsEndpointApi(BaseEndpointApi[NodeInstruction]): 43 | API_PATH = 'simulations/nodes/{id}/instructions' # Placeholder 44 | model = NodeInstruction 45 | 46 | @validate_payload_types 47 | def create( 48 | self, 49 | pk: PrimaryKey, 50 | executor: Literal['init', 'file', 'shell'], 51 | data: str, 52 | monitor: Optional[str] = None, 53 | ) -> NodeInstruction: 54 | params = { 55 | 'executor': executor, 56 | 'data': data, 57 | } 58 | if monitor: 59 | params['monitor'] = monitor 60 | 61 | detail_url = self.url.format(id=str(pk)) 62 | response = self.__api__.client.post(detail_url, data=mixins.serialize_payload(params)) 63 | raise_if_invalid_response(response, status_code=HTTPStatus.CREATED) 64 | return self.load_model(response.json()) 65 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/topologies.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2018-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | 5 | from dataclasses import dataclass 6 | from typing import Any, TypedDict, cast 7 | 8 | from air_sdk.v2.endpoints.mixins import serialize_payload 9 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi, DataDict 10 | from air_sdk.util import raise_if_invalid_response 11 | from air_sdk.v2.utils import join_urls 12 | 13 | 14 | class TopologyParseResponse(TypedDict): 15 | source_format: str 16 | destination_format: str 17 | topology_data: str 18 | 19 | 20 | @dataclass(eq=False) 21 | class Topology(AirModel): 22 | @classmethod 23 | def get_model_api(cls): 24 | """ 25 | Returns the respective `AirModelAPI` type for this model. 26 | """ 27 | return TopologyEndpointApi 28 | 29 | 30 | class TopologyEndpointApi(BaseEndpointApi[Topology]): 31 | API_PATH = 'topology' 32 | PARSE_PATH = 'parse' 33 | model = Topology 34 | 35 | def __init__(self, *args: Any, **kwargs: Any): 36 | super().__init__(*args, **kwargs) 37 | self.url_v1 = join_urls(self.__api__.client.base_url.replace('/api/v2/', '/api/v1/'), self.API_PATH) 38 | 39 | def parse(self, source_format: str, destination_format: str, topology_data: str) -> TopologyParseResponse: 40 | """ 41 | Parse topology data from one format to another. 42 | 43 | Args: 44 | source_format (str): The format of the input topology data 45 | destination_format (str): The desired output format 46 | topology_data (str): The topology data to parse 47 | 48 | Returns: 49 | TopologyParseResponse: The parsed topology data containing source_format, destination_format, 50 | topology_data 51 | """ 52 | payload: DataDict = { 53 | 'source_format': source_format, 54 | 'destination_format': destination_format, 55 | 'topology_data': topology_data, 56 | } 57 | parse_detail_url = join_urls(self.__api__.topologies.url_v1, self.PARSE_PATH) 58 | response = self.__api__.client.post(parse_detail_url, data=serialize_payload(payload)) 59 | raise_if_invalid_response(response) 60 | return cast(TopologyParseResponse, response.json()) 61 | -------------------------------------------------------------------------------- /air_sdk/login.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Login module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class Login(AirModel): 13 | """ 14 | View login information 15 | 16 | ### json 17 | Returns a JSON string representation of the login info 18 | 19 | ### refresh 20 | Syncs the login info with all values returned by the API 21 | """ 22 | 23 | _deletable = False 24 | _updatable = False 25 | 26 | def __repr__(self): 27 | if self._deleted: 28 | return super().__repr__() 29 | return f'' 30 | 31 | 32 | class LoginApi: 33 | """High-level interface for the Login API""" 34 | 35 | def __init__(self, client): 36 | self.client = client 37 | self.url = self.client.api_url + '/login/' 38 | 39 | def get(self, **kwargs): 40 | """ 41 | Get login information or start an OAuth request. This is equivalent to `login.list()`. 42 | 43 | Arguments: 44 | kwargs (dict, optional): All other optional keyword arguments are applied as query 45 | parameters/filters 46 | 47 | Returns: 48 | [`Login`](/docs/login) 49 | 50 | Raises: 51 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 52 | or valid response JSON 53 | 54 | Example: 55 | ``` 56 | >>> air.login.get() 57 | 58 | ``` 59 | """ 60 | return self.list(**kwargs) 61 | 62 | def list(self, **kwargs): 63 | """ 64 | Get login information or start an OAuth request. This is equivalent to `login.get()`. 65 | 66 | Arguments: 67 | kwargs (dict, optional): All other optional keyword arguments are applied as query 68 | parameters/filters 69 | 70 | Returns: 71 | [`Login`](/docs/login) 72 | 73 | Raises: 74 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 75 | or valid response JSON 76 | 77 | Example: 78 | ``` 79 | >>> air.login.get() 80 | 81 | ``` 82 | """ 83 | res = self.client.get(f'{self.url}', params=kwargs) 84 | util.raise_if_invalid_response(res) 85 | return Login(self, **res.json()) 86 | -------------------------------------------------------------------------------- /air_sdk/topology_file.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Topologyfile module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class TopologyFile(AirModel): 13 | """A text file that describes a network topology.""" 14 | 15 | _deletable = False 16 | _updatable = False 17 | 18 | def __repr__(self): 19 | return f'' 20 | 21 | 22 | class TopologyFileApi: 23 | """High-level interface for the TopologyFile API""" 24 | 25 | def __init__(self, client): 26 | self.client = client 27 | self.url = self.client.api_url.replace('v1', 'v2') + '/topology-file/' 28 | 29 | def get(self, file_id: str, **kwargs) -> TopologyFile: 30 | """ 31 | Get an existing topology file 32 | 33 | Arguments: 34 | file_id (str): TopologyFile ID 35 | kwargs (dict, optional): All other optional keyword arguments are applied as query 36 | parameters/filters 37 | 38 | Returns: 39 | TopologyFile 40 | 41 | Raises: 42 | AirUnexpectedResponse - API did not return a 200 OK or valid response JSON 43 | 44 | Example: 45 | ``` 46 | >>> air.topology_files.get('5cec8f3b-f449-47a3-a6ee-c5b81bf92ccf') 47 | 48 | ``` 49 | """ 50 | url = f'{self.url}{file_id}/' 51 | res = self.client.get(url, params=kwargs) 52 | util.raise_if_invalid_response(res) 53 | return TopologyFile(self, **res.json()) 54 | 55 | # TODO: v2 list APIs require pagination handling 56 | # def list(self, **kwargs) -> 'list[TopologyFile]': 57 | # """ 58 | # List existing TopologyFiles 59 | 60 | # Arguments: 61 | # kwargs (dict, optional): All other optional keyword arguments are applied as query 62 | # parameters/filters 63 | 64 | # Returns: 65 | # list[TopologyFile] 66 | 67 | # Raises: 68 | # AirUnexpectedResponse - API did not return a 200 OK or valid response JSON 69 | 70 | # Example: 71 | # ``` 72 | # >>> air.topology_files.list() 73 | # [, ] 74 | # ``` 75 | # """ 76 | -------------------------------------------------------------------------------- /tests/test_job.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for job.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import job 13 | 14 | 15 | class TestJob(TestCase): 16 | def setUp(self): 17 | self.model = job.Job(MagicMock()) 18 | self.model.id = 'abc123' 19 | self.model.category = 'START' 20 | 21 | def test_init_(self): 22 | self.assertFalse(self.model._deletable) 23 | self.assertTrue(self.model._updatable) 24 | 25 | def test_repr(self): 26 | self.assertEqual(str(self.model), f'') 27 | 28 | def test_repr_deleted(self): 29 | self.model._deleted = True 30 | self.assertTrue('Deleted Object' in str(self.model)) 31 | 32 | 33 | class TestJobApi(TestCase): 34 | def setUp(self): 35 | self.client = MagicMock() 36 | self.client.api_url = 'http://testserver/api' 37 | self.api = job.JobApi(self.client) 38 | 39 | def test_init_(self): 40 | self.assertEqual(self.api.client, self.client) 41 | self.assertEqual(self.api.url, 'http://testserver/api/job/') 42 | 43 | @patch('air_sdk.util.raise_if_invalid_response') 44 | def test_get(self, mock_raise): 45 | self.client.get.return_value.json.return_value = {'test': 'success'} 46 | res = self.api.get('abc123', foo='bar') 47 | self.client.get.assert_called_with(f'{self.client.api_url}/job/abc123/', params={'foo': 'bar'}) 48 | mock_raise.assert_called_with(self.client.get.return_value) 49 | self.assertIsInstance(res, job.Job) 50 | self.assertEqual(res.test, 'success') 51 | 52 | @patch('air_sdk.util.raise_if_invalid_response') 53 | def test_list(self, mock_raise): 54 | self.client.get.return_value.json.return_value = [{'id': 'abc'}, {'id': 'xyz'}] 55 | res = self.api.list(foo='bar') 56 | self.client.get.assert_called_with(f'{self.client.api_url}/job/', params={'foo': 'bar'}) 57 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 58 | self.assertEqual(len(res), 2) 59 | self.assertIsInstance(res[0], job.Job) 60 | self.assertEqual(res[0].id, 'abc') 61 | self.assertEqual(res[1].id, 'xyz') 62 | -------------------------------------------------------------------------------- /tests/test_demo.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for demo.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import demo 13 | 14 | 15 | class TestDemo(TestCase): 16 | def setUp(self): 17 | self.model = demo.Demo(MagicMock()) 18 | self.model.id = 'abc123' 19 | self.model.name = 'test' 20 | 21 | def test_init_(self): 22 | self.assertFalse(self.model._deletable) 23 | self.assertFalse(self.model._updatable) 24 | 25 | def test_repr(self): 26 | self.assertEqual(str(self.model), f"") 27 | 28 | def test_repr_deleted(self): 29 | self.model._deleted = True 30 | self.assertTrue('Deleted Object' in str(self.model)) 31 | 32 | 33 | class TestDemoApi(TestCase): 34 | def setUp(self): 35 | self.client = MagicMock() 36 | self.client.api_url = 'http://testserver/api' 37 | self.api = demo.DemoApi(self.client) 38 | 39 | def test_init_(self): 40 | self.assertEqual(self.api.client, self.client) 41 | self.assertEqual(self.api.url, 'http://testserver/api/demo/') 42 | 43 | @patch('air_sdk.util.raise_if_invalid_response') 44 | def test_get(self, mock_raise): 45 | self.client.get.return_value.json.return_value = {'test': 'success'} 46 | res = self.api.get('abc123', foo='bar') 47 | self.client.get.assert_called_with(f'{self.client.api_url}/demo/abc123/', params={'foo': 'bar'}) 48 | mock_raise.assert_called_with(self.client.get.return_value) 49 | self.assertIsInstance(res, demo.Demo) 50 | self.assertEqual(res.test, 'success') 51 | 52 | @patch('air_sdk.util.raise_if_invalid_response') 53 | def test_list(self, mock_raise): 54 | self.client.get.return_value.json.return_value = [{'id': 'abc'}, {'id': 'xyz'}] 55 | res = self.api.list(foo='bar') 56 | self.client.get.assert_called_with(f'{self.client.api_url}/demo/', params={'foo': 'bar'}) 57 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 58 | self.assertEqual(len(res), 2) 59 | self.assertIsInstance(res[0], demo.Demo) 60 | self.assertEqual(res[0].id, 'abc') 61 | self.assertEqual(res[1].id, 'xyz') 62 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/organizations.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from __future__ import annotations 5 | 6 | from dataclasses import dataclass, field 7 | from typing import Iterator, List, Any 8 | from air_sdk.v2.endpoints import mixins 9 | from air_sdk.v2.endpoints.resource_budgets import ResourceBudget 10 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi, PrimaryKey 11 | from air_sdk.v2.utils import validate_payload_types 12 | 13 | 14 | @dataclass(eq=False) 15 | class Organization(AirModel): 16 | id: str 17 | name: str 18 | member_count: int 19 | resource_budget: ResourceBudget = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 20 | 21 | def list_members(self) -> Iterator[OrganizationMember]: 22 | """Returns an iterator of organization members.""" 23 | members_api = OrganizationMembersEndpointApi(self.__api__) 24 | return members_api.list_organization_members(self.id) 25 | 26 | @classmethod 27 | def get_model_api(cls): 28 | """ 29 | Returns the respective `AirModelAPI` type for this model. 30 | """ 31 | return OrganizationEndpointApi 32 | 33 | 34 | class OrganizationEndpointApi( 35 | mixins.ListApiMixin[Organization], 36 | mixins.GetApiMixin[Organization], 37 | BaseEndpointApi[Organization], 38 | ): 39 | API_PATH = 'organizations' 40 | model = Organization 41 | 42 | 43 | @dataclass(eq=False) 44 | class OrganizationMember(AirModel): 45 | id: str 46 | username: str 47 | roles: List[str] 48 | resource_budget: ResourceBudget = field(metadata=AirModel.FIELD_FOREIGN_KEY) 49 | 50 | @classmethod 51 | def get_model_api(cls): 52 | """ 53 | Returns the respective `AirModelAPI` type for this model. 54 | """ 55 | return OrganizationMembersEndpointApi 56 | 57 | 58 | class OrganizationMembersEndpointApi( 59 | mixins.ListApiMixin[OrganizationMember], 60 | BaseEndpointApi[OrganizationMember], 61 | ): 62 | API_PATH = 'organization/{id}/members' 63 | model = OrganizationMember 64 | 65 | @validate_payload_types 66 | def list_organization_members( 67 | self, organization: PrimaryKey, **params: Any 68 | ) -> Iterator[OrganizationMember]: 69 | """Return an iterator of organization member instances for a specific organization.""" 70 | self.url = self.url.format(id=organization) 71 | return super().list(**params) 72 | -------------------------------------------------------------------------------- /air_sdk/user_preference.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | UserPreference module 6 | """ 7 | 8 | import json 9 | import re 10 | 11 | from .air_model import AirModel 12 | from . import util 13 | 14 | 15 | class UserPreference(AirModel): 16 | """ 17 | A collection of your user preferences, which may be global for your account or specific to a single 18 | simulation. 19 | 20 | ### json 21 | Returns a JSON string representation of the preferences object 22 | 23 | ### refresh 24 | Syncs the key with all values returned by the API 25 | """ 26 | 27 | _deletable = False 28 | 29 | def __init__(self, api, **kwargs): 30 | _model = kwargs.pop('_model', None) 31 | _version_override = kwargs.pop('_version_override', None) 32 | super().__init__(api, **kwargs) 33 | self._model = _model 34 | self._version_override = _version_override 35 | self._url = self._build_url() 36 | 37 | def __repr__(self): 38 | return json.dumps(self.preferences) 39 | 40 | def __setattr__(self, name, value): 41 | if not getattr(self, '_url', None): 42 | return super().__setattr__(name, value) 43 | if self.preferences.get(name) != value: 44 | res = self._api.client.patch(self._url, json={name: value}) 45 | util.raise_if_invalid_response(res) 46 | self.preferences[name] = value 47 | return None 48 | 49 | def _build_url(self): 50 | url = self._api.url 51 | if self._model: 52 | url += f'{self._model.id}/' 53 | url += 'preferences/' 54 | if self._version_override: 55 | url = re.sub(r'/v\d/', f'/v{self._version_override}/', url) 56 | return url 57 | 58 | def refresh(self): 59 | """Syncs the object with all values returned by the API""" 60 | instance = self._model or self._api 61 | self._load(**instance.preferences().__dict__) 62 | 63 | def update(self, **kwargs): 64 | """ 65 | Update the object with the provided data 66 | 67 | Arguments: 68 | kwargs (dict, optional): All optional keyword arguments are applied as key/value 69 | pairs in the request's JSON payload 70 | """ 71 | self.refresh() 72 | self.preferences.update(kwargs) 73 | res = self._api.client.put(self._url, json=self.__dict__) 74 | util.raise_if_invalid_response(res) 75 | -------------------------------------------------------------------------------- /tests/test_interface.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for interface.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import interface 13 | 14 | 15 | class TestInterface(TestCase): 16 | def setUp(self): 17 | self.model = interface.Interface(MagicMock()) 18 | self.model.id = 'abc123' 19 | self.model.name = 'eth0' 20 | 21 | def test_init_(self): 22 | self.assertFalse(self.model._deletable) 23 | self.assertFalse(self.model._updatable) 24 | 25 | def test_repr(self): 26 | self.assertEqual(str(self.model), f'') 27 | 28 | def test_repr_deleted(self): 29 | self.model._deleted = True 30 | self.assertTrue('Deleted Object' in str(self.model)) 31 | 32 | 33 | class TestInterfaceApi(TestCase): 34 | def setUp(self): 35 | self.client = MagicMock() 36 | self.client.api_url = 'http://testserver/api' 37 | self.api = interface.InterfaceApi(self.client) 38 | 39 | def test_init_(self): 40 | self.assertEqual(self.api.client, self.client) 41 | self.assertEqual(self.api.url, 'http://testserver/api/interface/') 42 | 43 | @patch('air_sdk.util.raise_if_invalid_response') 44 | def test_get(self, mock_raise): 45 | self.client.get.return_value.json.return_value = {'test': 'success'} 46 | res = self.api.get('abc123', foo='bar') 47 | self.client.get.assert_called_with(f'{self.client.api_url}/interface/abc123/', params={'foo': 'bar'}) 48 | mock_raise.assert_called_with(self.client.get.return_value) 49 | self.assertIsInstance(res, interface.Interface) 50 | self.assertEqual(res.test, 'success') 51 | 52 | @patch('air_sdk.util.raise_if_invalid_response') 53 | def test_list(self, mock_raise): 54 | self.client.get.return_value.json.return_value = [{'id': 'abc'}, {'id': 'xyz'}] 55 | res = self.api.list(foo='bar') 56 | self.client.get.assert_called_with(f'{self.client.api_url}/interface/', params={'foo': 'bar'}) 57 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 58 | self.assertEqual(len(res), 2) 59 | self.assertIsInstance(res[0], interface.Interface) 60 | self.assertEqual(res[0].id, 'abc') 61 | self.assertEqual(res[1].id, 'xyz') 62 | -------------------------------------------------------------------------------- /tests/test_marketplace.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for marketplace.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import marketplace 13 | 14 | 15 | class TestMarketplace(TestCase): 16 | def setUp(self): 17 | self.model = marketplace.Marketplace(MagicMock()) 18 | self.model.id = 'abc123' 19 | self.model.name = 'test' 20 | 21 | def test_init_(self): 22 | self.assertFalse(self.model._deletable) 23 | self.assertFalse(self.model._updatable) 24 | 25 | def test_repr(self): 26 | self.assertEqual(str(self.model), f"") 27 | 28 | def test_repr_deleted(self): 29 | self.model._deleted = True 30 | self.assertTrue('Deleted Object' in str(self.model)) 31 | 32 | 33 | class TestMarketplaceApi(TestCase): 34 | def setUp(self): 35 | self.client = MagicMock() 36 | self.client.api_url = 'http://testserver/api' 37 | self.api = marketplace.MarketplaceApi(self.client) 38 | 39 | def test_init_(self): 40 | self.assertEqual(self.api.client, self.client) 41 | self.assertEqual(self.api.url, 'http://testserver/api/marketplace/demo/') 42 | 43 | @patch('air_sdk.util.raise_if_invalid_response') 44 | def test_get(self, mock_raise): 45 | self.client.get.return_value.json.return_value = {'test': 'success'} 46 | res = self.api.get('abc123', foo='bar') 47 | self.client.get.assert_called_with( 48 | f'{self.client.api_url}/marketplace/demo/abc123/', params={'foo': 'bar'} 49 | ) 50 | mock_raise.assert_called_with(self.client.get.return_value) 51 | self.assertIsInstance(res, marketplace.Marketplace) 52 | self.assertEqual(res.test, 'success') 53 | 54 | @patch('air_sdk.util.raise_if_invalid_response') 55 | def test_list(self, mock_raise): 56 | self.client.get.return_value.json.return_value = [{'id': 'abc'}, {'id': 'xyz'}] 57 | res = self.api.list(foo='bar') 58 | self.client.get.assert_called_with(f'{self.client.api_url}/marketplace/demo/', params={'foo': 'bar'}) 59 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 60 | self.assertEqual(len(res), 2) 61 | self.assertIsInstance(res[0], marketplace.Marketplace) 62 | self.assertEqual(res[0].id, 'abc') 63 | self.assertEqual(res[1].id, 'xyz') 64 | -------------------------------------------------------------------------------- /tests/test_resource_budget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for resource_budget.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import resource_budget 13 | 14 | 15 | class TestResourceBudget(TestCase): 16 | def setUp(self): 17 | self.model = resource_budget.ResourceBudget(MagicMock()) 18 | self.model.id = 'abc123' 19 | self.model.category = 'START' 20 | 21 | def test_init_(self): 22 | self.assertFalse(self.model._deletable) 23 | self.assertTrue(self.model._updatable) 24 | 25 | def test_repr(self): 26 | self.assertEqual(str(self.model), f'') 27 | 28 | def test_repr_deleted(self): 29 | self.model._deleted = True 30 | self.assertTrue('Deleted Object' in str(self.model)) 31 | 32 | 33 | class TestResourceBudgetApi(TestCase): 34 | def setUp(self): 35 | self.client = MagicMock() 36 | self.client.api_url = 'http://testserver/api' 37 | self.api = resource_budget.ResourceBudgetApi(self.client) 38 | 39 | def test_init_(self): 40 | self.assertEqual(self.api.client, self.client) 41 | self.assertEqual(self.api.url, 'http://testserver/api/resource-budget/') 42 | 43 | @patch('air_sdk.util.raise_if_invalid_response') 44 | def test_get(self, mock_raise): 45 | self.client.get.return_value.json.return_value = {'test': 'success'} 46 | res = self.api.get('abc123', foo='bar') 47 | self.client.get.assert_called_with( 48 | f'{self.client.api_url}/resource-budget/abc123/', params={'foo': 'bar'} 49 | ) 50 | mock_raise.assert_called_with(self.client.get.return_value) 51 | self.assertIsInstance(res, resource_budget.ResourceBudget) 52 | self.assertEqual(res.test, 'success') 53 | 54 | @patch('air_sdk.util.raise_if_invalid_response') 55 | def test_list(self, mock_raise): 56 | self.client.get.return_value.json.return_value = [{'id': 'abc'}, {'id': 'xyz'}] 57 | res = self.api.list(foo='bar') 58 | self.client.get.assert_called_with(f'{self.client.api_url}/resource-budget/', params={'foo': 'bar'}) 59 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 60 | self.assertEqual(len(res), 2) 61 | self.assertIsInstance(res[0], resource_budget.ResourceBudget) 62 | self.assertEqual(res[0].id, 'abc') 63 | self.assertEqual(res[1].id, 'xyz') 64 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_node_instructions.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | import json 4 | from http import HTTPStatus 5 | 6 | import pytest 7 | import faker 8 | 9 | from air_sdk.v2.endpoints.mixins import serialize_payload 10 | 11 | faker.Faker.seed(0) 12 | fake = faker.Faker() 13 | 14 | 15 | class TestNodeInstructionsEndpointApi: 16 | @pytest.mark.parametrize( 17 | 'payload,is_valid', 18 | ( 19 | ({}, False), 20 | ( 21 | { 22 | 'pk': fake.uuid4(cast_to=str), 23 | 'executor': fake.slug(), 24 | 'data': fake.slug(), 25 | 'monitor': fake.slug(), 26 | }, 27 | False, 28 | ), 29 | ( 30 | { 31 | 'pk': fake.uuid4(cast_to=str), 32 | 'executor': 'init', 33 | 'data': fake.slug(), 34 | 'monitor': fake.slug(), 35 | }, 36 | True, 37 | ), 38 | ( 39 | { 40 | 'pk': fake.uuid4(cast_to=str), 41 | 'executor': 'init', 42 | 'data': fake.slug(), 43 | }, 44 | True, 45 | ), 46 | ), 47 | ) 48 | def test_create(self, setup_mock_responses, api, node_instruction_factory, payload, is_valid): 49 | """This tests that the data provided is properly validated and used.""" 50 | if is_valid: 51 | processed_payload = json.loads(serialize_payload(payload)) 52 | expected_inst = node_instruction_factory(api.node_instructions.__api__, **processed_payload) 53 | setup_mock_responses( 54 | { 55 | ('POST', api.node_instructions.url.format(id=payload['pk'])): { 56 | 'json': json.loads(expected_inst.json()), 57 | 'status_code': HTTPStatus.CREATED, 58 | } 59 | } 60 | ) 61 | inst = api.node_instructions.create(**payload) 62 | # Verify that the returned instance and the expected instance are equal 63 | assert inst == expected_inst 64 | # Verify that the returned instance and the expected instance are not the same 65 | assert inst is not expected_inst 66 | else: 67 | with pytest.raises(Exception) as err: 68 | api.node_instructions.create(**payload) 69 | assert err.type in (TypeError, ValueError) 70 | -------------------------------------------------------------------------------- /air_sdk/v2/utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | import inspect 4 | from dataclasses import Field, fields, is_dataclass 5 | from datetime import datetime, timezone 6 | from functools import wraps 7 | from typing import Optional, get_type_hints, TypeVar, Callable, Any, cast 8 | from urllib.parse import ParseResult, urlparse 9 | from uuid import UUID 10 | 11 | from air_sdk.v2.typing import type_check 12 | 13 | F = TypeVar('F', bound=Callable[..., Any]) 14 | 15 | 16 | def join_urls(*args: str) -> str: 17 | return '/'.join(frag.strip('/') for frag in args) + '/' 18 | 19 | 20 | def iso_string_to_datetime(iso: str) -> Optional[datetime]: 21 | try: 22 | return datetime.fromisoformat(iso.replace('Z', '+00:00')) 23 | except ValueError: 24 | return None 25 | 26 | 27 | def datetime_to_iso_string(date: datetime) -> str: 28 | return date.astimezone(tz=timezone.utc).isoformat().replace('+00:00', 'Z') 29 | 30 | 31 | def to_uuid(uuid: str) -> Optional[UUID]: 32 | try: 33 | return UUID(uuid, version=4) 34 | except ValueError: 35 | return None 36 | 37 | 38 | def to_url(url: str) -> Optional[ParseResult]: 39 | try: 40 | parsed_url = urlparse(url) 41 | return parsed_url if all((parsed_url.scheme, parsed_url.netloc, parsed_url.path)) else None 42 | except AttributeError: 43 | return None 44 | 45 | 46 | def is_dunder(name: str) -> bool: 47 | delimiter = '__' 48 | return name.startswith(delimiter) and name.endswith(delimiter) 49 | 50 | 51 | def as_field(class_or_instance: object, name: str) -> Optional[Field]: # type: ignore[type-arg] 52 | if is_dataclass(class_or_instance): 53 | try: 54 | return next(field for field in fields(class_or_instance) if field.name == name) 55 | except StopIteration: 56 | pass 57 | return None 58 | 59 | 60 | def validate_payload_types(func: F) -> F: 61 | """A wrapper for validating the type of payload during create.""" 62 | 63 | @wraps(func) 64 | def wrapper(*args, **kwargs): 65 | hints = get_type_hints(func) 66 | 67 | sig = inspect.signature(func) 68 | bound_args = sig.bind(*args, **kwargs) 69 | bound_args.apply_defaults() 70 | 71 | for name, value in bound_args.arguments.items(): 72 | if name in hints: 73 | expected_type = hints[name] 74 | if not type_check(value, expected_type): 75 | raise TypeError(f"Argument '{name}' must be {expected_type}, got {type(value)}") 76 | 77 | return func(*args, **kwargs) 78 | 79 | return cast(F, wrapper) 80 | -------------------------------------------------------------------------------- /air_sdk/marketplace.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Marketplace Demo module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class Marketplace(AirModel): 13 | """ 14 | Manage marketplace demos 15 | 16 | """ 17 | 18 | _updatable = False 19 | _deletable = False 20 | 21 | def __repr__(self): 22 | if self._deleted or not self.name: 23 | return super().__repr__() 24 | return f"" 25 | 26 | 27 | class MarketplaceApi: 28 | """High-level interface for the Marketplace API""" 29 | 30 | def __init__(self, client): 31 | self.client = client 32 | self.url = self.client.api_url + '/marketplace/demo/' 33 | 34 | def list(self, **kwargs): 35 | # pylint: disable=line-too-long 36 | """ 37 | List existing keys 38 | 39 | Arguments: 40 | kwargs (dict, optional): All other optional keyword arguments are applied as query 41 | parameters/filters 42 | 43 | Returns: 44 | list 45 | 46 | Raises: 47 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 48 | or valid response JSON 49 | 50 | Example: 51 | ``` 52 | >>> air.marketplace.list() 53 | [] 54 | ``` 55 | """ # pylint: enable=line-too-long 56 | res = self.client.get(f'{self.url}', params=kwargs) 57 | util.raise_if_invalid_response(res, data_type=list) 58 | return [Marketplace(self, **key) for key in res.json()] 59 | 60 | def get(self, demo_id, **kwargs): 61 | """ 62 | Get an existing marketplace demo 63 | 64 | Arguments: 65 | demo_id (str): Demo ID 66 | kwargs (dict, optional): All other optional keyword arguments are applied as query 67 | parameters/filters 68 | 69 | Returns: 70 | [`Demo`](/docs/marketplace) 71 | 72 | Raises: 73 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 74 | or valid response JSON 75 | 76 | Example: 77 | ``` 78 | >>> air.marketplace.get('3dadd54d-583c-432e-9383-a2b0b1d7f551') 79 | 80 | ``` 81 | """ 82 | url = f'{self.url}{demo_id}/' 83 | res = self.client.get(url, params=kwargs) 84 | util.raise_if_invalid_response(res) 85 | return Marketplace(self, **res.json()) 86 | -------------------------------------------------------------------------------- /tests/test_exceptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for exceptions.py 6 | """ 7 | # pylint: disable=missing-function-docstring,missing-class-docstring 8 | 9 | from unittest import TestCase 10 | 11 | from air_sdk import exceptions 12 | 13 | 14 | class TestAirError(TestCase): 15 | def test_init(self): 16 | err = exceptions.AirError('test', 200) 17 | self.assertEqual(str(err), 'test') 18 | self.assertEqual(err.status_code, 200) 19 | 20 | 21 | class TestAirAuthorizationError(TestCase): 22 | def test_init(self): 23 | err = exceptions.AirAuthorizationError('test', 200) 24 | self.assertEqual(err.message, 'test') 25 | self.assertEqual(err.status_code, 200) 26 | self.assertIsInstance(err, exceptions.AirError) 27 | 28 | def test_init_default(self): 29 | err = exceptions.AirAuthorizationError(status_code=200) 30 | self.assertEqual(err.message, 'An error occurred when authorizing the Air API') 31 | self.assertEqual(err.status_code, 200) 32 | 33 | 34 | class TestAirUnexpectedResponse(TestCase): 35 | def test_init(self): 36 | err = exceptions.AirUnexpectedResponse('test') 37 | self.assertEqual(err.message, 'Received an unexpected response from the Air API: test') 38 | self.assertIsNone(err.status_code) 39 | self.assertIsInstance(err, exceptions.AirError) 40 | 41 | def test_init_status_code(self): 42 | err = exceptions.AirUnexpectedResponse('test', status_code=200) 43 | self.assertEqual(err.message, 'Received an unexpected response from the Air API (200): test') 44 | self.assertEqual(err.status_code, 200) 45 | 46 | 47 | class TestAirForbiddenError(TestCase): 48 | def test_init(self): 49 | err = exceptions.AirForbiddenError('test') 50 | self.assertEqual(err.message, 'test') 51 | self.assertEqual(err.status_code, 403) 52 | self.assertIsInstance(err, exceptions.AirError) 53 | 54 | def test_init_default(self): 55 | err = exceptions.AirForbiddenError() 56 | self.assertEqual(err.message, 'Received 403 Forbidden. Please call AirApi.authorize().') 57 | self.assertEqual(err.status_code, 403) 58 | 59 | 60 | class TestAirObjectDeleted(TestCase): 61 | def test_init(self): 62 | err = exceptions.AirObjectDeleted('foo', 'test') 63 | self.assertEqual(err.message, 'test') 64 | self.assertIsInstance(err, exceptions.AirError) 65 | 66 | def test_init_default(self): 67 | err = exceptions.AirObjectDeleted('foo') 68 | self.assertEqual(err.message, 'foo object has been deleted and should no longer be referenced') 69 | -------------------------------------------------------------------------------- /air_sdk/demo.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Demo module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class Demo(AirModel): 13 | """ 14 | View Demos 15 | ### json 16 | Returns a JSON string representation of the demo 17 | 18 | ### refresh 19 | Syncs the demo with all values returned by the API 20 | """ 21 | 22 | _deletable = False 23 | _updatable = False 24 | 25 | def __repr__(self): 26 | if self._deleted or not self.name: 27 | return super().__repr__() 28 | return f"" 29 | 30 | 31 | class DemoApi: 32 | """High-level interface for the Demo API""" 33 | 34 | def __init__(self, client): 35 | self.client = client 36 | self.url = self.client.api_url + '/demo/' 37 | 38 | def get(self, demo_id, **kwargs): 39 | """ 40 | Get an existing demo 41 | 42 | Arguments: 43 | dmeo_id (str): Demo ID 44 | kwargs (dict, optional): All other optional keyword arguments are applied as query 45 | parameters/filters 46 | 47 | Returns: 48 | [`Demo`](/docs/demo) 49 | 50 | Raises: 51 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 52 | or valid response JSON 53 | 54 | Example: 55 | ``` 56 | >>> air.demos.get('3dadd54d-583c-432e-9383-a2b0b1d7f551') 57 | 58 | ``` 59 | """ 60 | url = f'{self.url}{demo_id}/' 61 | res = self.client.get(url, params=kwargs) 62 | util.raise_if_invalid_response(res) 63 | return Demo(self, **res.json()) 64 | 65 | def list(self, **kwargs): 66 | # pylint: disable=line-too-long 67 | """ 68 | List existing demos 69 | 70 | Arguments: 71 | kwargs (dict, optional): All other optional keyword arguments are applied as query 72 | parameters/filters 73 | 74 | Returns: 75 | list 76 | 77 | Raises: 78 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 79 | or valid response JSON 80 | 81 | Example: 82 | ``` 83 | >>> air.demos.list() 84 | [, ] 85 | ``` 86 | """ # pylint: enable=line-too-long 87 | res = self.client.get(f'{self.url}', params=kwargs) 88 | util.raise_if_invalid_response(res, data_type=list) 89 | return [Demo(self, **demo) for demo in res.json()] 90 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/breakouts.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from __future__ import annotations 5 | 6 | from dataclasses import dataclass, field 7 | from datetime import datetime 8 | from typing import Any, List, Optional, Union 9 | 10 | from air_sdk.v2.endpoints import mixins 11 | from air_sdk.v2.endpoints.nodes import Node 12 | from air_sdk.v2.endpoints.interfaces import Interface 13 | from air_sdk.v2.air_model import AirModel, ApiNotImplementedMixin, BaseEndpointApi, PrimaryKey 14 | from air_sdk.v2.utils import validate_payload_types 15 | 16 | 17 | @dataclass(eq=False) 18 | class BreakoutInterface(ApiNotImplementedMixin, AirModel): 19 | """ 20 | Interface created from a breakout operation. 21 | Represents a simplified interface model returned in breakout responses. 22 | """ 23 | 24 | id: str = field(repr=False) 25 | name: str 26 | 27 | @classmethod 28 | def get_model_api(cls): 29 | return BreakoutInterfaceEndpointApi 30 | 31 | 32 | class BreakoutInterfaceEndpointApi(BaseEndpointApi[BreakoutInterface]): 33 | API_PATH = 'interfaces' 34 | model = BreakoutInterface 35 | 36 | def get(self, pk: PrimaryKey, **params: Any) -> BreakoutInterface: 37 | return self.load_model({'id': str(pk)}) 38 | 39 | 40 | @dataclass(eq=False) 41 | class Breakout(AirModel): 42 | id: str = field(repr=False) 43 | node: Node = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 44 | name: str 45 | mac_address: Optional[str] = field(repr=False) 46 | split_count: int 47 | simulation_interfaces: List[BreakoutInterface] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 48 | created: Optional[datetime] = field(repr=False) 49 | 50 | @classmethod 51 | def get_model_api(cls): 52 | """ 53 | Returns the respective `AirModelAPI` type for this model. 54 | """ 55 | return BreakoutEndpointApi 56 | 57 | 58 | class BreakoutEndpointApi( 59 | mixins.ListApiMixin[Breakout], 60 | mixins.CreateApiMixin[Breakout], 61 | mixins.GetApiMixin[Breakout], 62 | mixins.DeleteApiMixin, 63 | BaseEndpointApi[Breakout], 64 | ): 65 | API_PATH = 'simulations/nodes/interfaces/breakouts/' 66 | model = Breakout 67 | 68 | @validate_payload_types 69 | def create(self, interface: Union[Interface, PrimaryKey], split_count: int) -> Breakout: 70 | """ 71 | Create a new breakout configuration. 72 | 73 | Args: 74 | interface: The interface to break out 75 | split_count: Number of splits to create (minimum value: 2) 76 | 77 | Returns: 78 | Breakout: The created breakout configuration 79 | """ 80 | return super().create(interface=interface, split_count=split_count) 81 | -------------------------------------------------------------------------------- /air_sdk/interface.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Interface module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class Interface(AirModel): 13 | """ 14 | View an Interface 15 | 16 | ### json 17 | Returns a JSON string representation of the interface 18 | 19 | ### refresh 20 | Syncs the interface with all values returned by the API 21 | """ 22 | 23 | _deletable = False 24 | _updatable = False 25 | 26 | def __repr__(self): 27 | if self._deleted or not self.name: 28 | return super().__repr__() 29 | return f'' 30 | 31 | 32 | class InterfaceApi: 33 | """High-level interface for the Interface API""" 34 | 35 | def __init__(self, client): 36 | self.client = client 37 | self.url = self.client.api_url + '/interface/' 38 | 39 | def get(self, interface_id, **kwargs): 40 | """ 41 | Get an existing interface 42 | 43 | Arguments: 44 | interface_id (str): Interface ID 45 | kwargs (dict, optional): All other optional keyword arguments are applied as query 46 | parameters/filters 47 | 48 | Returns: 49 | [`Interface`](/docs/interface) 50 | 51 | Raises: 52 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 53 | or valid response JSON 54 | 55 | Example: 56 | ``` 57 | >>> air.interfaces.get('3dadd54d-583c-432e-9383-a2b0b1d7f551') 58 | 59 | ``` 60 | """ 61 | url = f'{self.url}{interface_id}/' 62 | res = self.client.get(url, params=kwargs) 63 | util.raise_if_invalid_response(res) 64 | return Interface(self, **res.json()) 65 | 66 | def list(self, **kwargs): 67 | # pylint: disable=line-too-long 68 | """ 69 | List existing interfaces 70 | 71 | Arguments: 72 | kwargs (dict, optional): All other optional keyword arguments are applied as query 73 | parameters/filters 74 | 75 | Returns: 76 | list 77 | 78 | Raises: 79 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 80 | or valid response JSON 81 | 82 | Example: 83 | ``` 84 | >>> air.interfaces.list() 85 | [, ] 86 | ``` 87 | """ 88 | # pylint: enable=line-too-long 89 | res = self.client.get(f'{self.url}', params=kwargs) 90 | util.raise_if_invalid_response(res, data_type=list) 91 | return [Interface(self, **interface) for interface in res.json()] 92 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_marketplace_demo_tags.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | 5 | import faker 6 | 7 | from air_sdk import v2 8 | 9 | faker.Faker.seed(0) 10 | fake = faker.Faker() 11 | 12 | 13 | class TestMarketplaceDemoTagsEndpointApi: 14 | def setup_method(self): 15 | self.AirApi = v2.AirApi 16 | self.api_url = 'https://air-fake-test.nvidia.com/api/' 17 | self.api = self.AirApi(api_url=self.api_url, authenticate=False) 18 | self.endpoint_url = v2.utils.join_urls(self.api_url, 'v2', self.api.marketplace_demo_tags.API_PATH) 19 | 20 | def test_list_single_inst(self, api, setup_mock_responses, marketplace_demo_tag_factory): 21 | """Ensure list requests work when there is only one page in the paginated response.""" 22 | # Set up mock client 23 | results = [marketplace_demo_tag_factory(api).dict()] 24 | expected_responses = { 25 | ('GET', self.endpoint_url): { 26 | 'json': {'previous': None, 'next': None, 'count': len(results), 'results': results}, 27 | 'status_code': 200, 28 | } 29 | } 30 | setup_mock_responses(expected_responses) 31 | # Test SDK 32 | 33 | tags = list(self.api.marketplace_demo_tags.list(limit=len(results))) 34 | assert len(tags) == 1 35 | assert isinstance(tags[0], self.api.marketplace_demo_tags.model) 36 | 37 | def test_pagination(self, api, setup_mock_responses, marketplace_demo_tag_factory): 38 | """Ensure multiple calls are made to collect paginated responses.""" 39 | first_tag = marketplace_demo_tag_factory(api) 40 | second_tag = marketplace_demo_tag_factory(api) 41 | page_size = 1 42 | first_url = self.endpoint_url + f'?limit={page_size}' 43 | second_url = first_url + '&offset=1' 44 | expected_responses = { 45 | ('GET', first_url): { 46 | 'json': { 47 | 'previous': None, 48 | 'next': second_url, 49 | 'count': page_size, 50 | 'results': [first_tag.dict()], 51 | }, 52 | 'status_code': 200, 53 | }, 54 | ('GET', second_url): { 55 | 'json': { 56 | 'previous': first_url, 57 | 'next': None, 58 | 'count': page_size, 59 | 'results': [second_tag.dict()], 60 | }, 61 | 'status_code': 200, 62 | }, 63 | } 64 | setup_mock_responses(expected_responses) 65 | # Test SDK 66 | tags = list(self.api.marketplace_demo_tags.list(limit=page_size)) 67 | assert len(tags) == 2 68 | assert tags[0] == first_tag 69 | assert tags[1] == second_tag 70 | -------------------------------------------------------------------------------- /air_sdk/job.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Job module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class Job(AirModel): 13 | """ 14 | Manage a Job 15 | 16 | ### json 17 | Returns a JSON string representation of the job 18 | 19 | ### refresh 20 | Syncs the job with all values returned by the API 21 | 22 | ### update 23 | Update the job with the provided data 24 | 25 | Arguments: 26 | kwargs (dict, optional): All optional keyword arguments are applied as key/value 27 | pairs in the request's JSON payload 28 | """ 29 | 30 | _deletable = False 31 | 32 | def __repr__(self): 33 | if self._deleted or not self.category: 34 | return super().__repr__() 35 | return f'' 36 | 37 | 38 | class JobApi: 39 | """High-level interface for the Job API""" 40 | 41 | def __init__(self, client): 42 | self.client = client 43 | self.url = self.client.api_url + '/job/' 44 | 45 | def get(self, job_id, **kwargs): 46 | """ 47 | Get an existing job 48 | 49 | Arguments: 50 | job_id (str): Job ID 51 | kwargs (dict, optional): All other optional keyword arguments are applied as query 52 | parameters/filters 53 | 54 | Returns: 55 | [`Job`](/docs/job) 56 | 57 | Raises: 58 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 59 | or valid response JSON 60 | 61 | Example: 62 | ``` 63 | >>> air.jobs.get('3dadd54d-583c-432e-9383-a2b0b1d7f551') 64 | 65 | ``` 66 | """ 67 | url = f'{self.url}{job_id}/' 68 | res = self.client.get(url, params=kwargs) 69 | util.raise_if_invalid_response(res) 70 | return Job(self, **res.json()) 71 | 72 | def list(self, **kwargs): 73 | # pylint: disable=line-too-long 74 | """ 75 | List existing jobs 76 | 77 | Arguments: 78 | kwargs (dict, optional): All other optional keyword arguments are applied as query 79 | parameters/filters 80 | 81 | Returns: 82 | list 83 | 84 | Raises: 85 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 86 | or valid response JSON 87 | 88 | Example: 89 | ``` 90 | >>> air.jobs.list() 91 | [, ] 92 | ``` 93 | """ # pylint: enable=line-too-long 94 | res = self.client.get(f'{self.url}', params=kwargs) 95 | util.raise_if_invalid_response(res, data_type=list) 96 | return [Job(self, **job) for job in res.json()] 97 | -------------------------------------------------------------------------------- /tests/test_ssh_key.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for ssh_key.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import ssh_key 13 | 14 | 15 | class TestSSHKey(TestCase): 16 | def setUp(self): 17 | self.model = ssh_key.SSHKey(MagicMock()) 18 | self.model.id = 'abc123' 19 | self.model.name = 'public' 20 | 21 | def test_init(self): 22 | self.assertTrue(self.model._deletable) 23 | self.assertFalse(self.model._updatable) 24 | 25 | def test_repr(self): 26 | self.assertEqual(str(self.model), f'') 27 | 28 | def test_repr_deleted(self): 29 | self.model._deleted = True 30 | self.assertTrue('Deleted Object' in str(self.model)) 31 | 32 | 33 | class TestSSHKeyApi(TestCase): 34 | def setUp(self): 35 | self.client = MagicMock() 36 | self.client.api_url = 'http://testserver/api' 37 | self.api = ssh_key.SSHKeyApi(self.client) 38 | 39 | def test_init_(self): 40 | self.assertEqual(self.api.client, self.client) 41 | self.assertEqual(self.api.url, 'http://testserver/api/sshkey/') 42 | 43 | @patch('air_sdk.util.raise_if_invalid_response') 44 | def test_list(self, mock_raise): 45 | self.client.get.return_value.json.return_value = [{'id': 'abc'}, {'id': 'xyz'}] 46 | res = self.api.list(foo='bar') 47 | self.client.get.assert_called_with(f'{self.client.api_url}/sshkey/', params={'foo': 'bar'}) 48 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 49 | self.assertEqual(len(res), 2) 50 | self.assertIsInstance(res[0], ssh_key.SSHKey) 51 | self.assertEqual(res[0].id, 'abc') 52 | self.assertEqual(res[1].id, 'xyz') 53 | 54 | @patch('air_sdk.util.raise_if_invalid_response') 55 | def test_create(self, mock_raise): 56 | self.client.post.return_value.json.return_value = {'id': 'abc'} 57 | res = self.api.create(public_key='abc123', name='test') 58 | self.client.post.assert_called_with( 59 | f'{self.client.api_url}/sshkey/', json={'public_key': 'abc123', 'name': 'test'} 60 | ) 61 | mock_raise.assert_called_with(self.client.post.return_value, status_code=201) 62 | self.assertIsInstance(res, ssh_key.SSHKey) 63 | self.assertEqual(res.id, 'abc') 64 | 65 | def test_create_required_kwargs(self): 66 | with self.assertRaises(AttributeError) as err: 67 | self.api.create(name='test') 68 | self.assertTrue('requires public_key' in str(err.exception)) 69 | with self.assertRaises(AttributeError) as err: 70 | self.api.create(public_key='abc123') 71 | self.assertTrue('requires name' in str(err.exception)) 72 | -------------------------------------------------------------------------------- /tests/test_account.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for account.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import account, user_preference 13 | 14 | 15 | class TestAccount(TestCase): 16 | def setUp(self): 17 | self.model = account.Account(MagicMock()) 18 | self.model.username = 'foo' 19 | self.model.id = 'abc123' 20 | self.assertFalse(self.model._deletable) 21 | self.assertFalse(self.model._updatable) 22 | 23 | def test_init(self): 24 | self.assertFalse(self.model._deletable) 25 | self.assertFalse(self.model._updatable) 26 | 27 | def test_repr(self): 28 | self.assertEqual(str(self.model), f'') 29 | 30 | def test_repr_deleted(self): 31 | self.model._deleted = True 32 | self.assertTrue('Deleted Object' in str(self.model)) 33 | 34 | 35 | class TestAccountApi(TestCase): 36 | def setUp(self): 37 | self.mock_client = MagicMock() 38 | self.mock_client.api_url = 'http://testserver/api' 39 | self.mock_client.get.return_value.status_code = 200 40 | self.url = self.mock_client.api_url + '/account/' 41 | self.api = account.AccountApi(self.mock_client) 42 | 43 | def test_init(self): 44 | self.assertEqual(self.api.client, self.mock_client) 45 | self.assertEqual(self.api.url, self.url) 46 | 47 | def test_get(self): 48 | self.mock_client.get.return_value.json.return_value = {'foo': 'bar'} 49 | res = self.api.get('abc123', foo='bar') 50 | self.mock_client.get.assert_called_with(self.url + 'abc123/', params={'foo': 'bar'}) 51 | self.assertIsInstance(res, account.Account) 52 | self.assertEqual(res.foo, 'bar') 53 | 54 | def test_list(self): 55 | self.mock_client.get.return_value.json.return_value = [{'foo': 'bar'}] 56 | res = self.api.list(foo='bar') 57 | self.mock_client.get.assert_called_with(self.url, params={'foo': 'bar'}) 58 | self.assertEqual(len(res), 1) 59 | self.assertIsInstance(res[0], account.Account) 60 | self.assertEqual(res[0].foo, 'bar') 61 | 62 | @patch('air_sdk.account.util.raise_if_invalid_response') 63 | def test_preferences(self, mock_raise): 64 | prefs = {'foo': 'bar'} 65 | self.mock_client.get.return_value.json.return_value = {'preferences': prefs} 66 | kwargs = {'test': True} 67 | 68 | res = self.api.preferences(**kwargs) 69 | self.mock_client.get.assert_called_with(f'{self.url}preferences/', params=kwargs) 70 | mock_raise.assert_called_once_with(self.mock_client.get.return_value) 71 | self.assertIsInstance(res, user_preference.UserPreference) 72 | self.assertEqual(res.preferences, prefs) 73 | -------------------------------------------------------------------------------- /air_sdk/resource_budget.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | ResourceBudget module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class ResourceBudget(AirModel): 13 | """ 14 | Manage a ResourceBudget 15 | 16 | ### json 17 | Returns a JSON string representation of the budget 18 | 19 | ### refresh 20 | Syncs the budget with all values returned by the API 21 | 22 | ### update 23 | Update the budget with the provided data 24 | 25 | Arguments: 26 | kwargs (dict, optional): All optional keyword arguments are applied as key/value 27 | pairs in the request's JSON payload 28 | """ 29 | 30 | _deletable = False 31 | 32 | def __repr__(self): 33 | if self._deleted: 34 | return super().__repr__() 35 | return f'' 36 | 37 | 38 | class ResourceBudgetApi: 39 | """High-level interface for the ResourceBudget API""" 40 | 41 | def __init__(self, client): 42 | self.client = client 43 | self.url = self.client.api_url + '/resource-budget/' 44 | 45 | def get(self, budget_id, **kwargs): 46 | """ 47 | Get an existing budget 48 | 49 | Arguments: 50 | budget_id (str): ResourceBudget ID 51 | kwargs (dict, optional): All other optional keyword arguments are applied as query 52 | parameters/filters 53 | 54 | Returns: 55 | [`ResourceBudget`](/docs/resourcebudget) 56 | 57 | Raises: 58 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 59 | or valid response JSON 60 | 61 | Example: 62 | ``` 63 | >>> air.resource_budgets.get('c604c262-396a-48a0-a8f6-31708c0cff82') 64 | 65 | ``` 66 | """ 67 | url = f'{self.url}{budget_id}/' 68 | res = self.client.get(url, params=kwargs) 69 | util.raise_if_invalid_response(res) 70 | return ResourceBudget(self, **res.json()) 71 | 72 | def list(self, **kwargs): 73 | # pylint: disable=line-too-long 74 | """ 75 | List existing budgets 76 | 77 | Arguments: 78 | kwargs (dict, optional): All other optional keyword arguments are applied as query 79 | parameters/filters 80 | 81 | Returns: 82 | list 83 | 84 | Raises: 85 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 86 | or valid response JSON 87 | 88 | Example: 89 | ``` 90 | >>> air.resource_budgets.list() 91 | [, ] 92 | ``` 93 | """ # pylint: enable=line-too-long 94 | res = self.client.get(f'{self.url}', params=kwargs) 95 | util.raise_if_invalid_response(res, data_type=list) 96 | return [ResourceBudget(self, **budget) for budget in res.json()] 97 | -------------------------------------------------------------------------------- /air_sdk/v2/client.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from typing import Dict, Any, Optional 5 | from json import JSONDecodeError 6 | from urllib.parse import urlparse 7 | 8 | import requests # type: ignore[import-untyped] 9 | 10 | from air_sdk.v2.utils import join_urls 11 | from air_sdk import logger, const 12 | from air_sdk.air_api import AirAuthorizationError, _normalize_api_url # type: ignore[attr-defined] 13 | 14 | 15 | class Client(requests.Session): # type: ignore 16 | """A session client for managing the execution of API requests.""" 17 | 18 | def __init__(self, api_url: str): 19 | super().__init__() 20 | self.headers.update({'content-type': 'application/json', 'Authorization': None}) 21 | self.api_url = _normalize_api_url(api_url) 22 | self.base_url = join_urls(self.api_url, 'v2') 23 | self.connect_timeout = const.DEFAULT_CONNECT_TIMEOUT 24 | self.read_timeout = const.DEFAULT_READ_TIMEOUT 25 | self.pagination_page_size = const.DEFAULT_PAGINATION_PAGE_SIZE 26 | 27 | def rebuild_auth(self, prepared_request, response): 28 | """Allow credential sharing between nvidia.com and cumulusnetworks.com only""" 29 | if urlparse(prepared_request.url).hostname in const.ALLOWED_HOSTS: 30 | return 31 | super().rebuild_auth(prepared_request, response) 32 | 33 | def request(self, method: str, url: str, **kwargs: Dict[str, Any]) -> requests.Response: 34 | """Override request method to pass the timeout""" 35 | kwargs.setdefault('timeout', (self.connect_timeout, self.read_timeout)) # type: ignore[arg-type] 36 | return super().request(method, url, **kwargs) 37 | 38 | def authenticate( 39 | self, 40 | username: Optional[str] = None, 41 | password: Optional[str] = None, 42 | bearer_token: Optional[str] = None, 43 | ) -> None: 44 | if bearer_token is None and not (username and password): 45 | raise ValueError( 46 | 'Unable to authenticate client. Please provide one of the following: ' 47 | '(1) `username` and `password`; ' 48 | '(2) `bearer_token`' 49 | ) 50 | token = bearer_token or self.get_token(username, password) # type: ignore[arg-type] 51 | self.headers.update({'authorization': f'Bearer {token}'}) 52 | 53 | def get_token(self, username: str, password: str) -> str: 54 | login_url = join_urls(self.api_url, 'v1', 'login') 55 | response = self.post(login_url, json={'username': username, 'password': password}) 56 | try: 57 | token = response.json().get('token', None) 58 | if isinstance(token, str): 59 | return token 60 | logger.debug('AirApi.get_token :: Response JSON') # type: ignore 61 | logger.debug(response.json()) # type: ignore 62 | raise AirAuthorizationError('API did not provide a token for ' + username) 63 | except JSONDecodeError: 64 | raise AirAuthorizationError('API did not return a valid JSON response') 65 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/user_configs.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from dataclasses import dataclass, field 5 | from typing import Optional, Union 6 | 7 | from air_sdk.v2.endpoints import mixins 8 | from air_sdk.v2.endpoints.accounts import Account 9 | from air_sdk.v2.endpoints.organizations import Organization 10 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi, DataDict, PrimaryKey 11 | from air_sdk.v2.endpoints.resource_budgets import ResourceBudget 12 | from air_sdk.v2.utils import validate_payload_types 13 | 14 | 15 | @dataclass(eq=False) 16 | class UserConfig(AirModel): 17 | id: str = field(repr=False) 18 | name: str 19 | kind: str 20 | owner: Optional[Account] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 21 | owner_budget: Optional[ResourceBudget] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 22 | organization: Optional[Organization] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 23 | organization_budget: Optional[ResourceBudget] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 24 | content: Optional[str] = field(metadata=AirModel.FIELD_LAZY, repr=False) 25 | 26 | @classmethod 27 | def get_model_api(cls): 28 | """ 29 | Returns the respective `AirModelAPI` type for this model. 30 | """ 31 | return UserConfigEndpointApi 32 | 33 | @validate_payload_types 34 | def update(self, name: Optional[str] = None, content: Optional[str] = None) -> None: 35 | """Update specific fields of the user config..""" 36 | data = {'name': name, 'content': content} 37 | data = {key: value for (key, value) in data.items() if value is not None} 38 | super().update(**data) 39 | 40 | @validate_payload_types 41 | def full_update(self, name: str, content: str) -> None: 42 | """Update all fields of the cloud-init assignment.""" 43 | super().full_update( 44 | name=name, 45 | content=content, 46 | owner=self.owner, 47 | owner_budget=self.owner_budget, 48 | kind=self.kind, 49 | organization=self.organization, 50 | organization_budget=self.organization_budget, 51 | ) 52 | 53 | 54 | class UserConfigEndpointApi( 55 | mixins.ListApiMixin[UserConfig], 56 | mixins.CreateApiMixin[UserConfig], 57 | mixins.GetApiMixin[UserConfig], 58 | mixins.PatchApiMixin[UserConfig], 59 | mixins.PutApiMixin[UserConfig], 60 | mixins.DeleteApiMixin, 61 | BaseEndpointApi[UserConfig], 62 | ): 63 | API_PATH = 'userconfigs' 64 | model = UserConfig 65 | 66 | @validate_payload_types 67 | def create( 68 | self, 69 | name: str, 70 | kind: str, 71 | content: str, 72 | organization: Optional[Union[Organization, PrimaryKey]] = None, 73 | ) -> UserConfig: 74 | payload: DataDict = { 75 | 'name': name, 76 | 'kind': kind, 77 | 'content': content, 78 | 'organization': organization, 79 | } 80 | return super().create(**payload) 81 | -------------------------------------------------------------------------------- /air_sdk/ssh_key.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | SSH Key module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class SSHKey(AirModel): 13 | """ 14 | Manage a SSH Key 15 | 16 | ### delete 17 | Delete the key. Once successful, the object should no longer be used and will raise 18 | [`AirDeletedObject`](/docs/exceptions) when referenced. 19 | 20 | Raises: 21 | [`AirUnexpectedResposne`](/docs/exceptions) - Delete failed 22 | 23 | ### json 24 | Returns a JSON string representation of the key 25 | 26 | ### refresh 27 | Syncs the key with all values returned by the API 28 | """ 29 | 30 | _updatable = False 31 | 32 | def __repr__(self): 33 | if self._deleted or not self.name: 34 | return super().__repr__() 35 | return f'' 36 | 37 | 38 | class SSHKeyApi: 39 | """High-level interface for the SSHKey API""" 40 | 41 | def __init__(self, client): 42 | self.client = client 43 | self.url = self.client.api_url + '/sshkey/' 44 | 45 | def list(self, **kwargs): 46 | # pylint: disable=line-too-long 47 | """ 48 | List existing keys 49 | 50 | Arguments: 51 | kwargs (dict, optional): All other optional keyword arguments are applied as query 52 | parameters/filters 53 | 54 | Returns: 55 | list 56 | 57 | Raises: 58 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 59 | or valid response JSON 60 | 61 | Example: 62 | ``` 63 | >>> air.ssh_keys.list() 64 | [, ] 65 | ``` 66 | """ # pylint: enable=line-too-long 67 | res = self.client.get(f'{self.url}', params=kwargs) 68 | util.raise_if_invalid_response(res, data_type=list) 69 | return [SSHKey(self, **key) for key in res.json()] 70 | 71 | @util.required_kwargs(['public_key', 'name']) 72 | def create(self, **kwargs): 73 | """ 74 | Add a new public key to your account 75 | 76 | Arguments: 77 | name (str): Descriptive name for the public key 78 | public_key (str): Public key 79 | kwargs (dict, optional): All other optional keyword arguments are applied as key/value 80 | pairs in the request's JSON payload 81 | 82 | Returns: 83 | [`SSHKey`](/docs/sshkey) 84 | 85 | Raises: 86 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 87 | or valid response JSON 88 | 89 | Example: 90 | ``` 91 | >>> air.ssh_keys.create(name='my_pub_key', public_key='') 92 | 93 | ``` 94 | """ 95 | res = self.client.post(self.url, json=kwargs) 96 | util.raise_if_invalid_response(res, status_code=201) 97 | return SSHKey(self, **res.json()) 98 | -------------------------------------------------------------------------------- /tests/test_capacity.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for capacity.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring,unused-argument 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import capacity 13 | from air_sdk import simulation 14 | 15 | Simulation = simulation.Simulation 16 | 17 | 18 | class TestCapacity(TestCase): 19 | def setUp(self): 20 | self.model = capacity.Capacity(MagicMock()) 21 | self.model.copies = 30 22 | 23 | def test_init_(self): 24 | self.assertFalse(self.model._deletable) 25 | self.assertFalse(self.model._updatable) 26 | 27 | def test_repr(self): 28 | self.assertEqual(str(self.model), f'') 29 | 30 | def test_repr_zero(self): 31 | self.model.copies = 0 32 | self.assertEqual(str(self.model), f'') 33 | 34 | def test_repr_deleted(self): 35 | self.model._deleted = True 36 | self.assertTrue('Deleted Object' in str(self.model)) 37 | 38 | 39 | class TestCapacityApi(TestCase): 40 | def setUp(self): 41 | self.client = MagicMock() 42 | self.client.api_url = 'http://testserver/api' 43 | self.api = capacity.CapacityApi(self.client) 44 | 45 | def test_init_(self): 46 | self.assertEqual(self.api.client, self.client) 47 | self.assertEqual(self.api.url, 'http://testserver/api/capacity/') 48 | 49 | @patch('air_sdk.capacity.CapacityApi.get') 50 | def test_get_capacity_by_sim(self, mock_get): 51 | mock_sim = MagicMock() 52 | res = self.api.get_capacity(mock_sim) 53 | mock_get.assert_called_with(simulation_id=mock_sim.id) 54 | self.assertEqual(res, mock_get.return_value) 55 | 56 | @patch('air_sdk.capacity.CapacityApi.get') 57 | def test_get_capacity_by_id(self, mock_get): 58 | res = self.api.get_capacity(simulation_id='abc123') 59 | mock_get.assert_called_with(simulation_id='abc123') 60 | self.assertEqual(res, mock_get.return_value) 61 | 62 | def test_get_capacity_missing_param(self): 63 | with self.assertRaises(ValueError) as err: 64 | self.api.get_capacity() 65 | self.assertEqual(str(err.exception), 'Must pass a simulation or simulation_id argument') 66 | 67 | @patch('air_sdk.util.raise_if_invalid_response') 68 | def test_get(self, mock_raise): 69 | self.client.get.return_value.json.return_value = {'test': 'success'} 70 | res = self.api.get('abc123', foo='bar') 71 | self.client.get.assert_called_with(f'{self.client.api_url}/capacity/abc123/', params={'foo': 'bar'}) 72 | mock_raise.assert_called_with(self.client.get.return_value) 73 | self.assertIsInstance(res, capacity.Capacity) 74 | self.assertEqual(res.test, 'success') 75 | 76 | @patch('air_sdk.util.raise_if_invalid_response') 77 | def test_get_simulation(self, mock_raise): 78 | sim = Simulation(MagicMock()) 79 | sim.id = 'abc123' 80 | self.api.get(sim) 81 | self.client.get.assert_called_with(f'{self.client.api_url}/capacity/abc123/', params={}) 82 | -------------------------------------------------------------------------------- /tests/tests_v2/test_typing.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | import pytest 5 | from typing import TypedDict, Union, List, Dict 6 | from dataclasses import dataclass 7 | from air_sdk.v2 import typing 8 | 9 | import faker 10 | 11 | faker.Faker.seed(0) 12 | fake = faker.Faker() 13 | 14 | 15 | class ExampleTypedDict(TypedDict): 16 | field1: int 17 | field2: str 18 | 19 | 20 | class NestedTypedDict(TypedDict): 21 | field1: ExampleTypedDict 22 | field2: List[str] 23 | 24 | 25 | @dataclass 26 | class ExampleDataclass: 27 | field1: int 28 | field2: str 29 | 30 | 31 | class TestTypingMethods: 32 | @pytest.mark.parametrize( 33 | 'test_type,expected', 34 | [ 35 | (ExampleTypedDict, True), 36 | (ExampleDataclass, False), 37 | (int, False), 38 | ], 39 | ) 40 | def test_is_typed_dict(self, test_type, expected): 41 | assert typing.is_typed_dict(test_type) == expected 42 | 43 | @pytest.mark.parametrize( 44 | 'value,expected', 45 | [ 46 | ({'field1': fake.pyint(), 'field2': fake.slug()}, True), 47 | ({'field1': fake.pyint()}, False), 48 | ({'field1': str(fake.pyint()), 'field2': fake.slug()}, False), 49 | ], 50 | ) 51 | def test_type_check_typed_dict(self, value, expected): 52 | assert typing.type_check_typed_dict(value, ExampleTypedDict) == expected 53 | 54 | @pytest.mark.parametrize( 55 | 'value,expected', 56 | [ 57 | ( 58 | { 59 | 'field1': {'field1': fake.pyint(), 'field2': fake.slug()}, 60 | 'field2': [fake.slug(), fake.slug()], 61 | }, 62 | True, 63 | ), 64 | ( 65 | { 66 | 'field1': {'field1': fake.pyint(), 'field2': fake.pyint()}, 67 | 'field2': [fake.slug(), fake.slug()], 68 | }, 69 | False, 70 | ), 71 | ], 72 | ) 73 | def test_type_check_typed_dict_nested(self, value, expected): 74 | assert typing.type_check(value, NestedTypedDict) == expected 75 | 76 | @pytest.mark.parametrize( 77 | 'value,expected', 78 | [ 79 | ([1, 2, 3], True), 80 | ([1, 2, 'three'], False), 81 | ], 82 | ) 83 | def test_type_check_list(self, value, expected): 84 | assert typing.type_check(value, List[int]) == expected 85 | 86 | @pytest.mark.parametrize( 87 | 'value,expected', 88 | [ 89 | (123, True), 90 | (fake.slug(), True), 91 | (123.45, False), 92 | ], 93 | ) 94 | def test_type_check_union(self, value, expected): 95 | assert typing.type_check(value, Union[int, str]) == expected 96 | 97 | @pytest.mark.parametrize( 98 | 'value,expected', 99 | [ 100 | ({'key': 'value'}, True), 101 | ({1: 'value'}, False), 102 | ], 103 | ) 104 | def test_type_check_dict(self, value, expected): 105 | assert typing.type_check(value, Dict[str, str]) == expected 106 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_workers.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | from http import HTTPStatus 4 | 5 | import pytest 6 | import faker 7 | 8 | from air_sdk.v2 import AirModelAttributeError 9 | from air_sdk.v2.utils import join_urls 10 | 11 | faker.Faker.seed(0) 12 | fake = faker.Faker() 13 | 14 | 15 | class TestWorkerEndpointApi: 16 | def test_list(self, api, run_list_test, worker_factory): 17 | run_list_test(api.workers, worker_factory) 18 | 19 | @pytest.mark.parametrize( 20 | 'contact,is_valid', 21 | ( 22 | (fake.email(), True), 23 | (fake.pylist(value_types=(str,)), True), 24 | (fake.pydict(value_types=(str,)), True), 25 | (None, False), 26 | (fake.pyint(), False), 27 | ), 28 | ) 29 | def test_contact_field_response(self, api, worker_factory, contact, is_valid): 30 | if is_valid: 31 | worker = worker_factory(api, contact=contact) 32 | assert worker.contact == contact 33 | else: 34 | with pytest.raises(AirModelAttributeError): 35 | worker_factory(api, contact=contact) 36 | 37 | def test_refresh(self, api, run_refresh_test, worker_factory): 38 | run_refresh_test(api.workers, worker_factory) 39 | 40 | def test_heartbeat(self, api, setup_mock_responses, worker_factory): 41 | worker = worker_factory(api) 42 | setup_mock_responses( 43 | {('GET', join_urls(api.workers.url, 'heartbeat')): {'status_code': HTTPStatus.OK}} 44 | ) 45 | worker.heartbeat() 46 | 47 | @pytest.mark.parametrize( 48 | 'payload,is_valid', 49 | ( 50 | ({}, True), 51 | ({'invalid_key': fake.slug()}, False), 52 | ({'airstrike_version': fake.slug()}, True), 53 | ({'architecture': fake.slug()}, True), 54 | ({'docker': fake.slug()}, True), 55 | ({'kernel': fake.slug()}, True), 56 | ({'libvirt': fake.slug()}, True), 57 | ({'operating_system': fake.slug()}, True), 58 | ({'proxy_image': fake.slug()}, True), 59 | ({'worker_version': fake.slug()}, True), 60 | ( 61 | { 62 | 'airstrike_version': fake.slug(), 63 | 'architecture': fake.slug(), 64 | 'docker': fake.slug(), 65 | 'kernel': fake.slug(), 66 | 'libvirt': fake.slug(), 67 | 'operating_system': fake.slug(), 68 | 'proxy_image': fake.slug(), 69 | 'worker_version': fake.slug(), 70 | }, 71 | True, 72 | ), 73 | ), 74 | ) 75 | def test_update_inventory(self, api, setup_mock_responses, worker_factory, payload, is_valid): 76 | worker = worker_factory(api) 77 | setup_mock_responses( 78 | {('PUT', join_urls(api.workers.url, 'inventory')): {'status_code': HTTPStatus.OK}} 79 | ) 80 | if is_valid: 81 | worker.update_inventory(**payload) 82 | else: 83 | with pytest.raises(Exception) as err: 84 | worker.update_inventory(**payload) 85 | assert err.type in (TypeError, ValueError) 86 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/workers.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from dataclasses import dataclass, field 5 | from datetime import datetime 6 | from typing import Optional, Union, List, Dict 7 | 8 | from air_sdk.util import raise_if_invalid_response 9 | from air_sdk.v2.endpoints import mixins 10 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi 11 | from air_sdk.v2.endpoints.mixins import serialize_payload 12 | from air_sdk.v2.utils import join_urls, validate_payload_types 13 | 14 | 15 | @dataclass(eq=False) 16 | class Worker(AirModel): 17 | id: str = field(repr=False) 18 | fqdn: str 19 | cpu_arch: str 20 | available: bool = field(repr=False) 21 | capabilities: str = field(repr=False) 22 | cpu: int = field(repr=False) 23 | contact: Union[str, List[str], Dict[str, str]] = field(repr=False) 24 | created: datetime = field(repr=False) 25 | fleet: str = field(repr=False) 26 | gpu: int = field(repr=False) 27 | ip_address: str = field(repr=False) 28 | memory: int = field(repr=False) 29 | modified: datetime = field(repr=False) 30 | port_range: str = field(repr=False) 31 | registered: bool = field(repr=False) 32 | storage: int = field(repr=False) 33 | tunnel_port: int = field(repr=False) 34 | vgpu: int = field(repr=False) 35 | 36 | @classmethod 37 | def get_model_api(cls): 38 | """ 39 | Returns the respective `AirModelAPI` type for this model. 40 | """ 41 | return WorkerEndpointApi 42 | 43 | def heartbeat(self) -> None: 44 | """Keeps the worker alive with via a heartbeat to the manager.""" 45 | heartbeat_url = join_urls(self.__api__.workers.url, 'heartbeat') 46 | response = self.__api__.client.get(heartbeat_url) 47 | raise_if_invalid_response(response, data_type=None) 48 | 49 | @validate_payload_types 50 | def update_inventory( 51 | self, 52 | airstrike_version: Optional[str] = None, 53 | architecture: Optional[str] = None, 54 | docker: Optional[str] = None, 55 | kernel: Optional[str] = None, 56 | libvirt: Optional[str] = None, 57 | operating_system: Optional[str] = None, 58 | proxy_image: Optional[str] = None, 59 | worker_version: Optional[str] = None, 60 | ) -> None: 61 | """Update the worker's inventory with the manager.""" 62 | data = { 63 | 'airstrike_version': airstrike_version, 64 | 'architecture': architecture, 65 | 'docker': docker, 66 | 'kernel': kernel, 67 | 'libvirt': libvirt, 68 | 'operating_system': operating_system, 69 | 'proxy_image': proxy_image, 70 | 'worker_version': worker_version, 71 | } 72 | inventory_url = join_urls(self.__api__.workers.url, 'inventory') 73 | response = self.__api__.client.put( 74 | inventory_url, 75 | data=serialize_payload({key: value for (key, value) in data.items() if value is not None}), 76 | ) 77 | raise_if_invalid_response(response, data_type=None) 78 | 79 | 80 | class WorkerEndpointApi( 81 | mixins.ListApiMixin[Worker], 82 | mixins.GetApiMixin[Worker], 83 | BaseEndpointApi[Worker], 84 | ): 85 | API_PATH = 'workers' 86 | model = Worker 87 | -------------------------------------------------------------------------------- /tests/test_simulation_interface.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for simulation_interface.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring,unused-argument 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import simulation_interface 13 | 14 | 15 | class TestSimulationInterface(TestCase): 16 | def setUp(self): 17 | self.model = simulation_interface.SimulationInterface(MagicMock()) 18 | self.model.id = 'abc123' 19 | 20 | def test_init_(self): 21 | self.assertFalse(self.model._deletable) 22 | self.assertTrue(self.model._updatable) 23 | 24 | def test_repr(self): 25 | self.assertEqual(str(self.model), f'') 26 | 27 | def test_repr_deleted(self): 28 | self.model._deleted = True 29 | self.assertTrue('Deleted Object' in str(self.model)) 30 | 31 | 32 | class TestSimulationInterfaceApi(TestCase): 33 | def setUp(self): 34 | self.client = MagicMock() 35 | self.client.api_url = 'http://testserver/api' 36 | self.api = simulation_interface.SimulationInterfaceApi(self.client) 37 | 38 | def test_init_(self): 39 | self.assertEqual(self.api.client, self.client) 40 | self.assertEqual(self.api.url, 'http://testserver/api/simulation-interface/') 41 | 42 | @patch('air_sdk.simulation_interface.SimulationInterfaceApi.list') 43 | def test_get_simulation_interfaces(self, mock_list): 44 | res = self.api.get_simulation_interfaces('abc123', 'xyz123') 45 | mock_list.assert_called_with(simulation='abc123', original='xyz123') 46 | self.assertEqual(res, mock_list.return_value) 47 | 48 | @patch('air_sdk.util.raise_if_invalid_response') 49 | def test_get(self, mock_raise): 50 | self.client.get.return_value.json.return_value = {'test': 'success'} 51 | res = self.api.get('abc123', foo='bar') 52 | self.client.get.assert_called_with( 53 | f'{self.client.api_url}/simulation-interface/abc123/', params={'foo': 'bar'} 54 | ) 55 | mock_raise.assert_called_with(self.client.get.return_value) 56 | self.assertIsInstance(res, simulation_interface.SimulationInterface) 57 | self.assertEqual(res.test, 'success') 58 | 59 | @patch('air_sdk.util.raise_if_invalid_response') 60 | def test_list(self, mock_raise): 61 | self.client.get.return_value.json.return_value = [{'id': 'abc'}, {'id': 'xyz'}] 62 | res = self.api.list(foo='bar') 63 | self.client.get.assert_called_with( 64 | f'{self.client.api_url}/simulation-interface/', params={'foo': 'bar'} 65 | ) 66 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 67 | self.assertEqual(len(res), 2) 68 | self.assertIsInstance(res[0], simulation_interface.SimulationInterface) 69 | self.assertEqual(res[0].id, 'abc') 70 | self.assertEqual(res[1].id, 'xyz') 71 | 72 | @patch('air_sdk.util.raise_if_invalid_response') 73 | def test_list_interface(self, mock_raise): 74 | self.api.list(interface='test') 75 | self.client.get.assert_called_with( 76 | f'{self.client.api_url}/simulation-interface/', params={'original': 'test'} 77 | ) 78 | -------------------------------------------------------------------------------- /tests/test_token.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for token.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import token 13 | 14 | 15 | class TestAPIToken(TestCase): 16 | def setUp(self): 17 | self.model = token.Token(MagicMock()) 18 | # self.model.id = 'abc123' 19 | # self.model.token = 'abc123' 20 | self.model.name = 'public' 21 | 22 | def test_init_(self): 23 | self.assertTrue(self.model._deletable) 24 | self.assertFalse(self.model._updatable) 25 | 26 | def test_repr(self): 27 | self.assertEqual(str(self.model), f'') 28 | 29 | def test_repr_id(self): 30 | self.model.id = 'abc123' 31 | self.assertEqual(str(self.model), f'') 32 | 33 | def test_repr_token(self): 34 | self.model.token = 'abc123' 35 | self.assertEqual(str(self.model), f'') 36 | 37 | def test_repr_deleted(self): 38 | self.model._deleted = True 39 | self.assertTrue('Deleted Object' in str(self.model)) 40 | 41 | 42 | class TestTokenApi(TestCase): 43 | def setUp(self): 44 | self.client = MagicMock() 45 | self.client.api_url = 'http://testserver/api' 46 | self.api = token.TokenApi(self.client) 47 | 48 | def test_init_(self): 49 | self.assertEqual(self.api.client, self.client) 50 | self.assertEqual(self.api.url, 'http://testserver/api/api-token/') 51 | 52 | @patch('air_sdk.util.raise_if_invalid_response') 53 | def test_list(self, mock_raise): 54 | self.client.get.return_value.json.return_value = [{'id': 'abc'}, {'id': 'xyz'}] 55 | res = self.api.list(foo='bar') 56 | self.client.get.assert_called_with(f'{self.client.api_url}/api-token/', params={'foo': 'bar'}) 57 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 58 | self.assertEqual(len(res), 2) 59 | self.assertIsInstance(res[0], token.Token) 60 | self.assertEqual(res[0].id, 'abc') 61 | self.assertEqual(res[1].id, 'xyz') 62 | 63 | @patch('air_sdk.util.raise_if_invalid_response') 64 | def test_create(self, mock_raise): 65 | self.client.post.return_value.json.return_value = {'id': 'abc'} 66 | res = self.api.create(name='test') 67 | self.client.post.assert_called_with(f'{self.client.api_url}/api-token/', json={'name': 'test'}) 68 | mock_raise.assert_called_with(self.client.post.return_value, status_code=201) 69 | self.assertIsInstance(res, token.Token) 70 | self.assertEqual(res.id, 'abc') 71 | 72 | def test_create_required_kwargs(self): 73 | with self.assertRaises(AttributeError) as err: 74 | self.api.create() 75 | self.assertTrue('requires name' in str(err.exception)) 76 | 77 | @patch('air_sdk.util.raise_if_invalid_response') 78 | def test_delete(self, mock_raise): 79 | self.api.delete('abc') 80 | self.client.delete.assert_called_with(f'{self.client.api_url}/api-token/abc/', params={}) 81 | mock_raise.assert_called_with(self.client.delete.return_value, status_code=204, data_type=None) 82 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | # ruff:noqa: F401 5 | 6 | __all__ = [ 7 | 'AccountEndpointApi', 8 | 'Account', 9 | 'AnnouncementEndpointApi', 10 | 'Announcement', 11 | 'ApiTokenEndpointApi', 12 | 'ApiToken', 13 | 'BreakoutEndpointApi', 14 | 'Breakout', 15 | 'CloudInitEndpointApi', 16 | 'CloudInit', 17 | 'FleetEndpointApi', 18 | 'Fleet', 19 | 'ImageEndpointApi', 20 | 'Image', 21 | 'InterfaceEndpointApi', 22 | 'Interface', 23 | 'JobEndpointApi', 24 | 'Job', 25 | 'ManifestEndpointApi', 26 | 'Manifest', 27 | 'Link', 28 | 'LinkEndpointApi', 29 | 'MarketplaceDemoEndpointApi', 30 | 'MarketplaceDemo', 31 | 'MarketplaceDemoTagsEndpointApi', 32 | 'MarketplaceDemoTag', 33 | 'NodeEndpointApi', 34 | 'Node', 35 | 'OrganizationEndpointApi', 36 | 'Organization', 37 | 'ResourceBudgetEndpointApi', 38 | 'ResourceBudget', 39 | 'ServiceEndpointApi', 40 | 'Service', 41 | 'SimulationEndpointApi', 42 | 'Simulation', 43 | 'SystemEndpointApi', 44 | 'System', 45 | 'TopologyEndpointApi', 46 | 'Topology', 47 | 'UserConfigEndpointApi', 48 | 'UserConfig', 49 | 'WorkerEndpointApi', 50 | 'Worker', 51 | 'mixins', 52 | 'OrganizationMember', 53 | 'OrganizationMembersEndpointApi', 54 | 'NodeInstructionsEndpointApi', 55 | 'NodeInstruction', 56 | ] 57 | 58 | from air_sdk.v2.endpoints import mixins 59 | from air_sdk.v2.endpoints.accounts import Account, AccountEndpointApi 60 | from air_sdk.v2.endpoints.announcements import Announcement, AnnouncementEndpointApi 61 | from air_sdk.v2.endpoints.api_tokens import ApiToken, ApiTokenEndpointApi 62 | from air_sdk.v2.endpoints.breakouts import Breakout, BreakoutEndpointApi 63 | from air_sdk.v2.endpoints.cloud_inits import CloudInit, CloudInitEndpointApi 64 | from air_sdk.v2.endpoints.fleets import Fleet, FleetEndpointApi 65 | from air_sdk.v2.endpoints.images import Image, ImageEndpointApi 66 | from air_sdk.v2.endpoints.interfaces import Interface, InterfaceEndpointApi, Link, LinkEndpointApi 67 | from air_sdk.v2.endpoints.jobs import Job, JobEndpointApi 68 | from air_sdk.v2.endpoints.manifests import Manifest, ManifestEndpointApi 69 | from air_sdk.v2.endpoints.marketplace_demo_tags import MarketplaceDemoTag, MarketplaceDemoTagsEndpointApi 70 | from air_sdk.v2.endpoints.marketplace_demos import MarketplaceDemo, MarketplaceDemoEndpointApi 71 | from air_sdk.v2.endpoints.nodes import Node, NodeEndpointApi 72 | from air_sdk.v2.endpoints.organizations import ( 73 | Organization, 74 | OrganizationEndpointApi, 75 | OrganizationMember, 76 | OrganizationMembersEndpointApi, 77 | ) 78 | from air_sdk.v2.endpoints.resource_budgets import ResourceBudget, ResourceBudgetEndpointApi 79 | from air_sdk.v2.endpoints.services import Service, ServiceEndpointApi 80 | from air_sdk.v2.endpoints.simulations import Simulation, SimulationEndpointApi 81 | from air_sdk.v2.endpoints.systems import System, SystemEndpointApi 82 | from air_sdk.v2.endpoints.topologies import Topology, TopologyEndpointApi 83 | from air_sdk.v2.endpoints.user_configs import UserConfig, UserConfigEndpointApi 84 | from air_sdk.v2.endpoints.workers import Worker, WorkerEndpointApi 85 | from air_sdk.v2.endpoints.node_instructions import NodeInstruction, NodeInstructionsEndpointApi 86 | -------------------------------------------------------------------------------- /tests/test_link.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for link.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import link 13 | 14 | 15 | class TestLink(TestCase): 16 | def setUp(self): 17 | self.model = link.Link(MagicMock()) 18 | self.model.id = 'abc123' 19 | 20 | def test_init_(self): 21 | self.assertTrue(self.model._deletable) 22 | self.assertTrue(self.model._updatable) 23 | 24 | def test_repr(self): 25 | self.assertEqual(str(self.model), f'') 26 | 27 | def test_repr_deleted(self): 28 | self.model._deleted = True 29 | self.assertTrue('Deleted Object' in str(self.model)) 30 | 31 | 32 | class TestLinkApi(TestCase): 33 | def setUp(self): 34 | self.client = MagicMock() 35 | self.client.api_url = 'http://testserver/api' 36 | self.api = link.LinkApi(self.client) 37 | 38 | def test_init_(self): 39 | self.assertEqual(self.api.client, self.client) 40 | self.assertEqual(self.api.url, 'http://testserver/api/link/') 41 | 42 | @patch('air_sdk.util.raise_if_invalid_response') 43 | def test_get(self, mock_raise): 44 | self.client.get.return_value.json.return_value = {'test': 'success'} 45 | res = self.api.get('abc123', foo='bar') 46 | self.client.get.assert_called_with(f'{self.client.api_url}/link/abc123/', params={'foo': 'bar'}) 47 | mock_raise.assert_called_with(self.client.get.return_value) 48 | self.assertIsInstance(res, link.Link) 49 | self.assertEqual(res.test, 'success') 50 | 51 | @patch('air_sdk.util.raise_if_invalid_response') 52 | def test_list(self, mock_raise): 53 | self.client.get.return_value.json.return_value = [{'id': 'abc', 'interfaces': ['foo']}, {'id': 'xyz'}] 54 | res = self.api.list(foo='bar') 55 | self.client.get.assert_called_with(f'{self.client.api_url}/link/', params={'foo': 'bar'}) 56 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 57 | self.assertEqual(len(res), 2) 58 | self.assertIsInstance(res[0], link.Link) 59 | self.assertEqual(res[0].id, 'abc') 60 | self.assertEqual(res[1].id, 'xyz') 61 | 62 | @patch('air_sdk.util.raise_if_invalid_response') 63 | def test_create(self, mock_raise): 64 | self.client.post.return_value.json.return_value = {'id': 'abc'} 65 | res = self.api.create(topology='abc123', interfaces=['def123']) 66 | self.client.post.assert_called_with( 67 | f'{self.client.api_url}/link/', json={'topology': 'abc123', 'interfaces': ['def123']} 68 | ) 69 | mock_raise.assert_called_with(self.client.post.return_value, status_code=201) 70 | self.assertIsInstance(res, link.Link) 71 | self.assertEqual(res.id, 'abc') 72 | 73 | def test_create_required_kwargs(self): 74 | with self.assertRaises(AttributeError) as err: 75 | self.api.create(interfaces=[]) 76 | self.assertTrue('requires topology' in str(err.exception)) 77 | with self.assertRaises(AttributeError) as err: 78 | self.api.create(topology='abc123') 79 | self.assertTrue('requires interfaces' in str(err.exception)) 80 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/cloud_inits.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from __future__ import annotations 5 | 6 | from dataclasses import dataclass, field 7 | from http import HTTPStatus 8 | from typing import Any, Iterable, Optional, TypedDict, Union 9 | 10 | from air_sdk.util import raise_if_invalid_response 11 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi, PrimaryKey 12 | from air_sdk.v2.endpoints.mixins import serialize_payload 13 | from air_sdk.v2.endpoints.nodes import Node 14 | from air_sdk.v2.endpoints.user_configs import UserConfig 15 | from air_sdk.v2.utils import join_urls, validate_payload_types 16 | 17 | 18 | class CloudInitBulkAssignment(TypedDict, total=False): 19 | user_data: Optional[Union[UserConfig, PrimaryKey]] 20 | meta_data: Optional[Union[UserConfig, PrimaryKey]] 21 | simulation_node: Union[Node, PrimaryKey] 22 | 23 | 24 | @dataclass(eq=False) 25 | class CloudInit(AirModel): 26 | simulation_node: Node = field(metadata=AirModel.FIELD_FOREIGN_KEY) 27 | user_data: Optional[UserConfig] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 28 | meta_data: Optional[UserConfig] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 29 | user_data_name: Optional[str] 30 | meta_data_name: Optional[str] 31 | 32 | @classmethod 33 | def get_model_api(cls): 34 | """ 35 | Returns the respective `AirModelAPI` type for this model. 36 | """ 37 | return CloudInitEndpointApi 38 | 39 | @property 40 | def primary_key_field(self) -> str: 41 | return 'simulation_node' 42 | 43 | @property 44 | def __pk__(self) -> PrimaryKey: 45 | return getattr(self, self.primary_key_field).__pk__ # type: ignore 46 | 47 | @validate_payload_types 48 | def full_update( 49 | self, 50 | user_data: Optional[Union[UserConfig, PrimaryKey]], 51 | meta_data: Optional[Union[UserConfig, PrimaryKey]], 52 | ) -> None: 53 | """Update all fields of the cloud-init assignment.""" 54 | super().update(user_data=user_data, meta_data=meta_data) 55 | 56 | 57 | class CloudInitEndpointApi(BaseEndpointApi[CloudInit]): 58 | API_PATH = 'simulations/nodes/{id}/cloud-init' # Placeholder 59 | BULK_API_PATH = 'simulations/nodes/cloud-init/bulk-assign' 60 | model = CloudInit 61 | 62 | def get(self, pk: PrimaryKey, **params: Any) -> CloudInit: 63 | detail_url = self.url.format(id=str(pk)) 64 | response = self.__api__.client.get(detail_url, params=params) 65 | raise_if_invalid_response(response) 66 | return self.load_model(response.json()) 67 | 68 | def patch(self, pk: PrimaryKey, **kwargs: Any) -> CloudInit: 69 | detail_url = self.url.format(id=str(pk)) 70 | response = self.__api__.client.patch(detail_url, data=serialize_payload(kwargs)) 71 | raise_if_invalid_response(response) 72 | return self.load_model(response.json()) 73 | 74 | @validate_payload_types 75 | def bulk_assign(self, assignments: Iterable[CloudInitBulkAssignment]) -> None: 76 | response = self.__api__.client.patch( 77 | join_urls(self.__api__.client.base_url, self.BULK_API_PATH), 78 | data=serialize_payload([dict(assignment) for assignment in assignments]), 79 | ) 80 | raise_if_invalid_response(response, status_code=HTTPStatus.NO_CONTENT, data_type=None) 81 | -------------------------------------------------------------------------------- /air_sdk/account.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Account module 6 | """ 7 | 8 | from . import user_preference, util 9 | from .air_model import AirModel 10 | 11 | 12 | class Account(AirModel): 13 | """ 14 | Manage an Account 15 | ### json 16 | Returns a JSON string representation of the account 17 | 18 | ### refresh 19 | Syncs the account with all values returned by the API 20 | """ 21 | 22 | _deletable = False 23 | _updatable = False 24 | 25 | def __repr__(self): 26 | if self._deleted or not self.username: 27 | return super().__repr__() 28 | return f'' 29 | 30 | 31 | class AccountApi: 32 | """High-level interface for the Account API""" 33 | 34 | def __init__(self, client): 35 | self.client = client 36 | self.url = self.client.api_url + '/account/' 37 | 38 | def get(self, account_id, **kwargs): 39 | """ 40 | Get an existing account 41 | 42 | Arguments: 43 | account_id (str): Account ID 44 | kwargs (dict, optional): All other optional keyword arguments are applied as query 45 | parameters/filters 46 | 47 | Returns: 48 | [`Account`](/docs/account) 49 | 50 | Raises: 51 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 52 | or valid response JSON 53 | 54 | Example: 55 | ``` 56 | >>> air.accounts.get('3dadd54d-583c-432e-9383-a2b0b1d7f551') 57 | 58 | ``` 59 | """ 60 | url = f'{self.url}{account_id}/' 61 | res = self.client.get(url, params=kwargs) 62 | util.raise_if_invalid_response(res) 63 | return Account(self, **res.json()) 64 | 65 | def list(self, **kwargs): 66 | # pylint: disable=line-too-long 67 | """ 68 | List existing accounts 69 | 70 | Arguments: 71 | kwargs (dict, optional): All other optional keyword arguments are applied as query 72 | parameters/filters 73 | 74 | Returns: 75 | list 76 | 77 | Raises: 78 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 79 | or valid response JSON 80 | 81 | Example: 82 | ``` 83 | >>> air.accounts.list() 84 | [, ] 85 | ``` 86 | """ 87 | # pylint: enable=line-too-long 88 | res = self.client.get(f'{self.url}', params=kwargs) 89 | util.raise_if_invalid_response(res, data_type=list) 90 | return [Account(self, **account) for account in res.json()] 91 | 92 | def preferences(self, **kwargs): 93 | """ 94 | Returns your global account preferences 95 | 96 | Arguments: 97 | kwargs (dict, optional): All other optional keyword arguments are applied as query 98 | parameters/filters 99 | 100 | Returns: 101 | [`UserPreference`](/docs/userpreference) 102 | 103 | Example: 104 | ``` 105 | >>> air.accounts.preferences() 106 | {"show": true} 107 | ``` 108 | """ 109 | res = self.client.get(f'{self.url}preferences/', params=kwargs) 110 | util.raise_if_invalid_response(res) 111 | return user_preference.UserPreference(self, **res.json()) 112 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_cloud_init.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | import json 4 | from http import HTTPStatus 5 | 6 | import faker 7 | import pytest 8 | 9 | from air_sdk.v2.endpoints.mixins import serialize_payload 10 | from air_sdk.v2.utils import join_urls 11 | 12 | faker.Faker.seed(0) 13 | fake = faker.Faker() 14 | 15 | 16 | class TestCloudInit: 17 | @pytest.mark.parametrize( 18 | 'payload,is_valid', 19 | ( 20 | # Empty case 21 | ({}, False), 22 | ({'user_data': None, 'meta_data': None}, True), 23 | ({'user_data': fake.pybool(), 'meta_data': None}, False), 24 | ({'user_data': fake.uuid4(cast_to=None), 'meta_data': fake.uuid4(cast_to=None)}, True), 25 | ({'user_data': fake.uuid4(cast_to=None), 'meta_data': None}, True), 26 | ({'user_data': None, 'meta_data': fake.uuid4(cast_to=None)}, True), 27 | ( 28 | { 29 | 'user_data': fake.uuid4(cast_to=None), 30 | 'meta_data': fake.uuid4(cast_to=None), 31 | 'unexpected_field': fake.uuid4(cast_to=None), 32 | }, 33 | False, 34 | ), 35 | ), 36 | ) 37 | def test_full_update( 38 | self, setup_mock_responses, api, run_full_update_patch_test, cloud_init_factory, payload, is_valid 39 | ): 40 | endpoint_api = api.cloud_inits 41 | instance = cloud_init_factory(endpoint_api.__api__) 42 | if is_valid: 43 | processed_payload = json.loads(serialize_payload(payload)) 44 | updated_inst = cloud_init_factory( 45 | endpoint_api.__api__, **{**instance.dict(), **processed_payload} 46 | ) 47 | detail_url = endpoint_api.url.format(id=str(instance.__pk__)) 48 | setup_mock_responses( 49 | { 50 | ('PATCH', detail_url): { 51 | 'json': json.loads(instance.json()), 52 | 'status_code': HTTPStatus.OK, 53 | }, 54 | } 55 | ) 56 | instance.full_update(**payload) 57 | assert instance == updated_inst 58 | assert instance is not updated_inst 59 | else: 60 | with pytest.raises(Exception) as err: 61 | instance.full_update(**payload) 62 | assert err.type in (TypeError, ValueError) 63 | 64 | 65 | class TestCloudInitEndpointApi: 66 | def test_bulk_assign(self, api, mock_client, node_factory, user_config_factory): 67 | node_1, node_2 = node_factory(api), node_factory(api) 68 | user_data, meta_data = user_config_factory(api), user_config_factory(api) 69 | assignments = [ 70 | {'simulation_node': node_1.id, 'user_data': user_data.id}, 71 | {'simulation_node': node_2.id, 'user_data': None, 'meta_data': meta_data.id}, 72 | ] 73 | payload = serialize_payload(assignments) 74 | patch_called = False 75 | 76 | def _validate_payload(request, *args, **kwargs): 77 | """Makes sure a call was made and proper payload was provided.""" 78 | nonlocal patch_called 79 | assert payload == request.text 80 | patch_called = True 81 | return None 82 | 83 | mock_client.register_uri( 84 | 'PATCH', 85 | join_urls(api.client.base_url, api.cloud_inits.BULK_API_PATH), 86 | json=_validate_payload, 87 | status_code=HTTPStatus.NO_CONTENT, 88 | ) 89 | api.cloud_inits.bulk_assign(assignments) 90 | assert patch_called is True 91 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_breakout.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | import json 4 | from http import HTTPStatus 5 | 6 | import pytest 7 | from air_sdk.exceptions import AirUnexpectedResponse 8 | import faker 9 | 10 | faker.Faker.seed(0) 11 | fake = faker.Faker() 12 | 13 | 14 | class TestBreakoutEndpointApi: 15 | def test_list(self, api, run_list_test, breakout_factory): 16 | run_list_test(api.breakouts, breakout_factory) 17 | 18 | def test_refresh(self, api, run_refresh_test, breakout_factory): 19 | run_refresh_test(api.breakouts, breakout_factory) 20 | 21 | def test_delete(self, api, run_delete_test, breakout_factory): 22 | run_delete_test(api.breakouts, breakout_factory) 23 | 24 | @pytest.mark.parametrize( 25 | 'payload,is_valid', 26 | ( 27 | # Invalid cases 28 | ({}, False), 29 | ({'interface': str(fake.uuid4())}, False), # Missing split_count 30 | ({'split_count': 4}, False), # Missing interface 31 | ( 32 | { 33 | 'interface': 'not-a-uuid', # Invalid UUID format 34 | 'split_count': 4, 35 | }, 36 | False, 37 | ), 38 | ( 39 | { 40 | 'interface': str(fake.uuid4()), 41 | 'split_count': fake.slug(), # Wrong type 42 | }, 43 | False, 44 | ), 45 | ({'interface': None, 'split_count': None}, False), 46 | ({'interface': None, 'split_count': fake.pyint()}, False), 47 | ({'interface': fake.uuid4(cast_to=str), 'split_count': None}, False), 48 | ({'interface': fake.uuid4(cast_to=str), 'split_count': 1}, False), 49 | # Valid cases 50 | ({'interface': fake.uuid4(cast_to=str), 'split_count': 2}, True), 51 | ({'interface': fake.uuid4(cast_to=str), 'split_count': 4}, True), 52 | # special marker for Interface instance 53 | ({'interface': 'USE_INTERFACE_INSTANCE', 'split_count': 4}, True), 54 | ), 55 | ) 56 | def test_create(self, api, setup_mock_responses, breakout_factory, interface_factory, payload, is_valid): 57 | """This tests that the data provided is properly validated and used.""" 58 | endpoint_api = api.breakouts 59 | 60 | if payload.get('interface') == 'USE_INTERFACE_INSTANCE': 61 | payload['interface'] = interface_factory(api) 62 | 63 | if is_valid: 64 | factory_kwargs = {k: v for k, v in payload.items() if v is not None} 65 | expected_inst = breakout_factory(endpoint_api.__api__, **factory_kwargs) 66 | setup_mock_responses( 67 | { 68 | ('POST', endpoint_api.url): { 69 | 'json': json.loads(expected_inst.json()), 70 | 'status_code': HTTPStatus.CREATED, 71 | } 72 | } 73 | ) 74 | inst = endpoint_api.create(**payload) 75 | assert inst == expected_inst 76 | assert inst is not expected_inst 77 | else: 78 | setup_mock_responses( 79 | { 80 | ('POST', endpoint_api.url): { 81 | 'json': {'error': 'Invalid request'}, 82 | 'status_code': HTTPStatus.BAD_REQUEST, 83 | } 84 | } 85 | ) 86 | with pytest.raises(Exception) as err: 87 | endpoint_api.create(**payload) 88 | assert err.type in (AirUnexpectedResponse, TypeError, ValueError) 89 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/manifests.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from dataclasses import dataclass, field 5 | from typing import List, Optional, Union 6 | 7 | from air_sdk.v2.air_model import AirModel, BaseEndpointApi, DataDict, PrimaryKey 8 | from air_sdk.v2.endpoints import mixins 9 | from air_sdk.v2.endpoints.accounts import Account 10 | from air_sdk.v2.endpoints.images import Image 11 | from air_sdk.v2.endpoints.organizations import Organization 12 | 13 | 14 | @dataclass(eq=False) 15 | class Manifest(AirModel): 16 | id: str = field(repr=False) 17 | artifacts_directory: str = field(repr=False) 18 | artifacts_directory_max_size_gb: int = field(repr=False) 19 | boot_group: int = field(repr=False) 20 | configure_node_properties: DataDict = field(repr=False) 21 | configure_simulator: DataDict = field(repr=False) 22 | docker_run_parameters: DataDict = field(repr=False) 23 | emulation_type: str 24 | organization: Optional[Organization] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 25 | owner: Optional[Account] = field(metadata=AirModel.FIELD_FOREIGN_KEY, repr=False) 26 | platform_information: DataDict = field(repr=False) 27 | simulation_engine_versions: List[str] = field(repr=False) 28 | simulator_image: Image = field(metadata=AirModel.FIELD_FOREIGN_KEY) 29 | simulator_resources: DataDict = field(repr=False) 30 | emulation_params: DataDict = field(repr=False) 31 | 32 | @classmethod 33 | def get_model_api(cls): 34 | """ 35 | Returns the respective `AirModelAPI` type for this model. 36 | """ 37 | return ManifestEndpointApi 38 | 39 | 40 | class ManifestEndpointApi( 41 | mixins.ListApiMixin[Manifest], 42 | mixins.CreateApiMixin[Manifest], 43 | mixins.GetApiMixin[Manifest], 44 | mixins.PatchApiMixin[Manifest], 45 | mixins.PutApiMixin[Manifest], 46 | mixins.DeleteApiMixin, 47 | BaseEndpointApi[Manifest], 48 | ): 49 | API_PATH = 'manifests' 50 | model = Manifest 51 | 52 | def create( 53 | self, 54 | artifacts_directory: str, 55 | configure_node_properties: DataDict, 56 | configure_simulator: DataDict, 57 | docker_run_parameters: DataDict, 58 | emulation_type: str, 59 | organization: Union[Organization, PrimaryKey], 60 | simulator_image: Union[Image, PrimaryKey], 61 | simulator_resources: DataDict, 62 | artifacts_directory_max_size_gb: Optional[int] = None, 63 | boot_group: Optional[int] = None, 64 | platform_information: Optional[DataDict] = None, 65 | simulation_engine_versions: Optional[List[str]] = None, 66 | emulation_params: Optional[DataDict] = None, 67 | ) -> Manifest: 68 | payload: DataDict = { 69 | 'artifacts_directory': artifacts_directory, 70 | 'configure_node_properties': configure_node_properties, 71 | 'configure_simulator': configure_simulator, 72 | 'docker_run_parameters': docker_run_parameters, 73 | 'emulation_type': emulation_type, 74 | 'organization': organization, 75 | 'simulator_image': simulator_image, 76 | 'simulator_resources': simulator_resources, 77 | } 78 | for key, value in ( 79 | ('artifacts_directory_max_size_gb', artifacts_directory_max_size_gb), 80 | ('boot_group', boot_group), 81 | ('platform_information', platform_information), 82 | ('simulation_engine_versions', simulation_engine_versions), 83 | ('emulation_params', emulation_params), 84 | ): 85 | if value is not None: # Don't include optional None fields in payload. 86 | payload[key] = value 87 | return super().create(**payload) 88 | -------------------------------------------------------------------------------- /air_sdk/simulation_interface.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | SimulationInterface module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class SimulationInterface(AirModel): 13 | """ 14 | Manage a SimulationInterface 15 | 16 | ### json 17 | Returns a JSON string representation of the simulation interface 18 | 19 | ### refresh 20 | Syncs the simulation interface with all values returned by the API 21 | 22 | ### update 23 | Update the simulation interface with the provided data 24 | 25 | Arguments: 26 | kwargs (dict, optional): All optional keyword arguments are applied as key/value 27 | pairs in the request's JSON payload 28 | """ 29 | 30 | _deletable = False 31 | 32 | def __repr__(self): 33 | if self._deleted: 34 | return super().__repr__() 35 | return f'' 36 | 37 | 38 | class SimulationInterfaceApi: 39 | """High-level interface for the SimulationInterface API""" 40 | 41 | def __init__(self, client): 42 | self.client = client 43 | self.url = self.client.api_url + '/simulation-interface/' 44 | 45 | @util.deprecated('SimulationInterfaceApi.list()') 46 | def get_simulation_interfaces(self, simulation_id='', original_id=''): # pylint: disable=missing-function-docstring 47 | return self.list(simulation=simulation_id, original=original_id) 48 | 49 | def get(self, simulation_interface_id, **kwargs): 50 | """ 51 | Get an existing simulation interface 52 | 53 | Arguments: 54 | simulation_interface_id (str): SimulationInterface ID 55 | kwargs (dict, optional): All other optional keyword arguments are applied as query 56 | parameters/filters 57 | 58 | Returns: 59 | [`SimulationInterface`](/docs/simulationinterface) 60 | 61 | Raises: 62 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 63 | or valid response JSON 64 | 65 | Example: 66 | ``` 67 | >>> air.simulation_interfaces.get('3dadd54d-583c-432e-9383-a2b0b1d7f551') 68 | 69 | ``` 70 | """ 71 | url = f'{self.url}{simulation_interface_id}/' 72 | res = self.client.get(url, params=kwargs) 73 | util.raise_if_invalid_response(res) 74 | return SimulationInterface(self, **res.json()) 75 | 76 | def list(self, **kwargs): 77 | # pylint: disable=line-too-long 78 | """ 79 | List existing simulation interfaces 80 | 81 | Arguments: 82 | kwargs (dict, optional): All other optional keyword arguments are applied as query 83 | parameters/filters 84 | 85 | Returns: 86 | list 87 | 88 | Raises: 89 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 90 | or valid response JSON 91 | 92 | Example: 93 | ``` 94 | >>> air.simulation_interfaces.list() 95 | [, ] 96 | ``` 97 | """ # pylint: enable=line-too-long 98 | if kwargs.get('interface'): 99 | kwargs['original'] = kwargs['interface'] 100 | del kwargs['interface'] 101 | res = self.client.get(f'{self.url}', params=kwargs) 102 | util.raise_if_invalid_response(res, data_type=list) 103 | return [SimulationInterface(self, **simulation_interface) for simulation_interface in res.json()] 104 | -------------------------------------------------------------------------------- /air_sdk/v2/typing.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | import inspect 4 | from typing import Any, List, Literal, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin 5 | import sys 6 | from dataclasses import is_dataclass 7 | from typing import ( 8 | ForwardRef, 9 | ) 10 | 11 | T = TypeVar('T') 12 | 13 | 14 | def union_args_are_optional(args: Tuple[Union[Any, Any], ...]) -> bool: 15 | return len(args) >= 2 and type(None) in args 16 | 17 | 18 | def is_union(type_: Type[Any]) -> bool: 19 | return get_origin(type_) == Union 20 | 21 | 22 | def is_optional_union(type_: Type[Any]) -> bool: 23 | return is_union(type_) and union_args_are_optional(get_args(type_)) 24 | 25 | 26 | def get_optional_arg(optional_type: Type[Optional[T]]) -> Type[T]: 27 | return next(arg for arg in get_args(optional_type) if arg is not type(None)) # type: ignore[no-any-return] 28 | 29 | 30 | def get_list_arg(list_type: Type[List[T]]) -> Type[T]: 31 | return get_args(list_type)[0] # type: ignore[no-any-return] 32 | 33 | 34 | def is_typed_dict(expected_type: Type[Any]) -> bool: 35 | """Determine if the `expected_type` provided is a subclass of TypedDict.""" 36 | return hasattr(expected_type, '__annotations__') and not is_dataclass(expected_type) 37 | 38 | 39 | def type_check_typed_dict(value: Any, expected_type: Type[Any]) -> bool: 40 | """Perform type checking when the expected_type is a subclass of TypedDict. 41 | 42 | This currently does not work if the expected_type is also a dataclass. 43 | """ 44 | if not isinstance(value, dict): 45 | return False 46 | expected_keys = expected_type.__annotations__.keys() 47 | # Check all keys provided are defined within the expected_type TypedDict 48 | if not all(key in value for key in expected_keys): 49 | return False 50 | # Recursively check each key's value type 51 | return all(type_check(value[key], expected_type.__annotations__[key]) for key in expected_keys) 52 | 53 | 54 | def type_check(value: Any, expected_type: Type[Any]) -> bool: 55 | """Recursively check if the value matches the expected type.""" 56 | from air_sdk.v2.air_model import PrimaryKey # noqa 57 | 58 | if isinstance(expected_type, ForwardRef): 59 | if sys.version_info >= (3, 9): # Python 3.9+ requires a third `recursive_guard` arg 60 | expected_type = expected_type._evaluate(globals(), locals(), frozenset()) 61 | else: 62 | expected_type = expected_type._evaluate(globals(), locals()) 63 | origin = get_origin(expected_type) 64 | args = get_args(expected_type) 65 | 66 | if origin is None: # Base case 67 | if is_typed_dict(expected_type): 68 | return type_check_typed_dict(value, expected_type) 69 | if expected_type == Any: 70 | return True 71 | return isinstance(value, expected_type) 72 | 73 | if origin is Union: 74 | return any(type_check(value, arg) for arg in args) 75 | 76 | if origin is list: 77 | if not isinstance(value, list): 78 | return False 79 | if not args: # We're already a list, so if not args then we're good 80 | return True 81 | return all(type_check(item, args[0]) for item in value) 82 | 83 | if origin is dict: 84 | if not isinstance(value, dict): 85 | return False 86 | if not args: # We're already a dict, so if no args then we're good 87 | return True 88 | key_type, value_type = args 89 | return all(type_check(k, key_type) and type_check(v, value_type) for k, v in value.items()) 90 | 91 | if origin is Literal: 92 | return any(value == arg for arg in args) 93 | 94 | if inspect.isclass(origin): 95 | return isinstance(value, origin) 96 | 97 | return False 98 | -------------------------------------------------------------------------------- /tests/test_permission.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for permission.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import permission 13 | 14 | 15 | class TestPermission(TestCase): 16 | def setUp(self): 17 | self.model = permission.Permission(MagicMock()) 18 | self.model.id = 'abc123' 19 | 20 | def test_init_(self): 21 | self.assertTrue(self.model._deletable) 22 | self.assertFalse(self.model._updatable) 23 | 24 | def test_repr(self): 25 | self.assertEqual(str(self.model), f'') 26 | 27 | def test_repr_deleted(self): 28 | self.model._deleted = True 29 | self.assertTrue('Deleted Object' in str(self.model)) 30 | 31 | 32 | class TestPermissionApi(TestCase): 33 | def setUp(self): 34 | self.client = MagicMock() 35 | self.client.api_url = 'http://testserver/api' 36 | self.api = permission.PermissionApi(self.client) 37 | 38 | def test_init_(self): 39 | self.assertEqual(self.api.client, self.client) 40 | self.assertEqual(self.api.url, 'http://testserver/api/permission/') 41 | 42 | @patch('air_sdk.permission.PermissionApi.create') 43 | def test_create_permission(self, mock_create): 44 | res = self.api.create_permission('me@test.com', foo='bar') 45 | mock_create.assert_called_with(email='me@test.com', foo='bar') 46 | self.assertEqual(res, mock_create.return_value) 47 | 48 | @patch('air_sdk.util.raise_if_invalid_response') 49 | def test_get(self, mock_raise): 50 | self.client.get.return_value.json.return_value = {'test': 'success'} 51 | res = self.api.get('abc123', foo='bar') 52 | self.client.get.assert_called_with(f'{self.client.api_url}/permission/abc123/', params={'foo': 'bar'}) 53 | mock_raise.assert_called_with(self.client.get.return_value) 54 | self.assertIsInstance(res, permission.Permission) 55 | self.assertEqual(res.test, 'success') 56 | 57 | @patch('air_sdk.util.raise_if_invalid_response') 58 | def test_list(self, mock_raise): 59 | self.client.get.return_value.json.return_value = [{'id': 'abc'}, {'id': 'xyz'}] 60 | res = self.api.list(foo='bar') 61 | self.client.get.assert_called_with(f'{self.client.api_url}/permission/', params={'foo': 'bar'}) 62 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 63 | self.assertEqual(len(res), 2) 64 | self.assertIsInstance(res[0], permission.Permission) 65 | self.assertEqual(res[0].id, 'abc') 66 | self.assertEqual(res[1].id, 'xyz') 67 | 68 | @patch('air_sdk.util.raise_if_invalid_response') 69 | def test_create(self, mock_raise): 70 | self.client.post.return_value.json.return_value = {'id': 'abc'} 71 | res = self.api.create(simulation='abc123', email='me@test.com') 72 | self.client.post.assert_called_with( 73 | f'{self.client.api_url}/permission/', json={'simulation': 'abc123', 'email': 'me@test.com'} 74 | ) 75 | mock_raise.assert_called_with(self.client.post.return_value, status_code=201) 76 | self.assertIsInstance(res, permission.Permission) 77 | self.assertEqual(res.id, 'abc') 78 | 79 | def test_create_required_kwargs(self): 80 | with self.assertRaises(AttributeError) as err: 81 | self.api.create(simulation='abc123') 82 | self.assertTrue('requires email' in str(err.exception)) 83 | with self.assertRaises(AttributeError) as err: 84 | self.api.create(email='me@test.com') 85 | msg = "requires one of the following: ('topology', 'simulation', 'subject_id')" 86 | self.assertTrue(msg in str(err.exception)) 87 | -------------------------------------------------------------------------------- /tests/test_user_preference.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for user_preference.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring,protected-access,unused-argument 9 | import json 10 | from unittest import TestCase 11 | from unittest.mock import MagicMock, patch 12 | 13 | from air_sdk import simulation, user_preference 14 | 15 | 16 | class TestUserPreference(TestCase): 17 | def setUp(self): 18 | self.account_api = MagicMock() 19 | self.account_api.url = 'http://testserver/v1/account/' 20 | self.account_prefs = user_preference.UserPreference(self.account_api, preferences={'foo': 'bar'}) 21 | self.sim_api = MagicMock() 22 | self.sim_api.url = 'http://testserver/v1/simulation/' 23 | self.sim = simulation.Simulation(self.sim_api) 24 | self.sim.id = 'abc123' 25 | self.sim_prefs = user_preference.UserPreference( 26 | self.sim_api, _model=self.sim, preferences={'foo': 'bar'} 27 | ) 28 | 29 | def test_init(self): 30 | self.assertFalse(self.account_prefs._deletable) 31 | self.assertTrue(self.account_prefs._updatable) 32 | 33 | def test_repr(self): 34 | self.assertEqual(repr(self.account_prefs), json.dumps(self.account_prefs.preferences)) 35 | 36 | @patch('air_sdk.user_preference.util.raise_if_invalid_response') 37 | def test_setattr(self, mock_raise): 38 | self.account_prefs.test = True 39 | self.account_api.client.patch.assert_called_once_with(self.account_prefs._url, json={'test': True}) 40 | mock_raise.assert_called_once_with(self.account_api.client.patch.return_value) 41 | self.assertTrue(self.account_prefs.preferences['test']) 42 | 43 | @patch('air_sdk.user_preference.util.raise_if_invalid_response') 44 | @patch('air_sdk.user_preference.UserPreference._build_url', return_value=None) 45 | def test_setattr_super(self, *args): 46 | self.account_prefs._url = None 47 | pref = user_preference.UserPreference(self.sim_api, preferences={}) 48 | 49 | pref.test = True 50 | self.sim_api.client.patch.assert_not_called() 51 | 52 | def test_build_url(self): 53 | self.assertEqual(self.account_prefs._build_url(), f'{self.account_api.url}preferences/') 54 | 55 | def test_build_url_model(self): 56 | self.assertEqual(self.sim_prefs._build_url(), f'{self.sim_api.url}{self.sim.id}/preferences/') 57 | 58 | def test_build_url_version_override(self): 59 | pref = user_preference.UserPreference(self.sim_api, _model=self.sim, _version_override='2') 60 | 61 | self.assertEqual( 62 | pref._build_url(), f'{self.sim_api.url.replace("v1", "v2")}{self.sim.id}/preferences/' 63 | ) 64 | 65 | @patch('air_sdk.user_preference.UserPreference._load') 66 | def test_refresh(self, mock_load): 67 | self.account_prefs.refresh() 68 | mock_load.assert_called_once_with(**self.account_api.preferences().__dict__) 69 | 70 | @patch('air_sdk.simulation.Simulation.preferences') 71 | @patch('air_sdk.user_preference.UserPreference._load') 72 | def test_refresh_model(self, mock_load, *args): 73 | self.sim_prefs.refresh() 74 | mock_load.assert_called_once_with(**self.sim.preferences().__dict__) 75 | 76 | @patch('air_sdk.user_preference.util.raise_if_invalid_response') 77 | @patch('air_sdk.user_preference.UserPreference.refresh') 78 | def test_update(self, mock_refresh, mock_raise): 79 | self.account_prefs.update(test=True) 80 | mock_refresh.assert_called_once() 81 | self.assertTrue(self.account_prefs.preferences['test']) 82 | self.account_api.client.put.assert_called_once_with( 83 | self.account_prefs._url, json=self.account_prefs.__dict__ 84 | ) 85 | mock_raise.assert_called_once_with(self.account_api.client.put.return_value) 86 | -------------------------------------------------------------------------------- /air_sdk/token.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Token module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class Token(AirModel): 13 | """ 14 | View an API Token 15 | 16 | ### json 17 | Returns a JSON string representation of the interface 18 | 19 | ### refresh 20 | Syncs the interface with all values returned by the API 21 | """ 22 | 23 | # _deletable = False 24 | _updatable = False 25 | 26 | def __repr__(self): 27 | if self._deleted or not self.name: 28 | return super().__repr__() 29 | if hasattr(self, 'id'): 30 | return f'' 31 | if hasattr(self, 'token'): 32 | return f'' 33 | return f'' 34 | 35 | 36 | class TokenApi: 37 | """High-level interface for the Token API""" 38 | 39 | def __init__(self, client): 40 | self.client = client 41 | self.url = self.client.api_url + '/api-token/' 42 | 43 | def delete(self, token_id, **kwargs): 44 | """ 45 | Deletes an api token 46 | 47 | Arguments: 48 | token_id (str): Token ID 49 | kwargs (dict, optional): All other optional keyword arguments are applied as query 50 | parameters/filters 51 | 52 | 53 | Raises: 54 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 204 No Content 55 | or valid response JSON 56 | 57 | Example: 58 | ``` 59 | >>> air.api_tokens.delete('3dadd54d-583c-432e-9383-a2b0b1d7f551') 60 | ``` 61 | """ 62 | url = f'{self.url}{token_id}/' 63 | res = self.client.delete(url, params=kwargs) 64 | util.raise_if_invalid_response(res, status_code=204, data_type=None) 65 | 66 | def list(self, **kwargs): 67 | # pylint: disable=line-too-long 68 | """ 69 | List existing tokens 70 | 71 | Arguments: 72 | kwargs (dict, optional): All other optional keyword arguments are applied as query 73 | parameters/filters 74 | 75 | Returns: 76 | list 77 | 78 | Raises: 79 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 80 | or valid response JSON 81 | 82 | Example: 83 | ``` 84 | >>> air.tokens.list() 85 | [, ] 86 | ``` 87 | """ 88 | # pylint: enable=line-too-long 89 | res = self.client.get(f'{self.url}', params=kwargs) 90 | util.raise_if_invalid_response(res, data_type=list) 91 | return [Token(self, **token) for token in res.json()] 92 | 93 | @util.required_kwargs(['name']) 94 | def create(self, **kwargs): 95 | """ 96 | Add a new api token to your account 97 | 98 | Arguments: 99 | name (str): Descriptive name for the api token 100 | kwargs (dict, optional): All other optional keyword arguments are applied as key/value 101 | pairs in the request's JSON payload 102 | 103 | Returns: 104 | [`api-token`](/docs/api-token) 105 | 106 | Raises: 107 | [`AirUnexpectedResponse`](/docs/exceptions) - API did not return a 200 OK 108 | or valid response JSON 109 | 110 | Example: 111 | ``` 112 | >>> air.api_tokens.create(name='my_api_token') 113 | 114 | ``` 115 | """ 116 | res = self.client.post(self.url, json=kwargs) 117 | util.raise_if_invalid_response(res, status_code=201) 118 | return Token(self, **res.json()) 119 | -------------------------------------------------------------------------------- /tests/test_fleet.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for fleet.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import fleet, organization 13 | 14 | 15 | class TestFleet(TestCase): 16 | def setUp(self): 17 | self.api = MagicMock() 18 | self.model = fleet.Fleet(self.api) 19 | self.model.name = 'test.test' 20 | self.model.id = 'abc123' 21 | 22 | def test_init_(self): 23 | self.assertTrue(self.model._deletable) 24 | self.assertTrue(self.model._updatable) 25 | 26 | def test_repr(self): 27 | self.assertEqual(str(self.model), f'') 28 | 29 | def test_repr_deleted(self): 30 | self.model._deleted = True 31 | self.assertTrue('Deleted Object' in str(self.model)) 32 | 33 | 34 | class TestFleetApi(TestCase): 35 | def setUp(self): 36 | self.client = MagicMock() 37 | self.mock_api = MagicMock() 38 | self.client.api_url = 'http://testserver/api' 39 | self.api = fleet.FleetApi(self.client) 40 | self.org = organization.Organization(self.mock_api, id='xyz456', name='NVIDIA') 41 | 42 | def test_init_(self): 43 | self.assertEqual(self.api.client, self.client) 44 | self.assertEqual(self.api.url, 'http://testserver/api/fleet/') 45 | 46 | @patch('air_sdk.util.raise_if_invalid_response') 47 | def test_get(self, mock_raise): 48 | self.client.get.return_value.json.return_value = {'test': 'success'} 49 | res = self.api.get('abc123', foo='bar') 50 | self.client.get.assert_called_with(f'{self.client.api_url}/fleet/abc123/', params={'foo': 'bar'}) 51 | mock_raise.assert_called_with(self.client.get.return_value) 52 | self.assertIsInstance(res, fleet.Fleet) 53 | self.assertEqual(res.test, 'success') 54 | 55 | @patch('air_sdk.util.raise_if_invalid_response') 56 | def test_list(self, mock_raise): 57 | self.client.get.return_value.json.return_value = { 58 | 'count': 2, 59 | 'next': 'blabla', 60 | 'previous': None, 61 | 'results': [{'id': 'abc'}, {'id': 'xyz'}], 62 | } 63 | res = self.api.list(foo='bar') 64 | self.client.get.assert_called_with(f'{self.client.api_url}/fleet/', params={'foo': 'bar'}) 65 | mock_raise.assert_called_with(self.client.get.return_value, data_type=dict) 66 | self.assertEqual(len(res), 2) 67 | self.assertIsInstance(res[0], fleet.Fleet) 68 | self.assertEqual(res[0].id, 'abc') 69 | self.assertEqual(res[1].id, 'xyz') 70 | 71 | @patch('air_sdk.util.raise_if_invalid_response') 72 | def test_create(self, mock_raise): 73 | self.client.post.return_value.json.return_value = {'id': 'abc'} 74 | res = self.api.create( 75 | name='test_fleet_2', prefix_length=65, organization=str(self.org.id), port_range=22 76 | ) 77 | self.client.post.assert_called_with( 78 | f'{self.client.api_url}/fleet/', 79 | json={ 80 | 'name': 'test_fleet_2', 81 | 'prefix_length': 65, 82 | 'organization': str(self.org.id), 83 | 'port_range': 22, 84 | }, 85 | ) 86 | mock_raise.assert_called_with(self.client.post.return_value, status_code=201) 87 | self.assertIsInstance(res, fleet.Fleet) 88 | self.assertEqual(res.id, 'abc') 89 | 90 | def test_create_required_kwargs(self): 91 | with self.assertRaises(AttributeError) as err: 92 | self.api.create(prefix_length=65, organization=str(self.org.id), port_range=22) 93 | self.assertTrue('requires name' in str(err.exception)) 94 | with self.assertRaises(AttributeError) as err: 95 | self.api.create(name='test_fleet_2', prefix_length=65, port_range=22) 96 | self.assertTrue('requires organization' in str(err.exception)) 97 | -------------------------------------------------------------------------------- /tests/test_node.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Tests for node.py 6 | """ 7 | 8 | # pylint: disable=missing-function-docstring,missing-class-docstring,unused-argument 9 | from unittest import TestCase 10 | from unittest.mock import MagicMock, patch 11 | 12 | from air_sdk import node 13 | 14 | 15 | class TestNode(TestCase): 16 | def setUp(self): 17 | self.model = node.Node(MagicMock()) 18 | self.model.id = 'abc123' 19 | self.model.name = 'server' 20 | 21 | def test_init_(self): 22 | self.assertTrue(self.model._deletable) 23 | self.assertTrue(self.model._updatable) 24 | self.assertListEqual(self.model._ignored_update_fields, ['interfaces']) 25 | 26 | def test_repr(self): 27 | self.assertEqual(str(self.model), f'') 28 | 29 | def test_repr_deleted(self): 30 | self.model._deleted = True 31 | self.assertTrue('Deleted Object' in str(self.model)) 32 | 33 | 34 | class TestNodeApi(TestCase): 35 | def setUp(self): 36 | self.client = MagicMock() 37 | self.client.api_url = 'http://testserver/api' 38 | self.api = node.NodeApi(self.client) 39 | 40 | def test_init_(self): 41 | self.assertEqual(self.api.client, self.client) 42 | self.assertEqual(self.api.url, 'http://testserver/api/node/') 43 | 44 | @patch('air_sdk.node.NodeApi.list') 45 | def test_get_nodes(self, mock_list): 46 | res = self.api.get_nodes(simulation_id='foo') 47 | mock_list.assert_called_with(simulation='foo') 48 | self.assertEqual(res, mock_list.return_value) 49 | 50 | @patch('air_sdk.util.raise_if_invalid_response') 51 | def test_get(self, mock_raise): 52 | self.client.get.return_value.json.return_value = {'test': 'success'} 53 | res = self.api.get('abc123', foo='bar') 54 | self.client.get.assert_called_with(f'{self.client.api_url}/node/abc123/', params={'foo': 'bar'}) 55 | mock_raise.assert_called_with(self.client.get.return_value) 56 | self.assertIsInstance(res, node.Node) 57 | self.assertEqual(res.test, 'success') 58 | 59 | @patch('air_sdk.util.raise_if_invalid_response') 60 | def test_get_simulation_id(self, mock_raise): 61 | self.client.get.return_value.json.return_value = {'test': 'success'} 62 | self.api.get('abc123', simulation_id='xyz123') 63 | self.client.get.assert_called_with( 64 | f'{self.client.api_url}/node/abc123/', params={'simulation': 'xyz123'} 65 | ) 66 | 67 | @patch('air_sdk.util.raise_if_invalid_response') 68 | def test_list(self, mock_raise): 69 | self.client.get.return_value.json.return_value = [{'id': 'abc'}, {'id': 'xyz'}] 70 | res = self.api.list(foo='bar') 71 | self.client.get.assert_called_with(f'{self.client.api_url}/node/', params={'foo': 'bar'}) 72 | mock_raise.assert_called_with(self.client.get.return_value, data_type=list) 73 | self.assertEqual(len(res), 2) 74 | self.assertIsInstance(res[0], node.Node) 75 | self.assertEqual(res[0].id, 'abc') 76 | self.assertEqual(res[1].id, 'xyz') 77 | 78 | @patch('air_sdk.util.raise_if_invalid_response') 79 | def test_create(self, mock_raise): 80 | self.client.post.return_value.json.return_value = {'id': 'abc'} 81 | res = self.api.create(topology='abc123', name='test') 82 | self.client.post.assert_called_with( 83 | f'{self.client.api_url}/node/', json={'topology': 'abc123', 'name': 'test'} 84 | ) 85 | mock_raise.assert_called_with(self.client.post.return_value, status_code=201) 86 | self.assertIsInstance(res, node.Node) 87 | self.assertEqual(res.id, 'abc') 88 | 89 | def test_create_required_kwargs(self): 90 | with self.assertRaises(AttributeError) as err: 91 | self.api.create(name='test') 92 | self.assertTrue('requires topology' in str(err.exception)) 93 | with self.assertRaises(AttributeError) as err: 94 | self.api.create(topology='abc123') 95 | self.assertTrue('requires name' in str(err.exception)) 96 | -------------------------------------------------------------------------------- /air_sdk/link.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Link module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class Link(AirModel): 13 | """ 14 | Manage a Link 15 | 16 | ### delete 17 | Delete the link. Once successful, the object should no longer be used and will raise 18 | [`AirDeletedObject`](/docs/exceptions) when referenced. 19 | 20 | Raises: 21 | [`AirUnexpectedResposne`](/docs/exceptions) - Delete failed 22 | 23 | ### json 24 | Returns a JSON string representation of the link 25 | 26 | ### refresh 27 | Syncs the link with all values returned by the API 28 | 29 | ### update 30 | Update the link with the provided data 31 | 32 | Arguments: 33 | kwargs (dict, optional): All optional keyword arguments are applied as key/value 34 | pairs in the request's JSON payload 35 | """ 36 | 37 | def __repr__(self): 38 | if self._deleted: 39 | return super().__repr__() 40 | return f'' 41 | 42 | 43 | class LinkApi: 44 | """High-level interface for the Link API""" 45 | 46 | def __init__(self, client): 47 | self.client = client 48 | self.url = self.client.api_url + '/link/' 49 | 50 | def get(self, link_id, **kwargs): 51 | """ 52 | Get an existing link 53 | 54 | Arguments: 55 | link_id (str): Link ID 56 | kwargs (dict, optional): All other optional keyword arguments are applied as query 57 | parameters/filters 58 | 59 | Returns: 60 | [`Link`](/docs/link) 61 | 62 | Raises: 63 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 64 | or valid response JSON 65 | 66 | Example: 67 | ``` 68 | >>> air.links.get('3dadd54d-583c-432e-9383-a2b0b1d7f551') 69 | 70 | ``` 71 | """ 72 | url = f'{self.url}{link_id}/' 73 | res = self.client.get(url, params=kwargs) 74 | util.raise_if_invalid_response(res) 75 | return Link(self, **res.json()) 76 | 77 | def list(self, **kwargs): 78 | """ 79 | List existing links 80 | 81 | Arguments: 82 | kwargs (dict, optional): All other optional keyword arguments are applied as query 83 | parameters/filters 84 | 85 | Returns: 86 | list 87 | 88 | Raises: 89 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 90 | or valid response JSON 91 | 92 | Example: 93 | ``` 94 | >>> air.links.list() 95 | [, ] 96 | ``` 97 | """ 98 | res = self.client.get(f'{self.url}', params=kwargs) 99 | util.raise_if_invalid_response(res, data_type=list) 100 | return [Link(self, **link) for link in res.json()] 101 | 102 | @util.required_kwargs(['topology', 'interfaces']) 103 | def create(self, **kwargs): 104 | # pylint: disable=line-too-long 105 | """ 106 | Create a new link 107 | 108 | Arguments: 109 | topology (str | `Topology`): `Topology` or ID 110 | interfaces (list): List of `Interface` objects or IDs 111 | kwargs (dict, optional): All other optional keyword arguments are applied as key/value 112 | pairs in the request's JSON payload 113 | 114 | Returns: 115 | [`Link`](/docs/link) 116 | 117 | Raises: 118 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 119 | or valid response JSON 120 | 121 | Example: 122 | ``` 123 | >>> air.links.create(topology=topology, interfaces=[intf1, 'fd61e3d8-af2f-4735-8b1d-356ee6bf4abe']) 124 | 125 | ``` 126 | """ # pylint: enable=line-too-long 127 | res = self.client.post(self.url, json=kwargs) 128 | util.raise_if_invalid_response(res, status_code=201) 129 | return Link(self, **res.json()) 130 | -------------------------------------------------------------------------------- /air_sdk/v2/endpoints/mixins.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | from __future__ import annotations 4 | 5 | import json 6 | from http import HTTPStatus 7 | from typing import Optional, Any, TypedDict, List, Callable, Iterator, TYPE_CHECKING, Dict, Generic 8 | 9 | from air_sdk.v2.air_json_encoder import AirJSONEncoder 10 | from air_sdk.v2.air_model import PrimaryKey, DataDict, TAirModel_co 11 | from air_sdk.v2.utils import join_urls 12 | from air_sdk.util import raise_if_invalid_response 13 | 14 | if TYPE_CHECKING: 15 | from air_sdk.v2 import AirApi 16 | 17 | 18 | def serialize_payload(data: Dict[str, Any] | List[Dict[str, Any]]) -> str: 19 | """Serialize the dictionary of values into json using the AirJSONEncoder.""" 20 | return json.dumps(data, indent=None, separators=(',', ':'), cls=AirJSONEncoder) 21 | 22 | 23 | class BaseApiMixin: 24 | """A base class for API Mixins. 25 | 26 | This is primarily used for type hinting. 27 | """ 28 | 29 | __api__: AirApi 30 | url: str 31 | load_model: Callable[[DataDict], TAirModel_co] 32 | 33 | 34 | class PaginatedResponseData(TypedDict): 35 | count: int 36 | next: Optional[str] 37 | previous: Optional[str] 38 | results: List[DataDict] 39 | 40 | 41 | class ListApiMixin(BaseApiMixin, Generic[TAirModel_co]): 42 | """Returns an iterable of model objects. 43 | 44 | Handles pagination in the background. 45 | """ 46 | 47 | def list(self, **params: Any) -> Iterator[TAirModel_co]: 48 | """Return an iterator of model instances.""" 49 | url = self.url 50 | # Set up pagination 51 | next_url = None 52 | params.setdefault('limit', self.__api__.client.pagination_page_size) 53 | params = json.loads(serialize_payload(params)) # Accounts for UUIDs and AirModel params 54 | while url or next_url: 55 | if isinstance(next_url, str): 56 | response = self.__api__.client.get(next_url) 57 | else: 58 | response = self.__api__.client.get(url, params=params) 59 | raise_if_invalid_response(response) 60 | paginated_response_data: PaginatedResponseData = response.json() 61 | url = None # type: ignore[assignment] 62 | next_url = paginated_response_data['next'] 63 | for obj_data in paginated_response_data['results']: 64 | yield self.load_model(obj_data) 65 | 66 | 67 | class CreateApiMixin(BaseApiMixin, Generic[TAirModel_co]): 68 | def create(self, *args: Any, **kwargs: Any) -> TAirModel_co: 69 | response = self.__api__.client.post(self.url, data=serialize_payload(kwargs)) 70 | raise_if_invalid_response(response, status_code=HTTPStatus.CREATED) 71 | return self.load_model(response.json()) 72 | 73 | 74 | class GetApiMixin(BaseApiMixin, Generic[TAirModel_co]): 75 | def get(self, pk: PrimaryKey, **params: Any) -> TAirModel_co: 76 | detail_url = join_urls(self.url, str(pk)) 77 | response = self.__api__.client.get(detail_url, params=params) 78 | raise_if_invalid_response(response) 79 | return self.load_model(response.json()) 80 | 81 | 82 | class PutApiMixin(BaseApiMixin, Generic[TAirModel_co]): 83 | def put(self, pk: PrimaryKey, **kwargs: Any) -> TAirModel_co: 84 | response = self.__api__.client.put(join_urls(self.url, str(pk)), data=serialize_payload(kwargs)) 85 | raise_if_invalid_response(response, status_code=HTTPStatus.OK) 86 | return self.load_model(response.json()) 87 | 88 | 89 | class PatchApiMixin(BaseApiMixin, Generic[TAirModel_co]): 90 | def patch(self, pk: PrimaryKey, **kwargs: Any) -> TAirModel_co: 91 | response = self.__api__.client.patch(join_urls(self.url, str(pk)), data=serialize_payload(kwargs)) 92 | raise_if_invalid_response(response, status_code=HTTPStatus.OK) 93 | return self.load_model(response.json()) 94 | 95 | 96 | class DeleteApiMixin(BaseApiMixin): 97 | def delete(self, pk: PrimaryKey, **kwargs: Any) -> None: 98 | """Deletes the instances with the specified primary key.""" 99 | detail_url = join_urls(self.url, str(pk)) 100 | response = self.__api__.client.delete(detail_url, json=kwargs) 101 | raise_if_invalid_response(response, status_code=HTTPStatus.NO_CONTENT, data_type=None) 102 | -------------------------------------------------------------------------------- /air_sdk/userconfig.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | UserConfig model related module. 6 | """ 7 | 8 | import json 9 | import os 10 | from io import TextIOBase 11 | from pathlib import Path 12 | from typing import Dict, Optional, TextIO, Union 13 | 14 | from .air_model import AirModel, AirModelAPI 15 | from .exceptions import AirObjectDeleted 16 | from .organization import Organization 17 | 18 | 19 | class UserConfig(AirModel): 20 | """ 21 | Manage a UserConfig. 22 | 23 | ### delete 24 | Delete the UserConfig. Once successful, the object should no longer be used and will raise 25 | `AirDeletedObject` when referenced. 26 | 27 | Raises: 28 | `AirUnexpectedResponse` - Delete failed 29 | 30 | ### json 31 | Returns a JSON string representation of the UserConfig. 32 | 33 | ### refresh 34 | Syncs the UserConfig with all values returned by the API 35 | 36 | ### update 37 | Update the UserConfig with the provided data 38 | 39 | Arguments: 40 | kwargs (dict, optional): All optional keyword arguments are applied as key/value 41 | pairs in the request's JSON payload 42 | """ 43 | 44 | KIND_CLOUD_INIT_USER_DATA = 'cloud-init-user-data' 45 | KIND_CLOUD_INIT_META_DATA = 'cloud-init-meta-data' 46 | VALID_KINDS = (KIND_CLOUD_INIT_USER_DATA, KIND_CLOUD_INIT_META_DATA) 47 | 48 | def __repr__(self): 49 | try: 50 | self_dict: Dict = json.loads(self.json()) 51 | if self_dict.get('name', False) and self_dict.get('kind', False): 52 | return f'' 53 | except AirObjectDeleted: 54 | pass 55 | 56 | return super().__repr__() 57 | 58 | 59 | class UserConfigAPI(AirModelAPI[UserConfig]): 60 | """High-level interface for the UserConfig API.""" 61 | 62 | API_VERSION = 2 63 | API_PATH = 'userconfigs' 64 | 65 | def create( 66 | self, 67 | name: str, 68 | kind: str, 69 | organization: Optional[Union[Organization, str]], 70 | content: Union[str, Path, TextIO], 71 | ) -> UserConfig: 72 | """ 73 | Create a new UserConfig. Content data can be provided as a plain string, path to an existing file or an open file handle. 74 | Keep in mind that: 75 | - When passing a file path, it will be opened for reading using default encoding 76 | - When passing a file handle, it is assumed to be opened using a proper encoding and will be read from as-is 77 | - Due to a small size of UserConfig scripts, content will be loaded into memory in its entirety 78 | 79 | Arguments: 80 | name: UserConfig name 81 | kind: UserConfig kind, must be one of `UserConfig.VALID_KINDS` 82 | organization: Organization instance / ID to create the UserConfig in 83 | content: UserConfig data 84 | kwargs (dict, optional): All other optional keyword arguments are applied as key/value 85 | pairs in the request's JSON payload 86 | 87 | Raises: 88 | `AirUnexpectedResponse` - API did not return a 200 OK 89 | or valid response JSON 90 | `AttributeError` - provided content object is not one of the allowed types 91 | `FileNotFoundError` - when providing a path to a content file and the file is not present 92 | 93 | Example: 94 | ``` 95 | >>> air.user_configs.create(name='my-config', kind=air.user_configs.model.KIND_CLOUD_INIT_USER_DATA, organization=my_org, content="my-content") 96 | 97 | """ 98 | 99 | if isinstance(content, str): 100 | if os.path.exists(content): 101 | with open(content, 'r') as content_file: 102 | parsed_content = content_file.read() 103 | else: 104 | parsed_content = content 105 | elif isinstance(content, Path): 106 | with content.open('r') as content_file: 107 | parsed_content = content_file.read() 108 | elif isinstance(content, TextIOBase): 109 | parsed_content = content.read() 110 | else: 111 | raise AttributeError(f'Unexpected content type provided: `{type(content)}`') 112 | 113 | return super().create( 114 | name=name, 115 | kind=kind, 116 | organization=organization, 117 | content=parsed_content, 118 | ) 119 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_user_configs.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | import pytest 5 | import faker 6 | 7 | faker.Faker.seed(0) 8 | fake = faker.Faker() 9 | 10 | 11 | class TestUserConfigEndpointApi: 12 | def test_list(self, api, run_list_test, user_config_factory): 13 | run_list_test(api.user_configs, user_config_factory) 14 | 15 | def test_refresh(self, api, run_refresh_test, user_config_factory): 16 | run_refresh_test(api.user_configs, user_config_factory) 17 | 18 | def test_delete(self, api, run_delete_test, user_config_factory): 19 | run_delete_test(api.user_configs, user_config_factory) 20 | 21 | @pytest.mark.parametrize( 22 | 'payload,is_valid', 23 | ( 24 | ({}, False), 25 | ({'name': None, 'kind': None, 'content': None}, False), 26 | ({'name': fake.slug(), 'kind': fake.slug(), 'content': fake.binary()}, False), 27 | ({'name': fake.slug(), 'kind': fake.slug(), 'content': fake.slug()}, True), 28 | ( 29 | {'name': fake.slug(), 'kind': fake.slug(), 'content': fake.slug(), 'owner': fake.pyint()}, 30 | False, 31 | ), 32 | ( 33 | { 34 | 'name': fake.slug(), 35 | 'kind': fake.slug(), 36 | 'content': fake.slug(), 37 | }, 38 | True, 39 | ), 40 | ), 41 | ) 42 | def test_create(self, api, user_config_factory, run_create_test_case, payload, is_valid): 43 | """This tests that the data provided is properly validated and used.""" 44 | run_create_test_case(api.user_configs, user_config_factory, payload, is_valid) 45 | 46 | def test_update( 47 | self, 48 | api, 49 | run_update_patch_test, 50 | user_config_factory, 51 | ): 52 | cases = ( 53 | ({}, True), 54 | ({'name': None}, True), 55 | ({'name': fake.slug()}, True), 56 | ({'name': fake.pybool()}, False), 57 | ({'content': None}, True), 58 | ({'content': fake.text()}, True), 59 | ({'content': fake.pyint()}, False), 60 | ({'name': fake.slug(), 'content': fake.slug()}, True), 61 | ({'fake_field': None}, False), 62 | ) 63 | for payload, is_valid in cases: 64 | run_update_patch_test(api.user_configs, user_config_factory, payload, is_valid) 65 | 66 | def test_full_update( 67 | self, 68 | api, 69 | run_full_update_put_test, 70 | user_config_factory, 71 | organization_factory, 72 | ): 73 | cases = ( 74 | ({}, False), 75 | ({'name': None, 'content': None}, False), 76 | ({'name': None, 'content': fake.text()}, False), 77 | ({'name': fake.slug(), 'content': None}, False), 78 | ({'name': fake.slug(), 'content': fake.text()}, True), 79 | ({'name': fake.slug(), 'content': fake.text(), 'unexpected_field': fake.slug()}, False), 80 | ) 81 | for payload, is_valid in cases: 82 | run_full_update_put_test(api.user_configs, user_config_factory, payload, is_valid) 83 | 84 | 85 | class TestUserConfigModelRelations: 86 | def test_owner_access(self, api, user_config_factory): 87 | user_config = user_config_factory(api) 88 | owner = user_config.owner 89 | assert owner.__fk_resolved__ is False 90 | assert owner.id is not None 91 | assert owner.__fk_resolved__ is True 92 | with pytest.raises(NotImplementedError): 93 | owner.refresh() 94 | 95 | def test_owner_budget_access(self, api, user_config_factory): 96 | user_config = user_config_factory(api) 97 | owner_budget = user_config.owner_budget 98 | assert owner_budget.__fk_resolved__ is False 99 | assert owner_budget.id is not None 100 | assert owner_budget.__fk_resolved__ is True 101 | with pytest.raises(NotImplementedError): 102 | owner_budget.refresh() 103 | 104 | def test_organization_budget_access(self, api, user_config_factory): 105 | user_config = user_config_factory(api) 106 | organization_budget = user_config.organization_budget 107 | assert organization_budget.__fk_resolved__ is False 108 | assert organization_budget.id is not None 109 | assert organization_budget.__fk_resolved__ is True 110 | with pytest.raises(NotImplementedError): 111 | organization_budget.refresh() 112 | -------------------------------------------------------------------------------- /air_sdk/node.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Node module 6 | """ 7 | 8 | from . import util 9 | from .air_model import AirModel 10 | 11 | 12 | class Node(AirModel): 13 | """ 14 | Manage a Node 15 | 16 | ### delete 17 | Delete the node. Once successful, the object should no longer be used and will raise 18 | [`AirDeletedObject`](/docs/exceptions) when referenced. 19 | 20 | Raises: 21 | [`AirUnexpectedResposne`](/docs/exceptions) - Delete failed 22 | 23 | ### json 24 | Returns a JSON string representation of the node 25 | 26 | ### refresh 27 | Syncs the node with all values returned by the API 28 | 29 | ### update 30 | Update the node with the provided data 31 | 32 | Arguments: 33 | kwargs (dict, optional): All optional keyword arguments are applied as key/value 34 | pairs in the request's JSON payload 35 | """ 36 | 37 | _ignored_update_fields = ['interfaces'] 38 | 39 | def __repr__(self): 40 | if self._deleted or not self.name: 41 | return super().__repr__() 42 | return f'' 43 | 44 | 45 | class NodeApi: 46 | """High-level interface for the Node API""" 47 | 48 | def __init__(self, client): 49 | self.client = client 50 | self.url = self.client.api_url + '/node/' 51 | 52 | @util.deprecated('NodeApi.list()') 53 | def get_nodes(self, simulation_id=''): # pylint: disable=missing-function-docstring 54 | return self.list(simulation=simulation_id) 55 | 56 | def get(self, node_id, **kwargs): 57 | """ 58 | Get an existing node 59 | 60 | Arguments: 61 | node_id (str): Node ID 62 | kwargs (dict, optional): All other optional keyword arguments are applied as query 63 | parameters/filters 64 | 65 | Returns: 66 | [`Node`](/docs/node) 67 | 68 | Raises: 69 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 70 | or valid response JSON 71 | 72 | Example: 73 | ``` 74 | >>> air.nodes.get('3dadd54d-583c-432e-9383-a2b0b1d7f551') 75 | 76 | ``` 77 | """ 78 | if kwargs.get('simulation_id'): 79 | kwargs['simulation'] = kwargs['simulation_id'] 80 | del kwargs['simulation_id'] 81 | url = f'{self.url}{node_id}/' 82 | res = self.client.get(url, params=kwargs) 83 | util.raise_if_invalid_response(res) 84 | return Node(self, **res.json()) 85 | 86 | def list(self, **kwargs): 87 | # pylint: disable=line-too-long 88 | """ 89 | List existing nodes 90 | 91 | Arguments: 92 | kwargs (dict, optional): All other optional keyword arguments are applied as query 93 | parameters/filters 94 | 95 | Returns: 96 | list 97 | 98 | Raises: 99 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 100 | or valid response JSON 101 | 102 | Example: 103 | ``` 104 | >>> air.nodes.list() 105 | [, ] 106 | ``` 107 | """ # pylint: enable=line-too-long 108 | res = self.client.get(f'{self.url}', params=kwargs) 109 | util.raise_if_invalid_response(res, data_type=list) 110 | return [Node(self, **node) for node in res.json()] 111 | 112 | @util.required_kwargs(['name', 'topology']) 113 | def create(self, **kwargs): 114 | """ 115 | Create a new node 116 | 117 | Arguments: 118 | name (str): Node name 119 | topology (str | `Topology`): `Topology` or ID 120 | kwargs (dict, optional): All other optional keyword arguments are applied as key/value 121 | pairs in the request's JSON payload 122 | 123 | Returns: 124 | [`Node`](/docs/node) 125 | 126 | Raises: 127 | [`AirUnexpectedResposne`](/docs/exceptions) - API did not return a 200 OK 128 | or valid response JSON 129 | 130 | Example: 131 | ``` 132 | >>> air.nodes.create(name='server', topology=topology) 133 | 134 | ``` 135 | """ 136 | res = self.client.post(self.url, json=kwargs) 137 | util.raise_if_invalid_response(res, status_code=201) 138 | return Node(self, **res.json()) 139 | -------------------------------------------------------------------------------- /tests/tests_v2/test_endpoints/test_topologies.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | from http import HTTPStatus 5 | 6 | import faker 7 | import pytest 8 | 9 | from air_sdk.v2.utils import join_urls 10 | from air_sdk.v2.endpoints.topologies import TopologyEndpointApi 11 | from air_sdk.exceptions import AirUnexpectedResponse 12 | 13 | faker.Faker.seed(0) 14 | fake = faker.Faker() 15 | 16 | 17 | @pytest.mark.parametrize( 18 | 'payload, is_valid', 19 | [ 20 | ( # Valid case 21 | { 22 | 'source_format': 'DOT', 23 | 'destination_format': 'JSON', 24 | 'topology_data': 'graph "Valid Topology" {\r\n "node-1" [ os="generic/ubuntu2204" cpu=1 ]}', 25 | }, 26 | True, 27 | ), 28 | ( # Valid case: Multiple nodes with connections 29 | { 30 | 'source_format': 'DOT', 31 | 'destination_format': 'JSON', 32 | 'topology_data': """ 33 | graph "Multi Node Topology" { 34 | "node-1" [ os="generic/ubuntu2204" cpu=2 ]; 35 | "node-2" [ os="generic/ubuntu2204" cpu=1 ]; 36 | "node-1" -- "node-2" [ bandwidth="1" ]; 37 | } 38 | """, 39 | }, 40 | True, 41 | ), 42 | ( # Valid case: Complex topology with multiple connections 43 | { 44 | 'source_format': 'DOT', 45 | 'destination_format': 'JSON', 46 | 'topology_data': """ 47 | graph "Complex Topology" { 48 | "server-1" [ os="generic/ubuntu2204" cpu=4 memory="2048" ]; 49 | "server-2" [ os="generic/ubuntu2204" cpu=2 memory="2048" ]; 50 | "switch-1" [ type="switch" ]; 51 | "server-1" -- "switch-1" [ bandwidth="10Gbps" ]; 52 | "server-2" -- "switch-1" [ bandwidth="10Gbps" ]; 53 | } 54 | """, 55 | }, 56 | True, 57 | ), 58 | ( # Invalid case: Missing required field 59 | { 60 | 'source_format': 'DOT', 61 | 'destination_format': 'JSON', 62 | }, 63 | False, 64 | ), 65 | ( # Invalid case: Malformed DOT syntax 66 | { 67 | 'source_format': 'DOT', 68 | 'destination_format': 'JSON', 69 | 'topology_data': 'graph "Invalid Topology" { invalid syntax here }', 70 | }, 71 | False, 72 | ), 73 | ( # Invalid case: Empty topology 74 | { 75 | 'source_format': 'DOT', 76 | 'destination_format': 'JSON', 77 | 'topology_data': '', 78 | }, 79 | False, 80 | ), 81 | ( # Invalid case: Wrong source format 82 | { 83 | 'source_format': 'INVALID', 84 | 'destination_format': 'JSON', 85 | 'topology_data': 'graph "Valid Topology" { "node-1" }', 86 | }, 87 | False, 88 | ), 89 | ], 90 | ) 91 | def test_parse_conversion(setup_mock_responses, api, payload, is_valid): 92 | endpoint_api = api.topologies 93 | parse_url = join_urls(api.topologies.url_v1, TopologyEndpointApi.PARSE_PATH) 94 | 95 | if is_valid: 96 | expected_response = { 97 | 'source_format': payload['source_format'], 98 | 'destination_format': payload['destination_format'], 99 | 'topology_data': payload['topology_data'], 100 | } 101 | 102 | setup_mock_responses( 103 | { 104 | ('POST', parse_url): { 105 | 'json': expected_response, 106 | 'status_code': HTTPStatus.OK, 107 | } 108 | } 109 | ) 110 | 111 | response = endpoint_api.parse(**payload) 112 | 113 | assert response['source_format'] == payload['source_format'] 114 | assert response['destination_format'] == payload['destination_format'] 115 | assert response['topology_data'] == payload['topology_data'] 116 | 117 | else: 118 | setup_mock_responses( 119 | { 120 | ('POST', parse_url): { 121 | 'json': {'error': 'Invalid request'}, 122 | 'status_code': HTTPStatus.BAD_REQUEST, 123 | } 124 | } 125 | ) 126 | with pytest.raises(Exception) as err: 127 | endpoint_api.parse(**payload) 128 | 129 | assert err.type in (AirUnexpectedResponse, TypeError, ValueError) 130 | -------------------------------------------------------------------------------- /tests/tests_v2/test_air_api.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | import logging 4 | import uuid 5 | from datetime import timedelta 6 | from unittest.mock import patch 7 | 8 | import faker 9 | import pytest 10 | 11 | from air_sdk import v2, const 12 | 13 | faker.Faker.seed(0) 14 | fake = faker.Faker() 15 | 16 | 17 | class TestAirApi: 18 | def setup_method(self): 19 | self.AirApi = v2.AirApi 20 | self.username = str(uuid.uuid4()) 21 | self.password = str(uuid.uuid4()) 22 | self.api_url = f'https://air-{fake.slug()}.nvidia.com' 23 | self.v2_api_url = v2.utils.join_urls(self.api_url, 'api', 'v2') 24 | 25 | def test_client_setup_username_and_password(self): 26 | """Ensure the client's url and authentication were correctly set.""" 27 | fake_jwt = str(uuid.uuid4()) 28 | with patch('air_sdk.v2.client.Client.get_token') as mock_get_token: 29 | mock_get_token.return_value = fake_jwt 30 | api = self.AirApi(api_url=self.api_url, username=self.username, password=self.password) 31 | assert type(api.client) is v2.client.Client 32 | assert api.client.base_url == self.v2_api_url 33 | assert 'content-type' in api.client.headers.keys(), 'A default content-type should be set' 34 | # Ensure we have an authentication token 35 | mock_get_token.assert_called_once_with(self.username, self.password) 36 | assert ( 37 | api.client.headers['Authorization'] == f'Bearer {fake_jwt}' 38 | ), 'Authorization token was not set' 39 | 40 | def test_client_setup_bearer_token(self): 41 | """Ensure we use the bearer_token provided.""" 42 | bearer_token = str(uuid.uuid4()) 43 | with patch('air_sdk.v2.client.Client.get_token') as mock_get_token: 44 | api = self.AirApi(api_url=self.api_url, bearer_token=bearer_token) 45 | mock_get_token.assert_not_called() 46 | assert type(api.client) is v2.client.Client 47 | assert api.client.base_url == self.v2_api_url 48 | assert 'content-type' in api.client.headers.keys(), 'A default content-type should be set' 49 | # Ensure we have an authentication token 50 | assert ( 51 | api.client.headers['Authorization'] == f'Bearer {bearer_token}' 52 | ), 'Authorization token was not set' 53 | 54 | @pytest.mark.parametrize( 55 | 'extra_kwargs', [{}, {'username': str(uuid.uuid4())}, {'password': str(uuid.uuid4())}] 56 | ) 57 | def test_client_no_auth_credentials_provided(self, extra_kwargs): 58 | with pytest.raises(ValueError): 59 | self.AirApi(api_url=self.api_url, **extra_kwargs) 60 | 61 | def test_client_can_skip_authentication(self): 62 | api = self.AirApi(api_url=self.api_url, authenticate=False) 63 | assert type(api.client) is v2.client.Client 64 | assert api.client.headers['Authorization'] is None, 'A blank Authorization should be set.' 65 | 66 | def test_default_timeouts(self, mock_client, setup_mock_responses, paginated_response): 67 | """Ensure we set a default timeout for all requests.""" 68 | api = self.AirApi(api_url=self.api_url, authenticate=False) 69 | endpoint = api.marketplace_demo_tags 70 | setup_mock_responses({('GET', endpoint.url): paginated_response}) 71 | list(endpoint.list()) 72 | assert mock_client.call_count == 1 73 | assert mock_client.request_history[0]._timeout == ( 74 | const.DEFAULT_CONNECT_TIMEOUT, 75 | const.DEFAULT_READ_TIMEOUT, 76 | ) 77 | 78 | def test_custom_timeouts(self, mock_client, setup_mock_responses, paginated_response): 79 | """Ensure clients can set a custom timeouts for read/connect if they desire.""" 80 | api = self.AirApi(api_url=self.api_url, authenticate=False) 81 | custom_connect_timeout = fake.pyint() 82 | custom_read_timeout = fake.pyint() 83 | api.set_connect_timeout(timedelta(seconds=custom_connect_timeout)) 84 | api.set_read_timeout(timedelta(seconds=custom_read_timeout)) 85 | endpoint = api.marketplace_demo_tags 86 | setup_mock_responses({('GET', endpoint.url): paginated_response}) 87 | list(endpoint.list()) 88 | assert mock_client.call_count == 1 89 | assert mock_client.request_history[0]._timeout == (custom_connect_timeout, custom_read_timeout) 90 | 91 | def test_each_request_is_logged(self, caplog, mock_client, setup_mock_responses, paginated_response): 92 | api = self.AirApi(api_url=self.api_url, authenticate=False) 93 | endpoint = api.marketplace_demo_tags 94 | setup_mock_responses({('GET', endpoint.url): paginated_response}) 95 | with caplog.at_level(logging.DEBUG): 96 | assert list(endpoint.list()) == [] 97 | assert str(caplog.records[-1].message == f'GET: {endpoint.url}') 98 | -------------------------------------------------------------------------------- /air_sdk/util.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: MIT 3 | 4 | """ 5 | Helper utils 6 | """ 7 | 8 | import datetime 9 | from json import JSONDecodeError 10 | from urllib.parse import ParseResult 11 | from requests import Response 12 | 13 | from dateutil import parser as dateparser 14 | 15 | from .exceptions import AirUnexpectedResponse 16 | from .logger import air_sdk_logger as logger 17 | 18 | 19 | def raise_if_invalid_response(res: Response, status_code=200, data_type=dict): 20 | """ 21 | Validates that a given API response has the expected status code and JSON payload 22 | 23 | Arguments: 24 | res (requests.HTTPResponse) - API response object 25 | status_code [int] - Expected status code (default: 200) 26 | 27 | Raises: 28 | AirUnexpectedResponse - Raised if an unexpected response is received from the API 29 | """ 30 | json = None 31 | if res.status_code != status_code: 32 | logger.debug(res.text) 33 | raise AirUnexpectedResponse(message=res.text, status_code=res.status_code) 34 | if not data_type: 35 | return 36 | try: 37 | json = res.json() 38 | except JSONDecodeError: 39 | raise AirUnexpectedResponse(message=res.text, status_code=res.status_code) 40 | if not isinstance(json, data_type): 41 | raise AirUnexpectedResponse( 42 | message=f'Expected API response to be of type {data_type}, ' + f'got {type(json)}', 43 | status_code=res.status_code, 44 | ) 45 | 46 | 47 | def required_kwargs(required): 48 | """Decorator to enforce required kwargs for a function""" 49 | if not isinstance(required, list): 50 | required = [required] 51 | 52 | def wrapper(method): 53 | def wrapped(*args, **kwargs): 54 | for arg in required: 55 | if isinstance(arg, tuple): 56 | present = False 57 | for option in arg: 58 | if option in kwargs: 59 | present = True 60 | break 61 | if not present: 62 | raise AttributeError(f'{method} requires one of the following: {arg}') 63 | else: 64 | if arg not in kwargs: 65 | raise AttributeError(f'{method} requires {arg}') 66 | return method(*args, **kwargs) 67 | 68 | return wrapped 69 | 70 | return wrapper 71 | 72 | 73 | def deprecated(new=None): 74 | """Decorator to log a warning when calling a deprecated function""" 75 | 76 | def wrapper(method): 77 | def wrapped(*args, **kwargs): 78 | msg = f'{method} has been deprecated and will be removed in a future release.' 79 | if new: 80 | msg += f' Use {new} instead.' 81 | logger.warning(msg) 82 | return method(*args, **kwargs) 83 | 84 | return wrapped 85 | 86 | return wrapper 87 | 88 | 89 | def validate_timestamps(log_prefix, **kwargs): 90 | """ 91 | Logs a warning if any provided timestamps are in the past 92 | 93 | Arguments: 94 | log_prefix (str): Prefix to be prepended to the logged warning(s) 95 | kwargs (dict): Timestamps to verify 96 | """ 97 | now = datetime.datetime.now() 98 | for key, value in kwargs.items(): 99 | if value and dateparser.parse(str(value)) <= now: 100 | logger.warning(f'{log_prefix} with `{key}` in the past: {value} (now: {now})') 101 | 102 | 103 | def is_datetime_str(value): 104 | """ 105 | Checks to see if the string is a valid datetime format 106 | 107 | Arguments: 108 | value (str): String to test if valid datetime format 109 | """ 110 | if isinstance(value, str): 111 | try: 112 | return datetime.datetime.fromisoformat(value.replace('Z', '+00:00')) 113 | except ValueError: 114 | pass 115 | return False 116 | 117 | 118 | def url_path_join(base: ParseResult, *segments: str, trailing_slash: bool = False) -> ParseResult: 119 | """ 120 | Appends provided path segments (if any) to provided base URL. 121 | Appends or removes a trailing slash at the end of path as specified by `trailing_slash` argument. 122 | 123 | Examples: 124 | ``` 125 | url = urlparse("https://example.com/a/b") 126 | 127 | url_path_join(url, "c", "d") # https://example.com/a/b/c/d 128 | url_path_join(url, "c", "d", trailing_slash=True) # https://example.com/a/b/c/d/ 129 | url_path_join(url, trailing_slash=True) # https://example.com/a/b/ 130 | ``` 131 | """ 132 | 133 | resulting_path = base.path 134 | for segment in segments: 135 | resulting_path = f"{resulting_path.rstrip('/')}/{segment.lstrip('/')}" 136 | 137 | resulting_url = base._replace(path=resulting_path.rstrip('/') + ('/' if trailing_slash else '')) 138 | return resulting_url 139 | --------------------------------------------------------------------------------