├── .github └── workflows │ └── publish.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── README.md ├── _version.md ├── cli.md ├── cluster_config.md ├── commands │ ├── cluster.md │ ├── database.md │ ├── index.md │ └── job.md ├── config.md ├── connection.md ├── executor.md ├── job.md ├── types.md └── utils.md ├── pyproject.toml ├── requirements.dev.txt ├── setup.py └── torch_submit ├── __init__.py ├── cli.py ├── commands ├── __init__.py ├── cluster.py ├── database.py └── job.py ├── config.py ├── connection.py ├── executor.py ├── job.py ├── types.py └── utils.py /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install build 25 | - name: Build package 26 | run: python -m build 27 | - name: Publish package 28 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 29 | with: 30 | user: __token__ 31 | password: ${{ secrets.PYPI_API_TOKEN }} 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python cache files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | dist/ 8 | build/ 9 | *.egg-info/ 10 | *.egg 11 | 12 | # Virtual environments 13 | venv/ 14 | env/ 15 | .venv/ 16 | .env/ 17 | 18 | # IDEs and editors 19 | .vscode/ 20 | .idea/ 21 | *.swp 22 | *.swo 23 | 24 | # Operating system files 25 | .DS_Store 26 | Thumbs.db 27 | 28 | # Project-specific generated files 29 | torch_submit/_version.py 30 | 31 | # Logs 32 | *.log 33 | 34 | # Test coverage 35 | .coverage 36 | htmlcov/ 37 | 38 | # Temporary files 39 | *.tmp 40 | *.bak 41 | 42 | # Jupyter Notebook 43 | .ipynb_checkpoints 44 | 45 | # pyenv 46 | .python-version 47 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/astral-sh/ruff-pre-commit 9 | rev: v0.5.1 10 | hooks: 11 | - id: ruff 12 | args: [ --fix ] 13 | - id: ruff-format 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 Dream3D, Inc. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Torch Submit 2 | 3 | ## Introduction 4 | 5 | Torch Submit is a lightweight, easy-to-use tool for running distributed PyTorch jobs across multiple machines. It's designed for researchers and developers who: 6 | 7 | - Have access to a bunch of machines with IP addresses 8 | - Want to run distributed PyTorch jobs without the hassle 9 | - Don't have the time, energy, or patience to set up complex cluster management systems like SLURM or Kubernetes 10 | 11 | Under the hood, Torch Submit uses Fabric to copy your working directory to the remote addresses and TorchRun to execute the command. 12 | 13 | It's encouraged to read `torch_submit/executor.py` to understand how jobs are created and scheduled. 14 | 15 | ## Features 16 | 17 | - Simple cluster configuration: Just add your machines' IP addresses 18 | - Easy job submission: Run your PyTorch jobs with a single command 19 | - Job management: Submit, stop, restart, and monitor your jobs 20 | - Log tailing: Easily view the logs of your running jobs 21 | - Optuna Integration for parallel hyperparameter optimization 22 | 23 | ## Installation 24 | 25 | ```bash 26 | pip install torch-submit 27 | ``` 28 | 29 | or from source: 30 | 31 | ```bash 32 | pip install -e . --prefix ~/.local 33 | ``` 34 | 35 | ## Quick Start 36 | 37 | 1. Set up a cluster: 38 | ```bash 39 | torch-submit cluster create 40 | ``` 41 | Follow the interactive prompts to add your machines. 42 | 43 | 2. Submit a job: 44 | ```bash 45 | torch-submit job submit --cluster my_cluster -- 46 | # for example: 47 | # torch-submit job submit --cluster my_cluster -- python train.py 48 | # torch-submit job submit --cluster my_cluster -- python -m main.train 49 | ``` 50 | 51 | 3. List running jobs: 52 | ```bash 53 | torch-submit job list 54 | ``` 55 | 56 | 4. Tail logs: 57 | ```bash 58 | torch-submit logs tail 59 | ``` 60 | 61 | 5. Stop a job: 62 | ```bash 63 | torch-submit job stop 64 | ``` 65 | 66 | 6. Restart a stopped job: 67 | ```bash 68 | torch-submit job restart 69 | ``` 70 | 71 | ## Usage 72 | 73 | ### Cluster Management 74 | 75 | - Create a cluster: `torch-submit cluster create` 76 | - List clusters: `torch-submit cluster list` 77 | - Remove a cluster: `torch-submit cluster remove ` 78 | 79 | ### Job Management 80 | 81 | - Submit a job: `torch-submit job submit --cluster my_cluster -- ` 82 | - List jobs: `torch-submit job list` 83 | - Stop a job: `torch-submit job stop ` 84 | - Restart a job: `torch-submit job restart ` 85 | 86 | ### Log Management 87 | 88 | - Tail logs: `torch-submit job logs ` 89 | 90 | ### Optuna 91 | 92 | The Optuna exectuor requires setting a database connection. This can be done via `torch-submit db create`. This will create a new database within the specified connection called `torch_submit`. This database should be accessible to all machines in a cluster. Study name and storage info will be accessible to to the job via "OPTUNA_STUDY_NAME" and "OPTUNA_STORAGE" environment variables. 93 | 94 | ## Configuration 95 | 96 | Torch Submit stores cluster configurations in `~/.cache/torch-submit/config.yaml`. You can manually edit this file if needed, but it's recommended to use the CLI commands for cluster management. 97 | 98 | ## Requirements 99 | 100 | - Python 3.7+ 101 | - PyTorch (for your actual jobs) 102 | - SSH access to all machines in your cluster 103 | 104 | ## Contributing 105 | 106 | We welcome contributions! Please see our Contributing Guide for more details. 107 | 108 | ## License 109 | 110 | Torch Submit is released under the MIT License. See the LICENSE file for more details. 111 | 112 | ## Support 113 | 114 | If you encounter any issues or have questions, please file an issue on our GitHub Issues page. 115 | 116 | Happy distributed training! 117 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Torch-submit Index 2 | 3 | > Auto-generated documentation index. 4 | 5 | A full list of `Torch-submit` project modules. 6 | 7 | - [Version](./_version.md#version) 8 | - [Cli](./cli.md#cli) 9 | - [Commands](commands/index.md#commands) 10 | - [Cluster](commands/cluster.md#cluster) 11 | - [Database](commands/database.md#database) 12 | - [Job](commands/job.md#job) 13 | - [Config](./config.md#config) 14 | - [Connection](./connection.md#connection) 15 | - [Executor](./executor.md#executor) 16 | - [Job](./job.md#job) 17 | - [Types](./types.md#types) 18 | - [Utils](./utils.md#utils) 19 | -------------------------------------------------------------------------------- /docs/_version.md: -------------------------------------------------------------------------------- 1 | # Version 2 | 3 | [Torch-submit Index](./README.md#torch-submit-index) / Version 4 | 5 | > Auto-generated documentation for [_version](../torch_submit/_version.py) module. 6 | 7 | #### Attributes 8 | 9 | - `TYPE_CHECKING` - file generated by setuptools_scm 10 | don't change, don't track in version control: False 11 | - [Version](#version) 12 | -------------------------------------------------------------------------------- /docs/cli.md: -------------------------------------------------------------------------------- 1 | # Cli 2 | 3 | [Torch-submit Index](./README.md#torch-submit-index) / Cli 4 | 5 | > Auto-generated documentation for [cli](../torch_submit/cli.py) module. 6 | 7 | - [Cli](#cli) 8 | - [main](#main) 9 | - [version_callback](#version_callback) 10 | 11 | ## main 12 | 13 | [Show source in cli.py:20](../torch_submit/cli.py#L20) 14 | 15 | #### Signature 16 | 17 | ```python 18 | @app.callback() 19 | def main( 20 | version: bool = typer.Option( 21 | None, 22 | "--version", 23 | callback=version_callback, 24 | is_eager=True, 25 | help="Show the version and exit.", 26 | ) 27 | ): ... 28 | ``` 29 | 30 | 31 | 32 | ## version_callback 33 | 34 | [Show source in cli.py:14](../torch_submit/cli.py#L14) 35 | 36 | #### Signature 37 | 38 | ```python 39 | def version_callback(value: bool): ... 40 | ``` -------------------------------------------------------------------------------- /docs/cluster_config.md: -------------------------------------------------------------------------------- 1 | # ClusterConfig 2 | 3 | [Torch-submit Index](./README.md#torch-submit-index) / ClusterConfig 4 | 5 | > Auto-generated documentation for [cluster_config](../torch_submit/cluster_config.py) module. 6 | 7 | - [ClusterConfig](#clusterconfig) 8 | - [Cluster](#cluster) 9 | - [ClusterConfig](#clusterconfig-1) 10 | - [ClusterConfig().add_cluster](#clusterconfig()add_cluster) 11 | - [ClusterConfig().add_worker_node](#clusterconfig()add_worker_node) 12 | - [ClusterConfig().get_cluster](#clusterconfig()get_cluster) 13 | - [ClusterConfig().list_clusters](#clusterconfig()list_clusters) 14 | - [ClusterConfig().load_config](#clusterconfig()load_config) 15 | - [ClusterConfig().remove_cluster](#clusterconfig()remove_cluster) 16 | - [ClusterConfig().remove_worker_node](#clusterconfig()remove_worker_node) 17 | - [ClusterConfig().save_config](#clusterconfig()save_config) 18 | - [ClusterConfig().update_cluster](#clusterconfig()update_cluster) 19 | - [Node](#node) 20 | - [Node.from_db](#nodefrom_db) 21 | - [Node().to_db](#node()to_db) 22 | 23 | ## Cluster 24 | 25 | [Show source in cluster_config.py:52](../torch_submit/cluster_config.py#L52) 26 | 27 | #### Signature 28 | 29 | ```python 30 | class Cluster: ... 31 | ``` 32 | 33 | 34 | 35 | ## ClusterConfig 36 | 37 | [Show source in cluster_config.py:57](../torch_submit/cluster_config.py#L57) 38 | 39 | #### Signature 40 | 41 | ```python 42 | class ClusterConfig: 43 | def __init__(self): ... 44 | ``` 45 | 46 | ### ClusterConfig().add_cluster 47 | 48 | [Show source in cluster_config.py:103](../torch_submit/cluster_config.py#L103) 49 | 50 | #### Signature 51 | 52 | ```python 53 | def add_cluster(self, name: str, head_node: Node, worker_nodes: List[Node]): ... 54 | ``` 55 | 56 | #### See also 57 | 58 | - [Node](#node) 59 | 60 | ### ClusterConfig().add_worker_node 61 | 62 | [Show source in cluster_config.py:120](../torch_submit/cluster_config.py#L120) 63 | 64 | #### Signature 65 | 66 | ```python 67 | def add_worker_node(self, cluster_name: str, worker_node: Node): ... 68 | ``` 69 | 70 | #### See also 71 | 72 | - [Node](#node) 73 | 74 | ### ClusterConfig().get_cluster 75 | 76 | [Show source in cluster_config.py:112](../torch_submit/cluster_config.py#L112) 77 | 78 | #### Signature 79 | 80 | ```python 81 | def get_cluster(self, cluster_name: str) -> Cluster: ... 82 | ``` 83 | 84 | #### See also 85 | 86 | - [Cluster](#cluster) 87 | 88 | ### ClusterConfig().list_clusters 89 | 90 | [Show source in cluster_config.py:117](../torch_submit/cluster_config.py#L117) 91 | 92 | #### Signature 93 | 94 | ```python 95 | def list_clusters(self) -> List[str]: ... 96 | ``` 97 | 98 | ### ClusterConfig().load_config 99 | 100 | [Show source in cluster_config.py:63](../torch_submit/cluster_config.py#L63) 101 | 102 | #### Signature 103 | 104 | ```python 105 | def load_config(self): ... 106 | ``` 107 | 108 | ### ClusterConfig().remove_cluster 109 | 110 | [Show source in cluster_config.py:107](../torch_submit/cluster_config.py#L107) 111 | 112 | #### Signature 113 | 114 | ```python 115 | def remove_cluster(self, name: str): ... 116 | ``` 117 | 118 | ### ClusterConfig().remove_worker_node 119 | 120 | [Show source in cluster_config.py:126](../torch_submit/cluster_config.py#L126) 121 | 122 | #### Signature 123 | 124 | ```python 125 | def remove_worker_node(self, cluster_name: str, worker_node_ip: str): ... 126 | ``` 127 | 128 | ### ClusterConfig().save_config 129 | 130 | [Show source in cluster_config.py:75](../torch_submit/cluster_config.py#L75) 131 | 132 | #### Signature 133 | 134 | ```python 135 | def save_config(self): ... 136 | ``` 137 | 138 | ### ClusterConfig().update_cluster 139 | 140 | [Show source in cluster_config.py:136](../torch_submit/cluster_config.py#L136) 141 | 142 | #### Signature 143 | 144 | ```python 145 | def update_cluster(self, name: str, head_node: Node, worker_nodes: List[Node]): ... 146 | ``` 147 | 148 | #### See also 149 | 150 | - [Node](#node) 151 | 152 | 153 | 154 | ## Node 155 | 156 | [Show source in cluster_config.py:9](../torch_submit/cluster_config.py#L9) 157 | 158 | #### Signature 159 | 160 | ```python 161 | class Node: ... 162 | ``` 163 | 164 | ### Node.from_db 165 | 166 | [Show source in cluster_config.py:22](../torch_submit/cluster_config.py#L22) 167 | 168 | #### Signature 169 | 170 | ```python 171 | @classmethod 172 | def from_db(cls, row: str): ... 173 | ``` 174 | 175 | ### Node().to_db 176 | 177 | [Show source in cluster_config.py:36](../torch_submit/cluster_config.py#L36) 178 | 179 | #### Signature 180 | 181 | ```python 182 | def to_db(self): ... 183 | ``` -------------------------------------------------------------------------------- /docs/commands/cluster.md: -------------------------------------------------------------------------------- 1 | # Cluster 2 | 3 | [Torch-submit Index](../README.md#torch-submit-index) / [Commands](./index.md#commands) / Cluster 4 | 5 | > Auto-generated documentation for [commands.cluster](../../torch_submit/commands/cluster.py) module. 6 | 7 | - [Cluster](#cluster) 8 | - [create_cluster](#create_cluster) 9 | - [edit_cluster](#edit_cluster) 10 | - [list_clusters](#list_clusters) 11 | - [remove_cluster](#remove_cluster) 12 | 13 | ## create_cluster 14 | 15 | [Show source in cluster.py:14](../../torch_submit/commands/cluster.py#L14) 16 | 17 | Interactively create a new cluster configuration. 18 | 19 | Prompts the user for cluster details such as name, head node, and worker nodes. 20 | Adds the new cluster configuration to the config. 21 | 22 | #### Signature 23 | 24 | ```python 25 | @app.command("create") 26 | def create_cluster(): ... 27 | ``` 28 | 29 | 30 | 31 | ## edit_cluster 32 | 33 | [Show source in cluster.py:129](../../torch_submit/commands/cluster.py#L129) 34 | 35 | Edit an existing cluster configuration. 36 | 37 | Prompts the user for new cluster details and updates the specified cluster configuration in the config. 38 | 39 | #### Arguments 40 | 41 | - `name` *str* - The name of the cluster to edit. 42 | 43 | #### Signature 44 | 45 | ```python 46 | @app.command("edit") 47 | def edit_cluster(name: str): ... 48 | ``` 49 | 50 | 51 | 52 | ## list_clusters 53 | 54 | [Show source in cluster.py:77](../../torch_submit/commands/cluster.py#L77) 55 | 56 | List all available clusters. 57 | 58 | Retrieves the list of clusters from the config and displays them in a table format. 59 | 60 | #### Signature 61 | 62 | ```python 63 | @app.command("list") 64 | def list_clusters(): ... 65 | ``` 66 | 67 | 68 | 69 | ## remove_cluster 70 | 71 | [Show source in cluster.py:112](../../torch_submit/commands/cluster.py#L112) 72 | 73 | Remove a cluster configuration. 74 | 75 | Prompts the user for confirmation before removing the specified cluster configuration from the config. 76 | 77 | #### Arguments 78 | 79 | - `name` *str* - The name of the cluster to remove. 80 | 81 | #### Signature 82 | 83 | ```python 84 | @app.command("remove") 85 | def remove_cluster(name: str): ... 86 | ``` -------------------------------------------------------------------------------- /docs/commands/database.md: -------------------------------------------------------------------------------- 1 | # Database 2 | 3 | [Torch-submit Index](../README.md#torch-submit-index) / [Commands](./index.md#commands) / Database 4 | 5 | > Auto-generated documentation for [commands.database](../../torch_submit/commands/database.py) module. 6 | 7 | - [Database](#database) 8 | - [create_database](#create_database) 9 | - [edit_database](#edit_database) 10 | - [list_databases](#list_databases) 11 | - [remove_database](#remove_database) 12 | 13 | ## create_database 14 | 15 | [Show source in database.py:13](../../torch_submit/commands/database.py#L13) 16 | 17 | Interactively create a new database configuration. 18 | 19 | Prompts the user for database details such as name, type, address, port, username, and password. 20 | Adds the new database configuration to the config. 21 | 22 | #### Signature 23 | 24 | ```python 25 | @app.command("create") 26 | def create_database(): ... 27 | ``` 28 | 29 | 30 | 31 | ## edit_database 32 | 33 | [Show source in database.py:82](../../torch_submit/commands/database.py#L82) 34 | 35 | Edit an existing database configuration. 36 | 37 | Prompts the user for new database details and updates the specified database configuration in the config. 38 | 39 | #### Arguments 40 | 41 | - `name` *str* - The name of the database to edit. 42 | 43 | #### Signature 44 | 45 | ```python 46 | @app.command("edit") 47 | def edit_database(name: str): ... 48 | ``` 49 | 50 | 51 | 52 | ## list_databases 53 | 54 | [Show source in database.py:34](../../torch_submit/commands/database.py#L34) 55 | 56 | List all available databases. 57 | 58 | Retrieves the list of databases from the config and displays them in a table format. 59 | 60 | #### Signature 61 | 62 | ```python 63 | @app.command("list") 64 | def list_databases(): ... 65 | ``` 66 | 67 | 68 | 69 | ## remove_database 70 | 71 | [Show source in database.py:65](../../torch_submit/commands/database.py#L65) 72 | 73 | Remove a database configuration. 74 | 75 | Prompts the user for confirmation before removing the specified database configuration from the config. 76 | 77 | #### Arguments 78 | 79 | - `name` *str* - The name of the database to remove. 80 | 81 | #### Signature 82 | 83 | ```python 84 | @app.command("remove") 85 | def remove_database(name: str): ... 86 | ``` -------------------------------------------------------------------------------- /docs/commands/index.md: -------------------------------------------------------------------------------- 1 | # Commands 2 | 3 | [Torch-submit Index](../README.md#torch-submit-index) / Commands 4 | 5 | > Auto-generated documentation for [commands](../../torch_submit/commands/__init__.py) module. 6 | 7 | - [Commands](#commands) 8 | - [Modules](#modules) 9 | 10 | ## Modules 11 | 12 | - [Cluster](./cluster.md) 13 | - [Database](./database.md) 14 | - [Job](./job.md) -------------------------------------------------------------------------------- /docs/commands/job.md: -------------------------------------------------------------------------------- 1 | # Job 2 | 3 | [Torch-submit Index](../README.md#torch-submit-index) / [Commands](./index.md#commands) / Job 4 | 5 | > Auto-generated documentation for [commands.job](../../torch_submit/commands/job.py) module. 6 | 7 | - [Job](#job) 8 | - [delete_job](#delete_job) 9 | - [list_jobs](#list_jobs) 10 | - [print_logs](#print_logs) 11 | - [restart_job](#restart_job) 12 | - [stop_job](#stop_job) 13 | - [submit](#submit) 14 | 15 | ## delete_job 16 | 17 | [Show source in job.py:312](../../torch_submit/commands/job.py#L312) 18 | 19 | Delete a job. 20 | 21 | #### Arguments 22 | 23 | - `job_id` *str* - Job ID or name to delete or 'all' to delete all jobs. 24 | 25 | #### Signature 26 | 27 | ```python 28 | @app.command("delete") 29 | def delete_job( 30 | job_id: str = typer.Argument( 31 | ..., help="Job ID or name to delete or 'all' to delete all jobs" 32 | ) 33 | ): ... 34 | ``` 35 | 36 | 37 | 38 | ## list_jobs 39 | 40 | [Show source in job.py:194](../../torch_submit/commands/job.py#L194) 41 | 42 | List all submitted jobs. 43 | 44 | #### Signature 45 | 46 | ```python 47 | @app.command("list") 48 | def list_jobs(): ... 49 | ``` 50 | 51 | 52 | 53 | ## print_logs 54 | 55 | [Show source in job.py:165](../../torch_submit/commands/job.py#L165) 56 | 57 | Tail the logs of a specific job. 58 | 59 | #### Arguments 60 | 61 | - `job_id` *str* - Job ID or name. 62 | - `tail` *bool* - Tail the logs. 63 | 64 | #### Signature 65 | 66 | ```python 67 | @app.command("logs") 68 | def print_logs( 69 | job_id: str = typer.Argument(..., help="Job ID or name"), 70 | tail: bool = typer.Option(False, help="Tail the logs"), 71 | ): ... 72 | ``` 73 | 74 | 75 | 76 | ## restart_job 77 | 78 | [Show source in job.py:262](../../torch_submit/commands/job.py#L262) 79 | 80 | Restart a stopped job. 81 | 82 | #### Arguments 83 | 84 | - `job_id` *str* - Job ID or name. 85 | 86 | #### Signature 87 | 88 | ```python 89 | @app.command("restart") 90 | def restart_job(job_id: str = typer.Argument(..., help="Job ID or name")): ... 91 | ``` 92 | 93 | 94 | 95 | ## stop_job 96 | 97 | [Show source in job.py:228](../../torch_submit/commands/job.py#L228) 98 | 99 | Stop a running job. 100 | 101 | #### Arguments 102 | 103 | - `job_id` *str* - Job ID or name. 104 | 105 | #### Signature 106 | 107 | ```python 108 | @app.command("stop") 109 | def stop_job(job_id: str = typer.Argument(..., help="Job ID or name")): ... 110 | ``` 111 | 112 | 113 | 114 | ## submit 115 | 116 | [Show source in job.py:28](../../torch_submit/commands/job.py#L28) 117 | 118 | Submit a new job to a specified cluster. 119 | 120 | #### Arguments 121 | 122 | - `cluster` *str* - Name of the cluster to use. 123 | - `name` *Optional[str]* - Job name (optional, will be auto-generated if not provided). 124 | - `working_dir` *str* - Path to working directory. 125 | - `max_restarts` *int* - Maximum number of restarts for the job. 126 | - `num_gpus` *Optional[int]* - Number of GPUs to use per node (optional, defaults to all available). 127 | - `command` *List[str]* - The command to run, e.g. 'python main.py'. 128 | - `tail` *bool* - Tail the logs after submitting the job. 129 | - [Executor](../executor.md#executor) *Executor* - Executor to use. 130 | - `docker_image` *Optional[str]* - Docker image to use. 131 | - `database` *Optional[str]* - Database to use. 132 | - `runtime_env` *Optional[str]* - Runtime environment yaml file to use. 133 | 134 | #### Signature 135 | 136 | ```python 137 | @app.command("submit") 138 | def submit( 139 | cluster: str = typer.Option(..., help="Name of the cluster to use"), 140 | name: Optional[str] = typer.Option( 141 | None, help="Job name (optional, will be auto-generated if not provided)" 142 | ), 143 | working_dir: str = typer.Option("./", help="Path to working directory"), 144 | max_restarts: int = typer.Option(0, help="Maximum number of restarts for the job"), 145 | num_gpus: Optional[int] = typer.Option( 146 | None, help="Number of GPUs to use per node (optional, defaults to all available)" 147 | ), 148 | command: List[str] = typer.Argument( 149 | ..., help="The command to run, e.g. 'python main.py'" 150 | ), 151 | tail: bool = typer.Option(False, help="Tail the logs after submitting the job"), 152 | executor: Executor = typer.Option(Executor.TORCHRUN, help="Executor to use"), 153 | docker_image: Optional[str] = typer.Option(None, help="Docker image to use"), 154 | database: Optional[str] = typer.Option(None, help="Database to use"), 155 | runtime_env: Optional[str] = typer.Option( 156 | None, help="Runtime environment yaml file to use" 157 | ), 158 | ): ... 159 | ``` 160 | 161 | #### See also 162 | 163 | - [Executor](../types.md#executor) -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | # Config 2 | 3 | [Torch-submit Index](./README.md#torch-submit-index) / Config 4 | 5 | > Auto-generated documentation for [config](../torch_submit/config.py) module. 6 | 7 | - [Config](#config) 8 | - [Cluster](#cluster) 9 | - [Config](#config-1) 10 | - [Config().add_cluster](#config()add_cluster) 11 | - [Config().add_db](#config()add_db) 12 | - [Config().add_worker_node](#config()add_worker_node) 13 | - [Config().get_cluster](#config()get_cluster) 14 | - [Config().get_db](#config()get_db) 15 | - [Config().list_clusters](#config()list_clusters) 16 | - [Config().list_dbs](#config()list_dbs) 17 | - [Config().load_config](#config()load_config) 18 | - [Config().remove_cluster](#config()remove_cluster) 19 | - [Config().remove_db](#config()remove_db) 20 | - [Config().remove_worker_node](#config()remove_worker_node) 21 | - [Config().save_config](#config()save_config) 22 | - [Config().update_cluster](#config()update_cluster) 23 | - [Config().update_db](#config()update_db) 24 | - [Database](#database) 25 | - [Database().__eq__](#database()__eq__) 26 | - [Database().__hash__](#database()__hash__) 27 | - [Database().__post_init__](#database()__post_init__) 28 | - [Database().__str__](#database()__str__) 29 | - [Database.from_db](#databasefrom_db) 30 | - [Database().to_db](#database()to_db) 31 | - [Database().uri](#database()uri) 32 | - [DatabaseType](#databasetype) 33 | - [DatabaseType().connection_string](#databasetype()connection_string) 34 | - [Node](#node) 35 | - [Node().__eq__](#node()__eq__) 36 | - [Node().__hash__](#node()__hash__) 37 | - [Node().__post_init__](#node()__post_init__) 38 | - [Node().__str__](#node()__str__) 39 | - [Node.from_db](#nodefrom_db) 40 | - [Node().to_db](#node()to_db) 41 | 42 | ## Cluster 43 | 44 | [Show source in config.py:99](../torch_submit/config.py#L99) 45 | 46 | Represents a cluster of nodes. 47 | 48 | #### Attributes 49 | 50 | - `head_node` *Node* - The head node of the cluster. 51 | - `worker_nodes` *List[Node]* - A list of worker nodes in the cluster. 52 | 53 | #### Signature 54 | 55 | ```python 56 | class Cluster: ... 57 | ``` 58 | 59 | 60 | 61 | ## Config 62 | 63 | [Show source in config.py:239](../torch_submit/config.py#L239) 64 | 65 | Manages the configuration for clusters and databases. 66 | 67 | This class handles loading, saving, and manipulating the configuration 68 | for clusters and databases used in the torch-submit system. 69 | 70 | #### Signature 71 | 72 | ```python 73 | class Config: 74 | def __init__(self): ... 75 | ``` 76 | 77 | ### Config().add_cluster 78 | 79 | [Show source in config.py:311](../torch_submit/config.py#L311) 80 | 81 | Add a new cluster to the configuration. 82 | 83 | #### Arguments 84 | 85 | - `name` *str* - The name of the cluster. 86 | - `head_node` *Node* - The head node of the cluster. 87 | - `worker_nodes` *List[Node]* - The list of worker nodes in the cluster. 88 | 89 | #### Signature 90 | 91 | ```python 92 | def add_cluster(self, name: str, head_node: Node, worker_nodes: List[Node]): ... 93 | ``` 94 | 95 | #### See also 96 | 97 | - [Node](#node) 98 | 99 | ### Config().add_db 100 | 101 | [Show source in config.py:407](../torch_submit/config.py#L407) 102 | 103 | Add a new database to the configuration. 104 | 105 | #### Arguments 106 | 107 | - `type` *DatabaseType* - The type of the database. 108 | - `name` *str* - The name of the database configuration. 109 | - `address` *str* - The address of the database server. 110 | - `port` *int* - The port number for the database connection. 111 | - `username` *str* - The username for database authentication. 112 | - `password` *str* - The password for database authentication. 113 | 114 | #### Signature 115 | 116 | ```python 117 | def add_db( 118 | self, 119 | type: DatabaseType, 120 | name: str, 121 | address: str, 122 | port: int, 123 | username: str, 124 | password: str, 125 | ): ... 126 | ``` 127 | 128 | #### See also 129 | 130 | - [DatabaseType](#databasetype) 131 | 132 | ### Config().add_worker_node 133 | 134 | [Show source in config.py:356](../torch_submit/config.py#L356) 135 | 136 | Add a worker node to a cluster. 137 | 138 | #### Arguments 139 | 140 | - `cluster_name` *str* - The name of the cluster. 141 | - `worker_node` *Node* - The worker node to add. 142 | 143 | #### Raises 144 | 145 | - `ValueError` - If the cluster is not found in the configuration. 146 | 147 | #### Signature 148 | 149 | ```python 150 | def add_worker_node(self, cluster_name: str, worker_node: Node): ... 151 | ``` 152 | 153 | #### See also 154 | 155 | - [Node](#node) 156 | 157 | ### Config().get_cluster 158 | 159 | [Show source in config.py:332](../torch_submit/config.py#L332) 160 | 161 | Get a cluster by its name. 162 | 163 | #### Arguments 164 | 165 | - `cluster_name` *str* - The name of the cluster. 166 | 167 | #### Returns 168 | 169 | - [Cluster](#cluster) - The requested cluster. 170 | 171 | #### Raises 172 | 173 | - `ValueError` - If the cluster is not found in the configuration. 174 | 175 | #### Signature 176 | 177 | ```python 178 | def get_cluster(self, cluster_name: str) -> Cluster: ... 179 | ``` 180 | 181 | #### See also 182 | 183 | - [Cluster](#cluster) 184 | 185 | ### Config().get_db 186 | 187 | [Show source in config.py:442](../torch_submit/config.py#L442) 188 | 189 | Get a database configuration by its name. 190 | 191 | #### Arguments 192 | 193 | - `db_name` *str* - The name of the database configuration. 194 | 195 | #### Returns 196 | 197 | - [Database](#database) - The requested database configuration. 198 | 199 | #### Raises 200 | 201 | - `ValueError` - If the database configuration is not found. 202 | 203 | #### Signature 204 | 205 | ```python 206 | def get_db(self, db_name: str) -> Database: ... 207 | ``` 208 | 209 | #### See also 210 | 211 | - [Database](#database) 212 | 213 | ### Config().list_clusters 214 | 215 | [Show source in config.py:348](../torch_submit/config.py#L348) 216 | 217 | Get a list of all cluster names. 218 | 219 | #### Returns 220 | 221 | - `List[str]` - A list of cluster names. 222 | 223 | #### Signature 224 | 225 | ```python 226 | def list_clusters(self) -> List[str]: ... 227 | ``` 228 | 229 | ### Config().list_dbs 230 | 231 | [Show source in config.py:458](../torch_submit/config.py#L458) 232 | 233 | Get a list of all database configuration names. 234 | 235 | #### Returns 236 | 237 | - `List[str]` - A list of database configuration names. 238 | 239 | #### Signature 240 | 241 | ```python 242 | def list_dbs(self) -> List[str]: ... 243 | ``` 244 | 245 | ### Config().load_config 246 | 247 | [Show source in config.py:253](../torch_submit/config.py#L253) 248 | 249 | Load the configuration from the YAML file. 250 | 251 | #### Signature 252 | 253 | ```python 254 | def load_config(self): ... 255 | ``` 256 | 257 | ### Config().remove_cluster 258 | 259 | [Show source in config.py:322](../torch_submit/config.py#L322) 260 | 261 | Remove a cluster from the configuration. 262 | 263 | #### Arguments 264 | 265 | - `name` *str* - The name of the cluster to remove. 266 | 267 | #### Signature 268 | 269 | ```python 270 | def remove_cluster(self, name: str): ... 271 | ``` 272 | 273 | ### Config().remove_db 274 | 275 | [Show source in config.py:432](../torch_submit/config.py#L432) 276 | 277 | Remove a database from the configuration. 278 | 279 | #### Arguments 280 | 281 | - `name` *str* - The name of the database configuration to remove. 282 | 283 | #### Signature 284 | 285 | ```python 286 | def remove_db(self, name: str): ... 287 | ``` 288 | 289 | ### Config().remove_worker_node 290 | 291 | [Show source in config.py:371](../torch_submit/config.py#L371) 292 | 293 | Remove a worker node from a cluster. 294 | 295 | #### Arguments 296 | 297 | - `cluster_name` *str* - The name of the cluster. 298 | - `worker_node_ip` *str* - The IP address of the worker node to remove. 299 | 300 | #### Raises 301 | 302 | - `ValueError` - If the cluster is not found in the configuration. 303 | 304 | #### Signature 305 | 306 | ```python 307 | def remove_worker_node(self, cluster_name: str, worker_node_ip: str): ... 308 | ``` 309 | 310 | ### Config().save_config 311 | 312 | [Show source in config.py:270](../torch_submit/config.py#L270) 313 | 314 | Save the current configuration to the YAML file. 315 | 316 | #### Signature 317 | 318 | ```python 319 | def save_config(self): ... 320 | ``` 321 | 322 | ### Config().update_cluster 323 | 324 | [Show source in config.py:390](../torch_submit/config.py#L390) 325 | 326 | Update an existing cluster in the configuration. 327 | 328 | #### Arguments 329 | 330 | - `name` *str* - The name of the cluster to update. 331 | - `head_node` *Node* - The new head node for the cluster. 332 | - `worker_nodes` *List[Node]* - The new list of worker nodes for the cluster. 333 | 334 | #### Raises 335 | 336 | - `ValueError` - If the cluster is not found in the configuration. 337 | 338 | #### Signature 339 | 340 | ```python 341 | def update_cluster(self, name: str, head_node: Node, worker_nodes: List[Node]): ... 342 | ``` 343 | 344 | #### See also 345 | 346 | - [Node](#node) 347 | 348 | ### Config().update_db 349 | 350 | [Show source in config.py:466](../torch_submit/config.py#L466) 351 | 352 | Update an existing database configuration. 353 | 354 | #### Arguments 355 | 356 | - `type` *str* - The type of the database. 357 | - `name` *str* - The name of the database configuration to update. 358 | - `address` *str* - The new address of the database server. 359 | - `port` *int* - The new port number for the database connection. 360 | - `username` *str* - The new username for database authentication. 361 | - `password` *str* - The new password for database authentication. 362 | 363 | #### Raises 364 | 365 | - `ValueError` - If the specified database configuration is not found. 366 | 367 | #### Signature 368 | 369 | ```python 370 | def update_db( 371 | self, type: str, name: str, address: str, port: int, username: str, password: str 372 | ): ... 373 | ``` 374 | 375 | 376 | 377 | ## Database 378 | 379 | [Show source in config.py:141](../torch_submit/config.py#L141) 380 | 381 | Represents a database configuration. 382 | 383 | #### Attributes 384 | 385 | - `address` *str* - The address of the database server. 386 | - `port` *int* - The port number for the database connection. 387 | - `username` *str* - The username for database authentication. 388 | password (str | None): The password for database authentication, if required. 389 | - `type` *DatabaseType* - The type of the database (e.g., PostgreSQL, MySQL). 390 | 391 | #### Signature 392 | 393 | ```python 394 | class Database: ... 395 | ``` 396 | 397 | ### Database().__eq__ 398 | 399 | [Show source in config.py:221](../torch_submit/config.py#L221) 400 | 401 | Check if two Database objects are equal. 402 | 403 | #### Arguments 404 | 405 | - `other` - Another object to compare with. 406 | 407 | #### Returns 408 | 409 | - `bool` - True if the objects are equal, False otherwise. 410 | 411 | #### Signature 412 | 413 | ```python 414 | def __eq__(self, other): ... 415 | ``` 416 | 417 | ### Database().__hash__ 418 | 419 | [Show source in config.py:207](../torch_submit/config.py#L207) 420 | 421 | Return a hash value for the Database object. 422 | 423 | #### Returns 424 | 425 | - `int` - A hash value based on the database attributes. 426 | 427 | #### Signature 428 | 429 | ```python 430 | def __hash__(self): ... 431 | ``` 432 | 433 | ### Database().__post_init__ 434 | 435 | [Show source in config.py:158](../torch_submit/config.py#L158) 436 | 437 | Initialize the Database object after creation. 438 | 439 | #### Signature 440 | 441 | ```python 442 | def __post_init__(self): ... 443 | ``` 444 | 445 | ### Database().__str__ 446 | 447 | [Show source in config.py:199](../torch_submit/config.py#L199) 448 | 449 | Return a string representation of the Database object. 450 | 451 | #### Returns 452 | 453 | - `str` - A string representation of the Database object. 454 | 455 | #### Signature 456 | 457 | ```python 458 | def __str__(self): ... 459 | ``` 460 | 461 | ### Database.from_db 462 | 463 | [Show source in config.py:163](../torch_submit/config.py#L163) 464 | 465 | Create a Database object from a database row string. 466 | 467 | #### Arguments 468 | 469 | - `row` *str* - A string representation of the database data. 470 | 471 | #### Returns 472 | 473 | - [Database](#database) - A new Database object created from the row data. 474 | 475 | #### Signature 476 | 477 | ```python 478 | @classmethod 479 | def from_db(cls, row: str): ... 480 | ``` 481 | 482 | ### Database().to_db 483 | 484 | [Show source in config.py:191](../torch_submit/config.py#L191) 485 | 486 | Convert the Database object to a string representation for database storage. 487 | 488 | #### Returns 489 | 490 | - `str` - A string representation of the Database object. 491 | 492 | #### Signature 493 | 494 | ```python 495 | def to_db(self): ... 496 | ``` 497 | 498 | ### Database().uri 499 | 500 | [Show source in config.py:182](../torch_submit/config.py#L182) 501 | 502 | Get the database URI for SQLAlchemy connection. 503 | 504 | #### Returns 505 | 506 | - `str` - The database URI. 507 | 508 | #### Signature 509 | 510 | ```python 511 | @property 512 | def uri(self): ... 513 | ``` 514 | 515 | 516 | 517 | ## DatabaseType 518 | 519 | [Show source in config.py:111](../torch_submit/config.py#L111) 520 | 521 | Enumeration of supported database types. 522 | 523 | #### Attributes 524 | 525 | - `POSTGRES` - PostgreSQL database type. 526 | - `MYSQL` - MySQL database type. 527 | 528 | #### Signature 529 | 530 | ```python 531 | class DatabaseType(str, Enum): ... 532 | ``` 533 | 534 | ### DatabaseType().connection_string 535 | 536 | [Show source in config.py:122](../torch_submit/config.py#L122) 537 | 538 | Get the SQLAlchemy connection string prefix for the database type. 539 | 540 | #### Returns 541 | 542 | - `str` - The connection string prefix. 543 | 544 | #### Raises 545 | 546 | - `ValueError` - If the database type is unknown. 547 | 548 | #### Signature 549 | 550 | ```python 551 | @property 552 | def connection_string(self): ... 553 | ``` 554 | 555 | 556 | 557 | ## Node 558 | 559 | [Show source in config.py:11](../torch_submit/config.py#L11) 560 | 561 | Represents a node in a cluster. 562 | 563 | #### Attributes 564 | 565 | - `public_ip` *str* - The public IP address of the node. 566 | - `private_ip` *Optional[str]* - The private IP address of the node, if available. 567 | - `num_gpus` *int* - The number of GPUs available on the node. 568 | - `nproc` *int* - The number of processes that can run on the node. 569 | - `ssh_user` *Optional[str]* - The SSH username for accessing the node, if available. 570 | - `ssh_pub_key_path` *Optional[str]* - The path to the SSH public key file, if available. 571 | 572 | #### Signature 573 | 574 | ```python 575 | class Node: ... 576 | ``` 577 | 578 | ### Node().__eq__ 579 | 580 | [Show source in config.py:84](../torch_submit/config.py#L84) 581 | 582 | Check if two Node objects are equal. 583 | 584 | #### Arguments 585 | 586 | - `other` - Another object to compare with. 587 | 588 | #### Returns 589 | 590 | - `bool` - True if the objects are equal, False otherwise. 591 | 592 | #### Signature 593 | 594 | ```python 595 | def __eq__(self, other): ... 596 | ``` 597 | 598 | ### Node().__hash__ 599 | 600 | [Show source in config.py:76](../torch_submit/config.py#L76) 601 | 602 | Return a hash value for the Node object. 603 | 604 | #### Returns 605 | 606 | - `int` - A hash value based on the public IP address. 607 | 608 | #### Signature 609 | 610 | ```python 611 | def __hash__(self): ... 612 | ``` 613 | 614 | ### Node().__post_init__ 615 | 616 | [Show source in config.py:30](../torch_submit/config.py#L30) 617 | 618 | Initialize the Node object after creation. 619 | 620 | #### Signature 621 | 622 | ```python 623 | def __post_init__(self): ... 624 | ``` 625 | 626 | ### Node().__str__ 627 | 628 | [Show source in config.py:68](../torch_submit/config.py#L68) 629 | 630 | Return a string representation of the Node object. 631 | 632 | #### Returns 633 | 634 | - `str` - A string representation of the Node object. 635 | 636 | #### Signature 637 | 638 | ```python 639 | def __str__(self): ... 640 | ``` 641 | 642 | ### Node.from_db 643 | 644 | [Show source in config.py:38](../torch_submit/config.py#L38) 645 | 646 | Create a Node object from a database row string. 647 | 648 | #### Arguments 649 | 650 | - `row` *str* - A string representation of the node data. 651 | 652 | #### Returns 653 | 654 | - [Node](#node) - A new Node object created from the row data. 655 | 656 | #### Signature 657 | 658 | ```python 659 | @classmethod 660 | def from_db(cls, row: str): ... 661 | ``` 662 | 663 | ### Node().to_db 664 | 665 | [Show source in config.py:60](../torch_submit/config.py#L60) 666 | 667 | Convert the Node object to a string representation for database storage. 668 | 669 | #### Returns 670 | 671 | - `str` - A string representation of the Node object. 672 | 673 | #### Signature 674 | 675 | ```python 676 | def to_db(self): ... 677 | ``` -------------------------------------------------------------------------------- /docs/connection.md: -------------------------------------------------------------------------------- 1 | # Connection 2 | 3 | [Torch-submit Index](./README.md#torch-submit-index) / Connection 4 | 5 | > Auto-generated documentation for [connection](../torch_submit/connection.py) module. 6 | 7 | - [Connection](#connection) 8 | - [NodeConnection](#nodeconnection) 9 | - [NodeConnection().__enter__](#nodeconnection()__enter__) 10 | - [NodeConnection().__exit__](#nodeconnection()__exit__) 11 | 12 | ## NodeConnection 13 | 14 | [Show source in connection.py:6](../torch_submit/connection.py#L6) 15 | 16 | A context manager for handling SSH connections to a node. 17 | 18 | #### Signature 19 | 20 | ```python 21 | class NodeConnection: 22 | def __init__(self, node: Node): ... 23 | ``` 24 | 25 | #### See also 26 | 27 | - [Node](./config.md#node) 28 | 29 | ### NodeConnection().__enter__ 30 | 31 | [Show source in connection.py:17](../torch_submit/connection.py#L17) 32 | 33 | Establish an SSH connection to the node. 34 | 35 | #### Returns 36 | 37 | - `Connection` - The established SSH connection. 38 | 39 | #### Signature 40 | 41 | ```python 42 | def __enter__(self): ... 43 | ``` 44 | 45 | ### NodeConnection().__exit__ 46 | 47 | [Show source in connection.py:35](../torch_submit/connection.py#L35) 48 | 49 | Close the SSH connection when exiting the context. 50 | 51 | #### Arguments 52 | 53 | - `exc_type` - The type of the exception that caused the context to be exited. 54 | - `exc_val` - The instance of the exception that caused the context to be exited. 55 | - `exc_tb` - A traceback object encapsulating the call stack at the point where the exception occurred. 56 | 57 | #### Signature 58 | 59 | ```python 60 | def __exit__(self, exc_type, exc_val, exc_tb): ... 61 | ``` -------------------------------------------------------------------------------- /docs/executor.md: -------------------------------------------------------------------------------- 1 | # Executor 2 | 3 | [Torch-submit Index](./README.md#torch-submit-index) / Executor 4 | 5 | > Auto-generated documentation for [executor](../torch_submit/executor.py) module. 6 | 7 | - [Executor](#executor) 8 | - [BaseExecutor](#baseexecutor) 9 | - [BaseExecutor()._run_job](#baseexecutor()_run_job) 10 | - [BaseExecutor().cleanup](#baseexecutor()cleanup) 11 | - [BaseExecutor().execute](#baseexecutor()execute) 12 | - [BaseExecutor().get_command](#baseexecutor()get_command) 13 | - [DistributedExecutor](#distributedexecutor) 14 | - [DistributedExecutor().get_command](#distributedexecutor()get_command) 15 | - [DockerDistributedExecutor](#dockerdistributedexecutor) 16 | - [DockerDistributedExecutor().get_command](#dockerdistributedexecutor()get_command) 17 | - [JobExecutionManager](#jobexecutionmanager) 18 | - [JobExecutionManager.cancel_job](#jobexecutionmanagercancel_job) 19 | - [JobExecutionManager.submit_job](#jobexecutionmanagersubmit_job) 20 | - [OptunaExecutor](#optunaexecutor) 21 | - [OptunaExecutor().execute](#optunaexecutor()execute) 22 | - [OptunaExecutor().get_command](#optunaexecutor()get_command) 23 | - [TorchrunExecutor](#torchrunexecutor) 24 | - [TorchrunExecutor().get_command](#torchrunexecutor()get_command) 25 | - [WorkingDirectoryArchiver](#workingdirectoryarchiver) 26 | - [WorkingDirectoryArchiver().archive](#workingdirectoryarchiver()archive) 27 | 28 | ## BaseExecutor 29 | 30 | [Show source in executor.py:114](../torch_submit/executor.py#L114) 31 | 32 | Base class for executing jobs across a cluster. 33 | 34 | This class defines the structure for executing a job. Sub-classes must implement the get_command 35 | method, which generates the command to be executed on each node in the cluster. The execute method 36 | runs this command on each node, managing the setup and execution process. 37 | 38 | #### Methods 39 | 40 | - `get_command(rank` - int): Abstract method to create the command for the given node rank. 41 | execute() -> Dict[Node, int]: Executes the job command on each node in the cluster and returns 42 | a dictionary mapping nodes to their process IDs. 43 | 44 | #### Signature 45 | 46 | ```python 47 | class BaseExecutor(ABC): 48 | def __init__(self, job: Job): ... 49 | ``` 50 | 51 | #### See also 52 | 53 | - [Job](./types.md#job) 54 | 55 | ### BaseExecutor()._run_job 56 | 57 | [Show source in executor.py:179](../torch_submit/executor.py#L179) 58 | 59 | Run the job on the specified node. 60 | 61 | This method changes the directory to the remote directory, runs the provided torchrun command 62 | along with the job command, and captures the process ID of the running job. 63 | 64 | #### Arguments 65 | 66 | - `conn` *Connection* - The connection object to the node. 67 | - `executor_command` *str* - The command with which to run the user-provided script. 68 | - `node_rank` *int* - The rank of the node in the cluster. 69 | 70 | #### Returns 71 | 72 | - `int` - The process ID of the running job. 73 | 74 | #### Signature 75 | 76 | ```python 77 | def _run_job( 78 | self, conn: Connection, node_rank: int, env_vars: Optional[Dict[str, str]] = None 79 | ): ... 80 | ``` 81 | 82 | ### BaseExecutor().cleanup 83 | 84 | [Show source in executor.py:232](../torch_submit/executor.py#L232) 85 | 86 | Clean up the remote directories on all nodes. 87 | 88 | This method removes the remote directory created for the job on each node. 89 | If the cleanup fails on any node, a warning message is printed. 90 | 91 | #### Signature 92 | 93 | ```python 94 | def cleanup(self): ... 95 | ``` 96 | 97 | ### BaseExecutor().execute 98 | 99 | [Show source in executor.py:146](../torch_submit/executor.py#L146) 100 | 101 | Execute the job command on each node in the cluster. 102 | 103 | This method sets up the remote environment, copies the working directory, 104 | and runs the job command on each node in the cluster. It manages the setup 105 | and execution process, handling any exceptions that occur during execution. 106 | 107 | #### Returns 108 | 109 | - `Dict[Node,` *int]* - A dictionary mapping nodes to their process IDs. 110 | 111 | #### Signature 112 | 113 | ```python 114 | def execute(self, env_vars: Optional[Dict[str, str]] = None) -> Dict[Node, int]: ... 115 | ``` 116 | 117 | #### See also 118 | 119 | - [Node](./config.md#node) 120 | 121 | ### BaseExecutor().get_command 122 | 123 | [Show source in executor.py:133](../torch_submit/executor.py#L133) 124 | 125 | Generate the command to be executed on the given node rank. 126 | 127 | #### Arguments 128 | 129 | - `rank` *int* - The rank of the node in the cluster. 130 | 131 | #### Returns 132 | 133 | - `str` - The command to be executed on the node. 134 | 135 | #### Signature 136 | 137 | ```python 138 | @abstractmethod 139 | def get_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): ... 140 | ``` 141 | 142 | 143 | 144 | ## DistributedExecutor 145 | 146 | [Show source in executor.py:249](../torch_submit/executor.py#L249) 147 | 148 | The DistributedExecutor is responsible for setting up the environment for running 149 | distributed PyTorch jobs. It ensures that the necessary environment variables are set 150 | for the torch distributed environment, including MASTER_ADDR, MASTER_PORT, WORLD_SIZE, 151 | and NODE_RANK. These variables are essential for coordinating the distributed training 152 | process across multiple nodes and GPUs. 153 | 154 | Exposes the following environment variables to the user script: 155 | - MASTER_ADDR: The address of the master node. 156 | - MASTER_PORT: The port on which the master node is listening. 157 | - WORLD_SIZE: The total number of processes participating in the job. 158 | - NODE_RANK: The rank of the current node. 159 | - LOCAL_WORLD_SIZE: The number of processes on the current node. 160 | 161 | #### Signature 162 | 163 | ```python 164 | class DistributedExecutor(BaseExecutor): 165 | def __init__(self, job: Job): ... 166 | ``` 167 | 168 | #### See also 169 | 170 | - [BaseExecutor](#baseexecutor) 171 | - [Job](./types.md#job) 172 | 173 | ### DistributedExecutor().get_command 174 | 175 | [Show source in executor.py:269](../torch_submit/executor.py#L269) 176 | 177 | Constructs the command to run the job with the torch distributed environment variables set. 178 | 179 | This method sets up the necessary environment variables for a distributed torch run, including 180 | MASTER_ADDR, MASTER_PORT, WORLD_SIZE, and NODE_RANK. It then appends the user-provided command 181 | to these environment variables. 182 | 183 | #### Arguments 184 | 185 | - `rank` *int* - The rank of the current node. 186 | 187 | #### Returns 188 | 189 | - `str` - The full command to run the job with the necessary environment variables. 190 | 191 | #### Signature 192 | 193 | ```python 194 | def get_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): ... 195 | ``` 196 | 197 | 198 | 199 | ## DockerDistributedExecutor 200 | 201 | [Show source in executor.py:427](../torch_submit/executor.py#L427) 202 | 203 | EXPERIMENTAL: 204 | DockerDistributedExecutor is an executor that runs distributed jobs inside Docker containers. 205 | 206 | This executor extends the DistributedExecutor to provide Docker support, allowing the user to run 207 | distributed jobs in isolated Docker environments with GPU support. 208 | 209 | Exposes the following environment variables to the user script: 210 | - MASTER_ADDR: The address of the master node. 211 | - MASTER_PORT: The port on which the master node is listening. 212 | - WORLD_SIZE: The total number of processes participating in the job. 213 | - NODE_RANK: The rank of the current node. 214 | 215 | #### Signature 216 | 217 | ```python 218 | class DockerDistributedExecutor(DistributedExecutor): 219 | def __init__(self, job: Job): ... 220 | ``` 221 | 222 | #### See also 223 | 224 | - [DistributedExecutor](#distributedexecutor) 225 | - [Job](./types.md#job) 226 | 227 | ### DockerDistributedExecutor().get_command 228 | 229 | [Show source in executor.py:445](../torch_submit/executor.py#L445) 230 | 231 | Constructs the command to run the job with the torch distributed environment variables set. 232 | 233 | This method sets up the necessary environment variables for a distributed torch run, including 234 | MASTER_ADDR, MASTER_PORT, WORLD_SIZE, and NODE_RANK. It then appends the user-provided command 235 | to these environment variables. 236 | 237 | #### Arguments 238 | 239 | - `rank` *int* - The rank of the current node. 240 | 241 | #### Returns 242 | 243 | - `str` - The full command to run the job with the necessary environment variables. 244 | 245 | #### Signature 246 | 247 | ```python 248 | def get_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): ... 249 | ``` 250 | 251 | 252 | 253 | ## JobExecutionManager 254 | 255 | [Show source in executor.py:486](../torch_submit/executor.py#L486) 256 | 257 | #### Signature 258 | 259 | ```python 260 | class JobExecutionManager: ... 261 | ``` 262 | 263 | ### JobExecutionManager.cancel_job 264 | 265 | [Show source in executor.py:501](../torch_submit/executor.py#L501) 266 | 267 | #### Signature 268 | 269 | ```python 270 | @staticmethod 271 | def cancel_job(job: Job): ... 272 | ``` 273 | 274 | #### See also 275 | 276 | - [Job](./types.md#job) 277 | 278 | ### JobExecutionManager.submit_job 279 | 280 | [Show source in executor.py:487](../torch_submit/executor.py#L487) 281 | 282 | #### Signature 283 | 284 | ```python 285 | @staticmethod 286 | def submit_job(job: Job): ... 287 | ``` 288 | 289 | #### See also 290 | 291 | - [Job](./types.md#job) 292 | 293 | 294 | 295 | ## OptunaExecutor 296 | 297 | [Show source in executor.py:367](../torch_submit/executor.py#L367) 298 | 299 | The OptunaExecutor sets up and manages the execution of Optuna distributed optimization jobs. 300 | 301 | The head node runs a SQLite database for Optuna and exposes it to the cluster. Each node in the cluster 302 | runs a single Optuna process that will utilize all the GPUs available on that node. 303 | 304 | Exposes the following environment variables to the user script: 305 | - MASTER_ADDR: The address of the master node. 306 | - MASTER_PORT: The port on which the master node is listening. 307 | - WORLD_SIZE: The total number of processes participating in the job. 308 | - NODE_RANK: The rank of the current node. 309 | - STUDY_NAME: The name of the Optuna study (the job name). 310 | - DATABASE_URI: The URI of the database. 311 | 312 | #### Signature 313 | 314 | ```python 315 | class OptunaExecutor(DistributedExecutor): 316 | def __init__(self, job: Job): ... 317 | ``` 318 | 319 | #### See also 320 | 321 | - [DistributedExecutor](#distributedexecutor) 322 | - [Job](./types.md#job) 323 | 324 | ### OptunaExecutor().execute 325 | 326 | [Show source in executor.py:404](../torch_submit/executor.py#L404) 327 | 328 | Set up the database on the head node and then run the DistributedExecutor execute method. 329 | 330 | This method first sets up the SQLite database on the head node for Optuna. After the database 331 | is set up, it calls the execute method of the DistributedExecutor to run the job command on 332 | each node in the cluster. 333 | 334 | #### Returns 335 | 336 | - `Dict[Node,` *int]* - A dictionary mapping nodes to their process IDs. 337 | 338 | #### Signature 339 | 340 | ```python 341 | def execute(self) -> Dict[Node, int]: ... 342 | ``` 343 | 344 | #### See also 345 | 346 | - [Node](./config.md#node) 347 | 348 | ### OptunaExecutor().get_command 349 | 350 | [Show source in executor.py:386](../torch_submit/executor.py#L386) 351 | 352 | #### Signature 353 | 354 | ```python 355 | def get_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): ... 356 | ``` 357 | 358 | 359 | 360 | ## TorchrunExecutor 361 | 362 | [Show source in executor.py:302](../torch_submit/executor.py#L302) 363 | 364 | #### Signature 365 | 366 | ```python 367 | class TorchrunExecutor(BaseExecutor): 368 | def __init__(self, job: Job): ... 369 | ``` 370 | 371 | #### See also 372 | 373 | - [BaseExecutor](#baseexecutor) 374 | - [Job](./types.md#job) 375 | 376 | ### TorchrunExecutor().get_command 377 | 378 | [Show source in executor.py:307](../torch_submit/executor.py#L307) 379 | 380 | Constructs the command to run the job with torchrun. 381 | 382 | This method sets up the necessary parameters for a torchrun command, including 383 | the number of nodes, the number of processes per node, the rendezvous backend, 384 | the rendezvous endpoint, the job ID, and the maximum number of restarts. 385 | 386 | #### Arguments 387 | 388 | - `rank` *int* - The rank of the current node. 389 | 390 | #### Returns 391 | 392 | - `str` - The full command to run the job with torchrun. 393 | 394 | #### Signature 395 | 396 | ```python 397 | def get_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): ... 398 | ``` 399 | 400 | 401 | 402 | ## WorkingDirectoryArchiver 403 | 404 | [Show source in executor.py:21](../torch_submit/executor.py#L21) 405 | 406 | A class to handle archiving of working directories for jobs. 407 | 408 | This class creates a zip archive of the specified working directory, including job metadata 409 | and excluding files specified in a .gitignore file. 410 | 411 | #### Attributes 412 | 413 | - `job_id` *str* - The ID of the job. 414 | - `job_name` *str* - The name of the job. 415 | - `output_dir` *str* - The directory where the archive will be saved. 416 | 417 | #### Signature 418 | 419 | ```python 420 | class WorkingDirectoryArchiver: 421 | def __init__(self, job_id: str, job_name: str): ... 422 | ``` 423 | 424 | ### WorkingDirectoryArchiver().archive 425 | 426 | [Show source in executor.py:48](../torch_submit/executor.py#L48) 427 | 428 | Create a zip archive of the specified working directory. 429 | 430 | This method reads the .gitignore file in the working directory to determine which files 431 | to exclude from the archive. It also includes job metadata in the archive. 432 | 433 | #### Arguments 434 | 435 | - `working_dir` *str* - The path to the working directory to be archived. 436 | 437 | #### Returns 438 | 439 | - `str` - The path to the created zip archive. 440 | 441 | #### Signature 442 | 443 | ```python 444 | def archive(self, working_dir: str) -> str: ... 445 | ``` -------------------------------------------------------------------------------- /docs/job.md: -------------------------------------------------------------------------------- 1 | # Job 2 | 3 | [Torch-submit Index](./README.md#torch-submit-index) / Job 4 | 5 | > Auto-generated documentation for [job](../torch_submit/job.py) module. 6 | 7 | - [Job](#job) 8 | - [JobManager](#jobmanager) 9 | - [JobManager().add_job](#jobmanager()add_job) 10 | - [JobManager().check_job_status](#jobmanager()check_job_status) 11 | - [JobManager().close](#jobmanager()close) 12 | - [JobManager().create_table](#jobmanager()create_table) 13 | - [JobManager().delete_all_jobs](#jobmanager()delete_all_jobs) 14 | - [JobManager().delete_job](#jobmanager()delete_job) 15 | - [JobManager().get_all_jobs_with_status](#jobmanager()get_all_jobs_with_status) 16 | - [JobManager().get_job](#jobmanager()get_job) 17 | - [JobManager().list_jobs](#jobmanager()list_jobs) 18 | - [JobManager().migrate_table](#jobmanager()migrate_table) 19 | - [JobManager().update_job_pids](#jobmanager()update_job_pids) 20 | - [JobManager().update_job_status](#jobmanager()update_job_status) 21 | 22 | ## JobManager 23 | 24 | [Show source in job.py:14](../torch_submit/job.py#L14) 25 | 26 | Manages job-related operations and database interactions. 27 | 28 | #### Signature 29 | 30 | ```python 31 | class JobManager: 32 | def __init__( 33 | self, db_path: str = os.path.expanduser("~/.cache/torch-submit/jobs.db") 34 | ): ... 35 | ``` 36 | 37 | ### JobManager().add_job 38 | 39 | [Show source in job.py:49](../torch_submit/job.py#L49) 40 | 41 | Add a new job to the database. 42 | 43 | #### Arguments 44 | 45 | - [Job](#job) *Job* - The job to be added. 46 | 47 | #### Signature 48 | 49 | ```python 50 | def add_job(self, job: Job): ... 51 | ``` 52 | 53 | #### See also 54 | 55 | - [Job](./types.md#job) 56 | 57 | ### JobManager().check_job_status 58 | 59 | [Show source in job.py:96](../torch_submit/job.py#L96) 60 | 61 | Check the current status of a job. 62 | 63 | #### Arguments 64 | 65 | - [Job](#job) *Job* - The job to check. 66 | 67 | #### Returns 68 | 69 | - `str` - The current status of the job. 70 | 71 | #### Raises 72 | 73 | - `RuntimeError` - If an unknown job status is encountered. 74 | 75 | #### Signature 76 | 77 | ```python 78 | def check_job_status(self, job: Job) -> str: ... 79 | ``` 80 | 81 | #### See also 82 | 83 | - [Job](./types.md#job) 84 | 85 | ### JobManager().close 86 | 87 | [Show source in job.py:239](../torch_submit/job.py#L239) 88 | 89 | Close the database connection. 90 | 91 | #### Signature 92 | 93 | ```python 94 | def close(self): ... 95 | ``` 96 | 97 | ### JobManager().create_table 98 | 99 | [Show source in job.py:30](../torch_submit/job.py#L30) 100 | 101 | Create the jobs table if it doesn't exist. 102 | 103 | #### Signature 104 | 105 | ```python 106 | def create_table(self): ... 107 | ``` 108 | 109 | ### JobManager().delete_all_jobs 110 | 111 | [Show source in job.py:234](../torch_submit/job.py#L234) 112 | 113 | Delete all jobs from the database. 114 | 115 | #### Signature 116 | 117 | ```python 118 | def delete_all_jobs(self): ... 119 | ``` 120 | 121 | ### JobManager().delete_job 122 | 123 | [Show source in job.py:225](../torch_submit/job.py#L225) 124 | 125 | Delete a job from the database. 126 | 127 | #### Arguments 128 | 129 | - `job_id` *str* - The ID of the job to delete. 130 | 131 | #### Signature 132 | 133 | ```python 134 | def delete_job(self, job_id: str): ... 135 | ``` 136 | 137 | ### JobManager().get_all_jobs_with_status 138 | 139 | [Show source in job.py:173](../torch_submit/job.py#L173) 140 | 141 | Retrieve all jobs and update their statuses. 142 | 143 | #### Returns 144 | 145 | - `List[Job]` - A list of all jobs with updated statuses. 146 | 147 | #### Signature 148 | 149 | ```python 150 | def get_all_jobs_with_status(self) -> List[Job]: ... 151 | ``` 152 | 153 | #### See also 154 | 155 | - [Job](./types.md#job) 156 | 157 | ### JobManager().get_job 158 | 159 | [Show source in job.py:64](../torch_submit/job.py#L64) 160 | 161 | Retrieve a job by its ID or name. 162 | 163 | #### Arguments 164 | 165 | - `job_id_or_name` *str* - The ID or name of the job. 166 | 167 | #### Returns 168 | 169 | - `Optional[Job]` - The retrieved job, or None if not found. 170 | 171 | #### Signature 172 | 173 | ```python 174 | def get_job(self, job_id_or_name: str) -> Optional[Job]: ... 175 | ``` 176 | 177 | #### See also 178 | 179 | - [Job](./types.md#job) 180 | 181 | ### JobManager().list_jobs 182 | 183 | [Show source in job.py:87](../torch_submit/job.py#L87) 184 | 185 | Retrieve all jobs from the database. 186 | 187 | #### Returns 188 | 189 | - `List[Job]` - A list of all jobs. 190 | 191 | #### Signature 192 | 193 | ```python 194 | def list_jobs(self) -> List[Job]: ... 195 | ``` 196 | 197 | #### See also 198 | 199 | - [Job](./types.md#job) 200 | 201 | ### JobManager().migrate_table 202 | 203 | [Show source in job.py:243](../torch_submit/job.py#L243) 204 | 205 | Perform any necessary database migrations. 206 | 207 | #### Signature 208 | 209 | ```python 210 | def migrate_table(self): ... 211 | ``` 212 | 213 | ### JobManager().update_job_pids 214 | 215 | [Show source in job.py:209](../torch_submit/job.py#L209) 216 | 217 | Update the process IDs for a job in the database. 218 | 219 | #### Arguments 220 | 221 | - `job_id` *str* - The ID of the job to update. 222 | pids (Dict[Node, int]): A dictionary mapping nodes to process IDs. 223 | 224 | #### Signature 225 | 226 | ```python 227 | def update_job_pids(self, job_id: str, pids: Dict[Node, int]): ... 228 | ``` 229 | 230 | #### See also 231 | 232 | - [Node](./config.md#node) 233 | 234 | ### JobManager().update_job_status 235 | 236 | [Show source in job.py:192](../torch_submit/job.py#L192) 237 | 238 | Update the status of a job in the database. 239 | 240 | #### Arguments 241 | 242 | - `job_id` *str* - The ID of the job to update. 243 | - `status` *JobStatus* - The new status of the job. 244 | 245 | #### Raises 246 | 247 | - `ValueError` - If an invalid job status is provided. 248 | 249 | #### Signature 250 | 251 | ```python 252 | def update_job_status(self, job_id: str, status: JobStatus): ... 253 | ``` 254 | 255 | #### See also 256 | 257 | - [JobStatus](./types.md#jobstatus) -------------------------------------------------------------------------------- /docs/types.md: -------------------------------------------------------------------------------- 1 | # Types 2 | 3 | [Torch-submit Index](./README.md#torch-submit-index) / Types 4 | 5 | > Auto-generated documentation for [types](../torch_submit/types.py) module. 6 | 7 | - [Types](#types) 8 | - [Executor](#executor) 9 | - [Job](#job) 10 | - [Job().__post_init__](#job()__post_init__) 11 | - [Job().__str__](#job()__str__) 12 | - [Job.from_db](#jobfrom_db) 13 | - [Job().get_executor](#job()get_executor) 14 | - [Job().to_db](#job()to_db) 15 | - [JobStatus](#jobstatus) 16 | 17 | ## Executor 18 | 19 | [Show source in types.py:8](../torch_submit/types.py#L8) 20 | 21 | Enumeration of different types of executors. 22 | 23 | #### Signature 24 | 25 | ```python 26 | class Executor(str, Enum): ... 27 | ``` 28 | 29 | 30 | 31 | ## Job 32 | 33 | [Show source in types.py:27](../torch_submit/types.py#L27) 34 | 35 | A class representing a job to be executed. 36 | 37 | #### Attributes 38 | 39 | - `id` *str* - The ID of the job. 40 | - `name` *str* - The name of the job. 41 | - `status` *JobStatus* - The current status of the job. 42 | - `working_dir` *str* - The working directory for the job. 43 | - `nodes` *List[Node]* - The list of nodes assigned to the job. 44 | - `cluster` *str* - The cluster to which the job belongs. 45 | - `command` *str* - The command to be executed for the job. 46 | - `max_restarts` *int* - The maximum number of restarts allowed for the job. 47 | - `num_gpus` *Optional[int]* - The number of GPUs allocated for the job. 48 | pids (Dict[Node, int]): A dictionary mapping nodes to process IDs. 49 | - [Executor](./executor.md#executor) *Executor* - The executor type for the job. 50 | - `docker_image` *Optional[str]* - The Docker image to be used for the job. 51 | - `database` *Optional[Database]* - The database configuration for the job. 52 | - `optuna_port` *Optional[int]* - The port for Optuna executor. 53 | 54 | #### Signature 55 | 56 | ```python 57 | class Job: ... 58 | ``` 59 | 60 | ### Job().__post_init__ 61 | 62 | [Show source in types.py:63](../torch_submit/types.py#L63) 63 | 64 | Post-initialization checks for the Job class. 65 | 66 | #### Signature 67 | 68 | ```python 69 | def __post_init__(self): ... 70 | ``` 71 | 72 | ### Job().__str__ 73 | 74 | [Show source in types.py:161](../torch_submit/types.py#L161) 75 | 76 | Return a string representation of the Job instance. 77 | 78 | #### Returns 79 | 80 | - `str` - A string representation of the Job instance. 81 | 82 | #### Signature 83 | 84 | ```python 85 | def __str__(self): ... 86 | ``` 87 | 88 | ### Job.from_db 89 | 90 | [Show source in types.py:68](../torch_submit/types.py#L68) 91 | 92 | Create a Job instance from a database row. 93 | 94 | #### Arguments 95 | 96 | - `row` *Tuple* - A tuple representing a row from the database. 97 | 98 | #### Returns 99 | 100 | - [Job](#job) - A Job instance created from the database row. 101 | 102 | #### Signature 103 | 104 | ```python 105 | @classmethod 106 | def from_db(cls, row: Tuple) -> "Job": ... 107 | ``` 108 | 109 | ### Job().get_executor 110 | 111 | [Show source in types.py:129](../torch_submit/types.py#L129) 112 | 113 | Get the appropriate executor instance for the job. 114 | 115 | #### Returns 116 | 117 | An instance of the appropriate executor class. 118 | 119 | #### Raises 120 | 121 | - `ValueError` - If an unknown executor type is specified or if Docker image is not supported for the executor. 122 | 123 | #### Signature 124 | 125 | ```python 126 | def get_executor(self): ... 127 | ``` 128 | 129 | ### Job().to_db 130 | 131 | [Show source in types.py:105](../torch_submit/types.py#L105) 132 | 133 | Convert the Job instance to a tuple for database storage. 134 | 135 | #### Returns 136 | 137 | - `Tuple` - A tuple representing the Job instance for database storage. 138 | 139 | #### Signature 140 | 141 | ```python 142 | def to_db(self) -> Tuple: ... 143 | ``` 144 | 145 | 146 | 147 | ## JobStatus 148 | 149 | [Show source in types.py:15](../torch_submit/types.py#L15) 150 | 151 | Enumeration of different job statuses. 152 | 153 | #### Signature 154 | 155 | ```python 156 | class JobStatus(str, Enum): ... 157 | ``` -------------------------------------------------------------------------------- /docs/utils.md: -------------------------------------------------------------------------------- 1 | # Utils 2 | 3 | [Torch-submit Index](./README.md#torch-submit-index) / Utils 4 | 5 | > Auto-generated documentation for [utils](../torch_submit/utils.py) module. 6 | 7 | - [Utils](#utils) 8 | - [generate_friendly_name](#generate_friendly_name) 9 | - [get_job_metadata](#get_job_metadata) 10 | 11 | ## generate_friendly_name 12 | 13 | [Show source in utils.py:6](../torch_submit/utils.py#L6) 14 | 15 | Generate a friendly, human-readable name for a job. 16 | 17 | This function creates a name by combining a random adjective, a random animal noun, 18 | and a random 4-digit number. The name format is 'adjective-noun-number'. 19 | 20 | #### Returns 21 | 22 | A friendly name string in the format 'adjective-noun-number'. 23 | 24 | #### Examples 25 | 26 | ```python 27 | >>> generate_friendly_name() 28 | 'happy-panda-3721' 29 | ``` 30 | 31 | #### Signature 32 | 33 | ```python 34 | def generate_friendly_name() -> str: ... 35 | ``` 36 | 37 | 38 | 39 | ## get_job_metadata 40 | 41 | [Show source in utils.py:58](../torch_submit/utils.py#L58) 42 | 43 | Retrieve job metadata from the '.torch/job.json' file. 44 | 45 | This function attempts to read and parse the job metadata stored in the 46 | '.torch/job.json' file in the current working directory. 47 | 48 | #### Returns 49 | 50 | A dictionary containing job metadata if the file exists and can be parsed 51 | successfully, or None if the file is not found. 52 | 53 | #### Raises 54 | 55 | - `json.JSONDecodeError` - If the file exists but contains invalid JSON. 56 | 57 | #### Examples 58 | 59 | ```python 60 | >>> metadata = get_job_metadata() 61 | >>> if metadata: 62 | ... print(f"Job ID: {metadata.get('id')}") 63 | ... else: 64 | ... print("No job metadata found.") 65 | ``` 66 | 67 | #### Signature 68 | 69 | ```python 70 | def get_job_metadata() -> Optional[Dict[str, str]]: ... 71 | ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "torch-submit" 7 | authors = [ 8 | {name = "Tony Francis", email = "tony@dream3d.com"}, 9 | ] 10 | description = "A tool for submitting and managing distributed PyTorch jobs" 11 | readme = "README.md" 12 | requires-python = ">=3.7" 13 | keywords = ["pytorch", "distributed", "job submission"] 14 | license = {text = "MIT"} 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ] 20 | dependencies = [ 21 | "typer", 22 | "rich", 23 | "pyyaml", 24 | "fabric", 25 | "sqlalchemy", 26 | "psycopg2-binary", 27 | "optuna", 28 | "optuna-dashboard", 29 | ] 30 | dynamic = ["version"] 31 | 32 | [project.scripts] 33 | torch-submit = "torch_submit.cli:app" 34 | 35 | [tool.setuptools_scm] 36 | write_to = "torch_submit/_version.py" 37 | 38 | [tool.setuptools.packages.find] 39 | where = ["."] 40 | include = ["torch_submit*"] 41 | exclude = ["tests*"] 42 | 43 | [project.urls] 44 | "Homepage" = "https://github.com/dream3d-ai/torch-submit" 45 | "Bug Tracker" = "https://github.com/dream3d-ai/torch-submit/issues" -------------------------------------------------------------------------------- /requirements.dev.txt: -------------------------------------------------------------------------------- 1 | ruff 2 | uv 3 | twine 4 | build 5 | setuptools_scm 6 | pre-commit -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="torch-submit", 5 | use_scm_version=True, 6 | setup_requires=['setuptools_scm'], 7 | packages=find_packages(), 8 | install_requires=[ 9 | "typer", 10 | "rich", 11 | "pyyaml", 12 | "fabric", 13 | "sqlalchemy", 14 | "psycopg2", 15 | "optuna", 16 | "optuna-dashboard", 17 | ], 18 | entry_points={ 19 | "console_scripts": [ 20 | "torch-submit=torch_submit.cli:app", 21 | ], 22 | }, 23 | ) -------------------------------------------------------------------------------- /torch_submit/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings(action="ignore", module=".*paramiko.*") 4 | 5 | from . import utils 6 | from ._version import version as __version__ 7 | from .cli import app as cli 8 | 9 | 10 | 11 | __all__ = ["cli", "utils", "__version__"] 12 | -------------------------------------------------------------------------------- /torch_submit/cli.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | import typer 4 | 5 | from .commands import cluster, database, job 6 | 7 | app = typer.Typer() 8 | 9 | app.add_typer(cluster.app, name="cluster") 10 | app.add_typer(job.app, name="job") 11 | app.add_typer(database.app, name="db") 12 | 13 | 14 | def version_callback(value: bool): 15 | if value: 16 | print(f"torch-submit version: {version('torch-submit')}") 17 | raise typer.Exit() 18 | 19 | 20 | @app.callback() 21 | def main( 22 | version: bool = typer.Option( 23 | None, 24 | "--version", 25 | callback=version_callback, 26 | is_eager=True, 27 | help="Show the version and exit.", 28 | ), 29 | ): 30 | pass 31 | 32 | 33 | if __name__ == "__main__": 34 | app() 35 | -------------------------------------------------------------------------------- /torch_submit/commands/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dream3d-ai/torch-submit/de666ab245fa1307629bb4ec8da104d45c4decf9/torch_submit/commands/__init__.py -------------------------------------------------------------------------------- /torch_submit/commands/cluster.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from rich import box 3 | from rich.console import Console 4 | from rich.prompt import Confirm, Prompt 5 | from rich.table import Table 6 | 7 | from ..config import Config, Node 8 | 9 | app = typer.Typer() 10 | console = Console() 11 | config = Config() 12 | 13 | 14 | @app.command("create") 15 | def create_cluster(): 16 | """ 17 | Interactively create a new cluster configuration. 18 | 19 | Prompts the user for cluster details such as name, head node, and worker nodes. 20 | Adds the new cluster configuration to the config. 21 | """ 22 | name = Prompt.ask("Enter cluster name") 23 | 24 | # Head node 25 | head_public_ip = Prompt.ask("Enter head node public IP") 26 | head_private_ip = Prompt.ask("Enter head node private IP (optional)", default="") 27 | head_num_gpus = int(Prompt.ask("Enter number of GPUs on head node", default="0")) 28 | head_nproc = int(Prompt.ask("Enter number of processes for head node", default="1")) 29 | ssh_user = Prompt.ask("Enter SSH user for head node (optional)", default="") 30 | ssh_port = int(Prompt.ask("Enter SSH port for head node (optional)", default="22")) 31 | ssh_pub_key_path = Prompt.ask( 32 | "Enter absolute path to SSH public key file (optional)", default="" 33 | ) 34 | 35 | head_node = Node( 36 | head_public_ip, 37 | head_private_ip or None, 38 | head_num_gpus, 39 | head_nproc, 40 | ssh_user, 41 | ssh_pub_key_path, 42 | ssh_port 43 | ) 44 | 45 | # Worker nodes 46 | worker_nodes = [] 47 | while Confirm.ask("Add a worker node?", default=False): 48 | worker_public_ip = Prompt.ask("Enter worker node public IP") 49 | worker_private_ip = Prompt.ask( 50 | "Enter worker node private IP (optional)", default="" 51 | ) 52 | worker_num_gpus = int( 53 | Prompt.ask("Enter number of GPUs on worker node", default="0") 54 | ) 55 | worker_nproc = int( 56 | Prompt.ask("Enter number of processes for worker node", default="1") 57 | ) 58 | worker_ssh_user = Prompt.ask( 59 | "Enter SSH user for head node (optional)", default="" 60 | ) 61 | worker_ssh_port = int(Prompt.ask("Enter SSH port for worker node (optional)", default="22")) 62 | worker_ssh_pub_key_path = Prompt.ask( 63 | "Enter absolute path to SSH public key file (optional)", default="" 64 | ) 65 | 66 | worker_node = Node( 67 | worker_public_ip, 68 | worker_private_ip or None, 69 | worker_num_gpus, 70 | worker_nproc, 71 | worker_ssh_user, 72 | worker_ssh_pub_key_path, 73 | worker_ssh_port 74 | ) 75 | worker_nodes.append(worker_node) 76 | 77 | config.add_cluster(name, head_node, worker_nodes) 78 | console.print(f"Cluster [bold green]{name}[/bold green] created successfully.") 79 | 80 | 81 | @app.command("list") 82 | def list_clusters(): 83 | """ 84 | List all available clusters. 85 | 86 | Retrieves the list of clusters from the config and displays them in a table format. 87 | """ 88 | clusters = config.list_clusters() 89 | 90 | table = Table(title="Available Clusters", box=box.ROUNDED) 91 | table.add_column("Cluster Name", style="cyan") 92 | table.add_column("Head Node", style="magenta") 93 | table.add_column("Worker Nodes", style="green") 94 | table.add_column("Total GPUs", style="yellow") 95 | table.add_column("Total Processes", style="blue") 96 | 97 | for cluster_name in clusters: 98 | cluster = config.get_cluster(cluster_name) 99 | total_gpus = cluster.head_node.num_gpus + sum( 100 | node.num_gpus for node in cluster.worker_nodes 101 | ) 102 | total_procs = cluster.head_node.nproc + sum( 103 | node.nproc for node in cluster.worker_nodes 104 | ) 105 | table.add_row( 106 | cluster_name, 107 | cluster.head_node.public_ip, 108 | str(len(cluster.worker_nodes)), 109 | str(total_gpus), 110 | str(total_procs), 111 | ) 112 | 113 | console.print(table) 114 | 115 | 116 | @app.command("remove") 117 | def remove_cluster(name: str): 118 | """ 119 | Remove a cluster configuration. 120 | 121 | Prompts the user for confirmation before removing the specified cluster configuration from the config. 122 | 123 | Args: 124 | name (str): The name of the cluster to remove. 125 | """ 126 | if Confirm.ask(f"Are you sure you want to remove cluster '{name}'?"): 127 | config.remove_cluster(name) 128 | console.print(f"Cluster [bold red]{name}[/bold red] removed.") 129 | else: 130 | console.print("Cluster removal cancelled.") 131 | 132 | 133 | @app.command("edit") 134 | def edit_cluster(name: str): 135 | """ 136 | Edit an existing cluster configuration. 137 | 138 | Prompts the user for new cluster details and updates the specified cluster configuration in the config. 139 | 140 | Args: 141 | name (str): The name of the cluster to edit. 142 | """ 143 | try: 144 | cluster = config.get_cluster(name) 145 | except ValueError: 146 | console.print(f"[bold red]Error:[/bold red] Cluster '{name}' not found.") 147 | raise typer.Exit(code=1) 148 | 149 | console.print(f"Editing cluster: [bold green]{name}[/bold green]") 150 | 151 | # Edit head node 152 | head_node = cluster.head_node 153 | head_node.public_ip = typer.prompt("Head node public IP", default=head_node.public_ip) 154 | head_node.private_ip = typer.prompt("Head node private IP (optional)", default=head_node.private_ip or "") 155 | head_node.num_gpus = typer.prompt("Number of GPUs on head node", default=head_node.num_gpus, type=int) 156 | head_node.nproc = typer.prompt("Number of processes on head node", default=head_node.nproc, type=int) 157 | head_node.ssh_user = typer.prompt("SSH user for head node (optional)", default=head_node.ssh_user or "") 158 | head_node.ssh_pub_key_path = typer.prompt("SSH public key path for head node (optional)", default=head_node.ssh_pub_key_path or "") 159 | 160 | # Edit worker nodes 161 | worker_nodes = [] 162 | for i, worker in enumerate(cluster.worker_nodes): 163 | console.print(f"\nEditing worker node {i+1}") 164 | public_ip = typer.prompt("Worker node public IP", default=worker.public_ip) 165 | private_ip = typer.prompt("Worker node private IP (optional)", default=worker.private_ip or "") 166 | num_gpus = typer.prompt("Number of GPUs on worker node", default=worker.num_gpus, type=int) 167 | nproc = typer.prompt("Number of processes on worker node", default=worker.nproc, type=int) 168 | ssh_user = typer.prompt("SSH user for worker node (optional)", default=worker.ssh_user or "") 169 | ssh_pub_key_path = typer.prompt("SSH public key path for worker node (optional)", default=worker.ssh_pub_key_path or "") 170 | 171 | worker_node = Node(public_ip, private_ip or None, num_gpus, nproc, ssh_user, ssh_pub_key_path) 172 | worker_nodes.append(worker_node) 173 | 174 | if not typer.confirm("Add another worker node?", default=False): 175 | break 176 | 177 | # Update the cluster configuration 178 | config.update_cluster(name, head_node, worker_nodes) 179 | console.print(f"Cluster [bold green]{name}[/bold green] updated successfully.") -------------------------------------------------------------------------------- /torch_submit/commands/database.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from rich import box 3 | from rich.console import Console 4 | from rich.prompt import Confirm, Prompt 5 | from rich.table import Table 6 | 7 | from ..config import Config, DatabaseType 8 | 9 | app = typer.Typer() 10 | console = Console() 11 | config = Config() 12 | 13 | @app.command("create") 14 | def create_database(): 15 | """ 16 | Interactively create a new database configuration. 17 | 18 | Prompts the user for database details such as name, type, address, port, username, and password. 19 | Adds the new database configuration to the config. 20 | """ 21 | name = Prompt.ask("Enter database name") 22 | 23 | # Database address and port 24 | type = Prompt.ask("Enter database type (mysql, postgres)", default=DatabaseType.POSTGRES.value) 25 | address = Prompt.ask("Enter database address") 26 | port = int(Prompt.ask("Enter database port", default="5432")) 27 | username = Prompt.ask("Enter database username") 28 | password = Prompt.ask("Enter database password (optional)", password=True, default="") 29 | 30 | config.add_db(type, name, address, port, username, password) 31 | console.print(f"Database [bold green]{name}[/bold green] created successfully.") 32 | 33 | 34 | @app.command("list") 35 | def list_databases(): 36 | """ 37 | List all available databases. 38 | 39 | Retrieves the list of databases from the config and displays them in a table format. 40 | """ 41 | databases = config.list_dbs() 42 | 43 | table = Table(title="Available Databases", box=box.ROUNDED) 44 | table.add_column("Database Name", style="cyan") 45 | table.add_column("Type", style="magenta") 46 | table.add_column("Address", style="green") 47 | table.add_column("Port", style="yellow") 48 | table.add_column("Username", style="yellow") 49 | table.add_column("Password", style="red") 50 | 51 | for db_name in databases: 52 | database = config.get_db(db_name) 53 | table.add_row( 54 | db_name, 55 | database.type.value, 56 | database.address, 57 | str(database.port), 58 | database.username, 59 | "****" if database.password else "Not set", 60 | ) 61 | 62 | console.print(table) 63 | 64 | 65 | @app.command("remove") 66 | def remove_database(name: str): 67 | """ 68 | Remove a database configuration. 69 | 70 | Prompts the user for confirmation before removing the specified database configuration from the config. 71 | 72 | Args: 73 | name (str): The name of the database to remove. 74 | """ 75 | if Confirm.ask(f"Are you sure you want to remove database '{name}'?"): 76 | config.remove_db(name) 77 | console.print(f"Database [bold red]{name}[/bold red] removed.") 78 | else: 79 | console.print("Database removal cancelled.") 80 | 81 | 82 | @app.command("edit") 83 | def edit_database(name: str): 84 | """ 85 | Edit an existing database configuration. 86 | 87 | Prompts the user for new database details and updates the specified database configuration in the config. 88 | 89 | Args: 90 | name (str): The name of the database to edit. 91 | """ 92 | try: 93 | database = config.get_db(name) 94 | except ValueError: 95 | console.print(f"[bold red]Error:[/bold red] Database '{name}' not found.") 96 | raise typer.Exit(code=1) 97 | 98 | console.print(f"Editing database: [bold green]{name}[/bold green]") 99 | 100 | # Edit database address and port 101 | type = Prompt.ask("Enter database type (mysql, postgres)", default=database.type.value) 102 | address = Prompt.ask("Enter database address", default=database.address) 103 | port = Prompt.ask("Enter database port", default=database.port) 104 | username = Prompt.ask("Enter database username", default=database.username) 105 | password = Prompt.ask("Enter database password (optional)", password=True, default=database.password) 106 | 107 | # Update the database configuration 108 | config.update_db(type, name, address, port, username, password) 109 | console.print(f"Database [bold green]{name}[/bold green] updated successfully.") -------------------------------------------------------------------------------- /torch_submit/commands/job.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import uuid 4 | from typing import List, Optional 5 | 6 | import typer 7 | import yaml 8 | from rich.console import Console 9 | from rich.table import Table 10 | 11 | from ..config import Config 12 | from ..connection import NodeConnection 13 | from ..executor import ( 14 | BaseExecutor, 15 | WorkingDirectoryArchiver, 16 | ) 17 | from ..job import JobManager 18 | from ..types import Executor, Job, JobStatus 19 | from ..utils import generate_friendly_name 20 | 21 | app = typer.Typer() 22 | console = Console() 23 | config = Config() 24 | job_manager = JobManager() 25 | config = Config() 26 | 27 | 28 | @app.command("submit") 29 | def submit( 30 | cluster: str = typer.Option(..., help="Name of the cluster to use"), 31 | name: Optional[str] = typer.Option( 32 | None, help="Job name (optional, will be auto-generated if not provided)" 33 | ), 34 | working_dir: str = typer.Option("./", help="Path to working directory"), 35 | max_restarts: int = typer.Option(0, help="Maximum number of restarts for the job"), 36 | num_gpus: Optional[int] = typer.Option( 37 | None, 38 | help="Number of GPUs to use per node (optional, defaults to all available)", 39 | ), 40 | command: List[str] = typer.Argument( 41 | ..., help="The command to run, e.g. 'python main.py'" 42 | ), 43 | tail: bool = typer.Option(False, help="Tail the logs after submitting the job"), 44 | executor: Executor = typer.Option(Executor.TORCHRUN, help="Executor to use"), 45 | docker_image: Optional[str] = typer.Option(None, help="Docker image to use"), 46 | database: Optional[str] = typer.Option(None, help="Database to use"), 47 | runtime_env: Optional[str] = typer.Option( 48 | None, help="Runtime environment yaml file to use" 49 | ), 50 | ): 51 | """ 52 | Submit a new job to a specified cluster. 53 | 54 | Args: 55 | cluster (str): Name of the cluster to use. 56 | name (Optional[str]): Job name (optional, will be auto-generated if not provided). 57 | working_dir (str): Path to working directory. 58 | max_restarts (int): Maximum number of restarts for the job. 59 | num_gpus (Optional[int]): Number of GPUs to use per node (optional, defaults to all available). 60 | command (List[str]): The command to run, e.g. 'python main.py'. 61 | tail (bool): Tail the logs after submitting the job. 62 | executor (Executor): Executor to use. 63 | docker_image (Optional[str]): Docker image to use. 64 | database (Optional[str]): Database to use. 65 | runtime_env (Optional[str]): Runtime environment yaml file to use. 66 | """ 67 | if executor == Executor.OPTUNA: 68 | if not database: 69 | console.print( 70 | "[bold red]Error:[/bold red] Database is required for optuna executor" 71 | ) 72 | raise typer.Exit(code=1) 73 | try: 74 | config.get_db(database) 75 | except ValueError: 76 | console.print(f"Could not find database {database}") 77 | raise typer.Exit(code=1) 78 | 79 | try: 80 | cluster_info = config.get_cluster(cluster) 81 | except ValueError as e: 82 | console.print(f"[bold red]Error:[/bold red] {str(e)}") 83 | raise typer.Exit(code=1) 84 | 85 | if num_gpus is not None and num_gpus > cluster_info.head_node.num_gpus: 86 | console.print( 87 | f"[bold red]Error:[/bold red] Requested GPUs ({num_gpus}) exceeds available GPUs on head node ({cluster_info.head_node.num_gpus})" 88 | ) 89 | raise typer.Exit(code=1) 90 | 91 | if name is None: 92 | name = generate_friendly_name() 93 | 94 | working_dir = os.path.abspath(working_dir) 95 | 96 | job_id = str(uuid.uuid4()) 97 | archiver = WorkingDirectoryArchiver(job_id=job_id, job_name=name) 98 | 99 | if runtime_env: 100 | console.print( 101 | f"Loading runtime environment variables from: [bold green]{runtime_env}[/bold green]" 102 | ) 103 | with open(runtime_env, "r") as f: 104 | runtime_env_vars = yaml.load(f, Loader=yaml.FullLoader) 105 | assert all( 106 | isinstance(value, str) for value in runtime_env_vars.values() 107 | ), "All values in runtime_env must be strings" 108 | else: 109 | runtime_env_vars = None 110 | 111 | console.print("Archiving working directory...") 112 | archived_dir = archiver.archive(working_dir) 113 | console.print( 114 | f"Working directory archived to: [bold green]{archived_dir}[/bold green]" 115 | ) 116 | 117 | nodes = [cluster_info.head_node] + cluster_info.worker_nodes 118 | job = Job( 119 | id=job_id, 120 | name=name, 121 | status=JobStatus.SUBMITTED, 122 | working_dir=archived_dir, 123 | nodes=nodes, 124 | cluster=cluster, 125 | command=" ".join(command), 126 | max_restarts=max_restarts, 127 | num_gpus=num_gpus, 128 | executor=executor, 129 | docker_image=docker_image, 130 | database=database, 131 | optuna_port=random.randint(8000, 9000) if executor == Executor.OPTUNA else None, 132 | ) 133 | console.print("Submitting job...") 134 | job_manager.add_job(job) 135 | 136 | job_executor = job.get_executor() 137 | pids = job_executor.execute(runtime_env_vars) 138 | 139 | if all(pid is None for pid in pids.values()): 140 | job_manager.update_job_status(job_id, JobStatus.CRASHED) 141 | console.print(f"Job [bold red]{job_id}[/bold red] failed to start.") 142 | for node in nodes: 143 | with NodeConnection(node) as c: 144 | c.run(f"pkill -TERM -P {pids[node]}", hide=True) 145 | raise typer.Exit(code=1) 146 | 147 | job_manager.update_job_status(job_id, JobStatus.RUNNING) 148 | job_manager.update_job_pids(job_id, pids) 149 | 150 | console.print(f"Job submitted with name: [bold green]{name}[/bold green]") 151 | console.print(f"Job ID: [bold blue]{job_id}[/bold blue]") 152 | console.print(f"Working directory: [bold blue]{working_dir}[/bold blue]") 153 | console.print(f"Command: [bold yellow]{' '.join(command)}[/bold yellow]") 154 | console.print(f"Max restarts: [bold cyan]{max_restarts}[/bold cyan]") 155 | console.print( 156 | f"GPUs per node: [bold magenta]{num_gpus or 'All available'}[/bold magenta]" 157 | ) 158 | 159 | if tail: 160 | console.print("Tailing logs...") 161 | with NodeConnection(nodes[0]) as c: 162 | c.run(f"tail -f /tmp/torch_submit_job_{job.id}/output.log") 163 | 164 | 165 | @app.command("logs") 166 | def print_logs( 167 | job_id: str = typer.Argument(..., help="Job ID or name"), 168 | tail: bool = typer.Option(False, help="Tail the logs"), 169 | ): 170 | """ 171 | Tail the logs of a specific job. 172 | 173 | Args: 174 | job_id (str): Job ID or name. 175 | tail (bool): Tail the logs. 176 | """ 177 | job_manager = JobManager() 178 | job = job_manager.get_job(job_id) 179 | if job: 180 | console.print(f"Tailing logs for job [bold green]{job_id}[/bold green]") 181 | console.print("Press [bold red]Ctrl+C[/bold red] to stop") 182 | with NodeConnection(job.nodes[0]) as c: 183 | if tail: 184 | c.run(f"tail -f /tmp/torch_submit_job_{job.id}/output.log") 185 | else: 186 | result = c.run(f"cat /tmp/torch_submit_job_{job.id}/output.log") 187 | console.print(result.stdout) 188 | else: 189 | console.print( 190 | f"Job with ID [bold red]{job_id}[/bold red] not found", style="bold red" 191 | ) 192 | 193 | 194 | @app.command("list") 195 | def list_jobs(): 196 | """ 197 | List all submitted jobs. 198 | """ 199 | job_manager = JobManager() 200 | jobs = job_manager.get_all_jobs_with_status() 201 | 202 | table = Table() 203 | table.add_column("ID", style="cyan", no_wrap=True) 204 | table.add_column("Name", style="magenta") 205 | table.add_column("Status", style="green") 206 | table.add_column("Cluster", style="yellow") 207 | table.add_column("Nodes", style="blue") 208 | 209 | for job in jobs: 210 | status_style = { 211 | "started": "bold yellow", 212 | "running": "bold green", 213 | "crashed": "bold red", 214 | "stopping": "bold yellow", 215 | "stopped": "bold cyan", 216 | }.get(job.status, "bold white") 217 | 218 | table.add_row( 219 | job.id, 220 | job.name, 221 | f"[{status_style}]{job.status}[/{status_style}]", 222 | job.cluster, 223 | str(len(job.nodes)), 224 | ) 225 | console.print(table) 226 | 227 | 228 | @app.command("stop") 229 | def stop_job(job_id: str = typer.Argument(..., help="Job ID or name")): 230 | """ 231 | Stop a running job. 232 | 233 | Args: 234 | job_id (str): Job ID or name. 235 | """ 236 | job_manager = JobManager() 237 | job = job_manager.get_job(job_id) 238 | if not job: 239 | console.print( 240 | f"Job with ID [bold red]{job_id}[/bold red] not found", style="bold red" 241 | ) 242 | raise typer.Exit(code=1) 243 | 244 | console.print(f"Stopping job [bold yellow]{job_id}[/bold yellow]") 245 | 246 | try: 247 | for node, pid in job.pids.items(): 248 | with NodeConnection(node) as c: 249 | c.run(f"pkill -TERM -P {pid}", warn=True) 250 | 251 | if job.optuna_port: 252 | with NodeConnection(job.nodes[0]) as c: 253 | c.run(f"pkill -TERM -f 'optuna-dashboard --port {job.optuna_port}'", warn=True) 254 | 255 | job_manager.update_job_status(job_id, JobStatus.STOPPING) 256 | console.print(f"Job [bold green]{job_id}[/bold green] is stopping") 257 | except Exception as e: 258 | console.print(f"[bold red]Error stopping job:[/bold red] {str(e)}") 259 | raise typer.Exit(code=1) 260 | 261 | 262 | @app.command("restart") 263 | def restart_job(job_id: str = typer.Argument(..., help="Job ID or name")): 264 | """ 265 | Restart a stopped job. 266 | 267 | Args: 268 | job_id (str): Job ID or name. 269 | """ 270 | job = job_manager.get_job(job_id) 271 | if not job: 272 | console.print( 273 | f"Job with ID [bold red]{job_id}[/bold red] not found", style="bold red" 274 | ) 275 | raise typer.Exit(code=1) 276 | 277 | if job.status != "stopped": 278 | console.print( 279 | f"Job [bold yellow]{job_id}[/bold yellow] is not stopped. Current status: {job.status}" 280 | ) 281 | raise typer.Exit(code=1) 282 | 283 | console.print(f"Restarting job [bold yellow]{job_id}[/bold yellow]") 284 | 285 | try: 286 | cluster = config.get_cluster(job.cluster) 287 | script_name = job.command.split()[-1] 288 | script_path = os.path.join(f"/tmp/torch_submit_job_{job.id}", script_name) 289 | 290 | # Check if the job is already running on any node 291 | for node in [cluster.head_node] + cluster.worker_nodes: 292 | with NodeConnection(node) as c: 293 | result = c.run(f"pgrep -f '{script_path}'", warn=True) 294 | if result.ok: 295 | console.print( 296 | f"Job [bold yellow]{job_id}[/bold yellow] is already running on node {node_ip}" 297 | ) 298 | raise typer.Exit(code=1) 299 | 300 | # If not running, restart the job 301 | executor = job.get_executor() 302 | pids = executor.execute() 303 | 304 | job_manager.update_job_status(job_id, JobStatus.RUNNING) 305 | job_manager.update_job_pids(job_id, pids) 306 | console.print(f"Job [bold green]{job_id}[/bold green] has been restarted") 307 | except Exception as e: 308 | console.print(f"[bold red]Error restarting job:[/bold red] {str(e)}") 309 | raise typer.Exit(code=1) 310 | 311 | 312 | @app.command("delete") 313 | def delete_job( 314 | job_id: str = typer.Argument( 315 | ..., help="Job ID or name to delete or 'all' to delete all jobs" 316 | ), 317 | ): 318 | """ 319 | Delete a job. 320 | 321 | Args: 322 | job_id (str): Job ID or name to delete or 'all' to delete all jobs. 323 | """ 324 | job_manager = JobManager() 325 | jobs = job_manager.get_all_jobs_with_status() 326 | 327 | if not job_id == "all": 328 | jobs = [job for job in jobs if job.id == job_id] 329 | 330 | # Prepare a list of job IDs to be deleted 331 | job_ids_to_delete = [job.id for job in jobs] 332 | 333 | # If no jobs found, exit 334 | if not job_ids_to_delete: 335 | console.print("No jobs found to delete.") 336 | raise typer.Exit(code=1) 337 | 338 | # Show confirmation prompt 339 | if job_id == "all": 340 | message = f"Are you sure you want to delete all {len(job_ids_to_delete)} jobs?" 341 | else: 342 | message = f"Are you sure you want to delete job {job_id}?" 343 | 344 | if not typer.confirm(message): 345 | console.print("Operation cancelled.") 346 | raise typer.Exit(code=0) 347 | 348 | # Stop the job if it is still running 349 | for job in jobs: 350 | if job.status not in ["finished", "crashed", "stopped"]: 351 | try: 352 | stop_job(job.id) 353 | except typer.Exit: 354 | console.print(f"Failed to stop job [bold yellow]{job.id}[/bold yellow]") 355 | 356 | # Clean-up the job from remote (executor.cleanup) 357 | for job in jobs: 358 | try: 359 | executor = BaseExecutor(job) 360 | executor.cleanup() 361 | except Exception: 362 | pass 363 | 364 | for job in jobs: 365 | job_manager.delete_job(job.id) 366 | 367 | if job_id == "all": 368 | console.print( 369 | f"[bold green]Successfully deleted all {len(job_ids_to_delete)} jobs.[/bold green]" 370 | ) 371 | else: 372 | console.print(f"[bold green]Successfully deleted job {job_id}.[/bold green]") 373 | console.print("All specified jobs have been stopped and removed from the system.") 374 | -------------------------------------------------------------------------------- /torch_submit/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from typing import Dict, List, Optional 5 | 6 | import yaml 7 | from sqlalchemy import create_engine, text 8 | 9 | 10 | @dataclass 11 | class Node: 12 | """Represents a node in a cluster. 13 | 14 | Attributes: 15 | public_ip (str): The public IP address of the node. 16 | private_ip (Optional[str]): The private IP address of the node, if available. 17 | num_gpus (int): The number of GPUs available on the node. 18 | nproc (int): The number of processes that can run on the node. 19 | ssh_user (Optional[str]): The SSH username for accessing the node, if available. 20 | ssh_pub_key_path (Optional[str]): The path to the SSH public key file, if available. 21 | """ 22 | 23 | public_ip: str 24 | private_ip: Optional[str] 25 | num_gpus: int 26 | nproc: int 27 | ssh_user: Optional[str] 28 | ssh_pub_key_path: Optional[str] 29 | ssh_port: Optional[int] 30 | 31 | def __post_init__(self): 32 | """Initialize the Node object after creation.""" 33 | self.private_ip = self.private_ip or None 34 | self.ssh_user = self.ssh_user or None 35 | self.ssh_pub_key_path = self.ssh_pub_key_path or None 36 | self.num_gpus = int(self.num_gpus) 37 | self.nproc = int(self.nproc) 38 | 39 | @classmethod 40 | def from_db(cls, row: str): 41 | """Create a Node object from a database row string. 42 | 43 | Args: 44 | row (str): A string representation of the node data. 45 | 46 | Returns: 47 | Node: A new Node object created from the row data. 48 | """ 49 | public_ip, private_ip, num_gpus, nproc, ssh_user, ssh_pub_key_path, ssh_port = row.split( 50 | ":" 51 | ) 52 | return cls( 53 | public_ip, 54 | private_ip if private_ip != "None" else None, 55 | int(num_gpus), 56 | int(nproc), 57 | ssh_user if ssh_user != "None" else None, 58 | ssh_pub_key_path if ssh_pub_key_path != "None" else None, 59 | ssh_port if ssh_port != "None" else None, 60 | ) 61 | 62 | def to_db(self): 63 | """Convert the Node object to a string representation for database storage. 64 | 65 | Returns: 66 | str: A string representation of the Node object. 67 | """ 68 | return f"{self.public_ip}:{self.private_ip or 'None'}:{self.num_gpus}:{self.nproc}:{self.ssh_user or 'None'}:{self.ssh_pub_key_path or 'None'}:{self.ssh_port or 'None'}" 69 | 70 | def __str__(self): 71 | """Return a string representation of the Node object. 72 | 73 | Returns: 74 | str: A string representation of the Node object. 75 | """ 76 | return f"Node(public_ip={self.public_ip}, private_ip={self.private_ip}, num_gpus={self.num_gpus}, nproc={self.nproc}, ssh_user={self.ssh_user}, ssh_pub_key_path={self.ssh_pub_key_path})" 77 | 78 | def __hash__(self): 79 | """Return a hash value for the Node object. 80 | 81 | Returns: 82 | int: A hash value based on the public IP address. 83 | """ 84 | return hash(self.public_ip) 85 | 86 | def __eq__(self, other): 87 | """Check if two Node objects are equal. 88 | 89 | Args: 90 | other: Another object to compare with. 91 | 92 | Returns: 93 | bool: True if the objects are equal, False otherwise. 94 | """ 95 | if not isinstance(other, Node): 96 | return NotImplemented 97 | return self.public_ip == other.public_ip 98 | 99 | 100 | @dataclass 101 | class Cluster: 102 | """Represents a cluster of nodes. 103 | 104 | Attributes: 105 | head_node (Node): The head node of the cluster. 106 | worker_nodes (List[Node]): A list of worker nodes in the cluster. 107 | """ 108 | 109 | head_node: Node 110 | worker_nodes: List[Node] 111 | 112 | 113 | class DatabaseType(str, Enum): 114 | """Enumeration of supported database types. 115 | 116 | Attributes: 117 | POSTGRES: PostgreSQL database type. 118 | MYSQL: MySQL database type. 119 | """ 120 | 121 | POSTGRES = "postgres" 122 | MYSQL = "mysql" 123 | 124 | @property 125 | def connection_string(self): 126 | """Get the SQLAlchemy connection string prefix for the database type. 127 | 128 | Returns: 129 | str: The connection string prefix. 130 | 131 | Raises: 132 | ValueError: If the database type is unknown. 133 | """ 134 | if self == DatabaseType.POSTGRES: 135 | return "postgresql+psycopg2" 136 | elif self == DatabaseType.MYSQL: 137 | return "mysql" 138 | else: 139 | raise ValueError(f"Unknown database type: {self}") 140 | 141 | 142 | @dataclass 143 | class Database: 144 | """Represents a database configuration. 145 | 146 | Attributes: 147 | address (str): The address of the database server. 148 | port (int): The port number for the database connection. 149 | username (str): The username for database authentication. 150 | password (str | None): The password for database authentication, if required. 151 | type (DatabaseType): The type of the database (e.g., PostgreSQL, MySQL). 152 | """ 153 | 154 | address: str 155 | port: int 156 | username: str 157 | password: str | None = None 158 | type: DatabaseType = DatabaseType.POSTGRES 159 | 160 | def __post_init__(self): 161 | """Initialize the Database object after creation.""" 162 | self.port = int(self.port) 163 | self.type = DatabaseType(self.type) 164 | 165 | @classmethod 166 | def from_db(cls, row: str): 167 | """Create a Database object from a database row string. 168 | 169 | Args: 170 | row (str): A string representation of the database data. 171 | 172 | Returns: 173 | Database: A new Database object created from the row data. 174 | """ 175 | address, port, username, password, type = row.split(":") 176 | return cls( 177 | address, 178 | int(port), 179 | username, 180 | password or None, 181 | DatabaseType(type), 182 | ) 183 | 184 | @property 185 | def uri(self): 186 | """Get the database URI for SQLAlchemy connection. 187 | 188 | Returns: 189 | str: The database URI. 190 | """ 191 | return f"{self.type.connection_string}://{self.username}:{self.password}@{self.address}:{self.port}/torch_submit" 192 | 193 | def to_db(self): 194 | """Convert the Database object to a string representation for database storage. 195 | 196 | Returns: 197 | str: A string representation of the Database object. 198 | """ 199 | return f"{self.address}:{self.port}:{self.username}:{self.password or ''}:{self.type.value}" 200 | 201 | def __str__(self): 202 | """Return a string representation of the Database object. 203 | 204 | Returns: 205 | str: A string representation of the Database object. 206 | """ 207 | return f"Database(type={self.type}, address={self.address}, port={self.port}, username={self.username}, password=****)" 208 | 209 | def __hash__(self): 210 | """Return a hash value for the Database object. 211 | 212 | Returns: 213 | int: A hash value based on the database attributes. 214 | """ 215 | return ( 216 | hash(self.type) 217 | + hash(self.address) 218 | + hash(self.port) 219 | + hash(self.username) 220 | + hash(self.password) 221 | ) 222 | 223 | def __eq__(self, other): 224 | """Check if two Database objects are equal. 225 | 226 | Args: 227 | other: Another object to compare with. 228 | 229 | Returns: 230 | bool: True if the objects are equal, False otherwise. 231 | """ 232 | if not isinstance(other, Database): 233 | return NotImplemented 234 | return ( 235 | self.address == other.address 236 | and self.port == other.port 237 | and self.type == other.type 238 | ) 239 | 240 | 241 | class Config: 242 | """Manages the configuration for clusters and databases. 243 | 244 | This class handles loading, saving, and manipulating the configuration 245 | for clusters and databases used in the torch-submit system. 246 | """ 247 | 248 | def __init__(self): 249 | """Initialize the Config object.""" 250 | self.config_path = os.path.expanduser("~/.cache/torch-submit/config.yaml") 251 | self.clusters: Dict[str, Cluster] = {} 252 | self.databases: Dict[str, Database] = {} 253 | self.load_config() 254 | 255 | def load_config(self): 256 | """Load the configuration from the YAML file.""" 257 | if not os.path.exists(self.config_path): 258 | return 259 | 260 | with open(self.config_path, "r") as f: 261 | config = yaml.safe_load(f) or {} 262 | 263 | for cluster_name, cluster_data in config.get("clusters", {}).items(): 264 | head_node = Node(**cluster_data["head_node"]) 265 | worker_nodes = [Node(**node) for node in cluster_data["worker_nodes"]] 266 | self.clusters[cluster_name] = Cluster(head_node, worker_nodes) 267 | 268 | for database_name, database_data in config.get("databases", {}).items(): 269 | database = Database(**database_data) 270 | self.databases[database_name] = database 271 | 272 | def save_config(self): 273 | """Save the current configuration to the YAML file.""" 274 | os.makedirs(os.path.dirname(self.config_path), exist_ok=True) 275 | config = {"clusters": {}, "databases": {}} 276 | 277 | for cluster_name, cluster in self.clusters.items(): 278 | config["clusters"][cluster_name] = { 279 | "head_node": { 280 | "public_ip": cluster.head_node.public_ip, 281 | "private_ip": cluster.head_node.private_ip or None, 282 | "num_gpus": cluster.head_node.num_gpus, 283 | "nproc": cluster.head_node.nproc, 284 | "ssh_user": cluster.head_node.ssh_user or None, 285 | "ssh_pub_key_path": cluster.head_node.ssh_pub_key_path or None, 286 | "ssh_port": cluster.head_node.ssh_port or None, 287 | }, 288 | "worker_nodes": [ 289 | { 290 | "public_ip": node.public_ip, 291 | "private_ip": node.private_ip or None, 292 | "num_gpus": node.num_gpus, 293 | "nproc": node.nproc, 294 | "ssh_user": node.ssh_user or None, 295 | "ssh_pub_key_path": node.ssh_pub_key_path or None, 296 | "ssh_port": node.ssh_port or None, 297 | } 298 | for node in cluster.worker_nodes 299 | ], 300 | } 301 | 302 | for database_name, database in self.databases.items(): 303 | config["databases"][database_name] = { 304 | "address": database.address, 305 | "port": database.port, 306 | "username": database.username, 307 | "password": database.password, 308 | "type": database.type.value, 309 | } 310 | 311 | with open(self.config_path, "w") as f: 312 | yaml.dump(config, f) 313 | 314 | # Cluster methods 315 | def add_cluster(self, name: str, head_node: Node, worker_nodes: List[Node]): 316 | """Add a new cluster to the configuration. 317 | 318 | Args: 319 | name (str): The name of the cluster. 320 | head_node (Node): The head node of the cluster. 321 | worker_nodes (List[Node]): The list of worker nodes in the cluster. 322 | """ 323 | self.clusters[name] = Cluster(head_node, worker_nodes) 324 | self.save_config() 325 | 326 | def remove_cluster(self, name: str): 327 | """Remove a cluster from the configuration. 328 | 329 | Args: 330 | name (str): The name of the cluster to remove. 331 | """ 332 | if name in self.clusters: 333 | del self.clusters[name] 334 | self.save_config() 335 | 336 | def get_cluster(self, cluster_name: str) -> Cluster: 337 | """Get a cluster by its name. 338 | 339 | Args: 340 | cluster_name (str): The name of the cluster. 341 | 342 | Returns: 343 | Cluster: The requested cluster. 344 | 345 | Raises: 346 | ValueError: If the cluster is not found in the configuration. 347 | """ 348 | if cluster_name not in self.clusters: 349 | raise ValueError(f"Cluster '{cluster_name}' not found in config") 350 | return self.clusters[cluster_name] 351 | 352 | def list_clusters(self) -> List[str]: 353 | """Get a list of all cluster names. 354 | 355 | Returns: 356 | List[str]: A list of cluster names. 357 | """ 358 | return list(self.clusters.keys()) 359 | 360 | def add_worker_node(self, cluster_name: str, worker_node: Node): 361 | """Add a worker node to a cluster. 362 | 363 | Args: 364 | cluster_name (str): The name of the cluster. 365 | worker_node (Node): The worker node to add. 366 | 367 | Raises: 368 | ValueError: If the cluster is not found in the configuration. 369 | """ 370 | if cluster_name not in self.clusters: 371 | raise ValueError(f"Cluster '{cluster_name}' not found in config") 372 | self.clusters[cluster_name].worker_nodes.append(worker_node) 373 | self.save_config() 374 | 375 | def remove_worker_node(self, cluster_name: str, worker_node_ip: str): 376 | """Remove a worker node from a cluster. 377 | 378 | Args: 379 | cluster_name (str): The name of the cluster. 380 | worker_node_ip (str): The IP address of the worker node to remove. 381 | 382 | Raises: 383 | ValueError: If the cluster is not found in the configuration. 384 | """ 385 | if cluster_name not in self.clusters: 386 | raise ValueError(f"Cluster '{cluster_name}' not found in config") 387 | self.clusters[cluster_name].worker_nodes = [ 388 | node 389 | for node in self.clusters[cluster_name].worker_nodes 390 | if node.public_ip != worker_node_ip and node.private_ip != worker_node_ip 391 | ] 392 | self.save_config() 393 | 394 | def update_cluster(self, name: str, head_node: Node, worker_nodes: List[Node]): 395 | """Update an existing cluster in the configuration. 396 | 397 | Args: 398 | name (str): The name of the cluster to update. 399 | head_node (Node): The new head node for the cluster. 400 | worker_nodes (List[Node]): The new list of worker nodes for the cluster. 401 | 402 | Raises: 403 | ValueError: If the cluster is not found in the configuration. 404 | """ 405 | if name not in self.clusters: 406 | raise ValueError(f"Cluster '{name}' not found in config") 407 | self.clusters[name] = Cluster(head_node, worker_nodes) 408 | self.save_config() 409 | 410 | # Database methods 411 | def add_db( 412 | self, 413 | type: DatabaseType, 414 | name: str, 415 | address: str, 416 | port: int, 417 | username: str, 418 | password: str, 419 | ): 420 | """Add a new database to the configuration. 421 | 422 | Args: 423 | type (DatabaseType): The type of the database. 424 | name (str): The name of the database configuration. 425 | address (str): The address of the database server. 426 | port (int): The port number for the database connection. 427 | username (str): The username for database authentication. 428 | password (str): The password for database authentication. 429 | """ 430 | self.databases[name] = Database(address, port, username, password, type) 431 | self.save_config() 432 | engine = create_engine(self.databases[name].uri.strip("/torch_submit")) 433 | with engine.connect() as conn: 434 | conn.execute(text("CREATE DATABASE IF NOT EXISTS torch_submit")) 435 | 436 | def remove_db(self, name: str): 437 | """Remove a database from the configuration. 438 | 439 | Args: 440 | name (str): The name of the database configuration to remove. 441 | """ 442 | if name in self.databases: 443 | del self.databases[name] 444 | self.save_config() 445 | 446 | def get_db(self, db_name: str) -> Database: 447 | """Get a database configuration by its name. 448 | 449 | Args: 450 | db_name (str): The name of the database configuration. 451 | 452 | Returns: 453 | Database: The requested database configuration. 454 | 455 | Raises: 456 | ValueError: If the database configuration is not found. 457 | """ 458 | if db_name not in self.databases: 459 | raise ValueError(f"Database '{db_name}' not found in config") 460 | return self.databases[db_name] 461 | 462 | def list_dbs(self) -> List[str]: 463 | """Get a list of all database configuration names. 464 | 465 | Returns: 466 | List[str]: A list of database configuration names. 467 | """ 468 | return list(self.databases.keys()) 469 | 470 | def update_db( 471 | self, 472 | type: str, 473 | name: str, 474 | address: str, 475 | port: int, 476 | username: str, 477 | password: str, 478 | ): 479 | """Update an existing database configuration. 480 | 481 | Args: 482 | type (str): The type of the database. 483 | name (str): The name of the database configuration to update. 484 | address (str): The new address of the database server. 485 | port (int): The new port number for the database connection. 486 | username (str): The new username for database authentication. 487 | password (str): The new password for database authentication. 488 | 489 | Raises: 490 | ValueError: If the specified database configuration is not found. 491 | """ 492 | if name not in self.databases: 493 | raise ValueError(f"Database '{name}' not found in config") 494 | self.databases[name] = Database(address, port, username, password, type) 495 | self.save_config() 496 | -------------------------------------------------------------------------------- /torch_submit/connection.py: -------------------------------------------------------------------------------- 1 | from fabric import Connection 2 | 3 | from .config import Node 4 | 5 | 6 | class NodeConnection: 7 | """A context manager for handling SSH connections to a node.""" 8 | 9 | def __init__(self, node: Node): 10 | """Initialize the NodeConnection with a Node object. 11 | 12 | Args: 13 | node (Node): The Node object representing the remote machine. 14 | """ 15 | self.node = node 16 | 17 | def __enter__(self): 18 | """Establish an SSH connection to the node. 19 | 20 | Returns: 21 | Connection: The established SSH connection. 22 | """ 23 | connect_kwargs = None 24 | if self.node.ssh_pub_key_path: 25 | connect_kwargs = { 26 | "key_filename": self.node.ssh_pub_key_path, 27 | } 28 | 29 | self.connection = Connection( 30 | self.node.public_ip, 31 | user=self.node.ssh_user, 32 | connect_kwargs=connect_kwargs, 33 | port=self.node.ssh_port, 34 | ) 35 | self.connection.open() 36 | return self.connection 37 | 38 | def __exit__(self, exc_type, exc_val, exc_tb): 39 | """Close the SSH connection when exiting the context. 40 | 41 | Args: 42 | exc_type: The type of the exception that caused the context to be exited. 43 | exc_val: The instance of the exception that caused the context to be exited. 44 | exc_tb: A traceback object encapsulating the call stack at the point where the exception occurred. 45 | """ 46 | self.connection.close() 47 | -------------------------------------------------------------------------------- /torch_submit/executor.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import json 3 | import os 4 | import random 5 | import zipfile 6 | from abc import ABC, abstractmethod 7 | from typing import Dict, Optional 8 | 9 | import optuna 10 | from fabric import Connection 11 | from invoke import UnexpectedExit 12 | from rich.console import Console 13 | 14 | from .config import Config, Node 15 | from .connection import NodeConnection 16 | from .types import Job 17 | 18 | console = Console() 19 | 20 | 21 | class WorkingDirectoryArchiver: 22 | """ 23 | A class to handle archiving of working directories for jobs. 24 | 25 | This class creates a zip archive of the specified working directory, including job metadata 26 | and excluding files specified in a .gitignore file. 27 | 28 | Attributes: 29 | job_id (str): The ID of the job. 30 | job_name (str): The name of the job. 31 | output_dir (str): The directory where the archive will be saved. 32 | """ 33 | 34 | def __init__(self, job_id: str, job_name: str): 35 | """ 36 | Initialize the WorkingDirectoryArchiver with job ID and job name. 37 | 38 | Args: 39 | job_id (str): The ID of the job. 40 | job_name (str): The name of the job. 41 | """ 42 | self.job_id = job_id 43 | self.job_name = job_name 44 | 45 | self.output_dir = os.path.expanduser(f"~/.cache/torch-submit/jobs/{job_id}") 46 | os.makedirs(self.output_dir, exist_ok=True) 47 | 48 | def archive(self, working_dir: str) -> str: 49 | """ 50 | Create a zip archive of the specified working directory. 51 | 52 | This method reads the .gitignore file in the working directory to determine which files 53 | to exclude from the archive. It also includes job metadata in the archive. 54 | 55 | Args: 56 | working_dir (str): The path to the working directory to be archived. 57 | 58 | Returns: 59 | str: The path to the created zip archive. 60 | """ 61 | archive_name = f"{os.path.basename(working_dir)}.zip" 62 | archive_path = os.path.join(self.output_dir, archive_name) 63 | 64 | gitignore_path = os.path.join(working_dir, ".gitignore") 65 | ignore_patterns = [] 66 | if os.path.exists(gitignore_path): 67 | with open(gitignore_path, "r") as gitignore_file: 68 | ignore_patterns = [ 69 | line.strip() 70 | for line in gitignore_file 71 | if line.strip() and not line.startswith("#") 72 | ] 73 | 74 | def should_ignore(path): 75 | """ 76 | Determine if a file should be ignored based on .gitignore patterns. 77 | 78 | Args: 79 | path (str): The path to the file. 80 | 81 | Returns: 82 | bool: True if the file should be ignored, False otherwise. 83 | """ 84 | rel_path = os.path.relpath(path, working_dir) 85 | return any( 86 | rel_path.startswith(pattern) or fnmatch.fnmatch(rel_path, pattern) 87 | for pattern in ignore_patterns 88 | ) 89 | 90 | with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zipf: 91 | # Write job metadata under .torch/job.json 92 | job_metadata = { 93 | "id": self.job_id, 94 | "name": self.job_name, 95 | } 96 | zipf.writestr(".torch_submit/job.json", json.dumps(job_metadata)) 97 | 98 | # Archive files 99 | for root, dirs, files in os.walk(working_dir): 100 | dirs[:] = [ 101 | d 102 | for d in dirs 103 | if d != "__pycache__" and not should_ignore(os.path.join(root, d)) 104 | ] 105 | for file in files: 106 | file_path = os.path.join(root, file) 107 | if not should_ignore(file_path): 108 | arcname = os.path.relpath(file_path, working_dir) 109 | zipf.write(file_path, arcname) 110 | 111 | return archive_path 112 | 113 | 114 | class BaseExecutor(ABC): 115 | """ 116 | Base class for executing jobs across a cluster. 117 | 118 | This class defines the structure for executing a job. Sub-classes must implement the get_command 119 | method, which generates the command to be executed on each node in the cluster. The execute method 120 | runs this command on each node, managing the setup and execution process. 121 | 122 | Methods: 123 | get_command(rank: int): Abstract method to create the command for the given node rank. 124 | execute() -> Dict[Node, int]: Executes the job command on each node in the cluster and returns 125 | a dictionary mapping nodes to their process IDs. 126 | """ 127 | 128 | def __init__(self, job: Job): 129 | self.job = job 130 | self.remote_dir = f"/tmp/torch_submit_job_{self.job.id}" 131 | self.cluster = Config().get_cluster(self.job.cluster) 132 | 133 | @abstractmethod 134 | def get_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): 135 | """ 136 | Generate the command to be executed on the given node rank. 137 | 138 | Args: 139 | rank (int): The rank of the node in the cluster. 140 | 141 | Returns: 142 | str: The command to be executed on the node. 143 | """ 144 | ... 145 | 146 | def execute(self, env_vars: Optional[Dict[str, str]] = None) -> Dict[Node, int]: 147 | """ 148 | Execute the job command on each node in the cluster. 149 | 150 | This method sets up the remote environment, copies the working directory, 151 | and runs the job command on each node in the cluster. It manages the setup 152 | and execution process, handling any exceptions that occur during execution. 153 | 154 | Returns: 155 | Dict[Node, int]: A dictionary mapping nodes to their process IDs. 156 | """ 157 | pids = {} 158 | for rank, node in enumerate([self.cluster.head_node] + self.cluster.worker_nodes): 159 | try: 160 | with NodeConnection(node) as conn: 161 | self._setup_remote_env(conn) 162 | self._copy_working_dir(conn) 163 | pids[node] = self._run_job(conn, rank, env_vars) 164 | except Exception: 165 | console.print_exception() 166 | console.print(f"Error executing job on node {node.public_ip}") 167 | pids[node] = None 168 | return pids 169 | 170 | def _prepare_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): 171 | return ( 172 | f"cd {self.remote_dir} && " 173 | f"{self.get_command(rank, env_vars)} " 174 | f"{self.job.command}" 175 | ) 176 | 177 | def _run_job( 178 | self, 179 | conn: Connection, 180 | node_rank: int, 181 | env_vars: Optional[Dict[str, str]] = None, 182 | ): 183 | """ 184 | Run the job on the specified node. 185 | 186 | This method changes the directory to the remote directory, runs the provided torchrun command 187 | along with the job command, and captures the process ID of the running job. 188 | 189 | Args: 190 | conn (Connection): The connection object to the node. 191 | executor_command (str): The command with which to run the user-provided script. 192 | node_rank (int): The rank of the node in the cluster. 193 | 194 | Returns: 195 | int: The process ID of the running job. 196 | """ 197 | console.print( 198 | f"[bold blue]Running job on {conn.host} (rank {node_rank})...[/bold blue]" 199 | ) 200 | full_command = self._prepare_command(node_rank, env_vars) 201 | conn.run( 202 | "source ~/.profile && " 203 | f"USE_TORCHSUBMIT=1 {full_command} > {self.remote_dir}/output.log 2>&1 & " 204 | f"pid=$!; " 205 | f"echo $pid > {self.remote_dir}/job.pid; " 206 | f"wait $pid; " 207 | f"echo $? > {self.remote_dir}/exit_code", 208 | disown=True, 209 | ) 210 | # Parse the PID from the job.pid file 211 | result = conn.run(f"cat {self.remote_dir}/job.pid", hide=True) 212 | pid = int(result.stdout.strip()) 213 | return pid 214 | 215 | def _setup_remote_env(self, conn: Connection): 216 | conn.run(f"mkdir -p {self.remote_dir}") 217 | 218 | def _copy_working_dir(self, conn: Connection): 219 | remote_zip_path = f"{self.remote_dir}/working_dir.zip" 220 | console.print(f"[bold blue]Copying working directory to {conn.host}...[/bold blue]") 221 | conn.put(self.job.working_dir, remote_zip_path) 222 | 223 | console.print(f"[bold blue]Unzipping working directory on {conn.host}...[/bold blue]") 224 | conn.run(f"unzip -q -o {remote_zip_path} -d {self.remote_dir}") 225 | console.print("[bold green]Working directory successfully synced.[/bold green]") 226 | 227 | def cleanup(self): 228 | """ 229 | Clean up the remote directories on all nodes. 230 | 231 | This method removes the remote directory created for the job on each node. 232 | If the cleanup fails on any node, a warning message is printed. 233 | """ 234 | for node in self.job.nodes: 235 | try: 236 | with NodeConnection(node) as conn: 237 | conn.run(f"rm -rf {self.remote_dir}") 238 | except UnexpectedExit: 239 | console.print( 240 | f"[bold yellow]Warning: Could not clean up {self.remote_dir} on {node}[/bold yellow]" 241 | ) 242 | 243 | 244 | class DistributedExecutor(BaseExecutor): 245 | """ 246 | The DistributedExecutor is responsible for setting up the environment for running 247 | distributed PyTorch jobs. It ensures that the necessary environment variables are set 248 | for the torch distributed environment, including MASTER_ADDR, MASTER_PORT, WORLD_SIZE, 249 | and NODE_RANK. These variables are essential for coordinating the distributed training 250 | process across multiple nodes and GPUs. 251 | 252 | Exposes the following environment variables to the user script: 253 | - MASTER_ADDR: The address of the master node. 254 | - MASTER_PORT: The port on which the master node is listening. 255 | - WORLD_SIZE: The total number of processes participating in the job. 256 | - NODE_RANK: The rank of the current node. 257 | - LOCAL_WORLD_SIZE: The number of processes on the current node. 258 | """ 259 | 260 | def __init__(self, job: Job): 261 | super().__init__(job) 262 | self.port = random.randint(29400, 29499) 263 | 264 | def get_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): 265 | """ 266 | Constructs the command to run the job with the torch distributed environment variables set. 267 | 268 | This method sets up the necessary environment variables for a distributed torch run, including 269 | MASTER_ADDR, MASTER_PORT, WORLD_SIZE, and NODE_RANK. It then appends the user-provided command 270 | to these environment variables. 271 | 272 | Args: 273 | rank (int): The rank of the current node. 274 | 275 | Returns: 276 | str: The full command to run the job with the necessary environment variables. 277 | """ 278 | head_node = self.cluster.head_node 279 | ip = head_node.private_ip or head_node.public_ip 280 | 281 | world_size = 0 282 | for node in self.cluster.worker_nodes + [self.cluster.head_node]: 283 | world_size += node.num_gpus 284 | 285 | formatted_env_vars = " ".join(f"{k}={v}" for k, v in env_vars.items()) 286 | 287 | return ( 288 | f"MASTER_ADDR={ip} " 289 | f"MASTER_PORT={self.port} " 290 | f"WORLD_SIZE={world_size} " 291 | f"NODE_RANK={rank} " 292 | f"LOCAL_WORLD_SIZE={self.cluster.worker_nodes[rank].num_gpus} " 293 | f"{formatted_env_vars} " 294 | ) 295 | 296 | 297 | class TorchrunExecutor(BaseExecutor): 298 | def __init__(self, job: Job): 299 | super().__init__(job) 300 | self.port = random.randint(29400, 29499) 301 | 302 | def get_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): 303 | """ 304 | Constructs the command to run the job with torchrun. 305 | 306 | This method sets up the necessary parameters for a torchrun command, including 307 | the number of nodes, the number of processes per node, the rendezvous backend, 308 | the rendezvous endpoint, the job ID, and the maximum number of restarts. 309 | 310 | Args: 311 | rank (int): The rank of the current node. 312 | 313 | Returns: 314 | str: The full command to run the job with torchrun. 315 | """ 316 | nnodes = len(self.cluster.worker_nodes) + 1 # including head node 317 | 318 | # Determine nproc_per_node 319 | if rank == 0: # Head node 320 | if self.job.num_gpus is not None: 321 | nproc_per_node = self.job.num_gpus 322 | elif self.cluster.head_node.num_gpus is not None: 323 | nproc_per_node = self.cluster.head_node.num_gpus 324 | else: 325 | nproc_per_node = 1 # Default to 1 if no GPU information is available 326 | omp_num_threads = self.cluster.head_node.nproc // nproc_per_node 327 | else: # Worker node 328 | if self.job.num_gpus is not None: 329 | nproc_per_node = self.job.num_gpus 330 | elif self.cluster.worker_nodes[rank - 1].num_gpus is not None: 331 | nproc_per_node = self.cluster.worker_nodes[rank - 1].num_gpus 332 | else: 333 | nproc_per_node = 1 # Default to 1 if no GPU information is available 334 | omp_num_threads = self.cluster.worker_nodes[rank - 1].nproc // nproc_per_node 335 | 336 | # if len(self.cluster.worker_nodes) == 0: 337 | # rdzv_endpoint = f"localhost:{self.port}" 338 | # else: 339 | head_node = self.cluster.head_node 340 | master_ip = head_node.private_ip or head_node.public_ip 341 | master_port = self.port 342 | 343 | if env_vars: 344 | formatted_env_vars = " ".join(f"{k}={v}" for k, v in env_vars.items()) 345 | else: 346 | formatted_env_vars = "" 347 | 348 | return ( 349 | f"OMP_NUM_THREADS={omp_num_threads} " 350 | f"{formatted_env_vars} " 351 | f"nohup torchrun " 352 | f"--nnodes={nnodes} " 353 | f"--node_rank={rank} " 354 | f"--nproc-per-node={nproc_per_node} " 355 | f"--master_addr={master_ip} " 356 | f"--master_port={master_port} " 357 | f"--rdzv-id={self.job.id} " 358 | f"--max-restarts={self.job.max_restarts} " 359 | "--no-python" 360 | ) 361 | 362 | 363 | class OptunaExecutor(DistributedExecutor): 364 | """ 365 | The OptunaExecutor sets up and manages the execution of Optuna distributed optimization jobs. 366 | 367 | The head node runs a SQLite database for Optuna and exposes it to the cluster. Each node in the cluster 368 | runs a single Optuna process that will utilize all the GPUs available on that node. 369 | 370 | Exposes the following environment variables to the user script: 371 | - MASTER_ADDR: The address of the master node. 372 | - MASTER_PORT: The port on which the master node is listening. 373 | - WORLD_SIZE: The total number of processes participating in the job. 374 | - NODE_RANK: The rank of the current node. 375 | - STUDY_NAME: The name of the Optuna study (the job name). 376 | - DATABASE_URI: The URI of the database. 377 | """ 378 | 379 | def __init__(self, job: Job): 380 | super().__init__(job) 381 | 382 | def get_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): 383 | if rank == 0: 384 | world_size = self.cluster.head_node.num_gpus 385 | else: 386 | world_size = self.cluster.worker_nodes[rank - 1].num_gpus 387 | 388 | formatted_env_vars = " ".join(f"{k}={v}" for k, v in env_vars.items()) 389 | 390 | return ( 391 | f"MASTER_ADDR=localhost " 392 | f"MASTER_PORT={self.port} " 393 | f"WORLD_SIZE={world_size} " 394 | f"NODE_RANK={rank} " 395 | f"OPTUNA_STUDY_NAME={self.job.name} " 396 | f"OPTUNA_STORAGE={self.job.database.uri} " 397 | f"{formatted_env_vars} " 398 | ) 399 | 400 | def execute(self) -> Dict[Node, int]: 401 | """ 402 | Set up the database on the head node and then run the DistributedExecutor execute method. 403 | 404 | This method first sets up the SQLite database on the head node for Optuna. After the database 405 | is set up, it calls the execute method of the DistributedExecutor to run the job command on 406 | each node in the cluster. 407 | 408 | Returns: 409 | Dict[Node, int]: A dictionary mapping nodes to their process IDs. 410 | """ 411 | optuna.create_study( 412 | study_name=self.job.name, 413 | storage=self.job.database.uri, 414 | ) 415 | with NodeConnection(self.cluster.head_node) as conn: 416 | conn.run(f"nohup optuna-dashboard --port {self.job.optuna_port} &") 417 | console.print( 418 | f"[bold blue]Optuna dashboard running on {self.cluster.head_node.public_ip}:{self.job.optuna_port}[/bold blue]" 419 | ) 420 | return super().execute() 421 | 422 | 423 | class DockerDistributedExecutor(DistributedExecutor): 424 | """ 425 | EXPERIMENTAL: 426 | DockerDistributedExecutor is an executor that runs distributed jobs inside Docker containers. 427 | 428 | This executor extends the DistributedExecutor to provide Docker support, allowing the user to run 429 | distributed jobs in isolated Docker environments with GPU support. 430 | 431 | Exposes the following environment variables to the user script: 432 | - MASTER_ADDR: The address of the master node. 433 | - MASTER_PORT: The port on which the master node is listening. 434 | - WORLD_SIZE: The total number of processes participating in the job. 435 | - NODE_RANK: The rank of the current node. 436 | """ 437 | 438 | def __init__(self, job: Job): 439 | super().__init__(job) 440 | 441 | def get_command(self, rank: int, env_vars: Optional[Dict[str, str]] = None): 442 | """ 443 | Constructs the command to run the job with the torch distributed environment variables set. 444 | 445 | This method sets up the necessary environment variables for a distributed torch run, including 446 | MASTER_ADDR, MASTER_PORT, WORLD_SIZE, and NODE_RANK. It then appends the user-provided command 447 | to these environment variables. 448 | 449 | Args: 450 | rank (int): The rank of the current node. 451 | 452 | Returns: 453 | str: The full command to run the job with the necessary environment variables. 454 | """ 455 | head_node = self.cluster.head_node 456 | ip = head_node.private_ip or head_node.public_ip 457 | 458 | world_size = 0 459 | for node in self.cluster.worker_nodes + [self.cluster.head_node]: 460 | world_size += node.num_gpus 461 | 462 | formatted_env_vars = " ".join(f"-e {k}={v}" for k, v in env_vars.items()) 463 | 464 | return ( 465 | "docker run --rm" 466 | "--gpus all --runtime=nvidia " 467 | "--network host " 468 | f"-v {self.remote_dir}:{self.remote_dir} " 469 | f"-e MASTER_ADDR={ip} " 470 | f"-e MASTER_PORT={self.port} " 471 | f"-e WORLD_SIZE={world_size} " 472 | f"-e NODE_RANK={rank} " 473 | f"-e LOCAL_WORLD_SIZE={self.cluster.worker_nodes[rank].num_gpus} " 474 | f"{formatted_env_vars} " 475 | f"{self.job.docker_image} " 476 | ) 477 | 478 | def _prepare_command(self, rank: int): 479 | return f"{self.get_command(rank)} -- {self.job.command}" 480 | 481 | 482 | class JobExecutionManager: 483 | @staticmethod 484 | def submit_job(job: Job): 485 | executor = job.get_executor() 486 | try: 487 | executor.execute() 488 | console.print(f"[bold green]Job {job.id} submitted successfully[/bold green]") 489 | except Exception as e: 490 | console.print(f"[bold red]Error submitting job {job.id}:[/bold red] {str(e)}") 491 | executor.cleanup() 492 | 493 | @staticmethod 494 | def cancel_job(job: Job): 495 | executor = job.get_executor() 496 | try: 497 | executor.cleanup() 498 | console.print(f"[bold green]Job {job.id} cancelled successfully[/bold green]") 499 | except Exception as e: 500 | console.print(f"[bold red]Error cancelling job {job.id}:[/bold red] {str(e)}") 501 | -------------------------------------------------------------------------------- /torch_submit/job.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sqlite3 3 | from concurrent.futures import ThreadPoolExecutor 4 | from typing import Dict, List, Optional 5 | 6 | from rich.console import Console 7 | 8 | from .config import Node 9 | from .connection import NodeConnection 10 | from .types import Job, JobStatus 11 | 12 | console = Console() 13 | 14 | 15 | class JobManager: 16 | """Manages job-related operations and database interactions.""" 17 | 18 | def __init__( 19 | self, db_path: str = os.path.expanduser("~/.cache/torch-submit/jobs.db") 20 | ): 21 | """Initialize the JobManager. 22 | 23 | Args: 24 | db_path (str): Path to the SQLite database file. 25 | """ 26 | os.makedirs(os.path.dirname(db_path), exist_ok=True) 27 | self.conn = sqlite3.connect(db_path) 28 | self.create_table() 29 | self.migrate_table() 30 | 31 | def create_table(self): 32 | """Create the jobs table if it doesn't exist.""" 33 | self.conn.execute(""" 34 | CREATE TABLE IF NOT EXISTS jobs ( 35 | id TEXT PRIMARY KEY, 36 | name TEXT, 37 | status TEXT, 38 | working_dir TEXT, 39 | nodes TEXT, 40 | cluster TEXT, 41 | command TEXT, 42 | max_restarts INTEGER DEFAULT 0, 43 | num_gpus INTEGER DEFAULT NULL, 44 | pids TEXT DEFAULT NULL, 45 | executor TEXT DEFAULT NULL, 46 | docker_image TEXT DEFAULT NULL, 47 | database TEXT DEFAULT NULL, 48 | optuna_port INTEGER DEFAULT NULL 49 | ) 50 | """) 51 | 52 | def add_job(self, job: Job): 53 | """Add a new job to the database. 54 | 55 | Args: 56 | job (Job): The job to be added. 57 | """ 58 | self.conn.execute( 59 | """ 60 | INSERT INTO jobs (id, name, status, working_dir, nodes, cluster, command, max_restarts, num_gpus, pids, executor, docker_image, database, optuna_port) 61 | VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 62 | """, 63 | job.to_db(), 64 | ) 65 | self.conn.commit() 66 | 67 | def get_job(self, job_id_or_name: str) -> Optional[Job]: 68 | """Retrieve a job by its ID or name. 69 | 70 | Args: 71 | job_id_or_name (str): The ID or name of the job. 72 | 73 | Returns: 74 | Optional[Job]: The retrieved job, or None if not found. 75 | """ 76 | # Try to get by id first 77 | cursor = self.conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id_or_name,)) 78 | row = cursor.fetchone() 79 | if not row: 80 | # If not found by id, try to get by name 81 | cursor = self.conn.execute( 82 | "SELECT * FROM jobs WHERE name = ?", (job_id_or_name,) 83 | ) 84 | row = cursor.fetchone() 85 | 86 | if not row: 87 | return None 88 | return Job.from_db(row) 89 | 90 | def list_jobs(self) -> List[Job]: 91 | """Retrieve all jobs from the database. 92 | 93 | Returns: 94 | List[Job]: A list of all jobs. 95 | """ 96 | cursor = self.conn.execute("SELECT * FROM jobs") 97 | return [Job.from_db(row) for row in cursor.fetchall()] 98 | 99 | def check_job_status(self, job: Job) -> str: 100 | """Check the current status of a job. 101 | 102 | Args: 103 | job (Job): The job to check. 104 | 105 | Returns: 106 | str: The current status of the job. 107 | 108 | Raises: 109 | RuntimeError: If an unknown job status is encountered. 110 | """ 111 | if job.status in [JobStatus.STOPPED, JobStatus.FINISHED, JobStatus.CRASHED]: 112 | return job.status 113 | 114 | if job.status in [JobStatus.SUBMITTED, JobStatus.RUNNING, JobStatus.STOPPING]: 115 | node_statuses = [] 116 | 117 | def check_node_status(node): 118 | try: 119 | with NodeConnection(node) as c: 120 | result = c.run( 121 | f"ps -p {job.pids[node]}", 122 | warn=True, 123 | hide=True, 124 | ) 125 | 126 | if result.ok and job.status in [ 127 | JobStatus.SUBMITTED, 128 | JobStatus.RUNNING, 129 | ]: 130 | return JobStatus.RUNNING 131 | elif result.ok and job.status == JobStatus.STOPPING: 132 | return JobStatus.STOPPING 133 | elif not result.ok and job.status == JobStatus.STOPPING: 134 | return JobStatus.STOPPED 135 | elif not result.ok: 136 | exit_code_result = c.run( 137 | f"cat {job.working_dir}/exit_code.log", 138 | warn=True, 139 | hide=True, 140 | ) 141 | if ( 142 | exit_code_result.ok 143 | and exit_code_result.stdout.strip() == "0" 144 | ): 145 | return JobStatus.FINISHED 146 | else: 147 | return JobStatus.CRASHED 148 | else: 149 | raise RuntimeError( 150 | f"Unknown job status: {job.status} for node {node}, {result.stdout}" 151 | ) 152 | 153 | except Exception as exc: 154 | console.print(f"Error checking node status: {exc}") 155 | return JobStatus.UNKNOWN 156 | 157 | with ThreadPoolExecutor() as executor: 158 | node_statuses = list(executor.map(check_node_status, job.nodes)) 159 | 160 | # Aggregate job status across all nodes 161 | if all(status == JobStatus.RUNNING for status in node_statuses): 162 | return JobStatus.RUNNING 163 | elif all(status == JobStatus.STOPPED for status in node_statuses): 164 | return JobStatus.STOPPED 165 | elif all(status == JobStatus.FINISHED for status in node_statuses): 166 | return JobStatus.FINISHED 167 | elif any(status == JobStatus.CRASHED for status in node_statuses): 168 | return JobStatus.CRASHED 169 | elif any(status == JobStatus.STOPPING for status in node_statuses): 170 | return JobStatus.STOPPING 171 | else: 172 | return JobStatus.UNKNOWN 173 | 174 | raise RuntimeError(f"Unknown job status: {job.status}") 175 | 176 | def get_all_jobs_with_status(self) -> List[Job]: 177 | """Retrieve all jobs and update their statuses. 178 | 179 | Returns: 180 | List[Job]: A list of all jobs with updated statuses. 181 | """ 182 | jobs = self.list_jobs() 183 | with ThreadPoolExecutor() as executor: 184 | futures = [executor.submit(self.check_job_status, job) for job in jobs] 185 | for job, future in zip(jobs, futures): 186 | try: 187 | new_status = future.result() 188 | if new_status != job.status: 189 | self.update_job_status(job.id, new_status) 190 | job.status = new_status 191 | except Exception as exc: 192 | print(f"Job {job.id} generated an exception: {exc}") 193 | return jobs 194 | 195 | def update_job_status(self, job_id: str, status: JobStatus): 196 | """Update the status of a job in the database. 197 | 198 | Args: 199 | job_id (str): The ID of the job to update. 200 | status (JobStatus): The new status of the job. 201 | 202 | Raises: 203 | ValueError: If an invalid job status is provided. 204 | """ 205 | if not isinstance(status, JobStatus): 206 | raise ValueError(f"Invalid job status: {status}") 207 | self.conn.execute( 208 | "UPDATE jobs SET status = ? WHERE id = ?", (status.value, job_id) 209 | ) 210 | self.conn.commit() 211 | 212 | def update_job_pids(self, job_id: str, pids: Dict[Node, int]): 213 | """Update the process IDs for a job in the database. 214 | 215 | Args: 216 | job_id (str): The ID of the job to update. 217 | pids (Dict[Node, int]): A dictionary mapping nodes to process IDs. 218 | """ 219 | self.conn.execute( 220 | "UPDATE jobs SET pids = ? WHERE id = ?", 221 | ( 222 | ",".join([f"{node.public_ip}:{pid}" for node, pid in pids.items()]), 223 | job_id, 224 | ), 225 | ) 226 | self.conn.commit() 227 | 228 | def delete_job(self, job_id: str): 229 | """Delete a job from the database. 230 | 231 | Args: 232 | job_id (str): The ID of the job to delete. 233 | """ 234 | self.conn.execute("DELETE FROM jobs WHERE id = ?", (job_id,)) 235 | self.conn.commit() 236 | 237 | def delete_all_jobs(self): 238 | """Delete all jobs from the database.""" 239 | self.conn.execute("DELETE FROM jobs") 240 | self.conn.commit() 241 | 242 | def close(self): 243 | """Close the database connection.""" 244 | self.conn.close() 245 | 246 | def migrate_table(self): 247 | """Perform any necessary database migrations.""" 248 | # Add any necessary migration steps here 249 | pass 250 | -------------------------------------------------------------------------------- /torch_submit/types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum 3 | from typing import Dict, List, Optional, Tuple 4 | 5 | from .config import Database, Node 6 | 7 | 8 | class Executor(str, Enum): 9 | """Enumeration of different types of executors.""" 10 | 11 | TORCHRUN = "torchrun" 12 | DISTRIBUTED = "distributed" 13 | OPTUNA = "optuna" 14 | 15 | 16 | class JobStatus(str, Enum): 17 | """Enumeration of different job statuses.""" 18 | 19 | SUBMITTED = "submitted" 20 | RUNNING = "running" 21 | STOPPING = "stopping" 22 | STOPPED = "stopped" 23 | FINISHED = "finished" 24 | CRASHED = "crashed" 25 | UNKNOWN = "unknown" 26 | 27 | 28 | @dataclass 29 | class Job: 30 | """ 31 | A class representing a job to be executed. 32 | 33 | Attributes: 34 | id (str): The ID of the job. 35 | name (str): The name of the job. 36 | status (JobStatus): The current status of the job. 37 | working_dir (str): The working directory for the job. 38 | nodes (List[Node]): The list of nodes assigned to the job. 39 | cluster (str): The cluster to which the job belongs. 40 | command (str): The command to be executed for the job. 41 | max_restarts (int): The maximum number of restarts allowed for the job. 42 | num_gpus (Optional[int]): The number of GPUs allocated for the job. 43 | pids (Dict[Node, int]): A dictionary mapping nodes to process IDs. 44 | executor (Executor): The executor type for the job. 45 | docker_image (Optional[str]): The Docker image to be used for the job. 46 | database (Optional[Database]): The database configuration for the job. 47 | optuna_port (Optional[int]): The port for Optuna executor. 48 | """ 49 | 50 | id: str 51 | name: str 52 | status: JobStatus 53 | working_dir: str 54 | nodes: List[Node] 55 | cluster: str 56 | command: str 57 | max_restarts: int = 0 58 | num_gpus: Optional[int] = None 59 | pids: Dict[Node, int] = field(default_factory=dict) 60 | executor: Executor = field(default_factory=Executor.TORCHRUN) 61 | docker_image: Optional[str] = None 62 | database: Optional[Database] = None 63 | optuna_port: Optional[int] = None 64 | 65 | def __post_init__(self): 66 | """Post-initialization checks for the Job class.""" 67 | if self.executor == Executor.OPTUNA and not self.optuna_port: 68 | raise ValueError("Optuna executor requires a port") 69 | 70 | @classmethod 71 | def from_db(cls, row: Tuple) -> "Job": 72 | """ 73 | Create a Job instance from a database row. 74 | 75 | Args: 76 | row (Tuple): A tuple representing a row from the database. 77 | 78 | Returns: 79 | Job: A Job instance created from the database row. 80 | """ 81 | nodes = [Node.from_db(node) for node in row[4].split(",")] 82 | pids = {} 83 | if row[9]: 84 | for pair in row[9].split(","): 85 | node_ip, pid = pair.split(":") 86 | node = next((n for n in nodes if n.public_ip == node_ip), None) 87 | if node: 88 | pids[node] = int(pid) 89 | 90 | return cls( 91 | id=row[0], 92 | name=row[1], 93 | status=JobStatus(row[2]), 94 | working_dir=row[3], 95 | nodes=nodes, 96 | cluster=row[5], 97 | command=row[6], 98 | max_restarts=int(row[7]), 99 | num_gpus=int(row[8]) if row[8] else None, 100 | pids=pids, 101 | executor=Executor(row[10]), 102 | docker_image=row[11] or None, 103 | database=Database.from_db(row[12]) if row[12] else None, 104 | optuna_port=int(row[13]) if row[13] else None, 105 | ) 106 | 107 | def to_db(self) -> Tuple: 108 | """ 109 | Convert the Job instance to a tuple for database storage. 110 | 111 | Returns: 112 | Tuple: A tuple representing the Job instance for database storage. 113 | """ 114 | return ( 115 | self.id, 116 | self.name, 117 | self.status.value, 118 | self.working_dir, 119 | ",".join([node.to_db() for node in self.nodes]), 120 | self.cluster, 121 | self.command, 122 | self.max_restarts, 123 | self.num_gpus or "", 124 | ",".join([f"{k}:{v}" for k, v in self.pids.items()]), 125 | self.executor.value, 126 | self.docker_image or "", 127 | self.database.to_db() or "" if self.database else "", 128 | self.optuna_port or "", 129 | ) 130 | 131 | def get_executor(self): 132 | """ 133 | Get the appropriate executor instance for the job. 134 | 135 | Returns: 136 | An instance of the appropriate executor class. 137 | 138 | Raises: 139 | ValueError: If an unknown executor type is specified or if Docker image is not supported for the executor. 140 | """ 141 | from .executor import ( 142 | DistributedExecutor, 143 | DockerDistributedExecutor, 144 | OptunaExecutor, 145 | TorchrunExecutor, 146 | ) 147 | 148 | if self.executor == Executor.TORCHRUN and self.docker_image: 149 | raise ValueError("Docker image is not supported for torchrun executor") 150 | elif self.executor == Executor.TORCHRUN: 151 | return TorchrunExecutor(self) 152 | elif self.executor == Executor.DISTRIBUTED and self.docker_image: 153 | return DockerDistributedExecutor(self) 154 | elif self.executor == Executor.DISTRIBUTED: 155 | return DistributedExecutor(self) 156 | elif self.executor == Executor.OPTUNA and self.docker_image: 157 | raise ValueError("Docker image is not supported for optuna executor") 158 | elif self.executor == Executor.OPTUNA: 159 | return OptunaExecutor(self) 160 | else: 161 | raise ValueError(f"Unknown executor: {self.executor}") 162 | 163 | def __str__(self): 164 | """ 165 | Return a string representation of the Job instance. 166 | 167 | Returns: 168 | str: A string representation of the Job instance. 169 | """ 170 | return ( 171 | f"Job(" 172 | f"id={self.id}, " 173 | f"name={self.name}, " 174 | f"status={self.status.value}, " 175 | f"working_dir={self.working_dir}, " 176 | f"nodes={self.nodes}, " 177 | f"cluster={self.cluster}, " 178 | f"command={self.command}, " 179 | f"max_restarts={self.max_restarts}, " 180 | f"num_gpus={self.num_gpus}, " 181 | f"pids={self.pids}, " 182 | f"executor={self.executor}, " 183 | f"docker_image={self.docker_image}, " 184 | f"database={self.database}, " 185 | f"optuna_port={self.optuna_port}" 186 | f")" 187 | ) 188 | -------------------------------------------------------------------------------- /torch_submit/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from typing import Dict, Optional 4 | 5 | 6 | def generate_friendly_name() -> str: 7 | """Generate a friendly, human-readable name for a job. 8 | 9 | This function creates a name by combining a random adjective, a random animal noun, 10 | and a random 4-digit number. The name format is 'adjective-noun-number'. 11 | 12 | Returns: 13 | A friendly name string in the format 'adjective-noun-number'. 14 | 15 | Example: 16 | >>> generate_friendly_name() 17 | 'happy-panda-3721' 18 | """ 19 | adjectives = [ 20 | "happy", 21 | "sunny", 22 | "clever", 23 | "swift", 24 | "brave", 25 | "bright", 26 | "calm", 27 | "daring", 28 | "eager", 29 | "gentle", 30 | "jolly", 31 | "kind", 32 | "lively", 33 | "nice", 34 | "proud", 35 | "wise", 36 | ] 37 | nouns = [ 38 | "panda", 39 | "tiger", 40 | "eagle", 41 | "dolphin", 42 | "fox", 43 | "owl", 44 | "wolf", 45 | "bear", 46 | "hawk", 47 | "lion", 48 | "deer", 49 | "rabbit", 50 | "otter", 51 | "koala", 52 | "lynx", 53 | "raven", 54 | ] 55 | return f"{random.choice(adjectives)}-{random.choice(nouns)}-{random.randint(1000, 9999)}" 56 | 57 | 58 | def get_job_metadata() -> Optional[Dict[str, str]]: 59 | """Retrieve job metadata from the '.torch/job.json' file. 60 | 61 | This function attempts to read and parse the job metadata stored in the 62 | '.torch/job.json' file in the current working directory. 63 | 64 | Returns: 65 | A dictionary containing job metadata if the file exists and can be parsed 66 | successfully, or None if the file is not found. 67 | 68 | Raises: 69 | json.JSONDecodeError: If the file exists but contains invalid JSON. 70 | 71 | Example: 72 | >>> metadata = get_job_metadata() 73 | >>> if metadata: 74 | ... print(f"Job ID: {metadata.get('id')}") 75 | ... else: 76 | ... print("No job metadata found.") 77 | """ 78 | try: 79 | with open(".torch_submit/job.json", "r") as f: 80 | job_metadata = json.load(f) 81 | return job_metadata 82 | except FileNotFoundError: 83 | return None 84 | --------------------------------------------------------------------------------