├── .github ├── pull_request_template.md └── workflows │ ├── codebuild-ci.yml │ └── security-monitoring.yml ├── .gitignore ├── .gitmodules ├── CHANGELOG.md ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── THIRD-PARTY ├── doc ├── conf.py └── index.rst ├── examples └── basic-job-example-config.yaml ├── helm_chart ├── HyperPodHelmChart │ ├── .helmignore │ ├── Chart.yaml │ ├── charts │ │ ├── cluster-role-and-bindings │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ └── cluster-level-config.yaml │ │ │ └── values.yaml │ │ ├── deep-health-check │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ ├── aws-hyperpod-namespace.yaml │ │ │ │ └── deep-health-check-rbac.yaml │ │ │ └── values.yaml │ │ ├── health-monitoring-agent │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ └── health-monitoring-agent.yaml │ │ │ └── values.yaml │ │ ├── hyperpod-patching │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ ├── clusterrole.yaml │ │ │ │ └── clusterrolebinding.yaml │ │ │ └── values.yaml │ │ ├── job-auto-restart │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ └── job-auto-restart-rbac.yaml │ │ │ └── values.yaml │ │ ├── mlflow │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ └── service-account.yaml │ │ │ └── values.yaml │ │ ├── mpi-operator │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ ├── _helpers.tpl │ │ │ │ ├── deployment.yaml │ │ │ │ ├── mpijob-crd.yaml │ │ │ │ └── rbac.yaml │ │ │ └── values.yaml │ │ ├── namespaced-role-and-bindings │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ └── namespace-level-role.yaml │ │ │ └── values.yaml │ │ ├── neuron-device-plugin │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ ├── k8s-neuron-device-plugin-rbac.yaml │ │ │ │ └── k8s-neuron-device-plugin.yaml │ │ │ └── values.yaml │ │ ├── storage │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ ├── persistent-volume-claim.yaml │ │ │ │ └── persistent-volume.yaml │ │ │ └── values.yaml │ │ ├── team-role-and-bindings │ │ │ ├── Chart.yaml │ │ │ ├── templates │ │ │ │ ├── team-level-cluster-role.yaml │ │ │ │ └── team-level-namespaced-role.yaml │ │ │ └── values.yaml │ │ └── training-operators │ │ │ ├── Chart.yaml │ │ │ ├── crds │ │ │ ├── mxjobs.kubeflow.org-CustomResourceDefinition.yaml │ │ │ ├── paddlejobs.kubeflow.org-CustomResourceDefinition.yaml │ │ │ ├── pytorchjobs.kubeflow.org-CustomResourceDefinition.yaml │ │ │ ├── tfjobs.kubeflow.org-CustomResourceDefinition.yaml │ │ │ └── xgboostjobs.kubeflow.org-CustomResourceDefinition.yaml │ │ │ ├── templates │ │ │ ├── ClusterRole │ │ │ │ ├── kubeflow-training-admin-ClusterRole.yaml │ │ │ │ ├── kubeflow-training-edit-ClusterRole.yaml │ │ │ │ ├── kubeflow-training-view-ClusterRole.yaml │ │ │ │ └── training-operator-ClusterRole.yaml │ │ │ ├── ClusterRoleBinding │ │ │ │ └── training-operator-ClusterRoleBinding.yaml │ │ │ ├── Deployment │ │ │ │ └── training-operator-kubeflow-Deployment.yaml │ │ │ ├── Service │ │ │ │ └── training-operator-kubeflow-Service.yaml │ │ │ ├── ServiceAccount │ │ │ │ └── training-operator-kubeflow-ServiceAccount.yaml │ │ │ ├── _helpers.tpl │ │ │ └── kubeflow-namespace.yaml │ │ │ └── values.yaml │ └── values.yaml ├── get_helm.sh ├── install_dependencies.sh └── readme.md ├── pyproject.toml ├── setup.cfg ├── setup.py ├── src └── hyperpod_cli │ ├── __init__.py │ ├── cli.py │ ├── clients │ ├── __init__.py │ └── kubernetes_client.py │ ├── commands │ ├── __init__.py │ ├── cluster.py │ ├── job.py │ └── pod.py │ ├── constants │ ├── __init__.py │ ├── command_constants.py │ ├── exception_constants.py │ ├── hyperpod_instance_types.py │ ├── kueue_constants.py │ └── pytorch_constants.py │ ├── py.typed │ ├── service │ ├── __init__.py │ ├── cancel_training_job.py │ ├── discover_namespaces.py │ ├── exec_command.py │ ├── get_logs.py │ ├── get_namespaces.py │ ├── get_training_job.py │ ├── list_pods.py │ ├── list_training_jobs.py │ └── self_subject_access_review.py │ ├── telemetry │ ├── __init__.py │ └── user_agent.py │ ├── templates │ ├── __init__.py │ └── k8s_pytorch_job_template.py │ ├── utils.py │ └── validators │ ├── __init__.py │ ├── cluster_validator.py │ ├── job_validator.py │ └── validator.py ├── test ├── __init__.py ├── integration_tests │ ├── __init__.py │ ├── abstract_integration_tests.py │ ├── charts │ │ └── hp-node-auth.yaml │ ├── cloudformation │ │ └── resources.yaml │ ├── data │ │ ├── basicJob.yaml │ │ └── basicJobWithQuota.yaml │ ├── lifecycle_script │ │ └── on_create_noop.sh │ └── test_happy_case.py └── unit_tests │ ├── clients │ └── test_kubernetes_client.py │ ├── service │ ├── test_cancel_training_job_service.py │ ├── test_discover_namespaces.py │ ├── test_exec_command_service.py │ ├── test_get_logs_service.py │ ├── test_get_namespaces.py │ ├── test_get_training_job_service.py │ ├── test_list_pods_service.py │ ├── test_list_training_jobs_service.py │ └── test_self_subject_access_review.py │ ├── test_cluster.py │ ├── test_hyperpod_cli.py │ ├── test_job.py │ ├── test_pod.py │ ├── test_utils.py │ └── validators │ ├── test_cluster_validatory.py │ ├── test_job_validator.py │ └── test_validator.py └── tox.ini /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | # PR Approval Steps 2 | 3 | ## For Requester 4 | 5 | 1. Description 6 | - [ ] Check the PR title and description for clarity. It should describe the changes made and the reason behind them. 7 | - [ ] Ensure that the PR follows the contribution guidelines, if applicable. 8 | 2. Security requirements 9 | - [ ] Ensure that a Pull Request (PR) does not expose passwords and other sensitive information by using git-secrets and upload relevant evidence: https://github.com/awslabs/git-secrets 10 | - [ ] Ensure commit has GitHub Commit Signature 11 | 3. Manual review 12 | 1. Click on the Files changed tab to see the code changes. Review the changes thoroughly: 13 | - [ ] Code Quality: Check for coding standards, naming conventions, and readability. 14 | - [ ] Functionality: Ensure that the changes meet the requirements and that all necessary code paths are tested. 15 | - [ ] Security: Check for any security issues or vulnerabilities. 16 | - [ ] Documentation: Confirm that any necessary documentation (code comments, README updates, etc.) has been updated. 17 | 4. Check for Merge Conflicts: 18 | - [ ] Verify if there are any merge conflicts with the base branch. GitHub will usually highlight this. If there are conflicts, you should resolve them. 19 | 20 | ## For Reviewer 21 | 22 | 1. Go through `For Requester` section to double check each item. 23 | 2. Request Changes or Approve the PR: 24 | 1. If the PR is ready to be merged, click Review changes and select Approve. 25 | 2. If changes are required, select Request changes and provide feedback. Be constructive and clear in your feedback. 26 | 3. Merging the PR 27 | 1. Check the Merge Method: 28 | 1. Decide on the appropriate merge method based on your repository's guidelines (e.g., Squash and merge, Rebase and merge, or Merge). 29 | 2. Merge the PR: 30 | 1. Click the Merge pull request button. 31 | 2. Confirm the merge by clicking Confirm merge. 32 | 33 | -------------------------------------------------------------------------------- /.github/workflows/codebuild-ci.yml: -------------------------------------------------------------------------------- 1 | name: PR Checks 2 | on: 3 | pull_request_target: 4 | branches: 5 | - "main*" 6 | 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.head_ref }} 9 | cancel-in-progress: true 10 | 11 | permissions: 12 | id-token: write # This is required for requesting the JWT 13 | 14 | jobs: 15 | collab-check: 16 | runs-on: ubuntu-latest 17 | outputs: 18 | approval-env: ${{ steps.collab-check.outputs.result }} 19 | steps: 20 | - name: Collaborator Check 21 | uses: actions/github-script@v7 22 | id: collab-check 23 | with: 24 | github-token: ${{ secrets.COLLAB_CHECK_TOKEN }} 25 | result-encoding: string 26 | script: | 27 | try { 28 | const res = await github.rest.repos.checkCollaborator({ 29 | owner: context.repo.owner, 30 | repo: context.repo.repo, 31 | username: "${{ github.event.pull_request.user.login }}", 32 | }); 33 | console.log("Verifed ${{ github.event.pull_request.user.login }} is a repo collaborator. Auto Approving PR Checks.") 34 | return res.status == "204" ? "auto-approve" : "manual-approval" 35 | } catch (error) { 36 | console.log("${{ github.event.pull_request.user.login }} is not a collaborator. Requiring Manual Approval to run PR Checks.") 37 | return "manual-approval" 38 | } 39 | wait-for-approval: 40 | runs-on: ubuntu-latest 41 | needs: [collab-check] 42 | environment: ${{ needs.collab-check.outputs.approval-env }} 43 | steps: 44 | - run: echo "Workflow Approved! Starting PR Checks." 45 | unit-tests: 46 | runs-on: ubuntu-latest 47 | needs: [wait-for-approval] 48 | strategy: 49 | matrix: 50 | python-version: ["38", "39", "310", "311"] 51 | steps: 52 | - name: Configure AWS Credentials 53 | uses: aws-actions/configure-aws-credentials@v3 54 | with: 55 | role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} 56 | aws-region: us-west-2 57 | role-duration-seconds: 3600 58 | - name: Run Unit Tests in Python ${{ matrix.python-version }} 59 | uses: aws-actions/aws-codebuild-run-build@v1 60 | with: 61 | project-name: ${{ secrets.UNIT_TEST_PROJECT_PREFIX }}${{ matrix.python-version }} 62 | source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}' 63 | integration-tests: 64 | runs-on: ubuntu-latest 65 | needs: [wait-for-approval] 66 | strategy: 67 | fail-fast: false 68 | matrix: 69 | python-version: ["38", "39", "310", "311"] 70 | steps: 71 | - name: Configure AWS Credentials 72 | uses: aws-actions/configure-aws-credentials@v3 73 | with: 74 | role-to-assume: ${{ secrets.CI_AWS_ROLE_ARN }} 75 | aws-region: us-west-2 76 | role-duration-seconds: 3600 77 | - name: Run Integration Tests in Python ${{ matrix.python-version }} 78 | uses: aws-actions/aws-codebuild-run-build@v1 79 | with: 80 | project-name: ${{ secrets.INTEGRATION_TEST_PROJECT_PREFIX }}${{ matrix.python-version }} 81 | source-version-override: 'refs/pull/${{ github.event.pull_request.number }}/head^{${{ github.event.pull_request.head.sha }}}' 82 | -------------------------------------------------------------------------------- /.github/workflows/security-monitoring.yml: -------------------------------------------------------------------------------- 1 | name: Security Monitoring 2 | 3 | on: 4 | schedule: 5 | - cron: '0 16 * * *' 6 | 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.run_id }} 9 | cancel-in-progress: true 10 | 11 | permissions: 12 | id-token: write 13 | 14 | jobs: 15 | check-dependabot-alerts: 16 | runs-on: ubuntu-latest 17 | outputs: 18 | dependabot_alert_status: ${{ steps.check-dependabot-alerts.outputs.dependabot_alert_status }} 19 | steps: 20 | - name: Check for dependabot alerts 21 | id: check-dependabot-alerts 22 | uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea 23 | with: 24 | github-token: ${{ secrets.GH_PAT }} 25 | script: | 26 | async function checkAlerts() { 27 | const owner = '${{ github.repository_owner }}'; 28 | const repo = '${{ github.event.repository.name }}'; 29 | 30 | const dependabotAlerts = await github.rest.dependabot.listAlertsForRepo({ 31 | owner, 32 | repo, 33 | headers: { 34 | 'accept': 'applications/vnd.github+json' 35 | } 36 | }); 37 | const activeDependabotAlerts = dependabotAlerts.data.filter(alert => alert.state === 'open'); 38 | core.setOutput('dependabot_alert_status', activeDependabotAlerts.length > 0 ? '1': '0'); 39 | } 40 | await checkAlerts(); 41 | 42 | check-code-scanning-alerts: 43 | runs-on: ubuntu-latest 44 | outputs: 45 | code_scanning_alert_status: ${{ steps.check-code-scanning-alerts.outputs.code_scanning_alert_status }} 46 | steps: 47 | - name: Check for security alerts 48 | id: check-code-scanning-alerts 49 | uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea 50 | with: 51 | github-token: ${{ secrets.GH_PAT }} 52 | script: | 53 | async function checkAlerts() { 54 | const owner = '${{ github.repository_owner }}'; 55 | const repo = '${{ github.event.repository.name }}'; 56 | const ref = 'refs/heads/main'; 57 | 58 | const codeScanningAlerts = await github.rest.codeScanning.listAlertsForRepo({ 59 | owner, 60 | repo, 61 | ref: ref 62 | }); 63 | const activeCodeScanningAlerts = codeScanningAlerts.data.filter(alert => alert.state === 'open'); 64 | core.setOutput('code_scanning_alert_status', activeCodeScanningAlerts.length > 0 ? '1': '0'); 65 | } 66 | await checkAlerts(); 67 | 68 | put-metric-data: 69 | runs-on: ubuntu-latest 70 | needs: [check-dependabot-alerts, check-code-scanning-alerts] 71 | steps: 72 | - name: Configure AWS Credentials 73 | uses: aws-actions/configure-aws-credentials@12e3392609eaaceb7ae6191b3f54bbcb85b5002b 74 | with: 75 | role-to-assume: ${{ secrets.MONITORING_ROLE_ARN }} 76 | aws-region: us-west-2 77 | - name: Put Dependabot Alert Metric Data 78 | run: | 79 | if [ "${{ needs.check-dependabot-alerts.outputs.dependabot_alert_status }}" == "1" ]; then 80 | aws cloudwatch put-metric-data --metric-name DependabotAlert --namespace SecurityMonitoringMetrics --value 1 --unit Count --dimensions ProjectName=sagemaker-hyperpod-cli 81 | else 82 | aws cloudwatch put-metric-data --metric-name DependabotAlert --namespace SecurityMonitoringMetrics --value 0 --unit Count --dimensions ProjectName=sagemaker-hyperpod-cli 83 | fi 84 | - name: Put Code Scanning Alert Metric Data 85 | run: | 86 | if [ "${{ needs.check-code-scanning-alerts.outputs.code_scanning_alert_status }}" == "1" ]; then 87 | aws cloudwatch put-metric-data --metric-name CodeScanningAlert --namespace SecurityMonitoringMetrics --value 1 --unit Count --dimensions ProjectName=sagemaker-hyperpod-cli 88 | else 89 | aws cloudwatch put-metric-data --metric-name CodeScanningAlert --namespace SecurityMonitoringMetrics --value 0 --unit Count --dimensions ProjectName=sagemaker-hyperpod-cli 90 | fi 91 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *# 3 | *.swp 4 | 5 | *.DS_Store 6 | 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | *.egg-info/ 11 | 12 | /.coverage 13 | /.coverage.* 14 | /.cache 15 | /.pytest_cache 16 | /.mypy_cache 17 | 18 | /doc/_apidoc/ 19 | /build 20 | 21 | # Ignore all contents of result and results directories 22 | /result/ 23 | /results/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "src/hyperpod_cli/sagemaker_hyperpod_recipes"] 2 | path = src/hyperpod_cli/sagemaker_hyperpod_recipes 3 | url = https://github.com/aws/sagemaker-hyperpod-recipes.git 4 | branch = release-1.3.3 5 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## v2.0.0 (2024-12-04) 4 | 5 | ### Features 6 | 7 | - feature: The HyperPod CLI now support ([Hyperpod recipes](https://github.com/aws/sagemaker-hyperpod-recipes.git)). The HyperPod recipes enable customers to get started training and fine-tuning popular publicly-available foundation models like Llama 3.1 405B in minutes. Learn more ([here](https://github.com/aws/sagemaker-hyperpod-recipes.git)). 8 | 9 | ## v1.0.0 (2024-09-09) 10 | 11 | ### Features 12 | 13 | - feature: Add support for SageMaker HyperPod CLI 14 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @aws/sagemaker-hyperpod-cli 2 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 3 | documentation, we greatly value feedback and contributions from our community. 4 | 5 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 6 | information to effectively respond to your bug report or contribution. 7 | 8 | 9 | ## Reporting Bugs/Feature Requests 10 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 11 | 12 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 13 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 14 | 15 | * A reproducible test case or series of steps 16 | * The version of our code being used 17 | * Any modifications you've made relevant to the bug 18 | * Anything unusual about your environment or deployment 19 | 20 | 21 | ## Contributing via Pull Requests 22 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 23 | 24 | 1. You are working against the latest source on the *main* branch. 25 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 26 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 27 | 28 | To send us a pull request, please: 29 | 30 | 1. Fork the repository. 31 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 32 | 3. Ensure local tests pass. 33 | 4. Commit to your fork using clear commit messages. Also, you must [sign your commit using GPG](https://docs.github.com/en/authentication/managing-commit-signature-verification/about-commit-signature-verification). 34 | 5. Install [git-secrets](https://github.com/awslabs/git-secrets), run secret scan to make sure no secrets are accidentally logged in your commit and commit messages. 35 | 5. Send us a pull request, answering any default questions in the pull request interface. 36 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 37 | 38 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 39 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 40 | 41 | 42 | ## Run the Unit Tests 43 | Install tox using `pip install tox` 44 | Run the following tox command and verify that all code checks and unit tests pass: `tox`. It will run test for Python 3.8, 3.9, 3.10, 3.11 45 | You can also run a test suit with a single python version with the following command: `tox -e py311/py310/py39/py38` 46 | You can also just run unit tests with command: `tox -e unit` 47 | 48 | 49 | ## Finding contributions to work on 50 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 51 | 52 | 53 | ## Code of Conduct 54 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 55 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 56 | opensource-codeofconduct@amazon.com with any additional questions or comments. 57 | 58 | 59 | ## Security issue notifications 60 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 61 | 62 | 63 | ## Licensing 64 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 65 | 66 | 67 | ## Troubleshooting 68 | ### pytest 'start_job' related Unit Test failures 69 | - Underlying, the Click CLIRunner unit test tool has some issue with UTF-8 encoding 70 | - To resolve this, temporary use ```LC_ALL=C pytest test/unit_tests``` to complete the Unit tests 71 | 72 | ### pytest 'NoSuchModule' errors 73 | - double check whether ```pytest``` is running the same venv or global environment as your project by ```which pytest```. 74 | - If ```pytest``` is running in venv, ensure ```hyperpod``` is installed in venv. 75 | - Otherwise, if ```pytest``` is running globally, ensure ```hyperpod``` is installed globally. -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | """Sphinx configuration.""" 2 | 3 | import datetime 4 | import os 5 | import shutil 6 | 7 | 8 | def run_apidoc(app): 9 | """Generate doc stubs using sphinx-apidoc.""" 10 | module_dir = os.path.join(app.srcdir, "../src/") 11 | output_dir = os.path.join(app.srcdir, "_apidoc") 12 | excludes = [] 13 | 14 | # Ensure that any stale apidoc files are cleaned up first. 15 | if os.path.exists(output_dir): 16 | shutil.rmtree(output_dir) 17 | 18 | cmd = [ 19 | "--separate", 20 | "--module-first", 21 | "--doc-project=API Reference", 22 | "-o", 23 | output_dir, 24 | module_dir, 25 | ] 26 | cmd.extend(excludes) 27 | 28 | try: 29 | from sphinx.ext import apidoc # Sphinx >= 1.7 30 | 31 | apidoc.main(cmd) 32 | except ImportError: 33 | from sphinx import apidoc # Sphinx < 1.7 34 | 35 | cmd.insert(0, apidoc.__file__) 36 | apidoc.main(cmd) 37 | 38 | 39 | def setup(app): 40 | """Register our sphinx-apidoc hook.""" 41 | app.connect("builder-inited", run_apidoc) 42 | 43 | 44 | # Sphinx configuration below. 45 | project = "SageMaker HyperPod CLI" 46 | 47 | # Example configuration for intersphinx: refer to the Python standard library. 48 | intersphinx_mapping = {"python": ("http://docs.python.org/", None)} 49 | 50 | extensions = [ 51 | "sphinx.ext.autodoc", 52 | "sphinx.ext.intersphinx", 53 | "sphinx.ext.napoleon", 54 | "sphinx.ext.todo", 55 | "sphinx.ext.viewcode", 56 | ] 57 | 58 | source_suffix = ".rst" 59 | master_doc = "index" 60 | 61 | autoclass_content = "class" 62 | autodoc_member_order = "bysource" 63 | default_role = "py:obj" 64 | 65 | html_theme = "haiku" 66 | htmlhelp_basename = "{}doc".format(project) 67 | 68 | napoleon_use_rtype = False 69 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | HyperpodCLI 2 | ======================= 3 | 4 | Please replace this text with a short description of your package. 5 | 6 | .. toctree:: 7 | 8 | _apidoc/modules 9 | 10 | 11 | Indices and tables 12 | __________________ 13 | 14 | * :ref:`genindex` 15 | * :ref:`modindex` 16 | * :ref:`search` 17 | -------------------------------------------------------------------------------- /examples/basic-job-example-config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | 14 | # Hydra configurations. SageMaker Hyperpod CLI use Hydra to manage start-job YAML files. 15 | defaults: 16 | - override hydra/job_logging: stdout 17 | 18 | hydra: 19 | run: 20 | dir: . 21 | output_subdir: null 22 | 23 | training_cfg: 24 | # entry_script: Required. Path to the entry script of training/fine-tuning, this path should be inside container. 25 | # Mapping to '--entry-script' argument in 'start-job' command. 26 | entry_script: /opt/pytorch-mnist/mnist.py 27 | # script_args: Optional. List of script arguments. Mapping to '--script-args' argument in 'start-job' command. 28 | # Example of usage: 29 | # script_args: 30 | # - --max_context_width: 4096 31 | # - --num_layers: 32 32 | script_args: [] 33 | run: 34 | # name: Required. Current Training Job name. Mapping to '--job-name' argument in 'start-job' command. 35 | name: hyperpod-cli-test 36 | # nodes: Required. Number of nodes to use for current training. Mapping to '--node-count' argument in 'start-job' command. 37 | nodes: 2 38 | # ntasks_per_node: Optional. Number of devices to use per node. 39 | # For CPU instances, default value will be 8; For GPU or TRN instances, default value 40 | # will be the accelerator cores number on the instance. Mapping to '--tasks-per-node' argument in 'start-job' command. 41 | ntasks_per_node: 1 42 | cluster: 43 | # cluster_type: Required. Currently, only support k8s cluster type. 44 | cluster_type: k8s 45 | # instance_type: Required. SageMaker Hyperpod supported instance type only. 46 | # Mapping to '--instance-type' argument in 'start-job' command. 47 | instance_type: ml.c5.2xlarge 48 | # cluster_config: Required. Fields related to cluster configuration for each Training job run. 49 | cluster_config: 50 | # annotations: Optional. The annotations attached to the Training job. 51 | # To use SageMaker Hyperpod Job Auto Resume feature, set annotations like following example: 52 | # annotations: 53 | # sagemaker.amazonaws.com/enable-job-auto-resume: True 54 | # sagemaker.amazonaws.com/job-max-retry-count: 1 55 | annotations: null 56 | # service_account_name: Optional. The name of service account associated with the namespace. 57 | # Mapping to '--service-account-name' argument in 'start-job' command. 58 | service_account_name: null 59 | # persistent_volume_claims: Optional. The persistent volume claims, usually used to mount FSx. 60 | # Mapping to '--persistent-volume-claims' argument in 'start-job' command. 61 | persistent_volume_claims: null 62 | # namespace: Optional. The namespace to submit job. If not specify, Training job will submit to 63 | # the current namespace from Kubernetes context. 64 | # Mapping to '--namespace' argument in 'start-job' command. 65 | namespace: kubeflow 66 | # custom_labels: Optional. Used to specify the name of the queue, which is created by the cluster admin users. 67 | # The priority class label is mapped to '--priority' argument in 'start-job' command if your scheduler type is 'SageMaker'. 68 | # custom_labels: 69 | # kueue.x-k8s.io/queue-name: low-priority-queue2 70 | # kueue.x-k8s.io/priority-class: sample-priority 71 | custom_labels: null 72 | # priority_class_name: Optional. The priority for the job, which is created by the cluster admin users. 73 | # Mapping to '--priority' argument in 'start-job' command. 74 | priority_class_name: null 75 | # volumes: Optional. Used to mount temp path to container. Mapping to '--volumes' argument in 'start-job' command. 76 | # Example of usage: 77 | # volumes: 78 | # - volumeName: v1 79 | # hostPath: /data 80 | # mountPath: /data 81 | volumes: null 82 | # labal_selector: Optional. Defines Kubernetes node affinity to select nodes with labels. Following 83 | # config will choose SageMaker HyperPod health labels and prefer nodes with SageMaker Hyperpod burn-in 84 | # test passed label. Mapping to '--label-selector' argument in 'start-job' command. 85 | label_selector: 86 | required: 87 | sagemaker.amazonaws.com/node-health-status: 88 | - Schedulable 89 | preferred: 90 | sagemaker.amazonaws.com/deep-health-check-status: 91 | - Passed 92 | weights: 93 | - 100 94 | # pullPolicy: Required. Kubernetes PyTorchJob pull policy to pull container, can be Always, IfNotPresent and Never. 95 | # Mapping to '--pull-policy' argument in 'start-job' command. 96 | pullPolicy: IfNotPresent 97 | # restartPolicy: Required. Kubernetes PyTorchJob restart policy. Can be OnFailure, Always or Never. 98 | # To use SageMaker Hyperpod AutoResume functionality, please set it to OnFailure. 99 | # Mapping to '--restart-policy' argument in 'start-job' command. 100 | restartPolicy: OnFailure 101 | # scheduler_type: Optional. Used to decide which type of scheduler to use. Default value is 'SageMaker' which makes the job 102 | # only scheduled on queues created via SageMaker. This requires local queue name filled in custom label 'kueue.x-k8s.io/queue-name' 103 | # Another valid value is 'Kueue', with this option, queue name and namespace has to be manually filled out. 104 | # scheduler_type: SageMaker 105 | scheduler_type: Kueue 106 | # base_results_dir: Optional. Location to store the results, checkpoints and logs. 107 | # Mapping to '--results-dir' argument in 'start-job' command. 108 | base_results_dir: ./result 109 | # container: Required. Docker image to be used for Training Job 110 | # Mapping to '--image' argument in 'start-job' command. 111 | container: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-bc09cfd 112 | 113 | # env_vars: Optional. Environment variables passed to the training job. 114 | # Mapping to '--environment' argument in 'start-job' command. 115 | env_vars: 116 | NCCL_DEBUG: INFO # Logging level for NCCL. Set to "INFO" for debug information 117 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/.helmignore: -------------------------------------------------------------------------------- 1 | # Patterns to ignore when building packages. 2 | # This supports shell glob matching, relative path matching, and 3 | # negation (prefixed with !). Only one pattern per line. 4 | .DS_Store 5 | # Common VCS dirs 6 | .git/ 7 | .gitignore 8 | .bzr/ 9 | .bzrignore 10 | .hg/ 11 | .hgignore 12 | .svn/ 13 | # Common backup files 14 | *.swp 15 | *.bak 16 | *.tmp 17 | *.orig 18 | *~ 19 | # Various IDEs 20 | .project 21 | .idea/ 22 | *.tmproj 23 | .vscode/ 24 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: hyperpod-helm-chart 3 | description: A Helm chart for Kubernetes 4 | 5 | # A chart can be either an 'application' or a 'library' chart. 6 | # 7 | # Application charts are a collection of templates that can be packaged into versioned archives 8 | # to be deployed. 9 | # 10 | # Library charts provide useful utilities or functions for the chart developer. They're included as 11 | # a dependency of application charts to inject those utilities and functions into the rendering 12 | # pipeline. Library charts do not define any templates and therefore cannot be deployed. 13 | type: application 14 | 15 | # This is the chart version. This version number should be incremented each time you make changes 16 | # to the chart and its templates, including the app version. 17 | # Versions are expected to follow Semantic Versioning (https://semver.org/) 18 | version: 0.1.0 19 | 20 | # This is the version number of the application being deployed. This version number should be 21 | # incremented each time you make changes to the application. Versions are not expected to 22 | # follow Semantic Versioning. They should reflect the version the application is using. 23 | # It is recommended to use it with quotes. 24 | appVersion: "1.16.0" 25 | 26 | dependencies: 27 | - name: training-operators 28 | version: "0.1.0" 29 | repository: "file://charts/training-operators" 30 | - name: mlflow 31 | version: "0.1.0" 32 | repository: "file://charts/mlflow" 33 | condition: mlflow.enabled 34 | - name: nvidia-device-plugin 35 | version: "0.16.1" 36 | repository: https://nvidia.github.io/k8s-device-plugin 37 | condition: nvidia-device-plugin.devicePlugin.enabled 38 | - name: aws-efa-k8s-device-plugin 39 | version: "0.5.3" 40 | repository: https://aws.github.io/eks-charts/ 41 | condition: aws-efa-k8s-device-plugin.devicePlugin.enabled 42 | - name: neuron-device-plugin 43 | version: "0.1.0" 44 | repository: "file://charts/neuron-device-plugin" 45 | condition: neuron-device-plugin.devicePlugin.enabled 46 | - name: storage 47 | version: "0.1.0" 48 | repository: "file://charts/storage" 49 | condition: storage.enabled 50 | - name: health-monitoring-agent 51 | version: "0.1.0" 52 | repository: "file://charts/health-monitoring-agent" 53 | condition: health-monitoring-agent.enabled 54 | - name: mpi-operator 55 | version: "0.1.0" 56 | repository: "file://charts/mpi-operator" 57 | condition: mpi-operator.enabled 58 | - name: deep-health-check 59 | version: "0.1.0" 60 | repository: "file://charts/deep-health-check" 61 | condition: deep-health-check.enabled 62 | - name: job-auto-restart 63 | version: "0.1.0" 64 | repository: "file://charts/job-auto-restart" 65 | condition: job-auto-restart.enabled 66 | - name: cluster-role-and-bindings 67 | version: "0.1.0" 68 | repository: "file://charts/cluster-role-and-bindings" 69 | condition: cluster-role-and-bindings.enabled 70 | - name: namespaced-role-and-bindings 71 | version: "0.1.0" 72 | repository: "file://charts/namespaced-role-and-bindings" 73 | condition: namespaced-role-and-bindings.enabled 74 | - name: team-role-and-bindings 75 | version: "0.1.0" 76 | repository: "file://charts/team-role-and-bindings" 77 | condition: team-role-and-bindings.enabled 78 | - name: hyperpod-patching 79 | version: "0.1.0" 80 | repository: "file://charts/hyperpod-patching" 81 | condition: hyperpod-patching.enabled 82 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/cluster-role-and-bindings/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: cluster-role-and-bindings 3 | version: 0.1.0 4 | appVersion: 1.0 5 | description: A Helm chart for setting up Hyperpod cluster role and role bindings for cluster role -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/cluster-role-and-bindings/templates/cluster-level-config.yaml: -------------------------------------------------------------------------------- 1 | kind: ClusterRole 2 | apiVersion: rbac.authorization.k8s.io/v1 3 | metadata: 4 | name: {{ .Values.roleName }} 5 | rules: 6 | - apiGroups: [""] 7 | resources: ["pods"] 8 | verbs: ["list"] 9 | - apiGroups: [""] 10 | resources: ["nodes"] 11 | verbs: ["list"] 12 | --- 13 | apiVersion: rbac.authorization.k8s.io/v1 14 | kind: ClusterRoleBinding 15 | metadata: 16 | name: {{ .Values.roleName }}-binding 17 | subjects: 18 | - kind: Group 19 | name: {{ .Values.roleName }}-cluster-level 20 | apiGroup: rbac.authorization.k8s.io 21 | roleRef: 22 | kind: ClusterRole 23 | name: {{ .Values.roleName }} # this must match the name of the Role or ClusterRole you wish to bind to 24 | apiGroup: rbac.authorization.k8s.io -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/cluster-role-and-bindings/values.yaml: -------------------------------------------------------------------------------- 1 | roleName: "hyperpod-scientist-user-cluster-role" -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/deep-health-check/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: deep-health-check 3 | version: 0.1.0 4 | appVersion: 1.0 5 | description: A Helm chart for setting up Hyperpod deep health check related permissions -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/deep-health-check/templates/aws-hyperpod-namespace.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Namespace 3 | metadata: 4 | name: aws-hyperpod 5 | labels: 6 | name: aws-hyperpod 7 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/deep-health-check/templates/deep-health-check-rbac.yaml: -------------------------------------------------------------------------------- 1 | # rbac.yaml 2 | # service account 3 | --- 4 | apiVersion: v1 5 | kind: ServiceAccount 6 | metadata: 7 | name: deep-health-check-service-account 8 | namespace: {{ .Values.namespace }} 9 | --- 10 | kind: ClusterRole 11 | apiVersion: rbac.authorization.k8s.io/v1 12 | metadata: 13 | name: deep-health-check-service-account-role 14 | rules: 15 | - apiGroups: 16 | - "" 17 | resources: 18 | - nodes 19 | verbs: 20 | - get 21 | - list 22 | - apiGroups: 23 | - "" 24 | resources: 25 | - pods 26 | verbs: 27 | - get 28 | - list 29 | - patch 30 | --- 31 | kind: ClusterRoleBinding 32 | apiVersion: rbac.authorization.k8s.io/v1 33 | metadata: 34 | name: deep-health-check-service-account-role-binding 35 | roleRef: 36 | apiGroup: rbac.authorization.k8s.io 37 | kind: ClusterRole 38 | name: deep-health-check-service-account-role 39 | subjects: 40 | - kind: ServiceAccount 41 | name: deep-health-check-service-account 42 | namespace: {{ .Values.namespace }} 43 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/deep-health-check/values.yaml: -------------------------------------------------------------------------------- 1 | namespace: "aws-hyperpod" 2 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/health-monitoring-agent/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: health-monitoring-agent 3 | version: 0.1.0 4 | appVersion: 1.0 5 | description: A Helm chart for setting up Hyperpod health-monitoring-agent related permissions 6 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/health-monitoring-agent/templates/health-monitoring-agent.yaml: -------------------------------------------------------------------------------- 1 | # rbac.yaml 2 | --- 3 | kind: ClusterRole 4 | apiVersion: rbac.authorization.k8s.io/v1 5 | metadata: 6 | name: health-monitoring-agent 7 | rules: 8 | - apiGroups: 9 | - "" 10 | resources: 11 | - nodes 12 | verbs: 13 | - get 14 | - apiGroups: 15 | - "" 16 | resources: 17 | - nodes 18 | - nodes/status 19 | verbs: 20 | - patch 21 | - apiGroups: 22 | - "" 23 | - events.k8s.io 24 | resources: 25 | - events 26 | verbs: 27 | - create 28 | - patch 29 | - update 30 | --- 31 | apiVersion: v1 32 | kind: ServiceAccount 33 | metadata: 34 | name: health-monitoring-agent 35 | namespace: {{ .Values.namespace }} 36 | --- 37 | kind: ClusterRoleBinding 38 | apiVersion: rbac.authorization.k8s.io/v1 39 | metadata: 40 | name: health-monitoring-agent 41 | namespace: {{ .Values.namespace }} 42 | roleRef: 43 | apiGroup: rbac.authorization.k8s.io 44 | kind: ClusterRole 45 | name: health-monitoring-agent 46 | subjects: 47 | - kind: ServiceAccount 48 | name: health-monitoring-agent 49 | namespace: {{ .Values.namespace }} 50 | --- 51 | apiVersion: apps/v1 52 | kind: DaemonSet 53 | metadata: 54 | name: health-monitoring-agent 55 | namespace: {{ .Values.namespace }} 56 | labels: 57 | app: health-monitoring-agent 58 | spec: 59 | selector: 60 | matchLabels: 61 | app: health-monitoring-agent 62 | template: 63 | metadata: 64 | labels: 65 | app: health-monitoring-agent 66 | spec: 67 | affinity: 68 | nodeAffinity: 69 | requiredDuringSchedulingIgnoredDuringExecution: 70 | nodeSelectorTerms: 71 | - matchExpressions: 72 | - key: node.kubernetes.io/instance-type 73 | operator: In 74 | values: 75 | - ml.p5en.48xlarge 76 | - ml.p5e.48xlarge 77 | - ml.p5.48xlarge 78 | - ml.p4d.24xlarge 79 | - ml.p4de.24xlarge 80 | - ml.g5.xlarge 81 | - ml.g5.2xlarge 82 | - ml.g5.4xlarge 83 | - ml.g5.8xlarge 84 | - ml.g5.12xlarge 85 | - ml.g5.16xlarge 86 | - ml.g5.24xlarge 87 | - ml.g5.48xlarge 88 | - ml.inf2.xlarge 89 | - ml.inf2.8xlarge 90 | - ml.inf2.24xlarge 91 | - ml.inf2.48xlarge 92 | - ml.trn1.32xlarge 93 | - ml.trn1n.32xlarge 94 | - ml.g6.xlarge 95 | - ml.g6.2xlarge 96 | - ml.g6.4xlarge 97 | - ml.g6.8xlarge 98 | - ml.g6.16xlarge 99 | - ml.g6.12xlarge 100 | - ml.g6.24xlarge 101 | - ml.g6.48xlarge 102 | - ml.gr6.4xlarge 103 | - ml.gr6.8xlarge 104 | - ml.g6e.xlarge 105 | - ml.g6e.2xlarge 106 | - ml.g6e.4xlarge 107 | - ml.g6e.8xlarge 108 | - ml.g6e.16xlarge 109 | - ml.g6e.12xlarge 110 | - ml.g6e.24xlarge 111 | - ml.g6e.48xlarge 112 | - ml.trn2.48xlarge 113 | containers: 114 | - name: health-monitoring-agent 115 | args: 116 | - --enable-k8s-exporter=false 117 | - --config.system-log-monitor=/config/system-message-monitor.json 118 | image: {{ .Values.hmaimage }} 119 | resources: 120 | limits: 121 | cpu: 500m 122 | memory: 512Mi 123 | requests: 124 | cpu: 500m 125 | memory: 512Mi 126 | imagePullPolicy: IfNotPresent 127 | securityContext: 128 | runAsUser: 1000 129 | runAsGroup: 2000 130 | env: 131 | - name: NODE_NAME 132 | valueFrom: 133 | fieldRef: 134 | fieldPath: spec.nodeName 135 | - name: NODE_IP 136 | valueFrom: 137 | fieldRef: 138 | fieldPath: status.hostIP 139 | volumeMounts: 140 | - name: log 141 | mountPath: /var/log 142 | - name: kmsg 143 | mountPath: /dev/kmsg 144 | readOnly: true 145 | # Make sure node problem detector is in the same timezone 146 | # with the host. 147 | - name: localtime 148 | mountPath: /etc/localtime 149 | readOnly: true 150 | serviceAccountName: health-monitoring-agent 151 | volumes: 152 | - name: log 153 | # Config `log` to your system log directory 154 | hostPath: 155 | path: /var/log/ 156 | - name: kmsg 157 | hostPath: 158 | path: /dev/kmsg 159 | - name: localtime 160 | hostPath: 161 | path: /etc/localtime 162 | tolerations: 163 | - effect: NoSchedule 164 | operator: Exists 165 | - effect: NoExecute 166 | operator: Exists 167 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/health-monitoring-agent/values.yaml: -------------------------------------------------------------------------------- 1 | namespace: "aws-hyperpod" 2 | hmaimage: "905418368575.dkr.ecr.us-west-2.amazonaws.com/hyperpod-health-monitoring-agent:1.0.448.0_1.0.115.0" 3 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/hyperpod-patching/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: hyperpod-patching 3 | description: A subchart for RBAC used by HyperPod patching workflows 4 | version: 0.1.0 5 | appVersion: "1.0" 6 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/hyperpod-patching/templates/clusterrole.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRole 3 | metadata: 4 | name: hyperpod-patching 5 | rules: 6 | - apiGroups: [""] 7 | resources: ["pods"] 8 | verbs: ["list"] 9 | - apiGroups: [""] 10 | resources: ["pods/eviction"] 11 | verbs: ["create"] 12 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/hyperpod-patching/templates/clusterrolebinding.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRoleBinding 3 | metadata: 4 | name: hyperpod-patching 5 | subjects: 6 | - kind: User 7 | name: hyperpod-service-linked-role 8 | apiGroup: rbac.authorization.k8s.io 9 | roleRef: 10 | kind: ClusterRole 11 | name: hyperpod-patching 12 | apiGroup: rbac.authorization.k8s.io 13 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/hyperpod-patching/values.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-hyperpod-cli/84ba2448caaa47c82fa006373f13bbe97a4bb7e6/helm_chart/HyperPodHelmChart/charts/hyperpod-patching/values.yaml -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/job-auto-restart/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: job-auto-restart 3 | version: 0.1.0 4 | appVersion: 1.0 5 | description: A Helm chart for setting up Hyperpod training job auto restart related permissions -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/job-auto-restart/templates/job-auto-restart-rbac.yaml: -------------------------------------------------------------------------------- 1 | # rbac.yaml 2 | --- 3 | kind: ClusterRole 4 | apiVersion: rbac.authorization.k8s.io/v1 5 | metadata: 6 | name: job-auto-restart 7 | rules: 8 | - apiGroups: 9 | - "" 10 | resources: 11 | - nodes 12 | - nodes/status 13 | - pods 14 | - pods/status 15 | verbs: 16 | - get 17 | - list 18 | - watch 19 | - apiGroups: 20 | - "" 21 | resources: 22 | - pods 23 | verbs: 24 | - delete 25 | - deletecollection 26 | - apiGroups: 27 | - "" 28 | resources: 29 | - nodes 30 | - nodes/status 31 | verbs: 32 | - patch 33 | - apiGroups: 34 | - "" 35 | - events.k8s.io 36 | resources: 37 | - events 38 | verbs: 39 | - create 40 | - apiGroups: 41 | - kubeflow.org 42 | resources: 43 | - pytorchjobs 44 | - pytorchjobs/status 45 | verbs: 46 | - get 47 | - list 48 | - watch 49 | - delete 50 | - patch 51 | - update 52 | - describe 53 | - apiGroups: 54 | - coordination.k8s.io 55 | resources: 56 | - leases 57 | verbs: 58 | - get 59 | - create 60 | - update 61 | - delete 62 | --- 63 | apiVersion: v1 64 | kind: ServiceAccount 65 | metadata: 66 | name: job-auto-restart 67 | namespace: {{ .Values.namespace }} 68 | --- 69 | kind: ClusterRoleBinding 70 | apiVersion: rbac.authorization.k8s.io/v1 71 | metadata: 72 | name: job-auto-restart 73 | roleRef: 74 | apiGroup: rbac.authorization.k8s.io 75 | kind: ClusterRole 76 | name: job-auto-restart 77 | subjects: 78 | - kind: User 79 | name: hyperpod-service-linked-role 80 | namespace: {{ .Values.namespace }} 81 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/job-auto-restart/values.yaml: -------------------------------------------------------------------------------- 1 | namespace: "aws-hyperpod" -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/mlflow/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: mlflow 3 | version: 0.1.0 4 | appVersion: 1.0 5 | description: A Helm chart for deploying MLflow 6 | keywords: 7 | - mlflow 8 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/mlflow/templates/service-account.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ServiceAccount 3 | metadata: 4 | name: {{ .Values.mlflow.serviceAccount.name }} 5 | namespace: {{ .Values.mlflow.serviceAccount.namespace }} 6 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/mlflow/values.yaml: -------------------------------------------------------------------------------- 1 | mlflow: 2 | serviceAccount: 3 | name: "mlflow-service-account1" 4 | namespace: "kubeflow" 5 | roleARN: "arn:aws:iam::555555555555:role/hyperpod-mlflow-role" 6 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/mpi-operator/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: mpi-operator 3 | version: 0.1.0 4 | appVersion: "0.5.0" 5 | description: A Helm chart for deploying Kubeflow MPI operator 6 | keywords: 7 | - kubeflow 8 | - mpi 9 | - operator 10 | 11 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/mpi-operator/templates/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{/* 2 | Expand the name of the chart. 3 | */}} 4 | {{- define "mpi-operator.name" -}} 5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} 6 | {{- end }} 7 | 8 | {{/* 9 | */}} 10 | {{- define "mpi-operator.fullname" -}} 11 | {{- if .Values.fullnameOverride }} 12 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} 13 | {{- else }} 14 | {{- $name := default .Chart.Name .Values.nameOverride }} 15 | {{- if contains $name .Release.Name }} 16 | {{- .Release.Name | trunc 63 | trimSuffix "-" }} 17 | {{- else }} 18 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} 19 | {{- end }} 20 | {{- end }} 21 | {{- end }} 22 | 23 | {{/* 24 | Create chart name and version as used by the chart label. 25 | */}} 26 | {{- define "mpi-operator.chart" -}} 27 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} 28 | {{- end }} 29 | 30 | {{/* 31 | Common labels 32 | */}} 33 | {{- define "mpi-operator.labels" -}} 34 | helm.sh/chart: {{ include "mpi-operator.chart" . }} 35 | {{ include "mpi-operator.selectorLabels" . }} 36 | {{- if .Chart.AppVersion }} 37 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} 38 | {{- end }} 39 | app.kubernetes.io/managed-by: {{ .Release.Service }} 40 | {{- if .Values.extraLabels }} 41 | {{ toYaml .Values.extraLabels }} 42 | {{- end }} 43 | {{- end }} 44 | 45 | {{/* 46 | Selector labels 47 | */}} 48 | {{- define "mpi-operator.selectorLabels" -}} 49 | app.kubernetes.io/name: {{ include "mpi-operator.name" . }} 50 | app.kubernetes.io/instance: {{ .Release.Name }} 51 | {{- end }} 52 | 53 | {{/* 54 | Create the name of the service account to use 55 | */}} 56 | {{- define "mpi-operator.serviceAccountName" -}} 57 | {{- if .Values.serviceAccount.create }} 58 | {{- default (include "mpi-operator.fullname" .) .Values.serviceAccount.name }} 59 | {{- else }} 60 | {{- default "default" .Values.serviceAccount.name }} 61 | {{- end }} 62 | {{- end }} -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/mpi-operator/templates/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: {{ include "mpi-operator.fullname" . }} 5 | labels: 6 | app: mpi-operator 7 | app.kubernetes.io/component: mpijob 8 | kustomize.component: mpi-operator 9 | {{- include "mpi-operator.labels" . | nindent 4 }} 10 | spec: 11 | replicas: {{ .Values.mpiOperator.replicas }} 12 | selector: 13 | matchLabels: 14 | app: mpi-operator 15 | app.kubernetes.io/component: mpijob 16 | app.kubernetes.io/name: mpi-operator 17 | kustomize.component: mpi-operator 18 | {{- include "mpi-operator.selectorLabels" . | nindent 6 }} 19 | template: 20 | metadata: 21 | labels: 22 | app: mpi-operator 23 | app.kubernetes.io/component: mpijob 24 | app.kubernetes.io/name: mpi-operator 25 | kustomize.component: mpi-operator 26 | {{- include "mpi-operator.labels" . | nindent 8 }} 27 | {{- include "mpi-operator.selectorLabels" . | nindent 8 }} 28 | annotations: 29 | sidecar.istio.io/inject: "false" 30 | spec: 31 | {{- if .Values.mpiOperator.affinity }} 32 | affinity: 33 | {{ toYaml .Values.mpiOperator.affinity | indent 8 }} 34 | {{- end }} 35 | {{- if .Values.mpiOperator.nodeSelector }} 36 | nodeSelector: 37 | {{ toYaml .Values.mpiOperator.nodeSelector | indent 8 }} 38 | {{- end }} 39 | {{- if .Values.mpiOperator.tolerations }} 40 | tolerations: 41 | {{ toYaml .Values.mpiOperator.tolerations | indent 8 }} 42 | {{- end }} 43 | {{- if .Values.mpiOperator.topologySpreadConstraints }} 44 | topologySpreadConstraints: 45 | {{ toYaml .Values.mpiOperator.topologySpreadConstraints | indent 8 }} 46 | {{- end }} 47 | containers: 48 | - args: 49 | {{- toYaml .Values.mpiOperator.additionalArgs | nindent 8 }} 50 | env: 51 | - name: KUBERNETES_CLUSTER_DOMAIN 52 | value: "cluster.local" 53 | image: {{ .Values.mpiOperator.image.repository }}:{{ .Values.mpiOperator.image.tag | default .Chart.AppVersion }} 54 | name: mpi-operator 55 | imagePullPolicy: {{ .Values.mpiOperator.imagePullPolicy | default "IfNotPresent "}} 56 | resources: {} 57 | serviceAccountName: {{ include "mpi-operator.fullname" . }} -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/mpi-operator/templates/rbac.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Namespace 3 | metadata: 4 | labels: 5 | app: mpi-operator 6 | app.kubernetes.io/component: mpijob 7 | app.kubernetes.io/name: mpi-operator 8 | kustomize.component: mpi-operator 9 | name: mpi-operator 10 | --- 11 | apiVersion: v1 12 | kind: ServiceAccount 13 | metadata: 14 | name: {{ include "mpi-operator.fullname" . }} 15 | labels: 16 | app: mpi-operator 17 | app.kubernetes.io/component: mpijob 18 | kustomize.component: mpi-operator 19 | {{- include "mpi-operator.labels" . | nindent 4 }} 20 | annotations: 21 | {{- toYaml .Values.mpiOperator.serviceAccount.annotations | nindent 4 }} 22 | --- 23 | apiVersion: rbac.authorization.k8s.io/v1 24 | kind: ClusterRole 25 | metadata: 26 | name: {{ include "mpi-operator.fullname" . }}-kubeflow-mpijobs-admin 27 | labels: 28 | app: mpi-operator 29 | app.kubernetes.io/component: mpijob 30 | kustomize.component: mpi-operator 31 | rbac.authorization.kubeflow.org/aggregate-to-kubeflow-admin: "true" 32 | {{- include "mpi-operator.labels" . | nindent 4 }} 33 | aggregationRule: 34 | clusterRoleSelectors: 35 | - matchLabels: 36 | rbac.authorization.kubeflow.org/aggregate-to-kubeflow-mpijobs-admin: "true" 37 | rules: [] 38 | --- 39 | apiVersion: rbac.authorization.k8s.io/v1 40 | kind: ClusterRole 41 | metadata: 42 | name: {{ include "mpi-operator.fullname" . }}-kubeflow-mpijobs-edit 43 | labels: 44 | app: mpi-operator 45 | app.kubernetes.io/component: mpijob 46 | kustomize.component: mpi-operator 47 | rbac.authorization.kubeflow.org/aggregate-to-kubeflow-edit: "true" 48 | rbac.authorization.kubeflow.org/aggregate-to-kubeflow-mpijobs-admin: "true" 49 | {{- include "mpi-operator.labels" . | nindent 4 }} 50 | rules: 51 | - apiGroups: 52 | - kubeflow.org 53 | resources: 54 | - mpijobs 55 | - mpijobs/status 56 | verbs: 57 | - get 58 | - list 59 | - watch 60 | - create 61 | - delete 62 | - deletecollection 63 | - patch 64 | - update 65 | --- 66 | apiVersion: rbac.authorization.k8s.io/v1 67 | kind: ClusterRole 68 | metadata: 69 | name: {{ include "mpi-operator.fullname" . }}-kubeflow-mpijobs-view 70 | labels: 71 | app: mpi-operator 72 | app.kubernetes.io/component: mpijob 73 | kustomize.component: mpi-operator 74 | rbac.authorization.kubeflow.org/aggregate-to-kubeflow-view: "true" 75 | {{- include "mpi-operator.labels" . | nindent 4 }} 76 | rules: 77 | - apiGroups: 78 | - kubeflow.org 79 | resources: 80 | - mpijobs 81 | - mpijobs/status 82 | verbs: 83 | - get 84 | - list 85 | - watch 86 | --- 87 | apiVersion: rbac.authorization.k8s.io/v1 88 | kind: ClusterRole 89 | metadata: 90 | name: {{ include "mpi-operator.fullname" . }} 91 | labels: 92 | app: mpi-operator 93 | app.kubernetes.io/component: mpijob 94 | kustomize.component: mpi-operator 95 | {{- include "mpi-operator.labels" . | nindent 4 }} 96 | rules: 97 | - apiGroups: 98 | - "" 99 | resources: 100 | - configmaps 101 | - secrets 102 | - services 103 | verbs: 104 | - create 105 | - list 106 | - watch 107 | - update 108 | - apiGroups: 109 | - "" 110 | resources: 111 | - pods 112 | verbs: 113 | - create 114 | - get 115 | - list 116 | - watch 117 | - delete 118 | - update 119 | - patch 120 | - apiGroups: 121 | - "" 122 | resources: 123 | - pods/exec 124 | verbs: 125 | - create 126 | - apiGroups: 127 | - "" 128 | resources: 129 | - endpoints 130 | verbs: 131 | - create 132 | - get 133 | - update 134 | - apiGroups: 135 | - "" 136 | resources: 137 | - events 138 | verbs: 139 | - create 140 | - patch 141 | - apiGroups: 142 | - batch 143 | resources: 144 | - jobs 145 | verbs: 146 | - create 147 | - list 148 | - update 149 | - watch 150 | - apiGroups: 151 | - apiextensions.k8s.io 152 | resources: 153 | - customresourcedefinitions 154 | verbs: 155 | - create 156 | - get 157 | - apiGroups: 158 | - kubeflow.org 159 | resources: 160 | - mpijobs 161 | - mpijobs/finalizers 162 | - mpijobs/status 163 | verbs: 164 | - '*' 165 | - apiGroups: 166 | - coordination.k8s.io 167 | resources: 168 | - leases 169 | verbs: 170 | - '*' 171 | - apiGroups: 172 | - scheduling.incubator.k8s.io 173 | - scheduling.sigs.dev 174 | - scheduling.volcano.sh 175 | resources: 176 | - queues 177 | - podgroups 178 | verbs: 179 | - '*' 180 | - apiGroups: 181 | - scheduling.x-k8s.io 182 | resources: 183 | - podgroups 184 | verbs: 185 | - '*' 186 | - apiGroups: 187 | - scheduling.k8s.io 188 | resources: 189 | - priorityclasses 190 | verbs: 191 | - get 192 | - list 193 | - watch 194 | --- 195 | apiVersion: rbac.authorization.k8s.io/v1 196 | kind: ClusterRoleBinding 197 | metadata: 198 | name: {{ include "mpi-operator.fullname" . }} 199 | labels: 200 | app: mpi-operator 201 | app.kubernetes.io/component: mpijob 202 | kustomize.component: mpi-operator 203 | {{- include "mpi-operator.labels" . | nindent 4 }} 204 | roleRef: 205 | apiGroup: rbac.authorization.k8s.io 206 | kind: ClusterRole 207 | name: '{{ include "mpi-operator.fullname" . }}' 208 | subjects: 209 | - kind: ServiceAccount 210 | name: '{{ include "mpi-operator.fullname" . }}' 211 | namespace: '{{ .Release.Namespace }}' -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/mpi-operator/values.yaml: -------------------------------------------------------------------------------- 1 | mpiOperator: 2 | additionalArgs: 3 | - -alsologtostderr 4 | # Increase default logging, default is very terse 5 | - -stderrthreshold=INFO 6 | - -v=4 7 | image: 8 | repository: mpioperator/mpi-operator 9 | tag: "0.5" 10 | replicas: 1 11 | serviceAccount: 12 | annotations: {} 13 | 14 | ## Node labels for pod assignment 15 | ## Ref: https://kubernetes.io/docs/user-guide/node-selection/ 16 | nodeSelector: {} 17 | 18 | ## Affinity settings for pod assignment 19 | ## Ref: https://kubernetes.io/docs/concepts/configuration/assign-pod-node/ 20 | affinity: {} 21 | 22 | ## Tolerations for pod assignment 23 | ## Ref: https://kubernetes.io/docs/concepts/configuration/taint-and-toleration/ 24 | tolerations: 25 | - key: sagemaker.amazonaws.com/node-health-status 26 | operator: "Equal" 27 | value: "Unschedulable" 28 | effect: "NoSchedule" 29 | 30 | ## Topology spread constraints for pod assignment 31 | ## Ref: https://kubernetes.io/docs/concepts/workloads/pods/pod-topology-spread-constraints/ 32 | topologySpreadConstraints: [] 33 | 34 | ## Image Pull Policy for operator deployment 35 | imagePullPolicy: IfNotPresent 36 | 37 | ## Apply extra labels to all created resources 38 | extraLabels: {} -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/namespaced-role-and-bindings/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: namespaced-role-and-bindings 3 | version: 0.1.0 4 | appVersion: 1.0 5 | description: A Helm chart for setting up Hyperpod role and role bindings for cluster role in a namespace -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/namespaced-role-and-bindings/templates/namespace-level-role.yaml: -------------------------------------------------------------------------------- 1 | kind: Role 2 | apiVersion: rbac.authorization.k8s.io/v1 3 | metadata: 4 | namespace: {{ .Values.namespace }} 5 | name: {{ .Values.roleName }} 6 | ### 7 | # 1) add/list/describe/delete pods 8 | # 2) get/list/watch/create/patch/update/delete/describe kubeflow pytroch job 9 | # 3) get pod log 10 | ### 11 | rules: 12 | - apiGroups: [""] 13 | resources: ["pods"] 14 | verbs: ["create", "get"] 15 | - apiGroups: [""] 16 | resources: ["nodes"] 17 | verbs: ["get", "list"] 18 | - apiGroups: [""] 19 | resources: ["pods/log"] 20 | verbs: ["get", "list"] 21 | - apiGroups: [""] 22 | resources: ["pods/exec"] 23 | verbs: ["get", "create"] 24 | - apiGroups: ["kubeflow.org"] 25 | resources: ["pytorchjobs", "pytorchjobs/status"] 26 | verbs: ["get", "list", "create", "delete", "update", "describe"] 27 | - apiGroups: [""] 28 | resources: ["configmaps"] 29 | verbs: ["create", "update", "get", "delete", "list"] 30 | - apiGroups: [""] 31 | resources: ["secrets"] 32 | verbs: ["create", "get", "list", "delete"] 33 | --- 34 | apiVersion: rbac.authorization.k8s.io/v1 35 | kind: RoleBinding 36 | metadata: 37 | namespace: kubeflow 38 | name: {{ .Values.roleName }}-binding 39 | subjects: 40 | - kind: Group 41 | name: {{ .Values.roleName }}-{{ .Values.namespace }}-level 42 | apiGroup: rbac.authorization.k8s.io 43 | roleRef: 44 | kind: Role 45 | name: {{ .Values.roleName }} # this must match the name of the Role or ClusterRole you wish to bind to 46 | apiGroup: rbac.authorization.k8s.io -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/namespaced-role-and-bindings/values.yaml: -------------------------------------------------------------------------------- 1 | namespace: "kubeflow" 2 | roleName: "hyperpod-scientist-user-namespace-level-role" -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/neuron-device-plugin/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | name: neuron-device-plugin 3 | description: A Helm chart for Neuron device plugin. 4 | version: v0.1.0 5 | appVersion: "v2.19.16" 6 | home: https://github.com/aws-neuron/ 7 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/neuron-device-plugin/templates/k8s-neuron-device-plugin-rbac.yaml: -------------------------------------------------------------------------------- 1 | # rbac.yaml 2 | --- 3 | kind: ClusterRole 4 | apiVersion: rbac.authorization.k8s.io/v1 5 | metadata: 6 | name: neuron-device-plugin 7 | rules: 8 | - apiGroups: 9 | - "" 10 | resources: 11 | - nodes 12 | verbs: 13 | - get 14 | - list 15 | - watch 16 | - apiGroups: 17 | - "" 18 | resources: 19 | - events 20 | verbs: 21 | - create 22 | - patch 23 | - apiGroups: 24 | - "" 25 | resources: 26 | - pods 27 | verbs: 28 | - update 29 | - patch 30 | - get 31 | - list 32 | - watch 33 | - apiGroups: 34 | - "" 35 | resources: 36 | - nodes/status 37 | verbs: 38 | - patch 39 | - update 40 | --- 41 | apiVersion: v1 42 | kind: ServiceAccount 43 | metadata: 44 | name: neuron-device-plugin 45 | namespace: {{ .Values.namespace }} 46 | --- 47 | kind: ClusterRoleBinding 48 | apiVersion: rbac.authorization.k8s.io/v1 49 | metadata: 50 | name: neuron-device-plugin 51 | namespace: {{ .Values.namespace }} 52 | roleRef: 53 | apiGroup: rbac.authorization.k8s.io 54 | kind: ClusterRole 55 | name: neuron-device-plugin 56 | subjects: 57 | - kind: ServiceAccount 58 | name: neuron-device-plugin 59 | namespace: {{ .Values.namespace }} -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/neuron-device-plugin/templates/k8s-neuron-device-plugin.yaml: -------------------------------------------------------------------------------- 1 | # https://kubernetes.io/docs/concepts/extend-kubernetes/compute-storage-net/device-plugins/ 2 | apiVersion: apps/v1 3 | kind: DaemonSet 4 | metadata: 5 | name: neuron-device-plugin-daemonset 6 | namespace: {{ .Values.namespace }} 7 | spec: 8 | selector: 9 | matchLabels: 10 | name: neuron-device-plugin-ds 11 | updateStrategy: 12 | type: RollingUpdate 13 | template: 14 | metadata: 15 | # Uncomment the annotation below if k8s version is 1.13 or lower 16 | # annotations: 17 | # scheduler.alpha.kubernetes.io/critical-pod: "" 18 | labels: 19 | name: neuron-device-plugin-ds 20 | spec: 21 | serviceAccount: neuron-device-plugin 22 | tolerations: 23 | - key: CriticalAddonsOnly 24 | operator: Exists 25 | - key: aws.amazon.com/neuron 26 | operator: Exists 27 | effect: NoSchedule 28 | - key: sagemaker.amazonaws.com/node-health-status 29 | operator: Equal 30 | value: Unschedulable 31 | effect: NoSchedule 32 | # Mark this pod as a critical add-on; when enabled, the critical add-on 33 | # scheduler reserves resources for critical add-on pods so that they can 34 | # be rescheduled after a failure. 35 | # See https://kubernetes.io/docs/tasks/administer-cluster/guaranteed-scheduling-critical-addon-pods/ 36 | priorityClassName: "system-node-critical" 37 | affinity: 38 | nodeAffinity: 39 | requiredDuringSchedulingIgnoredDuringExecution: 40 | nodeSelectorTerms: 41 | - matchExpressions: 42 | - key: "node.kubernetes.io/instance-type" 43 | operator: In 44 | values: 45 | - inf1.xlarge 46 | - inf1.2xlarge 47 | - inf1.6xlarge 48 | - inf1.24xlarge 49 | - inf2.xlarge 50 | - inf2.8xlarge 51 | - inf2.24xlarge 52 | - inf2.48xlarge 53 | - trn1.2xlarge 54 | - trn1.32xlarge 55 | - trn1n.32xlarge 56 | - ml.inf2.xlarge 57 | - ml.inf2.8xlarge 58 | - ml.inf2.24xlarge 59 | - ml.inf2.48xlarge 60 | - ml.trn1.2xlarge 61 | - ml.trn1.32xlarge 62 | - ml.trn1n.32xlarge 63 | - ml.trn2.48xlarge 64 | containers: 65 | # Find all neuron-device-plugin images at https://gallery.ecr.aws/neuron/neuron-device-plugin 66 | - image: public.ecr.aws/neuron/neuron-device-plugin:2.19.16.0 67 | imagePullPolicy: Always 68 | name: neuron-device-plugin 69 | env: 70 | - name: KUBECONFIG 71 | value: /etc/kubernetes/kubelet.conf 72 | - name: NODE_NAME 73 | valueFrom: 74 | fieldRef: 75 | fieldPath: spec.nodeName 76 | securityContext: 77 | allowPrivilegeEscalation: false 78 | capabilities: 79 | drop: ["ALL"] 80 | volumeMounts: 81 | - name: device-plugin 82 | mountPath: /var/lib/kubelet/device-plugins 83 | - name: infa-map 84 | mountPath: /run 85 | volumes: 86 | - name: device-plugin 87 | hostPath: 88 | path: /var/lib/kubelet/device-plugins 89 | - name: infa-map 90 | hostPath: 91 | path: /run 92 | 93 | 94 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/neuron-device-plugin/values.yaml: -------------------------------------------------------------------------------- 1 | namespace: "kube-system" -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/storage/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: storage 3 | version: 0.1.0 4 | appVersion: 1.0 5 | description: A Helm chart for deploying storage configuration in K8 6 | keywords: 7 | - Fsx 8 | - EBS 9 | 10 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/storage/templates/persistent-volume-claim.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolumeClaim 3 | metadata: 4 | name: {{ .Values.persistentVolumeClaim.name }} 5 | spec: 6 | accessModes: 7 | {{- range .Values.persistentVolumeClaim.accessModes }} 8 | - {{ . }} 9 | {{- end }} 10 | storageClassName: {{ .Values.persistentVolumeClaim.storageClassName }} 11 | resources: 12 | requests: 13 | storage: {{ .Values.persistentVolumeClaim.requests.storage }} -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/storage/templates/persistent-volume.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: PersistentVolume 3 | metadata: 4 | name: {{ .Values.persistentVolume.name }} 5 | spec: 6 | capacity: 7 | storage: {{ .Values.persistentVolume.capacity }} 8 | volumeMode: {{ .Values.persistentVolume.volumeMode }} 9 | accessModes: 10 | {{- range .Values.persistentVolume.accessModes }} 11 | - {{ . }} 12 | {{- end }} 13 | mountOptions: 14 | {{- range .Values.persistentVolume.mountOptions }} 15 | - {{ . }} 16 | {{- end }} 17 | persistentVolumeReclaimPolicy: {{ .Values.persistentVolume.reclaimPolicy }} 18 | csi: 19 | driver: {{ .Values.persistentVolume.csi.driver }} 20 | volumeHandle: {{ .Values.persistentVolume.csi.volumeHandle }} 21 | volumeAttributes: 22 | dnsname: {{ .Values.persistentVolume.csi.volumeAttributes.dnsname }} 23 | mountname: {{ .Values.persistentVolume.csi.volumeAttributes.mountname }} -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/storage/values.yaml: -------------------------------------------------------------------------------- 1 | persistentVolume: 2 | name: fsx-pv 3 | capacity: 1200Gi 4 | volumeMode: Filesystem 5 | accessModes: 6 | - ReadWriteMany 7 | mountOptions: 8 | - flock 9 | reclaimPolicy: Retain 10 | csi: 11 | driver: fsx.csi.aws.com 12 | volumeHandle: ??? 13 | volumeAttributes: 14 | dnsname: ??? 15 | mountname: fsx 16 | 17 | persistentVolumeClaim: 18 | name: fsx-claim 19 | storageClassName: fsx-sc 20 | accessModes: 21 | - ReadWriteMany 22 | requests: 23 | storage: 1200Gi -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/team-role-and-bindings/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: team-role-and-bindings 3 | version: 0.1.0 4 | appVersion: 1.0 5 | description: This chart installs the namespaced and cluster roles and bindings for team members -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/team-role-and-bindings/templates/team-level-cluster-role.yaml: -------------------------------------------------------------------------------- 1 | {{- $targe_id := .Values.computeQuotaTarget.targetId | required ".Values.computeQuotaTarget.targetId is required." -}} 2 | kind: ClusterRole 3 | apiVersion: rbac.authorization.k8s.io/v1 4 | metadata: 5 | name: hyperpod-ns-{{ $targe_id }}-cluster-role 6 | rules: 7 | - apiGroups: [ "" ] 8 | resources: [ "namespaces" ] 9 | verbs: ["list", "get"] 10 | - apiGroups: [ "kueue.x-k8s.io" ] 11 | resources: [ "clusterqueues" ] 12 | resourceNames: ["hyperpod-ns-{{ $targe_id }}-clusterqueue"] 13 | verbs: [ "list", "get", "watch" ] 14 | - apiGroups: [ "kueue.x-k8s.io" ] 15 | resources: [ "resourceflavors", "workloadpriorityclasses" ] 16 | verbs: [ "list", "get", "watch" ] 17 | --- 18 | apiVersion: rbac.authorization.k8s.io/v1 19 | kind: ClusterRoleBinding 20 | metadata: 21 | name: hyperpod-ns-{{ $targe_id }}-cluster-role-binding 22 | subjects: 23 | - kind: Group 24 | name: hyperpod-ns-{{ $targe_id }}-namespace-level-group 25 | apiGroup: rbac.authorization.k8s.io 26 | roleRef: 27 | kind: ClusterRole 28 | name: hyperpod-ns-{{ $targe_id }}-cluster-role 29 | apiGroup: rbac.authorization.k8s.io -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/team-role-and-bindings/templates/team-level-namespaced-role.yaml: -------------------------------------------------------------------------------- 1 | {{- $targe_id := .Values.computeQuotaTarget.targetId | required ".Values.computeQuotaTarget.targetId is required." -}} 2 | kind: Role 3 | apiVersion: rbac.authorization.k8s.io/v1 4 | metadata: 5 | namespace: hyperpod-ns-{{ $targe_id }} 6 | name: hyperpod-ns-{{ $targe_id }}-namespace-role 7 | rules: 8 | - apiGroups: [""] 9 | resources: ["namespaces"] 10 | verbs: ["get", "describe"] 11 | - apiGroups: ["batch"] 12 | resources: ["jobs"] 13 | verbs: ["get", "list", "watch", "create", "update", "patch", "delete"] 14 | - apiGroups: [""] 15 | resources: ["pods"] 16 | verbs: ["create", "get", "watch", "list"] 17 | - apiGroups: [""] 18 | resources: ["nodes", "nodes/status", "pods/status"] 19 | verbs: ["get", "list", "watch"] 20 | - apiGroups: [""] 21 | resources: ["pods/log"] 22 | verbs: ["get", "list"] 23 | - apiGroups: ["kubeflow.org"] 24 | resources: ["pytorchjobs", "pytorchjobs/status"] 25 | verbs: ["get", "list", "watch", "create", "delete", "update"] 26 | - apiGroups: [""] 27 | resources: ["serviceaccounts"] 28 | verbs: ["get", "list"] 29 | - apiGroups: [""] 30 | resources: ["configmaps"] 31 | verbs: ["create", "update", "get", "delete"] 32 | - apiGroups: [""] 33 | resources: ["secrets"] 34 | verbs: ["create", "get", "list", "delete"] 35 | - apiGroups: [ "apiextensions.k8s.io" ] 36 | resources: [ "customresourcedefinitions" ] 37 | verbs: ["list", "get"] 38 | - apiGroups: [ "kueue.x-k8s.io" ] 39 | resources: [ "workloads", "workloads/status" ] 40 | verbs: ["get", "list", "watch", "patch"] 41 | --- 42 | apiVersion: rbac.authorization.k8s.io/v1 43 | kind: RoleBinding 44 | metadata: 45 | namespace: hyperpod-ns-{{ $targe_id }} 46 | name: hyperpod-ns-{{ $targe_id }}-namespace-role-binding 47 | subjects: 48 | - kind: Group 49 | name: hyperpod-ns-{{ $targe_id }}-namespace-level-group 50 | apiGroup: rbac.authorization.k8s.io 51 | roleRef: 52 | kind: Role 53 | name: hyperpod-ns-{{ $targe_id }}-namespace-role 54 | apiGroup: rbac.authorization.k8s.io 55 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/team-role-and-bindings/values.yaml: -------------------------------------------------------------------------------- 1 | computeQuotaTarget: 2 | targetId: null -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: training-operators 3 | version: 0.1.0 4 | appVersion: v1.7.0 5 | description: A Helm chart for deploying Kubeflow Training Operators 6 | keywords: 7 | - kubeflow 8 | - training 9 | - operators 10 | - pytorchjob -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/templates/ClusterRole/kubeflow-training-admin-ClusterRole.yaml: -------------------------------------------------------------------------------- 1 | aggregationRule: 2 | clusterRoleSelectors: 3 | - matchLabels: 4 | rbac.authorization.kubeflow.org/aggregate-to-kubeflow-training-admin: 'true' 5 | apiVersion: rbac.authorization.k8s.io/v1 6 | kind: ClusterRole 7 | metadata: 8 | labels: 9 | rbac.authorization.kubeflow.org/aggregate-to-kubeflow-admin: 'true' 10 | name: kubeflow-training-admin 11 | rules: [] 12 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/templates/ClusterRole/kubeflow-training-edit-ClusterRole.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRole 3 | metadata: 4 | labels: 5 | rbac.authorization.kubeflow.org/aggregate-to-kubeflow-edit: 'true' 6 | rbac.authorization.kubeflow.org/aggregate-to-kubeflow-training-admin: 'true' 7 | name: kubeflow-training-edit 8 | rules: 9 | - apiGroups: 10 | - kubeflow.org 11 | resources: 12 | - mpijobs 13 | - tfjobs 14 | - pytorchjobs 15 | - mxjobs 16 | - xgboostjobs 17 | - paddlejobs 18 | verbs: 19 | - create 20 | - delete 21 | - get 22 | - list 23 | - patch 24 | - update 25 | - watch 26 | - apiGroups: 27 | - kubeflow.org 28 | resources: 29 | - mpijobs/status 30 | - tfjobs/status 31 | - pytorchjobs/status 32 | - mxjobs/status 33 | - xgboostjobs/status 34 | - paddlejobs/status 35 | verbs: 36 | - get 37 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/templates/ClusterRole/kubeflow-training-view-ClusterRole.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRole 3 | metadata: 4 | labels: 5 | rbac.authorization.kubeflow.org/aggregate-to-kubeflow-view: 'true' 6 | name: kubeflow-training-view 7 | rules: 8 | - apiGroups: 9 | - kubeflow.org 10 | resources: 11 | - mpijobs 12 | - tfjobs 13 | - pytorchjobs 14 | - mxjobs 15 | - xgboostjobs 16 | - paddlejobs 17 | verbs: 18 | - get 19 | - list 20 | - watch 21 | - apiGroups: 22 | - kubeflow.org 23 | resources: 24 | - mpijobs/status 25 | - tfjobs/status 26 | - pytorchjobs/status 27 | - mxjobs/status 28 | - xgboostjobs/status 29 | - paddlejobs/status 30 | verbs: 31 | - get 32 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/templates/ClusterRole/training-operator-ClusterRole.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRole 3 | metadata: 4 | labels: 5 | app: training-operator 6 | name: training-operator 7 | rules: 8 | - apiGroups: 9 | - kubeflow.org 10 | resources: 11 | - mpijobs 12 | - tfjobs 13 | - mxjobs 14 | - pytorchjobs 15 | - xgboostjobs 16 | - paddlejobs 17 | - mpijobs/status 18 | - tfjobs/status 19 | - pytorchjobs/status 20 | - mxjobs/status 21 | - xgboostjobs/status 22 | - paddlejobs/status 23 | - mpijobs/finalizers 24 | - tfjobs/finalizers 25 | - pytorchjobs/finalizers 26 | - mxjobs/finalizers 27 | - xgboostjobs/finalizers 28 | - paddlejobs/finalizers 29 | verbs: 30 | - create 31 | - delete 32 | - get 33 | - list 34 | - patch 35 | - update 36 | - watch 37 | - apiGroups: 38 | - '' 39 | resources: 40 | - pods 41 | - services 42 | - endpoints 43 | - events 44 | verbs: 45 | - '*' 46 | - apiGroups: 47 | - apps 48 | - extensions 49 | resources: 50 | - deployments 51 | verbs: 52 | - '*' 53 | - apiGroups: 54 | - '' 55 | resources: 56 | - pods/exec 57 | verbs: 58 | - create 59 | - apiGroups: 60 | - rbac.authorization.k8s.io 61 | resources: 62 | - roles 63 | - rolebindings 64 | verbs: 65 | - create 66 | - list 67 | - watch 68 | - update 69 | - apiGroups: 70 | - '' 71 | resources: 72 | - configmaps 73 | - secrets 74 | - serviceaccounts 75 | verbs: 76 | - create 77 | - list 78 | - watch 79 | - update 80 | - apiGroups: 81 | - scheduling.volcano.sh 82 | resources: 83 | - podgroups 84 | verbs: 85 | - '*' 86 | - apiGroups: 87 | - autoscaling 88 | resources: 89 | - horizontalpodautoscalers 90 | verbs: 91 | - '*' 92 | - apiGroups: 93 | - scheduling.sigs.k8s.io 94 | resources: 95 | - podgroups 96 | verbs: 97 | - '*' 98 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/templates/ClusterRoleBinding/training-operator-ClusterRoleBinding.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRoleBinding 3 | metadata: 4 | labels: 5 | app: training-operator 6 | name: training-operator 7 | roleRef: 8 | apiGroup: rbac.authorization.k8s.io 9 | kind: ClusterRole 10 | name: training-operator 11 | subjects: 12 | - kind: ServiceAccount 13 | name: training-operator 14 | namespace: kubeflow 15 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/templates/Deployment/training-operator-kubeflow-Deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | labels: 5 | control-plane: kubeflow-training-operator 6 | name: {{ include "training-operator.fullname" . }} 7 | namespace: kubeflow 8 | spec: 9 | replicas: 1 10 | selector: 11 | matchLabels: 12 | control-plane: kubeflow-training-operator 13 | template: 14 | metadata: 15 | annotations: 16 | sidecar.istio.io/inject: 'false' 17 | labels: 18 | control-plane: kubeflow-training-operator 19 | spec: 20 | containers: 21 | - command: 22 | - /manager 23 | {{- if .Values.schemes }} 24 | {{- range .Values.schemes }} 25 | - --enable-scheme={{ . }} 26 | {{- end }} 27 | {{- end }} 28 | env: 29 | - name: MY_POD_NAMESPACE 30 | valueFrom: 31 | fieldRef: 32 | fieldPath: metadata.namespace 33 | - name: MY_POD_NAME 34 | valueFrom: 35 | fieldRef: 36 | fieldPath: metadata.name 37 | image: {{ .Values.image.repository }}:{{ .Values.image.tag }} 38 | livenessProbe: 39 | httpGet: 40 | path: /healthz 41 | port: 8081 42 | initialDelaySeconds: 15 43 | periodSeconds: 20 44 | timeoutSeconds: 3 45 | name: training-operator 46 | ports: 47 | - containerPort: 8080 48 | readinessProbe: 49 | httpGet: 50 | path: /readyz 51 | port: 8081 52 | initialDelaySeconds: 10 53 | periodSeconds: 15 54 | timeoutSeconds: 3 55 | securityContext: 56 | allowPrivilegeEscalation: false 57 | serviceAccountName: training-operator 58 | terminationGracePeriodSeconds: 10 -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/templates/Service/training-operator-kubeflow-Service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | annotations: 5 | prometheus.io/path: /metrics 6 | prometheus.io/port: '8080' 7 | prometheus.io/scrape: 'true' 8 | labels: 9 | app: training-operator 10 | name: training-operator 11 | namespace: kubeflow 12 | spec: 13 | ports: 14 | - name: monitoring-port 15 | port: 8080 16 | targetPort: 8080 17 | selector: 18 | control-plane: kubeflow-training-operator 19 | type: ClusterIP 20 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/templates/ServiceAccount/training-operator-kubeflow-ServiceAccount.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ServiceAccount 3 | metadata: 4 | labels: 5 | app: training-operator 6 | name: training-operator 7 | namespace: kubeflow 8 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/templates/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{/* 2 | Expand the name of the chart. 3 | */}} 4 | {{- define "training-operator.name" -}} 5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} 6 | {{- end }} 7 | 8 | {{/* 9 | Create a default fully qualified app name. 10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). 11 | If release name contains chart name it will be used as a full name. 12 | */}} 13 | {{- define "training-operator.fullname" -}} 14 | {{- if .Values.fullnameOverride }} 15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} 16 | {{- else }} 17 | {{- $name := default .Chart.Name .Values.nameOverride }} 18 | {{- if contains $name .Release.Name }} 19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }} 20 | {{- else }} 21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} 22 | {{- end }} 23 | {{- end }} 24 | {{- end }} 25 | 26 | {{/* 27 | Create chart name and version as used by the chart label. 28 | */}} 29 | {{- define "training-operator.chart" -}} 30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} 31 | {{- end }} 32 | 33 | {{/* 34 | Common labels 35 | */}} 36 | {{- define "training-operator.labels" -}} 37 | helm.sh/chart: {{ include "training-operator.chart" . }} 38 | {{ include "training-operator.selectorLabels" . }} 39 | {{- if .Chart.AppVersion }} 40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} 41 | {{- end }} 42 | app.kubernetes.io/managed-by: {{ .Release.Service }} 43 | {{- end }} 44 | 45 | {{/* 46 | Selector labels 47 | */}} 48 | {{- define "training-operator.selectorLabels" -}} 49 | app.kubernetes.io/name: {{ include "training-operator.name" . }} 50 | app.kubernetes.io/instance: {{ .Release.Name }} 51 | {{- end }} 52 | 53 | {{/* 54 | Create the name of the service account to use 55 | */}} 56 | {{- define "training-operator.serviceAccountName" -}} 57 | {{- if .Values.serviceAccount.create }} 58 | {{- default (include "training-operator.fullname" .) .Values.serviceAccount.name }} 59 | {{- else }} 60 | {{- default "default" .Values.serviceAccount.name }} 61 | {{- end }} 62 | {{- end }} 63 | -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/templates/kubeflow-namespace.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Values.createNamespace -}} 2 | --- 3 | apiVersion: v1 4 | kind: Namespace 5 | metadata: 6 | name: kubeflow 7 | {{- end -}} -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/charts/training-operators/values.yaml: -------------------------------------------------------------------------------- 1 | ## Training operator container repo and tag 2 | image: 3 | repository: kubeflow/training-operator 4 | tag: "v1-855e096" 5 | createNamespace: true 6 | schemes: 7 | - pytorchjob 8 | - tfjob 9 | - mxjob 10 | - xgboostjob 11 | - paddlejob 12 | # disabled to avoid the conflicting mpijobs.kubeflow.org with mpi-operator 13 | # - mpijob -------------------------------------------------------------------------------- /helm_chart/HyperPodHelmChart/values.yaml: -------------------------------------------------------------------------------- 1 | # Default values for HyperPodHelmChart. 2 | # This is a YAML-formatted file. 3 | # Declare variables to be passed into your templates. 4 | 5 | replicaCount: 1 6 | 7 | image: 8 | repository: nginx 9 | pullPolicy: IfNotPresent 10 | # Overrides the image tag whose default is the chart appVersion. 11 | tag: "" 12 | 13 | imagePullSecrets: [] 14 | nameOverride: "" 15 | fullnameOverride: "" 16 | 17 | serviceAccount: 18 | # Specifies whether a service account should be created 19 | create: true 20 | # Automatically mount a ServiceAccount's API credentials? 21 | automount: true 22 | # Annotations to add to the service account 23 | annotations: {} 24 | # The name of the service account to use. 25 | # If not set and create is true, a name is generated using the fullname template 26 | name: "" 27 | 28 | podAnnotations: {} 29 | podLabels: {} 30 | 31 | podSecurityContext: {} 32 | # fsGroup: 2000 33 | 34 | securityContext: {} 35 | # capabilities: 36 | # drop: 37 | # - ALL 38 | # readOnlyRootFilesystem: true 39 | # runAsNonRoot: true 40 | # runAsUser: 1000 41 | 42 | service: 43 | type: ClusterIP 44 | port: 80 45 | 46 | ingress: 47 | enabled: false 48 | className: "" 49 | annotations: {} 50 | # kubernetes.io/ingress.class: nginx 51 | # kubernetes.io/tls-acme: "true" 52 | hosts: 53 | - host: chart-example.local 54 | paths: 55 | - path: / 56 | pathType: ImplementationSpecific 57 | tls: [] 58 | # - secretName: chart-example-tls 59 | # hosts: 60 | # - chart-example.local 61 | 62 | resources: {} 63 | # We usually recommend not to specify default resources and to leave this as a conscious 64 | # choice for the user. This also increases chances charts run on environments with little 65 | # resources, such as Minikube. If you do want to specify resources, uncomment the following 66 | # lines, adjust them as necessary, and remove the curly braces after 'resources:'. 67 | # limits: 68 | # cpu: 100m 69 | # memory: 128Mi 70 | # requests: 71 | # cpu: 100m 72 | # memory: 128Mi 73 | 74 | livenessProbe: 75 | httpGet: 76 | path: / 77 | port: http 78 | readinessProbe: 79 | httpGet: 80 | path: / 81 | port: http 82 | 83 | autoscaling: 84 | enabled: false 85 | minReplicas: 1 86 | maxReplicas: 100 87 | targetCPUUtilizationPercentage: 80 88 | # targetMemoryUtilizationPercentage: 80 89 | 90 | # Additional volumes on the output Deployment definition. 91 | volumes: [] 92 | # - name: foo 93 | # secret: 94 | # secretName: mysecret 95 | # optional: false 96 | 97 | # Additional volumeMounts on the output Deployment definition. 98 | volumeMounts: [] 99 | # - name: foo 100 | # mountPath: "/etc/foo" 101 | # readOnly: true 102 | 103 | nodeSelector: {} 104 | 105 | tolerations: [] 106 | 107 | affinity: {} 108 | 109 | namespace: 110 | create: true 111 | name: aws-hyperpod 112 | 113 | mlflow: 114 | enabled: false 115 | 116 | trainingOperators: 117 | enabled: true 118 | 119 | storage: 120 | enabled: false 121 | 122 | cluster-role-and-bindings: 123 | enabled: false 124 | 125 | namespaced-role-and-bindings: 126 | enable: false 127 | 128 | team-role-and-bindings: 129 | enabled: false 130 | 131 | nvidia-device-plugin: 132 | devicePlugin: 133 | enabled: true 134 | allowDefaultNamespace: true 135 | namespaceOverride: "kube-system" 136 | affinity: 137 | nodeAffinity: 138 | requiredDuringSchedulingIgnoredDuringExecution: 139 | nodeSelectorTerms: 140 | - matchExpressions: 141 | - key: node.kubernetes.io/instance-type 142 | operator: In 143 | values: 144 | - ml.g5.xlarge 145 | - ml.g5.2xlarge 146 | - ml.g5.4xlarge 147 | - ml.g5.8xlarge 148 | - ml.g5.12xlarge 149 | - ml.g5.16xlarge 150 | - ml.g5.24xlarge 151 | - ml.g5.48xlarge 152 | - ml.g6.xlarge 153 | - ml.g6.2xlarge 154 | - ml.g6.4xlarge 155 | - ml.g6.8xlarge 156 | - ml.g6.16xlarge 157 | - ml.g6.12xlarge 158 | - ml.g6.24xlarge 159 | - ml.g6.48xlarge 160 | - ml.g6e.xlarge 161 | - ml.g6e.2xlarge 162 | - ml.g6e.4xlarge 163 | - ml.g6e.8xlarge 164 | - ml.g6e.12xlarge 165 | - ml.g6e.16xlarge 166 | - ml.g6e.24xlarge 167 | - ml.g6e.48xlarge 168 | - ml.gr6.4xlarge 169 | - ml.gr6.8xlarge 170 | - ml.p4d.24xlarge 171 | - ml.p4de.24xlarge 172 | - ml.p5.48xlarge 173 | - ml.p5e.48xlarge 174 | - ml.p5en.48xlarge 175 | tolerations: 176 | - key: nvidia.com/gpu 177 | operator: Exists 178 | effect: NoSchedule 179 | - key: sagemaker.amazonaws.com/node-health-status 180 | operator: Equal 181 | value: Unschedulable 182 | effect: NoSchedule 183 | 184 | neuron-device-plugin: 185 | devicePlugin: 186 | enabled: true 187 | 188 | aws-efa-k8s-device-plugin: 189 | devicePlugin: 190 | enabled: true 191 | supportedInstanceLabels: 192 | values: 193 | - ml.c5n.9xlarge 194 | - ml.c5n.18xlarge 195 | - ml.g5.8xlarge 196 | - ml.g5.12xlarge 197 | - ml.g5.16xlarge 198 | - ml.g5.24xlarge 199 | - ml.g5.48xlarge 200 | - ml.g6.8xlarge 201 | - ml.g6.12xlarge 202 | - ml.g6.16xlarge 203 | - ml.g6.24xlarge 204 | - ml.g6.48xlarge 205 | - ml.g6e.8xlarge 206 | - ml.g6e.12xlarge 207 | - ml.g6e.16xlarge 208 | - ml.g6e.24xlarge 209 | - ml.g6e.48xlarge 210 | - ml.gr6.8xlarge 211 | - ml.i3en.large 212 | - ml.i3en.xlarge 213 | - ml.i3en.2xlarge 214 | - ml.i3en.3xlarge 215 | - ml.i3en.6xlarge 216 | - ml.i3en.12xlarge 217 | - ml.i3en.24xlarge 218 | - ml.m7i.large 219 | - ml.m7i.xlarge 220 | - ml.m7i.2xlarge 221 | - ml.m7i.4xlarge 222 | - ml.m7i.8xlarge 223 | - ml.m7i.12xlarge 224 | - ml.m7i.16xlarge 225 | - ml.m7i.24xlarge 226 | - ml.m7i.48xlarge 227 | - ml.p4d.24xlarge 228 | - ml.p4de.24xlarge 229 | - ml.p5.48xlarge 230 | - ml.p5e.48xlarge 231 | - ml.p5en.48xlarge 232 | - ml.r7i.large 233 | - ml.r7i.xlarge 234 | - ml.r7i.2xlarge 235 | - ml.r7i.4xlarge 236 | - ml.r7i.8xlarge 237 | - ml.r7i.12xlarge 238 | - ml.r7i.16xlarge 239 | - ml.r7i.24xlarge 240 | - ml.r7i.48xlarge 241 | - ml.trn1.32xlarge 242 | - ml.trn1n.32xlarge 243 | - ml.trn2.48xlarge 244 | tolerations: 245 | - key: CriticalAddonsOnly 246 | operator: Exists 247 | - effect: NoSchedule 248 | key: aws.amazon.com/efa 249 | operator: Exists 250 | - key: sagemaker.amazonaws.com/node-health-status 251 | operator: "Equal" 252 | value: "Unschedulable" 253 | effect: "NoSchedule" 254 | 255 | mpi-operator: 256 | enabled: true 257 | health-monitoring-agent: 258 | enabled: true 259 | deep-health-check: 260 | enabled: true 261 | job-auto-restart: 262 | enabled: true 263 | hyperpod-patching: 264 | enabled: true 265 | -------------------------------------------------------------------------------- /helm_chart/install_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to check the status of the last command and exit if it failed 4 | check_status() { 5 | if [ $? -ne 0 ]; then 6 | echo "An error occurred during the installation. Exiting." 7 | exit 1 8 | fi 9 | } 10 | 11 | check_status 12 | 13 | # Add additional dependencies below as needed 14 | 15 | # Example: Installing an additional dependency (replace with actual URL) 16 | # echo "Installing Example Dependency..." 17 | # kubectl apply -f "https://example.com/path/to/dependency.yaml" 18 | # check_status 19 | 20 | # Add more dependencies here 21 | # echo "Installing Another Dependency..." 22 | # kubectl apply -f "https://example.com/path/to/another-dependency.yaml" 23 | # check_status 24 | 25 | echo "All dependencies installed successfully." 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 3 | 4 | [tool.isort] 5 | known_first_party = ["hyperpod_cli"] 6 | 7 | # required for compatibility with black: 8 | profile = "black" 9 | 10 | # To maintain consistency with other settings 11 | line_length = 100 12 | 13 | [tool.mypy] 14 | # See https://mypy.readthedocs.io/en/latest/config_file.html for more mypy options. 15 | 16 | # Enables the type-checker on the interior of functions without type annotations. 17 | check_untyped_defs = true 18 | 19 | # Displaying specific error codes makes it easier to silence specific errors 20 | # See also https://mypy.readthedocs.io/en/latest/error_codes.html 21 | show_error_codes = true 22 | 23 | # Show source code snippets and location markers in error messages 24 | pretty = true 25 | 26 | # Suppresses errors about packages which do not implement type-hint sharing. 27 | # See also https://mypy.readthedocs.io/en/latest/running_mypy.html#missing-imports. 28 | ignore_missing_imports = true 29 | 30 | # Force override for yaml package as the package installed name is pyyaml 31 | # See https://github.com/python/mypy/issues/10632 32 | [[tool.mypy.overrides]] 33 | module = ["yaml","tabulate"] 34 | ignore_missing_imports = true 35 | 36 | [tool.ruff] 37 | # Exclude a variety of commonly ignored directories. 38 | exclude = [ 39 | "custom_launcher" 40 | ] 41 | 42 | # Same as Black. 43 | line-length = 88 44 | indent-width = 4 45 | 46 | # Assume Python 3.8 47 | target-version = "py38" 48 | 49 | [tool.ruff.lint] 50 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 51 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or 52 | # McCabe complexity (`C901`) by default. 53 | select = ["E4", "E7", "E9", "F"] 54 | ignore = [] 55 | 56 | # Allow fix for all enabled rules (when `--fix`) is provided. 57 | fixable = ["ALL"] 58 | unfixable = [] 59 | 60 | # Allow unused variables when underscore-prefixed. 61 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 62 | 63 | [tool.ruff.format] 64 | # Like Black, use double quotes for strings. 65 | quote-style = "double" 66 | 67 | # Like Black, indent with spaces, rather than tabs. 68 | indent-style = "space" 69 | 70 | # Like Black, respect magic trailing commas. 71 | skip-magic-trailing-comma = false 72 | 73 | # Like Black, automatically detect the appropriate line ending. 74 | line-ending = "auto" 75 | 76 | # Enable auto-formatting of code examples in docstrings. Markdown, 77 | # reStructuredText code/literal blocks and doctests are all supported. 78 | # 79 | # This is currently disabled by default, but it is planned for this 80 | # to be opt-out in the future. 81 | docstring-code-format = false 82 | 83 | # Set the line length limit used when formatting code snippets in 84 | # docstrings. 85 | # 86 | # This only has an effect when the `docstring-code-format` setting is 87 | # enabled. 88 | docstring-code-line-length = "dynamic" -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = hyperpodcli 3 | version = 1.0 4 | description = SageMaker HyperPod CLI 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | 8 | [options] 9 | zip_safe = True 10 | include_package_data = True 11 | package_dir = 12 | = src 13 | packages = find: 14 | 15 | [options.packages.find] 16 | where = src 17 | exclude = 18 | test 19 | 20 | [options.package_data] 21 | hyperpod_cli = 22 | py.typed 23 | 24 | 25 | [options.entry_points] 26 | # declare your scripts 27 | # If you want to create any Python executables in bin/, define them here. 28 | # This is a three-step process: 29 | # 30 | # 1. Create the function you want to run on the CLI in src/hyperpod_cli/cli.py 31 | # For convenience I usually recommend calling it main() 32 | # 33 | # 2. Uncomment this section of the setup.cfg options; this will create 34 | # bin/HyperpodCLI (which you can obviously change!) as a script 35 | # that will call your main() function, above. 36 | # 37 | # console_scripts = 38 | # HyperpodCLI = hyperpod_cli.cli:main 39 | # 40 | # 3. Uncomment the Python interpreter and Python-setuptools in the 41 | # dependencies section of your Config. This is necessary to guarantee the 42 | # presence of a runtime interpreter and for the script generated by 43 | # setuptools to find its function. 44 | 45 | # Configuration for pytest; enable coverage for HyperpodCLI, emit 46 | # XML, HTML, and terminal reports. 47 | 48 | [tool:pytest] 49 | xfail_strict = true 50 | addopts = 51 | --verbose 52 | --ignore=build/private 53 | --cov hyperpod_cli 54 | --cov-config setup.cfg 55 | --cov-report term-missing 56 | --cov-report html:build/hyperpod-documentation/coverage 57 | --cov-report xml:build/hyperpod-documentation/coverage/coverage.xml 58 | # show the slowest 5 tests at the end 59 | --durations=5 60 | # Default to colorful output 61 | --color=yes 62 | # Uncomment to enforce a minimum code coverage threshold. 63 | # --cov-fail-under 50 64 | testpaths = test 65 | looponfailroots = src test 66 | 67 | [coverage:run] 68 | branch = true 69 | parallel = true 70 | 71 | omit = 72 | # omit anything in a .local directory anywhere 73 | */custom_launcher/* 74 | 75 | [coverage:paths] 76 | source = 77 | src/ 78 | build/lib/*/site-packages/ 79 | 80 | [coverage:html] 81 | directory = build/hyperpod-documentation/coverage 82 | 83 | [coverage:xml] 84 | output = build/hyperpod-documentation/coverage/coverage.xml 85 | 86 | [flake8] 87 | ignore = 88 | # Not pep8, black adds whitespace before ':' 89 | E203, 90 | # Not pep8, black adds line break before binary operator 91 | W503, 92 | # Once `bb format` is done with things, the only remaining long lines do not 93 | # matter; we can ignore them. 94 | E501, 95 | max_line_length = 100 96 | # Uncomment to enforce a maximum cyclomatic complexity - more info https://en.wikipedia.org/wiki/Cyclomatic_complexity 97 | # max_complexity=10 98 | 99 | [build_sphinx] 100 | warning-is-error = 1 101 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import os 14 | import subprocess 15 | 16 | from setuptools import find_packages, setup 17 | 18 | # Update submodules 19 | subprocess.call( 20 | [ 21 | "git", 22 | "submodule", 23 | "update", 24 | "--init", 25 | "--recursive", 26 | "--remote", 27 | ] 28 | ) 29 | 30 | # Declare your non-python data files: 31 | # Files underneath configuration/ will be copied into the build preserving the 32 | # subdirectory structure if they exist. 33 | sagemaker_hyperpod_recipes = [] 34 | for root, dirs, files in os.walk( 35 | "src/hyperpod_cli/sagemaker_hyperpod_recipes" 36 | ): 37 | sagemaker_hyperpod_recipes.append( 38 | ( 39 | os.path.relpath( 40 | root, 41 | "src/hyperpod_cli/sagemaker_hyperpod_recipes", 42 | ), 43 | [os.path.join(root, f) for f in files], 44 | ) 45 | ) 46 | 47 | setup( 48 | data_files=sagemaker_hyperpod_recipes, 49 | name="hyperpod", 50 | version="2.0.0", 51 | packages=find_packages(where="src", exclude=("test",)), 52 | install_requires=[ 53 | "click==8.1.7", 54 | "awscli>=1.34.9", 55 | "awscli-cwlogs>=1.4.6", 56 | "boto3>=1.35.3,<2.0", 57 | "botocore>=1.35.6 ", 58 | "kubernetes==30.1.0", 59 | "pyyaml==6.0.2", 60 | "ratelimit==2.2.1", 61 | "tabulate==0.9.0", 62 | # NeMo framework required packages: 63 | # https://github.com/NVIDIA/NeMo-Framework-Launcher/blob/23.11/requirements.txt 64 | "hydra-core==1.3.2", 65 | "omegaconf==2.3", 66 | "pynvml==11.4.1", 67 | "requests==2.32.3", 68 | "tqdm==4.66.5", 69 | "zstandard==0.15.2", 70 | # Test dependencies 71 | "pytest==8.3.2", 72 | "pytest-cov==5.0.0", 73 | "pytest-order==1.3.0", 74 | "tox==4.18.0", 75 | "ruff==0.6.2", 76 | "hera-workflows==5.16.3", 77 | ], 78 | entry_points={ 79 | "console_scripts": [ 80 | "hyperpod=hyperpod_cli.cli:cli", 81 | ], 82 | }, 83 | check_format=True, 84 | # Enable type checking 85 | test_mypy=True, 86 | # Enable linting at build time 87 | test_flake8=True, 88 | ) 89 | -------------------------------------------------------------------------------- /src/hyperpod_cli/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | # Implement your code here. 14 | -------------------------------------------------------------------------------- /src/hyperpod_cli/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import click 14 | import importlib.metadata 15 | 16 | from hyperpod_cli.commands.cluster import ( 17 | connect_cluster, 18 | get_clusters, 19 | ) 20 | from hyperpod_cli.commands.job import ( 21 | cancel_job, 22 | get_job, 23 | list_jobs, 24 | list_pods, 25 | patch_job, 26 | start_job, 27 | ) 28 | from hyperpod_cli.commands.pod import ( 29 | exec, 30 | get_log, 31 | ) 32 | 33 | HELP_TEXT = """ 34 | Find more information at: https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod.html 35 | 36 | Basic Commands: 37 | * get-clusters Get clusters information for HyperPod EKS clusters. 38 | * connect-cluster Creates a connection from users local terminal to the HyperPod cluster 39 | allowing user to start and preform other basic operations with training jobs. 40 | * start-job Start a training job from a file on HyperPod cluster. 41 | * get-job Show details of a specific training job submitted on HyperPod cluster. 42 | * list-jobs List training job on a HyperPod cluster. 43 | * cancel-job Cancel training job on a HyperPod cluster. 44 | * patch-job Patch a job with specific operation on a HyperPod cluster. 45 | Troubleshooting and Debugging Commands: 46 | * get-log Get logs for a pod of training job running on HyperPod cluster. 47 | * list-pods List all pods associated with a training job on HyperPod cluster. 48 | * exec Execute a command on a pod of a training job on HyperPod cluster. 49 | 50 | Usage: 51 | hyperpod [command] [options] 52 | 53 | Use "hyperpod --help" for more information about a given command. 54 | """ 55 | 56 | VERSION = importlib.metadata.version("hyperpod") 57 | 58 | 59 | class HyperPodCommandGroup(click.Group): 60 | def format_help(self, ctx, formatter): 61 | click.echo(HELP_TEXT) 62 | 63 | 64 | @click.group(cls=HyperPodCommandGroup) 65 | @click.version_option(version=VERSION) 66 | def cli(): 67 | pass 68 | 69 | 70 | cli.add_command(get_clusters) 71 | cli.add_command(connect_cluster) 72 | cli.add_command(start_job) 73 | cli.add_command(get_job) 74 | cli.add_command(list_jobs) 75 | cli.add_command(cancel_job) 76 | cli.add_command(patch_job) 77 | cli.add_command(exec) 78 | cli.add_command(list_pods) 79 | cli.add_command(get_log) 80 | 81 | if __name__ == "__main__": 82 | cli() 83 | -------------------------------------------------------------------------------- /src/hyperpod_cli/clients/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/hyperpod_cli/commands/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/hyperpod_cli/commands/pod.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import sys 14 | import logging 15 | from typing import Optional 16 | 17 | import click 18 | 19 | from hyperpod_cli.service.exec_command import ( 20 | ExecCommand, 21 | ) 22 | from hyperpod_cli.service.get_logs import GetLogs 23 | from hyperpod_cli.utils import ( 24 | setup_logger, 25 | set_logging_level, 26 | ) 27 | 28 | logger = setup_logger(__name__) 29 | 30 | 31 | @click.command() 32 | @click.option( 33 | "--job-name", 34 | type=click.STRING, 35 | required=True, 36 | help="Required. The name of the job to get the log for.", 37 | ) 38 | @click.option( 39 | "--pod", 40 | "-p", 41 | type=click.STRING, 42 | required=True, 43 | help="Required. The name of the pod to get the log from.", 44 | ) 45 | @click.option( 46 | "--namespace", 47 | "-n", 48 | type=click.STRING, 49 | required=False, 50 | help="Optional. The namespace to get the log from. If not provided, the CLI will get the log from the pod in the namespace set by the user while connecting to the cluster. If provided, and the user has access to the namespace, the CLI will get the log from the pod in the specified namespace.", 51 | ) 52 | @click.option( 53 | "--debug", 54 | is_flag=True, 55 | help="Enable debug mode", 56 | ) 57 | def get_log( 58 | job_name: str, 59 | pod: str, 60 | namespace: Optional[str], 61 | debug: bool, 62 | ): 63 | """Get the log of the specified training job.""" 64 | if debug: 65 | set_logging_level(logger, logging.DEBUG) 66 | 67 | get_logs_service = GetLogs() 68 | 69 | try: 70 | logger.debug("Getting logs for the training job") 71 | result = get_logs_service.get_training_job_logs( 72 | job_name, pod, namespace=namespace 73 | ) 74 | click.echo(result) 75 | except Exception as e: 76 | sys.exit( 77 | f"Unexpected error happens when trying to get logs for training job {job_name} : {e}" 78 | ) 79 | 80 | try: 81 | cloudwatch_link = get_logs_service.generate_cloudwatch_link(pod, namespace=namespace) 82 | if cloudwatch_link: 83 | click.echo(cloudwatch_link) 84 | except Exception as e: 85 | click.echo(f"WARNING: Failed to generate container insights cloudwatch link: {e}") 86 | 87 | def _exec_command_required_option_pod_and_all_pods(): 88 | class OptionRequiredClass(click.Command): 89 | def invoke(self, ctx): 90 | pod = ctx.params["pod"] 91 | all_pods = ctx.params["all_pods"] 92 | if not pod and not all_pods: 93 | raise click.ClickException( 94 | "With job-name name must specify option --pod or --all-pods" 95 | ) 96 | if pod and all_pods: 97 | raise click.ClickException( 98 | "With job-name name must specify only one option --pod or --all-pods" 99 | ) 100 | super(OptionRequiredClass, self).invoke(ctx) 101 | 102 | return OptionRequiredClass 103 | 104 | 105 | @click.command( 106 | cls=_exec_command_required_option_pod_and_all_pods(), 107 | context_settings={ 108 | "ignore_unknown_options": True, 109 | "allow_extra_args": False, 110 | }, 111 | ) 112 | @click.option( 113 | "--job-name", 114 | type=click.STRING, 115 | required=True, 116 | help="Required. The name of the job to execute the command on.", 117 | ) 118 | @click.option( 119 | "--namespace", 120 | "-n", 121 | type=click.STRING, 122 | nargs=1, 123 | required=False, 124 | help="Optional. The namespace to execute the command in. If not provided, the CLI will try to execute the command in the pod in the namespace set by the user while connecting to the cluster. If provided, and the user has access to the namespace, the CLI will execute the command in the pod from the specified namespace.", 125 | ) 126 | @click.option( 127 | "--pod", 128 | "-p", 129 | type=click.STRING, 130 | nargs=1, 131 | required=False, 132 | help="Optional. The name of the pod to execute the command in. You must provide either `--pod` or `--all-pods`.", 133 | ) 134 | @click.option( 135 | "--all-pods", 136 | type=click.BOOL, 137 | is_flag=True, 138 | required=False, 139 | help="Optional. If set, the command will be executed in all pods associated with the job. You must provide either `--pod` or `--all-pods`.", 140 | ) 141 | @click.argument( 142 | "bash_command", 143 | nargs=-1, 144 | type=click.UNPROCESSED, 145 | ) 146 | @click.option( 147 | "--debug", 148 | is_flag=True, 149 | help="Enable debug mode", 150 | ) 151 | def exec( 152 | job_name: str, 153 | namespace: Optional[str], 154 | pod: Optional[str], 155 | all_pods: Optional[bool], 156 | debug: bool, 157 | bash_command: tuple, 158 | ): 159 | """Execute a bash command in the specified job.""" 160 | if debug: 161 | set_logging_level(logger, logging.DEBUG) 162 | 163 | exec_command_service = ExecCommand() 164 | 165 | try: 166 | logger.debug("Executing command for the training job") 167 | result = exec_command_service.exec_command( 168 | job_name, 169 | pod, 170 | namespace, 171 | all_pods, 172 | bash_command, 173 | ) 174 | click.echo(result) 175 | except Exception as e: 176 | sys.exit( 177 | f"Unexpected error happens when trying to exec command for pod {pod} : {e}" 178 | ) 179 | -------------------------------------------------------------------------------- /src/hyperpod_cli/constants/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/hyperpod_cli/constants/command_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from enum import Enum 14 | from pathlib import Path 15 | 16 | GENERATED_LAUNCHER_CONFIG_FILE_PATH = "/tmp/" 17 | HYPERPOD_KUBERNETES_JOB_PREFIX = "hyperpod-k8s-job" 18 | HYPERPOD_CLUSTER_CONTEXT_FILE_NAME = "hyperpod_current_context.json" 19 | NODE_AFFINITY_DICT = { 20 | "required": {"sagemaker.amazonaws.com/node-health-status": ["Schedulable"]}, 21 | "preferred": {"sagemaker.amazonaws.com/deep-health-check-status": ["Passed"]}, 22 | "weights": [100], 23 | } 24 | DEEP_HEALTH_CHECK_PASSED_ONLY_NODE_AFFINITY_DICT = { 25 | "required": { 26 | "sagemaker.amazonaws.com/deep-health-check-status": ["Passed"], 27 | }, 28 | } 29 | KUEUE_QUEUE_NAME_LABEL_KEY = "kueue.x-k8s.io/queue-name" 30 | KUEUE_WORKLOAD_PRIORITY_CLASS_LABEL_KEY = "kueue.x-k8s.io/priority-class" 31 | KUEUE_JOB_UID_LABEL_KEY = "kueue.x-k8s.io/job-uid" 32 | HYPERPOD_AUTO_RESUME_ANNOTATION_KEY = "sagemaker.amazonaws.com/enable-job-auto-resume" 33 | HYPERPOD_MAX_RETRY_ANNOTATION_KEY = "sagemaker.amazonaws.com/job-max-retry-count" 34 | ENV_VARS_DICT = {"NCCL_DEBUG": "INFO"} 35 | SAGEMAKER_HYPERPOD_NAME_LABEL = "sagemaker.amazonaws.com/cluster-name" 36 | HP_HEALTH_STATUS_LABEL = "sagemaker.amazonaws.com/node-health-status" 37 | INSTANCE_TYPE_LABEL = "node.kubernetes.io/instance-type" 38 | DEEP_HEALTH_CHECK_STATUS_LABEL = "sagemaker.amazonaws.com/deep-health-check-status" 39 | SAGEMAKER_MANAGED_QUEUE_LABEL= "sagemaker.amazonaws.com/sagemaker-managed-queue" 40 | SAGEMAKER_QUOTA_ALLOCATION_LABEL = "sagemaker.amazonaws.com/quota-allocation-id" 41 | TEMP_KUBE_CONFIG_FILE = "/tmp/kubeconfig" 42 | HYPERPOD_NAMESPACE_PREFIX = "hyperpod-ns-" 43 | SAGEMAKER_MANAGED_LOCAL_QUEUE_SUFFIX = "-localqueue" 44 | SAGEMAKER_MANAGED_CLUSTER_QUEUE_SUFFIX = "-clusterqueue" 45 | SAGEMAKER_TRAINING_LAUNCHER_DIR = str(Path(__file__).parent.parent / "sagemaker_hyperpod_recipes") 46 | NVIDIA_GPU_RESOURCE_LIMIT_KEY = "nvidia.com/gpu" 47 | AVAILABLE_ACCELERATOR_DEVICES_KEY = "AvailableAcceleratorDevices" 48 | TOTAL_ACCELERATOR_DEVICES_KEY = "TotalAcceleratorDevices" 49 | USER_NAME_LABEL_KEY = "sagemaker.user/created-by" 50 | 51 | class PullPolicy(Enum): 52 | ALWAYS = "Always" 53 | IF_NOT_PRESENT = "IfNotPresent" 54 | NEVER = "Never" 55 | 56 | 57 | class RestartPolicy(Enum): 58 | ALWAYS = "Always" 59 | ON_FAILURE = "OnFailure" 60 | NEVER = "Never" 61 | EXIT_CODE = "ExitCode" 62 | 63 | 64 | class Orchestrator(Enum): 65 | EKS = "eks" 66 | 67 | 68 | class OutputFormat(Enum): 69 | JSON = "json" 70 | TABLE = "table" 71 | 72 | 73 | class PersistentVolumeClaim: 74 | claim_name: str 75 | mount_path: str 76 | 77 | def __init__(self, claim_name, mount_path): 78 | self.claim_name = claim_name 79 | self.mount_path = mount_path 80 | 81 | 82 | class SchedulerType(Enum): 83 | KUEUE = "Kueue" 84 | SAGEMAKER = "SageMaker" 85 | NONE = "None" 86 | 87 | def get_values(): 88 | return [scheduler_type.value for scheduler_type in SchedulerType] 89 | 90 | def get_default(): 91 | return SchedulerType.SAGEMAKER 92 | 93 | class JobPatchType(Enum): 94 | SUSPEND = "suspend" 95 | UNSUSPEND = "unsuspend" 96 | 97 | def get_values(): 98 | return [type.value for type in JobPatchType] 99 | 100 | 101 | class Volume: 102 | volume_name: str 103 | host_path: str 104 | mount_path: str 105 | 106 | def __init__(self, volume_name, host_path, mount_path): 107 | self.host_path = host_path 108 | self.mount_path = mount_path 109 | self.volume_name = volume_name 110 | -------------------------------------------------------------------------------- /src/hyperpod_cli/constants/exception_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | RESOURCE_NOT_FOUND_CODE = 404 14 | -------------------------------------------------------------------------------- /src/hyperpod_cli/constants/hyperpod_instance_types.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from enum import Enum 14 | 15 | 16 | class HyperpodInstanceType(Enum): 17 | ML_C5_LARGE = "ml.c5.large" 18 | ML_C5_XLARGE = "ml.c5.xlarge" 19 | ML_C5_2XLARGE = "ml.c5.2xlarge" 20 | ML_C5_4XLARGE = "ml.c5.4xlarge" 21 | ML_C5_9XLARGE = "ml.c5.9xlarge" 22 | ML_C5_12XLARGE = "ml.c5.12xlarge" 23 | ML_C5_18XLARGE = "ml.c5.18xlarge" 24 | ML_C5_24XLARGE = "ml.c5.24xlarge" 25 | ML_C5N_LARGE = "ml.c5n.large" 26 | ML_C5N_2XLARGE = "ml.c5n.2xlarge" 27 | ML_C5N_4XLARGE = "ml.c5n.4xlarge" 28 | ML_C5N_9XLARGE = "ml.c5n.9xlarge" 29 | ML_C5N_18XLARGE = "ml.c5n.18xlarge" 30 | ML_G5_XLARGE = "ml.g5.xlarge" 31 | ML_G5_2XLARGE = "ml.g5.2xlarge" 32 | ML_G5_4XLARGE = "ml.g5.4xlarge" 33 | ML_G5_8XLARGE = "ml.g5.8xlarge" 34 | ML_G5_12XLARGE = "ml.g5.12xlarge" 35 | ML_G5_16XLARGE = "ml.g5.16xlarge" 36 | ML_G5_24XLARGE = "ml.g5.24xlarge" 37 | ML_G5_48XLARGE = "ml.g5.48xlarge" 38 | ML_G6_XLARGE = "ml.g6.xlarge" 39 | ML_G6_2XLARGE = "ml.g6.2xlarge" 40 | ML_G6_4XLARGE = "ml.g6.4xlarge" 41 | ML_G6_8XLARGE = "ml.g6.8xlarge" 42 | ML_G6_12XLARGE = "ml.g6.12xlarge" 43 | ML_G6_16XLARGE = "ml.g6.16xlarge" 44 | ML_G6_24XLARGE = "ml.g6.24xlarge" 45 | ML_G6_48XLARGE = "ml.g6.48xlarge" 46 | ML_G6E_XLARGE = "ml.g6e.xlarge" 47 | ML_G6E_2XLARGE = "ml.g6e.2xlarge" 48 | ML_G6E_4XLARGE = "ml.g6e.4xlarge" 49 | ML_G6E_8XLARGE = "ml.g6e.8xlarge" 50 | ML_G6E_12XLARGE = "ml.g6e.12xlarge" 51 | ML_G6E_16XLARGE = "ml.g6e.16xlarge" 52 | ML_G6E_24XLARGE = "ml.g6e.24xlarge" 53 | ML_G6E_48XLARGE = "ml.g6e.48xlarge" 54 | ML_GR6_4XLARGE = "ml.gr6.4xlarge" 55 | ML_GR6_8XLARGE = "ml.gr6.8xlarge" 56 | ML_I3EN_LARGE = "ml.i3en.large" 57 | ML_I3EN_XLARGE = "ml.i3en.xlarge" 58 | ML_I3EN_2XLARGE = "ml.i3en.2xlarge" 59 | ML_I3EN_3XLARGE = "ml.i3en.3xlarge" 60 | ML_I3EN_6XLARGE = "ml.i3en.6xlarge" 61 | ML_I3EN_12XLARGE = "ml.i3en.12xlarge" 62 | ML_I3EN_24XLARGE = "ml.i3en.24xlarge" 63 | ML_M5_LARGE = "ml.m5.large" 64 | ML_M5_XLARGE = "ml.m5.xlarge" 65 | ML_M5_2XLARGE = "ml.m5.2xlarge" 66 | ML_M5_4XLARGE = "ml.m5.4xlarge" 67 | ML_M5_8XLARGE = "ml.m5.8xlarge" 68 | ML_M5_12XLARGE = "ml.m5.12xlarge" 69 | ML_M5_16XLARGE = "ml.m5.16xlarge" 70 | ML_M5_24XLARGE = "ml.m5.24xlarge" 71 | ML_M7I_LARGE = "ml.m7i.large" 72 | ML_M7I_XLARGE = "ml.m7i.xlarge" 73 | ML_M7I_2XLARGE = "ml.m7i.2xlarge" 74 | ML_M7I_4XLARGE = "ml.m7i.4xlarge" 75 | ML_M7I_8XLARGE = "ml.m7i.8xlarge" 76 | ML_M7I_12XLARGE = "ml.m7i.12xlarge" 77 | ML_M7I_16XLARGE = "ml.m7i.16xlarge" 78 | ML_M7I_24XLARGE = "ml.m7i.24xlarge" 79 | ML_M7I_48XLARGE = "ml.m7i.48xlarge" 80 | ML_P4D_24XLARGE = "ml.p4d.24xlarge" 81 | ML_P4DE_24XLARGE = "ml.p4de.24xlarge" 82 | ML_P5_48XLARGE = "ml.p5.48xlarge" 83 | ML_P5E_48XLARGE = "ml.p5e.48xlarge" 84 | ML_P5EN_48XLARGE = "ml.p5en.48xlarge" 85 | ML_R7I_LARGE = "ml.r7i.large" 86 | ML_R7I_XLARGE = "ml.r7i.xlarge" 87 | ML_R7I_2XLARGE = "ml.r7i.2xlarge" 88 | ML_R7I_4XLARGE = "ml.r7i.4xlarge" 89 | ML_R7I_8XLARGE = "ml.r7i.8xlarge" 90 | ML_R7I_12XLARGE = "ml.r7i.12xlarge" 91 | ML_R7I_16XLARGE = "ml.r7i.16xlarge" 92 | ML_R7I_24XLARGE = "ml.r7i.24xlarge" 93 | ML_R7I_48XLARGE = "ml.r7i.48xlarge" 94 | ML_T3_MEDIUM = "ml.t3.medium" 95 | ML_T3_LARGE = "ml.t3.large" 96 | ML_T3_XLARGE = "ml.t3.xlarge" 97 | ML_T3_2XLARGE = "ml.t3.2xlarge" 98 | ML_TRN1_32XLARGE = "ml.trn1.32xlarge" 99 | ML_TRN1N_32XLARGE = "ml.trn1n.32xlarge" 100 | ML_TRN2_48XLARGE = "ml.trn2.48xlarge" 101 | -------------------------------------------------------------------------------- /src/hyperpod_cli/constants/kueue_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | KUEUE_CUSTOM_OBJECT_GROUP = "kueue.x-k8s.io" 14 | WORKLOAD_CUSTOM_OBJECT_PLURAL = "workloads" 15 | KUEUE_CUSTOM_OBJECT_VERSION = "v1beta1" 16 | WORKLOAD_PRIORITY_CLASS_CUSTOM_OBJECT_PLURAL = "workloadpriorityclasses" 17 | CLUSTER_QUEUE_PRIORITY_CLASS_CUSTOM_OBJECT_PLURAL = "clusterqueues" 18 | -------------------------------------------------------------------------------- /src/hyperpod_cli/constants/pytorch_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | PYTORCH_CUSTOM_OBJECT_GROUP = "kubeflow.org" 14 | PYTORCH_CUSTOM_OBJECT_PLURAL = "pytorchjobs" 15 | PYTORCH_CUSTOM_OBJECT_VERSION = "v1" 16 | -------------------------------------------------------------------------------- /src/hyperpod_cli/py.typed: -------------------------------------------------------------------------------- 1 | # Marker file that indicates this package supports typing 2 | -------------------------------------------------------------------------------- /src/hyperpod_cli/service/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/hyperpod_cli/service/cancel_training_job.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from typing import Optional 14 | import subprocess 15 | 16 | from hyperpod_cli.clients.kubernetes_client import ( 17 | KubernetesClient, 18 | ) 19 | from kubernetes.client.rest import ApiException 20 | 21 | from hyperpod_cli.constants.pytorch_constants import ( 22 | PYTORCH_CUSTOM_OBJECT_GROUP, 23 | PYTORCH_CUSTOM_OBJECT_PLURAL 24 | ) 25 | from hyperpod_cli.service.discover_namespaces import DiscoverNamespaces 26 | from kubernetes.client import ( 27 | V1ResourceAttributes 28 | ) 29 | 30 | class CancelTrainingJob: 31 | def __init__(self): 32 | return 33 | 34 | def cancel_training_job(self, job_name: str, namespace: Optional[str]): 35 | """ 36 | Cancel training job provided by the user in the specified namespace. 37 | If namespace is not provided job is canceled from the default namespace in user context 38 | """ 39 | 40 | k8s_client = KubernetesClient() 41 | 42 | if not namespace: 43 | resource_attributes_template = V1ResourceAttributes( 44 | verb="delete", 45 | group=PYTORCH_CUSTOM_OBJECT_GROUP, 46 | resource=PYTORCH_CUSTOM_OBJECT_PLURAL, 47 | ) 48 | namespace = DiscoverNamespaces().discover_accessible_namespace( 49 | resource_attributes_template 50 | ) 51 | try: 52 | result = k8s_client.delete_training_job( 53 | job_name=job_name, namespace=namespace 54 | ) 55 | except ApiException as e: 56 | raise RuntimeError(f"Unexpected API error: {e.reason} ({e.status})") 57 | 58 | helm_chart_cleanup_command = [ 59 | "helm", 60 | "uninstall", 61 | job_name, 62 | "--namespace", 63 | namespace, 64 | ] 65 | 66 | if result.get("status") and result.get("status") == "Success": 67 | subprocess.run(helm_chart_cleanup_command, capture_output=True, text=True) 68 | return None 69 | else: 70 | return result 71 | -------------------------------------------------------------------------------- /src/hyperpod_cli/service/discover_namespaces.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import concurrent 14 | import sys 15 | import copy 16 | from kubernetes.client.rest import ApiException 17 | from hyperpod_cli.clients.kubernetes_client import KubernetesClient 18 | from hyperpod_cli.service.get_namespaces import GetNamespaces 19 | from hyperpod_cli.service.self_subject_access_review import SelfSubjectAccessReview 20 | from hyperpod_cli.utils import setup_logger 21 | 22 | 23 | logger = setup_logger(__name__) 24 | 25 | class DiscoverNamespaces: 26 | def __init__(self): 27 | return 28 | 29 | def discover_accessible_namespace(self, resource_attributes_template, only_sm_managed=True): 30 | """ 31 | Discover the accessible namespaces 32 | """ 33 | k8s_client = KubernetesClient() 34 | context_namespace = k8s_client.get_current_context_namespace() 35 | 36 | # If the namespace is explicitly set by user, take the namespace from the config instead 37 | # of discovering automatically 38 | if context_namespace is not None: 39 | return context_namespace 40 | 41 | try: 42 | if only_sm_managed: 43 | namespaces = GetNamespaces().get_sagemaker_managed_namespaces() 44 | else: 45 | namespaces = GetNamespaces().get_namespaces() 46 | 47 | discovered_namespaces = self.get_namespaces_by_checking_access_permission( 48 | namespaces, 49 | resource_attributes_template, 50 | ) 51 | 52 | if len(discovered_namespaces) == 0: 53 | logger.error("Found no accessible namespaces. Please ask for cluster admin for assistance or specify value of namespace explicitly in the command.") 54 | sys.exit(1) 55 | if len(discovered_namespaces) > 1: 56 | logger.error(f"Found more than 1 accessible namespaces {discovered_namespaces}. Please specify value of namespace explicitly in the command.") 57 | sys.exit(1) 58 | 59 | if len(discovered_namespaces) == 1: 60 | logger.info(f"Found accessible namespace: {discovered_namespaces}") 61 | return discovered_namespaces[0] 62 | except ApiException as e: 63 | if e.status == 403: 64 | logger.error("Got access denied error when discovering accessible namespaces, please specify the '--namespace' parameter in the command and try again.") 65 | sys.exit(1) 66 | 67 | 68 | def get_namespaces_by_checking_access_permission( 69 | self, 70 | namespaces, 71 | resource_attributes_template, 72 | max_workers=10 73 | ): 74 | """ 75 | Get the accessible namespaces by performing the SelfSubjectAccessReview. For each of the namespace, 76 | check if user has specified permission to it, if the answer is NO, the namespace will be skipped. 77 | 78 | Performing access check can take quite long if the number of namespaces is large. Thus the implementation 79 | leverages the multi-threading to ensure that multiple access check can be performed in parallel. 80 | """ 81 | subject_access_review = SelfSubjectAccessReview() 82 | accessible_namespaces = list() 83 | resource_attributes = list() 84 | 85 | for namespace in namespaces: 86 | resource_attribute = copy.deepcopy(resource_attributes_template) 87 | resource_attribute.namespace = namespace 88 | resource_attributes.append(resource_attribute) 89 | 90 | # Multi-thread the self subject access review to improve the performance 91 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: 92 | futures = { 93 | executor.submit( 94 | subject_access_review.self_subject_access_review, resource_attribute 95 | ) for resource_attribute in resource_attributes 96 | } 97 | 98 | for future in concurrent.futures.as_completed(futures): 99 | try: 100 | response = future.result() 101 | if response.status.allowed: 102 | accessible_namespaces.append( 103 | response.spec.resource_attributes.namespace 104 | ) 105 | except Exception as e: 106 | raise(e) 107 | 108 | return accessible_namespaces -------------------------------------------------------------------------------- /src/hyperpod_cli/service/exec_command.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from typing import Optional 14 | 15 | from hyperpod_cli.clients.kubernetes_client import ( 16 | KubernetesClient, 17 | ) 18 | from hyperpod_cli.service.list_pods import ( 19 | ListPods, 20 | ) 21 | 22 | from kubernetes.client.rest import ApiException 23 | 24 | 25 | class ExecCommand: 26 | def __init__(self): 27 | return 28 | 29 | def exec_command( 30 | self, 31 | job_name: str, 32 | pod_name: Optional[str], 33 | namespace: Optional[str], 34 | all_pods: Optional[bool], 35 | bash_command: tuple, 36 | ): 37 | bash_command_str: str = " ".join(bash_command) 38 | 39 | k8s_client = KubernetesClient() 40 | list_pods_service = ListPods() 41 | 42 | if not namespace: 43 | namespace = k8s_client.get_current_context_namespace() 44 | 45 | pods_for_training_job = list_pods_service.list_pods_for_training_job( 46 | job_name, namespace, False 47 | ) 48 | 49 | try: 50 | if all_pods: 51 | output = "" 52 | for pod in pods_for_training_job: 53 | output += pod + "\n" 54 | output += ( 55 | k8s_client.exec_command_on_pod( 56 | pod, 57 | namespace, 58 | bash_command_str, 59 | ) 60 | + "\n" 61 | ) 62 | output += "\n" 63 | return output 64 | else: 65 | if pod_name not in pods_for_training_job: 66 | raise RuntimeError( 67 | f"Given pod name {pod_name} is not associated with training job {job_name} in namespace {namespace}" 68 | ) 69 | return k8s_client.exec_command_on_pod( 70 | pod_name, 71 | namespace, 72 | bash_command_str, 73 | ) 74 | except ApiException as e: 75 | raise RuntimeError(f"Unexpected API error: {e.reason} ({e.status})") 76 | -------------------------------------------------------------------------------- /src/hyperpod_cli/service/get_logs.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from typing import Optional 14 | import boto3 15 | 16 | from hyperpod_cli.clients.kubernetes_client import ( 17 | KubernetesClient, 18 | ) 19 | from hyperpod_cli.service.discover_namespaces import DiscoverNamespaces 20 | from hyperpod_cli.service.list_pods import ( 21 | ListPods, 22 | ) 23 | from hyperpod_cli.utils import ( 24 | get_eks_cluster_name, 25 | get_hyperpod_cluster_region, 26 | validate_region_and_cluster_name, 27 | ) 28 | from kubernetes.client.rest import ApiException 29 | from kubernetes.client import V1ResourceAttributes 30 | import re 31 | 32 | AMAZON_ClOUDWATCH_OBSERVABILITY = "amazon-cloudwatch-observability" 33 | CONTAINER_INSIGHTS_LOG_REGEX_PATTERN = "https:\/\/([a-z0-9-]+).console.aws.amazon.com\/cloudwatch\/home\?region=([a-z0-9-]+)#logsV2:log-groups\/log-group\/\$252Faws\$252Fcontainerinsights\$252F([a-zA-Z0-9-]+)\$252Fapplication\/log-events\/([a-z0-9-]+)-application.var.log.containers.([a-z0-9-]+)_([a-z0-9-]+)_([a-z0-9-]+)-([a-z0-9-]+).log" 34 | 35 | class GetLogs: 36 | def __init__(self): 37 | return 38 | 39 | def get_training_job_logs( 40 | self, 41 | job_name: str, 42 | pod_name: str, 43 | namespace: Optional[str], 44 | ): 45 | """ 46 | Get logs for pod asscoited with the training job 47 | """ 48 | k8s_client = KubernetesClient() 49 | list_pods_service = ListPods() 50 | 51 | if not namespace: 52 | resource_attributes_template = V1ResourceAttributes( 53 | verb="get", 54 | group="", 55 | resource="pods", 56 | subresource="log", 57 | ) 58 | namespace = DiscoverNamespaces().discover_accessible_namespace( 59 | resource_attributes_template 60 | ) 61 | 62 | try: 63 | pods_for_training_job = list_pods_service.list_pods_for_training_job( 64 | job_name, namespace, False 65 | ) 66 | if pod_name not in pods_for_training_job: 67 | raise RuntimeError( 68 | f"Given pod name {pod_name} is not associated with training job {job_name} in namespace {namespace}" 69 | ) 70 | return k8s_client.get_logs_for_pod(pod_name, namespace) 71 | except ApiException as e: 72 | raise RuntimeError(f"Unexpected API error: {e.reason} ({e.status})") 73 | 74 | def generate_cloudwatch_link( 75 | self, 76 | pod_name: str, 77 | namespace: Optional[str], 78 | ): 79 | eks_cluster_name = get_eks_cluster_name() 80 | region = get_hyperpod_cluster_region() 81 | 82 | if self.is_container_insights_addon_enabled(eks_cluster_name): 83 | k8s_client = KubernetesClient() 84 | 85 | # pod_details is a V1Pod object 86 | pod_details = k8s_client.get_pod_details(pod_name, namespace) 87 | 88 | # get node name 89 | if pod_details.spec and pod_details.spec.node_name: 90 | node_name = pod_details.spec.node_name 91 | else: 92 | node_name = None 93 | 94 | # get container name 95 | if pod_details.spec and pod_details.spec.containers and pod_details.spec.containers[0].name: 96 | container_name = pod_details.spec.containers[0].name 97 | else: 98 | container_name = None 99 | 100 | # get container_id 101 | if pod_details.status and pod_details.status.container_statuses and pod_details.status.container_statuses[0].container_id: 102 | full_container_id = pod_details.status.container_statuses[0].container_id 103 | 104 | # full_container_id has format "containerd://xxxxxxxxxx" 105 | container_id = full_container_id[13:] if full_container_id.startswith('containerd://') else None 106 | else: 107 | container_id = None 108 | 109 | # Cloudwatch container insight log groups should have the same pod log as API response 110 | cloudwatch_url = self.get_log_url(eks_cluster_name, region, node_name, pod_name, namespace, container_name, container_id) 111 | 112 | if not validate_region_and_cluster_name(region, eks_cluster_name): 113 | raise ValueError('Eks cluster name or cluster region is invalid.') 114 | 115 | if not re.match(CONTAINER_INSIGHTS_LOG_REGEX_PATTERN, cloudwatch_url): 116 | raise ValueError("Failed to validate cloudwatch log url. Please make sure pod's node name, container name and container id are valid") 117 | 118 | cloudwatch_link = f'The pod cloudwatch log stream link is {cloudwatch_url}' 119 | else: 120 | cloudwatch_link = None 121 | 122 | return cloudwatch_link 123 | 124 | def get_log_url(self, eks_cluster_name, region, node_name, pod_name, namespace, container_name, container_id): 125 | console_prefix = f'https://{region}.console.aws.amazon.com/cloudwatch/home?region={region}#' 126 | log_group_prefix = f'logsV2:log-groups/log-group/$252Faws$252Fcontainerinsights$252F{eks_cluster_name}$252Fapplication/log-events/' 127 | log_stream = f'{node_name}-application.var.log.containers.{pod_name}_{namespace}_{container_name}-{container_id}.log' 128 | 129 | return console_prefix + log_group_prefix + log_stream 130 | 131 | def is_container_insights_addon_enabled(self, eks_cluster_name): 132 | response = boto3.client("eks").list_addons(clusterName=eks_cluster_name, maxResults=50) 133 | if AMAZON_ClOUDWATCH_OBSERVABILITY in response.get('addons', []): 134 | return True 135 | else: 136 | return False -------------------------------------------------------------------------------- /src/hyperpod_cli/service/get_namespaces.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from hyperpod_cli.clients.kubernetes_client import ( 14 | KubernetesClient, 15 | ) 16 | from kubernetes.client.rest import ApiException 17 | 18 | from hyperpod_cli.constants.command_constants import SAGEMAKER_MANAGED_QUEUE_LABEL 19 | 20 | 21 | LIMIT_PER_REQUEST = 200 22 | 23 | class GetNamespaces: 24 | 25 | def __init__(self): 26 | return 27 | 28 | def get_namespaces(self, label_selector=None): 29 | """ 30 | Get namespaces in the cluster 31 | """ 32 | core_v1_api = KubernetesClient().get_core_v1_api() 33 | all_namespaces = list() 34 | continue_token = None 35 | 36 | while True: 37 | response = None 38 | if continue_token: 39 | response = core_v1_api.list_namespace( 40 | limit=LIMIT_PER_REQUEST, label_selector=label_selector, _continue=continue_token 41 | ) 42 | else: 43 | response = core_v1_api.list_namespace( 44 | limit=LIMIT_PER_REQUEST, label_selector=label_selector 45 | ) 46 | 47 | all_namespaces.extend([ns.metadata.name for ns in response.items]) 48 | continue_token = response.metadata._continue 49 | 50 | if not continue_token: 51 | break 52 | 53 | return all_namespaces 54 | 55 | def get_sagemaker_managed_namespaces(self): 56 | """ 57 | Get sagemaker managed namespaces in the cluster 58 | """ 59 | 60 | return self.get_namespaces(SAGEMAKER_MANAGED_QUEUE_LABEL + "=true") 61 | -------------------------------------------------------------------------------- /src/hyperpod_cli/service/get_training_job.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from typing import Optional 14 | 15 | import json 16 | 17 | from hyperpod_cli.clients.kubernetes_client import ( 18 | KubernetesClient, 19 | ) 20 | 21 | from hyperpod_cli import utils 22 | from kubernetes.client.rest import ApiException 23 | from kubernetes.client import ( 24 | V1ResourceAttributes 25 | ) 26 | 27 | from hyperpod_cli.constants.pytorch_constants import PYTORCH_CUSTOM_OBJECT_GROUP, PYTORCH_CUSTOM_OBJECT_PLURAL 28 | from hyperpod_cli.service.discover_namespaces import DiscoverNamespaces 29 | 30 | class GetTrainingJob: 31 | def __init__(self): 32 | return 33 | 34 | def get_training_job( 35 | self, 36 | job_name: str, 37 | namespace: Optional[str], 38 | verbose: Optional[bool], 39 | ): 40 | """ 41 | Describe training job provided by the user in the specified namespace. 42 | If namespace is not provided job is described from the default namespace in user context 43 | """ 44 | 45 | k8s_client = KubernetesClient() 46 | if not namespace: 47 | resource_attributes_template = V1ResourceAttributes( 48 | verb="get", 49 | group=PYTORCH_CUSTOM_OBJECT_GROUP, 50 | resource=PYTORCH_CUSTOM_OBJECT_PLURAL, 51 | ) 52 | namespace = DiscoverNamespaces().discover_accessible_namespace( 53 | resource_attributes_template 54 | ) 55 | 56 | try: 57 | result = k8s_client.get_job( 58 | job_name=job_name, 59 | namespace=namespace, 60 | ) 61 | except ApiException as e: 62 | raise RuntimeError(f"Unexpected API error: {e.reason} ({e.status})") 63 | 64 | if not verbose: 65 | return self._format_output_to_keep_needed_fields(result) 66 | else: 67 | return self._format_verbose_output(result) 68 | 69 | def _format_output_to_keep_needed_fields(self, output): 70 | result = {} 71 | if output: 72 | if output.get("metadata"): 73 | result = { 74 | "Name": output.get("metadata").get("name"), 75 | "Namespace": output.get("metadata").get("namespace"), 76 | "Label": output.get("metadata").get("labels"), 77 | "CreationTimestamp": output.get("metadata").get( 78 | "creationTimestamp" 79 | ), 80 | } 81 | result.update({"Status": output.get("status")}) 82 | result.update({"ConsoleURL": utils.get_cluster_console_url()}) 83 | return json.dumps(result, indent=1, sort_keys=False) 84 | 85 | def _format_verbose_output(self, output): 86 | result = {} 87 | if output: 88 | if output.get("metadata"): 89 | result = { 90 | "Name": output.get("metadata").get("name"), 91 | "Namespace": output.get("metadata").get("namespace"), 92 | "Label": output.get("metadata").get("labels"), 93 | "Annotations": output.get("metadata").get("annotations"), 94 | "Metadata": { 95 | "CreationTimestamp": output.get("metadata").get( 96 | "creationTimestamp" 97 | ), 98 | "Generation": output.get("metadata").get("generation"), 99 | "ResourceVersion": output.get("metadata").get( 100 | "resourceVersion" 101 | ), 102 | "UID": output.get("metadata").get("uid"), 103 | }, 104 | } 105 | result.update({"Kind": output.get("kind")}) 106 | result.update({"ApiVersion": output.get("apiVersion")}) 107 | result.update({"Spec": output.get("spec")}) 108 | result.update({"Status": output.get("status")}) 109 | result.update({"ConsoleURL": utils.get_cluster_console_url()}) 110 | return json.dumps(result, indent=1, sort_keys=False) 111 | -------------------------------------------------------------------------------- /src/hyperpod_cli/service/list_pods.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from collections import defaultdict 14 | from typing import List, Optional 15 | import json 16 | 17 | from kubernetes.client import V1Pod, V1PodList 18 | 19 | from hyperpod_cli.clients.kubernetes_client import ( 20 | KubernetesClient, 21 | ) 22 | from kubernetes.client.rest import ApiException 23 | from kubernetes.client import ( 24 | V1ResourceAttributes 25 | ) 26 | 27 | from hyperpod_cli.constants.command_constants import NVIDIA_GPU_RESOURCE_LIMIT_KEY 28 | from hyperpod_cli.service.discover_namespaces import DiscoverNamespaces 29 | 30 | class ListPods: 31 | def __init__(self): 32 | return 33 | 34 | def list_pods_for_training_job( 35 | self, 36 | job_name: str, 37 | namespace: Optional[str], 38 | pretty: Optional[bool], 39 | ): 40 | """ 41 | List pods associated with a training job 42 | """ 43 | k8s_client = KubernetesClient() 44 | 45 | if not namespace: 46 | resource_attributes_template = V1ResourceAttributes( 47 | verb="list", 48 | group="", 49 | resource="pods", 50 | ) 51 | namespace = DiscoverNamespaces().discover_accessible_namespace( 52 | resource_attributes_template 53 | ) 54 | 55 | label_filter = f"training.kubeflow.org/job-name={job_name}" 56 | 57 | try: 58 | _pods: V1PodList = k8s_client.list_pods_with_labels(namespace, label_filter) 59 | except ApiException as e: 60 | raise RuntimeError(f"Unexpected API error: {e.reason} ({e.status})") 61 | 62 | if pretty: 63 | return self._generate_list_pods_output(_pods) 64 | else: 65 | return self._generate_pods_list(_pods) 66 | 67 | def list_pods_and_get_requested_resources_group_by_node_name( 68 | self, 69 | ): 70 | """ 71 | List pods for all namespaces and initialized by kubeflow. 72 | Group by the node_name of the pod and the value is the 73 | accelerator devices resources requested 74 | """ 75 | k8s_client = KubernetesClient() 76 | 77 | label_filter = f"training.kubeflow.org/job-name" 78 | pods = k8s_client.list_pods_in_all_namespaces_with_labels(label_filter) 79 | 80 | # Dictionary to hold total GPU/Neuron requests per node 81 | accelerator_devices_requests_by_node = defaultdict(int) 82 | 83 | # Loop through each pod and aggregate the GPU resource requests per node 84 | for pod in pods: 85 | node_name = pod.spec.node_name 86 | # Check if the pod has a node_name assigned, if pod is not scheduled, it may not have node_name 87 | if node_name: 88 | for container in pod.spec.containers: 89 | if container.resources and container.resources.requests: 90 | gpu_request = container.resources.requests.get(NVIDIA_GPU_RESOURCE_LIMIT_KEY) 91 | neuron_request = container.resources.requests.get( 92 | "aws.amazon.com/neurondevice" 93 | ) 94 | if gpu_request: 95 | accelerator_devices_requests_by_node[node_name] += int( 96 | gpu_request 97 | ) 98 | if neuron_request: 99 | accelerator_devices_requests_by_node[node_name] += int( 100 | neuron_request 101 | ) 102 | 103 | return accelerator_devices_requests_by_node 104 | 105 | def _generate_list_pods_output(self, pods: V1PodList) -> Optional[str]: 106 | output_pods = {"pods": []} 107 | if pods.items and len(pods.items) > 0: 108 | _pod: V1Pod 109 | for _pod in pods.items: 110 | if _pod.metadata and _pod.metadata.name and _pod.metadata.namespace: 111 | name = _pod.metadata.name 112 | namespace = _pod.metadata.namespace 113 | status = None 114 | creation_timestamp = None 115 | if _pod.status and _pod.status.phase: 116 | status = _pod.status.phase 117 | if _pod.metadata.creation_timestamp: 118 | creation_timestamp = str(_pod.metadata.creation_timestamp) 119 | output_pods["pods"].append( 120 | { 121 | "PodName": name, 122 | "Namespace": namespace, 123 | "Status": status, 124 | "CreationTime": creation_timestamp, 125 | } 126 | ) 127 | 128 | return json.dumps(output_pods, indent=1, sort_keys=False) 129 | 130 | def _generate_pods_list(self, pods: V1PodList) -> List: 131 | output_pods = [] 132 | if pods.items and len(pods.items) > 0: 133 | _pod: V1Pod 134 | for _pod in pods.items: 135 | if _pod.metadata and _pod.metadata.name: 136 | name = _pod.metadata.name 137 | output_pods.append(name) 138 | return output_pods 139 | -------------------------------------------------------------------------------- /src/hyperpod_cli/service/self_subject_access_review.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from hyperpod_cli.clients.kubernetes_client import ( 14 | KubernetesClient, 15 | ) 16 | from kubernetes.client.rest import ApiException 17 | from kubernetes.client import ( 18 | V1SelfSubjectAccessReview, 19 | V1SelfSubjectAccessReviewSpec, 20 | ) 21 | 22 | 23 | class SelfSubjectAccessReview: 24 | def __init__(self): 25 | return 26 | 27 | def self_subject_access_review( 28 | self, 29 | resource_attributes=None, 30 | non_resource_attributes=None, 31 | local_vars_configuration=None, 32 | ): 33 | """ 34 | Submit self subject access review 35 | """ 36 | auth_v1_api = KubernetesClient().get_auth_v1_api() 37 | 38 | access_review = V1SelfSubjectAccessReview( 39 | spec=V1SelfSubjectAccessReviewSpec( 40 | resource_attributes=resource_attributes, 41 | non_resource_attributes=non_resource_attributes, 42 | local_vars_configuration=local_vars_configuration, 43 | ) 44 | ) 45 | 46 | response = auth_v1_api.create_self_subject_access_review(body=access_review) 47 | 48 | return response 49 | -------------------------------------------------------------------------------- /src/hyperpod_cli/telemetry/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/hyperpod_cli/telemetry/user_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | """Placeholder docstring""" 14 | 15 | from __future__ import absolute_import 16 | 17 | import importlib.metadata 18 | 19 | 20 | CLI_PREFIX = "AWS-SageMaker-Hyperpod-CLI" 21 | 22 | 23 | def get_user_agent_extra_suffix(): 24 | """Get the user agent extra suffix string specific to SageMaker Hyperpod CLI 25 | 26 | Adhers to new boto recommended User-Agent 2.0 header format 27 | 28 | Returns: 29 | str: The user agent extra suffix string to be appended 30 | """ 31 | suffix = "cli/{}#{}".format( 32 | CLI_PREFIX, 33 | importlib.metadata.version("hyperpod"), 34 | ) 35 | 36 | return suffix 37 | -------------------------------------------------------------------------------- /src/hyperpod_cli/templates/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/hyperpod_cli/templates/k8s_pytorch_job_template.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | KUBERNETES_PYTORCH_JOB_TEMPLATE = """### Please keep template file unchanged ### 14 | defaults: 15 | - override hydra/job_logging: stdout 16 | 17 | hydra: 18 | run: 19 | dir: . 20 | output_subdir: null 21 | 22 | training_cfg: 23 | entry_script: ??? # Path to the entry script of training/fine-tuning. This path should be inside container or relative path in git repo 24 | script_args: ??? # Entry script arguments 25 | run: 26 | nodes: ??? # Number of nodes to use for current training 27 | ntasks_per_node: ??? # Number of tasks per node 28 | cluster: 29 | cluster_type: k8s # currently k8s only 30 | instance_type: ??? 31 | cluster_config: 32 | namespace: ??? # the namespace to submit job 33 | custom_labels: ??? 34 | service_account_name: null 35 | annotations: ??? 36 | priority_class_name: ??? 37 | # Create k8s NodeAffinity to select nodes to deploy jobs which matches required and preferred labels 38 | # Structure: 39 | # label_selector: 40 | # required: 41 | # preferred: 42 | # weights: 43 | # Example: 44 | # label_selector: 45 | # required: 46 | # example-label-key: 47 | # - expected-label-value-1 48 | # - expected-label-value-2 49 | # preferred: 50 | # preferred-label-key: 51 | # - preferred-label-value-1 52 | # - preferred-label-value-2 53 | # weights: 54 | # - 100 55 | label_selector: ??? 56 | # persistent volume, usually used to mount FSx 57 | persistent_volume_claims: null 58 | pullPolicy: ??? # policy to pull container, can be Always, IfNotPresent and Never 59 | restartPolicy: ??? # PyTorchJob restart policy 60 | # temp volume, usually used to mount temp directory 61 | # volumes, used to mount temp path to container 62 | # example: 63 | # volumes: 64 | # - volumeName: data1 65 | # hostPath: "/data" 66 | # mountPath: "/data" 67 | volumes: null 68 | 69 | base_results_dir: ??? # Location to store the results, checkpoints and logs. 70 | container: ??? # container to use 71 | 72 | env_vars: 73 | NCCL_DEBUG: INFO # Logging level for NCCL. Set to "INFO" for debug information 74 | """ 75 | -------------------------------------------------------------------------------- /src/hyperpod_cli/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import logging 14 | import re 15 | import json 16 | 17 | import boto3 18 | import botocore 19 | from botocore.config import Config 20 | 21 | from hyperpod_cli.constants.command_constants import ( 22 | GENERATED_LAUNCHER_CONFIG_FILE_PATH, 23 | HYPERPOD_CLUSTER_CONTEXT_FILE_NAME, 24 | ) 25 | 26 | 27 | def get_name_from_arn(arn: str) -> str: 28 | """ 29 | Parse the EKS cluster name from an EKS ARN. 30 | 31 | Args: 32 | arn (str): The ARN of the EKS cluster. 33 | 34 | Returns: str: The name of the EKS cluster if parsing is 35 | successful, otherwise raise RuntimeError. 36 | """ 37 | # Define the regex pattern to match the EKS ARN and capture the cluster name 38 | pattern = r"arn:aws:eks:[\w-]+:\d+:cluster/([\w-]+)" 39 | match = re.match(pattern, arn) 40 | 41 | if match: 42 | return match.group(1) 43 | else: 44 | raise RuntimeError("cannot get EKS cluster name") 45 | 46 | 47 | def setup_logger( 48 | name: str, 49 | logging_level: int = logging.ERROR, 50 | ) -> logging.Logger: 51 | """ 52 | Set up a logger with a console handler and a formatter. 53 | 54 | Args: 55 | name (str): The name of the logger. 56 | logging_level (int): The logging level. 57 | 58 | Returns: 59 | logging.Logger: The configured logger instance. 60 | """ 61 | # Create a logger 62 | logger = logging.getLogger(name) 63 | logger.setLevel(logging_level) # Set the log level to ERROR 64 | 65 | # Create a console handler 66 | console_handler = logging.StreamHandler() 67 | console_handler.setLevel(logging_level) # Set the log level for the handler 68 | 69 | # Create a formatter and set it for the handler 70 | formatter = logging.Formatter( 71 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s", 72 | datefmt="%Y-%m-%d %H:%M:%S", 73 | ) 74 | console_handler.setFormatter(formatter) 75 | 76 | # Add the handler to the logger 77 | logger.addHandler(console_handler) 78 | 79 | return logger 80 | 81 | 82 | def set_logging_level( 83 | logger: logging.Logger, 84 | logging_level: int, 85 | ): 86 | logger.setLevel(logging_level) 87 | logger.handlers[0].setLevel(logging_level) 88 | 89 | 90 | def get_sagemaker_client( 91 | session: boto3.Session, config: Config = None 92 | ) -> botocore.client.BaseClient: 93 | return session.client( 94 | service_name="sagemaker", 95 | config=config, 96 | ) 97 | 98 | 99 | def store_current_hyperpod_context(data): 100 | with open( 101 | GENERATED_LAUNCHER_CONFIG_FILE_PATH + HYPERPOD_CLUSTER_CONTEXT_FILE_NAME, 102 | "w", 103 | ) as hyperpod_current_context: 104 | hyperpod_current_context.write(json.dumps(data, indent=4, default=str)) 105 | 106 | 107 | def _retrieve_current_hyperpod_context(): 108 | with open( 109 | GENERATED_LAUNCHER_CONFIG_FILE_PATH + HYPERPOD_CLUSTER_CONTEXT_FILE_NAME, 110 | "r", 111 | ) as file: 112 | return json.load(file) 113 | 114 | 115 | def _validate_link(console_url): 116 | pattern = "https:\/\/([a-z0-9-]+).console.aws.amazon.com\/sagemaker\/home\?region=([a-z0-9-]+)#\/cluster-management\/([a-zA-Z0-9-]+)" 117 | match = re.match(pattern, console_url) 118 | if match: 119 | return True 120 | else: 121 | return False 122 | 123 | 124 | def validate_region_and_cluster_name(region, cluster_name): 125 | output = False 126 | region_char_list = region.split("-") 127 | 128 | if len(region_char_list) != 3: 129 | return False 130 | 131 | region_prefix_match = re.match("[a-z]+", region_char_list[0]) 132 | region_match = re.match("[a-z]+", region_char_list[1]) 133 | region_suffix_match = re.match("[0-9]+", region_char_list[2]) 134 | 135 | region_prefix_length = len(region_char_list[0]) 136 | region_length = len(region_char_list[1]) 137 | region_suffix_length = len(region_char_list[2]) 138 | 139 | cluster_name_match = re.match("[a-zA-Z0-9-]+", cluster_name) 140 | 141 | if ( 142 | region_prefix_match 143 | and region_match 144 | and region_suffix_match 145 | and region_prefix_length == 2 146 | and region_suffix_length == 1 147 | and region_length >= 4 148 | and region_length < 10 149 | and cluster_name_match 150 | and len(cluster_name) >= 1 151 | and len(cluster_name) <= 63 152 | ): 153 | output = True 154 | return output 155 | 156 | 157 | def get_cluster_console_url(): 158 | hyperpod_context_cluster = _retrieve_current_hyperpod_context() 159 | console_url = None 160 | if ( 161 | hyperpod_context_cluster 162 | and hyperpod_context_cluster.get("ClusterArn") 163 | and hyperpod_context_cluster.get("ClusterName") 164 | ): 165 | region = hyperpod_context_cluster.get("ClusterArn").split(":")[3] 166 | cluster_name = hyperpod_context_cluster.get("ClusterName") 167 | 168 | console_url = ( 169 | f"https://{region}.console.aws.amazon.com/sagemaker/" 170 | f"home?region={region}#/cluster-management/{cluster_name}" 171 | ) 172 | if _validate_link(console_url) and validate_region_and_cluster_name(region, cluster_name): 173 | return console_url 174 | return None 175 | 176 | def get_eks_cluster_name(): 177 | hyperpod_context_cluster = _retrieve_current_hyperpod_context() 178 | eks_cluster_arn = hyperpod_context_cluster.get("Orchestrator", {}).get("Eks", {}).get("ClusterArn", '') 179 | return eks_cluster_arn.split('cluster/')[-1] 180 | 181 | def get_hyperpod_cluster_region(): 182 | hyperpod_context_cluster = _retrieve_current_hyperpod_context() 183 | return hyperpod_context_cluster.get("ClusterArn").split(":")[3] -------------------------------------------------------------------------------- /src/hyperpod_cli/validators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /src/hyperpod_cli/validators/cluster_validator.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from typing import Optional 14 | 15 | import boto3 16 | from botocore.exceptions import ClientError 17 | 18 | from hyperpod_cli.utils import setup_logger 19 | from hyperpod_cli.validators.validator import ( 20 | Validator, 21 | ) 22 | 23 | logger = setup_logger(__name__) 24 | 25 | 26 | class ClusterValidator(Validator): 27 | def __init__(self): 28 | super().__init__() 29 | 30 | def validate_cluster_and_get_eks_arn( 31 | self, cluster_name: str, sm_client: boto3.client 32 | ) -> Optional[str]: 33 | try: 34 | hp_cluster_details = sm_client.describe_cluster(ClusterName=cluster_name) 35 | if ( 36 | "Orchestrator" not in hp_cluster_details 37 | or "Eks" not in hp_cluster_details["Orchestrator"] 38 | ): 39 | logger.warning( 40 | f"HyperPod cluster {cluster_name} Orchestrator not exist or is not Eks." 41 | ) 42 | return None 43 | return hp_cluster_details["Orchestrator"]["Eks"]["ClusterArn"] 44 | except ClientError as e: 45 | if e.response["Error"]["Code"] == "ResourceNotFoundException": 46 | logger.error(f"HyperPod cluster {cluster_name} not found.") 47 | else: 48 | logger.error(f"Validate HyperPod cluster {cluster_name} error: {e}") 49 | return None 50 | except Exception as e: 51 | logger.error(f"Validate HyperPod cluster {cluster_name} error: {e}") 52 | return None 53 | -------------------------------------------------------------------------------- /src/hyperpod_cli/validators/validator.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import boto3 14 | from botocore.exceptions import ( 15 | ClientError, 16 | NoCredentialsError, 17 | PartialCredentialsError, 18 | ) 19 | 20 | from hyperpod_cli.utils import setup_logger 21 | 22 | logger = setup_logger(__name__) 23 | 24 | 25 | class Validator: 26 | def __init__(self): 27 | return 28 | 29 | def validate(self): 30 | """ 31 | Abstract validate method to be implemented in subclasses. 32 | """ 33 | return NotImplementedError() 34 | 35 | def validate_aws_credential(self, session: boto3.Session) -> bool: 36 | """ 37 | Validate AWS credentials to ensure AWS credentials configured 38 | a valida credential exist in current session 39 | 40 | Returns: 41 | bool: True aws credentials are valid, False otherwise. 42 | """ 43 | try: 44 | # Check if credentials are available 45 | credentials = session.get_credentials() 46 | if not credentials: 47 | logger.error( 48 | "No AWS credentials found. Please configure your AWS credentials." 49 | ) 50 | return False 51 | 52 | # Get an STS client to check the credentials 53 | sts = session.client("sts") 54 | 55 | # Call get_caller_identity to validate credentials 56 | sts.get_caller_identity() 57 | 58 | logger.debug("AWS credentials are valid.") 59 | return True 60 | except (NoCredentialsError, PartialCredentialsError) as e: 61 | logger.error(f"No AWS credentials or partial AWS credentials provided: {e}") 62 | return False 63 | except ClientError as e: 64 | error_code = e.response["Error"]["Code"] 65 | if error_code == "ExpiredToken": 66 | logger.error( 67 | "AWS credentials have expired. Please refresh your AWS " 68 | "credentials" 69 | ) 70 | else: 71 | logger.error(f"Get credentials AWS client error: {e}") 72 | return False 73 | except Exception as e: 74 | logger.error(f"Unexpected error to get AWS credentials: {e}") 75 | return False 76 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/sagemaker-hyperpod-cli/84ba2448caaa47c82fa006373f13bbe97a4bb7e6/test/__init__.py -------------------------------------------------------------------------------- /test/integration_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /test/integration_tests/charts/hp-node-auth.yaml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | apiVersion: v1 14 | kind: Namespace 15 | metadata: 16 | name: hyperpod 17 | labels: 18 | name: hyperpod 19 | --- 20 | kind: ClusterRole 21 | apiVersion: rbac.authorization.k8s.io/v1 22 | metadata: 23 | name: hyperpod-node-manager-role 24 | ### 25 | # 1) add/list/describe/delete nodes 26 | # 2) add/delete/update labels 27 | # 3) cordon 28 | # 4) receive k8s events 29 | # 5) receive pod status change 30 | # 6) receive node status change 31 | # 7) get/list/watch/create/patch/update/delete/describe kubeflow pytroch job 32 | # 8) get pod log 33 | # 9) get/list/watch/create/patch/update/delete batch job 34 | ### 35 | rules: 36 | - resources: ["nodes"] 37 | verbs: ["*"] 38 | apiGroups: [""] 39 | # cloud controller permission reference 40 | # https://kubernetes.io/docs/concepts/architecture/cloud-controller/#authorization 41 | - apiGroups: [""] 42 | resources: ["nodes/status"] 43 | verbs: ["patch"] 44 | - apiGroups: [""] 45 | resources: ["events"] 46 | verbs: ["create", "patch", "update"] 47 | - apiGroups: [""] 48 | resources: ["services"] 49 | verbs: ["list", "patch", "update", "watch"] 50 | - apiGroups: [""] 51 | resources: ["serviceaccounts"] 52 | verbs: ["create"] 53 | - apiGroups: [""] 54 | resources: ["persistentvolumes"] 55 | verbs: ["get", "list", "watch", "update"] 56 | - apiGroups: [""] 57 | resources: ["endpoints"] 58 | verbs: ["get", "list", "watch", "create", "update"] 59 | # reference for csr approver permissions: https://github.com/postfinance/kubelet-csr-approver/blob/c5ca70db40ca5002e9d7c047eb7126049b97dbf6/deploy/k8s/clusterrole.yaml 60 | - apiGroups: ["certificates.k8s.io"] 61 | resources: ["certificatesigningrequests"] 62 | verbs: ["get", "list", "watch"] 63 | - apiGroups: ["certificates.k8s.io"] 64 | resources: ["certificatesigningrequests/approval"] 65 | verbs: ["update"] 66 | - apiGroups: ["certificates.k8s.io"] 67 | resources: ["signers"] 68 | resourceNames: ["kubernetes.io/kubelet-serving"] 69 | verbs: ["approve"] 70 | - apiGroups: ["authorization.k8s.io"] 71 | resources: ["subjectaccessreviews"] 72 | verbs: ["create"] 73 | # training job watcher permissions 74 | - apiGroups: [""] 75 | resources: ["nodes", "nodes/status", "pods", "pods/status"] 76 | verbs: ["get", "list", "watch"] 77 | - apiGroups: [""] 78 | resources: ["pods"] 79 | verbs: ["delete", "deletecollection"] 80 | - apiGroups: [""] 81 | resources: ["pods/log"] 82 | verbs: ["get", "list"] 83 | - apiGroups: [""] 84 | resources: ["nodes", "nodes/status"] 85 | verbs: ["patch"] 86 | - apiGroups: ["", "events.k8s.io"] 87 | resources: ["events"] 88 | verbs: ["create", "patch", "update"] 89 | - apiGroups: ["kubeflow.org"] 90 | resources: ["pytorchjobs", "pytorchjobs/status"] 91 | verbs: ["get", "list", "watch", "delete", "patch", "update", "describe"] 92 | - apiGroups: ["batch"] 93 | resources: ["jobs"] 94 | verbs: ["get", "list", "watch", "create", "delete", "patch", "update", "describe"] 95 | --- 96 | apiVersion: rbac.authorization.k8s.io/v1 97 | # This role binding allows "jane" to read pods in the "default" namespace. 98 | # You need to already have a Role named "pod-reader" in that namespace. 99 | kind: ClusterRoleBinding 100 | metadata: 101 | name: hyperpod-nodes 102 | namespace: kube-system 103 | subjects: 104 | # You can specify more than one "subject" 105 | - kind: Group 106 | name: hyperpod-node-manager # "name" is case sensitive 107 | apiGroup: rbac.authorization.k8s.io 108 | roleRef: 109 | # "roleRef" specifies the binding to a Role / ClusterRole 110 | kind: ClusterRole #this must be Role or ClusterRole 111 | name: hyperpod-node-manager-role # this must match the name of the Role or ClusterRole you wish to bind to 112 | apiGroup: rbac.authorization.k8s.io 113 | --- 114 | apiVersion: v1 115 | kind: ConfigMap 116 | metadata: 117 | name: aws-auth 118 | namespace: kube-system 119 | data: 120 | mapRoles: | 121 | - groups: 122 | - system:nodes 123 | - system:bootstrapers 124 | rolearn: SAGEMAKER_EXECUTION_ROLE 125 | username: system:node:hyperpod-{{SessionName}} 126 | - groups: 127 | - hyperpod-node-manager 128 | rolearn: SAGEMAKER_SERVICE_ROLE 129 | username: sagemaker-service 130 | mapUsers: | 131 | [] 132 | 133 | --- 134 | apiVersion: v1 135 | kind: ServiceAccount 136 | metadata: 137 | name: health-monitor 138 | namespace: hyperpod 139 | 140 | --- 141 | apiVersion: rbac.authorization.k8s.io/v1 142 | kind: ClusterRoleBinding 143 | metadata: 144 | name: health-monitor-binding 145 | roleRef: 146 | apiGroup: rbac.authorization.k8s.io 147 | kind: ClusterRole 148 | name: system:health-monitor 149 | subjects: 150 | - kind: ServiceAccount 151 | name: health-monitor 152 | namespace: hyperpod 153 | 154 | --- 155 | apiVersion: rbac.authorization.k8s.io/v1 156 | kind: ClusterRole 157 | metadata: 158 | labels: 159 | kubernetes.io/bootstrapping: rbac-defaults 160 | name: system:health-monitor 161 | rules: 162 | - apiGroups: 163 | - "" 164 | resources: 165 | - nodes 166 | verbs: 167 | - get 168 | - apiGroups: 169 | - "" 170 | resources: 171 | - nodes 172 | - nodes/status 173 | verbs: 174 | - patch 175 | - apiGroups: 176 | - "" 177 | - events.k8s.io 178 | resources: 179 | - events 180 | verbs: 181 | - create 182 | - patch 183 | - update 184 | 185 | --- 186 | apiVersion: v1 187 | kind: ServiceAccount 188 | metadata: 189 | name: burnin-test 190 | namespace: hyperpod 191 | 192 | --- 193 | apiVersion: rbac.authorization.k8s.io/v1 194 | kind: ClusterRole 195 | metadata: 196 | name: burnin-test 197 | rules: 198 | - apiGroups: 199 | - "" 200 | resources: 201 | - nodes 202 | verbs: 203 | - get 204 | - list 205 | - apiGroups: 206 | - "" 207 | resources: 208 | - pods 209 | verbs: 210 | - get 211 | - list 212 | 213 | --- 214 | apiVersion: rbac.authorization.k8s.io/v1 215 | kind: ClusterRoleBinding 216 | metadata: 217 | name: burnin-role-binding 218 | subjects: 219 | - kind: ServiceAccount 220 | name: burnin-test 221 | namespace: hyperpod 222 | roleRef: 223 | kind: ClusterRole 224 | name: burnin-test 225 | apiGroup: rbac.authorization.k8s.io -------------------------------------------------------------------------------- /test/integration_tests/cloudformation/resources.yaml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | AWSTemplateFormatVersion: '2010-09-09' 14 | Description: This template deploys a VPC, with three public and private subnets spread 15 | across three Availability Zones. It deploys an internet gateway, with a default 16 | route on the public subnets. It deploys a NAT gateway in each AZ, 17 | and default routes for them in the private subnets. 18 | 19 | Parameters: 20 | EKSClusterRoleArn: 21 | Description: Role used for creating eks cluster 22 | Type: String 23 | 24 | SubnetId1: 25 | Description: Subnets to attach EKS cluster to 26 | Type: String 27 | 28 | SubnetId2: 29 | Description: Subnets to attach EKS cluster to 30 | Type: String 31 | 32 | SecurityGroupId: 33 | Description: Security group to attach EKS cluster to 34 | Type: AWS::EC2::SecurityGroup::Id 35 | 36 | ClusterName: 37 | Description: EKS Cluster Name 38 | Type: String 39 | Default: 'hyperpod-eks' 40 | 41 | KubernetesVersion: 42 | Description: Kubernetes version to use for EKS cluster 43 | Type: String 44 | Default: '1.29' 45 | 46 | NetworkType: 47 | Description: IP version to use for EKS cluster 48 | Type: String 49 | Default: "ipv4" 50 | AllowedValues: 51 | - ipv4 52 | - ipv6 53 | ConstraintDescription: "Must be either ipv4 or ipv6" 54 | 55 | Resources: 56 | 57 | EKSCluster: 58 | Type: 'AWS::EKS::Cluster' 59 | Properties: 60 | Name: !Ref ClusterName 61 | Version: !Ref KubernetesVersion 62 | RoleArn: !Ref EKSClusterRoleArn 63 | AccessConfig: 64 | # For now, HyperPod requires config map to work 65 | AuthenticationMode: API_AND_CONFIG_MAP 66 | Logging: 67 | ClusterLogging: 68 | EnabledTypes: 69 | - Type: api 70 | - Type: audit 71 | - Type: authenticator 72 | - Type: controllerManager 73 | - Type: scheduler 74 | ResourcesVpcConfig: 75 | SubnetIds: 76 | - !Ref SubnetId1 77 | - !Ref SubnetId2 78 | SecurityGroupIds: 79 | - !Ref SecurityGroupId 80 | KubernetesNetworkConfig: 81 | IpFamily: !Ref NetworkType 82 | 83 | VpcCNIAddOn: 84 | Type: 'AWS::EKS::Addon' 85 | Properties: 86 | AddonName: vpc-cni 87 | ClusterName: !Ref EKSCluster 88 | ResolveConflicts: OVERWRITE 89 | 90 | KubeProxyAddOn: 91 | Type: 'AWS::EKS::Addon' 92 | Properties: 93 | AddonName: kube-proxy 94 | ClusterName: !Ref EKSCluster 95 | ResolveConflicts: OVERWRITE 96 | 97 | CoreDNSAddOn: 98 | Type: 'AWS::EKS::Addon' 99 | Properties: 100 | AddonName: coredns 101 | ClusterName: !Ref EKSCluster 102 | ResolveConflicts: OVERWRITE 103 | 104 | PodIdentityAddOn: 105 | Type: 'AWS::EKS::Addon' 106 | Properties: 107 | AddonName: eks-pod-identity-agent 108 | ClusterName: !Ref EKSCluster 109 | ResolveConflicts: OVERWRITE 110 | 111 | Outputs: 112 | 113 | ClusterArn: 114 | Description: The ARN of the EKS cluster 115 | Value: !GetAtt EKSCluster.Arn 116 | 117 | ClusterName: 118 | Description: The name of the EKS cluster 119 | Value: !Ref EKSCluster -------------------------------------------------------------------------------- /test/integration_tests/data/basicJob.yaml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | defaults: 14 | - override hydra/job_logging: stdout 15 | 16 | hydra: 17 | run: 18 | dir: . 19 | output_subdir: null 20 | 21 | training_cfg: 22 | entry_script: /opt/pytorch-mnist/mnist.py 23 | script_args: [] 24 | run: 25 | name: ${JOB_NAME} # Current run name 26 | nodes: 1 # Number of nodes to use for current training 27 | ntasks_per_node: 1 # Number of devices to use per node 28 | cluster: 29 | cluster_type: k8s # currently k8s only 30 | instance_type: ml.c5.2xlarge 31 | cluster_config: 32 | # name of service account associated with the namespace 33 | service_account_name: null 34 | # persistent volume, usually used to mount FSx 35 | persistent_volume_claims: null 36 | namespace: kubeflow 37 | # required node affinity to select nodes with HyperPod 38 | # labels and passed health check if burn-in enabled 39 | label_selector: 40 | required: 41 | sagemaker.amazonaws.com/node-health-status: 42 | - Schedulable 43 | preferred: 44 | sagemaker.amazonaws.com/deep-health-check-status: 45 | - Passed 46 | weights: 47 | - 100 48 | pullPolicy: IfNotPresent # policy to pull container, can be Always, IfNotPresent and Never 49 | restartPolicy: OnFailure # restart policy 50 | scheduler_type: None 51 | 52 | base_results_dir: ./result # Location to store the results, checkpoints and logs. 53 | container: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-bc09cfd # container to use 54 | 55 | env_vars: 56 | NCCL_DEBUG: INFO # Logging level for NCCL. Set to "INFO" for debug information -------------------------------------------------------------------------------- /test/integration_tests/data/basicJobWithQuota.yaml: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | defaults: 14 | - override hydra/job_logging: stdout 15 | 16 | hydra: 17 | run: 18 | dir: . 19 | output_subdir: null 20 | 21 | training_cfg: 22 | entry_script: /opt/pytorch-mnist/mnist.py 23 | script_args: [] 24 | run: 25 | name: hyperpod-cli-test-with-quota # Current run name 26 | nodes: 1 # Number of nodes to use for current training 27 | ntasks_per_node: 1 # Number of devices to use per node 28 | cluster: 29 | cluster_type: k8s # currently k8s only 30 | instance_type: ml.c5.2xlarge 31 | cluster_config: 32 | # name of service account associated with the namespace 33 | service_account_name: null 34 | # persistent volume, usually used to mount FSx 35 | persistent_volume_claims: null 36 | # required node affinity to select nodes with HyperPod 37 | # labels and passed health check if burn-in enabled 38 | label_selector: 39 | required: 40 | sagemaker.amazonaws.com/node-health-status: 41 | - Schedulable 42 | preferred: 43 | sagemaker.amazonaws.com/deep-health-check-status: 44 | - Passed 45 | weights: 46 | - 100 47 | pullPolicy: IfNotPresent # policy to pull container, can be Always, IfNotPresent and Never 48 | restartPolicy: OnFailure # restart policy 49 | scheduler_type: SageMaker 50 | base_results_dir: ./result # Location to store the results, checkpoints and logs. 51 | container: docker.io/kubeflowkatib/pytorch-mnist-cpu:v1beta1-bc09cfd # container to use 52 | 53 | env_vars: 54 | NCCL_DEBUG: INFO # Logging level for NCCL. Set to "INFO" for debug information -------------------------------------------------------------------------------- /test/integration_tests/lifecycle_script/on_create_noop.sh: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | #!/bin/bash 14 | 15 | set -ex 16 | 17 | LOG_FILE="/var/log/provision/provisioning.log" 18 | mkdir -p "/var/log/provision" 19 | touch $LOG_FILE 20 | 21 | # Function to log messages 22 | logger() { 23 | echo "$@" | tee -a $LOG_FILE 24 | } 25 | 26 | logger "[start] on_create.sh" 27 | logger "no more steps to run" 28 | logger "[stop] on_create.sh" -------------------------------------------------------------------------------- /test/unit_tests/service/test_cancel_training_job_service.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import unittest 14 | from unittest import mock 15 | from unittest.mock import MagicMock 16 | 17 | from hyperpod_cli.clients.kubernetes_client import ( 18 | KubernetesClient, 19 | ) 20 | from hyperpod_cli.service.cancel_training_job import ( 21 | CancelTrainingJob, 22 | ) 23 | 24 | from kubernetes.client.rest import ApiException 25 | 26 | class CancelTrainingJobTest(unittest.TestCase): 27 | def setUp(self): 28 | self.mock_cancel_training_job = CancelTrainingJob() 29 | self.mock_k8s_client = MagicMock(spec=KubernetesClient) 30 | 31 | @mock.patch("subprocess.run") 32 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 33 | def test_cancel_training_job_with_namespace( 34 | self, 35 | mock_kubernetes_client: mock.Mock, 36 | mock_subprocess_run: mock.Mock, 37 | ): 38 | mock_kubernetes_client.return_value = self.mock_k8s_client 39 | self.mock_k8s_client.delete_training_job.return_value = {"status": "Success"} 40 | result = self.mock_cancel_training_job.cancel_training_job( 41 | "sample-job", "namespace" 42 | ) 43 | self.assertIsNone(result) 44 | mock_subprocess_run.assert_called_once() 45 | 46 | @mock.patch("subprocess.run") 47 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 48 | def test_cancel_training_job_without_namespace( 49 | self, 50 | mock_kubernetes_client: mock.Mock, 51 | mock_subprocess_run: mock.Mock, 52 | ): 53 | mock_kubernetes_client.return_value = self.mock_k8s_client 54 | self.mock_k8s_client.get_current_context_namespace.return_value = "namespace" 55 | self.mock_k8s_client.delete_training_job.return_value = {"status": "Success"} 56 | result = self.mock_cancel_training_job.cancel_training_job("sample-job", None) 57 | self.assertIsNone(result) 58 | mock_subprocess_run.assert_called_once() 59 | 60 | @mock.patch("subprocess.run") 61 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 62 | def test_cancel_training_job_api_exception( 63 | self, 64 | mock_kubernetes_client: mock.Mock, 65 | mock_subprocess_run: mock.Mock, 66 | ): 67 | mock_kubernetes_client.return_value = self.mock_k8s_client 68 | self.mock_k8s_client.get_current_context_namespace.return_value = "namespace" 69 | self.mock_k8s_client.delete_training_job.side_effect = ApiException( 70 | status="Failed", reason="unexpected" 71 | ) 72 | with self.assertRaises(RuntimeError): 73 | self.mock_cancel_training_job.cancel_training_job("sample-job", None) 74 | mock_subprocess_run.assert_not_called() 75 | 76 | @mock.patch("hyperpod_cli.service.discover_namespaces.DiscoverNamespaces.discover_accessible_namespace") 77 | @mock.patch("subprocess.run") 78 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 79 | def test_cancel_training_job_auto_discover_namespace( 80 | self, 81 | mock_kubernetes_client: mock.Mock, 82 | mock_subprocess_run: mock.Mock, 83 | mock_discover_accessible_namespace: mock.Mock, 84 | ): 85 | mock_kubernetes_client.return_value = self.mock_k8s_client 86 | mock_discover_accessible_namespace.return_value = "discovered-namespace" 87 | self.mock_k8s_client.get_current_context_namespace.return_value = None 88 | self.mock_k8s_client.delete_training_job.return_value = {"status": "Success"} 89 | result = self.mock_cancel_training_job.cancel_training_job("sample-job", None) 90 | self.assertIsNone(result) 91 | mock_subprocess_run.assert_called_once() 92 | -------------------------------------------------------------------------------- /test/unit_tests/service/test_exec_command_service.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import unittest 14 | from unittest import mock 15 | from unittest.mock import MagicMock 16 | 17 | from hyperpod_cli.clients.kubernetes_client import ( 18 | KubernetesClient, 19 | ) 20 | from hyperpod_cli.service.exec_command import ( 21 | ExecCommand, 22 | ) 23 | from hyperpod_cli.service.list_pods import ( 24 | ListPods, 25 | ) 26 | 27 | from kubernetes.client.rest import ApiException 28 | 29 | 30 | class ExecCommandServiceTest(unittest.TestCase): 31 | def setUp(self): 32 | self.mock_exec_command = ExecCommand() 33 | self.mock_list_pods_service = MagicMock(spec=ListPods) 34 | self.mock_k8s_client = MagicMock(spec=KubernetesClient) 35 | 36 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 37 | @mock.patch("hyperpod_cli.service.list_pods.ListPods") 38 | @mock.patch("hyperpod_cli.service.list_pods.ListPods.list_pods_for_training_job") 39 | def test_exec_with_pod_without_namespace( 40 | self, 41 | mock_list_training_job_pods_service_with_list_pods: mock.Mock, 42 | mock_list_training_job_pods_service: mock.Mock, 43 | mock_kubernetes_client: mock.Mock, 44 | ): 45 | mock_kubernetes_client.return_value = self.mock_k8s_client 46 | self.mock_k8s_client.get_current_context_namespace.return_value = "kubeflow" 47 | self.mock_k8s_client.exec_command_on_pod.return_value = ( 48 | "Fri Aug 9 06:16:05 UTC 2024" 49 | ) 50 | mock_list_training_job_pods_service.return_value = self.mock_list_pods_service 51 | mock_list_training_job_pods_service_with_list_pods.return_value = [ 52 | "sample-job-master-0" 53 | ] 54 | result = self.mock_exec_command.exec_command( 55 | "sample-job", 56 | "sample-job-master-0", 57 | None, 58 | False, 59 | ( 60 | "date", 61 | ), 62 | ) 63 | self.assertIn("Fri Aug 9 06:16:05 UTC 2024", result) 64 | 65 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 66 | @mock.patch("hyperpod_cli.service.list_pods.ListPods") 67 | @mock.patch("hyperpod_cli.service.list_pods.ListPods.list_pods_for_training_job") 68 | def test_exec_with_pod_with_namespace( 69 | self, 70 | mock_list_training_job_pods_service_with_list_pods: mock.Mock, 71 | mock_list_training_job_pods_service: mock.Mock, 72 | mock_kubernetes_client: mock.Mock, 73 | ): 74 | mock_kubernetes_client.return_value = self.mock_k8s_client 75 | self.mock_k8s_client.exec_command_on_pod.return_value = ( 76 | "Fri Aug 9 06:16:05 UTC 2024" 77 | ) 78 | mock_list_training_job_pods_service.return_value = self.mock_list_pods_service 79 | mock_list_training_job_pods_service_with_list_pods.return_value = [ 80 | "sample-job-master-0" 81 | ] 82 | result = self.mock_exec_command.exec_command( 83 | "sample-job", 84 | "sample-job-master-0", 85 | "kubeflow", 86 | False, 87 | ( 88 | "date", 89 | ), 90 | ) 91 | self.assertIn("Fri Aug 9 06:16:05 UTC 2024", result) 92 | 93 | @mock.patch("hyperpod_cli.service.list_pods.ListPods") 94 | @mock.patch("hyperpod_cli.service.list_pods.ListPods.list_pods_for_training_job") 95 | def test_exec_with_pod_with_namespace_unknown_pod( 96 | self, 97 | mock_list_training_job_pods_service_with_list_pods: mock.Mock, 98 | mock_list_training_job_pods_service: mock.Mock, 99 | ): 100 | mock_list_training_job_pods_service.return_value = self.mock_list_pods_service 101 | mock_list_training_job_pods_service_with_list_pods.return_value = [ 102 | "sample-job-master-1" 103 | ] 104 | with self.assertRaises(RuntimeError): 105 | self.mock_exec_command.exec_command( 106 | "sample-job", 107 | "sample-job-master-0", 108 | "kubeflow", 109 | False, 110 | ( 111 | "date", 112 | ), 113 | ) 114 | 115 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 116 | @mock.patch("hyperpod_cli.service.list_pods.ListPods") 117 | @mock.patch("hyperpod_cli.service.list_pods.ListPods.list_pods_for_training_job") 118 | def test_exec_with_pod_with_namespace_all_pod( 119 | self, 120 | mock_list_training_job_pods_service_with_list_pods: mock.Mock, 121 | mock_list_training_job_pods_service: mock.Mock, 122 | mock_kubernetes_client: mock.Mock, 123 | ): 124 | mock_kubernetes_client.return_value = self.mock_k8s_client 125 | self.mock_k8s_client.exec_command_on_pod.return_value = ( 126 | "Fri Aug 9 06:16:05 UTC 2024" 127 | ) 128 | mock_list_training_job_pods_service.return_value = self.mock_list_pods_service 129 | mock_list_training_job_pods_service_with_list_pods.return_value = [ 130 | "sample-job-master-0" 131 | ] 132 | result = self.mock_exec_command.exec_command( 133 | "sample-job", 134 | None, 135 | "kubeflow", 136 | True, 137 | ( 138 | "date", 139 | ), 140 | ) 141 | self.assertIn("Fri Aug 9 06:16:05 UTC 2024", result) 142 | 143 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 144 | @mock.patch("hyperpod_cli.service.list_pods.ListPods") 145 | @mock.patch("hyperpod_cli.service.list_pods.ListPods.list_pods_for_training_job") 146 | def test_exec_with_pod_with_namespace_all_pod_api_exception( 147 | self, 148 | mock_list_training_job_pods_service_with_list_pods: mock.Mock, 149 | mock_list_training_job_pods_service: mock.Mock, 150 | mock_kubernetes_client: mock.Mock, 151 | ): 152 | mock_kubernetes_client.return_value = self.mock_k8s_client 153 | self.mock_k8s_client.exec_command_on_pod.side_effect = ApiException( 154 | status="Failed", reason="unexpected" 155 | ) 156 | mock_list_training_job_pods_service.return_value = self.mock_list_pods_service 157 | mock_list_training_job_pods_service_with_list_pods.return_value = [ 158 | "sample-job-master-0" 159 | ] 160 | with self.assertRaises(RuntimeError): 161 | self.mock_exec_command.exec_command( 162 | "sample-job", 163 | None, 164 | "kubeflow", 165 | True, 166 | ( 167 | "date", 168 | ), 169 | ) 170 | -------------------------------------------------------------------------------- /test/unit_tests/service/test_get_logs_service.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import unittest 14 | from unittest import mock 15 | from unittest.mock import MagicMock 16 | 17 | from hyperpod_cli.clients.kubernetes_client import ( 18 | KubernetesClient, 19 | ) 20 | from hyperpod_cli.service.get_logs import GetLogs 21 | from hyperpod_cli.service.list_pods import ( 22 | ListPods, 23 | ) 24 | 25 | from kubernetes.client.rest import ApiException 26 | 27 | 28 | class TestGetLogs(unittest.TestCase): 29 | def setUp(self): 30 | self.mock_get_logs = GetLogs() 31 | self.mock_list_pods_service = MagicMock(spec=ListPods) 32 | self.mock_k8s_client = MagicMock(spec=KubernetesClient) 33 | 34 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 35 | @mock.patch("hyperpod_cli.service.list_pods.ListPods") 36 | @mock.patch("hyperpod_cli.service.list_pods.ListPods.list_pods_for_training_job") 37 | def test_get_logs_with_namespace( 38 | self, 39 | mock_list_training_job_pods_service_with_list_pods: mock.Mock, 40 | mock_list_training_job_pods_service: mock.Mock, 41 | mock_kubernetes_client: mock.Mock, 42 | ): 43 | mock_kubernetes_client.return_value = self.mock_k8s_client 44 | mock_list_training_job_pods_service.return_value = self.mock_list_pods_service 45 | mock_list_training_job_pods_service_with_list_pods.return_value = ["test-pod"] 46 | self.mock_k8s_client.get_logs_for_pod.return_value = "test logs" 47 | result = self.mock_get_logs.get_training_job_logs( 48 | "sample-job", "test-pod", "kubeflow" 49 | ) 50 | self.assertIn("test logs", result) 51 | 52 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 53 | @mock.patch("hyperpod_cli.service.list_pods.ListPods") 54 | @mock.patch("hyperpod_cli.service.list_pods.ListPods.list_pods_for_training_job") 55 | def test_get_logs_without_namespace( 56 | self, 57 | mock_list_training_job_pods_service_with_list_pods: mock.Mock, 58 | mock_list_training_job_pods_service: mock.Mock, 59 | mock_kubernetes_client: mock.Mock, 60 | ): 61 | mock_kubernetes_client.return_value = self.mock_k8s_client 62 | mock_list_training_job_pods_service.return_value = self.mock_list_pods_service 63 | self.mock_k8s_client.get_current_context_namespace.return_value = "kubeflow" 64 | mock_list_training_job_pods_service_with_list_pods.return_value = ["test-pod"] 65 | self.mock_k8s_client.get_logs_for_pod.return_value = "test logs" 66 | result = self.mock_get_logs.get_training_job_logs( 67 | "sample-job", "test-pod", None 68 | ) 69 | self.assertIn("test logs", result) 70 | 71 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 72 | @mock.patch("hyperpod_cli.service.list_pods.ListPods") 73 | @mock.patch("hyperpod_cli.service.list_pods.ListPods.list_pods_for_training_job") 74 | def test_get_logs_pod_not_found_for_job( 75 | self, 76 | mock_list_training_job_pods_service_with_list_pods: mock.Mock, 77 | mock_list_training_job_pods_service: mock.Mock, 78 | mock_kubernetes_client: mock.Mock, 79 | ): 80 | mock_kubernetes_client.return_value = self.mock_k8s_client 81 | mock_list_training_job_pods_service.return_value = self.mock_list_pods_service 82 | self.mock_k8s_client.get_current_context_namespace.return_value = "kubeflow" 83 | mock_list_training_job_pods_service_with_list_pods.return_value = ["test-pod"] 84 | self.mock_k8s_client.get_logs_for_pod.return_value = "test logs" 85 | with self.assertRaises(RuntimeError): 86 | self.mock_get_logs.get_training_job_logs("sample-job", "test-pod1", None) 87 | 88 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 89 | @mock.patch("hyperpod_cli.service.list_pods.ListPods") 90 | @mock.patch("hyperpod_cli.service.list_pods.ListPods.list_pods_for_training_job") 91 | def test_get_logs_without_namespace_api_exception( 92 | self, 93 | mock_list_training_job_pods_service_with_list_pods: mock.Mock, 94 | mock_list_training_job_pods_service: mock.Mock, 95 | mock_kubernetes_client: mock.Mock, 96 | ): 97 | mock_kubernetes_client.return_value = self.mock_k8s_client 98 | mock_list_training_job_pods_service.return_value = self.mock_list_pods_service 99 | self.mock_k8s_client.get_current_context_namespace.return_value = "kubeflow" 100 | mock_list_training_job_pods_service_with_list_pods.return_value = ["test-pod"] 101 | self.mock_k8s_client.get_logs_for_pod.side_effect = ApiException( 102 | status="Failed", reason="unexpected" 103 | ) 104 | with self.assertRaises(RuntimeError): 105 | self.mock_get_logs.get_training_job_logs("sample-job", "test-pod", None) 106 | 107 | @mock.patch("hyperpod_cli.service.discover_namespaces.DiscoverNamespaces.discover_accessible_namespace") 108 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 109 | @mock.patch("hyperpod_cli.service.list_pods.ListPods") 110 | @mock.patch("hyperpod_cli.service.list_pods.ListPods.list_pods_for_training_job") 111 | def test_get_logs_auto_discover_namespace( 112 | self, 113 | mock_list_training_job_pods_service_with_list_pods: mock.Mock, 114 | mock_list_training_job_pods_service: mock.Mock, 115 | mock_kubernetes_client: mock.Mock, 116 | mock_discover_accessible_namespace: mock.Mock, 117 | ): 118 | mock_kubernetes_client.return_value = self.mock_k8s_client 119 | mock_list_training_job_pods_service.return_value = self.mock_list_pods_service 120 | self.mock_k8s_client.get_current_context_namespace.return_value = None 121 | mock_discover_accessible_namespace.return_value = "discovered-namespace" 122 | mock_list_training_job_pods_service_with_list_pods.return_value = ["test-pod"] 123 | self.mock_k8s_client.get_logs_for_pod.return_value = "test logs" 124 | result = self.mock_get_logs.get_training_job_logs( 125 | "sample-job", "test-pod", None 126 | ) 127 | self.assertIn("test logs", result) 128 | 129 | def test_get_log_url( 130 | self, 131 | ): 132 | eks_cluster_name = 'eks_cluster_name' 133 | region = 'us-west-2' 134 | node_name = 'node_name' 135 | pod_name = 'pod_name' 136 | namespace = 'namespace' 137 | container_name = 'container_name' 138 | container_id = 'container_id' 139 | result_url = self.mock_get_logs.get_log_url(eks_cluster_name, region, node_name, pod_name, namespace, container_name, container_id) 140 | 141 | self.assertEqual( 142 | result_url, 143 | 'https://us-west-2.console.aws.amazon.com/cloudwatch/home?region=us-west-2#logsV2:log-groups/log-group/$252Faws$252Fcontainerinsights$252Feks_cluster_name$252Fapplication/log-events/node_name-application.var.log.containers.pod_name_namespace_container_name-container_id.log' 144 | ) 145 | -------------------------------------------------------------------------------- /test/unit_tests/service/test_get_namespaces.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import unittest 14 | from unittest import mock 15 | from unittest.mock import MagicMock 16 | 17 | from hyperpod_cli.clients.kubernetes_client import KubernetesClient 18 | from hyperpod_cli.service.get_namespaces import GetNamespaces 19 | 20 | 21 | class TestGetNamespaces(unittest.TestCase): 22 | 23 | def setUp(self): 24 | self.mock_k8s_client = MagicMock(spec=KubernetesClient) 25 | self.mock_get_namespaces = GetNamespaces() 26 | 27 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 28 | def test_get_namespaces_success(self, mock_kubernetes_client): 29 | mock_core_v1_api = MagicMock() 30 | mock_kubernetes_client.return_value = self.mock_k8s_client 31 | self.mock_k8s_client.get_core_v1_api.return_value = mock_core_v1_api 32 | 33 | mock_namespace_list = MagicMock() 34 | mock_namespace1 = MagicMock() 35 | mock_namespace1.metadata.name = "namespace1" 36 | 37 | mock_namespace2 = MagicMock() 38 | mock_namespace2.metadata.name = "namespace2" 39 | mock_namespace_list.items = [mock_namespace1, mock_namespace2] 40 | 41 | mock_namespace_list.metadata._continue = None 42 | mock_core_v1_api.list_namespace.return_value = mock_namespace_list 43 | 44 | namespaces = self.mock_get_namespaces.get_namespaces() 45 | 46 | self.assertEqual(namespaces, ["namespace1", "namespace2"]) 47 | 48 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 49 | def test_get_namespaces_pagination(self, mock_kubernetes_client): 50 | mock_kubernetes_client.return_value = self.mock_k8s_client 51 | mock_core_v1_api = MagicMock() 52 | self.mock_k8s_client.get_core_v1_api.return_value = mock_core_v1_api 53 | 54 | # First response with a continue token 55 | first_response = MagicMock() 56 | mock_namespace1 = MagicMock() 57 | mock_namespace1.metadata.name = "namespace1" 58 | first_response.items = [mock_namespace1] 59 | first_response.metadata._continue = "next-token" 60 | 61 | # Second response without a continue token 62 | second_response = MagicMock() 63 | mock_namespace2 = MagicMock() 64 | mock_namespace2.metadata.name = "namespace2" 65 | second_response.items = [mock_namespace2] 66 | second_response.metadata._continue = None 67 | 68 | mock_core_v1_api.list_namespace.side_effect = [ 69 | first_response, second_response 70 | ] 71 | 72 | namespaces = self.mock_get_namespaces.get_namespaces() 73 | self.assertEqual(namespaces, ["namespace1", "namespace2"]) 74 | self.assertEqual(mock_core_v1_api.list_namespace.call_count, 2) 75 | 76 | 77 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 78 | def test_get_sagemaker_managed_namespaces(self, mock_kubernetes_client): 79 | mock_kubernetes_client.return_value = self.mock_k8s_client 80 | mock_core_v1_api = MagicMock() 81 | self.mock_k8s_client.get_core_v1_api.return_value = mock_core_v1_api 82 | 83 | first_response = MagicMock() 84 | mock_namespace1 = MagicMock() 85 | mock_namespace1.metadata.name = "namespace1" 86 | first_response.items = [mock_namespace1] 87 | first_response.metadata._continue = None 88 | 89 | mock_core_v1_api.list_namespace.side_effect = [ 90 | first_response 91 | ] 92 | 93 | namespaces = self.mock_get_namespaces.get_sagemaker_managed_namespaces() 94 | self.assertEqual(namespaces, ["namespace1"]) 95 | mock_core_v1_api.list_namespace.assert_called_once_with( 96 | limit=200, 97 | label_selector="sagemaker.amazonaws.com/sagemaker-managed-queue=true" 98 | ) 99 | 100 | 101 | if __name__ == "__main__": 102 | unittest.main() 103 | -------------------------------------------------------------------------------- /test/unit_tests/service/test_self_subject_access_review.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import unittest 14 | from unittest import mock 15 | from unittest.mock import MagicMock 16 | from kubernetes.client import V1SelfSubjectAccessReview 17 | from kubernetes.client import AuthorizationV1Api 18 | from hyperpod_cli.service.self_subject_access_review import SelfSubjectAccessReview 19 | 20 | 21 | class TestSelfSubjectAccessReview(unittest.TestCase): 22 | 23 | @mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__") 24 | def test_self_subject_access_review_success(self, mock_kubernetes_client): 25 | # Mock the Kubernetes API response 26 | mock_auth_v1_api = MagicMock(spec=AuthorizationV1Api) 27 | mock_kubernetes_client().get_auth_v1_api.return_value = mock_auth_v1_api 28 | 29 | # Create a mock response for create_self_subject_access_review 30 | mock_response = MagicMock(spec=V1SelfSubjectAccessReview) 31 | mock_response.status.allowed = True 32 | mock_auth_v1_api.create_self_subject_access_review.return_value = mock_response 33 | 34 | # Set up resource attributes for the access review 35 | resource_attributes = { 36 | "namespace": "test-namespace", 37 | "verb": "create", 38 | "resource": "pods" 39 | } 40 | 41 | # Instantiate the SelfSubjectAccessReview service 42 | service = SelfSubjectAccessReview() 43 | 44 | # Call the self_subject_access_review method 45 | response = service.self_subject_access_review(resource_attributes=resource_attributes) 46 | 47 | # Check that the response indicates the action is allowed 48 | self.assertTrue(response.status.allowed) 49 | 50 | # Assert that the Kubernetes API was called correctly 51 | mock_auth_v1_api.create_self_subject_access_review.assert_called_once() 52 | 53 | # Capture and inspect the arguments passed in the call 54 | args, kwargs = mock_auth_v1_api.create_self_subject_access_review.call_args 55 | self.assertIsInstance(kwargs['body'], V1SelfSubjectAccessReview) 56 | self.assertEqual(kwargs['body'].spec.resource_attributes['namespace'], 'test-namespace') 57 | 58 | 59 | if __name__ == "__main__": 60 | unittest.main() 61 | -------------------------------------------------------------------------------- /test/unit_tests/test_hyperpod_cli.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | from click.testing import CliRunner 14 | 15 | from hyperpod_cli.cli import cli 16 | 17 | 18 | def test_hyperpod_cli_importable(): 19 | import hyperpod_cli # noqa: F401 20 | 21 | 22 | def test_cli_init(): 23 | runner = CliRunner() 24 | result = runner.invoke(cli) 25 | assert result.exit_code == 0 26 | assert not result.exception 27 | 28 | 29 | def test_cli_help(): 30 | runner = CliRunner() 31 | result = runner.invoke(cli, ["--help"]) 32 | assert result.exit_code == 0 33 | assert "Basic Commands:" in result.output 34 | -------------------------------------------------------------------------------- /test/unit_tests/validators/test_cluster_validatory.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import unittest 14 | from unittest.mock import MagicMock, patch 15 | 16 | from botocore.exceptions import ClientError 17 | 18 | from hyperpod_cli.validators.cluster_validator import ( 19 | ClusterValidator, 20 | ) 21 | 22 | 23 | class TestClusterValidator(unittest.TestCase): 24 | def setUp(self): 25 | self.validator = ClusterValidator() 26 | self.mock_sm_client = MagicMock() 27 | 28 | @patch("boto3.client") 29 | def test_validate_cluster_and_get_eks_arn_success(self, mock_boto3_client): 30 | cluster_name = "my-cluster" 31 | eks_arn = "arn:aws:eks:us-west-2:123456789012:cluster/my-cluster" 32 | mock_describe_cluster_response = { 33 | "Orchestrator": {"Eks": {"ClusterArn": eks_arn}} 34 | } 35 | mock_sm_client = MagicMock() 36 | mock_sm_client.describe_cluster.return_value = mock_describe_cluster_response 37 | mock_boto3_client.return_value = mock_sm_client 38 | 39 | result = self.validator.validate_cluster_and_get_eks_arn( 40 | cluster_name, mock_sm_client 41 | ) 42 | self.assertEqual(result, eks_arn) 43 | 44 | @patch("boto3.client") 45 | def test_validate_cluster_and_get_eks_arn_non_eks_cluster(self, mock_boto3_client): 46 | cluster_name = "my-cluster" 47 | mock_describe_cluster_response = {"Orchestrator": "SomeOtherOrchestrator"} 48 | mock_sm_client = MagicMock() 49 | mock_sm_client.describe_cluster.return_value = mock_describe_cluster_response 50 | mock_boto3_client.return_value = mock_sm_client 51 | 52 | result = self.validator.validate_cluster_and_get_eks_arn( 53 | cluster_name, mock_sm_client 54 | ) 55 | self.assertIsNone(result) 56 | 57 | @patch("boto3.client") 58 | def test_validate_cluster_and_get_eks_arn_resource_not_found( 59 | self, mock_boto3_client 60 | ): 61 | cluster_name = "my-cluster" 62 | mock_sm_client = MagicMock() 63 | mock_sm_client.describe_cluster.side_effect = ClientError( 64 | error_response={"Error": {"Code": "ResourceNotFoundException"}}, 65 | operation_name="DescribeCluster", 66 | ) 67 | mock_boto3_client.return_value = mock_sm_client 68 | 69 | result = self.validator.validate_cluster_and_get_eks_arn( 70 | cluster_name, mock_sm_client 71 | ) 72 | self.assertIsNone(result) 73 | 74 | @patch("boto3.client") 75 | def test_validate_cluster_and_get_eks_arn_other_client_error( 76 | self, mock_boto3_client 77 | ): 78 | cluster_name = "my-cluster" 79 | mock_sm_client = MagicMock() 80 | mock_sm_client.describe_cluster.side_effect = ClientError( 81 | error_response={"Error": {"Code": "SomeOtherError"}}, 82 | operation_name="DescribeCluster", 83 | ) 84 | mock_boto3_client.return_value = mock_sm_client 85 | 86 | result = self.validator.validate_cluster_and_get_eks_arn( 87 | cluster_name, mock_sm_client 88 | ) 89 | self.assertIsNone(result) 90 | 91 | @patch("boto3.client") 92 | def test_validate_cluster_and_get_eks_arn_other_exception(self, mock_boto3_client): 93 | cluster_name = "my-cluster" 94 | mock_sm_client = MagicMock() 95 | mock_sm_client.describe_cluster.side_effect = Exception("Some other exception") 96 | mock_boto3_client.return_value = mock_sm_client 97 | 98 | result = self.validator.validate_cluster_and_get_eks_arn( 99 | cluster_name, mock_sm_client 100 | ) 101 | self.assertIsNone(result) 102 | -------------------------------------------------------------------------------- /test/unit_tests/validators/test_validator.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). You 4 | # may not use this file except in compliance with the License. A copy of 5 | # the License is located at 6 | # 7 | # http://aws.amazon.com/apache2.0/ 8 | # 9 | # or in the "license" file accompanying this file. This file is 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 11 | # ANY KIND, either express or implied. See the License for the specific 12 | # language governing permissions and limitations under the License. 13 | import unittest 14 | from unittest.mock import MagicMock, patch 15 | 16 | from botocore.exceptions import ( 17 | ClientError, 18 | NoCredentialsError, 19 | ) 20 | 21 | from hyperpod_cli.validators.validator import ( 22 | Validator, 23 | ) 24 | 25 | 26 | class TestValidator(unittest.TestCase): 27 | def setUp(self): 28 | self.validator = Validator() 29 | 30 | def test_validate_need_implement(self): 31 | self.validator.validate() 32 | 33 | @patch("boto3.Session") 34 | def test_validate_aws_credential_success(self, mock_session): 35 | mock_session.get_credentials.return_value = True 36 | mock_sts_client = MagicMock() 37 | mock_session.client.return_value = mock_sts_client 38 | result = self.validator.validate_aws_credential(mock_session) 39 | self.assertTrue(result) 40 | 41 | @patch("boto3.Session") 42 | def test_validate_aws_credential_no_credentials(self, mock_session): 43 | mock_session.get_credentials.return_value = None 44 | result = self.validator.validate_aws_credential(mock_session) 45 | self.assertFalse(result) 46 | 47 | @patch("boto3.Session") 48 | def test_validate_aws_credential_no_credentials_error(self, mock_session): 49 | mock_session.get_credentials.side_effect = NoCredentialsError() 50 | result = self.validator.validate_aws_credential(mock_session) 51 | self.assertFalse(result) 52 | 53 | @patch("boto3.Session") 54 | def test_validate_aws_credential_expired_token_error(self, mock_session): 55 | mock_sts_client = MagicMock() 56 | mock_sts_client.get_caller_identity.side_effect = ClientError( 57 | {"Error": {"Code": "ExpiredToken"}}, 58 | "operation", 59 | ) 60 | mock_session.client.return_value = mock_sts_client 61 | result = self.validator.validate_aws_credential(mock_session) 62 | self.assertFalse(result) 63 | 64 | @patch("boto3.Session") 65 | def test_validate_aws_credential_other_client_error(self, mock_session): 66 | mock_sts_client = MagicMock() 67 | mock_sts_client.get_caller_identity.side_effect = ClientError( 68 | {"Error": {"Code": "SomeOtherError"}}, 69 | "operation", 70 | ) 71 | mock_session.client.return_value = mock_sts_client 72 | result = self.validator.validate_aws_credential(mock_session) 73 | self.assertFalse(result) 74 | 75 | @patch("boto3.Session") 76 | def test_validate_aws_credential_unexpected_error(self, mock_session): 77 | mock_sts_client = MagicMock() 78 | mock_sts_client.get_caller_identity.side_effect = Exception("Unexpected") 79 | mock_session.client.return_value = mock_sts_client 80 | result = self.validator.validate_aws_credential(mock_session) 81 | self.assertFalse(result) 82 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | isolated_build = true 3 | envlist = py38,py39,py310,py311 4 | 5 | skip_missing_interpreters = False 6 | 7 | [testenv] 8 | # install pytest in the virtualenv where commands will be executed 9 | # Click unit test tool has some issue with UTF-8 encoding 10 | # put locale to C to run pytest to make it pass 11 | setenv = 12 | LC_ALL = C 13 | deps = pytest 14 | commands = 15 | pytest 16 | 17 | [testenv:unit] 18 | description = Run unit tests 19 | commands = 20 | pytest test/unit_tests 21 | 22 | [testenv:integ] 23 | description = Run integration tests 24 | commands = 25 | pytest test/integration_tests 26 | --------------------------------------------------------------------------------