├── __init__.py ├── ansible ├── hosts ├── ansible.cfg └── play.yml ├── utils ├── __init__.py ├── aws_spot_exception.py ├── pricing_util.py ├── az_zone.py └── aws_spot_instance.py ├── .gitignore ├── pip_requirements.txt ├── user_config.py ├── README.md └── main.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ansible/hosts: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pickle 2 | *.pyc 3 | venv 4 | play.retry -------------------------------------------------------------------------------- /utils/aws_spot_exception.py: -------------------------------------------------------------------------------- 1 | 2 | class SpotConstraintException(Exception): 3 | pass 4 | -------------------------------------------------------------------------------- /ansible/ansible.cfg: -------------------------------------------------------------------------------- 1 | [defaults] 2 | host_key_checking = False 3 | hostfile = hosts 4 | remote_user = ubuntu -------------------------------------------------------------------------------- /ansible/play.yml: -------------------------------------------------------------------------------- 1 | --- 2 | - hosts: all 3 | tasks: 4 | # - name: update yadlt 5 | # pip: name=yadlt state=latest 6 | 7 | # - name: install sklern dev 8 | # shell: sudo rm -rf ~/.cache/matplotlib && sudo apt-get install build-essential python-dev python-setuptools python-numpy python-scipy libatlas-dev libatlas3gf-base && sudo pip install cython && git clone "https://github.com/scikit-learn/scikit-learn.git" && cd scikit-learn && make && sudo python setup.py install 9 | -------------------------------------------------------------------------------- /pip_requirements.txt: -------------------------------------------------------------------------------- 1 | ansible==2.6.18 2 | appscript==1.0.1 3 | awscli==1.10.32 4 | boto3==1.3.1 5 | botocore==1.4.22 6 | cffi==1.6.0 7 | colorama==0.3.3 8 | cryptography==1.3.2 9 | docutils==0.12 10 | enum34==1.1.6 11 | futures==3.0.5 12 | httplib2==0.9.2 13 | idna==2.1 14 | ipaddress==1.0.16 15 | Jinja2==2.8 16 | jmespath==0.9.0 17 | MarkupSafe==0.23 18 | numpy==1.11.0 19 | paramiko==2.0.9 20 | pyasn1==0.1.9 21 | pycparser==2.14 22 | pycrypto==2.6.1 23 | python-dateutil==2.5.3 24 | PyYAML==5.1 25 | rsa==3.3 26 | s3transfer==0.0.1 27 | six==1.10.0 28 | wheel==0.24.0 29 | -------------------------------------------------------------------------------- /user_config.py: -------------------------------------------------------------------------------- 1 | 2 | # =============== Default configs ================== 3 | AWS_REGIONS = ['us-east-1', 'us-west-2', 'us-west-1', 'eu-west-1', 'eu-central-1', 'ap-southeast-1', 4 | 'ap-northeast-1', 'ap-northeast-2', 'ap-southeast-2', 'sa-east-1'] 5 | AZ_PICKLE_EXPIRE_TIME_DAYS = 30 6 | SPOT_PRICING_PICKLE_EXPIRE_SEC = 30 * 60 7 | 8 | 9 | # =============== Personal config ================== 10 | AWS_ACCESS_KEY_ID = '' 11 | AWS_SECRET_ACCESS_KEY = '' 12 | KEY_NAME = '' # 13 | SECURITY_GROUP_ID = '' 14 | SECURITY_GROUP = '' 15 | AMI_ID = '' 16 | INSTANCE_TYPES = ['g2.2xlarge'] 17 | BID = 0.20 18 | SSH_USER_NAME = 'ubuntu' 19 | QTY_INSTANCES = 1 20 | SERVER_TIMEOUT = 60 * 5 21 | 22 | WAIT_FOR_HTTP = True 23 | WAIT_FOR_SSH = True 24 | OPEN_IN_BROWSER = True 25 | OPEN_SSH = True 26 | ADD_TO_ANSIBLE_HOSTS = True 27 | RUN_ANSIBLE = True 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This repository is no longer relevant given the changes to AWS spot instance pricing. 2 | 3 | ## AWS-SPOT-BOT 4 | A tool for finding and launching the cheapest and most reliable AWS spot instances. Using an unsophisticated algorithm it launches instances in regions that have have a low price and a low price variance so that your instance is less likely to get shut down by changes in demand. It is primarily intended for machine learning researchers to be able to spawn GPU instances without incurring large costs. 5 | 6 | ### Usage 7 | Edit `user_config.py` to your specifications then run `main.py`. 8 | 9 | ### Ansible 10 | For convenience Ansible is integrated into this tool. This allows one to automatically run tasks on the servers after they are launched. 11 | This saves one from needing to rebuild AMIs every time a change is required. See `user_config.py` and `main.py` for more details. Be warned that 12 | hosts are not automatically removed from the Ansible `hosts` file. 13 | 14 | 15 | ### DISCLAIMER 16 | This library is something I threw together for my personal use. The code is not well tested and is in no way production worthy. Feel free to contribute. 17 | 18 | 19 | ### Requested contributions 20 | - add a check to report how many instances you currently have running 21 | - add to pypy 22 | - search the project for "todo" and improve those items 23 | 24 | 25 | ### License 26 | MIT 27 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from utils import pricing_util 4 | from utils.aws_spot_instance import AWSSpotInstance 5 | from utils.aws_spot_exception import SpotConstraintException 6 | import aws_spot_bot.user_config as uconf 7 | 8 | 9 | def launch_instances(qty): 10 | """Launches QTY instances and returns the instance objects.""" 11 | best_az = pricing_util.get_best_az() 12 | launched_instances = [] 13 | print "Best availability zone:", best_az.name 14 | 15 | for idx in range(qty): 16 | print '>> Launching instance #%s' % idx 17 | si = AWSSpotInstance(best_az.region, best_az.name, uconf.INSTANCE_TYPES[0], uconf.AMI_ID, uconf.BID) 18 | si.request_instance() 19 | try: 20 | si.get_ip() 21 | except SpotConstraintException, e: 22 | print(">> ", e.message) 23 | si.cancel_spot_request() 24 | continue 25 | launched_instances.append(si) 26 | 27 | return launched_instances 28 | 29 | 30 | if __name__ == '__main__': 31 | instances = launch_instances(uconf.QTY_INSTANCES) 32 | 33 | for si in instances: 34 | if uconf.WAIT_FOR_HTTP: 35 | si.wait_for_http() 36 | if uconf.WAIT_FOR_SSH: 37 | si.wait_for_ssh() 38 | if uconf.OPEN_IN_BROWSER: 39 | si.open_in_browser() 40 | if uconf.OPEN_SSH: 41 | si.open_ssh_term() 42 | if uconf.ADD_TO_ANSIBLE_HOSTS: 43 | si.add_to_ansible_hosts() 44 | 45 | if uconf.RUN_ANSIBLE: 46 | os.system('cd ansible && ansible-playbook -s play.yml') 47 | 48 | -------------------------------------------------------------------------------- /utils/pricing_util.py: -------------------------------------------------------------------------------- 1 | from operator import attrgetter 2 | import os 3 | import pickle 4 | import datetime 5 | 6 | import boto3 7 | 8 | from az_zone import AZZone 9 | import aws_spot_bot.user_config as uconf 10 | 11 | 12 | def modification_date(filename): 13 | t = os.path.getmtime(filename) 14 | return datetime.datetime.fromtimestamp(t) 15 | 16 | 17 | def generate_region_AZ_dict(): 18 | """ Generates a dict of {'region': [availability_zones, az2]} """ 19 | print "Getting all regions and AZ's... (this may take some time)" 20 | region_az = {} 21 | for region in uconf.AWS_REGIONS: 22 | client = boto3.setup_default_session(region_name=region) 23 | client = boto3.client('ec2', aws_access_key_id=uconf.AWS_ACCESS_KEY_ID, aws_secret_access_key=uconf.AWS_SECRET_ACCESS_KEY) 24 | 25 | avail_zones = [] 26 | for zone in client.describe_availability_zones()['AvailabilityZones']: 27 | if zone['State'] == 'available': 28 | avail_zones.append(zone['ZoneName']) 29 | region_az[region] = avail_zones 30 | print ">>", region 31 | 32 | return region_az 33 | 34 | 35 | def get_initialized_azs(): 36 | az_pickle_fn = "az_dict.pickle" 37 | az_objects_fn = 'az_objs_list.pickle' 38 | last_valid_AZ_time = datetime.datetime.now() - datetime.timedelta(days=uconf.AZ_PICKLE_EXPIRE_TIME_DAYS) 39 | last_valid_spot_time = datetime.datetime.now() - datetime.timedelta(seconds=uconf.SPOT_PRICING_PICKLE_EXPIRE_SEC) 40 | 41 | # Loads AZs from pickle if it exists and is less than 30 days old, else fetches them 42 | if os.path.isfile(az_pickle_fn) and modification_date(az_pickle_fn) > last_valid_AZ_time: 43 | az_dict = pickle.load(open(az_pickle_fn, "rb")) 44 | else: 45 | az_dict = generate_region_AZ_dict() 46 | pickle.dump(az_dict, open(az_pickle_fn, "wb")) 47 | 48 | # Loads AZs from pickle if it exists and is less than 30 days old, else fetches them 49 | if False and os.path.isfile(az_objects_fn) and modification_date(az_objects_fn) > last_valid_spot_time: 50 | az_objects = pickle.load(open(az_objects_fn, "rb")) 51 | else: 52 | az_objects = [] 53 | # Get the spot pricing for each AZ 54 | for region, azs in az_dict.iteritems(): 55 | for az in azs: 56 | az_obj = AZZone(region, az) 57 | az_objects.append(az_obj) 58 | 59 | # pickle.dump(az_objects, open(az_objects_fn, "wb")) 60 | 61 | return az_objects 62 | 63 | 64 | def get_best_az(): 65 | azs = get_initialized_azs() 66 | 67 | for az in azs: 68 | az.calculate_score(uconf.INSTANCE_TYPES, 0.65) 69 | 70 | # Sort the AZs by score and return the best one 71 | sorted_azs = sorted(azs, key=attrgetter('score')) 72 | 73 | for az in sorted_azs: 74 | print az.name 75 | print '>> price:', az.current_price 76 | print '>> mean:', az.spot_price_mean 77 | print '>> variance:', az.spot_price_variance 78 | print '>> score:', az.score 79 | 80 | return sorted_azs[-1] 81 | -------------------------------------------------------------------------------- /utils/az_zone.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import numpy as np 3 | import boto3 4 | 5 | import aws_spot_bot.user_config as uconf 6 | 7 | 8 | class AZZone(): 9 | 10 | def __init__(self, region, name): 11 | self.region = region 12 | self.name = name 13 | self.client = boto3.setup_default_session(region_name=self.region) 14 | self.client = boto3.client('ec2', aws_access_key_id=uconf.AWS_ACCESS_KEY_ID, 15 | aws_secret_access_key=uconf.AWS_SECRET_ACCESS_KEY) 16 | self.spot_pricing_history = None 17 | self.score = None 18 | 19 | @property 20 | def spot_price_variance(self): 21 | prices = [float(record['SpotPrice']) for record in self.spot_pricing_history] 22 | return np.var(prices) 23 | 24 | @property 25 | def spot_price_mean(self): 26 | prices = [float(record['SpotPrice']) for record in self.spot_pricing_history] 27 | return np.mean(prices) 28 | 29 | @property 30 | def current_price(self): 31 | if self.spot_pricing_history: 32 | return float(self.spot_pricing_history[0]['SpotPrice']) 33 | elif self.spot_pricing_history == []: 34 | return None 35 | else: 36 | raise Exception("You must fetch the history before calling this property") 37 | 38 | def get_spot_pricing_history(self, instance_types, product_descriptions=['Linux/UNIX']): 39 | """ Returns the spot price history given a specified AZ and region.""" 40 | print "Getting spot prices for", self.name 41 | 42 | response = self.client.describe_spot_price_history( 43 | DryRun=False, 44 | StartTime=datetime.datetime.now() - datetime.timedelta(days=7), 45 | EndTime=datetime.datetime.now(), 46 | InstanceTypes=instance_types, 47 | AvailabilityZone=self.name, 48 | ProductDescriptions=product_descriptions) 49 | 50 | self.spot_pricing_history = response.get('SpotPriceHistory', []) 51 | return response 52 | 53 | def calculate_score(self, instance_types, bid, update=False): 54 | if self.spot_pricing_history is None: 55 | self.get_spot_pricing_history(instance_types) 56 | elif update: 57 | self.get_spot_pricing_history(instance_types) 58 | 59 | # TODO: This should be removed but I am lazy and this is easier than catching exceptions 60 | # @jgre can you fix? 61 | if self.spot_pricing_history == []: 62 | return -1e10 63 | 64 | # We are not interested in this AZ if its more than the bid, so lets just return 65 | if self.current_price > bid: 66 | return 0 67 | 68 | # Here we multiply each item by a weight. 69 | # These weights are arbitrary and probably not ideal. 70 | # There is much room for improvement on this scoring algorithm, but this algorithm 71 | # works for most light use cases. Feel free to contribute! 72 | current_price_s = bid - self.current_price 73 | variance_s = -5 * (self.spot_price_variance * self.spot_price_mean) 74 | mean_s = 0.5 * (bid - self.spot_price_mean) 75 | 76 | self.score = current_price_s + variance_s + mean_s 77 | return self.score 78 | -------------------------------------------------------------------------------- /utils/aws_spot_instance.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | import webbrowser 5 | import socket 6 | import datetime 7 | 8 | import appscript 9 | import boto3 10 | 11 | import aws_spot_bot.user_config as uconf 12 | from aws_spot_exception import SpotConstraintException 13 | 14 | class AWSSpotInstance(): 15 | 16 | def __init__(self, region, az_zone, instance_type, ami_id, bid): 17 | self.random_id = str(random.random() * 1000) 18 | self.az_zone = az_zone 19 | self.region = region 20 | self.instance_type = instance_type 21 | self.ip = None 22 | self.bid = bid 23 | self.ami_id = ami_id 24 | self.key_name = uconf.KEY_NAME 25 | self.security_group_id = uconf.SECURITY_GROUP_ID 26 | 27 | # == Boto3 related tools == 28 | boto3.setup_default_session(region_name=self.region) 29 | self.client = boto3.client('ec2', aws_access_key_id=uconf.AWS_ACCESS_KEY_ID, aws_secret_access_key=uconf.AWS_SECRET_ACCESS_KEY) 30 | session = boto3.session.Session(aws_access_key_id=uconf.AWS_ACCESS_KEY_ID, aws_secret_access_key=uconf.AWS_SECRET_ACCESS_KEY, region_name='us-east-1') 31 | self.ec2_instance = session.resource('ec2') 32 | 33 | # == Values that we need to wait to get from AWS == 34 | self.spot_instance_request_id = None 35 | self.instance_id = None 36 | self.status_code = None 37 | self.ip = None 38 | 39 | def request_instance(self): 40 | """Boots the instance on AWS""" 41 | print ">> Requesting instance" 42 | response = self.client.request_spot_instances( 43 | SpotPrice=str(self.bid), 44 | ClientToken=self.random_id, 45 | InstanceCount=1, 46 | Type='one-time', 47 | # ValidUntil=datetime.datetime.utcnow() + datetime.timedelta(seconds=60 * 100), 48 | LaunchSpecification={ 49 | 'ImageId': self.ami_id, 50 | 'KeyName': self.key_name, 51 | 'InstanceType': self.instance_type, 52 | 'Placement': { 53 | 'AvailabilityZone': self.az_zone, 54 | }, 55 | 'EbsOptimized': False, 56 | 'SecurityGroupIds': [ 57 | self.security_group_id 58 | ] 59 | } 60 | ) 61 | self.spot_instance_request_id = response.get('SpotInstanceRequests')[0].get('SpotInstanceRequestId') 62 | return response 63 | 64 | def get_spot_request_status(self): 65 | print ">> Checking instance status" 66 | response = self.client.describe_spot_instance_requests( 67 | SpotInstanceRequestIds=[self.spot_instance_request_id], 68 | ) 69 | self.status_code = response.get('SpotInstanceRequests')[0].get('Status').get('Code') 70 | self.instance_id = response.get('SpotInstanceRequests')[0].get('InstanceId') 71 | return {'status_code': self.status_code, 'instance_id': self.instance_id} 72 | 73 | def cancel_spot_request(self): 74 | print ">> Cancelling spot request" 75 | response = self.client.cancel_spot_instance_requests( 76 | SpotInstanceRequestIds=[self.spot_instance_request_id], 77 | ) 78 | return response 79 | 80 | def get_ip(self): 81 | if self.ip: 82 | return self.ip 83 | 84 | if not self.status_code: 85 | self.get_spot_request_status() 86 | 87 | for idx in range(100): 88 | if not self.instance_id: 89 | if 'pending' in self.status_code: 90 | time.sleep(3) 91 | self.get_spot_request_status() 92 | else: 93 | raise SpotConstraintException("Spot constraints can't be met: " + self.status_code) 94 | else: 95 | self.ip = self.ec2_instance.Instance(self.instance_id).public_ip_address 96 | break 97 | 98 | # TODO: improve this 99 | if not self.ip: 100 | raise Exception('There is no public IP address for this instance... Maybe the bid failed..') 101 | 102 | return self.ip 103 | 104 | def terminate(self): 105 | """Terminates the instance on AWS""" 106 | pass 107 | 108 | def open_ssh_term(self): 109 | """Opens your default terminal and starts SSH session to the instance""" 110 | # TODO. This wont work on non osx machines. 111 | appscript.app('Terminal').do_script('ssh ' + uconf.SSH_USER_NAME + '@' + self.get_ip()) 112 | 113 | def open_in_browser(self, port='80'): 114 | """Opens the instance in your browser to the specified port. 115 | Default port is Jupyter server 116 | """ 117 | webbrowser.open_new_tab('http://' + self.ip + ':' + port) 118 | 119 | def add_to_ansible_hosts(self): 120 | path = os.path.dirname(os.path.dirname(__file__)) 121 | with open(path + '/ansible/hosts', 'a') as file: 122 | file.write(str(self.ip) + '\n') 123 | 124 | def wait_for_http(self, port=80, timeout=uconf.SERVER_TIMEOUT): 125 | """Waits until port 80 is open on this instance. 126 | This is a useful way to check if the system has booted. 127 | """ 128 | self.wait_for_port(port, timeout) 129 | 130 | def wait_for_ssh(self, port=22, timeout=uconf.SERVER_TIMEOUT): 131 | """Waits until port 22 is open on this instance. 132 | This is a useful way to check if the system has booted. 133 | """ 134 | self.wait_for_port(port, timeout) 135 | 136 | def wait_for_port(self, port, timeout=uconf.SERVER_TIMEOUT): 137 | """Waits until port is open on this instance. 138 | This is a useful way to check if the system has booted and the HTTP server is running. 139 | """ 140 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 141 | sock.settimeout(timeout) 142 | start = datetime.datetime.now() 143 | print ">> waiting for port", port 144 | 145 | if not self.get_ip(): 146 | raise Exception("Error getting IP for this instance. Instance must have an IP before calling this method.") 147 | 148 | while True: 149 | # We need this try block because depending on the parameters the system will cause the connection 150 | # to timeout early. 151 | try: 152 | if sock.connect_ex((self.get_ip(), port)): 153 | # we got a connection, lets return 154 | return 155 | else: 156 | time.sleep(3) 157 | except: 158 | # TODO: catch the timeout exception and ignore that, but every other exception should be raised 159 | # The system timeout, no problem 160 | pass 161 | 162 | if (datetime.datetime.now() - start).seconds > timeout: 163 | print (datetime.datetime.now() - start).seconds 164 | raise Exception("Connection timed out. Try increasing the timeout amount, or fix your server.") 165 | 166 | print ">> port %s is live" % (port) 167 | 168 | if __name__ == '__main__': 169 | import pricing_util 170 | # best_az = pricing_util.get_best_az() 171 | # print best_az.region 172 | # print best_az.name 173 | region = 'us-east-1' 174 | az_zone = 'us-east-1d' 175 | instance_type = uconf.INSTANCE_TYPES[0] 176 | si = AWSSpotInstance(region, az_zone, instance_type, uconf.AMI_ID, uconf.BID) 177 | response = si.request_instance() 178 | print si.get_ip() 179 | si.wait_for_ssh() 180 | si.wait_for_http() 181 | 182 | si.open_in_browser() 183 | si.open_ssh_term() 184 | --------------------------------------------------------------------------------