├── MANIFEST.in ├── docs ├── .gitignore ├── _static │ ├── portal_gun.png │ └── reference_policy.json ├── Makefile ├── index.rst ├── install.rst ├── config.rst ├── overview.rst ├── conf.py ├── commands.rst └── portal_spec.rst ├── portal_gun ├── providers │ ├── __init__.py │ ├── aws │ │ ├── __init__.py │ │ ├── pretty_print.py │ │ ├── helpers.py │ │ └── aws_client.py │ ├── gcp │ │ ├── __init__.py │ │ ├── pretty_print.py │ │ ├── helpers.py │ │ └── gcp_client.py │ └── exceptions.py ├── configuration │ ├── __init__.py │ ├── schemas │ │ ├── __init__.py │ │ ├── provision.py │ │ ├── config.py │ │ ├── compute_aws.py │ │ ├── portal.py │ │ └── compute_gcp.py │ ├── constants.py │ └── draft.py ├── context_managers │ ├── __init__.py │ ├── no_print.py │ ├── print_scope.py │ ├── print_indent.py │ └── step.py ├── fabric │ ├── __init__.py │ └── operations.py ├── __init__.py ├── __main__.py ├── commands │ ├── exceptions.py │ ├── __init__.py │ ├── handlers │ │ ├── __init__.py │ │ ├── factory.py │ │ ├── base_handler.py │ │ ├── gcp_handler.py │ │ └── aws_handler.py │ ├── base_command.py │ ├── close_portal.py │ ├── open_portal.py │ ├── generate_portal_spec.py │ ├── show_portal_info.py │ ├── ssh.py │ ├── open_channel.py │ ├── helpers.py │ └── volume.py ├── main.py └── one_of_schema.py ├── setup.cfg ├── .gitignore ├── install ├── portal ├── portal_gun.py ├── requirements.txt ├── LICENSE ├── setup.py └── README.rst /MANIFEST.in: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | -------------------------------------------------------------------------------- /portal_gun/providers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /portal_gun/configuration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /portal_gun/providers/aws/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /portal_gun/providers/gcp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /portal_gun/context_managers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_file = LICENSE -------------------------------------------------------------------------------- /portal_gun/fabric/__init__.py: -------------------------------------------------------------------------------- 1 | from portal_gun.fabric.operations import * 2 | -------------------------------------------------------------------------------- /portal_gun/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.0' 2 | __author__ = 'Vadim Fedorov' 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | env/ 3 | dist/ 4 | build/ 5 | *.egg-info/ 6 | *.json 7 | *.pyc 8 | -------------------------------------------------------------------------------- /install: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | virtualenv env 3 | source env/bin/activate 4 | pip install -r requirements.txt -------------------------------------------------------------------------------- /docs/_static/portal_gun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coderik/portal-gun/HEAD/docs/_static/portal_gun.png -------------------------------------------------------------------------------- /portal_gun/__main__.py: -------------------------------------------------------------------------------- 1 | """ Executed when model is called as script. """ 2 | 3 | from portal_gun.main import main 4 | 5 | main() 6 | -------------------------------------------------------------------------------- /portal_gun/configuration/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import ConfigSchema 2 | from .portal import PortalSchema, ComputeSchema 3 | from marshmallow import ValidationError -------------------------------------------------------------------------------- /portal: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PG_PATH="$(dirname "$(realpath "$0")")" 3 | export PYTHONPATH=$PYTHONPATH:$PG_PATH; 4 | $PG_PATH/env/bin/python -m portal_gun -c $PG_PATH/config.json $@ -------------------------------------------------------------------------------- /portal_gun.py: -------------------------------------------------------------------------------- 1 | """ Convenience wrapper for running application directly from source tree. """ 2 | 3 | from portal_gun.main import main 4 | 5 | 6 | if __name__ == '__main__': 7 | main() 8 | -------------------------------------------------------------------------------- /portal_gun/commands/exceptions.py: -------------------------------------------------------------------------------- 1 | class CommandError(Exception): 2 | def __init__(self, message): 3 | super(CommandError, self).__init__(message) 4 | 5 | def __srt__(self): 6 | return self.message 7 | 8 | def __repr__(self): 9 | return self.message 10 | -------------------------------------------------------------------------------- /portal_gun/configuration/schemas/provision.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, Schema 2 | 3 | 4 | class ProvisionActionSchema(Schema): 5 | name = fields.String(required=True) 6 | args = fields.Dict(required=True) 7 | 8 | class Meta: 9 | ordered = True 10 | -------------------------------------------------------------------------------- /portal_gun/providers/exceptions.py: -------------------------------------------------------------------------------- 1 | class ProviderRequestError(Exception): 2 | def __init__(self, message): 3 | super(ProviderRequestError, self).__init__(message) 4 | 5 | def __srt__(self): 6 | return self.message 7 | 8 | def __repr__(self): 9 | return self.message 10 | -------------------------------------------------------------------------------- /portal_gun/context_managers/no_print.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | 4 | class no_print: 5 | def __init__(self): 6 | pass 7 | 8 | def __enter__(self): 9 | self._original_stdout = sys.stdout 10 | sys.stdout = open(os.devnull, 'w') 11 | 12 | def __exit__(self, exc_type, exc_val, exc_tb): 13 | sys.stdout = self._original_stdout 14 | -------------------------------------------------------------------------------- /portal_gun/configuration/constants.py: -------------------------------------------------------------------------------- 1 | from sys import argv 2 | from os import path 3 | 4 | default_config_filename = 'config.json' 5 | 6 | # Paths where to look up for config file. Order reflects priority (from higher to lower) 7 | config_paths = [ 8 | path.join(path.dirname(path.abspath(argv[0])), default_config_filename), 9 | path.expanduser('~/.portal-gun/{}'.format(default_config_filename)) 10 | ] 11 | 12 | cloud_provider_env = 'PG_CLOUD_PROVIDER' 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | asn1crypto==0.24.0 2 | bcrypt==3.1.4 3 | boto3==1.5.18 4 | botocore==1.8.32 5 | cachetools==2.1.0 6 | cffi==1.11.4 7 | cryptography==2.1.4 8 | docutils==0.14 9 | enum34==1.1.6 10 | fabric==2.4.0 11 | google-api-python-client==1.6.7 12 | google-auth==1.4.1 13 | google-auth-httplib2==0.0.3 14 | httplib2==0.11.3 15 | idna==2.6 16 | invoke==1.2.0 17 | ipaddress==1.0.19 18 | jmespath==0.9.3 19 | marshmallow==3.0.0b8 20 | oauth2client==4.1.2 21 | paramiko==2.4.0 22 | pyasn1==0.4.2 23 | pyasn1-modules==0.2.1 24 | pycparser==2.18 25 | PyNaCl==1.2.1 26 | python-dateutil==2.6.1 27 | rsa==3.4.2 28 | s3transfer==0.1.12 29 | six==1.11.0 30 | uritemplate==3.0.0 31 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = PortalGun 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /portal_gun/context_managers/print_scope.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | from portal_gun.context_managers.print_indent import print_indent 4 | 5 | 6 | @contextmanager 7 | def print_scope(prologue, epilogue=None, indent=None): 8 | """ 9 | Context manager for enclosing any output within a prologue and 10 | (optionally) an epilogue lines. The output is also implicitly indented. 11 | """ 12 | 13 | print(prologue) 14 | 15 | with print_indent(indent): 16 | yield 17 | 18 | if epilogue is not None: 19 | print(epilogue) 20 | 21 | 22 | def set_default_indent(value): 23 | """ 24 | Conventional setter for the size of the global default indent. 25 | :param value: Indent size. 26 | """ 27 | 28 | print_indent.set_default_indent(value) 29 | -------------------------------------------------------------------------------- /portal_gun/commands/__init__.py: -------------------------------------------------------------------------------- 1 | import pkgutil 2 | 3 | from portal_gun.commands.base_command import BaseCommand 4 | 5 | 6 | def fill_subparsers(subparsers): 7 | """ Convenient wrapper for method that fills subparsers. """ 8 | return BaseCommand.fill_subparsers(subparsers) 9 | 10 | 11 | def create_command(cmd, args): 12 | """ Convenient wrapper for factory method that creates Commands. """ 13 | return BaseCommand.create_command(cmd, args) 14 | 15 | # Expose factory method 16 | __all__ = [ 17 | 'fill_subparsers', 18 | 'create_command' 19 | ] 20 | 21 | # Make sure all subclasses of BaseCommand are imported 22 | __path__ = pkgutil.extend_path(__path__, __name__) 23 | for importer, modname, ispkg in pkgutil.walk_packages(path=__path__, prefix=__name__+'.'): 24 | __import__(modname) 25 | -------------------------------------------------------------------------------- /portal_gun/commands/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | import pkgutil 2 | 3 | from .factory import get_handler_class, create_handler, list_providers, describe_providers 4 | 5 | 6 | def generate_portal_spec(provider): 7 | """ 8 | Generates draft portal specification. 9 | :param provider: Name of cloud provider 10 | :rtype: dict 11 | """ 12 | handler_class = get_handler_class(provider) 13 | return handler_class.generate_portal_spec() 14 | 15 | __all__ = [ 16 | 'create_handler', 17 | 'generate_portal_spec', 18 | 'list_providers', 19 | 'describe_providers' 20 | ] 21 | 22 | # Make sure all subclasses of BaseHandler are imported 23 | __path__ = pkgutil.extend_path(__path__, __name__) 24 | for importer, modname, ispkg in pkgutil.walk_packages(path=__path__, prefix=__name__+'.'): 25 | __import__(modname) 26 | -------------------------------------------------------------------------------- /portal_gun/configuration/schemas/config.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, Schema, validates_schema, ValidationError 2 | 3 | 4 | class AwsSchema(Schema): 5 | region = fields.String(required=True, default='string') 6 | access_key = fields.String(required=True, default='string') 7 | secret_key = fields.String(required=True, default='string') 8 | 9 | class Meta: 10 | ordered = True 11 | 12 | 13 | class GcpSchema(Schema): 14 | project = fields.String(required=True, default='string') 15 | region = fields.String(required=True, default='string') 16 | service_account_file = fields.String(required=True, default='string') 17 | 18 | class Meta: 19 | ordered = True 20 | 21 | 22 | class ConfigSchema(Schema): 23 | aws = fields.Nested(AwsSchema) 24 | gcp = fields.Nested(GcpSchema) 25 | 26 | @validates_schema 27 | def validate_providers(self, data): 28 | if len(data) == 0: 29 | raise ValidationError('Configuration for at least one cloud provider should be specified') 30 | -------------------------------------------------------------------------------- /portal_gun/commands/base_command.py: -------------------------------------------------------------------------------- 1 | class BaseCommand(object): 2 | """ Base class for all specific Commands. """ 3 | def __init__(self, args): 4 | self._args = args 5 | 6 | def run(self): 7 | raise NotImplementedError('Every subclass of BaseCommand should implement run() method.') 8 | 9 | @staticmethod 10 | def cmd(): 11 | raise NotImplementedError('Every subclass of BaseCommand should implement static cmd() method.') 12 | 13 | @classmethod 14 | def add_subparser(cls, subparsers): 15 | raise NotImplementedError('Every subclass of BaseCommand should implement static cmd() method.') 16 | 17 | @staticmethod 18 | def fill_subparsers(subparsers): 19 | """ Add subparser for every instance of Command. """ 20 | 21 | for cls in BaseCommand.__subclasses__(): 22 | cls.add_subparser(subparsers) 23 | 24 | @staticmethod 25 | def create_command(cmd, args): 26 | """ Factory method that creates instances of Commands. """ 27 | 28 | for cls in BaseCommand.__subclasses__(): 29 | if cls.cmd() == cmd: 30 | return cls(args) 31 | 32 | return None 33 | 34 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Portal Gun documentation master file, created by 2 | sphinx-quickstart on Mon Apr 16 21:51:03 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Portal Gun 7 | ========== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | .. 14 | 15 | *Deep Learning on Amazon EC2 Spot Instances without the agonizing pain.* 16 | 17 | Release v\ |version|. 18 | 19 | ---- 20 | 21 | **Portal Gun** is a command line tool that automates repetitive tasks associated with the management of Spot Instances on Amazon EC2 service. 22 | 23 | Primarily it is intended to simplify usage of AWS Spot Instances for Deep Learning. This focus will further shape the future development. 24 | 25 | User Guide 26 | ========== 27 | 28 | .. toctree:: 29 | :maxdepth: 2 30 | 31 | overview 32 | install 33 | config 34 | portal_spec 35 | commands 36 | 37 | 38 | 39 | .. Indices and tables 40 | .. ================== 41 | 42 | .. * :ref:`genindex` 43 | .. * :ref:`modindex` 44 | .. * :ref:`search` 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Vadim Fedorov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | .. _install: 2 | 3 | ============ 4 | Installation 5 | ============ 6 | 7 | **Portal Gun** has the following external dependencies: 8 | 9 | - `boto3 `_ - to make requests to AWS; 10 | - `Fabric `_ - to execute commands over ssh; 11 | - `marshmallow `_ - for serialization. 12 | 13 | Note that Python 3 is not supported yet, because Fabric is Python 2 only. Migration to Python 3 should be made after the first stable release of `Fabric 2 `_. 14 | 15 | Install or upgrade from the PyPI 16 | ================================ 17 | 18 | It is **strongly recommended** to install Portal Gun in **a virtual Python environment**. For details about virtual environments see `virtualenv documentation `_. 19 | 20 | To install the latest stable version from the PyPI: 21 | 22 | :: 23 | 24 | $ pip install -U portal-gun 25 | 26 | To install the latest pre-release version from the PyPI: 27 | 28 | :: 29 | 30 | $ pip install -U portal-gun --pre -------------------------------------------------------------------------------- /portal_gun/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from sys import exit 3 | 4 | from portal_gun import __version__ 5 | from portal_gun.commands import fill_subparsers, create_command 6 | from portal_gun.commands.exceptions import CommandError 7 | from portal_gun.context_managers.step import StepError 8 | from portal_gun.providers.exceptions import ProviderRequestError 9 | 10 | 11 | def main(): 12 | # Parse command line arguments 13 | parser = argparse.ArgumentParser(prog='PortalGun') 14 | subparsers = parser.add_subparsers(title='commands', dest='command') 15 | 16 | # Add sub argparsers for commands 17 | fill_subparsers(subparsers) 18 | 19 | parser.add_argument('-c', '--config', default=None, dest='config', 20 | help='set name and location of configuration file') 21 | parser.add_argument('--version', action='version', version=__version__) 22 | args = parser.parse_args() 23 | 24 | command = create_command(args.command, args) 25 | 26 | if command is None: 27 | exit('Unknown command: {}'.format(args.command)) 28 | 29 | try: 30 | command.run() 31 | except (CommandError, StepError, ProviderRequestError) as e: 32 | print('{}'.format(e).expandtabs(4)) 33 | -------------------------------------------------------------------------------- /portal_gun/providers/gcp/pretty_print.py: -------------------------------------------------------------------------------- 1 | def print_volume(volume): 2 | fill_width = 20 3 | 4 | users = volume['users'] if 'users' in volume else [] 5 | tags = volume['labels'] if 'labels' in volume else {} 6 | tags = ['{}:{}'.format(key, value) for key, value in list(tags.items())] 7 | 8 | state = 'in-use' if len(users) > 0 else 'available' 9 | 10 | try: 11 | print('{:{fill}} {}'.format('Volume Id:', volume['id'], fill=fill_width)) 12 | print('{:{fill}} {}'.format('Name:', volume['name'], fill=fill_width)) 13 | print('{:{fill}} {}Gb'.format('Size:', volume['sizeGb'], fill=fill_width)) 14 | print('{:{fill}} {}'.format('Availability Zone:', volume['zone'], fill=fill_width)) 15 | print('{:{fill}} {}'.format('State:', state, fill=fill_width)) 16 | for user in users: 17 | print('{:{fill}} {}'.format('Attached to:', user.rsplit('/', 1)[1], fill=fill_width)) 18 | # print('{:{fill}} {}'.format('Attached as:', volume['Attachments'][0]['Device'], fill=fill_width)) 19 | if len(tags) > 0: 20 | print('{:{fill}} {}'.format('User Tags:', ' '.join(tags), fill=fill_width)) 21 | print('') 22 | except KeyError as e: 23 | exit('Unexpected format of Volume. Key {} is missing'.format(e)) 24 | -------------------------------------------------------------------------------- /portal_gun/commands/close_portal.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from portal_gun.commands.helpers import get_provider_config, get_portal_spec, get_portal_name, \ 4 | get_provider_from_portal 5 | from portal_gun.context_managers.print_scope import print_scope 6 | from .base_command import BaseCommand 7 | from .handlers import create_handler 8 | 9 | 10 | class ClosePortalCommand(BaseCommand): 11 | def __init__(self, args): 12 | BaseCommand.__init__(self, args) 13 | 14 | @staticmethod 15 | def cmd(): 16 | return 'close' 17 | 18 | @classmethod 19 | def add_subparser(cls, subparsers): 20 | parser = subparsers.add_parser(cls.cmd(), help='Close portal') 21 | parser.add_argument('portal', help='Name of portal') 22 | 23 | def run(self): 24 | # Find, parse and validate configs 25 | with print_scope('Checking configuration:', 'Done.\n'): 26 | portal_name = get_portal_name(self._args.portal) 27 | portal_spec = get_portal_spec(portal_name) 28 | provider_name = get_provider_from_portal(portal_spec) 29 | provider_config = get_provider_config(self._args.config, provider_name) 30 | 31 | # Create appropriate command handler for given cloud provider 32 | handler = create_handler(provider_name, provider_config) 33 | 34 | handler.close_portal(portal_spec, portal_name) 35 | -------------------------------------------------------------------------------- /portal_gun/configuration/schemas/compute_aws.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, Schema 2 | 3 | from .provision import ProvisionActionSchema 4 | 5 | 6 | class InstanceSchema(Schema): 7 | type = fields.String(required=True) 8 | image_id = fields.String(required=True) 9 | availability_zone = fields.String(required=True) 10 | ebs_optimized = fields.Boolean() 11 | iam_fleet_role = fields.String(required=True) 12 | 13 | class Meta: 14 | ordered = True 15 | 16 | 17 | class AuthSchema(Schema): 18 | key_pair_name = fields.String(required=True) 19 | identity_file = fields.String(required=True) 20 | user = fields.String(required=True) 21 | group = fields.String(required=True) 22 | 23 | class Meta: 24 | ordered = True 25 | 26 | 27 | class NetworkSchema(Schema): 28 | security_group_id = fields.String(required=True) 29 | subnet_id = fields.String() 30 | 31 | class Meta: 32 | ordered = True 33 | 34 | 35 | class ComputeAwsSchema(Schema): 36 | provider = fields.String(required=True) 37 | instance = fields.Nested(InstanceSchema, required=True) 38 | auth = fields.Nested(AuthSchema, required=True) 39 | network = fields.Nested(NetworkSchema, required=True) 40 | provision_actions = fields.Nested(ProvisionActionSchema, many=True) 41 | 42 | class Meta: 43 | ordered = True 44 | -------------------------------------------------------------------------------- /portal_gun/providers/aws/pretty_print.py: -------------------------------------------------------------------------------- 1 | def print_volume(volume): 2 | fill_width = 20 3 | 4 | tags = volume['Tags'] if 'Tags' in volume else [] 5 | 6 | # Look for Name tag 7 | name = next((tag['Value'] for tag in tags if tag['Key'] == 'Name'), '') 8 | 9 | # Transform tags to a different format (also drop Name tag) 10 | tags = ['{}:{}'.format(tag['Key'], tag['Value']) for tag in tags if tag['Key'] != 'Name'] 11 | 12 | try: 13 | print('{:{fill}} {}'.format('Volume Id:', volume['VolumeId'], fill=fill_width)) 14 | print('{:{fill}} {}'.format('Name:', name, fill=fill_width)) 15 | print('{:{fill}} {}Gb'.format('Size:', volume['Size'], fill=fill_width)) 16 | print('{:{fill}} {}'.format('Availability Zone:', volume['AvailabilityZone'], fill=fill_width)) 17 | print('{:{fill}} {}'.format('State:', volume['State'], fill=fill_width)) 18 | if len(volume['Attachments']) > 0: 19 | print('{:{fill}} {}'.format('Attached to:', volume['Attachments'][0]['InstanceId'], fill=fill_width)) 20 | print('{:{fill}} {}'.format('Attached as:', volume['Attachments'][0]['Device'], fill=fill_width)) 21 | if len(tags) > 0: 22 | print('{:{fill}} {}'.format('User Tags:', ' '.join(tags), fill=fill_width)) 23 | print('') 24 | except KeyError as e: 25 | exit('Unexpected format of Volume. Key {} is missing'.format(e)) 26 | -------------------------------------------------------------------------------- /portal_gun/commands/open_portal.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from portal_gun.commands.helpers import get_provider_config, get_portal_spec, get_portal_name, \ 4 | get_provider_from_portal 5 | from portal_gun.context_managers.print_scope import print_scope 6 | from .base_command import BaseCommand 7 | from .handlers import create_handler 8 | 9 | 10 | class OpenPortalCommand(BaseCommand): 11 | def __init__(self, args): 12 | BaseCommand.__init__(self, args) 13 | 14 | @staticmethod 15 | def cmd(): 16 | return 'open' 17 | 18 | @classmethod 19 | def add_subparser(cls, subparsers): 20 | parser = subparsers.add_parser(cls.cmd(), help='Open portal') 21 | parser.add_argument('portal', help='Name of portal') 22 | 23 | # TODO: add verbose mode that prints all configs and dry-run mode to check the configs and permissions 24 | def run(self): 25 | # Find, parse and validate configs 26 | with print_scope('Checking configuration:', 'Done.\n'): 27 | portal_name = get_portal_name(self._args.portal) 28 | portal_spec = get_portal_spec(portal_name) 29 | provider_name = get_provider_from_portal(portal_spec) 30 | provider_config = get_provider_config(self._args.config, provider_name) 31 | 32 | # Create appropriate command handler for given cloud provider 33 | handler = create_handler(provider_name, provider_config) 34 | 35 | handler.open_portal(portal_spec, portal_name) 36 | -------------------------------------------------------------------------------- /portal_gun/configuration/schemas/portal.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, Schema 2 | from portal_gun.one_of_schema import OneOfSchema # replace by the proper marshmallow-oneofschema package 3 | 4 | from .compute_aws import ComputeAwsSchema 5 | from .compute_gcp import ComputeGcpSchema 6 | 7 | 8 | class ComputeSchema(OneOfSchema): 9 | type_field = 'provider' 10 | type_field_remove = False 11 | type_schemas = { 12 | 'aws': ComputeAwsSchema, 13 | 'gcp': ComputeGcpSchema 14 | } 15 | 16 | def get_obj_type(self, obj): 17 | # TODO: implement 18 | return 'aws' 19 | 20 | 21 | class PersistentVolumeSchema(Schema): 22 | volume_id = fields.String(required=True) 23 | device = fields.String(required=True) 24 | mount_point = fields.String(required=True) 25 | 26 | class Meta: 27 | ordered = True 28 | 29 | 30 | class ChannelSchema(Schema): 31 | direction = fields.String(required=True) 32 | local_path = fields.String(required=True) 33 | remote_path = fields.String(required=True) 34 | recursive = fields.Boolean() 35 | delay = fields.Float() 36 | 37 | class Meta: 38 | ordered = True 39 | 40 | 41 | class PortalSchema(Schema): 42 | compute = fields.Nested(ComputeSchema, required=True) 43 | persistent_volumes = fields.Nested(PersistentVolumeSchema, required=True, many=True) 44 | channels = fields.Nested(ChannelSchema, many=True) 45 | 46 | class Meta: 47 | ordered = True 48 | -------------------------------------------------------------------------------- /portal_gun/commands/handlers/factory.py: -------------------------------------------------------------------------------- 1 | from .base_handler import BaseHandler 2 | 3 | 4 | def get_handler_class(provider_name): 5 | """ 6 | Find appropriate Handler class for given cloud provider. 7 | :param provider_name: Name of cloud provider 8 | :rtype: type 9 | """ 10 | 11 | for cls in BaseHandler.__subclasses__(): 12 | if cls.provider_name() == provider_name: 13 | return cls 14 | 15 | raise Exception('Unknown cloud provider: {}'.format(provider_name)) 16 | 17 | 18 | def create_handler(provider_name, config): 19 | """ 20 | Factory method that creates instances of Handlers. 21 | :param provider_name: Name of cloud provider 22 | :param config: Cloud provider config 23 | :return: Subclass of BaseHandler 24 | """ 25 | 26 | handler_class = get_handler_class(provider_name) 27 | return handler_class(config) 28 | 29 | 30 | def list_providers(): 31 | """ 32 | Get list of names of all supported cloud providers 33 | :rtype: list 34 | """ 35 | 36 | return [cls.provider_name() for cls in BaseHandler.__subclasses__()] 37 | 38 | 39 | def describe_providers(): 40 | """ 41 | Get list of descriptions of all supported cloud providers. 42 | Description includes fields: 'name', 'long_name'. 43 | :rtype: list 44 | """ 45 | 46 | return [ 47 | { 48 | 'name': cls.provider_name(), 49 | 'long_name': cls.provider_long_name() 50 | } 51 | for cls in BaseHandler.__subclasses__() 52 | ] 53 | -------------------------------------------------------------------------------- /portal_gun/context_managers/print_indent.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | class print_indent: 5 | """ 6 | Context manager for implicit indentation of text printed to stdout and stderr. 7 | All output within the corresponding 'with' statement gets indented by a specified 8 | number of spaces. 9 | """ 10 | 11 | _default_indent = 4 12 | 13 | @staticmethod 14 | def set_default_indent(value): 15 | assert type(value) == int 16 | assert value >= 0 17 | print_indent._default_indent = value 18 | 19 | def __init__(self, indent=None): 20 | if indent is None: 21 | indent = print_indent._default_indent 22 | else: 23 | assert type(indent) == int 24 | assert indent >= 0 25 | 26 | self._indent = indent 27 | 28 | def __enter__(self): 29 | self._original_stdout = sys.stdout 30 | self._original_stderr = sys.stderr 31 | sys.stdout = print_indent.Wrapper(sys.stdout, self._indent) 32 | sys.stderr = print_indent.Wrapper(sys.stderr, self._indent) 33 | 34 | def __exit__(self, exc_type, exc_val, exc_tb): 35 | sys.stdout = self._original_stdout 36 | sys.stderr = self._original_stderr 37 | 38 | class Wrapper: 39 | def __init__(self, writable, indent): 40 | self._writable = writable 41 | self._indent = indent 42 | 43 | def write(self, data): 44 | self._writable.write('\t{}'.format(data).expandtabs(self._indent)) 45 | 46 | def flush(self): 47 | self._writable.flush() 48 | -------------------------------------------------------------------------------- /docs/_static/reference_policy.json: -------------------------------------------------------------------------------- 1 | { 2 | "Version": "2012-10-17", 3 | "Statement": [ 4 | { 5 | "Effect": "Allow", 6 | "Action": [ 7 | "iam:PassRole", 8 | "ec2:DescribeAccountAttributes", 9 | "ec2:DescribeAvailabilityZones", 10 | "ec2:DescribeSubnets", 11 | "ec2:CreateVolume", 12 | "ec2:ModifyVolume", 13 | "ec2:AttachVolume", 14 | "ec2:DetachVolume", 15 | "ec2:DeleteVolume", 16 | "ec2:DescribeVolumes", 17 | "ec2:DescribeVolumeStatus", 18 | "ec2:DescribeVolumeAttribute", 19 | "ec2:DescribeVolumesModifications", 20 | "ec2:RequestSpotFleet", 21 | "ec2:CancelSpotFleetRequests", 22 | "ec2:RequestSpotInstances", 23 | "ec2:CancelSpotInstanceRequests", 24 | "ec2:ModifySpotFleetRequest", 25 | "ec2:ModifyInstanceAttribute", 26 | "ec2:DescribeSpotFleetRequests", 27 | "ec2:DescribeSpotInstanceRequests", 28 | "ec2:DescribeSpotFleetInstances", 29 | "ec2:DescribeSpotPriceHistory", 30 | "ec2:DescribeSpotFleetRequestHistory", 31 | "ec2:DescribeInstances", 32 | "ec2:DescribeInstanceStatus", 33 | "ec2:DescribeInstanceAttribute", 34 | "ec2:CreateTags", 35 | "ec2:DeleteTags", 36 | "ec2:DescribeTags" 37 | ], 38 | "Resource": "*" 39 | } 40 | ] 41 | } -------------------------------------------------------------------------------- /portal_gun/context_managers/step.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | 4 | 5 | class step(object): 6 | _message_width = 40 7 | _filling_character = ' ' 8 | 9 | @staticmethod 10 | def set_message_width(value): 11 | step._message_width = value 12 | 13 | @staticmethod 14 | def set_filling_character(value): 15 | step._filling_character = value 16 | 17 | def __init__(self, message, error_message=None, catch=None): 18 | """ 19 | :param message: Text to be printed for the step. 20 | :param error_message: Error message to override actually caught exception. 21 | :param catch: List of exceptions to catch (by default catch everything) and wrap in StepError. 22 | """ 23 | self._title = message 24 | self._errors = catch or [Exception] 25 | self._error_message = error_message 26 | 27 | def __enter__(self): 28 | print('{msg:{fill}<{width}}' 29 | .format(msg=self._title, fill=step._filling_character, width=step._message_width), end='') 30 | 31 | # Ensure stdout is flushed immediately 32 | sys.stdout.flush() 33 | 34 | return self 35 | 36 | def __exit__(self, exc_type, exc_value, traceback): 37 | if exc_type is None: 38 | print('OK') 39 | else: 40 | print('ERROR') 41 | for err in self._errors: 42 | if issubclass(exc_type, err): 43 | # Suppress expected error 44 | if self._error_message is not None: 45 | raise StepError('{}'.format(self._error_message)) 46 | else: 47 | raise StepError('{}'.format(exc_value)) 48 | 49 | # Do not suppress unexpected errors 50 | return False 51 | 52 | 53 | class StepError(Exception): 54 | def __init__(self, message=None): 55 | super(StepError, self).__init__(message) 56 | 57 | def __srt__(self): 58 | return self.message 59 | 60 | def __repr__(self): 61 | return self.message 62 | -------------------------------------------------------------------------------- /portal_gun/configuration/schemas/compute_gcp.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, Schema, validates_schema, ValidationError 2 | 3 | from .provision import ProvisionActionSchema 4 | 5 | 6 | class GpuSchema(Schema): 7 | type = fields.String(required=True) 8 | count = fields.Integer(required=True) 9 | 10 | class Meta: 11 | ordered = True 12 | 13 | 14 | class InstanceSchema(Schema): 15 | name = fields.String() 16 | type = fields.String(required=True) 17 | cpu = fields.Integer() 18 | memory = fields.Integer() # in Gb 19 | gpu = fields.Nested(GpuSchema) 20 | image = fields.String(required=True) 21 | availability_zone = fields.String(required=True) 22 | preemptible = fields.Boolean(required=True) 23 | 24 | @validates_schema 25 | def validate_providers(self, data): 26 | if data['type'] == 'custom' and ('cpu' not in data or 'memory' not in data): 27 | raise ValidationError('For "custom" machine type fields "cpu" and "memory" are required') 28 | 29 | class Meta: 30 | ordered = True 31 | 32 | 33 | class AuthSchema(Schema): 34 | private_ssh_key = fields.String(required=True) 35 | public_ssh_key = fields.String(required=True) 36 | user = fields.String(required=True) 37 | group = fields.String(required=True) 38 | 39 | class Meta: 40 | ordered = True 41 | 42 | 43 | # class NetworkSchema(Schema): 44 | # security_group_id = fields.String(required=True) 45 | # subnet_id = fields.String() 46 | 47 | 48 | class ComputeGcpSchema(Schema): 49 | provider = fields.String(required=True) 50 | instance = fields.Nested(InstanceSchema, required=True) 51 | auth = fields.Nested(AuthSchema, required=True) 52 | # network = fields.Nested(NetworkSchema, required=True) 53 | provision_actions = fields.List(fields.Nested(ProvisionActionSchema)) 54 | 55 | class Meta: 56 | ordered = True 57 | -------------------------------------------------------------------------------- /portal_gun/configuration/draft.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from marshmallow import fields 4 | 5 | from portal_gun.one_of_schema import OneOfSchema # replace by the proper marshmallow-oneofschema package 6 | 7 | 8 | def generate_draft(schema, selectors=None): 9 | """ 10 | Generate draft config from a given schema. 11 | :param schema: 12 | :param selectors: Dictionary that maps types of polymorphic schemas to names of particular schemas ({type: string}) 13 | :return: 14 | """ 15 | return _process_schema(schema, selectors) 16 | 17 | 18 | def _process_schema(schema, selectors=None): 19 | draft = OrderedDict() 20 | 21 | if isinstance(schema, OneOfSchema): 22 | # If given schema is polymorphic (has type OneOfSchema), 23 | # use given selectors to replace it by a particular schema 24 | try: 25 | type_field = schema.type_field 26 | schema_name = selectors[type(schema)] 27 | schema = schema.type_schemas[schema_name]() 28 | draft[type_field] = schema_name 29 | except KeyError: 30 | return draft 31 | 32 | for field_name, field in schema.fields.items(): 33 | draft[field_name] = _process_field(field, selectors) 34 | 35 | return draft 36 | 37 | 38 | def _process_field(field, selectors=None): 39 | if isinstance(field, fields.Nested): 40 | field_value = _process_schema(field.schema, selectors) 41 | if field.schema.many: 42 | field_value = [field_value] 43 | elif isinstance(field, fields.List): 44 | field_value = [_process_field(field.container, selectors)] 45 | else: 46 | field_value = _describe_field(field) 47 | 48 | return field_value 49 | 50 | 51 | def _describe_field(field): 52 | description = field.__class__.__name__.lower() 53 | requirement = 'required' if field.required else 'optional' 54 | return '{} ({})'.format(description, requirement) 55 | -------------------------------------------------------------------------------- /portal_gun/commands/generate_portal_spec.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os import path 3 | 4 | from portal_gun.commands.helpers import get_portal_name, get_provider_from_env, get_provider_from_user 5 | from .base_command import BaseCommand 6 | from .handlers import list_providers, describe_providers, generate_portal_spec 7 | 8 | 9 | class GeneratePortalSpecCommand(BaseCommand): 10 | def __init__(self, args): 11 | BaseCommand.__init__(self, args) 12 | 13 | @staticmethod 14 | def cmd(): 15 | return 'init' 16 | 17 | @classmethod 18 | def add_subparser(cls, subparsers): 19 | parser = subparsers.add_parser(cls.cmd(), help='Generate template specification file for new portal') 20 | parser.add_argument('portal', help='Name of portal') 21 | provider_group = parser.add_mutually_exclusive_group() 22 | for desc in describe_providers(): 23 | provider_group.add_argument('--{}'.format(desc['name']), action='store_const', const=desc['name'], 24 | dest='provider', help='Set {} as cloud provider'.format(desc['long_name'])) 25 | 26 | def run(self): 27 | providers = list_providers() 28 | provider_name = self._args.provider or \ 29 | get_provider_from_env(choices=providers) or \ 30 | get_provider_from_user(choices=providers) 31 | 32 | portal_name = get_portal_name(self._args.portal) 33 | 34 | # Confirm portal name to user 35 | print('Creating draft specification for `{}` portal.'.format(portal_name)) 36 | 37 | # Ensure file with this name does not exist 38 | file_name = '{}.json'.format(portal_name) 39 | if path.exists(file_name): 40 | print('File `{}` already exists. Remove the file or pick different name for the portal.'.format(file_name)) 41 | return 42 | 43 | # Generate draft of a portal spec and pretty print it to JSON 44 | spec_str = json.dumps(generate_portal_spec(provider_name), indent=4) 45 | 46 | # Write portal spec to file 47 | with open(file_name, 'w') as f: 48 | f.write(spec_str) 49 | 50 | print('Draft of portal specification has been written to `{}`.'.format(file_name)) 51 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | A setuptools based setup module. 3 | 4 | See: 5 | https://packaging.python.org/en/latest/distributing.html 6 | https://github.com/pypa/sampleproject 7 | """ 8 | 9 | # Always prefer setuptools over distutils 10 | from setuptools import setup, find_packages 11 | # To use a consistent encoding 12 | from codecs import open 13 | from os import path 14 | import re 15 | 16 | here = path.abspath(path.dirname(__file__)) 17 | 18 | # Get version 19 | with open(path.join(here, 'portal_gun/__init__.py')) as f: 20 | version = re.search("^__version__ = '(\d\.\d+\.\d+(\.?(dev|a|b|rc)\d?)?)'$", f.read(), re.M).group(1) 21 | 22 | # Get the long description from the README file 23 | with open(path.join(here, 'README.rst'), encoding='utf-8') as f: 24 | long_description = f.read() 25 | 26 | setup( 27 | name='portal-gun', 28 | version=version, 29 | description=('A command line tool that automates repetitive tasks ' 30 | 'associated with the management of Spot Instances on Amazon EC2 service.'), 31 | long_description=long_description, 32 | long_description_content_type='text/x-rst', 33 | url='https://github.com/Coderik/portal-gun', 34 | author='Vadim Fedorov', 35 | author_email='coderiks@gmail.com', 36 | license='MIT', 37 | classifiers=[ 38 | # 3 - Alpha 39 | # 4 - Beta 40 | # 5 - Production/Stable 41 | 'Development Status :: 4 - Beta', 42 | 'Environment :: Console', 43 | 'Intended Audience :: Science/Research', 44 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 45 | 'License :: OSI Approved :: MIT License', 46 | 'Programming Language :: Python :: 2', 47 | 'Programming Language :: Python :: 2.7', 48 | ], 49 | packages=find_packages(exclude=()), 50 | install_requires=[ 51 | 'boto3>=1.5.18', 52 | 'Fabric>=1.14.0', 53 | 'marshmallow>=3.0.0b8', 54 | 'google-auth>=1.4.0', 55 | 'google-api-python-client>=1.6.7', 56 | 'google-auth-httplib2' 57 | ], 58 | entry_points={ # Optional 59 | 'console_scripts': [ 60 | 'portal=portal_gun.main:main', 61 | ], 62 | }, 63 | project_urls={ 64 | 'Bug Reports': 'https://github.com/Coderik/portal-gun/issues', 65 | 'Source': 'https://github.com/Coderik/portal-gun', 66 | }, 67 | # keywords='words separated by whitespace', 68 | # package_data={ 69 | # 'sample': ['package_data.dat'], 70 | # }, 71 | ) 72 | -------------------------------------------------------------------------------- /portal_gun/commands/show_portal_info.py: -------------------------------------------------------------------------------- 1 | from portal_gun.commands.helpers import get_provider_config, get_portal_spec, get_portal_name, \ 2 | get_provider_from_portal 3 | from portal_gun.context_managers.no_print import no_print 4 | from portal_gun.context_managers.print_scope import print_scope 5 | from .base_command import BaseCommand 6 | from .handlers import create_handler 7 | 8 | 9 | class ShowPortalInfoCommand(BaseCommand): 10 | FIELDS = ['name', 'status', 'id', 'type', 'user', 'host', 'ip', 'remote', 'key'] 11 | 12 | def __init__(self, args): 13 | BaseCommand.__init__(self, args) 14 | 15 | @staticmethod 16 | def cmd(): 17 | return 'info' 18 | 19 | @classmethod 20 | def add_subparser(cls, subparsers): 21 | parser = subparsers.add_parser(cls.cmd(), help='Show information about portal') 22 | parser.add_argument('portal', help='Name of portal') 23 | parser.add_argument('-f', '--field', dest='field', help='Print value for a specified field ({}).' 24 | .format(', '.join(cls.FIELDS))) 25 | 26 | def run(self): 27 | if self._args.field is not None: 28 | # Get value of the specified field and print it 29 | value = self.get_field(self._args.field) 30 | if value is not None: 31 | print(value) 32 | else: 33 | self.show_full_info() 34 | 35 | def get_field(self, field): 36 | # Ensure field name is valid 37 | if field not in self.FIELDS: 38 | return None 39 | 40 | with no_print(): 41 | # Find, parse and validate configs 42 | portal_name = get_portal_name(self._args.portal) 43 | portal_spec = get_portal_spec(portal_name) 44 | provider_name = get_provider_from_portal(portal_spec) 45 | provider_config = get_provider_config(self._args.config, provider_name) 46 | 47 | # Create appropriate command handler for given cloud provider 48 | handler = create_handler(provider_name, provider_config) 49 | 50 | return handler.get_portal_info_field(portal_spec, portal_name, field) 51 | 52 | def show_full_info(self): 53 | # Find, parse and validate configs 54 | with print_scope('Checking configuration:', 'Done.\n'): 55 | portal_name = get_portal_name(self._args.portal) 56 | portal_spec = get_portal_spec(portal_name) 57 | provider_name = get_provider_from_portal(portal_spec) 58 | provider_config = get_provider_config(self._args.config, provider_name) 59 | 60 | # Create appropriate command handler for given cloud provider 61 | handler = create_handler(provider_name, provider_config) 62 | 63 | handler.show_portal_info(portal_spec, portal_name) 64 | -------------------------------------------------------------------------------- /portal_gun/commands/ssh.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from portal_gun.commands.helpers import get_provider_config, get_portal_spec, get_portal_name, \ 4 | get_provider_from_portal 5 | from portal_gun.context_managers.no_print import no_print 6 | from .base_command import BaseCommand 7 | from .handlers import create_handler 8 | 9 | 10 | class SshCommand(BaseCommand): 11 | DEFAULT_TMUX_SESSION = 'portal' 12 | 13 | def __init__(self, args): 14 | BaseCommand.__init__(self, args) 15 | 16 | @staticmethod 17 | def cmd(): 18 | return 'ssh' 19 | 20 | @classmethod 21 | def add_subparser(cls, subparsers): 22 | parser = subparsers.add_parser(cls.cmd(), help='Connect to the remote host via ssh') 23 | parser.add_argument('portal', help='Name of portal') 24 | parser.add_argument('-t', '--tmux', dest='tmux', nargs='?', default=None, const=cls.DEFAULT_TMUX_SESSION, 25 | metavar='session', help='Automatically open tmux session upon connection. ' 26 | 'Default session name is `{}`.'.format(cls.DEFAULT_TMUX_SESSION)) 27 | 28 | def run(self): 29 | # Find, parse and validate configs 30 | with no_print(): 31 | portal_name = get_portal_name(self._args.portal) 32 | portal_spec = get_portal_spec(portal_name) 33 | provider_name = get_provider_from_portal(portal_spec) 34 | provider_config = get_provider_config(self._args.config, provider_name) 35 | 36 | # Create appropriate command handler for given cloud provider 37 | handler = create_handler(provider_name, provider_config) 38 | 39 | identity_file, user, host, disable_known_hosts = handler.get_ssh_params(portal_spec, portal_name) 40 | 41 | print('Connecting to the remote machine...') 42 | print('\tssh -i "{}" {}@{}'.format(identity_file, user, host).expandtabs(4)) 43 | 44 | # If needed, disable strict known-hosts check 45 | options = [] 46 | if disable_known_hosts: 47 | options = [ 48 | '-o', 49 | 'StrictHostKeyChecking=no' 50 | ] 51 | 52 | # If requested, configure a preamble (a set of commands to be run automatically after connection) 53 | preamble = [] 54 | if self._args.tmux is not None: 55 | preamble = [ 56 | '-t', 57 | '""tmux attach-session -t {sess} || tmux new-session -s {sess}""'.format(sess=self._args.tmux) 58 | ] 59 | print('Upon connection will open tmux session `{}`.'.format(self._args.tmux)) 60 | 61 | print('') 62 | 63 | # Ssh to remote host (effectively replace current process by ssh) 64 | os.execvp('ssh', ['ssh', '-i', identity_file, '{}@{}'.format(user, host)] + options + preamble) 65 | -------------------------------------------------------------------------------- /portal_gun/commands/handlers/base_handler.py: -------------------------------------------------------------------------------- 1 | class BaseHandler(object): 2 | """ Base class for all specific Handlers. """ 3 | def __init__(self, config): 4 | self._config = config 5 | 6 | @staticmethod 7 | def provider_name(): 8 | """ Short name uniquely identifying cloud provider """ 9 | raise NotImplementedError('Every subclass of BaseHandler should implement static provider_name() method.') 10 | 11 | @staticmethod 12 | def provider_long_name(): 13 | """ Human-readable descriptive name of cloud provider """ 14 | raise NotImplementedError('Every subclass of BaseHandler should implement static provider_long_name() method.') 15 | 16 | @staticmethod 17 | def generate_portal_spec(): 18 | """ 19 | Generate draft of portal specification in dictionary format. 20 | :rtype dict 21 | """ 22 | raise NotImplementedError('Every subclass of BaseHandler should implement static generate_portal_spec() method.') 23 | 24 | def open_portal(self, portal_spec, portal_name): 25 | raise NotImplementedError('Every subclass of BaseHandler should implement open_portal() method.') 26 | 27 | def close_portal(self, portal_spec, portal_name): 28 | raise NotImplementedError('Every subclass of BaseHandler should implement close_portal() method.') 29 | 30 | def show_portal_info(self, portal_spec, portal_name): 31 | raise NotImplementedError('Every subclass of BaseHandler should implement show_portal_info() method.') 32 | 33 | def get_portal_info_field(self, portal_spec, portal_name, field): 34 | raise NotImplementedError('Every subclass of BaseHandler should implement get_portal_info_field() method.') 35 | 36 | def get_ssh_params(self, portal_spec, portal_name): 37 | """ 38 | Get parameters for ssh connection 39 | :param portal_spec: 40 | :param portal_name: 41 | :return: (identity file, remote user, host, disable_known_hosts) 42 | :rtype (str, str, str, bool) 43 | """ 44 | raise NotImplementedError('Every subclass of BaseHandler should implement get_ssh_details() method.') 45 | 46 | def list_volumes(self, args): 47 | raise NotImplementedError('Every subclass of BaseHandler should implement list_volumes() method.') 48 | 49 | def create_volume(self, args): 50 | raise NotImplementedError('Every subclass of BaseHandler should implement create_volume() method.') 51 | 52 | def update_volume(self, args): 53 | raise NotImplementedError('Every subclass of BaseHandler should implement update_volume() method.') 54 | 55 | def delete_volume(self, args): 56 | raise NotImplementedError('Every subclass of BaseHandler should implement delete_volume() method.') 57 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | Portal Gun 3 | ========== 4 | 5 | Command line tool that automates routine tasks associated with the management of Spot Instances on Amazon EC2 service. 6 | 7 | Primarily it is intended to simplify usage of AWS Spot Instances for Deep Learning. This focus will further shape the future development. 8 | 9 | Documentation 10 | ============= 11 | 12 | Full documentation can be found at `http://portal-gun.readthedocs.io `_. 13 | 14 | Installation 15 | ============ 16 | 17 | It is **strongly recommended** to install Portal Gun in **a virtual Python environment**. 18 | 19 | To install the latest stable version from the PyPI:: 20 | 21 | $ pip install -U portal-gun 22 | 23 | To install the latest pre-release version from the PyPI:: 24 | 25 | $ pip install -U portal-gun --pre 26 | 27 | Refer to the documentation for details regarding `general configuration `_ 28 | and `portal specification `_. 29 | 30 | Basic Usage 31 | =========== 32 | 33 | 1. Persistent volumes 34 | --------------------- 35 | 36 | Use ``volume`` group of commands to work with EBS volumes. 37 | 38 | Create a new volume:: 39 | 40 | $ portal volume create 41 | 42 | List created volumes:: 43 | 44 | $ portal volume list 45 | 46 | Update previously created volume:: 47 | 48 | $ portal volume update [-n ] [-s ] 49 | 50 | Delete previously created volume:: 51 | 52 | $ portal volume delete 53 | 54 | 2. Portals 55 | ---------- 56 | 57 | Create draft specification for a new portal:: 58 | 59 | $ portal init 60 | 61 | Open a portal (request a new Spot Instance):: 62 | 63 | $ portal open 64 | 65 | Connect to the Spot Instance via ssh:: 66 | 67 | $ portal ssh 68 | 69 | Connect to the Spot Instance via ssh and attach to a tmux session (session name is optional):: 70 | 71 | $ portal ssh -t [] 72 | 73 | Close opened portal (cancel Spot Instance request):: 74 | 75 | $ portal close 76 | 77 | Get information about a portal:: 78 | 79 | $ portal info 80 | 81 | 82 | 3. Channels 83 | ----------- 84 | 85 | Start syncing files across the channels configured for a portal:: 86 | 87 | $ portal channel 88 | 89 | License 90 | ======= 91 | 92 | MIT licensed. See the bundled `LICENSE `_ file for details. -------------------------------------------------------------------------------- /portal_gun/providers/gcp/helpers.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | 4 | 5 | def get_instance_name(portal_spec, portal_name): 6 | try: 7 | name = portal_spec['compute']['instance']['name'] 8 | except KeyError: 9 | name = portal_name 10 | 11 | # Make name compliant with RFC1035 12 | name = re.sub(r'[_\s]', '-', name) 13 | name = re.sub(r'(^[^a-z]*|[^a-z0-9-]|-*$)', '', name.lower())[:63] 14 | 15 | return name 16 | 17 | 18 | def build_instance_props(portal_spec, instance_name): 19 | # Define shortcuts 20 | instance_spec = portal_spec['compute']['instance'] 21 | auth_spec = portal_spec['compute']['auth'] 22 | zone = instance_spec['availability_zone'] 23 | 24 | # Construct partial url for machine type 25 | if instance_spec['type'] == 'custom': 26 | cpu = instance_spec['cpu'] 27 | if cpu > 1: 28 | cpu = int(math.ceil(cpu / 2.0) * 2) 29 | memory = int(math.ceil(instance_spec['memory'] * 4.0)) * 256 30 | machine_type = 'zones/{}/machineTypes/custom-{}-{}'.format(zone, cpu, memory) 31 | else: 32 | machine_type = 'zones/{}/machineTypes/{}'.format(zone, instance_spec['type']) 33 | 34 | # Read public key from file 35 | with open(auth_spec['public_ssh_key'], 'r') as f: 36 | public_ssh_key = f.readline() 37 | 38 | # Fill props 39 | props = { 40 | 'scheduling': { 41 | 'preemptible': instance_spec['preemptible'] 42 | }, 43 | 'networkInterfaces': [ 44 | { 45 | 'network': 'global/networks/default', 46 | 'accessConfigs': [ 47 | { 48 | 'name': 'External NAT', 49 | 'type': 'ONE_TO_ONE_NAT' 50 | } 51 | ] 52 | } 53 | ], 54 | 'machineType': machine_type, 55 | 'name': instance_name, 56 | 'disks': [{ 57 | 'initializeParams': { 58 | 'diskName': '{}-boot'.format(instance_name), 59 | 'diskSizeGb': 20, 60 | 'sourceImage': 'global/images/{}'.format(instance_spec['image']) 61 | 62 | }, 63 | 'autoDelete': True, 64 | 'boot': True 65 | }], 66 | 'metadata': { 67 | 'items': [ 68 | { 69 | 'key': 'ssh-keys', 70 | 'value': '{}:{}'.format(auth_spec['user'], public_ssh_key) 71 | } 72 | ] 73 | } 74 | } 75 | 76 | # Specify GPU, if needed 77 | if 'gpu' in instance_spec: 78 | props['guestAccelerators'] = [{ 79 | 'acceleratorType': 'zones/{}/acceleratorTypes/{}'.format(zone, instance_spec['gpu']['type']), 80 | 'acceleratorCount': instance_spec['gpu']['count'] 81 | }] 82 | 83 | # Add persistent volumes 84 | for volume in portal_spec['persistent_volumes']: 85 | props['disks'].append({ 86 | 'source': 'zones/{}/disks/{}'.format(zone, volume['volume_id']), 87 | 'mode': 'READ_WRITE', 88 | 'autoDelete': False, 89 | 'boot': False 90 | }) 91 | 92 | return props 93 | -------------------------------------------------------------------------------- /docs/config.rst: -------------------------------------------------------------------------------- 1 | .. _config: 2 | 3 | ============= 4 | Configuration 5 | ============= 6 | 7 | Application Config 8 | ================== 9 | 10 | Portal Gun reads basic configuration from a file in JSON format. By default it looks for a file named ``config.json`` in the following locations (in that order): 11 | 12 | 1. script running path 13 | 2. ``~/.portal-gun/`` 14 | 15 | When Portal Gun is installed in a virtual Python environment (recommended), script running path is ``/virtual-env-path/bin/``. 16 | 17 | A custom location and filename may be specified using ``-c, --config`` argument. 18 | 19 | ---- 20 | 21 | Values to set in the configuration file: 22 | 23 | .. code-block:: json 24 | 25 | { 26 | "aws_region": "current AWS region", 27 | "aws_access_key": "access key for your AWS account", 28 | "aws_secret_key": "secret key for your AWS account" 29 | } 30 | 31 | Credentials (access and secret keys) for programmatic access on behalf of your AWS account can be found in the `IAM Console `_. **It is recommended to create a separate user** for programmatic access via Portal Gun. 32 | 33 | AWS Access Rights 34 | ================= 35 | 36 | Portal Gun requires the following access rights:: 37 | 38 | iam:PassRole 39 | 40 | ec2:DescribeAccountAttributes 41 | 42 | ec2:DescribeAvailabilityZones 43 | ec2:DescribeSubnets 44 | 45 | ec2:CreateVolume 46 | ec2:ModifyVolume 47 | ec2:AttachVolume 48 | ec2:DetachVolume 49 | ec2:DeleteVolume 50 | ec2:DescribeVolumes 51 | ec2:DescribeVolumeStatus 52 | ec2:DescribeVolumeAttribute 53 | ec2:DescribeVolumesModifications 54 | 55 | ec2:RequestSpotFleet 56 | ec2:CancelSpotFleetRequests 57 | ec2:RequestSpotInstances 58 | ec2:CancelSpotInstanceRequests 59 | ec2:ModifySpotFleetRequest 60 | ec2:ModifyInstanceAttribute 61 | ec2:DescribeSpotFleetRequests 62 | ec2:DescribeSpotInstanceRequests 63 | ec2:DescribeSpotFleetInstances 64 | ec2:DescribeSpotPriceHistory 65 | ec2:DescribeSpotFleetRequestHistory 66 | ec2:DescribeInstances 67 | ec2:DescribeInstanceStatus 68 | ec2:DescribeInstanceAttribute 69 | 70 | ec2:CreateTags 71 | ec2:DeleteTags 72 | ec2:DescribeTags 73 | 74 | `IAM Policy `_ is the most convenient way to grant required permissions. 75 | Create a new policy and attach it to a user which will be used for programmatic access via Portal Gun. 76 | 77 | Reference policy granting required permissions can be found `here <_static/reference_policy.json>`_. You can make it more strict, for instance, by limiting access right to a particular region. 78 | 79 | Additional Resources 80 | ==================== 81 | 82 | - `Controlling Access Using Policies `_ -------------------------------------------------------------------------------- /portal_gun/providers/aws/helpers.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | 4 | def to_aws_tags(tags): 5 | """ 6 | Convert tags from dictionary to a format expected by AWS: 7 | [{'Key': key, 'Value': value}] 8 | :param tags 9 | :return: 10 | """ 11 | return [{'Key': k, 'Value': v} for k, v in tags.items()] 12 | 13 | 14 | def from_aws_tags(tags): 15 | """ 16 | Convert tags from AWS format [{'Key': key, 'Value': value}] to dictionary 17 | :param tags 18 | :return: 19 | """ 20 | return {tag['Key']: tag['Value'] for tag in tags} 21 | 22 | 23 | def single_instance_spot_fleet_request(portal_spec, portal_name, user): 24 | # Define shortcuts 25 | instance_spec = portal_spec['compute']['instance'] 26 | network_spec = portal_spec['compute']['network'] 27 | auth_spec = portal_spec['compute']['auth'] 28 | 29 | fleet_request_config = { 30 | 'AllocationStrategy': 'lowestPrice', 31 | 'IamFleetRole': instance_spec['iam_fleet_role'], 32 | 'TargetCapacity': 1, 33 | 'ValidFrom': datetime.datetime.utcnow().isoformat().rsplit('.', 1)[0] + 'Z', 34 | 'ValidUntil': (datetime.datetime.utcnow() + datetime.timedelta(days=60)).isoformat().rsplit('.', 1)[0] + 'Z', 35 | 'TerminateInstancesWithExpiration': True, 36 | 'Type': 'request', 37 | 'LaunchSpecifications': [ 38 | { 39 | 'ImageId': instance_spec['image_id'], 40 | 'InstanceType': instance_spec['type'], 41 | 'KeyName': auth_spec['key_pair_name'], 42 | 'Placement': { 43 | 'AvailabilityZone': instance_spec['availability_zone'] 44 | }, 45 | 'NetworkInterfaces': [{ 46 | 'SubnetId': network_spec['subnet_id'], 47 | 'Groups': [network_spec['security_group_id']], 48 | 'DeviceIndex': 0 49 | }], 50 | 'TagSpecifications': [{ 51 | 'ResourceType': 'instance', 52 | 'Tags': [ 53 | {'Key': 'portal-name', 'Value': portal_name}, 54 | {'Key': 'created-by', 'Value': user}, 55 | ] 56 | }] 57 | } 58 | ] 59 | } 60 | 61 | # Add provided optional fields 62 | if 'ebs_optimized' in instance_spec: 63 | fleet_request_config['LaunchSpecifications'][0]['EbsOptimized'] = instance_spec['ebs_optimized'] 64 | 65 | return fleet_request_config 66 | 67 | 68 | def build_instance_launch_spec(portal_spec): 69 | # Define shortcuts 70 | instance_spec = portal_spec['compute']['instance'] 71 | network_spec = portal_spec['compute']['network'] 72 | auth_spec = portal_spec['compute']['auth'] 73 | 74 | # Set required fields 75 | aws_launch_spec = { 76 | 'SecurityGroupIds': [network_spec['security_group_id']], 77 | 'ImageId': instance_spec['image_id'], 78 | 'InstanceType': instance_spec['type'], 79 | 'KeyName': auth_spec['key_pair_name'], 80 | 'Placement': { 81 | 'AvailabilityZone': instance_spec['availability_zone'] 82 | } 83 | } 84 | 85 | # Add provided optional fields 86 | if 'ebs_optimized' in instance_spec: 87 | aws_launch_spec['EbsOptimized'] = instance_spec['ebs_optimized'] 88 | 89 | return aws_launch_spec 90 | -------------------------------------------------------------------------------- /portal_gun/commands/open_channel.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import portal_gun.fabric as fab 4 | from portal_gun.commands.helpers import get_provider_config, get_portal_spec, get_portal_name, \ 5 | get_provider_from_portal 6 | from portal_gun.context_managers.print_scope import print_scope 7 | from portal_gun.context_managers.step import step 8 | from .base_command import BaseCommand 9 | from .handlers import create_handler 10 | 11 | 12 | def run_periodically(callable, callable_args, delay): 13 | callable(*callable_args) 14 | threading.Timer(delay, run_periodically, args=[callable, callable_args, delay]).start() 15 | 16 | 17 | class OpenChannelCommand(BaseCommand): 18 | def __init__(self, args): 19 | BaseCommand.__init__(self, args) 20 | 21 | @staticmethod 22 | def cmd(): 23 | return 'channel' 24 | 25 | @classmethod 26 | def add_subparser(cls, subparsers): 27 | parser = subparsers.add_parser(cls.cmd(), help='Open channels for files synchronization') 28 | parser.add_argument('portal', help='Name of portal') 29 | 30 | def run(self): 31 | # Find, parse and validate configs 32 | with print_scope('Checking configuration:', 'Done.\n'): 33 | portal_name = get_portal_name(self._args.portal) 34 | portal_spec = get_portal_spec(portal_name) 35 | provider_name = get_provider_from_portal(portal_spec) 36 | provider_config = get_provider_config(self._args.config, provider_name) 37 | 38 | # Ensure there is at least one channel spec 39 | with step('Check specifications for channels', 40 | error_message='Portal specification does not contain any channel'): 41 | channels = portal_spec['channels'] 42 | if len(channels) == 0: 43 | raise Exception() 44 | 45 | # Create appropriate command handler for given cloud provider 46 | handler = create_handler(provider_name, provider_config) 47 | 48 | identity_file, user, host, disable_known_hosts = handler.get_ssh_params(portal_spec, portal_name) 49 | 50 | # Print information about the channels 51 | with print_scope('Channels defined for portal `{}`:'.format(portal_name), ''): 52 | for i in range(len(channels)): 53 | channel = channels[i] 54 | with print_scope('Channel #{} ({}):'.format(i, channel['direction'].upper())): 55 | print('Local: {}'.format(channel['local_path'])) 56 | print('Remote: {}'.format(channel['remote_path'])) 57 | 58 | # Configure ssh connection via fabric 59 | fab_conn = fab.create_connection(host, user, identity_file) 60 | 61 | # Periodically sync files across all channels 62 | print('Syncing... (press ctrl+C to interrupt)') 63 | for channel in channels: 64 | is_upload = channel['direction'] == 'out' 65 | is_recursive = channel['recursive'] if 'recursive' in channel else False 66 | delay = 1.0 67 | if 'delay' in channel: 68 | delay = channel['delay'] 69 | run_periodically(fab.sync_files, 70 | [fab_conn, channel['local_path'], channel['remote_path'], is_upload, is_recursive], delay) 71 | -------------------------------------------------------------------------------- /portal_gun/commands/helpers.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import json 4 | from os import path, environ 5 | 6 | from portal_gun.configuration.constants import config_paths, cloud_provider_env 7 | from portal_gun.configuration.schemas import ConfigSchema, PortalSchema, ValidationError 8 | from portal_gun.context_managers.step import step 9 | 10 | 11 | def get_provider_config(config_path, provider_name): 12 | # Parse general config 13 | with step('Parse general config file', catch=[IOError, ValueError]): 14 | # If config file is not specified in arguments, look for it in default locations 15 | if config_path is None: 16 | for p in config_paths: 17 | if path.exists(p): 18 | config_path = p 19 | break 20 | else: 21 | raise ValueError('Could not find config file') 22 | 23 | with open(config_path) as config_file: 24 | config_data = json.load(config_file) 25 | 26 | # Validate global config 27 | with step('Validate general config', catch=[ValidationError]): 28 | config = ConfigSchema().load(config_data) 29 | 30 | # Retrieve cloud provider config 31 | with step('Retrieve provider config ({})'.format(provider_name), 32 | error_message='Cloud provider {} is not configured'.format(provider_name), catch=[KeyError]): 33 | provider_config = config[provider_name] 34 | 35 | return provider_config 36 | 37 | 38 | def get_portal_name(portal_arg): 39 | return portal_arg.rsplit('.', 1)[0] 40 | 41 | 42 | def get_portal_spec(portal_name): 43 | # Get portal spec file 44 | spec_filename = '{}.json'.format(portal_name) 45 | 46 | # Ensure spec file exists 47 | with step('Locate portal specification file'): 48 | if not path.exists(spec_filename): 49 | raise Exception('Could not find portal specification file `{}`.'.format(spec_filename)) 50 | 51 | # Parse portal spec file 52 | with step('Parse portal specification file', catch=[IOError, ValueError]): 53 | with open(spec_filename) as spec_file: 54 | portal_spec_data = json.load(spec_file) 55 | 56 | # Validate portal spec 57 | with step('Validate portal specification', catch=[ValidationError]): 58 | portal_spec = PortalSchema().load(portal_spec_data) 59 | 60 | return portal_spec 61 | 62 | 63 | def get_provider_from_portal(portal_spec): 64 | return portal_spec['compute']['provider'] 65 | 66 | 67 | def get_provider_from_env(choices): 68 | try: 69 | provider = environ[cloud_provider_env] 70 | if provider in choices: 71 | return provider 72 | except KeyError: 73 | pass 74 | 75 | return None 76 | 77 | 78 | def get_provider_from_user(choices): 79 | provider = None 80 | while provider is None: 81 | print('Select cloud provider [{}]: '.format(', '.join(choices)), end='') 82 | provider = input() or None 83 | if provider not in choices: 84 | provider = None 85 | 86 | print() 87 | 88 | return provider 89 | 90 | 91 | __all__ = [ 92 | 'get_provider_config', 93 | 'get_portal_name', 94 | 'get_portal_spec', 95 | 'get_provider_from_portal', 96 | 'get_provider_from_env', 97 | 'get_provider_from_user' 98 | ] 99 | -------------------------------------------------------------------------------- /portal_gun/providers/gcp/gcp_client.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from google.oauth2 import service_account 4 | from googleapiclient.errors import HttpError 5 | import googleapiclient.discovery 6 | 7 | from portal_gun.providers.exceptions import ProviderRequestError 8 | 9 | 10 | def gcp_api_caller(): 11 | from functools import wraps 12 | 13 | def api_caller_decorator(func): 14 | @wraps(func) 15 | def wrapper(*args, **kwargs): 16 | try: 17 | return func(*args, **kwargs) 18 | except HttpError as e: 19 | content = json.loads(e.content) 20 | raise ProviderRequestError(content['error']['message']) 21 | 22 | return wrapper 23 | 24 | return api_caller_decorator 25 | 26 | 27 | class GcpClient(object): 28 | def __init__(self, service_account_file, project, region): 29 | self._service_account_file = service_account_file 30 | self._project = project 31 | self._region = region 32 | self._gce_client = None 33 | 34 | @gcp_api_caller() 35 | def request_instance(self, props): 36 | response = self.gce_client().instances().insert(project=self._project, zone=self._region, body=props).execute() 37 | 38 | return response 39 | 40 | @gcp_api_caller() 41 | def get_instance(self, name): 42 | response = self.gce_client().instances().get(project=self._project, zone=self._region, instance=name) \ 43 | .execute() 44 | 45 | return response 46 | 47 | @gcp_api_caller() 48 | def find_instance(self, name): 49 | flt = 'name = {}'.format(name) 50 | response = self.gce_client().instances().list(project=self._project, zone=self._region, filter=flt) \ 51 | .execute() 52 | 53 | if 'items' not in response or len(response['items']) == 0: 54 | return None 55 | 56 | return response['items'][0] 57 | 58 | @gcp_api_caller() 59 | def delete_instance(self, name): 60 | response = self.gce_client().instances().delete(project=self._project, zone=self._region, instance=name) \ 61 | .execute() 62 | 63 | return response 64 | 65 | @gcp_api_caller() 66 | def get_operation(self, name): 67 | response = self.gce_client().zoneOperations().get(project=self._project, zone=self._region, operation=name) \ 68 | .execute() 69 | 70 | return response 71 | 72 | @gcp_api_caller() 73 | def cancel_instance_request(self, name): 74 | response = self.gce_client().instances().delete(project=self._project, zone=self._region, instance=name).execute() 75 | 76 | return response 77 | 78 | @gcp_api_caller() 79 | def get_volumes(self): 80 | response = self.gce_client().disks().list(project=self._project, zone=self._region).execute() 81 | 82 | return response['items'] 83 | 84 | def gce_client(self): 85 | if self._gce_client is None: 86 | try: 87 | credentials = service_account.Credentials.from_service_account_file(self._service_account_file) 88 | except IOError as e: 89 | raise ProviderRequestError('Could not find service account file: {}'.format(e.filename)) 90 | 91 | self._gce_client = googleapiclient.discovery.build('compute', 'v1', credentials=credentials) 92 | 93 | return self._gce_client 94 | -------------------------------------------------------------------------------- /docs/overview.rst: -------------------------------------------------------------------------------- 1 | .. _overview: 2 | 3 | ======== 4 | Overview 5 | ======== 6 | 7 | **Portal Gun** originates from the necessity to rent some GPU resources for **Deep Learning** and the natural aspiration to save some money. Of course, it might be useful in other use cases involving AWS Spot Instances. 8 | 9 | Notice, though, that Portal Gun is not a generic tool. If you need full control over AWS resources from command line, use the `AWS CLI `_ instead. 10 | 11 | .. _concepts: 12 | 13 | Concepts 14 | ======== 15 | 16 | Portal 17 | ------ 18 | 19 | Portal Gun was design around the concept of *portals*, hence the name. A *portal* represents a remote environment and encapsulates such things as virtual server (Spot Instance) of some type, persistent storage, operating system of choice, libraries and frameworks, etc. 20 | 21 | To *open* a portal means to request a Spot Instance. To *close* a portal means to cancel the request and terminate the instance. For example, if you are training a model, you open a portal for a training session and close it, when the training is finished. If you follow the recommended workflow (:ref:`see bellow `), you should be able to open the portal again and find everything exactly like you left it before. 22 | 23 | A portal is defined by a *portal specification* file which describes a particular environment in JSON format. 24 | 25 | Portal specification includes:: 26 | - characteristics of a Spot Instance to be requested (instance type, key pair for secure connection, security group, availability zone, etc.); 27 | - software configuration (AMI, extra dependencies to be installed, etc.); 28 | - persistent data storage (see bellow); 29 | - data synchronization channels (see bellow). 30 | 31 | Persistent Volume 32 | ----------------- 33 | 34 | AWS Spot Instance are volatile by nature, therefore, some external storage is needed to persist the data. The most efficient option is `EBS Volume `_. 35 | 36 | Portal Gun allows you to manage EBS Volumes from command line. It also automatically attaches and mounts volumes to instances according to the portal specifications. You might have a single volume to store everything (dataset, code, checkpoints of training, etc.) or use separate volumes for each type of data. 37 | 38 | Channel 39 | ------- 40 | 41 | *Channels* can be defined in portal specification to synchronize files between a Spot Instance and your local machine. Synchronization is done continuously using ``rsync`` and should be started explicitly with a command. Every channel is either *inbound* (files are moved from remote to local) or *outbound* (files are moved from local to remote). 42 | 43 | For instance, you may edit scripts locally and configure a channel to send them to the remote instance after every save. You might configure another channel to automatically get some intermediate results from the remote instance to your local machine for preview. 44 | 45 | .. _ref-workflow: 46 | 47 | Typical Workflow 48 | ================ 49 | 50 | A typical Deep Learning workflow with Portal Gun is as follows: 51 | 1. Using Portal Gun create a new volume (e.g. named 'data') for all your data; 52 | 2. Configure a portal backed by the 'data' volume and a non-GPU instance; 53 | 3. Open the portal configured in step 2; 54 | 4. Connect to the non-GPU instance and copy all necessary data to the 'data' volume; 55 | 5. Close the portal configured in step 2; 56 | 6. Configure a portal backed by the 'data' volume and a GPU instance; 57 | 7. Open the portal configured in step 6; 58 | 8. Run training on the GPU instance; 59 | 9. Close the portal configured in step 6. -------------------------------------------------------------------------------- /portal_gun/commands/volume.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from portal_gun.commands.helpers import get_provider_config, get_provider_from_env, get_provider_from_user 4 | from portal_gun.context_managers.print_scope import print_scope 5 | from .base_command import BaseCommand 6 | from .handlers import list_providers, describe_providers, create_handler 7 | 8 | 9 | class VolumeCommand(BaseCommand): 10 | def __init__(self, args): 11 | BaseCommand.__init__(self, args) 12 | 13 | self._proper_tag_key = 'dimension' 14 | self._proper_tag_value = 'C-137' 15 | self._service_tags = [self._proper_tag_key, 'created-by', 'mount-point'] 16 | self._default_size = 50 # Gb 17 | self._min_size = 1 # Gb 18 | self._max_size = 16384 # Gb 19 | 20 | @staticmethod 21 | def cmd(): 22 | return 'volume' 23 | 24 | @classmethod 25 | def add_subparser(cls, command_parsers): 26 | parser = command_parsers.add_parser(cls.cmd(), help='Group of subcommands related to persistent volumes') 27 | provider_group = parser.add_mutually_exclusive_group() 28 | for desc in describe_providers(): 29 | provider_group.add_argument('--{}'.format(desc['name']), action='store_const', const=desc['name'], 30 | dest='provider', help='Set {} as cloud provider'.format(desc['long_name'])) 31 | 32 | subcommand_parsers = parser.add_subparsers(title='subcommands', dest='subcommand') 33 | 34 | # List 35 | parser_list = subcommand_parsers.add_parser('list', help='List persistent volumes') 36 | parser_list.add_argument('-a', '--all', dest='all', action='store_true', 37 | help='Show all volumes, not only ones created by Portal Gun.') 38 | parser_list.set_defaults(actor=lambda handler, args: handler.list_volumes(args)) 39 | 40 | # Create 41 | parser_create = subcommand_parsers.add_parser('create', help='Create new volume') 42 | parser_create.add_argument('-n', '--name', dest='name', default=None, help='Set name for new volume.') 43 | parser_create.add_argument('-s', '--size', dest='size', default=None, type=int, 44 | help='Set size (in Gb) for new volume.') 45 | parser_create.add_argument('-z', '--zone', dest='zone', default=None, help='Set availability zone for new volume.') 46 | parser_create.add_argument('-S', '--snapshot', dest='snapshot', default=None, 47 | help='Set Id of a snapshot to create new volume from.') 48 | parser_create.add_argument('-t', '--tags', nargs='+', dest='tags', metavar='key:value', 49 | help='Set user tags for new volume.') 50 | parser_create.set_defaults(actor=lambda handler, args: handler.create_volume(args)) 51 | # TODO: add silent mode 52 | 53 | # Update 54 | parser_update = subcommand_parsers.add_parser('update', help='Update persistent volume') 55 | parser_update.add_argument(dest='volume_id', help='Volume Id.') 56 | parser_update.add_argument('-n', '--name', dest='name', help='Update name of volume.') 57 | parser_update.add_argument('-s', '--size', dest='size', type=int, help='Update size of volume.') 58 | parser_update.add_argument('-t', '--tags', nargs='+', dest='tags', metavar='key:value', 59 | help='Add user tags for volume.') 60 | parser_update.set_defaults(actor=lambda handler, args: handler.update_volume(args)) 61 | 62 | # Delete 63 | parser_delete = subcommand_parsers.add_parser('delete', help='Delete persistent volume') 64 | parser_delete.add_argument(dest='volume_id', help='Volume Id.') 65 | parser_delete.add_argument('-f', '--force', dest='force', action='store_true', 66 | help='Delete any volume, even not owned.') 67 | parser_delete.set_defaults(actor=lambda handler, args: handler.delete_volume(args)) 68 | 69 | def run(self): 70 | providers = list_providers() 71 | provider_name = self._args.provider or \ 72 | get_provider_from_env(choices=providers) or \ 73 | get_provider_from_user(choices=providers) 74 | 75 | # Find, parse and validate configs 76 | with print_scope('Checking configuration:', 'Done.\n'): 77 | provider_config = get_provider_config(self._args.config, provider_name) 78 | 79 | # Create appropriate command handler for given cloud provider 80 | handler = create_handler(provider_name, provider_config) 81 | 82 | # Call corresponding actor to handle selected subcommand 83 | self._args.actor(handler, self._args) 84 | -------------------------------------------------------------------------------- /portal_gun/fabric/operations.py: -------------------------------------------------------------------------------- 1 | from invoke.vendor import six 2 | import fabric.connection 3 | 4 | 5 | def create_connection(host, user, identity_file): 6 | return fabric.connection.Connection(host=host, 7 | user=user, 8 | connect_kwargs={ 9 | 'key_filename': identity_file, 10 | }) 11 | 12 | 13 | def mount_volume(conn, device, mounting_point, user, group): 14 | # Catch tail of greeting output 15 | res = conn.sudo('whoami', hide=True) 16 | 17 | # Inspect volume's file system 18 | res = conn.sudo('file -s {}'.format(device), hide=True) 19 | 20 | # Ensure volume contains a file system 21 | has_file_system = res.stdout.strip() != '{}: data'.format(device) 22 | if not has_file_system: 23 | conn.sudo('mkfs -t ext4 {}'.format(device), hide=True) 24 | 25 | # Create mounting point 26 | res = conn.run('mkdir -p {}'.format(mounting_point), hide=True) 27 | 28 | # Mount volume 29 | res = conn.sudo('mount {} {}'.format(device, mounting_point), hide=True) 30 | 31 | # If file system has just been created, fix group and user of the mounting point 32 | if not has_file_system: 33 | res = conn.sudo('chown -R {}:{} {}'.format(group, user, mounting_point), hide=True) 34 | 35 | 36 | def install_python_packages(conn, virtual_env, packages): 37 | if not packages: 38 | return 39 | 40 | with conn.prefix('source activate {}'.format(virtual_env)): 41 | conn.run('pip install {}'.format(' '.join(packages)), hide=True) 42 | 43 | 44 | def install_packages(conn, packages): 45 | if not packages: 46 | return 47 | 48 | # TODO: handle locked /var/lib/dpkg/lock 49 | conn.sudo('apt install -y {}'.format(' '.join(packages))) 50 | 51 | 52 | def sync_files(conn, local_path, remote_path, is_upload, is_recursive, allow_delete=False, strict_host_keys=True): 53 | """This code was ported from https://github.com/fabric/patchwork and extended for two-way transfer. """ 54 | exclude = () 55 | ssh_opts = "" 56 | 57 | rsync_opts = '--out-format="[%t] {} %f %\'\'b"'.format('OUT' if is_upload else 'IN') 58 | if is_recursive: 59 | rsync_opts += ' -r' 60 | 61 | # Turn single-string exclude into a one-item list for consistency 62 | if isinstance(exclude, six.string_types): 63 | exclude = [exclude] 64 | # Create --exclude options from exclude list 65 | exclude_opts = ' --exclude "{}"' * len(exclude) 66 | # Double-backslash-escape 67 | exclusions = tuple([str(s).replace('"', '\\\\"') for s in exclude]) 68 | # Honor SSH key(s) 69 | key_string = "" 70 | # TODO: seems plausible we need to look in multiple places if there's too 71 | # much deferred evaluation going on in how we eg source SSH config files 72 | # and so forth, re: connect_kwargs 73 | # TODO: we could get VERY fancy here by eg generating a tempfile from any 74 | # in-memory-only keys...but that's also arguably a security risk, so... 75 | keys = conn.connect_kwargs.get("key_filename", []) 76 | # TODO: would definitely be nice for Connection/FabricConfig to expose an 77 | # always-a-list, always-up-to-date-from-all-sources attribute to save us 78 | # from having to do this sort of thing. (may want to wait for Paramiko auth 79 | # overhaul tho!) 80 | if isinstance(keys, six.string_types): 81 | keys = [keys] 82 | if keys: 83 | key_string = "-i " + " -i ".join(keys) 84 | # Get base cxn params 85 | user, host, port = conn.user, conn.host, conn.port 86 | port_string = "-p {}".format(port) 87 | # Remote shell (SSH) options 88 | rsh_string = "" 89 | # Strict host key checking 90 | disable_keys = "-o StrictHostKeyChecking=no" 91 | if not strict_host_keys and disable_keys not in ssh_opts: 92 | ssh_opts += " {}".format(disable_keys) 93 | rsh_parts = [key_string, port_string, ssh_opts] 94 | if any(rsh_parts): 95 | rsh_string = "--rsh='ssh {}'".format(" ".join(rsh_parts)) 96 | # Set up options part of string 97 | options_map = { 98 | "delete": "--delete" if allow_delete else "", 99 | "exclude": exclude_opts.format(*exclusions), 100 | "rsh": rsh_string, 101 | "extra": rsync_opts, 102 | } 103 | options = "{delete}{exclude} -pthrvz {extra} {rsh}".format(**options_map) 104 | 105 | # Create and run final command string 106 | # TODO: richer host object exposing stuff like .address_is_ipv6 or whatever 107 | if host.count(":") > 1: 108 | # Square brackets are mandatory for IPv6 rsync address, 109 | # even if port number is not specified 110 | cmd = "rsync {opt:} {local:} [{user:}@{host:}]:{remote:}" if is_upload else "rsync {opt:} [{user:}@{host:}]:{remote:} {local:}" 111 | else: 112 | cmd = "rsync {opt:} {local:} {user:}@{host:}:{remote:}" if is_upload else "rsync {opt:} {user:}@{host:}:{remote:} {local:}" 113 | 114 | cmd = cmd.format(opt=options, local=local_path, user=user, host=host, remote=remote_path) 115 | res = conn.local(cmd, hide=True) 116 | 117 | # Get transferred files 118 | transferred_files = res.stdout.strip('\n').split('\n')[1:-3] 119 | 120 | if len(transferred_files) > 0: 121 | print('\n'.join(transferred_files)) 122 | 123 | 124 | __all__ = [ 125 | 'create_connection', 126 | 'mount_volume', 127 | 'install_python_packages', 128 | 'install_packages', 129 | 'sync_files' 130 | ] 131 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/stable/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('..')) 18 | 19 | import portal_gun 20 | 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = 'Portal Gun' 25 | author = portal_gun.__author__ 26 | copyright = '2018, {}'.format(author) 27 | 28 | # The short X.Y version 29 | version = portal_gun.__version__ 30 | # The full version, including alpha/beta/rc tags 31 | release = portal_gun.__version__ 32 | 33 | 34 | # -- General configuration --------------------------------------------------- 35 | 36 | # If your documentation needs a minimal Sphinx version, state it here. 37 | # 38 | # needs_sphinx = '1.0' 39 | 40 | # Add any Sphinx extension module names here, as strings. They can be 41 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 42 | # ones. 43 | extensions = [ 44 | 'sphinx.ext.todo', 45 | 'sphinx.ext.ifconfig', 46 | 'sphinx.ext.viewcode', 47 | ] 48 | 49 | # Add any paths that contain templates here, relative to this directory. 50 | templates_path = ['_templates'] 51 | 52 | # The suffix(es) of source filenames. 53 | # You can specify multiple suffix as a list of string: 54 | # 55 | # source_suffix = ['.rst', '.md'] 56 | source_suffix = '.rst' 57 | 58 | # The master toctree document. 59 | master_doc = 'index' 60 | 61 | # The language for content autogenerated by Sphinx. Refer to documentation 62 | # for a list of supported languages. 63 | # 64 | # This is also used if you do content translation via gettext catalogs. 65 | # Usually you set "language" from the command line for these cases. 66 | language = None 67 | 68 | # List of patterns, relative to source directory, that match files and 69 | # directories to ignore when looking for source files. 70 | # This pattern also affects html_static_path and html_extra_path . 71 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 72 | 73 | # The name of the Pygments (syntax highlighting) style to use. 74 | pygments_style = 'sphinx' 75 | 76 | 77 | # -- Options for HTML output ------------------------------------------------- 78 | 79 | # The theme to use for HTML and HTML Help pages. See the documentation for 80 | # a list of builtin themes. 81 | # 82 | html_theme = 'alabaster' 83 | 84 | html_show_sourcelink = False 85 | 86 | # Theme options are theme-specific and customize the look and feel of a theme 87 | # further. For a list of options available for each theme, see the 88 | # documentation. 89 | # 90 | html_theme_options = { 91 | 'logo': 'portal_gun.png', 92 | 'github_user': 'Coderik', 93 | 'github_repo': 'portal-gun', 94 | 'github_banner': True, 95 | 'github_type': 'star' 96 | } 97 | 98 | # Add any paths that contain custom static files (such as style sheets) here, 99 | # relative to this directory. They are copied after the builtin static files, 100 | # so a file named "default.css" will overwrite the builtin "default.css". 101 | html_static_path = ['_static'] 102 | 103 | # Custom sidebar templates, must be a dictionary that maps document names 104 | # to template names. 105 | # 106 | # The default sidebars (for documents that don't match any pattern) are 107 | # defined by theme itself. Builtin themes are using these templates by 108 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 109 | # 'searchbox.html']``. 110 | 111 | html_sidebars = { 112 | 'index': ['about.html', 'searchbox.html'], 113 | '**': ['about.html', 'localtoc.html', 'relations.html', 'searchbox.html'] 114 | } 115 | 116 | 117 | # -- Options for HTMLHelp output --------------------------------------------- 118 | 119 | # Output file base name for HTML help builder. 120 | htmlhelp_basename = 'PortalGundoc' 121 | 122 | 123 | # -- Options for LaTeX output ------------------------------------------------ 124 | 125 | latex_elements = { 126 | # The paper size ('letterpaper' or 'a4paper'). 127 | # 128 | # 'papersize': 'letterpaper', 129 | 130 | # The font size ('10pt', '11pt' or '12pt'). 131 | # 132 | # 'pointsize': '10pt', 133 | 134 | # Additional stuff for the LaTeX preamble. 135 | # 136 | # 'preamble': '', 137 | 138 | # Latex figure (float) alignment 139 | # 140 | # 'figure_align': 'htbp', 141 | } 142 | 143 | # Grouping the document tree into LaTeX files. List of tuples 144 | # (source start file, target name, title, 145 | # author, documentclass [howto, manual, or own class]). 146 | latex_documents = [ 147 | (master_doc, 'PortalGun.tex', 'Portal Gun Documentation', 148 | 'Vadim Fedorov', 'manual'), 149 | ] 150 | 151 | 152 | # -- Options for manual page output ------------------------------------------ 153 | 154 | # One entry per manual page. List of tuples 155 | # (source start file, name, description, authors, manual section). 156 | man_pages = [ 157 | (master_doc, 'portalgun', 'Portal Gun Documentation', 158 | [author], 1) 159 | ] 160 | 161 | 162 | # -- Options for Texinfo output ---------------------------------------------- 163 | 164 | # Grouping the document tree into Texinfo files. List of tuples 165 | # (source start file, target name, title, author, 166 | # dir menu entry, description, category) 167 | texinfo_documents = [ 168 | (master_doc, 'PortalGun', 'Portal Gun Documentation', 169 | author, 'PortalGun', 'One line description of project.', 170 | 'Miscellaneous'), 171 | ] 172 | 173 | 174 | # -- Extension configuration ------------------------------------------------- 175 | 176 | # -- Options for todo extension ---------------------------------------------- 177 | 178 | # If true, `todo` and `todoList` produce output, else they produce nothing. 179 | todo_include_todos = True -------------------------------------------------------------------------------- /docs/commands.rst: -------------------------------------------------------------------------------- 1 | .. _commands: 2 | 3 | ======== 4 | Commands 5 | ======== 6 | 7 | Print top-level help message:: 8 | 9 | $ portal -h 10 | 11 | Add ``-h`` (or ``--help``) flag after commands and command groups to print corresponding help messages. For instance, print help message for the ``volume`` group including the list of commands:: 12 | 13 | $ portal volume -h 14 | 15 | **Top-level command options:** 16 | 17 | .. cmdoption:: -c CONFIG, --config CONFIG 18 | 19 | Set name and location of configuration file. 20 | 21 | .. _volume_cmd: 22 | 23 | Persistent Volumes 24 | ================== 25 | 26 | This section documents a group of commands that are used to manage persistent volumes. For information on how to configure attachment of persistent volumes to instances see :ref:`Portal Specification ` section. 27 | 28 | Create 29 | ------ 30 | 31 | Create a new EBS volume:: 32 | 33 | $ portal volume create 34 | 35 | Every volume requires **size** (in Gb) and **availability zone** to be specified. **Name** is optional, but recommended. If these three properties are not set using the command options, they will be requested from the standard input. 36 | 37 | Upon successful creation of a new volume its ```` will be provided. 38 | 39 | **Command options:** 40 | 41 | .. cmdoption:: -n NAME, --name NAME 42 | 43 | Set name for new volume. 44 | 45 | .. cmdoption:: -s SIZE, --size SIZE 46 | 47 | Set size (in Gb) for new volume. 48 | 49 | .. cmdoption:: -z ZONE, --zone ZONE 50 | 51 | Set availability zone for new volume. 52 | 53 | .. cmdoption:: -S SNAPSHOT, --snapshot SNAPSHOT 54 | 55 | Set Id of a snapshot to create new volume from. 56 | 57 | .. cmdoption:: -t key:value [key:value ...], --tags key:value [key:value ...] 58 | 59 | Set user tags for new volume. 60 | 61 | List 62 | ---- 63 | 64 | List existing EBS volume:: 65 | 66 | $ portal volume list 67 | 68 | By default ``list`` command outputs only the volumes created by Portal Gun on behalf of the current AWS user. To list all volumes use ``-a`` flag. 69 | 70 | **Command options:** 71 | 72 | .. cmdoption:: -a, --all 73 | 74 | Show all volumes, not only ones created by Portal Gun. 75 | 76 | Update 77 | ------ 78 | 79 | Update an AWS volume:: 80 | 81 | $ portal volume update 82 | 83 | **Command options:** 84 | 85 | .. cmdoption:: -n NAME, --name NAME 86 | 87 | Update name of volume. 88 | 89 | .. cmdoption:: -s SIZE, --size SIZE 90 | 91 | Update size of volume. 92 | 93 | .. cmdoption:: -t key:value [key:value ...], --tags key:value [key:value ...] 94 | 95 | Add user tags for volume. 96 | 97 | Delete 98 | ------ 99 | 100 | Delete an AWS volume:: 101 | 102 | $ portal volume delete 103 | 104 | By default ``delete`` command deletes only the volumes created by Portal Gun on behalf of the current AWS user. To force deletion of a volume use ``-f`` flag. 105 | 106 | **Command options:** 107 | 108 | .. cmdoption:: -f, --force 109 | 110 | Delete any volume, even not owned. 111 | 112 | ---- 113 | 114 | .. _portal_cmd: 115 | 116 | Portals 117 | ======= 118 | 119 | *Portal* is the main concept of the Portal Gun (see :ref:`Concepts ` for details). 120 | 121 | .. _portal_cmd_init: 122 | 123 | Init 124 | ---- 125 | 126 | Create a draft *portal specification* file:: 127 | 128 | $ portal init 129 | 130 | A file with the name ``.json`` will be created. Modify this file to set the appropriate values (see :ref:`Portal Specification ` section). 131 | 132 | Open 133 | ---- 134 | 135 | To open a portal means to request and configure a Spot Instance according to the *portal specification*. Open a portal:: 136 | 137 | $ portal open 138 | 139 | Ssh 140 | --- 141 | 142 | Once the portal is opened, connect to the remote instance via ssh:: 143 | 144 | $ portal ssh 145 | 146 | For long-running tasks like training a model it is particularly useful to be able to close current ssh session without interrupting the running task. One way of achieving this is offered by ``tmux``. "It lets you switch easily between several programs in one terminal, detach them (they keep running in the background) and reattach them to a different terminal." - tmux `wiki `_. You can run ``tmux`` within ssh session and then run the long task within ``tmux`` session. Portal Gun allows you to use tmux session automatically with ``-t`` command option. 147 | 148 | **Command options:** 149 | 150 | .. cmdoption:: -t [session], --tmux [session] 151 | 152 | Automatically open tmux session upon connection. Default session name is `portal`. 153 | 154 | Info 155 | ---- 156 | 157 | Check information about a portal:: 158 | 159 | $ portal info 160 | 161 | Information includes portal status (open or closed). If portal is open, information about the instance and attached volumes is provided. 162 | 163 | When Portal Gun is used in a shell script, it might be useful to get specific bits of information without the rest of the output. In this case use command option ``-f`` to get the value of one particular field. Supported fields are: 164 | 165 | * name - portal name; 166 | * status - portal status (open or close); 167 | * id - instance id; 168 | * type - instance type; 169 | * user - remote user; 170 | * host - remote host; 171 | * ip - public IP of instance; 172 | * remote - user@host 173 | * key - local ssh key file 174 | 175 | For instance, to copy a file from remote instance to local machine you can use Portal Gun to look up connection details:: 176 | 177 | $ scp -i "`portal info -f key`" `portal info -f remote`:/path/to/file /local/folder/ 178 | 179 | **Command options:** 180 | 181 | .. cmdoption:: -f FIELD, --field FIELD 182 | 183 | Print value for a specified field (name, status, id, type, user, host, ip, remote, key). 184 | 185 | Close 186 | ----- 187 | 188 | To close a portal means to cancel a Spot Instance request and terminate the instance itself. Close a portal:: 189 | 190 | $ portal close 191 | 192 | ---- 193 | 194 | .. _channel_cmd: 195 | 196 | Channels 197 | ======== 198 | 199 | Channels are used to sync remote and local folders. A channel has direction, source and target folders, and other properties. Every channel belongs to a portal and should be configured in the corresponding portal specification file (see :ref:`Portal Specification ` section for details). 200 | 201 | Channel 202 | ------- 203 | 204 | Start syncing specified folders:: 205 | 206 | $ portal channel 207 | 208 | Synchronization of files over the channels is done continuously using ``rsync``. Data transfer happens every time a new file appears or an existing file is changed in the source folder. 209 | 210 | To stop synchronization press ``^C``. -------------------------------------------------------------------------------- /portal_gun/one_of_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import Schema, ValidationError 2 | 3 | # NOTE: this implementation has been extracted from https://github.com/francbartoli/marshmallow-oneofschema 4 | # which is a fork of https://github.com/maximkulkin/marshmallow-oneofschema 5 | # As soon as PR #13 in maximkulkin/marshmallow-oneofschema is merged, this implementation should be 6 | # replaced by a proper marshmallow-oneofschema package 7 | 8 | 9 | class OneOfSchema(Schema): 10 | """ 11 | This is a special kind of schema that actually multiplexes other schemas 12 | based on object type. When serializing values, it uses get_obj_type() method 13 | to get object type name. Then it uses `type_schemas` name-to-Schema mapping 14 | to get schema for that particular object type, serializes object using that 15 | schema and adds an extra "type" field with name of object type. 16 | Deserialization is reverse. 17 | 18 | Example: 19 | 20 | class Foo(object): 21 | def __init__(self, foo): 22 | self.foo = foo 23 | 24 | class Bar(object): 25 | def __init__(self, bar): 26 | self.bar = bar 27 | 28 | class FooSchema(marshmallow.Schema): 29 | foo = marshmallow.fields.String(required=True) 30 | 31 | @marshmallow.post_load 32 | def make_foo(self, data): 33 | return Foo(**data) 34 | 35 | class BarSchema(marshmallow.Schema): 36 | bar = marshmallow.fields.Integer(required=True) 37 | 38 | @marshmallow.post_load 39 | def make_bar(self, data): 40 | return Bar(**data) 41 | 42 | class MyUberSchema(marshmallow.OneOfSchema): 43 | type_schemas = { 44 | 'foo': FooSchema, 45 | 'bar': BarSchema, 46 | } 47 | 48 | def get_obj_type(self, obj): 49 | if isinstance(obj, Foo): 50 | return 'foo' 51 | elif isinstance(obj, Bar): 52 | return 'bar' 53 | else: 54 | raise Exception('Unknown object type: %s' % repr(obj)) 55 | 56 | MyUberSchema().dump([Foo(foo='hello'), Bar(bar=123)], many=True).data 57 | # => [{'type': 'foo', 'foo': 'hello'}, {'type': 'bar', 'bar': 123}] 58 | 59 | You can control type field name added to serialized object representation by 60 | setting `type_field` class property. 61 | """ 62 | type_field = 'type' 63 | type_field_remove = True 64 | type_schemas = [] 65 | 66 | def get_obj_type(self, obj): 67 | """Returns name of object schema""" 68 | return obj.__class__.__name__ 69 | 70 | def dump(self, obj, many=None, update_fields=True, **kwargs): 71 | errors = {} 72 | result_data = [] 73 | result_errors = {} 74 | many = self.many if many is None else bool(many) 75 | if not many: 76 | result = result_data = self._dump(obj, update_fields, **kwargs) 77 | else: 78 | for idx, o in enumerate(obj): 79 | try: 80 | result = self._dump(o, update_fields, **kwargs) 81 | result_data.append(result) 82 | except ValidationError as error: 83 | result_errors[idx] = error.messages 84 | result_data.append(error.valid_data) 85 | 86 | result = result_data 87 | errors = result_errors 88 | 89 | if not errors: 90 | return result 91 | else: 92 | exc = ValidationError(errors, data=obj, valid_data=result) 93 | raise exc 94 | 95 | def _dump(self, obj, update_fields=True, **kwargs): 96 | obj_type = self.get_obj_type(obj) 97 | if not obj_type: 98 | return None, { 99 | '_schema': 'Unknown object class: %s' % obj.__class__.__name__ 100 | } 101 | 102 | type_schema = self.type_schemas.get(obj_type) 103 | if not type_schema: 104 | return None, { 105 | '_schema': 'Unsupported object type: %s' % obj_type 106 | } 107 | 108 | schema = ( 109 | type_schema if isinstance(type_schema, Schema) 110 | else type_schema() 111 | ) 112 | 113 | schema.context.update(getattr(self, 'context', {})) 114 | 115 | result = schema.dump( 116 | obj, many=False, update_fields=update_fields, **kwargs 117 | ) 118 | if result: 119 | result[self.type_field] = obj_type 120 | return result 121 | 122 | def load(self, data, many=None, partial=None): 123 | errors = {} 124 | result_data = [] 125 | result_errors = {} 126 | many = self.many if many is None else bool(many) 127 | if partial is None: 128 | partial = self.partial 129 | if not many: 130 | try: 131 | result = result_data = self._load(data, partial=partial) 132 | # result_data.append(result) 133 | except ValidationError as error: 134 | result_errors[0] = error.messages 135 | result_data.append(error.valid_data) 136 | else: 137 | for idx, item in enumerate(data): 138 | try: 139 | result = self._load(item, partial=partial) 140 | result_data.append(result) 141 | except ValidationError as error: 142 | result_errors[idx] = error.messages 143 | result_data.append(error.valid_data) 144 | 145 | result = result_data 146 | errors = result_errors 147 | 148 | if not errors: 149 | return result 150 | else: 151 | exc = ValidationError(errors, data=data, valid_data=result) 152 | raise exc 153 | 154 | def _load(self, data, partial=None): 155 | if not isinstance(data, dict): 156 | raise ValidationError({'_schema': 'Invalid data type: %s' % data}) 157 | 158 | data = dict(data) 159 | 160 | data_type = data.get(self.type_field) 161 | if self.type_field in data and self.type_field_remove: 162 | data.pop(self.type_field) 163 | 164 | if not data_type: 165 | raise ValidationError({ 166 | self.type_field: ['Missing data for required field.'] 167 | }) 168 | 169 | try: 170 | type_schema = self.type_schemas.get(data_type) 171 | except TypeError: 172 | # data_type could be unhashable 173 | raise ValidationError({ 174 | self.type_field: ['Invalid value: %s' % data_type] 175 | }) 176 | if not type_schema: 177 | raise ValidationError({ 178 | self.type_field: ['Unsupported value: %s' % data_type], 179 | }) 180 | 181 | schema = ( 182 | type_schema if isinstance(type_schema, Schema) else type_schema() 183 | ) 184 | 185 | schema.context.update(getattr(self, 'context', {})) 186 | 187 | return schema.load(data, many=False, partial=partial) 188 | 189 | def validate(self, data, many=None, partial=None): 190 | try: 191 | self.load(data, many=many, partial=partial) 192 | except ValidationError as ve: 193 | return ve.messages 194 | return {} 195 | -------------------------------------------------------------------------------- /portal_gun/providers/aws/aws_client.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore.exceptions import ClientError 3 | from botocore.exceptions import EndpointConnectionError 4 | 5 | from portal_gun.providers.exceptions import ProviderRequestError 6 | from .helpers import to_aws_tags 7 | 8 | 9 | def aws_api_caller(): 10 | from functools import wraps 11 | 12 | def api_caller_decorator(func): 13 | @wraps(func) 14 | def wrapper(*args, **kwargs): 15 | try: 16 | return func(*args, **kwargs) 17 | except EndpointConnectionError as e: 18 | raise ProviderRequestError('Could not make request to AWS.') 19 | except ClientError as e: 20 | raise ProviderRequestError(str(e)) 21 | 22 | return wrapper 23 | 24 | return api_caller_decorator 25 | 26 | 27 | class AwsClient(object): 28 | def __init__(self, access_key, secret_key, region): 29 | self._access_key = access_key 30 | self._secret_key = secret_key 31 | self._region = region 32 | self._ec2_client = None 33 | self._sts_client = None 34 | 35 | @aws_api_caller() 36 | def get_user_identity(self): 37 | # Call API 38 | response = self.sts_client().get_caller_identity() 39 | 40 | self._check_status_code(response) 41 | 42 | return response 43 | 44 | @aws_api_caller() 45 | def get_availability_zones(self): 46 | # Call API 47 | response = self.ec2_client().describe_availability_zones() 48 | 49 | self._check_status_code(response) 50 | 51 | try: 52 | zones = [zone['ZoneName'] for zone in response['AvailabilityZones'] if zone['State'] == 'available'] 53 | except KeyError as e: 54 | raise ProviderRequestError('Response from AWS has unexpected format: {}.'.format(e.message)) 55 | 56 | return zones 57 | 58 | @aws_api_caller() 59 | def get_subnets(self, availability_zone): 60 | # Define filters 61 | filters = [{'Name': 'availability-zone', 'Values': [availability_zone]}, 62 | {'Name': 'default-for-az', 'Values': ['true']}] 63 | 64 | # Call API 65 | response = self.ec2_client().describe_subnets(Filters=filters) 66 | 67 | self._check_status_code(response) 68 | 69 | try: 70 | subnets = response['Subnets'] 71 | except KeyError as e: 72 | raise ProviderRequestError('Response from AWS has unexpected format: {}.'.format(e.message)) 73 | 74 | return subnets 75 | 76 | @aws_api_caller() 77 | def find_spot_instance(self, portal_name, user): 78 | # Define filters 79 | filters = [{'Name': 'tag:portal-name', 'Values': [portal_name]}, 80 | {'Name': 'tag:created-by', 'Values': [user]}, 81 | {'Name': 'instance-state-name', 'Values': ['running', 'pending']}] 82 | 83 | # Call API 84 | response = self.ec2_client().describe_instances(Filters=filters) 85 | 86 | self._check_status_code(response) 87 | 88 | if len(response['Reservations']) == 0 or len(response['Reservations'][0]['Instances']) == 0: 89 | return None 90 | 91 | return response['Reservations'][0]['Instances'][0] 92 | 93 | @aws_api_caller() 94 | def get_instance(self, instance_id): 95 | # Call API 96 | response = self.ec2_client().describe_instances(InstanceIds=[instance_id]) 97 | 98 | self._check_status_code(response) 99 | 100 | if len(response['Reservations']) == 0 or len(response['Reservations'][0]['Instances']) == 0: 101 | return None 102 | 103 | return response['Reservations'][0]['Instances'][0] 104 | 105 | @aws_api_caller() 106 | def get_spot_fleet_instances(self, spot_fleet_request_id): 107 | # Call API 108 | response = self.ec2_client().describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_request_id) 109 | 110 | self._check_status_code(response) 111 | 112 | return response['ActiveInstances'] 113 | 114 | @aws_api_caller() 115 | def get_spot_fleet_request(self, spot_fleet_request_id): 116 | # Call API 117 | response = self.ec2_client().describe_spot_fleet_requests(SpotFleetRequestIds=[spot_fleet_request_id]) 118 | 119 | self._check_status_code(response) 120 | 121 | if len(response['SpotFleetRequestConfigs']) == 0: 122 | return None 123 | 124 | return response['SpotFleetRequestConfigs'][0] 125 | 126 | @aws_api_caller() 127 | def get_volumes_by_id(self, volume_ids): 128 | """ 129 | :param volume_ids: One or several volume Ids 130 | :type volume_ids: string or list 131 | :return: 132 | """ 133 | # Call API 134 | response = self.ec2_client().describe_volumes(VolumeIds=AwsClient._as_list(volume_ids)) 135 | 136 | self._check_status_code(response) 137 | 138 | return response['Volumes'] 139 | 140 | @aws_api_caller() 141 | def get_volumes(self, filters=None): 142 | if filters is None: 143 | filters = {} 144 | 145 | # Convert list of filters to the expected format 146 | aws_filters = [{'Name': k, 'Values': AwsClient._as_list(v)} for k, v in filters.items()] 147 | 148 | # Call API 149 | response = self.ec2_client().describe_volumes(Filters=aws_filters) 150 | 151 | self._check_status_code(response) 152 | 153 | return response['Volumes'] 154 | 155 | @aws_api_caller() 156 | def create_volume(self, size, availability_zone, tags=None, snapshot_id=None): 157 | if tags is None: 158 | tags = {} 159 | if snapshot_id is None: 160 | snapshot_id = '' 161 | 162 | # Convert tags to the expected format 163 | aws_tags = to_aws_tags(tags) 164 | 165 | # Call API 166 | response = self.ec2_client().create_volume(AvailabilityZone=availability_zone, 167 | Size=size, 168 | VolumeType='gp2', 169 | SnapshotId=snapshot_id, 170 | TagSpecifications=[{'ResourceType': 'volume', 'Tags': aws_tags}]) 171 | 172 | self._check_status_code(response) 173 | 174 | return response['VolumeId'] 175 | 176 | @aws_api_caller() 177 | def update_volume(self, volume_id, size): 178 | # Call API 179 | response = self.ec2_client().modify_volume(VolumeId=volume_id, 180 | Size=size) 181 | 182 | self._check_status_code(response) 183 | 184 | return response 185 | 186 | @aws_api_caller() 187 | def attach_volume(self, instance_id, volume_id, device): 188 | # Call API 189 | response = self.ec2_client().attach_volume(InstanceId=instance_id, 190 | VolumeId=volume_id, 191 | Device=device) 192 | 193 | self._check_status_code(response) 194 | 195 | return response 196 | 197 | @aws_api_caller() 198 | def delete_volume(self, volume_id): 199 | # Call API 200 | response = self.ec2_client().delete_volume(VolumeId=volume_id) 201 | 202 | self._check_status_code(response) 203 | 204 | return response 205 | 206 | @aws_api_caller() 207 | def add_tags(self, resource_ids, tags): 208 | """ 209 | Add or overwrite tags for an EC2 resource (e.g. an instance or a volume). 210 | :param resource_ids: One or several resources to be affected 211 | :param tags: Dictionary of tags 212 | :type resource_ids: string or list 213 | :type tags: dict 214 | :return: 215 | """ 216 | # Convert tags to the expected format 217 | aws_tags = to_aws_tags(tags) 218 | 219 | # Call API 220 | response = self.ec2_client().create_tags(Resources=AwsClient._as_list(resource_ids), Tags=aws_tags) 221 | 222 | self._check_status_code(response) 223 | 224 | return True 225 | 226 | @aws_api_caller() 227 | def remove_tags(self, resource_ids, keys): 228 | """ 229 | Remove tags for an EC2 resource (e.g. an instance or a volume). 230 | :param resource_ids: One or several resources to be affected 231 | :param keys: One or several tag keys to be removed 232 | :type resource_ids: string or list 233 | :type keys: string or list 234 | :return: 235 | """ 236 | aws_tags = [{'Key': key} for key in AwsClient._as_list(keys)] 237 | 238 | # Call API 239 | response = self.ec2_client().delete_tags(Resources=AwsClient._as_list(resource_ids), Tags=aws_tags) 240 | 241 | self._check_status_code(response) 242 | 243 | return True 244 | 245 | @aws_api_caller() 246 | def request_spot_fleet(self, config): 247 | # Call API 248 | response = self.ec2_client().request_spot_fleet(SpotFleetRequestConfig=config) 249 | 250 | self._check_status_code(response) 251 | 252 | return response 253 | 254 | @aws_api_caller() 255 | def cancel_spot_fleet_request(self, spot_fleet_request_id): 256 | # Call API 257 | response = self.ec2_client().cancel_spot_fleet_requests(SpotFleetRequestIds=[spot_fleet_request_id], 258 | TerminateInstances=True) 259 | 260 | self._check_status_code(response) 261 | 262 | # TODO: check the response to make sure request was canceled 263 | return True 264 | 265 | def ec2_client(self): 266 | if self._ec2_client is None: 267 | self._ec2_client = boto3.client('ec2', 268 | self._region, 269 | aws_access_key_id=self._access_key, 270 | aws_secret_access_key=self._secret_key) 271 | 272 | return self._ec2_client 273 | 274 | def sts_client(self): 275 | if self._sts_client is None: 276 | self._sts_client = boto3.client('sts', 277 | aws_access_key_id=self._access_key, 278 | aws_secret_access_key=self._secret_key) 279 | 280 | return self._sts_client 281 | 282 | @staticmethod 283 | def _as_list(x): 284 | """ 285 | Ensure that argument is a list 286 | :param x: Individual element or list 287 | :return: List 288 | :rtype: list 289 | """ 290 | return x if type(x) == list else [x] 291 | 292 | @staticmethod 293 | def _check_status_code(response): 294 | status_code = response['ResponseMetadata']['HTTPStatusCode'] 295 | if status_code != 200: 296 | raise ProviderRequestError('Request to AWS failed with status code {}.'.format(status_code)) 297 | -------------------------------------------------------------------------------- /docs/portal_spec.rst: -------------------------------------------------------------------------------- 1 | .. _portal_spec: 2 | 3 | ==================== 4 | Portal Specification 5 | ==================== 6 | 7 | .. code-block:: json 8 | 9 | { 10 | "spot_instance": { 11 | "instance_type": "string (required)", 12 | "image_id": "string (required)", 13 | "key_pair_name": "string (required)", 14 | "identity_file": "string (required)", 15 | "security_group_id": "string (required)", 16 | "availability_zone": "string (required)", 17 | "subnet_id": "string (optional)", 18 | "ebs_optimized": "boolean (optional)", 19 | "remote_group": "string (required)", 20 | "remote_user": "string (required)", 21 | "python_virtual_env": "string (optional)", 22 | "extra_python_packages": [ 23 | "string (optional)" 24 | ] 25 | }, 26 | "spot_fleet": { 27 | "iam_fleet_role": "string (required)" 28 | }, 29 | "persistent_volumes": [ 30 | { 31 | "volume_id": "string (required)", 32 | "device": "string (required)", 33 | "mount_point": "string (required)" 34 | } 35 | ], 36 | "channels": [ 37 | { 38 | "direction": "string (required)", 39 | "local_path": "string (required)", 40 | "remote_path": "string (required)", 41 | "recursive": "boolean (optional)", 42 | "delay": "float (optional)" 43 | } 44 | ] 45 | } 46 | 47 | Draft portal specification as the one above can be created using :ref:`init ` command. 48 | 49 | Schema 50 | ====== 51 | 52 | **spot_instance** 53 | ^^^^^^^^^^^^^^^^^ 54 | 55 | *Type: object. Required.* 56 | 57 | Specification of a Spot Instance. 58 | 59 | spot_instance . **instance_type** 60 | """"""""""""""""""""""""""""""""" 61 | 62 | *Type: string. Required.* 63 | 64 | Type of AWS EC2 Instance to be requested. 65 | 66 | Detailed description of available types can be found `here `_. 67 | 68 | spot_instance . **image_id** 69 | """""""""""""""""""""""""""" 70 | 71 | *Type: string. Required.* 72 | 73 | Id of `Amazon Machine Images (AMI) `_ to be used to launch the Instance. 74 | 75 | For information about Deep Learning AMIa see :ref:`bellow `. 76 | 77 | spot_instance . **key_pair_name** 78 | """"""""""""""""""""""""""""""""" 79 | 80 | *Type: string. Required.* 81 | 82 | Name of `AWS EC2 Key Pair `_ to be used to connect to the Instance. The Key Pair should match the identity file. 83 | 84 | spot_instance . **identity_file** 85 | """"""""""""""""""""""""""""""""" 86 | 87 | *Type: string. Required.* 88 | 89 | Location of identity file (private key) to be used for authentication when connecting to the Instance. This file is generated by AWS upon creation of a new Key Pair. The identity file should match the Key Pair. 90 | 91 | spot_instance . **security_group_id** 92 | """"""""""""""""""""""""""""""""""""" 93 | 94 | *Type: string. Required.* 95 | 96 | Id of `AWS EC2 Security Group `_ to be associated with the Instance. Security Group controls which inbound and outbound traffic is allowed for the Instance. Make sure to at least allow inbound ssh traffic (port 22). 97 | 98 | spot_instance . **availability_zone** 99 | """"""""""""""""""""""""""""""""""""" 100 | 101 | *Type: string. Required.* 102 | 103 | Name of `AWS Availability Zone `_ in which to launch the Instance. Availability Zone has to be within the Region, specified in the :ref:`application config `. 104 | 105 | Note that Spot Instance prices might differ between Availability Zones. 106 | 107 | spot_instance . **subnet_id** 108 | """"""""""""""""""""""""""""" 109 | 110 | *Type: string. Optional.* 111 | 112 | Id of `Subnet `_ to be used for the Instance. If not specified, default Subnet of the Availability Zone is used. 113 | 114 | spot_instance . **ebs_optimized** 115 | """"""""""""""""""""""""""""""""" 116 | 117 | *Type: boolean. Optional.* 118 | 119 | Enable/disable `EBS Optimization `_. An EBS–optimized instance uses an optimized configuration stack and provides additional, dedicated capacity for EBS I/O. 120 | 121 | spot_instance . **remote_group** 122 | """""""""""""""""""""""""""""""" 123 | 124 | *Type: string. Required.* 125 | 126 | Default AMI user group. For images based on Ubuntu in most cases the group will be *ubuntu*. If in doubt, check AMI usage instructions. 127 | 128 | spot_instance . **remote_user** 129 | """"""""""""""""""""""""""""""" 130 | 131 | *Type: string. Required.* 132 | 133 | Default AMI username. For images based on Ubuntu in most cases the username will be *ubuntu*. If in doubt, check AMI usage instructions. 134 | 135 | spot_instance . **python_virtual_env** 136 | """""""""""""""""""""""""""""""""""""" 137 | 138 | *Type: string. Optional.* 139 | 140 | Default Python virtual environment to be used to install extra Python packages. Should be specified when *spot_instance.extra_python_packages* is specified. 141 | 142 | spot_instance . **extra_python_packages** 143 | """"""""""""""""""""""""""""""""""""""""" 144 | 145 | *Type: array of strings. Optional.* 146 | 147 | Extra Python packages to be installed in the default Python virtual environment. 148 | 149 | **spot_fleet** 150 | ^^^^^^^^^^^^^^ 151 | 152 | *Type: object. Required.* 153 | 154 | Specification of a Spot Instance Fleet. 155 | 156 | spot_fleet . **iam_fleet_role** 157 | """"""""""""""""""""""""""""""" 158 | 159 | *Type: string. Required.* 160 | 161 | IAM role that grants the Spot Fleet permission to terminate Spot Instances on your behalf when you cancel its Spot Fleet request. For instance: 162 | 163 | *arn:aws:iam::123456789012:role/aws-ec2-spot-fleet-tagging-role* 164 | 165 | where "123456789012" should be replaced by your AWS Account Id which can be found in `AWS Console `_. 166 | 167 | .. _portal_spec_volumes: 168 | 169 | **persistent_volumes** 170 | ^^^^^^^^^^^^^^^^^^^^^^ 171 | 172 | *Type: array of objects. Required.* 173 | 174 | Specifications of EBS volumes to be attached. Use :ref:`volume ` group of commands to manage and list volumes. 175 | 176 | **Note:** to be able to attach EBS Volumes to an Instance, they should be in the same Availability Zone. 177 | 178 | persistent_volumes[] . **volume_id** 179 | """""""""""""""""""""""""""""""""""" 180 | 181 | *Type: string. Required.* 182 | 183 | Id of EBS volume to be attached to the Instance. 184 | 185 | persistent_volumes[] . **device** 186 | """"""""""""""""""""""""""""""""" 187 | 188 | *Type: string. Required.* 189 | 190 | Name of device to represent the attached volume. For example, ``/dev/xvdf``. See `documentation `_ for details. 191 | 192 | persistent_volumes[] . **mount_point** 193 | """""""""""""""""""""""""""""""""""""" 194 | 195 | *Type: string. Required.* 196 | 197 | Mounting point within the Instance file system, where device representing the volume should be mounted. For example, ``/home/ubuntu/workspace`` (assuming that AMI username is *ubuntu*). 198 | 199 | .. _portal_spec_channels: 200 | 201 | **channels** 202 | ^^^^^^^^^^^^ 203 | 204 | *Type: array of objects. Required.* 205 | 206 | Specifications of file synchronization channels. 207 | 208 | 209 | channels[] . **direction** 210 | """""""""""""""""""""""""" 211 | 212 | *Type: string. Required.* 213 | 214 | Direction of file transfer. Expected values are "*in*" and "*out*". Inbound channel transfers files from the remote Instance to the local machine. Outbound channel transfers files from the local machine to the remote Instance. 215 | 216 | channels[] . **local_path** 217 | """"""""""""""""""""""""""" 218 | 219 | *Type: string. Required.* 220 | 221 | Local path to be used in synchronization. Note that synchronization is done via ``rsync``, therefore, similar rules regarding the trailing slash (/) in the source path are applied (see :ref:`excerpt ` of rsync help for details). 222 | 223 | channels[] . **remote_path** 224 | """""""""""""""""""""""""""" 225 | 226 | *Type: string. Required.* 227 | 228 | Remote path to be used in synchronization. Note that synchronization is done via ``rsync``, therefore, similar rules regarding the trailing slash (/) in the source path are applied (see :ref:`excerpt ` of rsync help for details). 229 | 230 | channels[] . **recursive** 231 | """""""""""""""""""""""""" 232 | 233 | *Type: boolean. Optional.* 234 | 235 | Enable/disable recursive synchronization. Disabled by default. 236 | 237 | channels[] . **delay** 238 | """""""""""""""""""""" 239 | 240 | *Type: float. Optional.* 241 | 242 | Delay between two consecutive synchronization attempts. Defaults to 1 second. 243 | 244 | ---- 245 | 246 | Additional Details 247 | ================== 248 | 249 | .. _deep_learning_amis: 250 | 251 | Deep Learning AMIs 252 | ^^^^^^^^^^^^^^^^^^ 253 | 254 | `Amazon Machine Images (AMI) `_ are used to create virtual machines within the AWS EC2. They capture the exact state of software environment: operating system, libraries, applications, etc. One can think of them as templates. Pre-configured AMIs can be found in `AWS Marketplace `_. Some of them are free to use, others have per hour license price depending on the set of pre-installed software. EC2 users can also created their own Images. 255 | 256 | `Deep Learning AMIs `_ - is a group of AMIs created by Amazon specifically for deep learning applications. They come pre-installed with open-source deep learning frameworks including TensorFlow, Apache MXNet, PyTorch, Chainer, Microsoft Cognitive Toolkit, Caffe, Caffe2, Theano, and Keras, optimized for high performance on Amazon EC2 instances. These AMIs are free to use, you only pay for the AWS resources needed to store and run your applications. Official documentation, guides and tutorials can be found `here `_. 257 | 258 | There are several different flavors of Deep Learning AMIs. Check the `guide `_ to know the difference between them. 259 | 260 | In order to instruct Portal Gun to use one of the Deep Learning AMIs to create an AWS Instance you need to know its **ID**: 261 | 262 | 1. Go to `AWS Marketplace `_ and search for *"deep learning ami"*; 263 | 2. Pick an image from the search results, e.g. *Deep Learning Base AMI (Ubuntu)*; 264 | 3. On the AMI's page click **Continue to Subscribe** button; 265 | 4. On the opened page select **Manual Launch** tab; 266 | 5. In the **Launch** section you will see AMI IDs for different regions. 267 | 268 | .. _rsync_help: 269 | 270 | Rsync Help on Trailing Slash 271 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 272 | 273 | *An excerpt of* ``man rsync``: 274 | 275 | Recursively transfer all files from the directory src/bar on the machine foo into the /data/tmp/bar directory on the local machine:: 276 | 277 | $ rsync foo:src/bar /data/tmp 278 | 279 | A trailing slash on the source changes this behavior to avoid creating an additional directory level at the destination:: 280 | 281 | $ rsync foo:src/bar/ /data/tmp 282 | 283 | You can think of a trailing / on a source as meaning "copy the contents of this directory" as opposed to "copy the directory by name", but in both cases the attributes of the containing directory are transferred to the containing directory on the destination. In other words, each of the following commands copies the files in the same way, including their setting of the attributes of /dest/foo:: 284 | 285 | $ rsync /src/foo /dest 286 | $ rsync /src/foo/ /dest/foo 287 | -------------------------------------------------------------------------------- /portal_gun/commands/handlers/gcp_handler.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import datetime 4 | import sys 5 | import time 6 | 7 | from portal_gun.configuration.draft import generate_draft 8 | from portal_gun.configuration.schemas import PortalSchema, ComputeSchema 9 | import portal_gun.providers.gcp.helpers as gcp_helpers 10 | import portal_gun.fabric as fab 11 | from portal_gun.commands.exceptions import CommandError 12 | from portal_gun.commands.handlers.base_handler import BaseHandler 13 | from portal_gun.context_managers.print_scope import print_scope 14 | from portal_gun.context_managers.step import step 15 | from portal_gun.providers.gcp.gcp_client import GcpClient 16 | from portal_gun.providers.gcp.pretty_print import print_volume 17 | 18 | 19 | class GcpHandler(BaseHandler): 20 | def __init__(self, config): 21 | super(GcpHandler, self).__init__(config) 22 | 23 | @staticmethod 24 | def provider_name(): 25 | return 'gcp' 26 | 27 | @staticmethod 28 | def provider_long_name(): 29 | return 'Google Cloud Platform' 30 | 31 | @staticmethod 32 | def generate_portal_spec(): 33 | return generate_draft(PortalSchema(), selectors={ComputeSchema: 'gcp'}) 34 | 35 | def open_portal(self, portal_spec, portal_name): 36 | # Create GCP client 37 | gcp = self._create_client() 38 | 39 | # Define shortcuts 40 | compute_spec = portal_spec['compute'] 41 | instance_spec = compute_spec['instance'] 42 | # network_spec = compute_spec['network'] 43 | auth_spec = compute_spec['auth'] 44 | 45 | instance_name = gcp_helpers.get_instance_name(portal_spec, portal_name) 46 | 47 | with print_scope('Retrieving data from GCP:', 'Done.\n'): 48 | # Ensure that instance does not yet exist 49 | with step('Check already running instances', 50 | error_message='Portal `{}` seems to be already opened'.format(portal_name), 51 | catch=[RuntimeError]): 52 | instance_info = gcp.find_instance(instance_name) 53 | 54 | if instance_info is not None: 55 | raise RuntimeError('Instance is already running') 56 | # TODO: Retrieving other data from GCP 57 | 58 | # Make request for instance 59 | with print_scope('Requesting an instance:'): 60 | instance_props = gcp_helpers.build_instance_props(portal_spec, instance_name) 61 | operation = gcp.request_instance(instance_props) 62 | 63 | # Wait for instance request to be fulfilled 64 | print('Waiting for the instance to be created...') 65 | print('(usually it takes around a minute, but might take much longer)') 66 | try: 67 | elapsed_seconds = self._wait_for(operation, gcp) 68 | except KeyboardInterrupt: 69 | print('\n') 70 | print('Interrupting...') 71 | 72 | # Cancel spot instance request 73 | gcp.cancel_instance_request(instance_name) 74 | 75 | raise CommandError('Instance request has been cancelled.') 76 | print('\nInstance was created in {} seconds.\n'.format(elapsed_seconds)) 77 | 78 | # Get information about the created instance 79 | instance_info = gcp.get_instance(instance_name) 80 | 81 | public_ip = instance_info['networkInterfaces'][0]['accessConfigs'][0]['natIP'] 82 | public_dns = public_ip 83 | 84 | # Configure ssh connection via fabric 85 | fab_conn = fab.create_connection(public_dns, auth_spec['user'], auth_spec['private_ssh_key']) 86 | 87 | with print_scope('Preparing the instance:', 'Instance is ready.\n'): 88 | # Mount persistent volumes 89 | for i in range(len(portal_spec['persistent_volumes'])): 90 | with step('Mount volume #{}'.format(i), error_message='Could not mount volume', 91 | catch=[RuntimeError]): 92 | volume_spec = portal_spec['persistent_volumes'][i] 93 | 94 | # Mount volume 95 | fab.mount_volume(fab_conn, volume_spec['device'], volume_spec['mount_point'], 96 | auth_spec['user'], auth_spec['group']) 97 | 98 | # TODO: consider importing and executing custom fab tasks instead 99 | # Install extra python packages, if needed 100 | if 'provision_actions' in compute_spec and len(compute_spec['provision_actions']) > 0: 101 | for action_spec in compute_spec['provision_actions']: 102 | if action_spec['name'] == 'install-python-packages': 103 | virtual_env = action_spec['args']['virtual_env'] 104 | packages = action_spec['args']['packages'] 105 | with step('Install extra python packages', error_message='Could not install python packages', 106 | catch=[RuntimeError]): 107 | fab.install_python_packages(fab_conn, virtual_env, packages) 108 | elif action_spec['name'] == 'install-packages': 109 | packages = action_spec['args']['packages'] 110 | with step('Install extra packages', error_message='Could not install extra packages', 111 | catch=[RuntimeError]): 112 | fab.install_packages(fab_conn, packages) 113 | 114 | # Print summary 115 | print('Portal `{}` is now opened.'.format(portal_name)) 116 | with print_scope('Summary:', ''): 117 | with print_scope('Instance:'): 118 | print('Id: {}'.format(instance_info['id'])) 119 | print('Name: {}'.format(instance_name)) 120 | print('Type: {}'.format(instance_info['machineType'].rsplit('/', 1)[1])) 121 | print('Public IP: {}'.format(public_ip)) 122 | print('Public DNS name: {}'.format(public_dns)) 123 | with print_scope('Persistent volumes:'): 124 | for volume_spec in portal_spec['persistent_volumes']: 125 | print('{}: {}'.format(volume_spec['device'], volume_spec['mount_point'])) 126 | 127 | # Print ssh command 128 | print('Use the following command to connect to the remote machine:') 129 | print('ssh -i "{}" {}@{}'.format(auth_spec['private_ssh_key'], 130 | auth_spec['user'], 131 | public_dns)) 132 | 133 | def close_portal(self, portal_spec, portal_name): 134 | # Create GCP client 135 | gcp = self._create_client() 136 | 137 | instance_name = gcp_helpers.get_instance_name(portal_spec, portal_name) 138 | 139 | with print_scope('Retrieving data from GCP:', 'Done.\n'): 140 | # Get spot instance 141 | with step('Get instance', error_message='Portal `{}` does not seem to be opened'.format(portal_name), 142 | catch=[RuntimeError]): 143 | instance_info = gcp.find_instance(instance_name) 144 | 145 | if instance_info is None: 146 | raise RuntimeError('Instance is not running') 147 | 148 | # Delete instance 149 | operation = gcp.delete_instance(instance_name) 150 | 151 | # Wait for instance to be deleted 152 | print('Waiting for the instance to be deleted...') 153 | try: 154 | elapsed_seconds = self._wait_for(operation, gcp) 155 | print('Portal `{}` has been closed in {} seconds.'.format(portal_name, elapsed_seconds)) 156 | except KeyboardInterrupt: 157 | print('\n') 158 | print('Stop waiting. Instance will still be deleted eventually.') 159 | 160 | def show_portal_info(self, portal_spec, portal_name): 161 | # Create GCP client 162 | gcp = self._create_client() 163 | 164 | # Define shortcut 165 | auth_spec = portal_spec['compute']['auth'] 166 | 167 | instance_name = gcp_helpers.get_instance_name(portal_spec, portal_name) 168 | 169 | volumes = [] 170 | with print_scope('Retrieving data from GCP:', 'Done.\n'): 171 | # Get instance 172 | with step('Get instance', error_message='Portal `{}` does not seem to be opened'.format(portal_name), 173 | catch=[RuntimeError]): 174 | instance_info = gcp.find_instance(instance_name) 175 | 176 | # Get persistent volumes, if portal is opened 177 | # if instance_info is not None: 178 | # with step('Get volumes'): 179 | # # TODO: get volumes 180 | 181 | # Print status 182 | if instance_info is not None: 183 | public_ip = instance_info['networkInterfaces'][0]['accessConfigs'][0]['natIP'] 184 | public_dns = public_ip 185 | 186 | with print_scope('Summary:', ''): 187 | print('Name: {}'.format(portal_name)) 188 | print('Status: open') 189 | 190 | with print_scope('Instance:', ''): 191 | print('Id: {}'.format(instance_info['id'])) 192 | print('Name: {}'.format(instance_name)) 193 | print('Type: {}'.format(instance_info['machineType'].rsplit('/', 1)[1])) 194 | print('Public IP: {}'.format(public_ip)) 195 | print('Public DNS name: {}'.format(public_dns)) 196 | print('User: {}'.format(auth_spec['user'])) 197 | 198 | # TODO: print volumes' details 199 | # with print_scope('Persistent volumes:', ''): 200 | # for i in range(len(volumes)): 201 | # volume = volumes[i] 202 | # with print_scope('Volume #{}:'.format(i), ''): 203 | # self._print_volume_info(volume) 204 | 205 | # Print ssh command 206 | with print_scope('Use the following command to connect to the remote machine:'): 207 | print('ssh -i "{}" {}@{}'.format(auth_spec['private_ssh_key'], 208 | auth_spec['user'], 209 | public_dns)) 210 | else: 211 | with print_scope('Summary:'): 212 | print('Name: {}'.format(portal_name)) 213 | print('Status: close') 214 | 215 | def get_portal_info_field(self, portal_spec, portal_name, field): 216 | # Define shortcut 217 | auth_spec = portal_spec['compute']['auth'] 218 | 219 | if field == 'name': 220 | return portal_name 221 | if field == 'user': 222 | return auth_spec['user'] 223 | if field == 'key': 224 | return auth_spec['private_ssh_key'] 225 | 226 | # Create GCP client 227 | gcp = self._create_client() 228 | 229 | # Get instance 230 | instance_name = gcp_helpers.get_instance_name(portal_spec, portal_name) 231 | instance_info = gcp.find_instance(instance_name) 232 | 233 | if field == 'status': 234 | return 'open' if instance_info is not None else 'close' 235 | 236 | # If portal is closed, we cannot provide any other information 237 | if instance_info is None: 238 | return None 239 | 240 | if field == 'id': 241 | return instance_info['id'] 242 | if field == 'type': 243 | return instance_info['machineType'].rsplit('/', 1)[1] 244 | if field == 'host': 245 | return instance_info['networkInterfaces'][0]['accessConfigs'][0]['natIP'] 246 | if field == 'ip': 247 | return instance_info['networkInterfaces'][0]['accessConfigs'][0]['natIP'] 248 | if field == 'remote': 249 | return '{}@{}'.format(auth_spec['user'], 250 | instance_info['networkInterfaces'][0]['accessConfigs'][0]['natIP']) 251 | 252 | return None 253 | 254 | def get_ssh_params(self, portal_spec, portal_name): 255 | # Create GCP client 256 | gcp = self._create_client() 257 | 258 | # Define shortcut 259 | auth_spec = portal_spec['compute']['auth'] 260 | 261 | with print_scope('Retrieving data from GCP:', 'Done.\n'): 262 | # Get current user 263 | # with step('Get user identity'): 264 | # aws_user = aws.get_user_identity() 265 | 266 | # Get spot instance 267 | with step('Get instance', error_message='Portal `{}` does not seem to be opened'.format(portal_name)): 268 | instance_name = gcp_helpers.get_instance_name(portal_spec, portal_name) 269 | instance_info = gcp.get_instance(instance_name) 270 | if instance_info is None: 271 | raise CommandError('Portal `{}` does not seem to be opened'.format(portal_name)) 272 | 273 | public_ip = instance_info['networkInterfaces'][0]['accessConfigs'][0]['natIP'] 274 | 275 | # Return parameters for ssh 276 | return (auth_spec['private_ssh_key'], 277 | auth_spec['user'], 278 | public_ip, 279 | True) 280 | 281 | def list_volumes(self, args): 282 | # Create GCP client 283 | gcp = self._create_client() 284 | 285 | volumes = gcp.get_volumes() 286 | 287 | # Pretty print list of volumes 288 | list(map(print_volume, volumes)) 289 | 290 | def create_volume(self, args): 291 | raise NotImplementedError('Every subclass of BaseHandler should implement create_volume() method.') 292 | 293 | def update_volume(self, args): 294 | raise NotImplementedError('Every subclass of BaseHandler should implement update_volume() method.') 295 | 296 | def delete_volume(self, args): 297 | raise NotImplementedError('Every subclass of BaseHandler should implement delete_volume() method.') 298 | 299 | def _create_client(self): 300 | assert self._config 301 | 302 | return GcpClient(self._config['service_account_file'], self._config['project'], self._config['region']) 303 | 304 | def _wait_for(self, operation, gcp_client): 305 | begin_time = datetime.datetime.now() 306 | next_time = begin_time 307 | while True: 308 | # Repeat status request every N seconds 309 | if datetime.datetime.now() > next_time: 310 | operation = gcp_client.get_operation(operation['name']) 311 | next_time += datetime.timedelta(seconds=5) 312 | 313 | # Compute time spend in waiting 314 | elapsed = datetime.datetime.now() - begin_time 315 | 316 | # Check operation status 317 | request_state = operation['status'] 318 | if request_state == 'DONE': 319 | break 320 | else: 321 | print('Elapsed {}s. Operation is {}' 322 | .format(elapsed.seconds, request_state), end='\r') 323 | 324 | sys.stdout.flush() # ensure stdout is flushed immediately. 325 | time.sleep(0.5) 326 | 327 | return (datetime.datetime.now() - begin_time).seconds 328 | -------------------------------------------------------------------------------- /portal_gun/commands/handlers/aws_handler.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sys 3 | import time 4 | 5 | import portal_gun.providers.aws.helpers as aws_helpers 6 | import portal_gun.fabric as fab 7 | from portal_gun.commands.exceptions import CommandError 8 | from portal_gun.commands.handlers.base_handler import BaseHandler 9 | from portal_gun.configuration.draft import generate_draft 10 | from portal_gun.configuration.schemas import PortalSchema, ComputeSchema 11 | from portal_gun.context_managers.print_scope import print_scope 12 | from portal_gun.context_managers.step import step 13 | from portal_gun.providers.aws.aws_client import AwsClient 14 | from portal_gun.providers.aws.pretty_print import print_volume 15 | 16 | 17 | class AwsHandler(BaseHandler): 18 | def __init__(self, config): 19 | super(AwsHandler, self).__init__(config) 20 | 21 | self._proper_tag_key = 'dimension' 22 | self._proper_tag_value = 'C-137' 23 | self._service_tags = [self._proper_tag_key, 'created-by', 'mount-point'] 24 | self._default_size = 50 # Gb 25 | self._min_size = 1 # Gb 26 | self._max_size = 16384 # Gb 27 | 28 | @staticmethod 29 | def provider_name(): 30 | return 'aws' 31 | 32 | @staticmethod 33 | def provider_long_name(): 34 | return 'Amazon Web Services' 35 | 36 | @staticmethod 37 | def generate_portal_spec(): 38 | return generate_draft(PortalSchema(), selectors={ComputeSchema: 'aws'}) 39 | 40 | def open_portal(self, portal_spec, portal_name): 41 | # Create AWS client 42 | aws = self._create_client() 43 | 44 | # Define shortcuts 45 | compute_spec = portal_spec['compute'] 46 | instance_spec = compute_spec['instance'] 47 | network_spec = compute_spec['network'] 48 | auth_spec = compute_spec['auth'] 49 | 50 | with print_scope('Retrieving data from AWS:', 'Done.\n'): 51 | # Get current user 52 | with step('Get user identity'): 53 | user = aws.get_user_identity() 54 | 55 | # Ensure that instance does not yet exist 56 | with step('Check already running instances', 57 | error_message='Portal `{}` seems to be already opened'.format(portal_name), 58 | catch=[RuntimeError]): 59 | spot_instance = aws.find_spot_instance(portal_name, user['Arn']) 60 | 61 | if spot_instance is not None: 62 | raise RuntimeError('Instance is already running') 63 | 64 | # Ensure persistent volumes are available 65 | with step('Check volumes availability', catch=[RuntimeError]): 66 | volume_ids = [volume_spec['volume_id'] for volume_spec in portal_spec['persistent_volumes']] 67 | volumes = aws.get_volumes_by_id(volume_ids) 68 | 69 | if not all([volume['State'] == 'available' for volume in volumes]): 70 | states = ['{} is {}'.format(volume['VolumeId'], volume['State']) for volume in volumes] 71 | raise RuntimeError('Not all volumes are available: {}'.format(', '.join(states))) 72 | 73 | # If subnet Id is not provided, pick the default subnet of the availability zone 74 | if 'subnet_id' not in network_spec or not network_spec['subnet_id']: 75 | with step('Get subnet id', catch=[IndexError, KeyError]): 76 | subnets = aws.get_subnets(instance_spec['availability_zone']) 77 | network_spec['subnet_id'] = subnets[0]['SubnetId'] 78 | 79 | # Make request for Spot instance 80 | with print_scope('Requesting a Spot instance of type {}:'.format(instance_spec['type'])): 81 | request_config = aws_helpers.single_instance_spot_fleet_request(portal_spec, portal_name, user['Arn']) 82 | response = aws.request_spot_fleet(request_config) 83 | spot_fleet_request_id = response['SpotFleetRequestId'] 84 | 85 | # Wait for spot fleet request to be fulfilled 86 | print('Waiting for the Spot instance to be created...') 87 | print('(usually it takes around a minute, but might take much longer)') 88 | begin_time = datetime.datetime.now() 89 | next_time = begin_time 90 | try: 91 | while True: 92 | # Repeat status request every N seconds 93 | if datetime.datetime.now() > next_time: 94 | spot_fleet_request = aws.get_spot_fleet_request(spot_fleet_request_id) 95 | next_time += datetime.timedelta(seconds=5) 96 | 97 | # Compute time spend in waiting 98 | elapsed = datetime.datetime.now() - begin_time 99 | 100 | # Check request state and activity status 101 | request_state = spot_fleet_request['SpotFleetRequestState'] 102 | if request_state == 'active': 103 | spot_request_status = spot_fleet_request['ActivityStatus'] 104 | if spot_request_status == 'fulfilled': 105 | break 106 | else: 107 | print('Elapsed {}s. Spot request is {} and has status `{}`' 108 | .format(elapsed.seconds, request_state, spot_request_status), end='\r') 109 | else: 110 | print('Elapsed {}s. Spot request is {}'.format(elapsed.seconds, request_state), end='\r') 111 | 112 | sys.stdout.flush() # ensure stdout is flushed immediately. 113 | time.sleep(0.5) 114 | except KeyboardInterrupt: 115 | print('\n') 116 | print('Interrupting...') 117 | 118 | # Cancel spot instance request 119 | aws.cancel_spot_fleet_request(spot_fleet_request_id) 120 | 121 | raise CommandError('Spot request has been cancelled.') 122 | print('\nSpot instance was created in {} seconds.\n'.format((datetime.datetime.now() - begin_time).seconds)) 123 | 124 | # Get id of the created instance 125 | spot_fleet_instances = aws.get_spot_fleet_instances(spot_fleet_request_id) 126 | instance_id = spot_fleet_instances[0]['InstanceId'] 127 | 128 | # Get information about the created instance 129 | instance_info = aws.get_instance(instance_id) 130 | 131 | # Make requests to attach persistent volumes 132 | with print_scope('Attaching persistent volumes:'): 133 | for volume_spec in portal_spec['persistent_volumes']: 134 | response = aws.attach_volume(instance_id, volume_spec['volume_id'], volume_spec['device']) 135 | 136 | # Check status code 137 | if response['State'] not in ['attaching', 'attached']: 138 | raise CommandError('Could not attach persistent volume `{}`'.format(volume_spec['volume_id'])) 139 | 140 | # Wait for persistent volumes to be attached 141 | print('Waiting for the persistent volumes to be attached...') 142 | begin_time = datetime.datetime.now() 143 | next_time = begin_time 144 | while True: 145 | # Repeat status request every N seconds 146 | if datetime.datetime.now() > next_time: 147 | volumes = aws.get_volumes_by_id(volume_ids) 148 | next_time += datetime.timedelta(seconds=1) 149 | 150 | # Compute time spend in waiting 151 | elapsed = datetime.datetime.now() - begin_time 152 | 153 | if all([volume['Attachments'][0]['State'] == 'attached' for volume in volumes]): 154 | break 155 | else: 156 | states = ['{} - `{}`'.format(volume['VolumeId'], volume['Attachments'][0]['State']) 157 | for volume in volumes] 158 | print('Elapsed {}s. States: {}'.format(elapsed.seconds, ', '.join(states)), end='\r') 159 | 160 | sys.stdout.flush() # ensure stdout is flushed immediately. 161 | time.sleep(0.5) 162 | print('\nPersistent volumes were attached in {} seconds.\n'.format((datetime.datetime.now() - begin_time).seconds)) 163 | 164 | # Configure ssh connection via fabric 165 | fab_conn = fab.create_connection(instance_info['PublicDnsName'], auth_spec['user'], auth_spec['identity_file']) 166 | 167 | with print_scope('Preparing the instance:', 'Instance is ready.\n'): 168 | # Mount persistent volumes 169 | for i in range(len(portal_spec['persistent_volumes'])): 170 | with step('Mount volume #{}'.format(i), error_message='Could not mount volume', 171 | catch=[RuntimeError]): 172 | volume_spec = portal_spec['persistent_volumes'][i] 173 | 174 | # Mount volume 175 | fab.mount_volume(fab_conn, volume_spec['device'], volume_spec['mount_point'], 176 | auth_spec['user'], auth_spec['group']) 177 | 178 | # Store extra information in volume's tags 179 | aws.add_tags(volume_spec['volume_id'], {'mount-point': volume_spec['mount_point']}) 180 | 181 | # TODO: consider importing and executing custom fab tasks instead 182 | # Install extra python packages, if needed 183 | if 'provision_actions' in compute_spec and len(compute_spec['provision_actions']) > 0: 184 | for action_spec in compute_spec['provision_actions']: 185 | if action_spec['name'] == 'install-python-packages': 186 | virtual_env = action_spec['args']['virtual_env'] 187 | packages = action_spec['args']['packages'] 188 | with step('Install extra python packages', error_message='Could not install python packages', 189 | catch=[RuntimeError]): 190 | fab.install_python_packages(fab_conn, virtual_env, packages) 191 | 192 | # Print summary 193 | print('Portal `{}` is now opened.'.format(portal_name)) 194 | with print_scope('Summary:', ''): 195 | with print_scope('Instance:'): 196 | print('Id: {}'.format(instance_id)) 197 | print('Type: {}'.format(instance_info['InstanceType'])) 198 | print('Public IP: {}'.format(instance_info['PublicIpAddress'])) 199 | print('Public DNS name: {}'.format(instance_info['PublicDnsName'])) 200 | with print_scope('Persistent volumes:'): 201 | for volume_spec in portal_spec['persistent_volumes']: 202 | print('{}: {}'.format(volume_spec['device'], volume_spec['mount_point'])) 203 | 204 | # Print ssh command 205 | print('Use the following command to connect to the remote machine:') 206 | print('ssh -i "{}" {}@{}'.format(auth_spec['identity_file'], 207 | auth_spec['user'], 208 | instance_info['PublicDnsName'])) 209 | 210 | def close_portal(self, portal_spec, portal_name): 211 | # Create AWS client 212 | aws = self._create_client() 213 | 214 | with print_scope('Retrieving data from AWS:', 'Done.\n'): 215 | # Get current user 216 | with step('Get user identity'): 217 | user = aws.get_user_identity() 218 | 219 | # Get spot instance 220 | with step('Get spot instance', error_message='Portal `{}` does not seem to be opened'.format(portal_name), 221 | catch=[RuntimeError]): 222 | spot_instance = aws.find_spot_instance(portal_name, user['Arn']) 223 | 224 | if spot_instance is None: 225 | raise RuntimeError('Instance is not running') 226 | 227 | spot_fleet_request_id = \ 228 | list(filter(lambda tag: tag['Key'] == 'aws:ec2spot:fleet-request-id', spot_instance['Tags']))[0]['Value'] 229 | 230 | # Get spot instance 231 | with step('Get spot request', error_message='Portal `{}` does not seem to be opened'.format(portal_name), 232 | catch=[RuntimeError]): 233 | spot_fleet_request = aws.get_spot_fleet_request(spot_fleet_request_id) 234 | 235 | if spot_fleet_request is None: 236 | raise RuntimeError('Could not find spot instance request') 237 | 238 | # TODO: print fleet and instance statistics 239 | 240 | # Cancel spot instance request 241 | aws.cancel_spot_fleet_request(spot_fleet_request_id) 242 | 243 | # Clean up volumes' tags 244 | volume_ids = [volume['Ebs']['VolumeId'] 245 | for volume in spot_instance['BlockDeviceMappings'] 246 | if not volume['Ebs']['DeleteOnTermination']] 247 | aws.remove_tags(volume_ids, 'mount-point') 248 | 249 | print('Portal `{}` has been closed.'.format(portal_name)) 250 | 251 | def show_portal_info(self, portal_spec, portal_name): 252 | # Create AWS client 253 | aws = self._create_client() 254 | 255 | # Define shortcut 256 | auth_spec = portal_spec['compute']['auth'] 257 | 258 | volumes = [] 259 | with print_scope('Retrieving data from AWS:', 'Done.\n'): 260 | # Get current user 261 | with step('Get user identity'): 262 | aws_user = aws.get_user_identity() 263 | 264 | # Get spot instance 265 | with step('Get spot instance', error_message='Portal `{}` does not seem to be opened'.format(portal_name), 266 | catch=[RuntimeError]): 267 | instance_info = aws.find_spot_instance(portal_name, aws_user['Arn']) 268 | 269 | # Get persistent volumes, if portal is opened 270 | if instance_info is not None: 271 | with step('Get volumes'): 272 | volume_ids = [volume['Ebs']['VolumeId'] 273 | for volume in instance_info['BlockDeviceMappings'] 274 | if not volume['Ebs']['DeleteOnTermination']] 275 | volumes = aws.get_volumes_by_id(volume_ids) 276 | 277 | # Print status 278 | if instance_info is not None: 279 | with print_scope('Summary:', ''): 280 | print('Name: {}'.format(portal_name)) 281 | print('Status: open') 282 | 283 | with print_scope('Instance:', ''): 284 | print('Id: {}'.format(instance_info['InstanceId'])) 285 | print('Type: {}'.format(instance_info['InstanceType'])) 286 | print('Public IP: {}'.format(instance_info['PublicIpAddress'])) 287 | print('Public DNS name: {}'.format(instance_info['PublicDnsName'])) 288 | print('User: {}'.format(auth_spec['user'])) 289 | 290 | with print_scope('Persistent volumes:', ''): 291 | for i in range(len(volumes)): 292 | volume = volumes[i] 293 | with print_scope('Volume #{}:'.format(i), ''): 294 | self._print_volume_info(volume) 295 | 296 | # Print ssh command 297 | with print_scope('Use the following command to connect to the remote machine:'): 298 | print('ssh -i "{}" {}@{}'.format(auth_spec['identity_file'], 299 | auth_spec['user'], 300 | instance_info['PublicDnsName'])) 301 | else: 302 | with print_scope('Summary:'): 303 | print('Name: {}'.format(portal_name)) 304 | print('Status: close') 305 | 306 | def get_portal_info_field(self, portal_spec, portal_name, field): 307 | # Define shortcut 308 | auth_spec = portal_spec['compute']['auth'] 309 | 310 | if field == 'name': 311 | return portal_name 312 | if field == 'user': 313 | return auth_spec['user'] 314 | if field == 'key': 315 | return auth_spec['identity_file'] 316 | 317 | # Create AWS client 318 | aws = self._create_client() 319 | 320 | # Get current user 321 | aws_user = aws.get_user_identity() 322 | 323 | # Get spot instance 324 | instance_info = aws.find_spot_instance(portal_name, aws_user['Arn']) 325 | 326 | if field == 'status': 327 | return 'open' if instance_info is not None else 'close' 328 | 329 | # If portal is closed, we cannot provide any other information 330 | if instance_info is None: 331 | return None 332 | 333 | if field == 'id': 334 | return instance_info['InstanceId'] 335 | if field == 'type': 336 | return instance_info['InstanceType'] 337 | if field == 'host': 338 | return instance_info['PublicDnsName'] 339 | if field == 'ip': 340 | return instance_info['PublicIpAddress'] 341 | if field == 'remote': 342 | return '{}@{}'.format(auth_spec['user'], instance_info['PublicDnsName']) 343 | 344 | return None 345 | 346 | def get_ssh_params(self, portal_spec, portal_name): 347 | # Create AWS client 348 | aws = self._create_client() 349 | 350 | # Define shortcut 351 | auth_spec = portal_spec['compute']['auth'] 352 | 353 | with print_scope('Retrieving data from AWS:', 'Done.\n'): 354 | # Get current user 355 | with step('Get user identity'): 356 | aws_user = aws.get_user_identity() 357 | 358 | # Get spot instance 359 | with step('Get spot instance', error_message='Portal `{}` does not seem to be opened'.format(portal_name)): 360 | instance_info = aws.find_spot_instance(portal_name, aws_user['Arn']) 361 | if instance_info is None: 362 | raise CommandError('Portal `{}` does not seem to be opened'.format(portal_name)) 363 | 364 | # Return parameters for ssh 365 | return (auth_spec['identity_file'], 366 | auth_spec['user'], 367 | instance_info['PublicDnsName'], 368 | False) 369 | 370 | def list_volumes(self, args): 371 | # Create AWS client 372 | aws = self._create_client() 373 | 374 | with print_scope('Retrieving data from AWS:', 'Done.\n'): 375 | if not args.all: 376 | # Get current user 377 | with step('Get user identity'): 378 | user = aws.get_user_identity() 379 | 380 | # Get list of volumes owned by user 381 | with step('Get list of proper volumes'): 382 | volumes = aws.get_volumes(self._get_proper_volume_filter(user)) 383 | else: 384 | # Get list of all volumes 385 | with step('Get list of volumes'): 386 | volumes = aws.get_volumes() 387 | 388 | # Filter tags of every volume 389 | volumes = (self._filter_tags(volume) for volume in volumes) 390 | 391 | # Pretty print list of volumes 392 | list(map(print_volume, volumes)) 393 | 394 | def create_volume(self, args): 395 | # Create AWS client 396 | aws = self._create_client() 397 | 398 | with print_scope('Retrieving data from AWS:', 'Done.\n'): 399 | # Get current user 400 | with step('Get user identity'): 401 | user = aws.get_user_identity() 402 | 403 | # Ensure that instance does not yet exist 404 | with step('Get Availability Zones'): 405 | availability_zones = aws.get_availability_zones() 406 | 407 | print('Creating new persistent volume.') 408 | 409 | # Get properties of the new volume 410 | name = args.name 411 | size = args.size 412 | availability_zone = args.zone 413 | snapshot_id = args.snapshot 414 | 415 | # Ask for name, if not provided 416 | if name is None: 417 | print('Enter name for the new volume (no name by default): ', end='') 418 | name = input() or None 419 | 420 | # Ask for size, if not provide 421 | if args.size is None: 422 | print('Enter size of the new volume in Gb ({}): '.format(self._default_size), end='') 423 | size = input() or self._default_size 424 | try: 425 | size = int(size) 426 | except ValueError as e: 427 | raise CommandError('Size has to be an integer.') 428 | 429 | # Check size parameter 430 | if size < self._min_size: 431 | raise CommandError('Specified size {}Gb is smaller than the lower limit of {}Gb.' 432 | .format(size, self._min_size)) 433 | elif size > self._max_size: 434 | raise CommandError('Specified size {}Gb is bigger than the upper limit of {}Gb.' 435 | .format(size, self._max_size)) 436 | 437 | # Ask for availability zone, if not provided 438 | if availability_zone is None: 439 | print('Enter availability zone for the new volume ({}): '.format(availability_zones[0]), end='') 440 | availability_zone = input() or availability_zones[0] 441 | 442 | # Check availability zone 443 | if availability_zone not in availability_zones: 444 | raise CommandError('Unexpected availability zone "{}". Available zones are: {}.' 445 | .format(availability_zone, ', '.join(availability_zones))) 446 | 447 | # Set tags 448 | tags = {'Name': name, 'created-by': user['Arn'], self._proper_tag_key: self._proper_tag_value} 449 | 450 | # Add user-specified tags, if provided 451 | if args.tags is not None: 452 | tags.update(self._parse_tags(args.tags)) 453 | 454 | # Create volume 455 | volume_id = aws.create_volume(size, availability_zone, tags, snapshot_id) 456 | 457 | print('New persistent volume has been created.\nVolume id: {}'.format(volume_id)) 458 | 459 | def update_volume(self, args): 460 | # Create AWS client 461 | aws = self._create_client() 462 | 463 | updates = 0 464 | 465 | # Get user tags 466 | tags = self._parse_tags(args.tags) 467 | 468 | # Add 'Name' tag, if specified 469 | if args.name is not None: 470 | tags.update({'Name': args.name}) 471 | 472 | # Update tags, if specified 473 | if len(tags) > 0: 474 | aws.add_tags(args.volume_id, tags) 475 | updates += len(tags) 476 | 477 | # Update size, if specified 478 | if args.size is not None: 479 | aws.update_volume(args.volume_id, args.size) 480 | updates += 1 481 | 482 | if updates > 0: 483 | print('Volume {} is updated.'.format(args.volume_id)) 484 | else: 485 | print('Nothing to do.') 486 | 487 | def delete_volume(self, args): 488 | # Create AWS client 489 | aws = self._create_client() 490 | 491 | with print_scope('Retrieving data from AWS:', 'Done.\n'): 492 | # Get current user 493 | with step('Get user identity'): 494 | user = aws.get_user_identity() 495 | 496 | # Ensure that instance does not yet exist 497 | with step('Get volume details'): 498 | volume = aws.get_volumes_by_id(args.volume_id)[0] 499 | 500 | if not self._is_proper_volume(volume, user) and not args.force: 501 | raise CommandError('Volume {} is not owned by you. Use -f flag to force deletion.'.format(args.volume_id)) 502 | 503 | aws.delete_volume(args.volume_id) 504 | 505 | print('Volume {} is deleted.'.format(args.volume_id)) 506 | 507 | def _create_client(self): 508 | assert self._config 509 | 510 | return AwsClient(self._config['access_key'], 511 | self._config['secret_key'], self._config['region']) 512 | 513 | def _print_volume_info(self, volume): 514 | tags = volume['Tags'] if 'Tags' in volume else [] 515 | 516 | # Look for specific tags 517 | name = next((tag['Value'] for tag in tags if tag['Key'] == 'Name'), '') 518 | mount_point = next((tag['Value'] for tag in tags if tag['Key'] == 'mount-point'), 'n/a') 519 | 520 | print('Id: {}'.format(volume['VolumeId'])) 521 | print('Name: {}'.format(name)) 522 | print('Size: {}Gb'.format(volume['Size'])) 523 | print('Device: {}'.format(volume['Attachments'][0]['Device'])) 524 | print('Mount point: {}'.format(mount_point)) 525 | 526 | def _filter_tags(self, volume): 527 | if 'Tags' in volume: 528 | volume['Tags'] = [tag for tag in volume['Tags'] if tag['Key'] not in self._service_tags] 529 | 530 | return volume 531 | 532 | def _is_proper_volume(self, volume, user): 533 | try: 534 | tags = aws_helpers.from_aws_tags(volume['Tags']) 535 | return tags[self._proper_tag_key] == self._proper_tag_value and tags['created-by'] == user['Arn'] 536 | except KeyError: 537 | return False 538 | 539 | def _get_proper_volume_filter(self, user): 540 | return {'tag:{}'.format(self._proper_tag_key): self._proper_tag_value, 'tag:created-by': user['Arn']} 541 | 542 | def _parse_tags(self, tags): 543 | """ 544 | Parse tags from command line arguments. 545 | :param tags: List of tag args in 'key:value' format. 546 | :return: Tags in dictionary format 547 | """ 548 | return {key_value[0]: key_value[1] for key_value in 549 | (tag.split(':') for tag in (tags or [])) 550 | if len(key_value) == 2 and len(key_value[0]) > 0 and len(key_value[1]) > 0} 551 | --------------------------------------------------------------------------------