├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Readme.md ├── clean_and_build.py ├── example_running.png ├── project_structure ├── pyproject.toml ├── requirements.txt ├── setup.py ├── ssh_gpu_monitor ├── __init__.py ├── __main__.py ├── config │ └── config.yaml ├── main.py └── src │ ├── __init__.py │ ├── config_loader.py │ └── table_display.py └── todos.md /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | logs/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | bin/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # Installer logs 27 | pip-log.txt 28 | pip-delete-this-directory.txt 29 | 30 | # Unit test / coverage reports 31 | .tox/ 32 | .coverage 33 | .cache 34 | nosetests.xml 35 | coverage.xml 36 | 37 | # Translations 38 | *.mo 39 | 40 | # Mr Developer 41 | .mr.developer.cfg 42 | .project 43 | .pydevproject 44 | 45 | # Rope 46 | .ropeproject 47 | 48 | # Django stuff: 49 | *.log 50 | *.pot 51 | 52 | # Sphinx documentation 53 | docs/_build/ 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Alexander F. Spies 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements.txt 4 | recursive-include ssh_gpu_monitor/config *.yaml 5 | recursive-include ssh_gpu_monitor *.py -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # SSH GPU Monitor 🖥️ 2 | [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/) 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | A fast, asynchronous GPU monitoring tool that provides real-time status of NVIDIA GPUs across multiple machines through SSH, with support for jump hosts and per-machine credentials. 6 | 7 | ![Example Output](example_running.png) 8 | 9 | ## ✨ Features 10 | 11 | - **Real-time Monitoring**: Live updates of GPU status across multiple machines 12 | - **Asynchronous Operation**: Fast, non-blocking checks using `asyncio` and `asyncssh` 13 | - **Jump Host Support**: Access machines behind a bastion/jump host 14 | - **Rich Display**: Beautiful terminal UI using the `rich` library 15 | - **Flexible Configuration**: 16 | - YAML-based configuration 17 | - Per-machine SSH credentials 18 | - Pattern-based target generation 19 | - **Robust Error Handling**: Graceful handling of network issues and timeouts 20 | 21 | ## 🚀 Installation & Usage 22 | 23 | ### Install from PyPI 24 | ```bash 25 | pip install ssh-gpu-monitor 26 | ``` 27 | 28 | ### Run the Monitor 29 | After installation, you can run the monitor in several ways: 30 | 31 | ```bash 32 | # Run using the command-line tool 33 | ssh-gpu-monitor 34 | 35 | # Or run as a Python module 36 | python -m ssh_gpu_monitor 37 | 38 | # Use a custom config file 39 | ssh-gpu-monitor --config /path/to/your/config.yaml 40 | 41 | # Get the default config path 42 | ssh-gpu-monitor --get_config_path 43 | ``` 44 | 45 | ### Configuration 46 | 1. Get the default config path: 47 | ```bash 48 | ssh-gpu-monitor --get_config_path 49 | ``` 50 | 51 | 2. Either: 52 | - Copy the default config to your preferred location and use `--config` to specify it 53 | - Modify the default config directly 54 | - Use command line options to override any config values (see below) 55 | 56 | Example config file: 57 | ```yaml 58 | ssh: 59 | username: "your_username" 60 | key_path: "~/.ssh/id_rsa" 61 | jump_host: "jump.example.com" 62 | timeout: 10 63 | 64 | targets: 65 | individual: 66 | - "gpu-server1" 67 | - "gpu-server2" 68 | 69 | display: 70 | refresh_rate: 5 71 | ``` 72 | 73 | ## 📖 Configuration 74 | 75 | ### Basic Structure 76 | ```yaml 77 | ssh: 78 | username: "default_user" # Default username 79 | key_path: "~/.ssh/id_rsa" # Default SSH key 80 | jump_host: "jump.example.com" 81 | timeout: 10 # seconds 82 | 83 | targets: 84 | # Individual machines 85 | individual: 86 | - host: "gpu-server1" 87 | username: "different_user" # Optional override 88 | key_path: "~/.ssh/special_key" # Optional override 89 | - "gpu-server2" # Uses default credentials 90 | 91 | # Pattern-based groups 92 | patterns: 93 | - prefix: "gpu" 94 | start: 1 95 | end: 30 96 | format: "{prefix}{number:02}" # Results in gpu01, gpu02, etc. 97 | username: "gpu_user" # Optional override 98 | key_path: "~/.ssh/gpu_key" # Optional override 99 | 100 | display: 101 | refresh_rate: 5 # seconds 102 | 103 | debug: 104 | enabled: false 105 | log_dir: "logs" 106 | log_file: "gpu_checker.log" 107 | log_max_size: 1048576 # 1MB 108 | log_backup_count: 3 109 | ``` 110 | 111 | ### Command Line Options 112 | Override any configuration option via command line: 113 | ```bash 114 | # Enable debug logging 115 | python main.py --debug.enabled 116 | 117 | # Override SSH settings 118 | python main.py --ssh.username=other_user --ssh.key_path=~/.ssh/other_key 119 | 120 | # Check specific targets 121 | python main.py --targets gpu01 gpu02 special-server 122 | ``` 123 | 124 | ## 🔧 Advanced Usage 125 | 126 | ### Custom Target Patterns 127 | Generate targets using patterns: 128 | ```yaml 129 | patterns: 130 | - prefix: "compute" 131 | start: 1 132 | end: 100 133 | format: "{prefix}-{number:03d}" # compute-001, compute-002, etc. 134 | ``` 135 | 136 | ### Per-Machine Credentials 137 | Specify different credentials for specific machines: 138 | ```yaml 139 | individual: 140 | - host: "special-gpu" 141 | username: "admin" 142 | key_path: "~/.ssh/admin_key" 143 | ``` 144 | 145 | ### Debug Logging 146 | Enable detailed logging for troubleshooting: 147 | ```yaml 148 | debug: 149 | enabled: true 150 | log_dir: "logs" 151 | log_file: "debug.log" 152 | ``` 153 | 154 | ## 🤝 Contributing 155 | 156 | Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change. 157 | 158 | ## 📝 License 159 | 160 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 161 | 162 | ## 🙏 Acknowledgments 163 | 164 | ### Original Contributors 165 | Originally created as "some awful, brittle code to check GPU status of multiple machines at a given host address through an SSH jumpnode." 166 | 167 | Special thanks to: 168 | - @harrygcoppock and @minut1bc for their PRs on v1 169 | - [gpuobserver](https://github.com/pawni/gpuobserver) for earlier code concepts 170 | - [Stack Overflow answer](https://stackoverflow.com/a/36096801/7565759) for SSH connection handling insights 171 | 172 | ### Libraries 173 | - [Rich](https://github.com/Textualize/rich) for the beautiful terminal interface 174 | - [asyncssh](https://github.com/ronf/asyncssh) for async SSH support 175 | - [PyYAML](https://pyyaml.org/) for configuration management 176 | 177 | ## 🔍 Similar Projects 178 | The following projects are similar in spirit, but only support a single machine: 179 | - [nvidia-smi-tools](https://github.com/example/nvidia-smi-tools) 180 | - [gpu-monitor](https://github.com/example/gpu-monitor) 181 | 182 | ## ⚠️ Known Issues 183 | 184 | - SSH connection might timeout on very slow networks 185 | - Some older NVIDIA drivers might return incompatible XML formats 186 | 187 | ## 📊 Roadmap 188 | 189 | - [ ] Add support for AMD GPUs 190 | - [ ] Implement process name filtering 191 | - [ ] Add web interface 192 | - [ ] Support for custom SSH config files 193 | -------------------------------------------------------------------------------- /clean_and_build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import shutil 4 | import subprocess 5 | import argparse 6 | from pathlib import Path 7 | 8 | # Try to use the standard library tomllib (Python 3.11+) 9 | # Otherwise fall back to tomli/toml 10 | try: 11 | import tomllib as tomli # for reading 12 | from tomllib import dump as toml_dump # will raise error since tomllib is read-only 13 | except ImportError: 14 | try: 15 | import tomli # for reading 16 | import toml # for writing 17 | toml_dump = toml.dump 18 | except ImportError: 19 | raise ImportError("Please install toml with: pip install toml") 20 | 21 | def increment_version(current_version: str, increment_type: str) -> str: 22 | """Increment version number based on semver.""" 23 | major, minor, patch = map(int, current_version.split('.')) 24 | assert increment_type in ['major', 'minor', 'patch'], "Invalid increment type, must be 'major', 'minor', or 'patch'" 25 | if increment_type == 'major': 26 | return f"{major + 1}.0.0" 27 | elif increment_type == 'minor': 28 | return f"{major}.{minor + 1}.0" 29 | elif increment_type == 'patch': # patch 30 | return f"{major}.{minor}.{patch + 1}" 31 | 32 | def update_version(new_version: str = None, increment: str = None): 33 | """Update version in pyproject.toml.""" 34 | try: 35 | # Read current pyproject.toml 36 | with open('pyproject.toml', 'rb') as f: # Open in binary mode for tomli 37 | try: 38 | config = tomli.load(f) 39 | except tomli.TOMLDecodeError as e: 40 | print(f"❌ Error in pyproject.toml: {str(e)}") 41 | print("\nPlease check your pyproject.toml file for duplicate entries.") 42 | print("Specifically look for duplicate declarations of 'tool.setuptools.packages.find'") 43 | raise SystemExit(1) 44 | 45 | # Get current version 46 | current_version = None 47 | if 'project' in config: 48 | current_version = config['project']['version'] 49 | elif 'tool' in config and 'poetry' in config['tool']: 50 | current_version = config['tool']['poetry']['version'] 51 | else: 52 | raise ValueError("Couldn't find version field in pyproject.toml") 53 | 54 | # Determine new version 55 | if increment: 56 | final_version = increment_version(current_version, increment) 57 | elif new_version: 58 | final_version = new_version 59 | else: 60 | raise ValueError("Either --version or --increment must be specified") 61 | 62 | # Update version in config 63 | if 'project' in config: 64 | config['project']['version'] = final_version 65 | elif 'tool' in config and 'poetry' in config['tool']: 66 | config['tool']['poetry']['version'] = final_version 67 | 68 | # Write back to file 69 | with open('pyproject.toml', 'w', encoding='utf-8') as f: 70 | toml_dump(config, f) 71 | 72 | print(f"✔ Updated version from {current_version} to {final_version}") 73 | return final_version 74 | except Exception as e: 75 | print(f"❌ Failed to update version: {e}") 76 | raise 77 | 78 | def clean_build_artifacts(): 79 | """Remove build artifacts and cache directories.""" 80 | dirs_to_remove = [ 81 | 'build', 82 | 'dist', 83 | '*.egg-info', 84 | '__pycache__', 85 | '.pytest_cache', 86 | '.coverage', 87 | '.tox', 88 | '.mypy_cache' 89 | ] 90 | 91 | for pattern in dirs_to_remove: 92 | for path in Path('.').rglob(pattern): 93 | if path.is_dir(): 94 | shutil.rmtree(path) 95 | else: 96 | path.unlink() 97 | print("✔ Cleaned build artifacts and caches") 98 | 99 | def build_package(): 100 | """Build the Python package.""" 101 | try: 102 | subprocess.run(['python', '-m', 'build'], check=True) 103 | print("✔ Built package successfully") 104 | except subprocess.CalledProcessError as e: 105 | print(f"❌ Build failed: {e}") 106 | raise 107 | 108 | def publish_to_pypi(): 109 | """Publish the package to PyPI.""" 110 | try: 111 | subprocess.run(['python', '-m', 'twine', 'upload', 'dist/*'], check=True) 112 | print("✔ Published to PyPI successfully") 113 | except subprocess.CalledProcessError as e: 114 | print(f"❌ Upload failed: {e}") 115 | raise 116 | 117 | def main(): 118 | parser = argparse.ArgumentParser(description='Build and publish Python package') 119 | version_group = parser.add_mutually_exclusive_group(required=True) 120 | version_group.add_argument('--version', '-v', help='New version number (e.g., 1.0.1)') 121 | version_group.add_argument('--increment', '-i', choices=['major', 'minor', 'patch'], 122 | help='Increment major, minor, or patch version') 123 | parser.add_argument('--no-publish', action='store_true', help='Skip publishing to PyPI') 124 | args = parser.parse_args() 125 | 126 | new_version = update_version(args.version, args.increment) 127 | print(f"Building version {new_version}") 128 | 129 | clean_build_artifacts() 130 | build_package() 131 | 132 | if not args.no_publish: 133 | publish_to_pypi() 134 | 135 | if __name__ == '__main__': 136 | main() 137 | -------------------------------------------------------------------------------- /example_running.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/afspies/gpu-monitor/6f53a1f991e1f9cef1829b49488d5211e3b2d34c/example_running.png -------------------------------------------------------------------------------- /project_structure: -------------------------------------------------------------------------------- 1 | ssh-gpu-monitor/ 2 | ├── pyproject.toml 3 | ├── setup.py 4 | ├── MANIFEST.in 5 | ├── README.md 6 | ├── LICENSE 7 | └── ssh_gpu_monitor/ 8 | ├── __init__.py 9 | ├── __main__.py 10 | ├── main.py 11 | ├── config_loader.py 12 | ├── table_display.py 13 | └── config/ 14 | └── config.yaml -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ "setuptools>=45", "wheel",] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ssh-gpu-monitor" 7 | version = "1.0.2" 8 | description = "A fast, asynchronous GPU monitoring tool for multiple machines through SSH" 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Development Status :: 4 - Beta", "Environment :: Console", "Topic :: System :: Monitoring", "Topic :: System :: Systems Administration",] 12 | dependencies = [ "rich>=10.0.0", "asyncssh>=2.13.1", "pyyaml>=6.0.1", "pyOpenSSL==23.1.1", "cryptography==40.0.2",] 13 | [[project.authors]] 14 | name = "Alex Spies" 15 | email = "alex@afspies.com" 16 | 17 | [project.license] 18 | file = "LICENSE" 19 | 20 | [project.urls] 21 | Homepage = "https://github.com/afspies/gpu-monitor" 22 | "Bug Tracker" = "https://github.com/afspies/gpu-monitor/issues" 23 | 24 | [project.scripts] 25 | ssh-gpu-monitor = "ssh_gpu_monitor.__main__:main_entry" 26 | 27 | [tool.setuptools.package-data] 28 | ssh_gpu_monitor = [ "config/*.yaml",] 29 | 30 | [tool.setuptools.dynamic.readme] 31 | file = [ "README.md",] 32 | content-type = "text/markdown" 33 | 34 | [tool.setuptools.packages.find] 35 | where = [ ".",] 36 | include = [ "*",] 37 | exclude = [ "clean_and_build.py",] 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rich>=10.0.0 2 | asyncssh>=2.13.1 3 | pyyaml>=6.0.1 4 | pyOpenSSL==23.1.1 5 | cryptography==40.0.2 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() -------------------------------------------------------------------------------- /ssh_gpu_monitor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/afspies/gpu-monitor/6f53a1f991e1f9cef1829b49488d5211e3b2d34c/ssh_gpu_monitor/__init__.py -------------------------------------------------------------------------------- /ssh_gpu_monitor/__main__.py: -------------------------------------------------------------------------------- 1 | """Main entry point for the SSH GPU Monitor.""" 2 | import asyncio 3 | from .main import main 4 | from .src.config_loader import load_config 5 | 6 | def main_entry(): 7 | """Entry point for the console script.""" 8 | config = load_config() 9 | asyncio.run(main(config)) 10 | 11 | if __name__ == '__main__': 12 | main_entry() -------------------------------------------------------------------------------- /ssh_gpu_monitor/config/config.yaml: -------------------------------------------------------------------------------- 1 | # Default configuration for GPU Checker 2 | # SSH Configuration 3 | ssh: 4 | username: "afs219" # Default username 5 | key_path: "~/.ssh/id_rsa" # Default key path 6 | jump_host: "shell4.doc.ic.ac.uk" 7 | timeout: 10 # seconds 8 | 9 | # Target Specification 10 | targets: 11 | # Individual targets with optional username and key_path override 12 | individual: 13 | - host: "spikesaurus" 14 | username: "afs219" # Override for this specific target 15 | key_path: "~/.ssh/spikesaurus" # Override key for this target 16 | - host: "animal" 17 | username: "alex" 18 | key_path: "~/.ssh/animalai" 19 | 20 | # Pattern-based targets with optional username and key_path override 21 | patterns: 22 | - prefix: "gpu" 23 | start: 1 24 | end: 37 25 | format: "{prefix}{number:02}" 26 | username: "afs219" # Override for all targets in this pattern 27 | - prefix: "ray" 28 | start: 1 29 | end: 27 30 | format: "{prefix}{number:02}" 31 | # No username or key_path specified, will use defaults 32 | 33 | # Display Configuration 34 | display: 35 | refresh_rate: 5 # seconds 36 | 37 | # Debug Configuration 38 | debug: 39 | enabled: true 40 | log_dir: "logs" 41 | log_file: "gpu_checker.log" 42 | log_max_size: 5242880 # 5MB in bytes 43 | log_backup_count: 2 44 | -------------------------------------------------------------------------------- /ssh_gpu_monitor/main.py: -------------------------------------------------------------------------------- 1 | """This module is used to check GPU machines information asynchronously.""" 2 | 3 | import asyncio 4 | import os 5 | from typing import NamedTuple, Dict, List, Any 6 | import xml.etree.ElementTree as ET 7 | import re 8 | import sys 9 | import logging 10 | from logging.handlers import RotatingFileHandler 11 | from rich.live import Live 12 | from rich.console import Console 13 | import signal 14 | from pathlib import Path 15 | 16 | # Add these imports at the top 17 | import asyncssh 18 | from .src.table_display import GPUTable 19 | from .src.config_loader import generate_targets, Target 20 | 21 | class GPUInfo(NamedTuple): 22 | """A class for representing a GPU machine's information.""" 23 | 24 | model: str 25 | num_procs: int 26 | gpu_util: str 27 | used_mem: str 28 | total_mem: str 29 | 30 | def __str__(self) -> str: 31 | return ( 32 | f'{self.model:26} | ' 33 | f'Free: {self.num_procs == 0!s:5} | ' 34 | f'Num Procs: {self.num_procs:2d} | ' 35 | f'GPU Util: {self.gpu_util:>3} % | ' 36 | f'Memory: {self.used_mem:>5} / {self.total_mem:>5} MiB' 37 | ) 38 | 39 | class AsyncGPUChecker: 40 | """A class for asynchronously checking GPU machines.""" 41 | 42 | def __init__(self, targets: List[Target]) -> None: 43 | self.targets = targets 44 | self.proc_filter = re.compile(r'.*') 45 | self.gpu_table = GPUTable() 46 | self.connections = {} 47 | self.jump_conn = None 48 | self.console = Console() 49 | self.running = True 50 | # Add semaphore for connection pooling 51 | self.connection_semaphore = asyncio.Semaphore(10) # Limit to 10 concurrent connections 52 | 53 | def signal_handler(self, signum, frame): 54 | """Handle Ctrl+C gracefully""" 55 | print("\nShutting down gracefully...") 56 | self.gpu_table.show_goodbye() # Show goodbye message 57 | self.running = False 58 | 59 | async def open_connection(self, target: Target): 60 | """Modified connection method with better error handling and rate limiting""" 61 | self.gpu_table.add_status(f"[{target.host}] Waiting for connection semaphore", "cyan") 62 | logging.debug(f"[{target.host}] Waiting for connection semaphore (current limit: 10)") 63 | async with self.connection_semaphore: # Use semaphore to limit concurrent connections 64 | try: 65 | self.gpu_table.add_status(f"[{target.host}] Attempting port forwarding", "yellow") 66 | logging.debug(f"[{target.host}] Attempting port forwarding through jump host") 67 | # Set up port forwarding 68 | try: 69 | listener = await asyncio.wait_for( 70 | self.jump_conn.forward_local_port('', 0, target.host, 22), 71 | timeout=SSH_TIMEOUT/2 72 | ) 73 | tunnel_port = listener.get_port() 74 | self.gpu_table.add_status(f"[{target.host}] Port forwarding established", "green") 75 | logging.debug(f"[{target.host}] Port forwarding established on port {tunnel_port}") 76 | except asyncio.TimeoutError: 77 | error_msg = f"Port forwarding timeout after {SSH_TIMEOUT/2} seconds" 78 | self.gpu_table.add_status(f"[{target.host}] {error_msg}", "red") 79 | logging.error(f"[{target.host}] {error_msg}") 80 | return {target.host: error_msg} 81 | except Exception as e: 82 | error_msg = f"Port forwarding error: {str(e)}" 83 | self.gpu_table.add_status(f"[{target.host}] {error_msg}", "red") 84 | logging.error(f"[{target.host}] {error_msg}", exc_info=True) 85 | return {target.host: error_msg} 86 | 87 | # Expand the key path 88 | key_path = os.path.expanduser(target.key_path) 89 | self.gpu_table.add_status(f"[{target.host}] Attempting SSH connection", "yellow") 90 | logging.debug(f"[{target.host}] Using SSH key: {key_path}") 91 | 92 | if not os.path.exists(key_path): 93 | error_msg = f"SSH key not found: {key_path}" 94 | self.gpu_table.add_status(f"[{target.host}] {error_msg}", "red") 95 | logging.error(f"[{target.host}] {error_msg}") 96 | return {target.host: error_msg} 97 | 98 | # Attempt SSH connection using target-specific username and key 99 | try: 100 | conn = await asyncio.wait_for( 101 | asyncssh.connect( 102 | 'localhost', 103 | port=tunnel_port, 104 | username=target.username, 105 | client_keys=[key_path], # Use target-specific key 106 | known_hosts=None, 107 | keepalive_interval=30, 108 | keepalive_count_max=5 109 | ), 110 | timeout=SSH_TIMEOUT/2 111 | ) 112 | 113 | self.connections[target.host] = conn 114 | self.gpu_table.add_status(f"[{target.host}] Connection established", "green") 115 | logging.debug(f"[{target.host}] Successfully established SSH connection") 116 | return {target.host: "Connected"} 117 | 118 | except asyncio.TimeoutError: 119 | error_msg = f"SSH connection timeout after {SSH_TIMEOUT/2} seconds" 120 | self.gpu_table.add_status(f"[{target.host}] {error_msg}", "red") 121 | logging.error(f"[{target.host}] {error_msg}") 122 | return {target.host: error_msg} 123 | except asyncssh.Error as e: 124 | error_msg = f"SSH Error: {str(e)}" 125 | self.gpu_table.add_status(f"[{target.host}] {error_msg}", "red") 126 | logging.error(f"[{target.host}] {error_msg}", exc_info=True) 127 | return {target.host: error_msg} 128 | except Exception as e: 129 | error_msg = f"Unexpected connection error: {str(e)}" 130 | self.gpu_table.add_status(f"[{target.host}] {error_msg}", "red") 131 | logging.error(f"[{target.host}] {error_msg}", exc_info=True) 132 | return {target.host: error_msg} 133 | 134 | except Exception as e: 135 | error_msg = f"Critical connection error: {str(e)}" 136 | self.gpu_table.add_status(f"[{target.host}] {error_msg}", "red") 137 | logging.error(f"[{target.host}] {error_msg}", exc_info=True) 138 | return {target.host: error_msg} 139 | 140 | async def check_single_target(self, target: Target) -> Dict[str, str]: 141 | """Check a single GPU target using an existing connection.""" 142 | logging.debug(f"[{target.host}] Starting GPU status check") 143 | try: 144 | conn = self.connections.get(target.host) 145 | if not conn: 146 | logging.warning(f"[{target.host}] No active connection found") 147 | return {target.host: "No connection"} 148 | 149 | logging.debug(f"[{target.host}] Running nvidia-smi command") 150 | result = await asyncio.wait_for( 151 | conn.run('nvidia-smi -q -x', check=True), 152 | timeout=SSH_TIMEOUT 153 | ) 154 | if result.exit_status != 0: 155 | logging.error(f"[{target.host}] Command failed with status {result.exit_status}: {result.stderr}") 156 | return {target.host: f"Command failed: {result.stderr}"} 157 | 158 | # Ensure we're working with a string 159 | output = result.stdout 160 | if isinstance(output, bytes): 161 | output = output.decode('utf-8') 162 | 163 | logging.debug(f"[{target.host}] Successfully received nvidia-smi output, length: {len(output)} bytes") 164 | return {target.host: self.parse_gpu_info(target.host, output)} 165 | except asyncio.TimeoutError: 166 | logging.error(f"[{target.host}] Query timeout after {SSH_TIMEOUT} seconds") 167 | return {target.host: "Timeout"} 168 | except asyncssh.Error as exc: 169 | logging.error(f"[{target.host}] SSH Error: {str(exc)}", exc_info=True) 170 | return {target.host: f"SSH Error: {str(exc)}"} 171 | except Exception as exc: 172 | logging.error(f"[{target.host}] Unexpected error: {str(exc)}", exc_info=True) 173 | return {target.host: f"Unexpected error: {str(exc)}"} 174 | 175 | def parse_gpu_info(self, machine_name: str, xml_output: str) -> str: 176 | """Parse the GPU info from XML output.""" 177 | logging.debug(f"[{machine_name}] Parsing GPU information from XML") 178 | try: 179 | # Ensure xml_output is a string 180 | if isinstance(xml_output, bytes): 181 | xml_output = xml_output.decode('utf-8') 182 | root = ET.fromstring(xml_output) 183 | 184 | gpu_infos = [] 185 | for i, gpu in enumerate(root.findall('gpu')): 186 | try: 187 | logging.debug(f"[{machine_name}] Parsing GPU {i}") 188 | model = gpu.find('product_name').text 189 | processes = gpu.find('processes') 190 | 191 | # More robust process counting 192 | num_procs = 0 193 | if processes is not None: 194 | for process in processes.findall('process_info'): 195 | proc_name = process.find('process_name') 196 | if proc_name is not None and proc_name.text is not None: 197 | if self.proc_filter.search(proc_name.text): 198 | num_procs += 1 199 | 200 | gpu_util = gpu.find('utilization').find('gpu_util').text.removesuffix(' %') 201 | memory_usage = gpu.find('fb_memory_usage') 202 | used_mem = memory_usage.find('used').text.removesuffix(' MiB') 203 | total_mem = memory_usage.find('total').text.removesuffix(' MiB') 204 | 205 | gpu_info = GPUInfo(model, num_procs, gpu_util, used_mem, total_mem) 206 | logging.debug(f"[{machine_name}] GPU {i} info: {gpu_info}") 207 | gpu_infos.append(gpu_info) 208 | except AttributeError as e: 209 | logging.error(f"[{machine_name}] Error parsing GPU {i} info: {str(e)}", exc_info=True) 210 | return f"Error parsing GPU info: {str(e)}" 211 | 212 | # Join multiple GPU infos with newlines 213 | return '\n'.join(map(str, gpu_infos)) 214 | except ET.ParseError as e: 215 | logging.error(f"[{machine_name}] XML parse error: {str(e)}", exc_info=True) 216 | return f"XML parse error: {str(e)}" 217 | except Exception as e: 218 | logging.error(f"[{machine_name}] Unexpected error parsing GPU info: {str(e)}", exc_info=True) 219 | return f"Parse error: {str(e)}" 220 | 221 | async def run(self) -> None: 222 | """Run the main loop of the GPU checker asynchronously.""" 223 | # Set up signal handler for graceful shutdown 224 | signal.signal(signal.SIGINT, self.signal_handler) 225 | 226 | print('\n-------------------------------------------------\n') 227 | self.gpu_table.add_status("Starting GPU checker...", "green") 228 | logging.info("Starting GPU checker") 229 | 230 | try: 231 | # Initialize table with "Connecting" status 232 | initial_data = {target.host: "Connecting" for target in self.targets} 233 | self.gpu_table.update_table(initial_data) 234 | self.gpu_table.add_status(f"Initialized with {len(self.targets)} targets", "blue") 235 | logging.info(f"Initialized table with {len(self.targets)} targets") 236 | 237 | # Create live display 238 | with Live(self.gpu_table.layout, console=self.console, refresh_per_second=4) as live: 239 | try: 240 | # Open jump host connection 241 | self.gpu_table.add_status(f"Connecting to jump host {JUMP_SHELL}", "yellow") 242 | logging.info(f"Connecting to jump host {JUMP_SHELL} as {USERNAME}") 243 | self.jump_conn = await asyncio.wait_for( 244 | asyncssh.connect( 245 | JUMP_SHELL, 246 | username=USERNAME, 247 | client_keys=[SSH_KEY_PATH], 248 | keepalive_interval=30, 249 | keepalive_count_max=5 250 | ), 251 | timeout=SSH_TIMEOUT 252 | ) 253 | self.gpu_table.add_status("Successfully connected to jump host", "green") 254 | logging.info("Successfully connected to jump host") 255 | 256 | # Open connections to GPU servers in batches 257 | total_targets = len(self.targets) 258 | self.gpu_table.add_status(f"Opening connections to {total_targets} GPU servers", "blue") 259 | logging.info(f"Starting to open connections to {total_targets} GPU servers") 260 | 261 | batch_size = 10 262 | successful_connections = 0 263 | failed_connections = 0 264 | 265 | for i in range(0, total_targets, batch_size): 266 | batch = self.targets[i:i+batch_size] 267 | batch_num = i//batch_size + 1 268 | self.gpu_table.add_status(f"Processing batch {batch_num} ({len(batch)} targets)", "yellow") 269 | logging.info(f"Processing batch {batch_num} ({len(batch)} targets)") 270 | 271 | connection_tasks = [self.open_connection(target) for target in batch] 272 | connection_results = await asyncio.gather(*connection_tasks, return_exceptions=True) 273 | 274 | # Process connection results 275 | batch_successes = 0 276 | for target, result in zip(batch, connection_results): 277 | if isinstance(result, dict): 278 | self.gpu_table.update_table(result) 279 | batch_successes += 1 280 | successful_connections += 1 281 | else: 282 | failed_connections += 1 283 | error_msg = str(result) if result is not None else "Unknown error" 284 | self.gpu_table.update_table({target.host: f"Error: {error_msg}"}) 285 | 286 | progress = (i + len(batch)) / total_targets * 100 287 | self.gpu_table.add_status( 288 | f"Progress: {progress:.1f}% ({successful_connections} ok, {failed_connections} failed)", 289 | "blue" 290 | ) 291 | logging.info(f"Batch {batch_num} complete: {batch_successes}/{len(batch)} successful connections") 292 | live.update(self.gpu_table.layout) 293 | await asyncio.sleep(1) # Brief pause between batches 294 | 295 | self.gpu_table.add_status( 296 | f"Connection phase complete: {successful_connections} ok, {failed_connections} failed", 297 | "green" if failed_connections == 0 else "yellow" 298 | ) 299 | logging.info(f"Connection phase complete - Success: {successful_connections}, Failed: {failed_connections}") 300 | 301 | except asyncio.TimeoutError: 302 | error_msg = f"Timeout connecting to jump host after {SSH_TIMEOUT} seconds" 303 | self.gpu_table.add_status(error_msg, "red") 304 | logging.error(error_msg) 305 | return 306 | 307 | except Exception as e: 308 | error_msg = f"Error connecting to jump host: {str(e)}" 309 | self.gpu_table.add_status(error_msg, "red") 310 | logging.error(error_msg, exc_info=True) 311 | return 312 | 313 | # Start the query loop 314 | while self.running: 315 | try: 316 | query_tasks = [self.check_single_target(target) for target in self.targets] 317 | query_results = await asyncio.gather(*query_tasks, return_exceptions=True) 318 | 319 | successful_queries = 0 320 | failed_queries = 0 321 | 322 | for target, result in zip(self.targets, query_results): 323 | if isinstance(result, dict): 324 | self.gpu_table.update_table(result) 325 | successful_queries += 1 326 | else: 327 | failed_queries += 1 328 | error_msg = str(result) if result is not None else "Unknown error" 329 | self.gpu_table.update_table({target.host: f"Query Error: {error_msg}"}) 330 | 331 | self.gpu_table.add_status( 332 | f"Query complete: {successful_queries} ok, {failed_queries} failed", 333 | "green" if failed_queries == 0 else "yellow" 334 | ) 335 | logging.info(f"Query cycle complete - Success: {successful_queries}, Failed: {failed_queries}") 336 | 337 | live.update(self.gpu_table.layout) 338 | await asyncio.sleep(REFRESH_RATE) 339 | 340 | except Exception as e: 341 | error_msg = f"Error in query loop: {str(e)}" 342 | self.gpu_table.add_status(error_msg, "red") 343 | logging.error(error_msg, exc_info=True) 344 | await asyncio.sleep(REFRESH_RATE) # Wait before retrying 345 | 346 | except Exception as e: 347 | error_msg = f"Critical error in main loop: {str(e)}" 348 | self.gpu_table.add_status(error_msg, "red") 349 | logging.error(error_msg, exc_info=True) 350 | 351 | finally: 352 | # Cleanup 353 | logging.info("Closing connections") 354 | for conn in self.connections.values(): 355 | try: 356 | conn.close() 357 | except: 358 | pass 359 | if self.jump_conn: 360 | try: 361 | self.jump_conn.close() 362 | except: 363 | pass 364 | logging.info("GPU checker stopped") 365 | 366 | def _process_query_results(self, results): 367 | """Helper method to process query results and update the table""" 368 | data = {} 369 | for result in results: 370 | if isinstance(result, dict): 371 | data.update(result) 372 | else: 373 | logging.error(f"Query error: {str(result)}") 374 | self.gpu_table.update_table(data) 375 | 376 | def setup_logging(config: Dict) -> None: 377 | """Set up logging configuration based on debug config""" 378 | # Suppress all loggers initially 379 | logging.getLogger().setLevel(logging.CRITICAL) 380 | asyncssh_logger = logging.getLogger('asyncssh') 381 | asyncssh_logger.setLevel(logging.CRITICAL) 382 | 383 | if not config['debug']['enabled']: 384 | return 385 | 386 | try: 387 | # Log directory and file paths are now absolute from config_loader 388 | log_file = config['debug']['log_file'] 389 | log_dir = os.path.dirname(log_file) 390 | 391 | # Ensure log directory exists 392 | os.makedirs(log_dir, exist_ok=True) 393 | 394 | # Create a more detailed log format 395 | log_formatter = logging.Formatter( 396 | '%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s' 397 | ) 398 | 399 | # Set up file handler with rotation and immediate flush 400 | log_handler = RotatingFileHandler( 401 | log_file, 402 | maxBytes=config['debug']['log_max_size'], 403 | backupCount=config['debug']['log_backup_count'], 404 | mode='a' # Append mode 405 | ) 406 | log_handler.setFormatter(log_formatter) 407 | log_handler.setLevel(logging.DEBUG) 408 | 409 | # Configure root logger 410 | root_logger = logging.getLogger() 411 | root_logger.setLevel(logging.DEBUG) # Set to DEBUG to capture all levels 412 | 413 | # Remove any existing handlers to prevent duplicate logging 414 | for handler in root_logger.handlers[:]: 415 | root_logger.removeHandler(handler) 416 | 417 | root_logger.addHandler(log_handler) 418 | 419 | # Add console handler for immediate feedback 420 | console_handler = logging.StreamHandler() 421 | console_handler.setFormatter(log_formatter) 422 | console_handler.setLevel(logging.DEBUG) 423 | root_logger.addHandler(console_handler) 424 | 425 | # Log initial debug information 426 | logging.debug(f"Log file initialized at: {log_file}") 427 | logging.debug(f"Debug configuration: {config['debug']}") 428 | 429 | # Force flush 430 | log_handler.flush() 431 | 432 | except Exception as e: 433 | print(f"Error setting up logging: {str(e)}") 434 | # Set up console logging as fallback 435 | logging.basicConfig( 436 | level=logging.DEBUG if config['debug']['enabled'] else logging.CRITICAL, 437 | format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s' 438 | ) 439 | 440 | async def main(config: Dict[str, Any]) -> None: 441 | """Main function for running GPU checker.""" 442 | # Remove config loading since it's now passed in 443 | setup_logging(config) 444 | 445 | # Expand SSH key path 446 | config['ssh']['key_path'] = os.path.expanduser(config['ssh']['key_path']) 447 | 448 | if config['debug']['enabled']: 449 | logging.info("GPU checker started in debug mode") 450 | 451 | if not os.path.exists(config['ssh']['key_path']): 452 | if config['debug']['enabled']: 453 | logging.error(f'SSH key not found at {config["ssh"]["key_path"]}') 454 | print('SSH key not found. Please check the provided path.') 455 | return 456 | 457 | # Generate target list 458 | targets = generate_targets(config) 459 | 460 | if config['debug']['enabled']: 461 | logging.info(f"Generated targets: {targets}") 462 | 463 | # Update global constants 464 | global USERNAME, SSH_KEY_PATH, JUMP_SHELL, SSH_TIMEOUT, REFRESH_RATE 465 | USERNAME = config['ssh']['username'] 466 | SSH_KEY_PATH = config['ssh']['key_path'] 467 | JUMP_SHELL = config['ssh']['jump_host'] 468 | SSH_TIMEOUT = config['ssh']['timeout'] 469 | REFRESH_RATE = config['display']['refresh_rate'] 470 | 471 | gpu_checker = AsyncGPUChecker(targets) 472 | await gpu_checker.run() 473 | 474 | if __name__ == '__main__': 475 | asyncio.run(main()) 476 | -------------------------------------------------------------------------------- /ssh_gpu_monitor/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/afspies/gpu-monitor/6f53a1f991e1f9cef1829b49488d5211e3b2d34c/ssh_gpu_monitor/src/__init__.py -------------------------------------------------------------------------------- /ssh_gpu_monitor/src/config_loader.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | from pathlib import Path 4 | from typing import Any, Dict, List, NamedTuple, Optional 5 | 6 | def get_default_config_path() -> Path: 7 | """Return the default config path for the package.""" 8 | return Path(__file__).parent.parent / 'config' / 'config.yaml' 9 | 10 | def get_project_root() -> Path: 11 | """Return the root directory of the project.""" 12 | return Path(__file__).parent.parent.parent 13 | 14 | class Target(NamedTuple): 15 | """Represents a target with its associated username and key path""" 16 | host: str 17 | username: str 18 | key_path: str 19 | 20 | def generate_targets(config: Dict[str, Any]) -> List[Target]: 21 | """Generate list of targets from config specification.""" 22 | targets = [] 23 | default_username = config['ssh']['username'] 24 | default_key_path = config['ssh']['key_path'] 25 | 26 | # Add individual targets 27 | if 'individual' in config['targets']: 28 | for target in config['targets']['individual']: 29 | if isinstance(target, str): 30 | # Simple string target uses defaults 31 | targets.append(Target(target, default_username, default_key_path)) 32 | else: 33 | # Dictionary target may have overrides 34 | targets.append(Target( 35 | target['host'], 36 | target.get('username', default_username), 37 | target.get('key_path', default_key_path) 38 | )) 39 | 40 | # Add pattern-based targets 41 | if 'patterns' in config['targets']: 42 | for pattern in config['targets']['patterns']: 43 | pattern_username = pattern.get('username', default_username) 44 | pattern_key_path = pattern.get('key_path', default_key_path) 45 | targets.extend([ 46 | Target( 47 | pattern['format'].format( 48 | prefix=pattern['prefix'], 49 | number=x 50 | ), 51 | pattern_username, 52 | pattern_key_path 53 | ) 54 | for x in range(pattern['start'], pattern['end'] + 1) 55 | ]) 56 | 57 | # Remove duplicates while preserving order 58 | seen = set() 59 | unique_targets = [] 60 | for target in targets: 61 | if target.host not in seen: 62 | seen.add(target.host) 63 | unique_targets.append(target) 64 | 65 | return unique_targets 66 | 67 | def parse_args() -> argparse.Namespace: 68 | """Parse command line arguments.""" 69 | parser = argparse.ArgumentParser(description='SSH GPU Monitor') 70 | 71 | # Core arguments 72 | parser.add_argument('--get_config_path', action='store_true', 73 | help='Print the default config path and exit') 74 | parser.add_argument('--config', '-c', type=str, 75 | help='Path to custom config file') 76 | 77 | # SSH configuration 78 | parser.add_argument('--ssh.username', type=str, help='Default SSH username') 79 | parser.add_argument('--ssh.key_path', type=str, help='SSH key path') 80 | parser.add_argument('--ssh.jump_host', type=str, help='Jump host') 81 | parser.add_argument('--ssh.timeout', type=int, help='SSH timeout in seconds') 82 | 83 | # Target configuration 84 | parser.add_argument('--targets', type=str, nargs='+', 85 | help='Override all targets with specified list (will use default username)') 86 | 87 | # Display configuration 88 | parser.add_argument('--display.refresh_rate', type=int, help='Refresh rate in seconds') 89 | 90 | # Debug configuration 91 | parser.add_argument('--debug.enabled', action='store_true', help='Enable debug mode') 92 | parser.add_argument('--debug.log_dir', type=str, help='Log directory') 93 | parser.add_argument('--debug.log_file', type=str, help='Log file name') 94 | 95 | return parser.parse_args() 96 | 97 | def load_config(config_path: Optional[str] = None) -> Dict[str, Any]: 98 | """Load configuration from YAML file and override with command line arguments.""" 99 | args = parse_args() 100 | 101 | # Handle --get_config_path 102 | if args.get_config_path: 103 | print(get_default_config_path()) 104 | import sys 105 | sys.exit(0) 106 | 107 | # Use provided config path, CLI config path, or default 108 | if config_path is None: 109 | config_path = args.config if args.config else get_default_config_path() 110 | 111 | # Load YAML config 112 | with open(config_path, 'r') as f: 113 | config = yaml.safe_load(f) 114 | 115 | # Update config with command line arguments 116 | for arg, value in vars(args).items(): 117 | if value is not None and arg not in ['get_config_path', 'config']: 118 | if arg == 'targets': 119 | # Special handling for targets override 120 | config['targets'] = { 121 | 'individual': [{'host': t} for t in value], 122 | 'patterns': [] 123 | } 124 | else: 125 | # Split the argument name into sections 126 | sections = arg.split('.') 127 | 128 | # Navigate through the config dictionary 129 | current = config 130 | for section in sections[:-1]: 131 | current = current[section] 132 | 133 | # Set the value 134 | current[sections[-1]] = value 135 | 136 | # Make log paths absolute using project root 137 | if config['debug']['enabled']: 138 | project_root = get_project_root() 139 | # Ensure log_dir is an absolute path 140 | log_dir = Path(config['debug']['log_dir']) 141 | if not log_dir.is_absolute(): 142 | log_dir = project_root / log_dir 143 | 144 | # Create log directory 145 | log_dir.mkdir(parents=True, exist_ok=True) 146 | 147 | # Update config with absolute paths 148 | config['debug']['log_dir'] = str(log_dir) 149 | config['debug']['log_file'] = str(log_dir / config['debug']['log_file']) 150 | 151 | # Print debug info about paths 152 | print(f"Log directory: {config['debug']['log_dir']}") 153 | print(f"Log file: {config['debug']['log_file']}") 154 | 155 | return config 156 | -------------------------------------------------------------------------------- /ssh_gpu_monitor/src/table_display.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from rich.console import Console 3 | from rich.table import Table 4 | from rich.live import Live 5 | from rich.align import Align 6 | from rich.panel import Panel 7 | from rich.layout import Layout 8 | from rich.text import Text 9 | from rich.style import Style 10 | from collections import deque 11 | from datetime import datetime 12 | 13 | class GPUTable: 14 | def __init__(self): 15 | self.console = Console() 16 | # Initialize dictionary to track max widths for each column 17 | self.max_widths = { 18 | "Hostname": 20, # Start with our minimum requirements 19 | "Status/Model": 30, 20 | "Free": 5, 21 | "Procs": 5, 22 | "GPU %": 6, 23 | "Memory": 15 24 | } 25 | # Add message queue for status updates 26 | self.messages = deque(maxlen=5) # Keep last 5 messages 27 | 28 | # Create layout 29 | self.layout = Layout() 30 | self.layout.split( 31 | Layout(name="status", size=7), 32 | Layout(name="table") 33 | ) 34 | 35 | # Add initial status message 36 | self.add_status("Initializing GPU checker...", "blue") 37 | 38 | # Create initial table 39 | self._create_table() 40 | self.data = {} 41 | self.layout["table"].update(self.table) 42 | 43 | def add_status(self, message: str, style: str = "white") -> None: 44 | """Add a status message to the display""" 45 | timestamp = datetime.now().strftime("%H:%M:%S") 46 | self.messages.append((timestamp, message, style)) 47 | self._update_layout() 48 | 49 | def _create_status_panel(self) -> Panel: 50 | """Create the status panel with messages""" 51 | if not self.messages: 52 | return Panel( 53 | Align.left("No status updates"), 54 | title="Status", 55 | border_style="blue", 56 | height=7 57 | ) 58 | 59 | messages = [] 60 | for timestamp, msg, style in self.messages: 61 | messages.append(Text.from_markup(f"[{timestamp}] [{style}]{msg}[/]")) 62 | 63 | return Panel( 64 | Align.left("\n".join(str(msg) for msg in messages)), 65 | title="Status", 66 | border_style="blue", 67 | height=7 68 | ) 69 | 70 | def _update_layout(self) -> None: 71 | """Update the layout with current status and table""" 72 | self.layout["status"].update(self._create_status_panel()) 73 | self.layout["table"].update(self.table) 74 | 75 | def _create_table(self): 76 | """Create the table with proper columns""" 77 | self.raw_table = Table( 78 | title="GPU Status", 79 | show_header=True, 80 | header_style="bold magenta", 81 | border_style="blue", 82 | pad_edge=True, 83 | padding=(0, 1) 84 | ) 85 | 86 | # Use tracked max widths for columns 87 | self.raw_table.add_column("Hostname", style="cyan", no_wrap=True, min_width=self.max_widths["Hostname"]) 88 | self.raw_table.add_column("Status/Model", style="green", min_width=self.max_widths["Status/Model"]) 89 | self.raw_table.add_column("Free", style="yellow", justify="center", min_width=self.max_widths["Free"]) 90 | self.raw_table.add_column("Procs", style="red", justify="center", min_width=self.max_widths["Procs"]) 91 | self.raw_table.add_column("GPU %", style="blue", justify="right", min_width=self.max_widths["GPU %"]) 92 | self.raw_table.add_column("Memory", style="green", justify="right", min_width=self.max_widths["Memory"]) 93 | 94 | self.table = Align.center(self.raw_table) 95 | self._update_layout() 96 | 97 | def update_max_widths(self, hostname: str, values: list[str]) -> None: 98 | """Update the maximum widths based on new values""" 99 | columns = ["Hostname", "Status/Model", "Free", "Procs", "GPU %", "Memory"] 100 | values = [hostname] + values # Add hostname to the values list 101 | 102 | for col, val in zip(columns, values): 103 | if val != "—": # Don't update for placeholder values 104 | self.max_widths[col] = max(self.max_widths[col], len(str(val))) 105 | 106 | def update_table(self, new_data: dict[str, str]) -> None: 107 | """Update the table with new data""" 108 | logging.debug(f"Updating table with data: {new_data}") 109 | self.data.update(new_data) 110 | 111 | # First pass: update maximum widths 112 | for hostname, status in self.data.items(): 113 | if status in ["Connecting", "Connected", "No connection"] or \ 114 | status.startswith(("Connection", "Error", "Timeout", "SSH Error", "Unexpected error", "Parse error", "XML parse error")): 115 | self.update_max_widths(hostname, [status, "—", "—", "—", "—"]) 116 | else: 117 | try: 118 | # Split multiple GPU entries 119 | gpu_entries = status.split('\n') 120 | for gpu_entry in gpu_entries: 121 | parts = [p.strip() for p in gpu_entry.split("|")] 122 | if len(parts) == 5: 123 | model = parts[0] 124 | is_free = parts[1].split(":")[1].strip() 125 | num_procs = parts[2].split(":")[1].strip() 126 | gpu_util = parts[3].split(":")[1].strip() 127 | memory = parts[4].split(":")[1].strip() 128 | self.update_max_widths(hostname, [model, is_free, num_procs, gpu_util, memory]) 129 | except (IndexError, ValueError) as e: 130 | logging.warning(f"Failed to update widths for {hostname}: {e}") 131 | 132 | # Recreate the table with updated widths 133 | self._create_table() 134 | 135 | # Second pass: add rows 136 | for hostname, status in sorted(self.data.items()): 137 | try: 138 | if status in ["Connecting", "Connected", "No connection"] or \ 139 | status.startswith(("Connection", "Error", "Timeout", "SSH Error", "Unexpected error", "Parse error", "XML parse error")): 140 | self.raw_table.add_row( 141 | hostname, 142 | status, 143 | "—", 144 | "—", 145 | "—", 146 | "—" 147 | ) 148 | else: 149 | try: 150 | # Split multiple GPU entries 151 | gpu_entries = status.split('\n') 152 | for i, gpu_entry in enumerate(gpu_entries): 153 | parts = [p.strip() for p in gpu_entry.split("|")] 154 | if len(parts) != 5: 155 | raise ValueError(f"Invalid format: expected 5 parts, got {len(parts)}") 156 | 157 | model = parts[0] 158 | is_free = parts[1].split(":")[1].strip() 159 | num_procs = parts[2].split(":")[1].strip() 160 | gpu_util = parts[3].split(":")[1].strip() 161 | memory = parts[4].split(":")[1].strip() 162 | 163 | # Only show hostname on first GPU row 164 | display_hostname = hostname if i == 0 else "" 165 | 166 | self.raw_table.add_row( 167 | display_hostname, 168 | model, 169 | is_free, 170 | num_procs, 171 | gpu_util, 172 | memory 173 | ) 174 | except (IndexError, ValueError) as e: 175 | logging.error(f"Failed to parse GPU info for {hostname}: {str(e)}") 176 | self.raw_table.add_row( 177 | hostname, 178 | f"Parse error: {str(e)}", 179 | "—", 180 | "—", 181 | "—", 182 | "—" 183 | ) 184 | 185 | except Exception as e: 186 | logging.error(f"Error adding row for {hostname}: {str(e)}") 187 | self.raw_table.add_row( 188 | hostname, 189 | f"Error: {str(e)}", 190 | "—", 191 | "—", 192 | "—", 193 | "—" 194 | ) 195 | 196 | logging.debug("Table updated successfully") 197 | 198 | def show_goodbye(self): 199 | """Show goodbye message instead of table""" 200 | goodbye_msg = "Goodbye! 👋" 201 | self.table = Align.center(Panel.fit(goodbye_msg)) 202 | self._update_layout() 203 | 204 | def get_live_table(self) -> Live: 205 | """Get a Live display context manager for the table""" 206 | return Live( 207 | self.layout, # Use the full layout instead of just the table 208 | refresh_per_second=4, 209 | console=self.console, 210 | vertical_overflow="visible" 211 | ) 212 | -------------------------------------------------------------------------------- /todos.md: -------------------------------------------------------------------------------- 1 | # SSH GPU Checker Revamp TODOs 2 | 3 | ## 1. Implement Asynchronous SSH Checks 4 | - [X] Refactor the main SSH checking function to be asynchronous 5 | - [X] Implement an async function to check a single target 6 | - [X] Create an async main function to run all checks concurrently 7 | - [X] Implement a mechanism to update the results table as checks complete 8 | 9 | ## 2. Enhance Display with Rich Library 10 | - [X] Design a new table layout with columns for hostname, status, GPU info, etc 11 | - [X] Implement color coding for different statuses (e.g., green for available, red for unavailable) 12 | - [X] Use font emphases (bold, italic) to highlight important information 13 | - [X] Implement live updating of the table as results come in 14 | 15 | ## 3. Restructure Project 16 | - [X] Create a new project structure: 17 | ``` 18 | ssh_gpu_checker/ 19 | ├── config/ 20 | │ └── config.yaml 21 | ├── src/ 22 | │ └── ssh_checker.py 23 | │ └── table_display.py 24 | ├── main.py 25 | └── requirements.txt 26 | ``` 27 | - [X] Move configuration to `config/config.yaml` 28 | - [X] Implement config loading function in `main.py` 29 | - [ ] Create `src/ssh_checker.py` for core SSH logic 30 | - [X] Create `src/table_display.py` for Rich table implementation 31 | - [ ] Update `main.py` to use the new structure and modules 32 | 33 | ## 4. Additional Improvements 34 | - [X] Add error handling and logging 35 | - [X] Implement command-line arguments for flexibility 36 | - [ ] Write unit tests for core functions 37 | - [X] Add documentation and comments throughout the code 38 | - [ ] Create a README.md with usage instructions and project overview 39 | --------------------------------------------------------------------------------