├── studio ├── util │ ├── __init__.py │ ├── logs.py │ └── gpu_util.py ├── artifacts │ └── __init__.py ├── queues │ ├── __init__.py │ ├── queues_setup.py │ ├── qclient_cache.py │ └── local_queue.py ├── credentials │ └── __init__.py ├── db_providers │ ├── __init__.py │ ├── local_db_provider.py │ ├── s3_provider.py │ └── db_provider_setup.py ├── experiments │ └── __init__.py ├── optimizer_plugins │ ├── __init__.py │ └── opt_util.py ├── payload_builders │ ├── __init__.py │ ├── unencrypted_payload_builder.py │ └── payload_builder.py ├── storage │ ├── __init__.py │ ├── storage_type.py │ ├── storage_setup.py │ ├── http_storage_handler.py │ ├── storage_handler_factory.py │ ├── storage_handler.py │ └── local_storage_handler.py ├── dependencies_policies │ ├── __init__.py │ ├── dependencies_policy.py │ └── studio_dependencies_policy.py ├── scripts │ ├── studio-runs │ ├── studio-run │ ├── studio-serve │ ├── studio-ui │ ├── studio-local-worker │ ├── studio-remote-worker │ ├── studio │ ├── studio-add-credentials │ ├── ec2_worker_startup.sh │ ├── gcloud_worker_startup.sh │ ├── studio-start-remote-worker │ └── install_studio.sh ├── completion_service │ ├── __init__.py │ ├── completion_service_testfunc.py │ ├── completion_service_testfunc_files.py │ ├── completion_service_testfunc_saveload.py │ ├── completion_service_client.py │ └── encryptor.py ├── static │ ├── tfs_small.png │ └── Studio.ml-icon-std-1000.png ├── torch │ ├── __init__.py │ ├── summary_test.py │ ├── saver.py │ └── summary.py ├── templates │ ├── error.html │ ├── user_details.html │ ├── all_experiments.html │ ├── dashboard.html │ ├── project_details.html │ ├── projects.html │ └── users.html ├── __init__.py ├── appengine_config.py ├── app.yaml ├── apiserver_config.yaml ├── run_magic.py.stub ├── aws │ ├── aws_amis.yaml │ └── aws_prices.yaml ├── default_config.yaml ├── client_config.yaml ├── server_config.yaml ├── patches │ └── requests │ │ └── models.py.patch ├── rmq_config.yml ├── postgres_provider.py ├── serve.py ├── firebase_provider.py ├── ed25519_key_util.py ├── deploy_apiserver.sh ├── fs_tracker.py ├── cloud_worker_util.py ├── remote_worker.py ├── git_util.py ├── experiment_submitter.py ├── magics.py ├── cli.py └── model.py ├── .git_archival.txt ├── .gitattributes ├── logo.png ├── docs ├── logo.png ├── _static │ └── img │ │ └── logo.png ├── docker.rst ├── examples.rst ├── testing.rst ├── authentication.rst ├── ci_testing.rst ├── ec2_setup.rst ├── cli.rst ├── local_filesystem_setup.rst ├── customenv.rst ├── index.rst ├── jupyter.rst ├── containers.rst ├── faq.rst ├── README.rst └── gcloud_setup.rst ├── extra_example_requirements.txt ├── tests ├── hyperparam_hello_world.py ├── check_style.sh ├── test_array.py ├── stop_experiment.py ├── test_config_http_server.yaml ├── tf_hello_world.py ├── model_increment.py ├── conflicting_args.py ├── art_hello_world.py ├── config_http_client.yaml ├── test_config_s3_storage.yaml ├── test_config_gcloud_storage.yaml ├── test_config_http_client.yaml ├── env_detect.py ├── save_model.py ├── test_bad_config.yaml ├── runner_test.py ├── test_config_s3.yaml ├── config_http_server.yaml ├── gpu_util_test.py ├── test_config_env.yaml ├── test_config_gs.yaml ├── test_config_auth.yaml ├── test_config_datacenter.yaml ├── fs_tracker_test.py ├── test_config.yaml ├── git_util_test.py ├── hyperparam_test.py ├── util_test.py ├── model_test.py ├── serving_test.py └── http_provider_hosted_test.py ├── MANIFEST.in ├── test_requirements-cs.txt ├── test_requirements.txt ├── requirements-cs.txt ├── examples ├── pytorch │ └── README.md ├── tensorflow │ ├── helloworld.py │ └── train_mnist.py ├── general │ ├── report_system_info.py │ ├── print_norm_linreg.py │ └── train_linreg.py └── keras │ ├── train_mnist.py │ ├── train_mnist_keras_mutligpu.py │ ├── multi_gpu.py │ ├── fashion_mnist.py │ └── train_cifar10.py ├── studioml.bib ├── studioml.bibtex ├── Dockerfile ├── .github └── dependabot.yml ├── requirements.txt ├── Dockerfile_keras_example ├── runtests.sh ├── Dockerfile_standalone_testing ├── .gitignore ├── .travis.yml └── test-runner.yaml /studio/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /studio/artifacts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /studio/queues/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /studio/credentials/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /studio/db_providers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /studio/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /studio/optimizer_plugins/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /studio/payload_builders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /studio/storage/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /studio/dependencies_policies/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.git_archival.txt: -------------------------------------------------------------------------------- 1 | ref-names: HEAD -> master 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | .git_archival.txt export-subst 2 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/studioml/studio/HEAD/logo.png -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/studioml/studio/HEAD/docs/logo.png -------------------------------------------------------------------------------- /extra_example_requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | keras 3 | jupyter 4 | tensorflow 5 | 6 | -------------------------------------------------------------------------------- /studio/scripts/studio-runs: -------------------------------------------------------------------------------- 1 | #!python 2 | from studio import cli 3 | cli.main() 4 | -------------------------------------------------------------------------------- /studio/scripts/studio-run: -------------------------------------------------------------------------------- 1 | #!python 2 | from studio import runner 3 | runner.main() 4 | -------------------------------------------------------------------------------- /studio/scripts/studio-serve: -------------------------------------------------------------------------------- 1 | #!python 2 | from studio import serve 3 | serve.main() 4 | -------------------------------------------------------------------------------- /studio/completion_service/__init__.py: -------------------------------------------------------------------------------- 1 | from .completion_service import CompletionService 2 | -------------------------------------------------------------------------------- /studio/scripts/studio-ui: -------------------------------------------------------------------------------- 1 | #!python 2 | from studio import apiserver 3 | apiserver.main() 4 | -------------------------------------------------------------------------------- /docs/_static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/studioml/studio/HEAD/docs/_static/img/logo.png -------------------------------------------------------------------------------- /studio/scripts/studio-local-worker: -------------------------------------------------------------------------------- 1 | #!python 2 | from studio import local_worker 3 | local_worker.main() 4 | -------------------------------------------------------------------------------- /studio/static/tfs_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/studioml/studio/HEAD/studio/static/tfs_small.png -------------------------------------------------------------------------------- /studio/scripts/studio-remote-worker: -------------------------------------------------------------------------------- 1 | #!python 2 | from studio import remote_worker 3 | remote_worker.main() 4 | -------------------------------------------------------------------------------- /studio/torch/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .saver import Saver, load_checkpoint, save_checkpoint 3 | from .summary import Reporter 4 | -------------------------------------------------------------------------------- /tests/hyperparam_hello_world.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | learning_rate = 0.3 4 | print(learning_rate) 5 | 6 | sys.stdout.flush() 7 | -------------------------------------------------------------------------------- /studio/static/Studio.ml-icon-std-1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/studioml/studio/HEAD/studio/static/Studio.ml-icon-std-1000.png -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst 2 | include studio/default_config.yaml 3 | include studio/templates/*.html 4 | include requirements.txt 5 | -------------------------------------------------------------------------------- /test_requirements-cs.txt: -------------------------------------------------------------------------------- 1 | pycodestyle 2 | pytest_xdist 3 | tensorflow 4 | keras 5 | certifi==2023.7.22 6 | pillow 7 | timeout_decorator 8 | 9 | -------------------------------------------------------------------------------- /test_requirements.txt: -------------------------------------------------------------------------------- 1 | pycodestyle 2 | pytest_xdist 3 | tensorflow 4 | keras 5 | certifi==2023.7.22 6 | pillow 7 | timeout_decorator 8 | 9 | -------------------------------------------------------------------------------- /docs/docker.rst: -------------------------------------------------------------------------------- 1 | To run a docker image with aws credentials us 2 | sudo docker run -v ${HOME}/.aws/credentials:/root/.aws/credentials:ro standalone_testing 3 | 4 | -------------------------------------------------------------------------------- /studio/templates/error.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 |

Error

5 | {{ errormsg }} 6 | {% endblock %} 7 | -------------------------------------------------------------------------------- /studio/__init__.py: -------------------------------------------------------------------------------- 1 | from setuptools_scm import get_version 2 | 3 | try: 4 | __version__ = get_version(root='..', relative_to=__file__) 5 | except BaseException: 6 | pass 7 | -------------------------------------------------------------------------------- /requirements-cs.txt: -------------------------------------------------------------------------------- 1 | pip 2 | setuptools_scm 3 | 4 | configparser 5 | 6 | requests 7 | 8 | PyYAML 9 | pyhocon 10 | 11 | boto3 12 | 13 | filelock 14 | pika >= 1.1.0 15 | 16 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | In order to try examples in studio/examples directory, 2 | be sure to install additional Python dependencies by: 3 | 4 | pip install -r extra_example_requirements.txt 5 | 6 | -------------------------------------------------------------------------------- /studio/appengine_config.py: -------------------------------------------------------------------------------- 1 | from google.appengine.ext import vendor 2 | import tempfile 3 | import subprocess 4 | 5 | tempfile.SpooledTemporaryFile = tempfile.TemporaryFile 6 | subprocess.Popen = None 7 | 8 | vendor.add('lib') 9 | -------------------------------------------------------------------------------- /tests/check_style.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | APP_ROOT=$(dirname $(dirname $(stat -fm $0))) 4 | echo $APP_ROOT 5 | pep8 --show-source --statistics $APP_ROOT 6 | # find $APP_ROOT -name "*.py" -exec pep8 --show-source {} \; 7 | -------------------------------------------------------------------------------- /tests/test_array.py: -------------------------------------------------------------------------------- 1 | from studio import fs_tracker 2 | import numpy as np 3 | 4 | try: 5 | lr = np.load(fs_tracker.get_artifact('lr')) 6 | except BaseException: 7 | lr = np.random.random(10) 8 | 9 | print("fitness: %s" % np.abs(np.sum(lr))) 10 | -------------------------------------------------------------------------------- /tests/stop_experiment.py: -------------------------------------------------------------------------------- 1 | import time 2 | from studio import logs 3 | 4 | logger = logs.get_logger('helloworld') 5 | logger.setLevel(10) 6 | 7 | i = 0 8 | while True: 9 | logger.info('{} seconds passed '.format(i)) 10 | time.sleep(1) 11 | i += 1 12 | -------------------------------------------------------------------------------- /tests/test_config_http_server.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: gs 3 | bucket: "studioml-meta" 4 | authentication: none 5 | 6 | storage: 7 | type: gcloud 8 | bucket: "studioml-artifacts" 9 | 10 | server: 11 | authentication: github 12 | 13 | -------------------------------------------------------------------------------- /tests/tf_hello_world.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | a = tf.constant([1.0, 5.0]) 4 | 5 | @tf.function 6 | def forward(x): 7 | return x + 1.0 8 | 9 | result = forward(a) 10 | assert len(result) == 2 11 | 12 | print("[ {} {} ]".format(result[0], result[1])) 13 | -------------------------------------------------------------------------------- /studio/storage/storage_type.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class StorageType(Enum): 4 | storageHTTP = 1 5 | storageS3 = 2 6 | storageLocal = 3 7 | storageFirebase = 4 8 | storageDockerHub = 5 9 | storageSHub = 6 10 | 11 | storageInvalid = 99 -------------------------------------------------------------------------------- /tests/model_increment.py: -------------------------------------------------------------------------------- 1 | import six 2 | 3 | 4 | def create_model(modeldir): 5 | 6 | def model(data): 7 | retval = {} 8 | for k, v in six.iteritems(data): 9 | retval[k] = v + 1 10 | return retval 11 | 12 | return model 13 | -------------------------------------------------------------------------------- /studio/app.yaml: -------------------------------------------------------------------------------- 1 | runtime: python27 2 | # api_verison: 1 3 | threadsafe: false 4 | 5 | handlers: 6 | - url: /.* 7 | script: studio.apiserver.app 8 | 9 | libraries: 10 | - name: ssl 11 | version: latest 12 | - name: numpy 13 | version: latest 14 | 15 | -------------------------------------------------------------------------------- /studio/completion_service/completion_service_testfunc.py: -------------------------------------------------------------------------------- 1 | 2 | def clientFunction(args, files): 3 | print('client function call with args ' + 4 | str(args) + ' and files ' + str(files)) 5 | return args 6 | 7 | 8 | if __name__ == "__main__": 9 | clientFunction() 10 | -------------------------------------------------------------------------------- /tests/conflicting_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | 5 | parser = argparse.ArgumentParser(description='Test argument conflict') 6 | parser.add_argument('--experiment', '-e', help='experiment key', required=True) 7 | args = parser.parse_args() 8 | 9 | print("Experiment key = " + args.experiment) 10 | sys.stdout.flush() 11 | -------------------------------------------------------------------------------- /examples/pytorch/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Examples for Studio 2 | 3 | General dependency: 4 | 5 | ``` 6 | pip install torch torchvision 7 | ``` 8 | 9 | ## MNIST 10 | 11 | Straightforward example for MNIST dataset, taken from `https://github.com/pytorch/examples/tree/master/mnist`: 12 | 13 | ``` 14 | studio run mnist.py 15 | ``` 16 | 17 | -------------------------------------------------------------------------------- /studio/apiserver_config.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: gs 3 | bucket: "studioml-meta" 4 | projectId: "studio-ed756" 5 | 6 | authentication: none 7 | 8 | server: 9 | authentication: github 10 | 11 | storage: 12 | type: gcloud 13 | bucket: studio-ed756.appspot.com 14 | 15 | 16 | verbose: debug 17 | 18 | 19 | -------------------------------------------------------------------------------- /tests/art_hello_world.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from studio import fs_tracker 3 | 4 | print(fs_tracker.get_artifact('f')) 5 | with open(fs_tracker.get_artifact('f'), 'r') as f: 6 | print(f.read()) 7 | 8 | if len(sys.argv) > 1: 9 | with open(fs_tracker.get_artifact('f'), 'w') as f: 10 | f.write(sys.argv[1]) 11 | 12 | 13 | sys.stdout.flush() 14 | -------------------------------------------------------------------------------- /examples/tensorflow/helloworld.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | import tensorflow as tf 4 | 5 | 6 | s = tf.Session() 7 | 8 | x = tf.constant([1.0, 2.0]) 9 | y = x * 2 10 | 11 | logging.basicConfig() 12 | logger = logging.get_logger('helloworld') 13 | logger.setLevel(10) 14 | 15 | while True: 16 | logger.info(s.run(y)) 17 | time.sleep(10) 18 | -------------------------------------------------------------------------------- /examples/general/report_system_info.py: -------------------------------------------------------------------------------- 1 | import psutil 2 | 3 | g = 2**30 + 0.0 4 | meminfo = psutil.virtual_memory() 5 | diskinfo = psutil.disk_usage('/') 6 | print("Cpu count = {}".format(psutil.cpu_count())) 7 | print("RAM: total {}g, free {}g".format(meminfo.total / g, meminfo.free / g)) 8 | print("HDD: total {}g, free {}g".format(diskinfo.total / g, diskinfo.free / g)) 9 | -------------------------------------------------------------------------------- /studio/dependencies_policies/dependencies_policy.py: -------------------------------------------------------------------------------- 1 | class DependencyPolicy: 2 | """ 3 | Abstract class representing some policy 4 | for generating Python packages dependencies 5 | to be used for submitted experiment. 6 | """ 7 | 8 | def generate(self, resources_needed): 9 | raise NotImplementedError('Not implemented DependencyPolicy') 10 | -------------------------------------------------------------------------------- /studio/templates/user_details.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 |

User {{ user }}

5 | 6 | {{ experimenttable(delete_button=true) }} 7 | 8 | 11 | 12 | {% endblock %} 13 | -------------------------------------------------------------------------------- /tests/config_http_client.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: http 3 | serverUrl: "http://localhost:5000" 4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8 5 | authDomain: "studio-ed756.firebaseapp.com" 6 | 7 | saveMetricsFrequency: 1m 8 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes) 9 | 10 | cloud: 11 | cpus: 1 12 | gpus: 0 13 | ram: 4g 14 | hdd: 10g 15 | -------------------------------------------------------------------------------- /studio/templates/all_experiments.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 | 5 |

All experiments

6 | 7 | {{ experimenttable() }} 8 | 9 | 14 | 15 | 16 | 17 | {% endblock %} 18 | -------------------------------------------------------------------------------- /studioml.bib: -------------------------------------------------------------------------------- 1 | @software{StudioML, 2 | author = {Zhokhov, Peter and Denissov, Andrei and Mutch, Karl}, 3 | title = {Studio}, 4 | howpublished = {\url{https://studio.ml}}, 5 | publisher={Cognizant Evolutionary AI}, 6 | version = {0.0.48}, 7 | url = {https://github.com/studioml/studio/tree/0.0.48}, 8 | year = {2017--2021}, 9 | annotate = {Studio: Simplify and expedite the model building process} 10 | } 11 | -------------------------------------------------------------------------------- /studioml.bibtex: -------------------------------------------------------------------------------- 1 | @software{StudioML, 2 | author = {Zhokhov, Peter and Denissov, Andrei and Mutch, Karl}, 3 | title = {Studio}, 4 | howpublished = {\url{https://studio.ml}}, 5 | publisher={Cognizant Evolutionary AI}, 6 | version = {0.0.48}, 7 | url = {https://github.com/studioml/studio/tree/0.0.48}, 8 | year = {2017--2021}, 9 | annotate = {Studio: Simplify and expedite the model building process} 10 | } 11 | -------------------------------------------------------------------------------- /tests/test_config_s3_storage.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: FireBase 3 | 4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8 5 | authDomain: "{}.firebaseapp.com" 6 | databaseURL: "https://{}.firebaseio.com" 7 | projectId: studio-ed756 8 | storageBucket: "{}.appspot.com" 9 | messagingSenderId: 81790704397 10 | 11 | guest: true 12 | 13 | storage: 14 | type: s3 15 | bucket: "studioml-artifacts" 16 | 17 | -------------------------------------------------------------------------------- /examples/general/print_norm_linreg.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from studio import fs_tracker 4 | import pickle 5 | 6 | 7 | weights_list = sorted( 8 | glob.glob( 9 | os.path.join( 10 | fs_tracker.get_artifact('w'), 11 | '*.pck'))) 12 | 13 | print('*****') 14 | print(weights_list[-1]) 15 | with open(weights_list[-1], 'r') as f: 16 | w = pickle.load(f) 17 | 18 | print(w.dot(w)) 19 | print('*****') 20 | -------------------------------------------------------------------------------- /tests/test_config_gcloud_storage.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: FireBase 3 | 4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8 5 | authDomain: "{}.firebaseapp.com" 6 | databaseURL: "https://{}.firebaseio.com" 7 | projectId: studio-ed756 8 | storageBucket: "{}.appspot.com" 9 | messagingSenderId: 81790704397 10 | 11 | guest: true 12 | 13 | storage: 14 | type: gcloud 15 | bucket: "studio-ed756.appspot.com" 16 | 17 | -------------------------------------------------------------------------------- /tests/test_config_http_client.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: http 3 | serverUrl: "https://studio-ed756.appspot.com" 4 | authentication: 5 | type: none 6 | 7 | log: 8 | name: output.log 9 | 10 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes) 11 | saveMetricsFrequency: 1m 12 | 13 | verbose: debug 14 | cloud: 15 | gcloud: 16 | zone: us-east1-c 17 | 18 | resources_needed: 19 | cpus: 2 20 | ram: 3g 21 | hdd: 60g 22 | gpus: 0 23 | 24 | -------------------------------------------------------------------------------- /studio/payload_builders/unencrypted_payload_builder.py: -------------------------------------------------------------------------------- 1 | from studio.payload_builders.payload_builder import PayloadBuilder 2 | from studio.experiments.experiment import Experiment 3 | 4 | class UnencryptedPayloadBuilder(PayloadBuilder): 5 | """ 6 | Simple payload builder constructing 7 | unencrypted experiment payloads. 8 | """ 9 | def construct(self, experiment: Experiment, config, packages): 10 | return { 'experiment': experiment.to_dict(), 11 | 'config': config} 12 | -------------------------------------------------------------------------------- /tests/env_detect.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | AWSInstance = "aws" in list( 4 | yaml.load( 5 | open( 6 | "tests/test_config.yaml", 7 | "r"), Loader=yaml.SafeLoader)["cloud"].keys()) 8 | GcloudInstance = "gcloud" in list( 9 | yaml.load( 10 | open( 11 | "tests/test_config.yaml", 12 | "r"), Loader=yaml.SafeLoader)["cloud"].keys()) 13 | 14 | 15 | def on_gcp(): 16 | return GcloudInstance 17 | 18 | 19 | def on_aws(): 20 | return AWSInstance 21 | -------------------------------------------------------------------------------- /studio/payload_builders/payload_builder.py: -------------------------------------------------------------------------------- 1 | class PayloadBuilder: 2 | """ 3 | Abstract class representing 4 | payload object construction from experiment components. 5 | Result is payload ready to be submitted for execution. 6 | """ 7 | def __init__(self, name: str): 8 | self.name = name if name else 'NO NAME' 9 | 10 | def construct(self, experiment, config, packages): 11 | raise NotImplementedError( 12 | 'Not implemented for payload builder {0}'.format(self.name)) 13 | -------------------------------------------------------------------------------- /docs/testing.rst: -------------------------------------------------------------------------------- 1 | Tests only be run from the in the main studio directory 2 | use "python -m pytest tests/{testname}" to run individual tests 3 | use "python -m pytest tests" to run all tests 4 | 5 | Be sure to install additional Python dependencies for running tests 6 | by: 7 | 8 | pip install -r test_requirements.txt 9 | 10 | Also you would have to set cloud environment and credentials in test_config.yaml 11 | prior to running tests. 12 | Verify also that "studio" executable is in your PATH 13 | to be able to run local_worker test. 14 | 15 | 16 | -------------------------------------------------------------------------------- /studio/completion_service/completion_service_testfunc_files.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import six 3 | from studio import util 4 | 5 | 6 | def clientFunction(args, files): 7 | print('client function call with args ' + 8 | str(args) + ' and files ' + str(files)) 9 | 10 | cs_files = {'output', 'clientscript', 'args'} 11 | filehashes = { 12 | k: util.filehash( 13 | v, 14 | hashobj=hashlib.md5()) for k, 15 | v in six.iteritems(files) if k not in cs_files} 16 | 17 | return (args, filehashes) 18 | -------------------------------------------------------------------------------- /studio/util/logs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | DEBUG = logging.DEBUG 4 | INFO = logging.INFO 5 | ERROR = logging.ERROR 6 | 7 | logging.basicConfig( 8 | format='%(asctime)s %(levelname)-6s %(name)s - %(message)s', 9 | level=ERROR, 10 | datefmt='%Y-%m-%d %H:%M:%S') 11 | 12 | 13 | def get_logger(name): 14 | return logging.getLogger(name) 15 | 16 | 17 | def debug(line): 18 | return logging.debug(line) 19 | 20 | 21 | def error(line): 22 | return logging.error(line) 23 | 24 | 25 | def info(line): 26 | return logging.info(line) 27 | -------------------------------------------------------------------------------- /tests/save_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from keras.layers import Dense 5 | from keras.models import Sequential 6 | 7 | from studio import fs_tracker 8 | 9 | model = Sequential() 10 | model.add(Dense(2, input_shape=(2,))) 11 | 12 | weights = model.get_weights() 13 | new_weights = [np.array([[2, 0], [0, 2]])] 14 | # print weights 15 | # new_weights = [] 16 | # for weight in weights: 17 | # new_weights.append(weight + 1) 18 | 19 | model.set_weights(new_weights) 20 | model.save(os.path.join(fs_tracker.get_model_directory(), 'weights.h5')) 21 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04 2 | 3 | # add tensorflow-gpu to use with gpu to sudo pip install 4 | # to use on linux machines with gpus 5 | RUN apt-get update && \ 6 | apt-get -y install python-pip python-dev python3-pip python3-dev python3 git wget && \ 7 | python -m pip install --upgrade pip && \ 8 | python3 -m pip install --upgrade pip 9 | 10 | COPY . /studio 11 | RUN cd studio && \ 12 | python -m pip install -e . --upgrade && \ 13 | python3 -m pip install -e . --upgrade 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /tests/test_bad_config.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: FireBase 3 | 4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8 5 | authDomain: "{}.firebaseapp.com" 6 | databaseURL: "https://{}.firebaseio.com" 7 | projectId: studio-e756 8 | storageBucket: "{}.appspot.com" 9 | messagingSenderId: 81790704397 10 | 11 | guest: true 12 | 13 | log: 14 | name: output.log 15 | 16 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes) 17 | saveMetricsFrequency: 1m 18 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "github-actions" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /studio/torch/summary_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from io import StringIO 3 | 4 | from studio.torch import summary 5 | 6 | 7 | class ReporterTest(unittest.TestCase): 8 | 9 | def test_summary_report(self): 10 | r = summary.Reporter(log_interval=2, smooth_interval=2) 11 | out = StringIO() 12 | r.add(0, 'k', 0.1) 13 | r.add(1, 'k', 0.2) 14 | r.report() 15 | r.add(2, 'k', 0.3) 16 | r.report(out) 17 | self.assertEqual(out.getvalue(), "Step 2: k = 0.25000") 18 | 19 | 20 | if __name__ == "__main__": 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pip 2 | 3 | setuptools_scm 4 | setuptools_scm_git_archive 5 | 6 | configparser 7 | 8 | numpy 9 | 10 | flask 11 | 12 | cma 13 | 14 | psutil 15 | 16 | apscheduler 17 | openssh_key_parser 18 | pycryptodome 19 | sshpubkeys 20 | PyNaCl 21 | requests 22 | requests_toolbelt 23 | python_jwt 24 | sseclient 25 | 26 | terminaltables 27 | 28 | PyYAML 29 | pyhocon 30 | 31 | google-api-core 32 | google-api-python-client 33 | google-cloud-storage 34 | google-cloud-pubsub 35 | google-auth-httplib2 36 | oauth2client==3.0.0 37 | 38 | boto3 39 | rsa 40 | 41 | filelock 42 | pika >= 1.1.0 43 | 44 | -------------------------------------------------------------------------------- /tests/runner_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from studio import model 4 | 5 | 6 | class RunnerTest(unittest.TestCase): 7 | 8 | def test_add_packages(self): 9 | 10 | list1 = ['keras==2.0.5', 'boto3==1.1.3'] 11 | list2 = ['keras==1.0.9', 'h5py==2.7.0', 'abc'] 12 | 13 | result = set(model.add_packages(list1, list2)) 14 | expected_result = set(['boto3==1.1.3', 'h5py==2.7.0', 15 | 'keras==1.0.9', 'abc']) 16 | 17 | self.assertEqual(result, expected_result) 18 | 19 | 20 | if __name__ == '__main__': 21 | unittest.main() 22 | -------------------------------------------------------------------------------- /Dockerfile_keras_example: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | MAINTAINER jiamingjxu@gmail.com 4 | 5 | ENV LANG C.UTF-8 6 | 7 | RUN mkdir -p /setupTesting 8 | 9 | COPY . /setupTesting 10 | 11 | WORKDIR /setupTesting 12 | 13 | RUN apt-get update && \ 14 | apt-get install -y python-pip libpq-dev python-dev && \ 15 | apt-get install -y git && \ 16 | pip install -U pytest && \ 17 | pip install -r test_requirements.txt && \ 18 | python setup.py build && \ 19 | python setup.py install 20 | 21 | CMD studio run --lifetime=30m --max-duration=20m --gpus 4 --queue=rmq_kmutch --force-git /examples/keras/train_mnist_keras.py -------------------------------------------------------------------------------- /tests/test_config_s3.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: s3 3 | 4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8 5 | authDomain: "{}.firebaseapp.com" 6 | databaseURL: "https://{}.firebaseio.com" 7 | projectId: studio-ed756 8 | storageBucket: "{}.appspot.com" 9 | messagingSenderId: 81790704397 10 | 11 | bucket: studioml-meta 12 | guest: true 13 | max_keys: -1 14 | 15 | log: 16 | name: output.log 17 | 18 | verbose: error 19 | 20 | cloud: 21 | gcloud: 22 | zone: us-central1-f 23 | 24 | resources_needed: 25 | cpus: 2 26 | ram: 3g 27 | hdd: 10g 28 | gpus: 0 29 | 30 | -------------------------------------------------------------------------------- /tests/config_http_server.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: FireBase 3 | 4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8 5 | authDomain: "{}.firebaseapp.com" 6 | databaseURL: "https://{}.firebaseio.com" 7 | projectId: studio-ed756 8 | storageBucket: "{}.appspot.com" 9 | messagingSenderId: 81790704397 10 | serviceAccount: /Users/peter.zhokhov/gkeys/studio-ed756-firebase-adminsdk-4ug1o-7202e6402d.json 11 | 12 | storage: 13 | type: gcloud 14 | bucket: "studio-ed756.appspot.com" 15 | 16 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes) 17 | saveMetricsFrequency: 1m 18 | 19 | 20 | -------------------------------------------------------------------------------- /studio/run_magic.py.stub: -------------------------------------------------------------------------------- 1 | import __main__ as _M 2 | import pickle 3 | import six 4 | import gzip 5 | 6 | from studio import fs_tracker 7 | 8 | with gzip.open(fs_tracker.get_artifact('_ns'), 'rb') as f: 9 | _M.__dict__.update(pickle.load(f)) 10 | 11 | {script} 12 | 13 | ns_dict = _M.__dict__.copy() 14 | pickleable_ns = dict() 15 | for varname, var in six.iteritems(ns_dict): 16 | try: 17 | pickle.dumps(var) 18 | pickleable_ns[varname] = var 19 | except BaseException: 20 | pass 21 | 22 | with open(fs_tracker.get_artifact('_ns'), 'w') as f: 23 | f.write(pickle.dumps(pickleable_ns)) 24 | 25 | -------------------------------------------------------------------------------- /tests/gpu_util_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from studio.util.gpu_util import memstr2int 3 | 4 | 5 | class GpuUtilTest(unittest.TestCase): 6 | 7 | def test_memstr2int(self): 8 | self.assertEqual(memstr2int('123 Mb'), 123 * (2**20)) 9 | self.assertEqual(memstr2int('456 MiB'), 456 * (2**20)) 10 | self.assertEqual(memstr2int('23 Gb'), 23 * (2**30)) 11 | self.assertEqual(memstr2int('23 GiB'), 23 * (2**30)) 12 | self.assertEqual(memstr2int('23 '), 23) 13 | 14 | with self.assertRaises(ValueError): 15 | memstr2int('300 spartans') 16 | 17 | 18 | if __name__ == "__main__": 19 | unittest.main() 20 | -------------------------------------------------------------------------------- /tests/test_config_env.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: FireBase 3 | 4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8 5 | authDomain: "{}.firebaseapp.com" 6 | databaseURL: "https://{}.firebaseio.com" 7 | projectId: studio-ed756 8 | storageBucket: "{}.appspot.com" 9 | messagingSenderId: 81790704397 10 | 11 | guest: true 12 | 13 | log: 14 | name: output.log 15 | 16 | test_key: $TEST_VAR1 17 | test_section: 18 | test_key: $TEST_VAR2 19 | 20 | 21 | cloud: 22 | type: google 23 | zone: us-central1-f 24 | 25 | cpus: 2 26 | ram: 3g 27 | hdd: 10g 28 | gpus: 0 29 | 30 | -------------------------------------------------------------------------------- /tests/test_config_gs.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: gs 3 | 4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8 5 | authDomain: "{}.firebaseapp.com" 6 | databaseURL: "https://{}.firebaseio.com" 7 | projectId: studio-ed756 8 | storageBucket: "{}.appspot.com" 9 | messagingSenderId: 81790704397 10 | 11 | bucket: studioml-meta 12 | guest: true 13 | max_keys: -1 14 | 15 | log: 16 | name: output.log 17 | 18 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes) 19 | verbose: error 20 | 21 | cloud: 22 | gcloud: 23 | zone: us-central1-f 24 | 25 | resources_needed: 26 | cpus: 2 27 | ram: 3g 28 | hdd: 10g 29 | gpus: 0 30 | 31 | -------------------------------------------------------------------------------- /studio/templates/dashboard.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 | 5 |

Your experiments

6 | 7 | {{ experimenttable() }} 8 | 9 | 21 | {% endblock %} 22 | -------------------------------------------------------------------------------- /tests/test_config_auth.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: FireBase 3 | 4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8 5 | authDomain: "{}.firebaseapp.com" 6 | databaseURL: "https://{}.firebaseio.com" 7 | projectId: studio-ed756 8 | storageBucket: "{}.appspot.com" 9 | messagingSenderId: 81790704397 10 | 11 | authentication: 12 | type: firebase 13 | use_email_auth: true 14 | email: authtest@sentient.ai 15 | password: HumptyDumptyS@tOnAWall 16 | authDomain: "studio-ed756.firebaseapp.com" 17 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8 18 | 19 | log: 20 | name: output.log 21 | -------------------------------------------------------------------------------- /tests/test_config_datacenter.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: http 3 | serverUrl: https://studio-sentient.appspot.com 4 | projectId: studio-sentient 5 | compression: None 6 | guest: true 7 | 8 | log: 9 | name: output.log 10 | 11 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes) 12 | saveMetricsFrequency: 1m 13 | 14 | verbose: debug 15 | cloud: 16 | gcloud: 17 | zone: us-east1-c 18 | 19 | resources_needed: 20 | cpus: 2 21 | ram: 3g 22 | hdd: 10g 23 | gpus: 0 24 | 25 | env: 26 | AWS_ACCESS_KEY_ID: $AWS_ACCESS_KEY_ID 27 | AWS_SECRET_ACCESS_KEY: $AWS_SECRET_ACCESS_KEY 28 | AWS_DEFAULT_REGION: us-west-2 29 | 30 | 31 | 32 | healthcheck: python -c "import tensorflow" 33 | 34 | -------------------------------------------------------------------------------- /studio/aws/aws_amis.yaml: -------------------------------------------------------------------------------- 1 | ubuntu16.04: 2 | # us-east-1: ami-da05a4a0 # Vanilla ubuntu 16.04 3 | us-east-1: ami-d91866a3 # studio.ml enabled 4 | # us-east-1: ami-f346c289 # aws deep learning ami 5 | 6 | us-east-2: ami-336b4456 7 | us-west-1: ami-1c1d217c 8 | # us-west-2: ami-0a00ce72 # Vanilla ubuntu 16.04 9 | us-west-2: ami-aa508ad2 # studio.ml enabled 10 | 11 | ca-central-1: ami-8a71c9ee 12 | eu-west-1: ami-add175d4 13 | eu-central-1: ami-97e953f8 14 | eu-west-2: ami-ecbea388 15 | ap-southeast-1: ami-67a6e604 16 | ap-southeast-2: ami-41c12e23 17 | ap-northeast-2: ami-7b1cb915 18 | ap-northeast-1: ami-15872773 19 | ap-south-1: ami-bc0d40d3 20 | sa-east-1: ami-466b132a 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /runtests.sh: -------------------------------------------------------------------------------- 1 | python tests/artifact_store_test.py 2 | python tests/cloud_worker_test.py 3 | python tests/fs_tracker_test.py 4 | python tests/git_util_test.py 5 | python tests/gpu_util_test.py 6 | python tests/http_provider_hosted_test.py 7 | python tests/http_provider_test.py 8 | python tests/hyperparam_test.py 9 | python tests/local_worker_test.py 10 | python tests/model_test.py 11 | python tests/model_util_test.py 12 | python tests/providers_test.py 13 | python tests/queue_test.py 14 | python tests/remote_worker_test.py 15 | python tests/runner_test.py 16 | python tests/serving_test.py 17 | python tests/util_test.py 18 | 19 | -------------------------------------------------------------------------------- /studio/templates/project_details.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 |

Project {{ project }}

5 | 6 | {{ experimenttable(showproject=false) }} 7 | 8 | {% if allow_tensorboard %} 9 | 10 | {%- endif %} 11 | 12 | 24 | 25 | {% endblock %} 26 | -------------------------------------------------------------------------------- /studio/completion_service/completion_service_testfunc_saveload.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from studio import fs_tracker 4 | 5 | 6 | def clientFunction(args, files): 7 | print('client function call with args ' + 8 | str(args) + ' and files ' + str(files)) 9 | 10 | modelfile = 'model.dat' 11 | filename = files.get('model') or \ 12 | os.path.join(fs_tracker.get_artifact('modeldir'), modelfile) 13 | 14 | print("Trying to load file {}".format(filename)) 15 | 16 | if os.path.exists(filename): 17 | with open(filename, 'rb') as f: 18 | args = pickle.loads(f.read()) + 1 19 | 20 | else: 21 | print("Trying to write file {}".format(filename)) 22 | with open(filename, 'wb') as f: 23 | f.write(pickle.dumps(args, protocol=2)) 24 | 25 | return args 26 | 27 | 28 | if __name__ == "__main__": 29 | clientFunction('test', {}) 30 | -------------------------------------------------------------------------------- /docs/authentication.rst: -------------------------------------------------------------------------------- 1 | Authentication 2 | ============== 3 | 4 | Currently, Studio uses GitHub auth for authentication. For command-line tools 5 | (studio ui, studio run etc) the authentication is done via personal 6 | access tokens. When no token is present (e.g. when you run studio for 7 | the first time), you will be asked to input your github username and 8 | password. Studio DOES NOT store your username or password, instead, 9 | those are being sent to GitHub API server in exchange to an access token. 10 | The access token is being saved and used from that point on. 11 | The personal access tokens do not expire, can be transferred from one 12 | machine to another, and, if necessary, can be revoked by going to 13 | GitHub -> Settings -> Developer Settings -> Personal Access Tokens 14 | 15 | Authentication for the hosted UI server (https://zoo.studio.ml) follows 16 | the standard GitHub Auth flow for web apps. 17 | -------------------------------------------------------------------------------- /studio/default_config.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: http 3 | serverUrl: https://zoo.studio.ml 4 | authentication: github 5 | 6 | storage: 7 | type: gcloud 8 | bucket: studio-ed756.appspot.com 9 | 10 | 11 | 12 | queue: local 13 | 14 | saveMetricsFrequency: 1m 15 | saveWorkspaceFrequency: 1m 16 | verbose: error 17 | 18 | cloud: 19 | gcloud: 20 | zone: us-east1-c 21 | 22 | resources_needed: 23 | cpus: 2 24 | ram: 3g 25 | hdd: 60g 26 | gpus: 0 27 | 28 | sleep_time: 1 29 | worker_timeout: 30 30 | 31 | optimizer: 32 | cmaes_config: 33 | popsize: 100 34 | sigma0: 0.25 35 | load_best_only: false 36 | load_checkpoint_file: 37 | visualization: true 38 | result_dir: "~/Desktop/" 39 | checkpoint_interval: 0 40 | termination_criterion: 41 | generation: 5 42 | fitness: 999 43 | skip_gen_thres: 1.0 44 | skip_gen_timeout: 30 45 | -------------------------------------------------------------------------------- /docs/ci_testing.rst: -------------------------------------------------------------------------------- 1 | Requirements: Docker, Dockerhub Account, Kubernetes, Keel 2 | 3 | https://docs.docker.com/install/ 4 | 5 | https://keel.sh/v1/guide/installation.html 6 | 7 | https://kubernetes.io/docs/tasks/tools/install-kubectl/ 8 | 9 | https://keel.sh/v1/guide/installation.html 10 | 11 | To run individual tests, edit the Dockerfile_standalone_testing. 12 | 13 | After editing build the image using 14 | "docker image build --tag [dockerhubUsername]/standalone_testing:latest . -f Dockerfile_standalone_testing" 15 | 16 | May have to use sudo 17 | 18 | Push the image to your docker account with 19 | 20 | "docker push [dockerhubUsername]/standalone_testing" 21 | 22 | Then to run the tests edit the test-runner.yaml:56 to 23 | 24 | "- image: [dockerhubUsername]/standalone_testing" 25 | 26 | Finally use "kubectl apply -f test-runner.yaml" to automatically run tests, 27 | 28 | results can be seen using "kubectl log test-runner-xxxxxxx-xxxxx" 29 | -------------------------------------------------------------------------------- /Dockerfile_standalone_testing: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | MAINTAINER jiamingjxu@gmail.com 4 | 5 | ENV LANG C.UTF-8 6 | 7 | RUN mkdir -p /setupTesting 8 | 9 | COPY . /setupTesting 10 | 11 | WORKDIR /setupTesting 12 | 13 | RUN apt-get update && apt-get install -y \ 14 | curl 15 | 16 | RUN \ 17 | apt-get update && apt-get install -y apt-transport-https && \ 18 | curl -s https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - && \ 19 | echo "deb https://apt.kubernetes.io/ kubernetes-xenial main" | tee -a /etc/apt/sources.list.d/kubernetes.list && \ 20 | apt-get update && \ 21 | apt-get install -y kubectl 22 | 23 | RUN apt-get update && \ 24 | apt-get install -y python-pip libpq-dev python-dev && \ 25 | apt-get install -y git && \ 26 | pip install -U pytest && \ 27 | pip install -r test_requirements.txt && \ 28 | python setup.py build && \ 29 | python setup.py install 30 | 31 | CMD python -m pytest tests/util_test.py 32 | -------------------------------------------------------------------------------- /studio/client_config.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: http 3 | serverUrl: http://localhost:5000 4 | authentication: 5 | type: github 6 | 7 | storage: 8 | type: gcloud 9 | bucket: studioml-sentient-artifacts 10 | 11 | 12 | 13 | queue: local 14 | 15 | saveMetricsFrequency: 1m 16 | saveWorkspaceFrequency: 1m 17 | verbose: error 18 | 19 | cloud: 20 | gcloud: 21 | zone: us-central1-f 22 | 23 | resources_needed: 24 | cpus: 2 25 | ram: 3g 26 | hdd: 10g 27 | gpus: 0 28 | 29 | sleep_time: 1 30 | worker_timeout: 30 31 | 32 | optimizer: 33 | cmaes_config: 34 | popsize: 100 35 | sigma0: 0.25 36 | load_best_only: false 37 | load_checkpoint_file: 38 | visualization: true 39 | result_dir: "~/Desktop/" 40 | checkpoint_interval: 0 41 | termination_criterion: 42 | generation: 5 43 | fitness: 999 44 | skip_gen_thres: 1.0 45 | skip_gen_timeout: 30 46 | -------------------------------------------------------------------------------- /docs/ec2_setup.rst: -------------------------------------------------------------------------------- 1 | Setting up Amazon EC2 2 | ===================== 3 | 4 | This page describes the process of configuring Studio to work 5 | with Amazon EC2. We assume that you already have AWS credentials 6 | and an AWS account set up. 7 | 8 | Install boto3 9 | ------------- 10 | 11 | Studio interacts with AWS via the boto3 API. Thus, in order to use EC2 12 | cloud you'll need to install boto3: 13 | 14 | :: 15 | 16 | pip install boto3 17 | 18 | Set up credentials 19 | ------------------ 20 | 21 | Add credentials to a location where boto3 can access them. The 22 | recommended way is to install the AWS CLI: 23 | 24 | :: 25 | 26 | pip install awscli 27 | 28 | and then run 29 | 30 | :: 31 | 32 | aws configure 33 | 34 | and enter your AWS credentials and region. The output format cam be left as 35 | None. Alternatively, use any method of letting boto3 know the 36 | credentials described here: 37 | http://boto3.readthedocs.io/en/latest/guide/configuration.html 38 | -------------------------------------------------------------------------------- /tests/fs_tracker_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | 4 | from studio import fs_tracker 5 | 6 | 7 | class StudioLoggingTest(unittest.TestCase): 8 | 9 | def test_get_model_directory_args(self): 10 | experimentName = 'testExperiment' 11 | modelDir = fs_tracker.get_model_directory(experimentName) 12 | self.assertTrue( 13 | modelDir == os.path.join( 14 | os.path.expanduser('~'), 15 | '.studioml/experiments/testExperiment/modeldir')) 16 | 17 | def test_get_model_directory_noargs(self): 18 | testExperiment = 'testExperiment' 19 | testPath = os.path.join( 20 | os.path.expanduser('~'), 21 | '.studioml/experiments', 22 | testExperiment, 'modeldir') 23 | 24 | os.environ['STUDIOML_EXPERIMENT'] = testExperiment 25 | self.assertTrue(testPath == fs_tracker.get_model_directory()) 26 | 27 | 28 | if __name__ == "__main__": 29 | unittest.main() 30 | -------------------------------------------------------------------------------- /examples/general/train_linreg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | 5 | no_samples = 100 6 | dim_samples = 5 7 | 8 | learning_rate = 0.01 9 | no_steps = 10 10 | 11 | X = np.random.random((no_samples, dim_samples)) 12 | y = np.random.random((no_samples,)) 13 | 14 | w = np.random.random((dim_samples,)) 15 | 16 | for step in range(no_steps): 17 | yhat = X.dot(w) 18 | err = (yhat - y) 19 | dw = err.dot(X) 20 | w -= learning_rate * dw 21 | loss = 0.5 * err.dot(err) 22 | 23 | print("step = {}, loss = {}, L2 norm = {}".format(step, loss, w.dot(w))) 24 | 25 | # with open(os.path.expanduser('~/weights/lr_w_{}_{}.pck' 26 | # .format(step, loss)), 'w') as f: 27 | # f.write(pickle.dumps(w)) 28 | 29 | from studio import fs_tracker 30 | with open(os.path.join(fs_tracker.get_artifact('weights'), 31 | 'lr_w_{}_{}.pck'.format(step, loss)), 32 | 'w') as f: 33 | f.write(pickle.dumps(w)) 34 | -------------------------------------------------------------------------------- /studio/server_config.yaml: -------------------------------------------------------------------------------- 1 | database: 2 | type: gs 3 | bucket: studioml-sentient-meta 4 | projectId: "studio-sentient" 5 | 6 | authentication: 7 | type: github 8 | 9 | storage: 10 | type: gcloud 11 | bucket: studioml-sentient-artifacts 12 | 13 | 14 | 15 | queue: local 16 | 17 | saveMetricsFrequency: 1m 18 | saveWorkspaceFrequency: 1m 19 | verbose: error 20 | 21 | cloud: 22 | gcloud: 23 | zone: us-central1-f 24 | 25 | resources_needed: 26 | cpus: 2 27 | ram: 3g 28 | hdd: 10g 29 | gpus: 0 30 | 31 | sleep_time: 1 32 | worker_timeout: 30 33 | 34 | optimizer: 35 | cmaes_config: 36 | popsize: 100 37 | sigma0: 0.25 38 | load_best_only: false 39 | load_checkpoint_file: 40 | visualization: true 41 | result_dir: "~/Desktop/" 42 | checkpoint_interval: 0 43 | termination_criterion: 44 | generation: 5 45 | fitness: 999 46 | skip_gen_thres: 1.0 47 | skip_gen_timeout: 30 48 | -------------------------------------------------------------------------------- /tests/test_config.yaml: -------------------------------------------------------------------------------- 1 | ## NOTE: to run unit tests in your environment, 2 | ## please provide your own credentials for S3 and RMQ servers access! 3 | database: 4 | type: s3 5 | authentication: none 6 | aws_access_key: ********* 7 | aws_secret_key: ************** 8 | bucket: test-database 9 | endpoint: http://******************************/ 10 | 11 | env: 12 | AWS_DEFAULT_REGION: us-west-2 13 | 14 | server: 15 | authentication: None 16 | 17 | storage: 18 | type: s3 19 | aws_access_key: ********* 20 | aws_secret_key: ************** 21 | bucket: test-storage 22 | endpoint: http://******************************/ 23 | 24 | log: 25 | name: output.log 26 | 27 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes) 28 | saveMetricsFrequency: 1m 29 | verbose: error 30 | 31 | cloud: 32 | queue: 33 | rmq: amqp://********************************** 34 | 35 | resources_needed: 36 | cpus: 2 37 | ram: 3g 38 | hdd: 60g 39 | gpus: 0 40 | 41 | -------------------------------------------------------------------------------- /studio/patches/requests/models.py.patch: -------------------------------------------------------------------------------- 1 | --- lib/requests/models.py 2018-01-03 14:10:27.000000000 -0800 2 | +++ models.py 2018-01-03 14:09:01.000000000 -0800 3 | @@ -742,8 +742,16 @@ 4 | # Special case for urllib3. 5 | if hasattr(self.raw, 'stream'): 6 | try: 7 | - for chunk in self.raw.stream(chunk_size, decode_content=True): 8 | - yield chunk 9 | + if isinstance(self.raw._original_response._method, int): 10 | + while True: 11 | + chunk = self.raw.read(chunk_size, decode_content=True) 12 | + if not chunk: 13 | + break 14 | + yield chunk 15 | + else: 16 | + for chunk in self.raw.stream(chunk_size, decode_content=True): 17 | + yield chunk 18 | + 19 | except ProtocolError as e: 20 | raise ChunkedEncodingError(e) 21 | except DecodeError as e: 22 | -------------------------------------------------------------------------------- /docs/cli.rst: -------------------------------------------------------------------------------- 1 | ====================== 2 | Command-line interface 3 | ====================== 4 | 5 | In some cases, a (semi-)programmatic way of keeping track of experiments may be preferred. On top of the Python and HTTP API, we provide 6 | a command-line tool to get a quick overview of exising experiments and take actions on them. Commands available 7 | at the moment are: 8 | 9 | - ``studio runs list users`` - lists all users 10 | - ``studio runs list projects`` - lists all projects 11 | - ``studio runs list [user]`` - lists your (default) or someone else's experiments 12 | - ``studio runs list project `` - lists all experiments in a project 13 | - ``studio runs list all`` - lists all experiments 14 | 15 | - ``studio runs kill `` - deletes experiment 16 | - ``studio runs stop `` - stops experiment 17 | 18 | Note that for now if the experiment is running, killing it will NOT automatically stop the runner. You should stop the experiment first, ensure its status has been changed to stopped, and then kill it. This is a known issue, and we are working on a solution. 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /studio/rmq_config.yml: -------------------------------------------------------------------------------- 1 | database: 2 | type: s3 3 | endpoint: https://s3-us-west-2.amazonaws.com/ 4 | bucket: "karl-mutch-rmq" 5 | compression: None 6 | authentication: none 7 | 8 | storage: 9 | type: s3 10 | endpoint: https://s3-us-west-2.amazonaws.com/ 11 | bucket: "karl-mutch-rmq" 12 | compression: None 13 | 14 | runner: 15 | slack_destination: "@karl.mutch" 16 | 17 | cloud: 18 | queue: 19 | rmq: "amqp://guest:guest@localhost:5672/" 20 | 21 | verbose: debug 22 | saveWorkspaceFrequency: 3m 23 | experimentLifetime: 20m 24 | 25 | resources_needed: 26 | cpus: 1 27 | gpus: 1 28 | hdd: 3gb 29 | ram: 2gb 30 | gpuMem: 3gb 31 | 32 | env: 33 | AWS_ACCESS_KEY_ID: **removed** 34 | AWS_DEFAULT_REGION: us-west-2 35 | AWS_SECRET_ACCESS_KEY: **removed** 36 | PATH: "%PATH%:./bin" 37 | 38 | pip: 39 | - keras==2.0.8 40 | 41 | optimizer: 42 | cmaes_config: 43 | popsize: 100 44 | sigma0: 0.25 45 | load_best_only: false 46 | load_checkpoint_file: 47 | visualization: true 48 | result_dir: "~/Desktop/" 49 | checkpoint_interval: 0 50 | termination_criterion: 51 | generation: 5 52 | fitness: 999 53 | skip_gen_thres: 1.0 54 | skip_gen_timeout: 30 55 | -------------------------------------------------------------------------------- /tests/git_util_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import tempfile 3 | import os 4 | import subprocess 5 | import uuid 6 | import re 7 | from studio import git_util 8 | 9 | 10 | class GitUtilTest(unittest.TestCase): 11 | 12 | def test_is_git(self): 13 | self.assertTrue(git_util.is_git()) 14 | 15 | def test_is_not_git(self): 16 | self.assertFalse(git_util.is_git(tempfile.gettempdir())) 17 | 18 | def test_is_not_clean(self): 19 | filename = str(uuid.uuid4()) 20 | subprocess.call(['touch', filename]) 21 | is_clean = git_util.is_clean() 22 | os.remove(filename) 23 | self.assertFalse(is_clean) 24 | 25 | @unittest.skipIf(os.environ.get('TEST_GIT_REPO_ADDRESS') != 1, 26 | 'skip if being tested from a forked repo') 27 | def test_repo_url(self): 28 | expected = re.compile( 29 | 'https{0,1}://github\.com/studioml/studio(\.git){0,1}') 30 | expected2 = re.compile( 31 | 'git@github\.com:studioml/studio(\.git){0,1}') 32 | actual = git_util.get_repo_url(remove_user=True) 33 | self.assertTrue( 34 | (expected.match(actual) is not None) or 35 | (expected2.match(actual) is not None)) 36 | 37 | 38 | if __name__ == "__main__": 39 | unittest.main() 40 | -------------------------------------------------------------------------------- /studio/storage/storage_setup.py: -------------------------------------------------------------------------------- 1 | from studio.storage.storage_handler import StorageHandler 2 | from studio.storage.storage_type import StorageType 3 | from studio.util.logs import INFO 4 | 5 | DB_KEY = "database" 6 | STORE_KEY = "store" 7 | 8 | # Global dictionary which keeps Database Provider 9 | # and Artifact Store objects created from experiment configuration. 10 | _storage_setup = None 11 | 12 | _storage_verbose_level = INFO 13 | 14 | def setup_storage(db_provider, artifact_store): 15 | global _storage_setup 16 | _storage_setup = { DB_KEY: db_provider, STORE_KEY: artifact_store } 17 | 18 | def get_storage_db_provider(): 19 | global _storage_setup 20 | if _storage_setup is None: 21 | return None 22 | return _storage_setup.get(DB_KEY, None) 23 | 24 | def get_storage_artifact_store(): 25 | global _storage_setup 26 | if _storage_setup is None: 27 | return None 28 | return _storage_setup.get(STORE_KEY, None) 29 | 30 | def reset_storage(): 31 | global _storage_setup 32 | _storage_setup = None 33 | 34 | def get_storage_verbose_level(): 35 | global _storage_verbose_level 36 | return _storage_verbose_level 37 | 38 | def set_storage_verbose_level(level: int): 39 | global _storage_verbose_level 40 | _storage_verbose_level = level 41 | 42 | 43 | -------------------------------------------------------------------------------- /studio/templates/projects.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 |

Projects

5 |
6 | 7 | 38 | {% endblock %} 39 | -------------------------------------------------------------------------------- /examples/keras/train_mnist.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import tensorflow_datasets as tfds 4 | 5 | (ds_train, ds_test), ds_info = tfds.load( 6 | 'mnist', 7 | split=['train', 'test'], 8 | shuffle_files=True, 9 | as_supervised=True, 10 | with_info=True, 11 | ) 12 | 13 | 14 | def normalize_img(image, label): 15 | """Normalizes images: `uint8` -> `float32`.""" 16 | return tf.cast(image, tf.float32) / 255., label 17 | 18 | 19 | ds_train = ds_train.map( 20 | normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) 21 | ds_train = ds_train.cache() 22 | ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples) 23 | ds_train = ds_train.batch(128) 24 | ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE) 25 | 26 | ds_test = ds_test.map( 27 | normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) 28 | ds_test = ds_test.batch(128) 29 | ds_test = ds_test.cache() 30 | ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE) 31 | 32 | 33 | model = tf.keras.models.Sequential([ 34 | tf.keras.layers.Flatten(input_shape=(28, 28)), 35 | tf.keras.layers.Dense(128, activation='relu'), 36 | tf.keras.layers.Dense(10) 37 | ]) 38 | model.compile( 39 | optimizer=tf.keras.optimizers.Adam(0.001), 40 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 41 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], 42 | ) 43 | 44 | model.fit( 45 | ds_train, 46 | epochs=6, 47 | validation_data=ds_test, 48 | ) 49 | 50 | tf.saved_model.save(model, "../modeldir") 51 | -------------------------------------------------------------------------------- /docs/local_filesystem_setup.rst: -------------------------------------------------------------------------------- 1 | Setting up experiment storage and database in local filesystem 2 | ============================================================== 3 | 4 | This page describes how to setup studioml to use 5 | local filesystem for storing experiment artifacts and meta-data. 6 | With this option, there is no need to setup any external 7 | connection to S3/Minio/GCS etc. 8 | 9 | StudioML configuration 10 | -------------------- 11 | 12 | :: 13 | 14 | "studio_ml_config": { 15 | 16 | ... 17 | 18 | "database": { 19 | "type": "local", 20 | "endpoint": SOME_DB_LOCAL_PATH, 21 | "bucket": DB_BUCKET_NAME, 22 | "authentication": "none" 23 | }, 24 | "storage": { 25 | "type": "local", 26 | "endpoint": SOME_ARTIFACTS_LOCAL_PATH, 27 | "bucket": ARTIFACTS_BUCKET_NAME, 28 | } 29 | 30 | ... 31 | } 32 | 33 | 34 | With StudioML database type set to "local", 35 | all experiment meta-data will be stored locally under 36 | directory: SOME_DB_LOCAL_PATH/DB_BUCKET_NAME. 37 | Similarly, with storage type set to "local", 38 | all experiment artifacts will be stored locally under 39 | directory: SOME_ARTIFACTS_LOCAL_PATH/ARTIFACTS_BUCKET_NAME. 40 | 41 | Note: if you are using "local" mode, it is recommended to use it 42 | for both storage and database configuration. 43 | But it's technically possible to mix, for example, local storage configuration 44 | and S3-based database configuration etc. 45 | 46 | -------------------------------------------------------------------------------- /studio/aws/aws_prices.yaml: -------------------------------------------------------------------------------- 1 | t2.nano : 0.0058 2 | t2.micro : 0.0116 3 | t2.small : 0.023 4 | t2.medium : 0.0464 5 | t2.large : 0.0928 6 | t2.xlarge : 0.1856 7 | t2.2xlarge : 0.3712 8 | m4.large : 0.1 9 | m4.xlarge : 0.2 10 | m4.2xlarge : 0.4 11 | m4.4xlarge : 0.8 12 | m4.10xlarge : 2 13 | m4.16xlarge : 3.2 14 | m3.medium : 0.067 15 | m3.large : 0.133 16 | m3.xlarge : 0.266 17 | m3.2xlarge : 0.532 18 | c5.large : 0.085 19 | c5.xlarge : 0.17 20 | c5.2xlarge : 0.34 21 | c5.4xlarge : 0.68 22 | c5.9xlarge : 1.53 23 | c5.18xlarge : 3.06 24 | c4.large : 0.1 25 | c4.xlarge : 0.199 26 | c4.2xlarge : 0.398 27 | c4.4xlarge : 0.796 28 | c4.8xlarge : 1.591 29 | c3.large : 0.105 30 | c3.xlarge : 0.21 31 | c3.2xlarge : 0.42 32 | c3.4xlarge : 0.84 33 | c3.8xlarge : 1.68 34 | p2.xlarge : 0.9 35 | p2.8xlarge : 7.2 36 | p2.16xlarge : 14.4 37 | p3.2xlarge : 3.06 38 | p3.8xlarge : 12.24 39 | p3.16xlarge : 24.48 40 | g2.2xlarge : 0.65 41 | g2.8xlarge : 2.6 42 | g3.4xlarge : 1.14 43 | g3.8xlarge : 2.28 44 | g3.16xlarge : 4.56 45 | f1.2xlarge : 1.65 46 | f1.16xlarge : 13.2 47 | x1.16xlarge : 6.669 48 | x1.32xlarge : 13.338 49 | x1e.32xlarge : 26.688 50 | r3.large : 0.166 51 | r3.xlarge : 0.333 52 | r3.2xlarge : 0.665 53 | r3.4xlarge : 1.33 54 | r3.8xlarge : 2.66 55 | r4.large : 0.133 56 | r4.xlarge : 0.266 57 | r4.2xlarge : 0.532 58 | r4.4xlarge : 1.064 59 | r4.8xlarge : 2.128 60 | r4.16xlarge : 4.256 61 | i3.large : 0.156 62 | i3.xlarge : 0.312 63 | i3.2xlarge : 0.624 64 | i3.4xlarge : 1.248 65 | i3.8xlarge : 2.496 66 | i3.16xlarge : 4.992 67 | d2.xlarge : 0.69 68 | d2.2xlarge : 1.38 69 | d2.4xlarge : 2.76 70 | d2.8xlarge : 5.52 71 | -------------------------------------------------------------------------------- /studio/queues/queues_setup.py: -------------------------------------------------------------------------------- 1 | """Data providers.""" 2 | import uuid 3 | 4 | from studio.queues.local_queue import LocalQueue 5 | from studio.queues.sqs_queue import SQSQueue 6 | from studio.queues.qclient_cache import get_cached_queue, shutdown_cached_queue 7 | 8 | def get_queue( 9 | queue_name=None, 10 | cloud=None, 11 | config=None, 12 | logger=None, 13 | close_after=None, 14 | verbose=10): 15 | _ = verbose 16 | if queue_name is None: 17 | if cloud in ['gcloud', 'gcspot']: 18 | queue_name = 'pubsub_' + str(uuid.uuid4()) 19 | elif cloud in ['ec2', 'ec2spot']: 20 | queue_name = 'sqs_' + str(uuid.uuid4()) 21 | else: 22 | queue_name = 'local_' + str(uuid.uuid4()) 23 | 24 | if queue_name.startswith('ec2') or \ 25 | queue_name.startswith('sqs'): 26 | return SQSQueue(queue_name, config=config, logger=logger) 27 | if queue_name.startswith('rmq_'): 28 | return get_cached_queue( 29 | name=queue_name, 30 | route='StudioML.' + queue_name, 31 | config=config, 32 | close_after=close_after, 33 | logger=logger) 34 | if queue_name.startswith('local'): 35 | return LocalQueue(queue_name, logger=logger) 36 | return None 37 | 38 | def shutdown_queue(queue, logger=None, delete_queue=True): 39 | if queue is None: 40 | return 41 | queue_name = queue.get_name() 42 | if queue_name.startswith("rmq_"): 43 | shutdown_cached_queue(queue, logger, delete_queue) 44 | else: 45 | queue.shutdown(delete_queue) 46 | -------------------------------------------------------------------------------- /studio/templates/users.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block content %} 4 |

Users

5 |
6 | 7 | 47 | {% endblock %} 48 | -------------------------------------------------------------------------------- /studio/postgres_provider.py: -------------------------------------------------------------------------------- 1 | class PostgresProvider(object): 2 | """Data provider for Postgres.""" 3 | 4 | def __init__(self, connection_uri): 5 | # TODO: implement connection 6 | pass 7 | 8 | def add_experiment(self, experiment): 9 | raise NotImplementedError() 10 | 11 | def delete_experiment(self, experiment): 12 | raise NotImplementedError() 13 | 14 | def start_experiment(self, experiment): 15 | raise NotImplementedError() 16 | 17 | def stop_experiment(self, experiment): 18 | raise NotImplementedError() 19 | 20 | def finish_experiment(self, experiment): 21 | raise NotImplementedError() 22 | 23 | def get_experiment(self, key): 24 | raise NotImplementedError() 25 | 26 | def get_user_experiments(self, user): 27 | raise NotImplementedError() 28 | 29 | def get_projects(self): 30 | raise NotImplementedError() 31 | 32 | def get_project_experiments(self): 33 | raise NotImplementedError() 34 | 35 | def get_artifacts(self): 36 | raise NotImplementedError() 37 | 38 | def get_artifact(self): 39 | raise NotImplementedError() 40 | 41 | def get_users(self): 42 | raise NotImplementedError() 43 | 44 | def checkpoint_experiment(self, experiment): 45 | raise NotImplementedError() 46 | 47 | def refresh_auth_token(self, email, refresh_token): 48 | raise NotImplementedError() 49 | 50 | def is_auth_expired(self): 51 | raise NotImplementedError() 52 | 53 | def can_write_experiment(self, key=None, user=None): 54 | raise NotImplementedError() 55 | 56 | def register_user(self, userid, email): 57 | raise NotImplementedError() 58 | -------------------------------------------------------------------------------- /studio/optimizer_plugins/opt_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import numpy as np 6 | 7 | try: 8 | import matplotlib 9 | matplotlib.use("Agg") 10 | import matplotlib.pyplot as plt 11 | except BaseException: 12 | pass 13 | 14 | EPSILON = 1e-12 15 | 16 | 17 | def scale_var(var, min_range, max_range): 18 | return (var - min_range) / max((max_range - min_range), EPSILON) 19 | 20 | 21 | def unscale_var(var, min_range, max_range): 22 | return (var * (max_range - min_range)) + min_range 23 | 24 | 25 | def visualize_fitness(fitness_file=None, best_fitnesses=None, 26 | mean_fitnesses=None, outfile="fitness.png"): 27 | if best_fitnesses is None or mean_fitnesses is None: 28 | assert os.path.exists(fitness_file) 29 | best_fitnesses = [] 30 | mean_fitnesses = [] 31 | with open(fitness_file) as f: 32 | for line in f.readlines(): 33 | best_fit, mean_fit = [float(x) for x in line.rstrip().split()] 34 | best_fitnesses.append(best_fit) 35 | mean_fitnesses.append(mean_fit) 36 | 37 | plt.figure(figsize=(16, 12)) 38 | plt.plot(np.arange(len(best_fitnesses)), best_fitnesses, 39 | label="Best Fitness") 40 | plt.plot(np.arange(len(mean_fitnesses)), mean_fitnesses, 41 | label="Mean Fitness") 42 | plt.xlabel("Generation") 43 | plt.ylabel("Fitness") 44 | plt.grid() 45 | plt.legend(loc='lower right') 46 | 47 | outfile = os.path.abspath(os.path.expanduser(outfile)) 48 | plt.savefig(outfile, bbox_inches='tight') 49 | 50 | 51 | if __name__ == "__main__": 52 | func = eval(sys.argv[1]) 53 | func(*sys.argv[2:]) 54 | -------------------------------------------------------------------------------- /studio/scripts/studio: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | function print_help { 5 | echo "List of allowed commands:" 6 | echo " " 7 | echo " studio ui -- launches WebUI" 8 | echo " studio run -- runs an experiment from a python script" 9 | echo " studio runs -- CLI to view and manipulate existing experiments" 10 | echo " studio start remote worker -- orchestrates (re)-starts of worker that listens to the remote queue in a docker container" 11 | echo " studio remote worker -- start a worker that listens to remote queue" 12 | echo " studio add credentials -- create a docker image with credentials baked in" 13 | echo " " 14 | echo "Type --help for a list of options for each command" 15 | } 16 | 17 | function parse_command { 18 | 19 | expected_command=$1 20 | shift 21 | cmd="studio" 22 | for w in $expected_command; do 23 | if [ "$w" = "$1" ]; then 24 | cmd="$cmd-$w" 25 | shift 26 | else 27 | echo "Unknown command $1" 28 | print_help 29 | exit 1 30 | fi 31 | done 32 | 33 | echo $cmd $* 34 | eval $cmd $* 35 | exit $? 36 | } 37 | 38 | 39 | case $1 in 40 | run) 41 | parse_command "run" $* 42 | ;; 43 | 44 | serve) 45 | parse_command "serve" $* 46 | ;; 47 | 48 | runs) 49 | parse_command "runs" $* 50 | ;; 51 | 52 | ui) 53 | parse_command "ui" $* 54 | ;; 55 | 56 | start) 57 | parse_command "start remote worker" $* 58 | ;; 59 | 60 | remote) 61 | parse_command "remote worker" $* 62 | ;; 63 | 64 | add) 65 | parse_command "add credentials" $* 66 | ;; 67 | 68 | esac 69 | 70 | echo "Unknown command $1" 71 | print_help 72 | exit 1 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /examples/keras/train_mnist_keras_mutligpu.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from keras.layers import Input, Dense 3 | from keras.models import Model 4 | from keras.datasets import mnist 5 | from keras.utils import to_categorical 6 | from keras.callbacks import ModelCheckpoint, TensorBoard 7 | 8 | from studio import fs_tracker 9 | from studio import multi_gpu 10 | 11 | # this placeholder will contain our input digits, as flat vectors 12 | img = Input((784,)) 13 | # fully-connected layer with 128 units and ReLU activation 14 | x = Dense(128, activation='relu')(img) 15 | x = Dense(128, activation='relu')(x) 16 | # output layer with 10 units and a softmax activation 17 | preds = Dense(10, activation='softmax')(x) 18 | 19 | 20 | no_gpus = 2 21 | batch_size = 128 22 | 23 | model = Model(img, preds) 24 | model = multi_gpu.make_parallel(model, no_gpus) 25 | model.compile(loss='categorical_crossentropy', optimizer='adam') 26 | 27 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 28 | 29 | x_train = x_train.reshape(60000, 784) 30 | x_test = x_test.reshape(10000, 784) 31 | x_train = x_train.astype('float32') 32 | x_test = x_test.astype('float32') 33 | x_train /= 255 34 | x_test /= 255 35 | 36 | # convert class vectors to binary class matrices 37 | y_train = to_categorical(y_train, 10) 38 | y_test = to_categorical(y_test, 10) 39 | 40 | 41 | checkpointer = ModelCheckpoint( 42 | fs_tracker.get_model_directory() + 43 | '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf') 44 | 45 | 46 | tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(), 47 | histogram_freq=0, 48 | write_graph=True, 49 | write_images=False) 50 | 51 | 52 | model.fit( 53 | x_train, 54 | y_train, 55 | validation_data=(x_test, y_test), 56 | epochs=int(sys.argv[1]), 57 | batch_size=batch_size * no_gpus, 58 | callbacks=[checkpointer, tbcallback]) 59 | -------------------------------------------------------------------------------- /docs/customenv.rst: -------------------------------------------------------------------------------- 1 | Custom environments 2 | =================== 3 | 4 | Using custom environment variables at runtime 5 | --------------------------------------------- 6 | 7 | You can add an env section to your yaml configuration file in order to send environment variables into your runner environment variables table. Variables can be prefixed with a $ sign if you wish to substitute local environment variables into your run configuration. Be aware that all values are stored in clear text. If you wish to exchange secrets you will need to encrypt them into your configuration file and then decrypt your secrets within your python code used during the experiment. 8 | 9 | 10 | Customization of python environment for the workers 11 | --------------------------------------------------- 12 | 13 | Sometimes your experiment relies on an older / custom version of some 14 | python package. For example, the Keras API has changed quite a bit between 15 | versions 1 and 2. What if you are using a new environment locally, but 16 | would like to re-run old experiments that needed older version of 17 | packages? Or, for example, you'd like to see if your code would work 18 | with the latest version of a package. Studio gives you this 19 | opportunity. 20 | 21 | :: 22 | 23 | studio run --python-pkg=== 24 | 25 | allows you to run ```` on a remote / cloud worker with a 26 | specific version of a package. You can also omit ``==`` 27 | to install the latest version of the package (which may not be 28 | equal to the version in your environment). Note that if a package with a 29 | custom version has dependencies conflicting with the current version, the situation 30 | gets tricky. For now, it is up to pip to resolve conflicts. In some 31 | cases it may fail and you'll have to manually specify dependencies 32 | versions by adding more ``--python-pkg`` arguments. 33 | -------------------------------------------------------------------------------- /tests/hyperparam_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from studio.hyperparameter import HyperparameterParser, Hyperparameter 5 | from studio import logs 6 | 7 | 8 | class RunnerArgs(object): 9 | def __init__(self): 10 | self.optimizer = "grid" 11 | self.verbose = False 12 | 13 | 14 | class HyperparamTest(unittest.TestCase): 15 | def test_parse_range(self): 16 | logger = logs.get_logger('test_stop_experiment') 17 | h = HyperparameterParser(RunnerArgs(), logger) 18 | range_strs = ['1,2,3', ':5', '2:5', '0.1:0.05:0.3', '0.1:3:0.3', 19 | '0.01:4l:10'] 20 | gd_truths = [ 21 | [ 22 | 1.0, 2.0, 3.0], [ 23 | 0.0, 1.0, 2.0, 3.0, 4.0, 5.0], [ 24 | 2.0, 3.0, 4.0, 5.0], [ 25 | 0.1, 0.15, 0.2, 0.25, 0.3], [ 26 | 0.1, 0.2, 0.3], [ 27 | 0.01, 0.1, 1, 10]] 28 | 29 | for range_str, gd_truth in zip(range_strs, gd_truths): 30 | hyperparameter = h._parse_grid("test", range_str) 31 | self.assertTrue(np.isclose(hyperparameter.values, gd_truth).all()) 32 | 33 | def test_unfold_tuples(self): 34 | logger = logs.get_logger('test_stop_experiment') 35 | h = HyperparameterParser(RunnerArgs(), logger) 36 | 37 | hyperparams = [Hyperparameter(name='a', values=[1, 2, 3]), 38 | Hyperparameter(name='b', values=[4, 5])] 39 | 40 | expected_tuples = [ 41 | {'a': 1, 'b': 4}, {'a': 2, 'b': 4}, {'a': 3, 'b': 4}, 42 | {'a': 1, 'b': 5}, {'a': 2, 'b': 5}, {'a': 3, 'b': 5}] 43 | 44 | self.assertEqual( 45 | sorted(h.convert_to_tuples(hyperparams), key=lambda x: str(x)), 46 | sorted(expected_tuples, key=lambda x: str(x))) 47 | 48 | 49 | if __name__ == '__main__': 50 | unittest.main() 51 | -------------------------------------------------------------------------------- /examples/keras/multi_gpu.py: -------------------------------------------------------------------------------- 1 | from keras.layers import merge 2 | from keras.layers.core import Lambda 3 | from keras.models import Model 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def make_parallel(model, gpu_count): 9 | def get_slice(data, idx, parts): 10 | shape = tf.shape(data) 11 | size = tf.concat([shape[:1] // parts, shape[1:]], 0) 12 | stride = tf.concat([shape[:1] // parts, shape[1:] * 0], 0) 13 | start = stride * idx 14 | return tf.slice(data, start, size) 15 | 16 | outputs_all = [] 17 | for i in range(len(model.outputs)): 18 | outputs_all.append([]) 19 | 20 | # Place a copy of the model on each GPU, each getting a slice of the batch 21 | for i in range(gpu_count): 22 | with tf.device('/gpu:%d' % i): 23 | with tf.name_scope('tower_%d' % i): 24 | 25 | inputs = [] 26 | # Slice each input into a piece for processing on this GPU 27 | for x in model.inputs: 28 | input_shape = tuple(x.get_shape().as_list())[1:] 29 | slice_n = Lambda( 30 | get_slice, output_shape=input_shape, arguments={ 31 | 'idx': i, 'parts': gpu_count})(x) 32 | inputs.append(slice_n) 33 | 34 | outputs = model(inputs) 35 | 36 | if not isinstance(outputs, list): 37 | outputs = [outputs] 38 | 39 | # Save all the outputs for merging back together later 40 | for l in range(len(outputs)): 41 | outputs_all[l].append(outputs[l]) 42 | 43 | # merge outputs on CPU 44 | with tf.device('/cpu:0'): 45 | merged = [] 46 | for outputs in outputs_all: 47 | merged.append(merge(outputs, mode='concat', concat_axis=0)) 48 | 49 | return Model(input=model.inputs, output=merged) 50 | -------------------------------------------------------------------------------- /studio/queues/qclient_cache.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | from studio.queues.rabbit_queue import RMQueue 4 | from studio.util.util import check_for_kb_interrupt 5 | 6 | _queue_cache = {} 7 | 8 | def get_cached_queue( 9 | name, 10 | route, 11 | config=None, 12 | logger=None, 13 | close_after=None): 14 | 15 | queue = _queue_cache.get(name, None) 16 | if queue is not None: 17 | if logger is not None: 18 | logger.info("Got queue named %s from queue cache.", name) 19 | return queue 20 | 21 | queue = RMQueue( 22 | queue=name, 23 | route=route, 24 | config=config, 25 | logger=logger) 26 | 27 | if logger is not None: 28 | logger.info("Created new queue named %s.", name) 29 | 30 | if close_after is not None and close_after.total_seconds() > 0: 31 | thr = threading.Timer( 32 | interval=close_after.total_seconds(), 33 | function=purge_rmq, 34 | kwargs={ 35 | "q": queue, 36 | "logger": logger}) 37 | thr.setDaemon(True) 38 | thr.start() 39 | 40 | _queue_cache[name] = queue 41 | if logger is not None: 42 | logger.info("Added queue named %s to queue cache.", name) 43 | return queue 44 | 45 | def shutdown_cached_queue(queue, logger=None, delete_queue=True): 46 | if queue is None: 47 | return 48 | 49 | _queue_cache.pop(queue.get_name(), None) 50 | if logger is not None: 51 | logger.info("Removed queue named %s from queue cache.", 52 | queue.get_name()) 53 | 54 | queue.shutdown(delete_queue) 55 | 56 | 57 | def purge_rmq(queue, logger): 58 | if queue is None: 59 | return 60 | 61 | try: 62 | queue.shutdown(True) 63 | except BaseException as exc: 64 | check_for_kb_interrupt() 65 | logger.warning(exc) 66 | return 67 | return 68 | -------------------------------------------------------------------------------- /studio/serve.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from studio import runner 4 | 5 | 6 | def main(): 7 | argparser = argparse.ArgumentParser() 8 | argparser.add_argument( 9 | '--wrapper', '-w', 10 | help='python script with function create_model ' + 11 | 'that takes modeldir ' 12 | '(that is, directory where experiment saves ' + 13 | 'the checkpoints etc)' + 14 | 'and returns dict -> dict function (model).' + 15 | 'By default, studio-serve will try to determine ' + 16 | 'this function automatically.', 17 | default=None 18 | ) 19 | 20 | argparser.add_argument('--port', 21 | help='port to run Flask server on', 22 | type=int, 23 | default=5000) 24 | 25 | argparser.add_argument('--host', 26 | help='host name.', 27 | default='0.0.0.0') 28 | 29 | argparser.add_argument( 30 | '--killafter', 31 | help='Shut down after this many seconds of inactivity', 32 | default=3600) 33 | 34 | options, other_args = argparser.parse_known_args(sys.argv[1:]) 35 | serve_args = ['studio::serve_main'] 36 | 37 | assert len(other_args) >= 1 38 | experiment_key = other_args[-1] 39 | runner_args = other_args[:-1] 40 | runner_args.append('--reuse={}/modeldir:modeldata'.format(experiment_key)) 41 | runner_args.append('--force-git') 42 | runner_args.append('--port=' + str(options.port)) 43 | 44 | if options.wrapper: 45 | serve_args.append('--wrapper=' + options.wrapper) 46 | serve_args.append('--port=' + str(options.port)) 47 | 48 | serve_args.append('--host=' + options.host) 49 | serve_args.append('--killafter=' + str(options.killafter)) 50 | 51 | total_args = runner_args + serve_args 52 | runner.main(total_args) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | studio/.DS_Store 92 | .DS_Store 93 | 94 | auth/.firebaserc 95 | auth/firebase.json 96 | 97 | # Ignore directories used by the packaging tools, 98 | # the go runner, and google tooling 99 | # 100 | bin/ 101 | include/ 102 | local/ 103 | pkg/ 104 | src/ 105 | pip-selfcheck.json 106 | 107 | # ctags 108 | .tags 109 | .tags_sorted_by_file 110 | 111 | #version - for deployment from travis 112 | .version 113 | examples/.envrc 114 | -------------------------------------------------------------------------------- /studio/dependencies_policies/studio_dependencies_policy.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | try: 4 | try: 5 | from pip._internal.operations import freeze 6 | except Exception: 7 | from pip.operations import freeze 8 | except ImportError: 9 | freeze = None 10 | 11 | from studio.dependencies_policies.dependencies_policy import DependencyPolicy 12 | 13 | 14 | class StudioDependencyPolicy(DependencyPolicy): 15 | """ 16 | StudioML policy for adjusting experiment dependencies 17 | to specific execution environment and required resources. 18 | """ 19 | 20 | def generate(self, resources_needed): 21 | if freeze is None: 22 | raise ValueError( 23 | "freeze operation is not available for StudioDependencyPolicy") 24 | 25 | needs_gpu = self._needs_gpu(resources_needed) 26 | packages = freeze.freeze() 27 | result = [] 28 | for pkg in packages: 29 | if self._is_special_reference(pkg): 30 | # git repo or directly installed package 31 | result.append(pkg) 32 | elif '==' in pkg: 33 | # pypi package 34 | pkey = re.search(r'^.*?(?=\=\=)', pkg).group(0) 35 | pversion = re.search(r'(?<=\=\=).*\Z', pkg).group(0) 36 | 37 | if needs_gpu and \ 38 | (pkey == 'tensorflow' or pkey == 'tf-nightly'): 39 | pkey = pkey + '-gpu' 40 | 41 | # TODO add installation logic for torch 42 | result.append(pkey + '==' + pversion) 43 | return result 44 | 45 | def _is_special_reference(self, pkg: str): 46 | if pkg.startswith('-e git+'): 47 | return True 48 | if 'git+https://' in pkg or 'file://' in pkg: 49 | return True 50 | return False 51 | 52 | def _needs_gpu(self, resources_needed): 53 | return resources_needed is not None and \ 54 | int(resources_needed.get('gpus')) > 0 55 | -------------------------------------------------------------------------------- /studio/firebase_provider.py: -------------------------------------------------------------------------------- 1 | from studio.db_providers.keyvalue_provider import KeyValueProvider 2 | from .firebase_storage_handler import FirebaseStorageHandler 3 | 4 | class FirebaseProvider(KeyValueProvider): 5 | 6 | def __init__(self, db_config, blocking_auth=True): 7 | self.meta_store = FirebaseStorageHandler(db_config) 8 | 9 | super().__init__( 10 | db_config, 11 | self.meta_store, 12 | blocking_auth) 13 | 14 | def _get(self, key, shallow=False): 15 | try: 16 | splitKey = key.split('/') 17 | key_path = '/'.join(splitKey[:-1]) 18 | key_name = splitKey[-1] 19 | dbobj = self.app.database().child(key_path).child(key_name) 20 | return dbobj.get(self.auth.get_token(), shallow=shallow).val() \ 21 | if self.auth else dbobj.get(shallow=shallow).val() 22 | except Exception as err: 23 | self.logger.warn(("Getting key {} from a database " + 24 | "raised an exception: {}").format(key, err)) 25 | return None 26 | 27 | def _set(self, key, value): 28 | try: 29 | splitKey = key.split('/') 30 | key_path = '/'.join(splitKey[:-1]) 31 | key_name = splitKey[-1] 32 | dbobj = self.app.database().child(key_path) 33 | if self.auth: 34 | dbobj.update({key_name: value}, self.auth.get_token()) 35 | else: 36 | dbobj.update({key_name: value}) 37 | except Exception as err: 38 | self.logger.warn(("Putting key {}, value {} into a database " + 39 | "raised an exception: {}") 40 | .format(key, value, err)) 41 | 42 | def _delete(self, key, shallow=True, token=None): 43 | dbobj = self.app.database().child(key) 44 | 45 | if self.auth: 46 | dbobj.remove(self.auth.get_token()) 47 | else: 48 | dbobj.remove() 49 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Studio.ml 3 | ========= 4 | 5 | `Github `_ | 6 | `pip `_ | 7 | 8 | 9 | Studio is a model management framework written in Python to help simplify and expedite your model building experience. It was developed to minimize the overhead involved with scheduling, running, monitoring and managing artifacts of your machine learning experiments. No one wants to spend their time configuring different machines, setting up dependencies, or playing archeologist to track down previous model artifacts. 10 | 11 | Most of the features are compatible with any Python machine learning framework (`Keras `__, `TensorFlow `__, `PyTorch `__, `scikit-learn `__, etc) without invasion of your code; some extra features are available for Keras and TensorFlow. 12 | 13 | **Use Studio to:** 14 | 15 | - Capture experiment information- Python environment, files, dependencies and logs- without modifying the experiment code. Monitor and organize experiments using a web dashboard that integrates with TensorBoard. 16 | - Run experiments locally, remotely, or in the cloud (Google Cloud or Amazon EC2) 17 | - Manage artifacts 18 | - Perform hyperparameter search 19 | - Create customizable Python environments for remote workers. 20 | 21 | 22 | .. toctree:: 23 | :hidden: 24 | :caption: Introduction 25 | 26 | Getting Started 27 | installation 28 | authentication 29 | cli 30 | 31 | .. toctree:: 32 | :hidden: 33 | :caption: Main Documentation 34 | 35 | artifacts 36 | hyperparams 37 | model_pipelines 38 | setup_database 39 | 40 | .. toctree:: 41 | :hidden: 42 | :caption: Remote computing 43 | 44 | remote_worker 45 | customenv 46 | 47 | .. toctree:: 48 | :hidden: 49 | :caption: Cloud computing 50 | 51 | cloud 52 | ec2_setup 53 | gcloud_setup 54 | 55 | -------------------------------------------------------------------------------- /tests/util_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from studio import util 3 | from random import randint 4 | 5 | 6 | class UtilTest(unittest.TestCase): 7 | def test_remove_backspaces(self): 8 | testline = 'abcd\x08\x08\x08efg\x08\x08hi\x08' 9 | removed = util.remove_backspaces(testline) 10 | self.assertTrue(removed == 'aeh') 11 | 12 | testline = 'abcd\x08\x08\x08efg\x08\x08hi' 13 | removed = util.remove_backspaces(testline) 14 | self.assertTrue(removed == 'aehi') 15 | 16 | testline = 'abcd' 17 | removed = util.remove_backspaces(testline) 18 | self.assertTrue(removed == 'abcd') 19 | 20 | testline = 'abcd\n\ndef' 21 | removed = util.remove_backspaces(testline) 22 | self.assertTrue(removed == testline) 23 | 24 | def test_retry(self): 25 | attempts = [0] 26 | value = randint(0, 1000) 27 | 28 | def failing_func(): 29 | attempts[0] += 1 30 | if attempts[0] != 2: 31 | raise ValueError('Attempt {} failed' 32 | .format(attempts[0])) 33 | return value 34 | 35 | retval = util.retry(failing_func, 36 | no_retries=2, 37 | sleep_time=1, 38 | exception_class=ValueError) 39 | 40 | self.assertEqual(retval, value) 41 | self.assertEqual(attempts, [2]) 42 | 43 | # test out for catching different exception class 44 | try: 45 | retval = util.retry(failing_func, 46 | no_retries=2, 47 | sleep_time=1, 48 | exception_class=OSError) 49 | except ValueError: 50 | pass 51 | else: 52 | self.assertTrue(False) 53 | 54 | def test_str2duration(self): 55 | self.assertEqual( 56 | int(util.str2duration('30m').total_seconds()), 57 | 1800) 58 | 59 | 60 | if __name__ == "__main__": 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /docs/jupyter.rst: -------------------------------------------------------------------------------- 1 | Jupyter / ipython notebooks 2 | =========================== 3 | 4 | Studio can be used not only with scripts, but also with 5 | jupyter notebooks. The main idea is as follows - 6 | the cell annotated with a special cell magic is being treated 7 | as a separate script; and the variables are being passed in and 8 | out as artifacts (this means that all variables the cell 9 | depends on have to be pickleable). The script can then be run 10 | either locally (in which case the main benefit of studio 11 | is keeping track of all runs of the cell), or in the cloud / remotely. 12 | 13 | To use Studio in your notebook, add 14 | 15 | :: 16 | 17 | from studio import magics 18 | 19 | to the import section of your notebook. 20 | 21 | Then annotate the cell that you'd like to run via studio with 22 | 23 | :: 24 | 25 | %%studio_run 26 | 27 | This will execute the statements in the cell using studio, 28 | also passing ```` to the runner. 29 | For example, let's imagine that a variable ``x`` is declared in 30 | your notebook. Then 31 | 32 | :: 33 | 34 | %%studio_run --cloud=gcloud 35 | x += 1 36 | 37 | will do the increment of the variable ``x`` in the notebook namespace 38 | using a google cloud compute 39 | instance (given that increment of a variable in python does not take a millisecond, 40 | spinning up an entire instance to do that is probably the most wasteful thing you 41 | have seen today, but you get the idea :). The ``%%studio_run`` cell magic 42 | accepts the same arguments as the ``studio run`` command, please refer to 43 | `` for a more involved discussion of cloud and hardware selection options. 44 | 45 | Every run with studio will get a unique key and can be viewed as an experiment in 46 | studio ui. 47 | 48 | The only limitation to using studio in a notebook is that variables being used 49 | in a studio-run cell have to be pickleable. That means that, for example, you 50 | cannot use lambda functions defined elsewhere, because those are not 51 | pickleable. 52 | 53 | 54 | -------------------------------------------------------------------------------- /docs/containers.rst: -------------------------------------------------------------------------------- 1 | ========================= 2 | Containerized experiments 3 | ========================= 4 | 5 | Some experiments may require more than just a specific python environment to be run reproducibly. For instance, 2017 NIPS running 6 | competition relied on a specific set of system-level pacakges for walker physics simulations. To address such experiments, Studio.ML 7 | supports execution in containers by using Singularity (https://singularity.lbl.gov). Singularity supports both Docker and its own format 8 | of containers. Containers can be used in two main ways: 9 | 10 | 1. Running experiment using container environment 11 | ------------------------------------------------- 12 | In this mode, an environment is set up within the container, but the python code is outside. Studio.ML with help of Singularity 13 | mounts copy of current directory and artifacts into the container and executes the script. Typical command line will look like 14 | 15 | :: 16 | 17 | studio run --container=/path/to/container.simg script.py args 18 | 19 | 20 | Note that if your script is using Studio.ML library functions (such as `fs_tracker.get_artifact()`), Studio.ML will need to be 21 | installed within the container. 22 | 23 | 2. Running experiment using executable container 24 | ------------------------------------------------ 25 | Both singularity and docker support executable containers. Studio.ML experiment can consist solely out of an executable container: 26 | 27 | :: 28 | 29 | studio run --container=/path/to/container.simg 30 | 31 | In this case, the code does not even need to be python, but all Studio.ML perks (such as cloud execution with hardware selection, 32 | keeping track of inputs and outputs of the experiment etc) still apply. There is even an artifact management - artifacts will be 33 | seen in the container in the folder one level up from working directory. 34 | 35 | Containers can be located either locally as `*.simg` files, or in the Singularity/Docker hub. In the latter case, provide a link that 36 | starts with `shub://` or `dockerhub://` 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /studio/db_providers/local_db_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from studio.db_providers.keyvalue_provider import KeyValueProvider 5 | from studio.storage.storage_handler_factory import StorageHandlerFactory 6 | from studio.storage.storage_type import StorageType 7 | from studio.util import util 8 | 9 | class LocalDbProvider(KeyValueProvider): 10 | 11 | def __init__(self, config, blocking_auth=True): 12 | self.config = config 13 | self.bucket = config.get('bucket', 'studioml-meta') 14 | 15 | factory: StorageHandlerFactory = StorageHandlerFactory.get_factory() 16 | self.meta_store = factory.get_handler(StorageType.storageLocal, config) 17 | 18 | self.endpoint = self.meta_store.get_endpoint() 19 | self.db_root = os.path.join(self.endpoint, self.bucket) 20 | self._ensure_path_dirs_exist(self.db_root) 21 | 22 | super().__init__(config, self.meta_store, blocking_auth) 23 | 24 | def _ensure_path_dirs_exist(self, path): 25 | dirs = os.path.dirname(path) 26 | os.makedirs(dirs, mode = 0o777, exist_ok = True) 27 | 28 | def _get(self, key, shallow=False): 29 | file_name = os.path.join(self.db_root, key) 30 | if not os.path.exists(file_name): 31 | return None 32 | try: 33 | with open(file_name) as infile: 34 | result = json.load(infile) 35 | except BaseException as exc: 36 | self.logger.error("FAILED to load file %s - %s", file_name, exc) 37 | result = None 38 | return result 39 | 40 | def _delete(self, key, shallow=True): 41 | file_name = os.path.join(self.db_root, key) 42 | if os.path.exists(file_name): 43 | self.logger.debug("Deleting local database file %s.", file_name) 44 | util.delete_local_path(file_name, self.db_root, shallow) 45 | 46 | def _set(self, key, value): 47 | file_name = os.path.join(self.db_root, key) 48 | self._ensure_path_dirs_exist(file_name) 49 | with open(file_name, 'w') as outfile: 50 | json.dump(value, outfile) 51 | -------------------------------------------------------------------------------- /tests/model_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import uuid 3 | try: 4 | try: 5 | from pip._internal.operations import freeze 6 | except Exception: 7 | from pip.operations import freeze 8 | except ImportError: 9 | freeze = None 10 | import os 11 | 12 | from studio import model 13 | from studio.experiments.experiment import create_experiment 14 | from studio.dependencies_policies.studio_dependencies_policy import StudioDependencyPolicy 15 | 16 | 17 | def get_test_experiment(): 18 | filename = 'test.py' 19 | args = ['a', 'b', 'c'] 20 | experiment_name = 'test_experiment_' + str(uuid.uuid4()) 21 | experiment = create_experiment(filename, args, experiment_name, dependency_policy=StudioDependencyPolicy()) 22 | return experiment, experiment_name, filename, args 23 | 24 | 25 | class ModelTest(unittest.TestCase): 26 | def test_create_experiment(self): 27 | _, experiment_name, filename, args = get_test_experiment() 28 | experiment_project = 'create_experiment_project' 29 | experiment = create_experiment( 30 | filename, args, experiment_name, experiment_project, dependency_policy=StudioDependencyPolicy()) 31 | 32 | packages = [p for p in freeze.freeze()] 33 | 34 | self.assertTrue(experiment.key == experiment_name) 35 | self.assertTrue(experiment.filename == filename) 36 | self.assertTrue(experiment.args == args) 37 | self.assertTrue(experiment.project == experiment_project) 38 | self.assertTrue(sorted(experiment.pythonenv) == sorted(packages)) 39 | 40 | def test_get_config_env(self): 41 | value1 = str(uuid.uuid4()) 42 | os.environ['TEST_VAR1'] = value1 43 | value2 = str(uuid.uuid4()) 44 | os.environ['TEST_VAR2'] = value2 45 | 46 | config = model.get_config( 47 | os.path.join(os.path.dirname(os.path.realpath(__file__)), 48 | 'test_config_env.yaml')) 49 | self.assertEqual(config['test_key'], value1) 50 | self.assertEqual(config['test_section']['test_key'], value2) 51 | 52 | 53 | if __name__ == "__main__": 54 | unittest.main() 55 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | 3 | language: python 4 | python: 5 | - "2.7" 6 | - "3.6" 7 | 8 | env: 9 | fast_finish: true 10 | 11 | script: 12 | - git fetch --unshallow 13 | - 'if [ -n "$GOOGLE_CREDENTIALS_ENCODED" ]; then echo "$GOOGLE_CREDENTIALS_ENCODED" | base64 --decode > ./credentials.json && export GOOGLE_APPLICATION_CREDENTIALS=$(pwd)/credentials.json; fi' 14 | # - 'if [ -n "$GOOGLE_CREDENTIALS_DS_ENCODED" ]; then echo "$GOOGLE_CREDENTIALS_DS_ENCODED" | base64 --decode > ./credentials_ds.json && export GOOGLE_APPLICATION_CREDENTIALS_DC=$(pwd)/credentials_ds.json; fi' 15 | - pip install -e . 16 | - pip install -r test_requirements.txt 17 | 18 | - pycodestyle --show-source --exclude=.eggs --ignore=W605,W504,E704 . 19 | 20 | - echo GOOGLE_APPLICATION_CREDENTIALS=$GOOGLE_APPLICATION_CREDENTIALS 21 | - echo GOOGLE_APPLICATION_CREDENTIALS_DC=$GOOGLE_APPLICATION_CREDENTIALS_DC 22 | - pip freeze 23 | - pytest --collect-only -v 24 | - pytest -n 4 -v --durations=0 -l 25 | 26 | jobs: 27 | include: 28 | - stage: deploy 29 | python: 2.7 30 | script: 31 | - git fetch --unshallow 32 | - python setup.py sdist 33 | 34 | deploy: 35 | provider: pypi 36 | distributions: "sdist" 37 | user: pzhokhov 38 | password: 39 | secure: "FkfIvyF3PReaXOuCNSeqlg9SAiocc8WxCzlbqnfssG6JxsIBWNuoOPtgKtiYmMO+LBXnzpZ8CkTjoHT2oWWBD8r92QNmdIOze1MPs5xQmGxu12TEc5dEgMUnvenV3vkOZGC/hhwbio16Dqfd8PlHYBRkduvvqrvRD1xJPsggx9tgZwZ+Vvv0h51/BpOIMxm0xV3qu1U5A3BIvyzTUucvCAPbmHeZj+tWtiTq3OEwpc812PsCNkFNEOIBUkwafu1VE15tFcud/pZyFYwEK8/z35CJDGY4oVKWZyoC5/Gp9w658ps6MnjWP36Y1GNj4wpQPL2ftvGzRfXvfZ/FcSi4mreon24YPYFlY76ezLibWq8m2/ZBlQQwchFi48nGoalDJ7a92hyVkrRr1UP87+fbrZW8u4l/2HiKYKgotAixnoKKVOaWcmhraLve+kLE5F6e0lwmwCOHQwO9Cz4PECiZtxf8ePdDj+Dr9RLI3rKehRATpL+kXkR4WzKXPJAfDiPS6WiOz/melZJIttQcFZpOUVCs79yKWj8E95atbEy0jFaGn1PR+lFHqFYsg8TTe9Q9L/qMc+ECzBzQ4ryClCKDMYylCqbfDpOmCtHZw0a79IKmSyyvNbjnFMhAcn+GHFp+O46iF7Zxb1fAmp45AZ48bG/FCVqp66EXr590Fy1gjh0=" 40 | skip_cleanup: true 41 | on: 42 | # tags: true 43 | branch: master 44 | repo: studioml/studio 45 | condition: $TRAVIS_EVENT_TYPE == "push" 46 | -------------------------------------------------------------------------------- /examples/tensorflow/train_mnist.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import sys 3 | import os 4 | import time 5 | 6 | from keras.objectives import categorical_crossentropy 7 | from keras.layers import Dense 8 | from tensorflow.examples.tutorials.mnist import input_data 9 | 10 | from keras import backend as K 11 | from studio import fs_tracker 12 | 13 | import logging 14 | 15 | logging.basicConfig() 16 | 17 | sess = tf.Session() 18 | K.set_session(sess) 19 | # Now let's get started with our MNIST model. We can start building a 20 | # classifier exactly as you would do in TensorFlow: 21 | 22 | # this placeholder will contain our input digits, as flat vectors 23 | img = tf.placeholder(tf.float32, shape=(None, 784)) 24 | # We can then use Keras layers to speed up the model definition process: 25 | 26 | 27 | # Keras layers can be called on TensorFlow tensors: 28 | # fully-connected layer with 128 units and ReLU activation 29 | x = Dense(128, activation='relu')(img) 30 | x = Dense(128, activation='relu')(x) 31 | # output layer with 10 units and a softmax activation 32 | preds = Dense(10, activation='softmax')(x) 33 | # We define the placeholder for the labels, and the loss function we will use: 34 | 35 | labels = tf.placeholder(tf.float32, shape=(None, 10)) 36 | 37 | loss = tf.reduce_mean(categorical_crossentropy(labels, preds)) 38 | # Let's train the model with a TensorFlow optimizer: 39 | 40 | mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True) 41 | 42 | global_step = tf.Variable(0, name='global_step', trainable=False) 43 | train_step = tf.train.GradientDescentOptimizer( 44 | 0.5).minimize(loss, global_step=global_step) 45 | # Initialize all variables 46 | init_op = tf.global_variables_initializer() 47 | saver = tf.train.Saver() 48 | sess.run(init_op) 49 | 50 | 51 | logger = logging.get_logger('train_mnist') 52 | logger.setLevel(10) 53 | # Run training loop 54 | with sess.as_default(): 55 | while True: 56 | batch = mnist_data.train.next_batch(50) 57 | train_step.run(feed_dict={img: batch[0], 58 | labels: batch[1]}) 59 | 60 | sys.stdout.flush() 61 | saver.save( 62 | sess, 63 | os.path.join( 64 | fs_tracker.get_model_directory(), 65 | "ckpt"), 66 | global_step=global_step) 67 | time.sleep(1) 68 | -------------------------------------------------------------------------------- /examples/keras/fashion_mnist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import urllib 3 | 4 | from keras.layers import Dense, Flatten 5 | 6 | from keras.models import Sequential 7 | 8 | ### 9 | # AS OF 10/18/2017, fashion_mnist available as a part of github master 10 | # branch of keras 11 | # but not a part of pypi package 12 | # Therefore, to use this, you'll need keras installed from a git repo: 13 | # git clone https://github.com/fchollet/keras && cd keras && pip install . 14 | ### 15 | from keras.datasets import fashion_mnist 16 | from keras.utils import to_categorical 17 | 18 | from keras.callbacks import ModelCheckpoint, TensorBoard 19 | from keras import optimizers 20 | 21 | import numpy as np 22 | from PIL import Image 23 | from io import BytesIO 24 | 25 | from studio import fs_tracker, model_util 26 | 27 | (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data() 28 | 29 | x_train = x_train.reshape(60000, 28, 28, 1) 30 | x_test = x_test.reshape(10000, 28, 28, 1) 31 | x_train = x_train.astype('float32') 32 | x_test = x_test.astype('float32') 33 | x_train /= 255 34 | x_test /= 255 35 | 36 | # convert class vectors to binary class matrices 37 | y_train = to_categorical(y_train, 10) 38 | y_test = to_categorical(y_test, 10) 39 | 40 | 41 | model = Sequential() 42 | 43 | model.add(Flatten(input_shape=(28, 28, 1))) 44 | model.add(Dense(128, activation='relu')) 45 | model.add(Dense(128, activation='relu')) 46 | 47 | model.add(Dense(10, activation='softmax')) 48 | model.summary() 49 | 50 | 51 | batch_size = 128 52 | no_epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 10 53 | lr = 0.01 54 | 55 | print('learning rate = {}'.format(lr)) 56 | print('batch size = {}'.format(batch_size)) 57 | print('no_epochs = {}'.format(no_epochs)) 58 | 59 | model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=lr), 60 | metrics=['accuracy']) 61 | 62 | 63 | checkpointer = ModelCheckpoint( 64 | fs_tracker.get_model_directory() + 65 | '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf') 66 | 67 | 68 | tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(), 69 | histogram_freq=0, 70 | write_graph=True, 71 | write_images=True) 72 | 73 | 74 | model.fit( 75 | x_train, y_train, validation_data=( 76 | x_test, 77 | y_test), 78 | epochs=no_epochs, 79 | callbacks=[checkpointer, tbcallback], 80 | batch_size=batch_size) 81 | -------------------------------------------------------------------------------- /studio/ed25519_key_util.py: -------------------------------------------------------------------------------- 1 | import openssh_key.private_key_list as pkl 2 | from openssh_key.key import PublicKey 3 | 4 | class Ed25519KeyUtil: 5 | 6 | @classmethod 7 | def parse_private_key_file(cls, filepath: str, logger): 8 | """ 9 | Parse a file with ed25519 private key in OPEN SSH PRIVATE key format. 10 | :param filepath: file path to a key file 11 | :param logger: logger to use for messages 12 | :return: (public key part as bytes, private key part as bytes) 13 | """ 14 | contents: str = '' 15 | try: 16 | with open(filepath, "r") as f: 17 | contents = f.read() 18 | except Exception as exc: 19 | if logger is not None: 20 | logger.error("FAILED to read keyfile %s: %s", 21 | filepath, exc) 22 | return None, None 23 | try: 24 | key_data = pkl.PrivateKeyList.from_string(contents) 25 | data_public = key_data[0].private.params['public'] 26 | data_private = key_data[0].private.params['private_public'] 27 | return data_public, data_private[:32] 28 | except Exception as exc: 29 | if logger is not None: 30 | logger.error("FAILED to decode keyfile format %s: %s", 31 | filepath, exc) 32 | return None, None 33 | 34 | @classmethod 35 | def parse_public_key_file(cls, filepath: str, logger): 36 | """ 37 | Parse a file with ed25519 public. 38 | :param filepath: file path to a key file 39 | :param logger: logger to use for messages 40 | :return: public key part as bytes 41 | """ 42 | contents: str = '' 43 | try: 44 | with open(filepath, "r") as f: 45 | contents = f.read() 46 | except Exception as exc: 47 | if logger is not None: 48 | logger.error("FAILED to read keyfile %s: %s", 49 | filepath, exc) 50 | return None, None 51 | try: 52 | key_data = PublicKey.from_string(contents) 53 | data_public = key_data.params['public'] 54 | return data_public 55 | except Exception as exc: 56 | if logger is not None: 57 | logger.error("FAILED to decode keyfile format %s: %s", 58 | filepath, exc) 59 | return None 60 | -------------------------------------------------------------------------------- /studio/torch/saver.py: -------------------------------------------------------------------------------- 1 | """Tools to save/restore model from checkpoints.""" 2 | 3 | import os 4 | try: 5 | import torch 6 | except ImportError: 7 | torch = None 8 | 9 | 10 | def load_checkpoint(model, optimizer, model_dir, map_to_cpu=False): 11 | path = os.path.join(model_dir, 'checkpoint') 12 | if os.path.exists(path): 13 | print("Loading model from %s" % path) 14 | if map_to_cpu: 15 | checkpoint = torch.load( 16 | path, map_location=lambda storage, location: storage) 17 | else: 18 | checkpoint = torch.load(path) 19 | old_state_dict = model.state_dict() 20 | for key in old_state_dict.keys(): 21 | if key not in checkpoint['model']: 22 | checkpoint['model'][key] = old_state_dict[key] 23 | model.load_state_dict(checkpoint['model']) 24 | optimizer.load_state_dict(checkpoint['optimizer']) 25 | return checkpoint.get('step', 0) 26 | return 0 27 | 28 | 29 | def save_checkpoint(model, optimizer, step, model_dir, ignore=[]): 30 | if not os.path.exists(model_dir): 31 | os.makedirs(model_dir) 32 | path = os.path.join(model_dir, 'checkpoint') 33 | state_dict = model.state_dict() 34 | if ignore: 35 | for key in state_dict.keys(): 36 | for item in ignore: 37 | if key.startswith(item): 38 | state_dict.pop(key) 39 | torch.save({ 40 | 'model': state_dict, 41 | 'optimizer': optimizer.state_dict(), 42 | 'step': step 43 | }, path) 44 | 45 | 46 | class Saver(object): 47 | """Class to manage save and restore for the model and optimizer.""" 48 | 49 | def __init__(self, model, optimizer): 50 | self._model = model 51 | self._optimizer = optimizer 52 | 53 | def restore(self, model_dir, map_to_cpu=False): 54 | """Restores model and optimizer from given directory. 55 | 56 | Returns: 57 | Last training step for the model restored. 58 | """ 59 | last_step = load_checkpoint( 60 | self._model, self._optimizer, model_dir, map_to_cpu) 61 | return last_step 62 | 63 | def save(self, model_dir, step): 64 | """Saves model and optimizer to given directory. 65 | 66 | Args: 67 | model_dir: Model directory to save. 68 | step: Current training step. 69 | """ 70 | save_checkpoint(self._model, self._optimizer, step, model_dir) 71 | -------------------------------------------------------------------------------- /examples/keras/train_cifar10.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from studio import fs_tracker 4 | 5 | from keras.layers import Dense, Flatten, Conv2D, BatchNormalization 6 | 7 | from keras.models import Sequential 8 | from keras.datasets import cifar10 9 | from keras.utils import to_categorical 10 | 11 | from keras import optimizers 12 | from keras.callbacks import ModelCheckpoint, TensorBoard 13 | 14 | import tensorflow as tf 15 | from keras import backend as backend 16 | 17 | config = tf.ConfigProto() 18 | config.gpu_options.allow_growth = True 19 | sess = tf.Session(config=config) 20 | 21 | backend.set_session(sess) 22 | 23 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 24 | 25 | x_train = x_train.reshape(50000, 32, 32, 3) 26 | x_test = x_test.reshape(10000, 32, 32, 3) 27 | x_train = x_train.astype('float32') 28 | x_test = x_test.astype('float32') 29 | x_train /= 255 30 | x_test /= 255 31 | 32 | # convert class vectors to binary class matrices 33 | y_train = to_categorical(y_train, 10) 34 | y_test = to_categorical(y_test, 10) 35 | 36 | 37 | model = Sequential() 38 | 39 | model.add(Conv2D(32, (5, 5), activation='relu', input_shape=(32, 32, 3))) 40 | model.add(Conv2D(32, (3, 3), activation='relu')) 41 | model.add(BatchNormalization()) 42 | model.add(Conv2D(64, (5, 5), activation='relu')) 43 | model.add(Conv2D(64, (3, 3), activation='relu')) 44 | model.add(BatchNormalization()) 45 | model.add(Conv2D(128, (3, 3), activation='relu')) 46 | model.add(Conv2D(128, (3, 3), activation='relu')) 47 | model.add(BatchNormalization()) 48 | model.add(Flatten()) 49 | model.add(Dense(10, activation='softmax')) 50 | model.summary() 51 | 52 | 53 | batch_size = 128 54 | no_epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 40 55 | lr = 0.01 56 | 57 | print('learning rate = {}'.format(lr)) 58 | print('batch size = {}'.format(batch_size)) 59 | print('no_epochs = {}'.format(no_epochs)) 60 | 61 | model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=lr), 62 | metrics=['accuracy']) 63 | 64 | 65 | checkpointer = ModelCheckpoint( 66 | fs_tracker.get_model_directory() + 67 | '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf') 68 | 69 | 70 | tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(), 71 | histogram_freq=0, 72 | write_graph=True, 73 | write_images=True) 74 | 75 | 76 | model.fit( 77 | x_train, y_train, validation_data=( 78 | x_test, 79 | y_test), 80 | epochs=no_epochs, 81 | callbacks=[checkpointer, tbcallback], 82 | batch_size=batch_size) 83 | -------------------------------------------------------------------------------- /studio/deploy_apiserver.sh: -------------------------------------------------------------------------------- 1 | #if [[ -z "$FIREBASE_ADMIN_CREDENTIALS" ]]; then 2 | # echo "*** Firbase admin credentials file reqiured! ***" 3 | # echo "Input path to firebase admin credentials json" 4 | # echo "(you can also set FIREBASE_ADMIN_CREDENTIALS env variable manually):" 5 | # read -p ">>" firebase_creds 6 | #else 7 | # firebase_creds=$FIREBASE_ADMIN_CREDENTIALS 8 | #fi 9 | 10 | #if [ ! -f $firebase_creds ]; then 11 | # echo " *** File $firebase_creds does not exist! ***" 12 | # exit 1 13 | #fi 14 | 15 | #creds="./firebase_admin_creds.json" 16 | #cp $firebase_creds $creds 17 | 18 | config="apiserver_config.yaml" 19 | if [ -n "$2" ]; then 20 | config=$2 21 | fi 22 | 23 | echo "config file = $config" 24 | 25 | if [ "$1" = "gae" ]; then 26 | 27 | mv default_config.yaml default_config.yaml.orig 28 | cp $config default_config.yaml 29 | cp app.yaml app.yaml.old 30 | echo "env_variables:" >> app.yaml 31 | 32 | if [ -n "$AWS_ACCESS_KEY_ID" ]; then 33 | echo "exporting AWS env variables to app.yaml" 34 | echo " AWS_ACCESS_KEY_ID: $AWS_ACCESS_KEY_ID" >> app.yaml 35 | echo " AWS_SECRET_ACCESS_KEY: $AWS_SECRET_ACCESS_KEY" >> app.yaml 36 | echo " AWS_DEFAULT_REGION: $AWS_DEFAULT_REGION" >> app.yaml 37 | fi 38 | 39 | if [ -n "$STUDIO_GITHUB_ID" ]; then 40 | echo "exporting github secret env variables to app.yaml" 41 | echo " STUDIO_GITHUB_ID: $STUDIO_GITHUB_ID" >> app.yaml 42 | echo " STUDIO_GITHUB_SECRET: $STUDIO_GITHUB_SECRET" >> app.yaml 43 | fi 44 | 45 | rm -rf lib 46 | # pip install -t lib -r ../requirements.txt 47 | pip install -t lib ../ 48 | # pip install -t lib -r ../extra_server_requirements.txt 49 | 50 | # patch library files where necessary 51 | for patch in $(find patches -name "*.patch"); do 52 | filename=${patch#patches/} 53 | filename=${filename%.patch} 54 | patch lib/$filename $patch 55 | done 56 | 57 | rm lib/tensorflow/python/_pywrap_tensorflow_internal.so 58 | echo "" > lib/tensorflow/__init__.py 59 | 60 | # dev_appserver.py app.yaml --dev_appserver_log_level debug 61 | yes Y | gcloud app deploy --no-promote 62 | 63 | mv default_config.yaml.orig default_config.yaml 64 | mv app.yaml.old app.yaml 65 | else if [ "$1" = "local" ]; then 66 | port=$2 67 | studio ui --config=apiserver_config.yaml --port=$port 68 | else 69 | echo "*** unknown target: $1 (should be either gae or local) ***" 70 | exit 1 71 | fi 72 | fi 73 | 74 | # rm -f $creds 75 | # rm -rf lib 76 | -------------------------------------------------------------------------------- /studio/fs_tracker.py: -------------------------------------------------------------------------------- 1 | """Utilities to track and record file system.""" 2 | 3 | import os 4 | import shutil 5 | 6 | from studio.artifacts import artifacts_tracker 7 | 8 | STUDIOML_EXPERIMENT = 'STUDIOML_EXPERIMENT' 9 | STUDIOML_HOME = 'STUDIOML_HOME' 10 | STUDIOML_ARTIFACT_MAPPING = 'STUDIOML_ARTIFACT_MAPPING' 11 | 12 | 13 | def get_experiment_key(): 14 | return artifacts_tracker.get_experiment_key() 15 | 16 | 17 | def get_studio_home(): 18 | return artifacts_tracker.get_studio_home() 19 | 20 | 21 | def setup_experiment(env, experiment, clean=True): 22 | artifacts_tracker.setup_experiment(env, experiment, clean=clean) 23 | 24 | 25 | def get_artifact(tag): 26 | return artifacts_tracker.get_experiment(tag) 27 | 28 | 29 | def get_artifacts(): 30 | return artifacts_tracker.get_artifacts() 31 | 32 | 33 | def get_model_directory(experiment_name=None): 34 | return get_artifact_cache('modeldir', experiment_name) 35 | 36 | 37 | def get_artifact_cache(tag, experiment_name=None): 38 | return artifacts_tracker.get_artifact_cache( 39 | tag, 40 | experiment_name=experiment_name) 41 | 42 | 43 | def get_blob_cache(blobkey): 44 | return artifacts_tracker.get_blob_cache(blobkey) 45 | 46 | def get_model_directory(experiment_name=None): 47 | return get_artifact_cache('modeldir', experiment_name) 48 | 49 | def _get_artifact_mapping_path(experiment_name=None): 50 | experiment_name = experiment_name if experiment_name else \ 51 | os.environ[STUDIOML_EXPERIMENT] 52 | 53 | basepath = os.path.join( 54 | get_studio_home(), 55 | 'artifact_mappings', 56 | experiment_name 57 | ) 58 | if not os.path.exists(basepath): 59 | os.makedirs(basepath) 60 | 61 | return os.path.join(basepath, 'artifacts.json') 62 | 63 | 64 | def _get_experiment_key(experiment): 65 | if not isinstance(experiment, str): 66 | return experiment.key 67 | else: 68 | return experiment 69 | 70 | 71 | def _setup_model_directory(experiment_name, clean=False): 72 | path = get_model_directory(experiment_name) 73 | if clean and os.path.exists(path): 74 | shutil.rmtree(path) 75 | 76 | if not os.path.exists(path): 77 | os.makedirs(path) 78 | 79 | 80 | def get_queue_directory(): 81 | queue_dir = os.path.join( 82 | get_studio_home(), 83 | 'queue') 84 | if not os.path.exists(queue_dir): 85 | try: 86 | os.makedirs(queue_dir) 87 | except OSError: 88 | pass 89 | 90 | return queue_dir 91 | 92 | 93 | def get_tensorboard_dir(experiment_name=None): 94 | return get_artifact_cache('tb', experiment_name) 95 | -------------------------------------------------------------------------------- /studio/cloud_worker_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from studio.util.util import rand_string 4 | """ 5 | Utility functions for anything shared in common by ec2cloud_worker and 6 | gcloud_worker 7 | """ 8 | 9 | INDENT = 4 10 | 11 | 12 | def insert_user_startup_script(user_startup_script, startup_script_str, 13 | logger): 14 | if user_startup_script is None: 15 | return startup_script_str 16 | 17 | try: 18 | with open(os.path.abspath(os.path.expanduser( 19 | user_startup_script))) as f: 20 | user_startup_script_lines = f.read().splitlines() 21 | except BaseException: 22 | if user_startup_script is not None: 23 | logger.warn("User startup script (%s) cannot be loaded" % 24 | user_startup_script) 25 | return startup_script_str 26 | 27 | startup_script_lines = startup_script_str.splitlines() 28 | new_startup_script_lines = [] 29 | whitespace = " " * INDENT 30 | for line in startup_script_lines: 31 | 32 | if line.startswith("studio remote worker") or \ 33 | line.startswith("studio-remote-worker"): 34 | curr_working_dir = "curr_working_dir_%s" % rand_string(32) 35 | func_name = "user_script_%s" % rand_string(32) 36 | 37 | new_startup_script_lines.append("%s=$(pwd)\n" % curr_working_dir) 38 | new_startup_script_lines.append("cd ~\n") 39 | new_startup_script_lines.append("%s()(\n" % func_name) 40 | for user_line in user_startup_script_lines: 41 | if user_line.startswith("#!"): 42 | continue 43 | new_startup_script_lines.append("%s%s\n" % 44 | (whitespace, user_line)) 45 | 46 | new_startup_script_lines.append("%scd $%s\n" % 47 | (whitespace, curr_working_dir)) 48 | new_startup_script_lines.append("%s%s\n" % 49 | (whitespace, line)) 50 | new_startup_script_lines.append(")\n") 51 | new_startup_script_lines.append("%s\n" % func_name) 52 | else: 53 | new_startup_script_lines.append("%s\n" % line) 54 | 55 | new_startup_script = "".join(new_startup_script_lines) 56 | logger.info('Inserting the following user startup script' 57 | ' into the default startup script:') 58 | logger.info("\n".join(user_startup_script_lines)) 59 | 60 | # with open("/home/jason/Desktop/script.sh", 'wb') as f: 61 | # f.write(new_startup_script) 62 | # sys.exit() 63 | 64 | return new_startup_script 65 | -------------------------------------------------------------------------------- /studio/storage/http_storage_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from urllib.parse import urlparse 3 | from typing import Dict 4 | from studio.util import logs 5 | from studio.credentials.credentials import Credentials 6 | from studio.storage.storage_setup import get_storage_verbose_level 7 | from studio.storage.storage_type import StorageType 8 | from studio.storage.storage_handler import StorageHandler 9 | from studio.storage import storage_util 10 | 11 | class HTTPStorageHandler(StorageHandler): 12 | def __init__(self, remote_path, credentials_dict, 13 | timestamp=None, 14 | compression=None): 15 | 16 | self.logger = logs.get_logger(self.__class__.__name__) 17 | self.logger.setLevel(get_storage_verbose_level()) 18 | 19 | self.url = remote_path 20 | self.timestamp = timestamp 21 | 22 | parsed_url = urlparse(self.url) 23 | self.scheme = parsed_url.scheme 24 | self.endpoint = parsed_url.netloc 25 | self.path = parsed_url.path 26 | self.credentials = Credentials(credentials_dict) 27 | 28 | super().__init__(StorageType.storageHTTP, 29 | self.logger, 30 | False, 31 | compression=compression) 32 | 33 | def upload_file(self, key, local_path): 34 | storage_util.upload_file(self.url, local_path, self.logger) 35 | 36 | def download_file(self, key, local_path): 37 | return storage_util.download_file(self.url, local_path, self.logger) 38 | 39 | def download_remote_path(self, remote_path, local_path): 40 | head, _ = os.path.split(local_path) 41 | if head is not None: 42 | os.makedirs(head, exist_ok=True) 43 | return storage_util.download_file(remote_path, local_path, self.logger) 44 | 45 | @classmethod 46 | def get_id(cls, config: Dict) -> str: 47 | endpoint = config.get('endpoint', None) 48 | if endpoint is None: 49 | return None 50 | creds: Credentials = Credentials.get_credentials(config) 51 | creds_fingerprint = creds.get_fingerprint() if creds else '' 52 | return '[http]{0}::{1}'.format(endpoint, creds_fingerprint) 53 | 54 | def get_local_destination(self, remote_path: str): 55 | parsed_url = urlparse(remote_path) 56 | parts = parsed_url.path.split('/') 57 | return None, parts[len(parts)-1] 58 | 59 | def delete_file(self, key, shallow=True): 60 | raise NotImplementedError 61 | 62 | def get_file_url(self, key): 63 | return self.url 64 | 65 | def get_file_timestamp(self, key): 66 | return self.timestamp 67 | 68 | def get_qualified_location(self, key): 69 | return self.url 70 | -------------------------------------------------------------------------------- /studio/scripts/studio-add-credentials: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # script that builds docker image with credentials baked in 4 | # usage: 5 | # studio-add-credentials [--base-image=] [--tag=] [--check-gpu] 6 | # 7 | # --base-image specifies the base image to add credentials to. Default is peterzhokhoff/tfstudio 8 | # --tag specifes the tag of the resulting image. Default is _creds 9 | # --check-gpu option (if specified) checks if nvidia-smi works correctly on a current machine, and if it is not, uninstalls tensorflow-gpu from docker image. 10 | # Without --check-gpu option the built docker image may not work on a current machine. 11 | 12 | base_img=peterzhokhoff/tfstudio 13 | 14 | while [[ $# -gt 0 ]] 15 | do 16 | key="$1" 17 | case ${key%%=*} in 18 | 19 | -b|--base-image) 20 | base_img="${1##*=}" 21 | ;; 22 | 23 | -t|--tag) 24 | output_img="${1##*=}" 25 | ;; 26 | 27 | esac 28 | shift 29 | done 30 | 31 | if [ -z $output_img ]; then 32 | output_img=$base_img"_creds" 33 | fi 34 | 35 | # mypath=$(pwd)/${0%/*} 36 | dockerfile=".Dockerfile_bake_creds" 37 | awspath=$HOME/.aws 38 | 39 | # cd $mypath/../../ 40 | 41 | echo "Base image: $base_img" 42 | echo "Tag: $output_img" 43 | 44 | 45 | echo "Uninstall tensorflow-gpu: $uninstall_tfgpu" 46 | 47 | contextdir=$TMPDIR/studioml_container_context 48 | mkdir $contextdir 49 | cd $contextdir 50 | 51 | if [ -d $HOME/.studioml/keys ]; then 52 | cp -r $HOME/.studioml/keys .keys 53 | fi 54 | 55 | if [ -n $GOOGLE_APPLICATION_CREDENTIALS ]; then 56 | cp $GOOGLE_APPLICATION_CREDENTIALS .gac_credentials 57 | fi 58 | 59 | if [ -d $awspath ]; then 60 | cp -r $awspath .aws 61 | fi 62 | 63 | # build dockerfile 64 | echo "Constructing dockerfile..." 65 | echo "FROM $base_img" > $dockerfile 66 | 67 | if [ -d $HOME/.studioml/keys/ ]; then 68 | echo "ADD .keys /root/.studioml/keys" >> $dockerfile 69 | fi 70 | 71 | if [ -n $GOOGLE_APPLICATION_CREDENTIALS ]; then 72 | echo "ADD .gac_credentials /root/gac_credentials" >> $dockerfile 73 | echo "ENV GOOGLE_APPLICATION_CREDENTIALS /root/gac_credentials" >> $dockerfile 74 | fi 75 | 76 | if [ -d $awspath ]; then 77 | echo "ADD .aws /root/.aws" >> $dockerfile 78 | fi 79 | 80 | echo "Done. Resulting dockerfile: " 81 | cat $dockerfile 82 | 83 | 84 | # build docker image 85 | echo "Building image..." 86 | docker build -t $output_img -f $dockerfile . 87 | echo "Done" 88 | 89 | # cleanup 90 | echo "Cleaning up..." 91 | rm -rf .keys 92 | rm -rf $dockerfile 93 | 94 | if [ -n $GOOGLE_APPLICATION_CREDENTIALS ]; then 95 | rm -rf $gac_file 96 | fi 97 | 98 | if [ -d $awspath ]; then 99 | rm -rf .aws 100 | fi 101 | echo "Done" 102 | -------------------------------------------------------------------------------- /studio/remote_worker.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | from studio import model, logs 5 | from .local_worker import worker_loop 6 | from .pubsub_queue import PubsubQueue 7 | from .sqs_queue import SQSQueue 8 | from .util import parse_verbosity 9 | 10 | from .qclient_cache import get_cached_queue 11 | 12 | 13 | def main(args=sys.argv): 14 | logger = logs.get_logger('studio-remote-worker') 15 | parser = argparse.ArgumentParser( 16 | description='Studio remote worker. \ 17 | Usage: studio-remote-worker \ 18 | ') 19 | parser.add_argument('--config', help='configuration file', default=None) 20 | 21 | parser.add_argument( 22 | '--guest', 23 | help='Guest mode (does not require db credentials)', 24 | action='store_true') 25 | 26 | parser.add_argument( 27 | '--single-run', 28 | help='quit after a single run (regardless of the state of the queue)', 29 | action='store_true') 30 | 31 | parser.add_argument('--queue', help='queue name', required=True) 32 | parser.add_argument( 33 | '--verbose', '-v', 34 | help='Verbosity level. Allowed vaules: ' + 35 | 'debug, info, warn, error, crit ' + 36 | 'or numerical value of logger levels.', 37 | default=None) 38 | 39 | parser.add_argument( 40 | '--timeout', '-t', 41 | help='Timeout after which remote worker stops listening (in seconds)', 42 | type=int, 43 | default=100) 44 | 45 | parsed_args, script_args = parser.parse_known_args(args) 46 | verbose = parse_verbosity(parsed_args.verbose) 47 | logger.setLevel(verbose) 48 | 49 | config = None 50 | if parsed_args.config is not None: 51 | config = model.get_config(parsed_args.config) 52 | 53 | if parsed_args.queue.startswith('ec2_') or \ 54 | parsed_args.queue.startswith('sqs_'): 55 | queue = SQSQueue(parsed_args.queue, verbose=verbose) 56 | elif parsed_args.queue.startswith('rmq_'): 57 | queue = get_cached_queue( 58 | name=parsed_args.queue, 59 | route='StudioML.' + parsed_args.queue, 60 | config=config, 61 | logger=logger, 62 | verbose=verbose) 63 | else: 64 | queue = PubsubQueue(parsed_args.queue, verbose=verbose) 65 | 66 | logger.info('Waiting for work') 67 | 68 | timeout_before = parsed_args.timeout 69 | timeout_after = timeout_before if timeout_before > 0 else 0 70 | # wait_for_messages(queue, timeout_before, logger) 71 | 72 | logger.info('Starting working') 73 | worker_loop(queue, parsed_args, 74 | single_experiment=parsed_args.single_run, 75 | timeout=timeout_after, 76 | verbose=verbose) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /studio/storage/storage_handler_factory.py: -------------------------------------------------------------------------------- 1 | from studio.util import logs 2 | from typing import Dict 3 | 4 | from studio.storage.http_storage_handler import HTTPStorageHandler 5 | from studio.storage.local_storage_handler import LocalStorageHandler 6 | from studio.storage.storage_setup import get_storage_verbose_level 7 | from studio.storage.storage_handler import StorageHandler 8 | from studio.storage.storage_type import StorageType 9 | from studio.storage.s3_storage_handler import S3StorageHandler 10 | 11 | _storage_factory = None 12 | 13 | class StorageHandlerFactory: 14 | def __init__(self): 15 | self.logger = logs.get_logger(self.__class__.__name__) 16 | self.logger.setLevel(get_storage_verbose_level()) 17 | self.handlers_cache = dict() 18 | self.cleanup_at_exit: bool = True 19 | 20 | @classmethod 21 | def get_factory(cls): 22 | global _storage_factory 23 | if _storage_factory is None: 24 | _storage_factory = StorageHandlerFactory() 25 | return _storage_factory 26 | 27 | def set_cleanup_at_exit(self, value: bool): 28 | self.cleanup_at_exit = value 29 | 30 | def cleanup(self): 31 | if not self.cleanup_at_exit: 32 | return 33 | 34 | for _, handler in self.handlers_cache.items(): 35 | if handler is not None: 36 | handler.cleanup() 37 | 38 | def get_handler(self, handler_type: StorageType, 39 | config: Dict) -> StorageHandler: 40 | if handler_type == StorageType.storageS3: 41 | handler_id: str = S3StorageHandler.get_id(config) 42 | handler = self.handlers_cache.get(handler_id, None) 43 | if handler is None: 44 | handler = S3StorageHandler(config) 45 | self.handlers_cache[handler_id] = handler 46 | return handler 47 | if handler_type == StorageType.storageHTTP: 48 | handler_id: str = HTTPStorageHandler.get_id(config) 49 | handler = self.handlers_cache.get(handler_id, None) 50 | if handler is None: 51 | handler = HTTPStorageHandler( 52 | config.get('endpoint', None), 53 | config.get('credentials', None)) 54 | self.handlers_cache[handler_id] = handler 55 | return handler 56 | if handler_type == StorageType.storageLocal: 57 | handler_id: str = LocalStorageHandler.get_id(config) 58 | handler = self.handlers_cache.get(handler_id, None) 59 | if handler is None: 60 | handler = LocalStorageHandler(config) 61 | self.handlers_cache[handler_id] = handler 62 | return handler 63 | self.logger("FAILED to get storage handler: unsupported type %s", 64 | repr(handler_type)) 65 | return None -------------------------------------------------------------------------------- /tests/serving_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import subprocess 3 | import time 4 | from random import randint 5 | import os 6 | import uuid 7 | import requests 8 | 9 | from studio import model 10 | from env_detect import on_gcp 11 | 12 | 13 | @unittest.skipIf( 14 | not on_gcp(), 15 | 'User indicated not on gcp') 16 | class UserIndicatedOnGCPTest(unittest.TestCase): 17 | def test_on_enviornment(self): 18 | self.assertTrue('GOOGLE_APPLICATION_CREDENTIALS' in os.environ.keys()) 19 | 20 | 21 | @unittest.skipIf( 22 | (not on_gcp()) or 23 | 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(), 24 | 'Skipping due to userinput or GCP Not detected' + 25 | "GOOGLE_APPLICATION_CREDENTIALS is missing, needed for " + 26 | "server to communicate with storage") 27 | class ServingTest(unittest.TestCase): 28 | 29 | _mutliprocess_shared_ = True 30 | 31 | def _test_serving(self, data_in, expected_data_out, wrapper=None): 32 | 33 | self.port = randint(5000, 9000) 34 | server_experimentid = 'test_serving_' + str(uuid.uuid4()) 35 | 36 | args = [ 37 | 'studio', 'run', 38 | '--force-git', 39 | '--verbose=debug', 40 | '--experiment=' + server_experimentid, 41 | '--config=' + self.get_config_path(), 42 | 'studio::serve_main', 43 | '--port=' + str(self.port), 44 | '--host=localhost' 45 | ] 46 | 47 | if wrapper: 48 | args.append('--wrapper=' + wrapper) 49 | 50 | subprocess.Popen(args, cwd=os.path.dirname(__file__)) 51 | time.sleep(60) 52 | 53 | try: 54 | retval = requests.post( 55 | url='http://localhost:' + str(self.port), json=data_in) 56 | data_out = retval.json() 57 | assert data_out == expected_data_out 58 | 59 | finally: 60 | with model.get_db_provider(model.get_config( 61 | self.get_config_path())) as db: 62 | 63 | db.stop_experiment(server_experimentid) 64 | time.sleep(20) 65 | db.delete_experiment(server_experimentid) 66 | 67 | def test_serving_identity(self): 68 | data = {"a": "b"} 69 | self._test_serving( 70 | data_in=data, 71 | expected_data_out=data 72 | ) 73 | 74 | def test_serving_increment(self): 75 | data_in = {"a": 1} 76 | data_out = {"a": 2} 77 | 78 | self._test_serving( 79 | data_in=data_in, 80 | expected_data_out=data_out, 81 | wrapper='model_increment.py' 82 | ) 83 | 84 | def get_config_path(self): 85 | return os.path.join( 86 | os.path.dirname(__file__), 87 | 'test_config_http_client.yaml' 88 | ) 89 | 90 | 91 | if __name__ == '__main__': 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /studio/git_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | import subprocess 3 | import os 4 | 5 | 6 | def get_git_info(path='.', abort_dirty=True): 7 | info = {} 8 | if not is_git(path): 9 | return None 10 | 11 | if abort_dirty and not is_clean(path): 12 | return None 13 | 14 | info['url'] = get_repo_url(path) 15 | info['commit'] = get_commit(path) 16 | return info 17 | 18 | 19 | def is_git(path='.'): 20 | p = subprocess.Popen( 21 | ['git', 'status'], 22 | stdout=subprocess.DEVNULL, 23 | stderr=subprocess.DEVNULL, 24 | cwd=path) 25 | 26 | p.wait() 27 | return (p.returncode == 0) 28 | 29 | 30 | def is_clean(path='.'): 31 | p = subprocess.Popen( 32 | ['git', 'status', '-s'], 33 | stdout=subprocess.PIPE, 34 | stderr=subprocess.PIPE, 35 | cwd=path) 36 | 37 | stdout, _ = p.communicate() 38 | if not p.returncode == 0: 39 | return False 40 | 41 | return (stdout.strip() == '') 42 | 43 | 44 | def get_repo_url(path='.', remove_user=True): 45 | p = subprocess.Popen( 46 | ['git', 'config', '--get', 'remote.origin.url'], 47 | stdout=subprocess.PIPE, 48 | stderr=subprocess.PIPE, 49 | cwd=path) 50 | 51 | stdout, _ = p.communicate() 52 | if p.returncode != 0: 53 | return None 54 | 55 | url = stdout.strip() 56 | if remove_user: 57 | url = re.sub('(?<=://).*@', '', url.decode('utf-8')) 58 | return url 59 | 60 | 61 | def get_branch(path='.'): 62 | p = subprocess.Popen( 63 | ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], 64 | stdout=subprocess.PIPE, 65 | stderr=subprocess.STDOUT, 66 | cwd=path) 67 | 68 | stdout, _ = p.communicate() 69 | if p.returncode == 0: 70 | return None 71 | 72 | return stdout.strip().decode('utf8') 73 | 74 | 75 | def get_commit(path='.'): 76 | p = subprocess.Popen( 77 | ['git', 'rev-parse', 'HEAD'], 78 | stdout=subprocess.PIPE, 79 | stderr=subprocess.STDOUT, 80 | cwd=path) 81 | 82 | stdout, _ = p.communicate() 83 | if p.returncode != 0: 84 | return None 85 | 86 | return stdout.strip().decode('utf8') 87 | 88 | 89 | def get_my_repo_url(): 90 | mypath = os.path.dirname(os.path.realpath(__file__)) 91 | repo = get_repo_url(mypath) 92 | if repo is None: 93 | repo = "https://github.com/studioml/studio" 94 | return repo 95 | 96 | 97 | def get_my_branch(): 98 | mypath = os.path.dirname(os.path.realpath(__file__)) 99 | branch = get_branch(mypath) 100 | if branch is None: 101 | branch = "master" 102 | return branch 103 | 104 | 105 | def get_my_checkout_target(): 106 | mypath = os.path.dirname(os.path.realpath(__file__)) 107 | target = get_commit(mypath) 108 | if target is None: 109 | target = get_my_branch() 110 | 111 | return target 112 | -------------------------------------------------------------------------------- /studio/completion_service/completion_service_client.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import shutil 3 | import pickle 4 | import os 5 | import sys 6 | import six 7 | import signal 8 | import pdb 9 | 10 | from studio import fs_tracker, model, logs, util 11 | 12 | logger = logs.get_logger('completion_service_client') 13 | try: 14 | logger.setLevel(model.parse_verbosity(sys.argv[1])) 15 | except BaseException: 16 | logger.setLevel(10) 17 | 18 | 19 | def main(): 20 | logger.debug('copying and importing client module') 21 | logger.debug('getting file mappings') 22 | 23 | # Register signal handler for signal.SIGUSR1 24 | # which will invoke built-in Python debugger: 25 | signal.signal(signal.SIGUSR1, lambda sig, stack: pdb.set_trace()) 26 | 27 | artifacts = fs_tracker.get_artifacts() 28 | files = {} 29 | logger.debug("Artifacts = {}".format(artifacts)) 30 | 31 | for tag, path in six.iteritems(artifacts): 32 | if tag not in {'workspace', 'modeldir', 'tb', '_runner'}: 33 | if os.path.isfile(path): 34 | files[tag] = path 35 | elif os.path.isdir(path): 36 | dirlist = os.listdir(path) 37 | if any(dirlist): 38 | files[tag] = os.path.join( 39 | path, 40 | dirlist[0] 41 | ) 42 | 43 | logger.debug("Files = {}".format(files)) 44 | script_path = files['clientscript'] 45 | retval_path = fs_tracker.get_artifact('retval') 46 | util.rm_rf(retval_path) 47 | 48 | # script_name = os.path.basename(script_path) 49 | new_script_path = os.path.join(os.getcwd(), '_clientscript.py') 50 | shutil.copy(script_path, new_script_path) 51 | 52 | script_path = new_script_path 53 | logger.debug("script path: " + script_path) 54 | 55 | mypath = os.path.dirname(script_path) 56 | sys.path.append(mypath) 57 | # os.path.splitext(os.path.basename(script_path))[0] 58 | module_name = '_clientscript' 59 | 60 | client_module = importlib.import_module(module_name) 61 | logger.debug('loading args') 62 | 63 | args_path = files['args'] 64 | 65 | with open(args_path, 'rb') as f: 66 | args = pickle.loads(f.read()) 67 | 68 | logger.debug('calling client function') 69 | retval = client_module.clientFunction(args, files) 70 | 71 | logger.debug('saving the return value') 72 | if os.path.isdir(fs_tracker.get_artifact('clientscript')): 73 | # on go runner: 74 | logger.debug("Running in a go runner, creating {} for retval" 75 | .format(retval_path)) 76 | try: 77 | os.mkdir(retval_path) 78 | except OSError: 79 | logger.debug('retval dir present') 80 | 81 | retval_path = os.path.join(retval_path, 'retval') 82 | logger.debug("New retval_path is {}".format(retval_path)) 83 | 84 | logger.debug('Saving retval') 85 | with open(retval_path, 'wb') as f: 86 | f.write(pickle.dumps(retval, protocol=2)) 87 | logger.debug('Done') 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /docs/faq.rst: -------------------------------------------------------------------------------- 1 | Frequently Asked Questions 2 | ========================== 3 | 4 | `Join us on Slack! `_ 5 | ----------------------------------------------- 6 | 7 | - `What is the complete list of tools Studio.ML is compatible with?`_ 8 | 9 | - `Is Studio.ML compatible with Python 3?`_ 10 | 11 | - `Do I need to change my code to use Studio.ML?`_ 12 | 13 | - `How can I track the training of my models?`_ 14 | 15 | - `How does Studio.ml integrate with Google Cloud or Amazon EC2?`_ 16 | 17 | - `Is it possible to view the experiment artifacts outside of the Web UI?`_ 18 | 19 | 20 | _`What is the complete list of tools Studio.ML is compatible with?` 21 | ------------------------------------------------------------------ 22 | 23 | Keras, TensorFlow, PyTorch, scikit-learn, pandas - anything that runs in python 24 | 25 | _`Is Studio.ML compatible with Python 3?` 26 | -------------------------------- 27 | 28 | Yes! Studio.ML is now compatible to use with Python 3. 29 | 30 | _`Can I use Studio.ML with my jupyter / ipython notebooks?` 31 | ----------------------------------------------------------- 32 | 33 | Yes! The basic usage pattern is import ``magics`` module from studio, 34 | and then annotate cells that need to be run via studio with 35 | ``%%studio_run`` cell magic (optionally followed by the same command-line arguments that 36 | ``studio run`` accepts. Please refer to `` the for more info 37 | 38 | 39 | _`Do I need to change my code to use Studio.ML?` 40 | --------------------------------------------- 41 | 42 | Studio is designed to minimize any invasion of your existing code. Running an experiment with Studio should be as simple as replacing ``python`` with ``studio run`` in your command line with a few flags for capturing your workspace or naming your experiments. 43 | 44 | _`How can I track the training of my models?` 45 | -------------------- 46 | 47 | You can manage any of your experiments- current, old or queued- through the web interface. Simply run ``studio ui`` to launch the UI to view details of any of your experiments. 48 | 49 | _`How does Studio.ml integrate with Google Cloud or Amazon EC2?` 50 | ----------------- 51 | 52 | We use standard Python tools like Boto and Google Cloud Python Client to launch GPU instances that are used for model training and de-provision them when the experiment is finished. 53 | 54 | _`Is it possible to view the experiment artifacts outside of the Web UI?` 55 | ------------------- 56 | 57 | Yes! 58 | 59 | :: 60 | 61 | from studio import model 62 | 63 | with model.get_db_provider() as db: 64 | experiment = db.get_experiment() 65 | 66 | 67 | will return an experiment object that contains all the information about the experiment with key ````, including artifacts. 68 | The artifacts can then be downloaded: 69 | 70 | :: 71 | 72 | with model.get_db_provider() as db: 73 | artifact_path = db.get_artifact(experiment.artifacts[]) 74 | 75 | will download an artifact with tag ```` and return a local path to it in ``artifact_path`` variable 76 | 77 | -------------------------------------------------------------------------------- /studio/scripts/ec2_worker_startup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | exec > >(tee -i ~/ec2_worker_logfile.txt) 4 | exec 2>&1 5 | 6 | cd ~ 7 | mkdir .aws 8 | echo "[default]" > .aws/config 9 | echo "region = {region}" >> .aws/config 10 | 11 | mkdir -p .studioml/keys 12 | key_name="{auth_key}" 13 | queue_name="{queue_name}" 14 | echo "{auth_data}" | base64 --decode > .studioml/keys/$key_name 15 | echo "{google_app_credentials}" | base64 --decode > credentials.json 16 | 17 | export GOOGLE_APPLICATION_CREDENTIALS=~/credentials.json 18 | 19 | export AWS_ACCESS_KEY_ID="{aws_access_key}" 20 | export AWS_SECRET_ACCESS_KEY="{aws_secret_key}" 21 | 22 | code_url_base="https://storage.googleapis.com/studio-ed756.appspot.com/src" 23 | #code_ver="tfstudio-64_config_location-2017-08-04_1.tgz" 24 | 25 | 26 | autoscaling_group="{autoscaling_group}" 27 | instance_id=$(wget -q -O - http://169.254.169.254/latest/meta-data/instance-id) 28 | 29 | echo "Environment varibles:" 30 | env 31 | 32 | {install_studio} 33 | 34 | python $(which studio-remote-worker) --queue=$queue_name --verbose=debug --timeout={timeout} 35 | 36 | # sudo update-alternatives --set python3 37 | 38 | # shutdown the instance 39 | echo "Work done" 40 | 41 | hostname=$(hostname) 42 | aws s3 cp /var/log/cloud-init-output.log "s3://studioml-logs/$queue_name/$hostname.txt" 43 | 44 | if [[ -n $(who) ]]; then 45 | echo "Users are logged in, not shutting down" 46 | echo "Do not forget to shut the instance down manually" 47 | exit 0 48 | fi 49 | 50 | 51 | 52 | if [ -n $autoscaling_group ]; then 53 | 54 | echo "Getting info for auto-scaling group $autoscaling_group" 55 | 56 | asg_info="aws autoscaling describe-auto-scaling-groups --auto-scaling-group-name $autoscaling_group" 57 | desired_size=$( $asg_info | jq --raw-output ".AutoScalingGroups | .[0] | .DesiredCapacity" ) 58 | launch_config=$( $asg_info | jq --raw-output ".AutoScalingGroups | .[0] | .LaunchConfigurationName" ) 59 | 60 | echo "Launch config: $launch_config" 61 | echo "Current autoscaling group size (desired): $desired_size" 62 | 63 | if [[ $desired_size -gt 1 ]]; then 64 | echo "Detaching myself ($instance_id) from the ASG $autoscaling_group" 65 | aws autoscaling detach-instances --instance-ids $instance_id --auto-scaling-group-name $autoscaling_group --should-decrement-desired-capacity 66 | #new_desired_size=$((desired_size - 1)) 67 | #echo "Decreasing ASG size to $new_desired_size" 68 | #aws autoscaling update-auto-scaling-group --auto-scaling-group-name $autoscaling_group --desired-capacity $new_desired_size 69 | else 70 | echo "Deleting launch configuration and auto-scaling group" 71 | aws autoscaling delete-auto-scaling-group --auto-scaling-group-name $autoscaling_group --force-delete 72 | aws autoscaling delete-launch-configuration --launch-configuration-name $launch_config 73 | fi 74 | # if desired_size > 1 decrease desired size (with cooldown - so that it does not try to remove any other instances!) 75 | # else delete the group - that should to the shutdown 76 | # 77 | 78 | fi 79 | aws s3 cp /var/log/cloud-init-output.log "s3://studioml-logs/$queue_name/$hostname.txt" 80 | echo "Shutting the instance down!" 81 | sudo shutdown now 82 | -------------------------------------------------------------------------------- /studio/util/gpu_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import xml.etree.ElementTree as ET 4 | 5 | from studio.util.util import sixdecode 6 | 7 | 8 | def memstr2int(string): 9 | conversion_factors = [ 10 | ('Mb', 2**20), ('MiB', 2**20), ('m', 2**20), ('mb', 2**20), 11 | ('Gb', 2**30), ('GiB', 2**30), ('g', 2**30), ('gb', 2**30), 12 | ('kb', 2**10), ('k', 2**10) 13 | ] 14 | 15 | for key, factor in conversion_factors: 16 | if string.endswith(key): 17 | return int(float(string.replace(key, '')) * factor) 18 | 19 | return int(string) 20 | 21 | 22 | def get_available_gpus(gpu_mem_needed=None, strict=False): 23 | gpus = _get_gpu_info() 24 | 25 | def check_gpu_nomem_strict(gpu): 26 | return memstr2int(gpu.find('fb_memory_usage').find('used').text) < \ 27 | memstr2int(gpu.find('fb_memory_usage').find('free').text) 28 | if gpu_mem_needed is None: 29 | if strict: 30 | return [gpu.find('minor_number').text for gpu in gpus if 31 | check_gpu_nomem_strict(gpu)] 32 | return [gpu.find('minor_number').text for gpu in gpus] 33 | 34 | gpu_mem_needed = memstr2int(gpu_mem_needed) 35 | 36 | def check_gpu_mem_strict(gpu): 37 | return gpu_mem_needed < \ 38 | memstr2int(gpu.find('fb_memory_usage').find('free').text) 39 | 40 | def check_gpu_mem_loose(gpu): 41 | return (gpu_mem_needed < 42 | memstr2int(gpu.find('fb_memory_usage').find('total').text)) \ 43 | and check_gpu_nomem_strict(gpu) 44 | 45 | if strict: 46 | return [gpu.find('minor_number').text 47 | for gpu in gpus if check_gpu_mem_strict(gpu)] 48 | return [gpu.find('minor_number').text 49 | for gpu in gpus if check_gpu_mem_loose(gpu)] 50 | 51 | 52 | def _get_gpu_info(): 53 | try: 54 | with subprocess.Popen(['nvidia-smi', '-q', '-x'], 55 | stdout=subprocess.PIPE, 56 | stderr=subprocess.STDOUT) as smi_proc: 57 | smi_output, _ = smi_proc.communicate() 58 | xmlroot = ET.fromstring(sixdecode(smi_output)) 59 | return xmlroot.findall('gpu') 60 | except Exception: 61 | return [] 62 | 63 | 64 | def get_gpus_summary(): 65 | info = _get_gpu_info() 66 | 67 | def info_to_summary(gpuinfo): 68 | util = gpuinfo.find('utilization').find('gpu_util').text 69 | mem = gpuinfo.find('fb_memory_usage').find('used').text 70 | 71 | return "util: {}, mem {}".format(util, memstr2int(mem)) 72 | 73 | return " ".join([ 74 | "gpu {} {}".format( 75 | gpuinfo.find('minor_number').text, 76 | info_to_summary(gpuinfo)) for gpuinfo in info]) 77 | 78 | 79 | def get_gpu_mapping(): 80 | no_gpus = len(_get_gpu_info()) 81 | return {str(i): i for i in range(no_gpus)} 82 | 83 | 84 | def _find_my_gpus(prop='minor_number'): 85 | gpu_info = _get_gpu_info() 86 | my_gpus = [g.find(prop).text for g in gpu_info if os.getpid() in [int( 87 | p.find('pid').text) for p in 88 | g.find('processes').findall('process_info')]] 89 | 90 | return my_gpus 91 | 92 | 93 | if __name__ == "__main__": 94 | print(get_gpu_mapping()) 95 | -------------------------------------------------------------------------------- /studio/torch/summary.py: -------------------------------------------------------------------------------- 1 | """Tools to simplify PyTorch reporting and integrate with TensorBoard.""" 2 | 3 | import collections 4 | import six 5 | import time 6 | 7 | try: 8 | from tensorflow import summary as tb_summary 9 | except ImportError: 10 | tb_summary = None 11 | 12 | 13 | class TensorBoardWriter(object): 14 | """Write events in TensorBoard format.""" 15 | 16 | def __init__(self, logdir): 17 | if tb_summary is None: 18 | raise ValueError( 19 | "You must install TensorFlow " + 20 | "to use Tensorboard summary writer.") 21 | self._writer = tb_summary.FileWriter(logdir) 22 | 23 | def add(self, step, key, value): 24 | summary = tb_summary.Summary() 25 | summary_value = summary.value.add() 26 | summary_value.tag = key 27 | summary_value.simple_value = value 28 | self._writer.add_summary(summary, global_step=step) 29 | 30 | def flush(self): 31 | self._writer.flush() 32 | 33 | def close(self): 34 | self._writer.close() 35 | 36 | 37 | class Reporter(object): 38 | """Manages reporting of metrics.""" 39 | 40 | def __init__(self, log_interval=10, logdir=None, smooth_interval=10): 41 | self._writer = None 42 | if logdir: 43 | self._writer = TensorBoardWriter(logdir) 44 | self._last_step = 0 45 | self._last_reported_step = None 46 | self._last_reported_time = None 47 | self._log_interval = log_interval 48 | self._smooth_interval = smooth_interval 49 | self._metrics = collections.defaultdict(collections.deque) 50 | 51 | def record(self, step, **kwargs): 52 | for key, value in six.iteritems(kwargs): 53 | self.add(step, key, value) 54 | 55 | def add(self, step, key, value): 56 | self._last_step = step 57 | self._metrics[key].append(value) 58 | if len(self._metrics[key]) > self._smooth_interval: 59 | self._metrics[key].popleft() 60 | if self._last_step % self._log_interval == 0: 61 | if self._writer: 62 | self._writer.add(step, key, value) 63 | 64 | def report(self, stdout=None): 65 | if self._last_step % self._log_interval == 0: 66 | def smooth(values): 67 | return (sum(values) / len(values)) if values else 0.0 68 | metrics = ','.join(["%s = %.5f" % (k, smooth(v)) 69 | for k, v in six.iteritems(self._metrics)]) 70 | if self._last_reported_time: 71 | elapsed_secs = time.time() - self._last_reported_time 72 | metrics += " (%.3f sec)" % elapsed_secs 73 | if self._writer: 74 | elapsed_steps = float( 75 | self._last_step - self._last_reported_step) 76 | self._writer.add( 77 | self._last_step, 'step/sec', 78 | elapsed_steps / elapsed_secs) 79 | 80 | line = u"Step {}: {}".format(self._last_step, metrics) 81 | if stdout: 82 | stdout.write(line) 83 | else: 84 | print(line) 85 | 86 | self._last_reported_time = time.time() 87 | self._last_reported_step = self._last_step 88 | -------------------------------------------------------------------------------- /studio/storage/storage_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | import time 4 | from typing import Dict 5 | 6 | from studio.storage.storage_type import StorageType 7 | from studio.util.util import get_temp_filename, check_for_kb_interrupt 8 | 9 | # StorageHandler encapsulates the logic of basic storage operations 10 | # for specific storage endpoint (S3, http, local etc.) 11 | # together with access credentials for this endpoint. 12 | class StorageHandler: 13 | def __init__(self, storage_type: StorageType, 14 | logger, 15 | measure_timestamp_diff=False, 16 | compression=None): 17 | self.type = storage_type 18 | self.logger = logger 19 | self.compression = compression 20 | self._timestamp_shift = 0 21 | if measure_timestamp_diff: 22 | try: 23 | self._timestamp_shift = self._measure_timestamp_diff() 24 | except BaseException: 25 | check_for_kb_interrupt() 26 | self._timestamp_shift = 0 27 | 28 | @classmethod 29 | def get_id(cls, config: Dict) -> str: 30 | raise NotImplementedError("Not implemented: upload_file") 31 | 32 | def upload_file(self, key, local_path): 33 | raise NotImplementedError("Not implemented: upload_file") 34 | 35 | def download_file(self, key, local_path): 36 | raise NotImplementedError("Not implemented: download_file") 37 | 38 | def download_remote_path(self, remote_path, local_path): 39 | raise NotImplementedError("Not implemented: download_remote_path") 40 | 41 | def delete_file(self, key, shallow=True): 42 | raise NotImplementedError("Not implemented: delete_file") 43 | 44 | def get_file_url(self, key, method='GET'): 45 | raise NotImplementedError("Not implemented: get_file_url") 46 | 47 | def get_file_timestamp(self, key): 48 | raise NotImplementedError("Not implemented: get_file_timestamp") 49 | 50 | def get_qualified_location(self, key): 51 | raise NotImplementedError("Not implemented: get_qualified_location") 52 | 53 | def get_local_destination(self, remote_path: str): 54 | raise NotImplementedError("Not implemented: get_local_destination") 55 | 56 | def get_timestamp_shift(self): 57 | return self._timestamp_shift 58 | 59 | def get_compression(self): 60 | return self.compression 61 | 62 | def cleanup(self): 63 | pass 64 | 65 | def _measure_timestamp_diff(self): 66 | max_diff = 60 67 | tmpfile = get_temp_filename() + '.txt' 68 | with open(tmpfile, 'w') as f: 69 | f.write('timestamp_diff_test') 70 | key = 'tests/' + str(uuid.uuid4()) 71 | self.upload_file(key, tmpfile) 72 | remote_timestamp = self.get_file_timestamp(key) 73 | 74 | if remote_timestamp is not None: 75 | now_remote_diff = time.time() - remote_timestamp 76 | self.storage_handler.delete_file(key) 77 | os.remove(tmpfile) 78 | 79 | assert -max_diff < now_remote_diff and \ 80 | now_remote_diff < max_diff, \ 81 | "Timestamp difference is more than 60 seconds. " + \ 82 | "You'll need to adjust local clock for caching " + \ 83 | "to work correctly" 84 | 85 | return -now_remote_diff if now_remote_diff < 0 else 0 86 | 87 | -------------------------------------------------------------------------------- /tests/http_provider_hosted_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import tempfile 4 | import uuid 5 | 6 | from studio import model 7 | from model_test import get_test_experiment 8 | 9 | # We are not currently working with HTTP providers. 10 | @unittest.skip 11 | class HTTPProviderHostedTest(unittest.TestCase): 12 | 13 | def get_db_provider(self, config_name): 14 | config_file = os.path.join( 15 | os.path.dirname( 16 | os.path.realpath(__file__)), 17 | config_name) 18 | return model.get_db_provider(model.get_config(config_file)) 19 | 20 | def test_add_get_delete_experiment(self): 21 | with self.get_db_provider('test_config_http_client.yaml') as hp: 22 | 23 | experiment_tuple = get_test_experiment() 24 | hp.add_experiment(experiment_tuple[0]) 25 | experiment = hp.get_experiment(experiment_tuple[0].key) 26 | self.assertEquals(experiment.key, experiment_tuple[0].key) 27 | self.assertEquals( 28 | experiment.filename, 29 | experiment_tuple[0].filename) 30 | self.assertEquals(experiment.args, experiment_tuple[0].args) 31 | 32 | hp.delete_experiment(experiment_tuple[1]) 33 | 34 | self.assertTrue(hp.get_experiment(experiment_tuple[1]) is None) 35 | 36 | def test_start_experiment(self): 37 | with self.get_db_provider('test_config_http_client.yaml') as hp: 38 | experiment_tuple = get_test_experiment() 39 | 40 | hp.add_experiment(experiment_tuple[0]) 41 | hp.start_experiment(experiment_tuple[0]) 42 | 43 | experiment = hp.get_experiment(experiment_tuple[1]) 44 | 45 | self.assertTrue(experiment.status == 'running') 46 | 47 | self.assertEquals(experiment.key, experiment_tuple[0].key) 48 | self.assertEquals( 49 | experiment.filename, 50 | experiment_tuple[0].filename) 51 | self.assertEquals(experiment.args, experiment_tuple[0].args) 52 | 53 | hp.finish_experiment(experiment_tuple[0]) 54 | hp.delete_experiment(experiment_tuple[1]) 55 | 56 | def test_add_get_experiment_artifacts(self): 57 | experiment_tuple = get_test_experiment() 58 | e_experiment = experiment_tuple[0] 59 | e_artifacts = e_experiment.artifacts 60 | 61 | a1_filename = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) 62 | a2_filename = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) 63 | 64 | with open(a1_filename, 'w') as f: 65 | f.write('hello world') 66 | 67 | e_artifacts['a1'] = { 68 | 'local': a1_filename, 69 | 'mutable': False 70 | } 71 | 72 | e_artifacts['a2'] = { 73 | 'local': a2_filename, 74 | 'mutable': True 75 | } 76 | 77 | with self.get_db_provider('test_config_http_client.yaml') as db: 78 | db.add_experiment(e_experiment) 79 | 80 | experiment = db.get_experiment(e_experiment.key) 81 | self.assertEquals(experiment.key, e_experiment.key) 82 | self.assertEquals(experiment.filename, e_experiment.filename) 83 | self.assertEquals(experiment.args, e_experiment.args) 84 | db.delete_experiment(e_experiment.key) 85 | os.remove(a1_filename) 86 | 87 | 88 | if __name__ == '__main__': 89 | unittest.main() 90 | -------------------------------------------------------------------------------- /studio/experiment_submitter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | import traceback 5 | 6 | from studio.artifacts.artifact import Artifact 7 | from studio.db_providers import db_provider_setup 8 | from studio.experiments.experiment import Experiment 9 | from studio import git_util 10 | from studio.payload_builders.payload_builder import PayloadBuilder 11 | from studio.payload_builders.unencrypted_payload_builder import UnencryptedPayloadBuilder 12 | from studio.encrypted_payload_builder import EncryptedPayloadBuilder 13 | from studio.storage import storage_setup 14 | from studio.util import util 15 | 16 | def submit_experiments( 17 | experiments, 18 | config=None, 19 | logger=None, 20 | queue=None, 21 | python_pkg=[], 22 | external_payload_builder: PayloadBuilder=None): 23 | 24 | num_experiments = len(experiments) 25 | 26 | payload_builder = external_payload_builder 27 | if payload_builder is None: 28 | # Setup our own payload builder 29 | payload_builder = UnencryptedPayloadBuilder("simple-payload") 30 | # Are we using experiment payload encryption? 31 | public_key_path = config.get('public_key_path', None) 32 | if public_key_path is not None: 33 | logger.info("Using RSA public key path: {0}".format(public_key_path)) 34 | signing_key_path = config.get('signing_key_path', None) 35 | if signing_key_path is not None: 36 | logger.info("Using RSA signing key path: {0}".format(signing_key_path)) 37 | payload_builder = \ 38 | EncryptedPayloadBuilder( 39 | "cs-rsa-encryptor [{0}]".format(public_key_path), 40 | public_key_path, signing_key_path) 41 | 42 | start_time = time.time() 43 | 44 | # Reset our storage setup, which will guarantee 45 | # that we rebuild our database and storage provider objects 46 | # that's important in the case that previous experiment batch 47 | # cleaned up after itself. 48 | storage_setup.reset_storage() 49 | 50 | for experiment in experiments: 51 | # Update Python environment info for our experiments: 52 | experiment.pythonenv = util.add_packages(experiment.pythonenv, python_pkg) 53 | 54 | # Add experiment to database: 55 | try: 56 | with db_provider_setup.get_db_provider(config) as db: 57 | _add_git_info(experiment, logger) 58 | db.add_experiment(experiment) 59 | except BaseException: 60 | traceback.print_exc() 61 | raise 62 | 63 | payload = payload_builder.construct(experiment, config, python_pkg) 64 | 65 | logger.info("Submitting experiment: {0}" 66 | .format(json.dumps(payload, indent=4))) 67 | 68 | queue.enqueue(json.dumps(payload)) 69 | logger.info("studio run: submitted experiment " + experiment.key) 70 | 71 | logger.info("Added {0} experiment(s) in {1} seconds to queue {2}" 72 | .format(num_experiments, int(time.time() - start_time), queue.get_name())) 73 | return queue.get_name() 74 | 75 | def _add_git_info(experiment: Experiment, logger): 76 | wrk_space: Artifact = experiment.artifacts.get('workspace', None) 77 | if wrk_space is not None: 78 | if wrk_space.local_path is not None and \ 79 | os.path.exists(wrk_space.local_path): 80 | if logger is not None: 81 | logger.info("git location for experiment %s", wrk_space.local_path) 82 | experiment.git = git_util.get_git_info(wrk_space.local_path) 83 | -------------------------------------------------------------------------------- /studio/scripts/gcloud_worker_startup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | exec > >(tee -i ~/gcloud_worker_logfile.txt) 4 | exec 2>&1 5 | 6 | metadata_url="http://metadata.google.internal/computeMetadata/v1/instance" 7 | queue_name=$(curl "$metadata_url/attributes/queue_name" -H "Metadata-Flavor: Google") 8 | key_name=$(curl "$metadata_url/attributes/auth_key" -H "Metadata-Flavor: Google") 9 | timeout=$(curl "$metadata_url/attributes/timeout" -H "Metadata-Flavor: Google") 10 | 11 | zone=$(curl "$metadata_url/zone" -H "Metadata-Flavor: Google") 12 | instance_name=$(curl "$metadata_url/name" -H "Metadata-Flavor: Google") 13 | group_name=$(curl "$metadata_url/attributes/groupname" -H "Metadata-Flavor: Google") 14 | 15 | echo Instance name is $instance_name 16 | echo Group name is $group_name 17 | 18 | cd ~ 19 | 20 | mkdir -p .studioml/keys 21 | curl "$metadata_url/attributes/auth_data" -H "Metadata-Flavor: Google" > .studioml/keys/$key_name 22 | curl "$metadata_url/attributes/credentials" -H "Metadata-Flavor: Google" > credentials.json 23 | export GOOGLE_APPLICATION_CREDENTIALS=~/credentials.json 24 | 25 | 26 | : "${{GOOGLE_APPLICATION_CREDENTIALS?Need to point GOOGLE_APPLICATION_CREDENTIALS to the google credentials file}}" 27 | : "${{queue_name?Queue name is not specified (pass as a script argument}}" 28 | 29 | gac_path=${{GOOGLE_APPLICATION_CREDENTIALS%/*}} 30 | gac_name=${{GOOGLE_APPLICATION_CREDENTIALS##*/}} 31 | #bash_cmd="git clone $repo && \ 32 | # cd studio && \ 33 | # git checkout $branch && \ 34 | # sudo pip install --upgrade pip && \ 35 | # sudo pip install -e . --upgrade && \ 36 | # mkdir /workspace && cd /workspace && \ 37 | # studio-rworker --queue=$queue_name" 38 | 39 | code_url_base="https://storage.googleapis.com/studio-ed756.appspot.com/src" 40 | #code_ver="tfstudio-64_config_location-2017-08-04_1.tgz" 41 | 42 | echo "Environment varibles:" 43 | env 44 | 45 | {install_studio} 46 | 47 | python $(which studio-remote-worker) --queue=$queue_name --verbose=debug --timeout=$timeout 48 | 49 | logbucket={log_bucket} 50 | if [[ -n $logbucket ]]; then 51 | gsutil cp /var/log/syslog gs://$logbucket/$queue_name/$instance_name.log 52 | fi 53 | 54 | if [[ -n $(who) ]]; then 55 | echo "Users logged in, preventing auto-shutdown" 56 | echo "Do not forget to turn the instance off manually" 57 | exit 0 58 | fi 59 | 60 | # shutdown the instance 61 | not_spot=$(echo "$group_name" | grep "Error 404" | wc -l) 62 | echo "not_spot = $not_spot" 63 | 64 | if [[ "$not_spot" -eq "0" ]]; then 65 | current_size=$(gcloud compute instance-groups managed describe $group_name --zone $zone | grep "targetSize" | awk '{{print $2}}') 66 | echo Current group size is $current_size 67 | if [[ $current_size -gt 1 ]]; then 68 | echo "Deleting myself (that is, $instance_name) from $group_name" 69 | gcloud compute instance-groups managed delete-instances $group_name --zone $zone --instances $instance_name 70 | else 71 | template=$(gcloud compute instance-groups managed describe $group_name --zone $zone | grep "instanceTemplate" | awk '{{print $2}}') 72 | echo "Detaching myself, deleting group $group_name and the template $template" 73 | gcloud compute instance-groups managed abandon-instances $group_name --zone $zone --instances $instance_name 74 | sleep 5 75 | gcloud compute instance-groups managed delete $group_name --zone $zone --quiet 76 | sleep 5 77 | gcloud compute instance-templates delete $template --quiet 78 | fi 79 | 80 | fi 81 | echo "Shutting down" 82 | gcloud compute instances delete $instance_name --zone $zone --quiet 83 | -------------------------------------------------------------------------------- /studio/db_providers/s3_provider.py: -------------------------------------------------------------------------------- 1 | import json 2 | from studio.db_providers.keyvalue_provider import KeyValueProvider 3 | from studio.storage.storage_handler_factory import StorageHandlerFactory 4 | from studio.storage.storage_type import StorageType 5 | 6 | 7 | class S3Provider(KeyValueProvider): 8 | 9 | def __init__(self, config, blocking_auth=True): 10 | self.config = config 11 | self.bucket = config.get('bucket', 'studioml-meta') 12 | 13 | factory: StorageHandlerFactory = StorageHandlerFactory.get_factory() 14 | self.meta_store = factory.get_handler(StorageType.storageS3, config) 15 | 16 | super().__init__( 17 | config, 18 | self.meta_store, 19 | blocking_auth) 20 | 21 | def _get(self, key, shallow=False): 22 | try: 23 | response = self.meta_store.client.list_objects( 24 | Bucket=self.bucket, 25 | Prefix=key, 26 | Delimiter='/', 27 | MaxKeys=1024*16 28 | ) 29 | except Exception as exc: 30 | msg: str = "FAILED to list objects in bucket {0}: {1}"\ 31 | .format(self.bucket, exc) 32 | self._report_fatal(msg) 33 | return None 34 | 35 | if response is None: 36 | return None 37 | 38 | if 'Contents' not in response.keys(): 39 | return None 40 | 41 | key_count = len(response['Contents']) 42 | 43 | if key_count == 0: 44 | return None 45 | 46 | for key_item in response['Contents']: 47 | if 'Key' in key_item.keys() and key_item['Key'] == key: 48 | response = self.meta_store.client.get_object( 49 | Bucket=self.bucket, 50 | Key=key) 51 | return json.loads(response['Body'].read().decode("utf-8")) 52 | 53 | return None 54 | 55 | def _delete(self, key, shallow=True): 56 | self.logger.info("S3 deleting object: %s/%s", self.bucket, key) 57 | 58 | try: 59 | response = self.meta_store.client.delete_object( 60 | Bucket=self.bucket, 61 | Key=key) 62 | except Exception as exc: 63 | msg: str = "FAILED to delete object {0} in bucket {1}: {2}"\ 64 | .format(key, self.bucket, exc) 65 | self.logger.info(msg) 66 | return 67 | 68 | reason = response['ResponseMetadata'] if response else "None" 69 | if response is None or\ 70 | response['ResponseMetadata']['HTTPStatusCode'] != 204: 71 | msg: str = ('attempt to delete key {0} in bucket {1}' + 72 | ' returned response {2}')\ 73 | .format(key, self.bucket, reason) 74 | self.logger.info(msg) 75 | 76 | def _set(self, key, value): 77 | try: 78 | response = self.meta_store.client.put_object( 79 | Bucket=self.bucket, 80 | Key=key, 81 | Body=json.dumps(value)) 82 | except Exception as exc: 83 | msg: str = "FAILED to write object {0} in bucket {1}: {2}"\ 84 | .format(key, self.bucket, exc) 85 | self._report_fatal(msg) 86 | 87 | reason = response['ResponseMetadata'] if response else "None" 88 | if response is None or \ 89 | response['ResponseMetadata']['HTTPStatusCode'] != 200: 90 | msg: str = ('attempt to write key {0} in bucket {1}' + 91 | ' returned response {2}')\ 92 | .format(key, self.bucket, reason) 93 | self._report_fatal(msg) 94 | -------------------------------------------------------------------------------- /studio/completion_service/encryptor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from Crypto.PublicKey import RSA 3 | from Crypto.Cipher import PKCS1_OAEP 4 | import nacl.secret 5 | import nacl.utils 6 | import base64 7 | 8 | class Encryptor: 9 | """ 10 | Implementation for experiment payload builder 11 | using public key RSA encryption. 12 | """ 13 | def __init__(self, keypath: str): 14 | """ 15 | param: keypath - file path to .pem file with public key 16 | """ 17 | 18 | self.key_path = keypath 19 | self.recipient_key = None 20 | try: 21 | self.recipient_key = RSA.import_key(open(self.key_path).read()) 22 | except: 23 | print( 24 | "FAILED to import recipient public key from: {0}".format(self.key_path)) 25 | return 26 | 27 | def _import_rsa_key(self, key_path: str): 28 | key = None 29 | try: 30 | key = RSA.import_key(open(key_path).read()) 31 | except: 32 | self.logger.error( 33 | "FAILED to import RSA key from: {0}".format(key_path)) 34 | key = None 35 | return key 36 | 37 | def _encrypt_str(self, workload: str): 38 | # Generate one-time symmetric session key: 39 | session_key = nacl.utils.random(32) 40 | 41 | # Encrypt the data with the NaCL session key 42 | data_to_encrypt = workload.encode("utf-8") 43 | box_out = nacl.secret.SecretBox(session_key) 44 | encrypted_data = box_out.encrypt(data_to_encrypt) 45 | encrypted_data_text = base64.b64encode(encrypted_data) 46 | 47 | # Encrypt the session key with the public RSA key 48 | cipher_rsa = PKCS1_OAEP.new(self.recipient_key) 49 | encrypted_session_key = cipher_rsa.encrypt(session_key) 50 | encrypted_session_key_text = base64.b64encode(encrypted_session_key) 51 | 52 | return encrypted_session_key_text, encrypted_data_text 53 | 54 | def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_text): 55 | private_key = self._import_rsa_key(private_key_path) 56 | if private_key is None: 57 | return None 58 | 59 | try: 60 | private_key = RSA.import_key(open(private_key_path).read()) 61 | except: 62 | self.logger.error( 63 | "FAILED to import private key from: {0}".format(private_key_path)) 64 | return None 65 | 66 | # Decrypt the session key with the private RSA key 67 | cipher_rsa = PKCS1_OAEP.new(private_key) 68 | session_key = cipher_rsa.decrypt( 69 | base64.b64decode(encrypted_key_text)) 70 | 71 | # Decrypt the data with the NaCL session key 72 | box_in = nacl.secret.SecretBox(session_key) 73 | decrypted_data = box_in.decrypt( 74 | base64.b64decode(encrypted_data_text)) 75 | decrypted_data = decrypted_data.decode("utf-8") 76 | 77 | return decrypted_data 78 | 79 | def encrypt(self, payload: str): 80 | enc_key, enc_payload = self._encrypt_str(payload) 81 | 82 | enc_key_str = enc_key.decode("utf-8") 83 | enc_payload_str = enc_payload.decode("utf-8") 84 | 85 | return "{0},{1}".format(enc_key_str, enc_payload_str) 86 | 87 | def main(): 88 | if len(sys.argv) < 3: 89 | print("USAGE {0} public-key-file-path string-to-encrypt" 90 | .format(sys.argv[0])) 91 | return 92 | 93 | encryptor = Encryptor(sys.argv[1]) 94 | data = sys.argv[2] 95 | print(data) 96 | result = encryptor.encrypt(data) 97 | print(result) 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /docs/README.rst: -------------------------------------------------------------------------------- 1 | .. raw:: html 2 | 3 |

4 | 5 |

6 | 7 | |Hex.pm| |Build.pm| 8 | 9 | Studio is a model management framework written in Python to help simplify and expedite your model building experience. It was developed to minimize the overhead involved with scheduling, running, monitoring and managing artifacts of your machine learning experiments. No one wants to spend their time configuring different machines, setting up dependencies, or playing archeologist to track down previous model artifacts. 10 | 11 | Most of the features are compatible with any Python machine learning 12 | framework (`Keras `__, 13 | `TensorFlow `__, 14 | `PyTorch `__, 15 | `scikit-learn `__, etc); 16 | some extra features are available for Keras and TensorFlow. 17 | 18 | **Use Studio to:** 19 | 20 | * Capture experiment information- Python environment, files, dependencies and logs- without modifying the experiment code. 21 | * Monitor and organize experiments using a web dashboard that integrates with TensorBoard. 22 | * Run experiments locally, remotely, or in the cloud (Google Cloud or Amazon EC2) 23 | * Manage artifacts 24 | * Perform hyperparameter search 25 | * Create customizable Python environments for remote workers. 26 | 27 | NOTE: ``studio`` package is compatible with Python 2 and 3! 28 | 29 | Example usage 30 | ------------- 31 | 32 | Start visualizer: 33 | 34 | :: 35 | 36 | studio ui 37 | 38 | Run your jobs: 39 | 40 | :: 41 | 42 | studio run train_mnist_keras.py 43 | 44 | You can see results of your job at http://localhost:5000. Run 45 | ``studio {ui|run} --help`` for a full list of ui / runner options. 46 | WARNING: because studio tries to create a reproducible environment 47 | for your experiment, if you run it in a large folder, it will take 48 | a while to archive and upload the folder. 49 | 50 | Installation 51 | ------------ 52 | 53 | pip install studioml from the master pypi repositry: 54 | 55 | :: 56 | 57 | pip install studioml 58 | 59 | Find more `details `__ on installation methods and the release process. 60 | 61 | Authentication 62 | -------------- 63 | 64 | Currently Studio supports 2 methods of authentication: `email / password `__ and using a `Google account. `__ To use studio runner and studio ui in guest 65 | mode, in studio/default\_config.yaml, uncomment "guest: true" under the 66 | database section. 67 | 68 | Alternatively, you can set up your own database and configure Studio to 69 | use it. See `setting up database `__. This is a 70 | preferred option if you want to keep your models and artifacts private. 71 | 72 | 73 | Further reading and cool features 74 | --------------------------------- 75 | 76 | - `Running experiments remotely `__ 77 | 78 | - `Custom Python environments for remote workers `__ 79 | 80 | - `Running experiments in the cloud `__ 81 | 82 | - `Google Cloud setup instructions `__ 83 | 84 | - `Amazon EC2 setup instructions `__ 85 | 86 | - `Artifact management `__ 87 | - `Hyperparameter search `__ 88 | - `Pipelines for trained models `__ 89 | - `Containerized experiments `__ 90 | 91 | .. |Hex.pm| image:: https://img.shields.io/hexpm/l/plug.svg 92 | :target: https://github.com/studioml/studio/blob/master/LICENSE 93 | 94 | .. |Build.pm| image:: https://travis-ci.org/studioml/studio.svg?branch=master 95 | :target: https://travis-ci.org/studioml/studio.svg?branch=master 96 | -------------------------------------------------------------------------------- /studio/magics.py: -------------------------------------------------------------------------------- 1 | # This code can be put in any Python module, it does not require IPython 2 | # itself to be running already. It only creates the magics subclass but 3 | # doesn't instantiate it yet. 4 | # from __future__ import print_function 5 | import pickle 6 | from IPython.core.magic import (Magics, magics_class, line_cell_magic) 7 | 8 | from types import ModuleType 9 | import six 10 | import subprocess 11 | import uuid 12 | import os 13 | import time 14 | import gzip 15 | from apscheduler.schedulers.background import BackgroundScheduler 16 | 17 | from studio.extra_util import rsync_cp 18 | from studio import fs_tracker, model 19 | from studio.runner import main as runner_main 20 | 21 | from studio.util import logs 22 | 23 | 24 | @magics_class 25 | class StudioMagics(Magics): 26 | 27 | @line_cell_magic 28 | def studio_run(self, line, cell=None): 29 | script_text = [] 30 | pickleable_ns = {} 31 | 32 | for varname, var in six.iteritems(self.shell.user_ns): 33 | if not varname.startswith('__'): 34 | if isinstance(var, ModuleType) and \ 35 | var.__name__ != 'studio.magics': 36 | script_text.append( 37 | 'import {} as {}'.format(var.__name__, varname) 38 | ) 39 | 40 | else: 41 | try: 42 | pickle.dumps(var) 43 | pickleable_ns[varname] = var 44 | except BaseException: 45 | pass 46 | 47 | script_text.append(cell) 48 | script_text = '\n'.join(script_text) 49 | stub_path = os.path.join( 50 | os.path.dirname(os.path.realpath(__file__)), 51 | 'run_magic.py.stub') 52 | 53 | with open(stub_path) as f: 54 | script_stub = f.read() 55 | 56 | script = script_stub.format(script=script_text) 57 | 58 | experiment_key = str(int(time.time())) + \ 59 | "_jupyter_" + str(uuid.uuid4()) 60 | 61 | print('Running studio with experiment key ' + experiment_key) 62 | config = model.get_config() 63 | if config['database']['type'] == 'http': 64 | print("Experiment progress can be viewed/shared at:") 65 | print("{}/experiment/{}".format( 66 | config['database']['serverUrl'], 67 | experiment_key)) 68 | 69 | workspace_new = fs_tracker.get_artifact_cache( 70 | 'workspace', experiment_key) 71 | 72 | rsync_cp('.', workspace_new) 73 | with open(os.path.join(workspace_new, '_script.py'), 'w') as f: 74 | f.write(script) 75 | 76 | ns_path = fs_tracker.get_artifact_cache('_ns', experiment_key) 77 | 78 | with gzip.open(ns_path, 'wb') as f: 79 | f.write(pickle.dumps(pickleable_ns)) 80 | 81 | if any(line): 82 | runner_args = line.strip().split(' ') 83 | else: 84 | runner_args = [] 85 | 86 | runner_args.append('--capture={}:_ns'.format(ns_path)) 87 | runner_args.append('--capture-once=.:workspace') 88 | runner_args.append('--force-git') 89 | runner_args.append('--experiment=' + experiment_key) 90 | 91 | notebook_cwd = os.getcwd() 92 | os.chdir(workspace_new) 93 | print(runner_args + ['_script.py']) 94 | runner_main(runner_args + ['_script.py']) 95 | os.chdir(notebook_cwd) 96 | 97 | with model.get_db_provider() as db: 98 | while True: 99 | experiment = db.get_experiment(experiment_key) 100 | if experiment and experiment.status == 'finished': 101 | break 102 | 103 | time.sleep(10) 104 | 105 | new_ns_path = db.get_artifact(experiment.artifacts['_ns']) 106 | 107 | with open(new_ns_path) as f: 108 | new_ns = pickle.loads(f.read()) 109 | 110 | self.shell.user_ns.update(new_ns) 111 | 112 | 113 | ip = get_ipython() 114 | ip.register_magics(StudioMagics) 115 | -------------------------------------------------------------------------------- /studio/scripts/studio-start-remote-worker: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script that starts remote worker 4 | # Usage: 5 | # 6 | # studio-start-remote-worker --queue= [--image= argument" 57 | exit 1 58 | fi 59 | 60 | echo "Docker image = $docker_img" 61 | echo "Queue = $queue_name" 62 | 63 | eval $docker_cmd 64 | if [ $? != 0 ]; then 65 | echo "Docker not installed! Install docker." 66 | exit 1 67 | fi 68 | 69 | eval nvidia-smi 70 | if [ $? == 0 ]; then 71 | eval nvidia-docker 72 | if [ $? == 0 ]; then 73 | docker_cmd=nvidia-docker 74 | bash_cmd="" 75 | else 76 | echo "Warning! nvidia-docker is not installed correctly, won't be able to use gpus" 77 | fi 78 | fi 79 | 80 | echo "docker_cmd = $docker_cmd" 81 | if [ $docker_cmd == 'docker' ]; then 82 | bash_cmd="$bash_cmd python -m pip install -e ./studio && python3 -m pip install -e ./studio && " 83 | fi 84 | 85 | bash_cmd="$bash_cmd studio-remote-worker --queue=$queue_name --verbose=debug --timeout=$timeout --single-run" 86 | 87 | : "${GOOGLE_APPLICATION_CREDENTIALS?Need to point GOOGLE_APPLICATION_CREDENTIALS to the google credentials file}" 88 | : "${queue_name?Queue name is not specified (pass as a script argument}" 89 | 90 | gac_path=${GOOGLE_APPLICATION_CREDENTIALS%/*} 91 | gac_name=${GOOGLE_APPLICATION_CREDENTIALS##*/} 92 | 93 | #bash_cmd="git clone $repo && \ 94 | # cd studio && \ 95 | # git checkout $branch && \ 96 | # sudo pip install --upgrade pip && \ 97 | # sudo pip install -e . --upgrade && \ 98 | # mkdir /workspace && cd /workspace && \ 99 | # studio-rworker --queue=$queue_name" 100 | 101 | 102 | # loop until killed 103 | 104 | docker_args="run --rm --pid=host " 105 | if [[ $no_cache -ne 1 ]]; then 106 | docker_args="$docker_args -v $HOME/.studioml/experiments:/root/.studioml/experiemnts" 107 | docker_args="$docker_args -v $HOME/.studioml/blobcache:/root/.studioml/blobcache" 108 | fi 109 | 110 | if [[ $baked_credentials -ne 1 ]]; then 111 | docker_args="$docker_args -v $HOME/.studioml/keys:/root/.studioml/keys" 112 | docker_args="$docker_args -v $gac_path:/creds -e GOOGLE_APPLICATION_CREDENTIALS=/creds/$gac_name" 113 | fi 114 | 115 | echo "Docker args = $docker_args" 116 | echo "Docker image = $docker_img" 117 | 118 | while true 119 | do 120 | echo $bash_cmd 121 | $docker_cmd pull $docker_img 122 | $docker_cmd $docker_args $docker_img /bin/bash -c "$bash_cmd" 123 | if [ $single_run ]; 124 | then 125 | exit 0 126 | fi 127 | done 128 | 129 | -------------------------------------------------------------------------------- /studio/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import time 4 | from terminaltables import AsciiTable 5 | 6 | from studio import model 7 | from studio.util import logs 8 | 9 | _my_logger = None 10 | 11 | 12 | def print_help(): 13 | print('Usage: studio runs [command] arguments') 14 | print('\ncommand can be one of the following:') 15 | print('') 16 | print('\tlist [username] - display experiments') 17 | print('\tstop [experiment] - stop running experiment') 18 | print('\tkill [experiment] - stop and delete experiment') 19 | print('') 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--config', help='configuration file', default=None) 25 | parser.add_argument( 26 | '--short', '-s', help='Brief output - names of experiments only', 27 | action='store_true') 28 | 29 | cli_args, script_args = parser.parse_known_args(sys.argv) 30 | 31 | get_logger().setLevel(10) 32 | 33 | if len(script_args) < 2: 34 | get_logger().critical('No command provided!') 35 | parser.print_help() 36 | print_help() 37 | return 38 | 39 | cmd = script_args[1] 40 | 41 | if cmd == 'list': 42 | _list(script_args[2:], cli_args) 43 | elif cmd == 'stop': 44 | _stop(script_args[2:], cli_args) 45 | elif cmd == 'kill': 46 | _kill(script_args[2:], cli_args) 47 | 48 | else: 49 | get_logger().critical('Unknown command ' + cmd) 50 | parser.print_help() 51 | print_help() 52 | return 53 | 54 | 55 | def _list(args, cli_args): 56 | with model.get_db_provider(cli_args.config) as db: 57 | if len(args) == 0: 58 | experiments = db.get_user_experiments() 59 | elif args[0] == 'project': 60 | assert len(args) == 2 61 | experiments = db.get_project_experiments(args[1]) 62 | elif args[0] == 'users': 63 | assert len(args) == 1 64 | users = db.get_users() 65 | for u in users: 66 | print(users[u].get('email')) 67 | return 68 | elif args[0] == 'user': 69 | assert len(args) == 2 70 | users = db.get_users() 71 | user_ids = [u for u in users if users[u].get('email') == args[1]] 72 | assert len(user_ids) == 1, \ 73 | 'The user with email ' + args[1] + \ 74 | 'not found!' 75 | experiments = db.get_user_experiments(user_ids[0]) 76 | elif args[0] == 'all': 77 | assert len(args) == 1 78 | users = db.get_users() 79 | experiments = [] 80 | for u in users: 81 | experiments += db.get_user_experiments(u) 82 | else: 83 | get_logger().critical('Unknown command ' + args[0]) 84 | return 85 | 86 | if cli_args.short: 87 | for e in experiments: 88 | print(e) 89 | return 90 | 91 | experiments = [db.get_experiment(e) for e in experiments] 92 | 93 | experiments.sort(key=lambda e: -e.time_added) 94 | table = [['Time added', 'Key', 'Project', 'Status']] 95 | 96 | for e in experiments: 97 | table.append([ 98 | time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(e.time_added)), 99 | e.key, 100 | e.project, 101 | e.status]) 102 | 103 | print(AsciiTable(table).table) 104 | 105 | 106 | def _stop(args, cli_args): 107 | with model.get_db_provider(cli_args.config) as db: 108 | for e in args: 109 | get_logger().info('Stopping experiment ' + e) 110 | db.stop_experiment(e) 111 | 112 | 113 | def _kill(args, cli_args): 114 | with model.get_db_provider(cli_args.config) as db: 115 | for e in args: 116 | get_logger().info('Deleting experiment ' + e) 117 | db.delete_experiment(e) 118 | 119 | 120 | def get_logger(): 121 | global _my_logger 122 | if not _my_logger: 123 | _my_logger = logs.get_logger('studio-runs') 124 | return _my_logger 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /studio/scripts/install_studio.sh: -------------------------------------------------------------------------------- 1 | repo_url="{repo_url}" 2 | branch="{studioml_branch}" 3 | 4 | if [ ! -d "studio" ]; then 5 | echo "Installing system packages..." 6 | sudo apt -y update 7 | sudo apt install -y wget git jq 8 | sudo apt install -y python python-pip python-dev python3 python3-dev python3-pip dh-autoreconf build-essential 9 | echo "python2 version: " $(python -V) 10 | 11 | sudo python -m pip install --upgrade pip 12 | sudo python -m pip install --upgrade awscli boto3 13 | 14 | echo "python3 version: " $(python3 -V) 15 | 16 | sudo python3 -m pip install --upgrade pip 17 | sudo python3 -m pip install --upgrade awscli boto3 18 | 19 | # Install singularity 20 | git clone https://github.com/singularityware/singularity.git 21 | cd singularity 22 | ./autogen.sh 23 | ./configure --prefix=/usr/local --sysconfdir=/etc 24 | time (make && make install) 25 | cd .. 26 | 27 | time apt-get -y install python3-tk python-tk 28 | time python -m pip install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp27-cp27mu-linux_x86_64.whl 29 | time python3 -m pip install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp35-cp35m-linux_x86_64.whl 30 | 31 | nvidia-smi 32 | nvidia_smi_error=$? 33 | 34 | if [ "{use_gpus}" -eq 1 ] && [ "$nvidia_smi_error" -ne 0 ]; then 35 | cudnn5="libcudnn5_5.1.10-1_cuda8.0_amd64.deb" 36 | cudnn6="libcudnn6_6.0.21-1_cuda8.0_amd64.deb" 37 | cuda_base="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/" 38 | cuda_ver="cuda-repo-ubuntu1604_8.0.61-1_amd64.deb" 39 | 40 | cuda_driver='nvidia-diag-driver-local-repo-ubuntu1604-384.66_1.0-1_amd64.deb' 41 | wget $code_url_base/$cuda_driver 42 | dpkg -i $cuda_driver 43 | apt-key add /var/nvidia-diag-driver-local-repo-384.66/7fa2af80.pub 44 | apt-get -y update 45 | apt-get -y install cuda-drivers 46 | apt-get -y install unzip 47 | 48 | # install cuda 49 | cuda_url="https://developer.nvidia.com/compute/cuda/8.0/Prod2/local_installers/cuda_8.0.61_375.26_linux-run" 50 | cuda_patch_url="https://developer.nvidia.com/compute/cuda/8.0/Prod2/patches/2/cuda_8.0.61.2_linux-run" 51 | 52 | # install cuda 53 | wget $cuda_url 54 | wget $cuda_patch_url 55 | #udo dpkg -i $cuda_ver 56 | #sudo apt -y update 57 | #sudo apt install -y "cuda-8.0" 58 | sh ./cuda_8.0.61_375.26_linux-run --silent --toolkit 59 | sh ./cuda_8.0.61.2_linux-run --silent --accept-eula 60 | 61 | export PATH=$PATH:/usr/local/cuda-8.0/bin 62 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-8.0/lib64 63 | 64 | # wget $cuda_base/$cuda_ver 65 | # sudo dpkg -i $cuda_ver 66 | # sudo apt -y update 67 | # sudo apt install -y "cuda-8.0" 68 | 69 | # install cudnn 70 | wget $code_url_base/$cudnn5 71 | wget $code_url_base/$cudnn6 72 | sudo dpkg -i $cudnn5 73 | sudo dpkg -i $cudnn6 74 | 75 | # sudo python -m pip install tf-nightly tf-nightly-gpu --upgrade 76 | # sudo python3 -m pip install tf-nightly tf-nightly-gpu --upgrade 77 | else 78 | sudo apt install -y default-jre 79 | fi 80 | fi 81 | 82 | #if [[ "{use_gpus}" -ne 1 ]]; then 83 | # rm -rf /usr/lib/x86_64-linux-gnu/libcuda* 84 | #fi 85 | 86 | rm -rf studio 87 | git clone $repo_url 88 | if [[ $? -ne 0 ]]; then 89 | git clone https://github.com/studioml/studio 90 | fi 91 | 92 | cd studio 93 | git pull 94 | git checkout $branch 95 | 96 | apoptosis() {{ 97 | while : 98 | do 99 | date 100 | shutdown +1 101 | (nvidia-smi; shutdown -c; echo "nvidia-smi functional, preventing shutdown")& 102 | sleep 90 103 | done 104 | 105 | }} 106 | 107 | if [[ "{use_gpus}" -eq 1 ]]; then 108 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-8.0/lib64 109 | # (apoptosis > autoterminate.log)& 110 | fi 111 | 112 | time python -m pip install -e . --upgrade 113 | time python3 -m pip install -e . --upgrade 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /test-runner.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: v1 3 | kind: PersistentVolumeClaim 4 | metadata: 5 | # This name uniquely identifies the PVC. Will be used in deployment below. 6 | name: test-runner-pv-claim 7 | labels: 8 | app: test-runner-storage-claim 9 | namespace: default 10 | spec: 11 | # Read more about access modes here: https://kubernetes.io/docs/user-guide/persistent-volumes/#access-modes 12 | accessModes: 13 | - ReadWriteMany 14 | resources: 15 | # This is the request for storage. Should be available in the cluster. 16 | requests: 17 | storage: 10Gi 18 | # Uncomment and add storageClass specific to your requirements below. Read more https://kubernetes.io/docs/concepts/storage/persistent-volumes/#class-1 19 | #storageClassName: 20 | --- 21 | apiVersion: extensions/v1beta1 22 | kind: Deployment 23 | metadata: 24 | name: test-runner 25 | namespace: default 26 | labels: 27 | name: "test-runer" 28 | keel.sh/policy: force 29 | keel.sh/trigger: poll 30 | spec: 31 | template: 32 | metadata: 33 | name: test-runner 34 | namespace: default 35 | labels: 36 | app: test-runner 37 | spec: 38 | volumes: 39 | - name: test-runner-storage 40 | persistentVolumeClaim: 41 | # Name of the PVC created earlier 42 | claimName: test-runner-pv-claim 43 | - name: podinfo 44 | downwardAPI: 45 | items: 46 | - path: "namespace" 47 | fieldRef: 48 | fieldPath: metadata.namespace 49 | - path: "annotations" 50 | fieldRef: 51 | fieldPath: metadata.annotations 52 | - path: "labels" 53 | fieldRef: 54 | fieldPath: metadata.labels 55 | containers: 56 | - image: xujiamin9/standalone_testing 57 | imagePullPolicy: Always 58 | name: test-runner 59 | env: 60 | - name: K8S_POD_NAME 61 | valueFrom: 62 | fieldRef: 63 | fieldPath: metadata.name 64 | - name: K8S_NAMESPACE 65 | valueFrom: 66 | fieldRef: 67 | fieldPath: metadata.namespace 68 | ports: 69 | - containerPort: 8080 70 | resources: 71 | requests: 72 | memory: "128Mi" 73 | cpu: "500m" 74 | limits: 75 | memory: "1024Mi" 76 | cpu: 4 77 | volumeMounts: 78 | - name: test-runner-storage # must match the volume name, above 79 | mountPath: "/build" 80 | - name: podinfo 81 | mountPath: /etc/podinfo 82 | readOnly: false 83 | lifecycle: 84 | postStart: 85 | exec: 86 | command: 87 | - "/bin/bash" 88 | - "-c" 89 | - > 90 | set -euo pipefail ; 91 | IFS=$'\n\t' ; 92 | echo "Starting the keel modifications" $K8S_POD_NAME ; 93 | kubectl label deployment test-runner keel.sh/policy- --namespace=$K8S_NAMESPACE ; 94 | curl -v --cacert /var/run/secrets/kubernetes.io/serviceaccount/ca.crt -H "Authorization: Bearer $(cat /var/runsecrets/kubernetes.io/serviceaccount/token)" https://$KUBERNETES_SERVICE_HOST:$KUBERNETES_PORT_443_TCP_PORT/api/v1/namespaces/$K8S_NAMESPACE/pods/$K8S_POD_NAME 95 | preStop: 96 | exec: 97 | command: 98 | - "/bin/bash" 99 | - "-c" 100 | - > 101 | set -euo pipefail; 102 | IFS=$'\n\t' ; 103 | echo "Starting the namespace injections etc" $K8S_POD_NAME ; 104 | kubectl label deployment test-runner keel.sh/policy=force --namespace=$K8S_NAMESPACE ; 105 | for (( ; ; )) ; 106 | do ; 107 | sleep 10 ; 108 | done 109 | -------------------------------------------------------------------------------- /studio/storage/local_storage_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from typing import Dict 4 | 5 | from studio.storage.storage_setup import get_storage_verbose_level 6 | from studio.storage.storage_type import StorageType 7 | from studio.storage.storage_handler import StorageHandler 8 | from studio.util import logs 9 | from studio.util import util 10 | 11 | class LocalStorageHandler(StorageHandler): 12 | def __init__(self, config, 13 | measure_timestamp_diff=False, 14 | compression=None): 15 | 16 | self.logger = logs.get_logger(self.__class__.__name__) 17 | self.logger.setLevel(get_storage_verbose_level()) 18 | 19 | if compression is None: 20 | compression = config.get('compression', None) 21 | 22 | self.endpoint = config.get('endpoint', '~') 23 | self.endpoint = os.path.realpath(os.path.expanduser(self.endpoint)) 24 | if not os.path.exists(self.endpoint) \ 25 | or not os.path.isdir(self.endpoint): 26 | msg: str = "Store root {0} doesn't exist or not a directory. Aborting."\ 27 | .format(self.endpoint) 28 | self._report_fatal(msg) 29 | 30 | self.bucket = config.get('bucket', 'storage') 31 | self.store_root = os.path.join(self.endpoint, self.bucket) 32 | self._ensure_path_dirs_exist(self.store_root) 33 | 34 | super().__init__(StorageType.storageLocal, 35 | self.logger, 36 | measure_timestamp_diff, 37 | compression=compression) 38 | 39 | def _ensure_path_dirs_exist(self, path): 40 | dirs = os.path.dirname(path) 41 | os.makedirs(dirs, mode = 0o777, exist_ok = True) 42 | 43 | def _copy_file(self, from_path, to_path): 44 | try: 45 | shutil.copyfile(from_path, to_path) 46 | except Exception as exc: 47 | msg: str = "FAILED to copy '{0}' to '{1}': {2}. Aborting."\ 48 | .format(from_path, to_path,exc) 49 | self._report_fatal(msg) 50 | 51 | def upload_file(self, key, local_path): 52 | target_path = os.path.join(self.store_root, key) 53 | if not os.path.exists(local_path): 54 | self.logger.debug( 55 | "Local path {0} does not exist. SKIPPING upload to {1}" 56 | .format(local_path, target_path)) 57 | return False 58 | self._ensure_path_dirs_exist(target_path) 59 | self._copy_file(local_path, target_path) 60 | return True 61 | 62 | def download_file(self, key, local_path): 63 | source_path = os.path.join(self.store_root, key) 64 | if not os.path.exists(source_path): 65 | self.logger.debug( 66 | "Source path {0} does not exist. SKIPPING download to {1}" 67 | .format(source_path, local_path)) 68 | return False 69 | self._ensure_path_dirs_exist(local_path) 70 | self._copy_file(source_path, local_path) 71 | return True 72 | 73 | def delete_file(self, key, shallow=True): 74 | key_path: str = self._get_file_path_from_key(key) 75 | if os.path.exists(key_path): 76 | self.logger.debug("Deleting local file {0}.".format(key_path)) 77 | util.delete_local_path(key_path, self.store_root, False) 78 | 79 | @classmethod 80 | def get_id(cls, config: Dict) -> str: 81 | endpoint = config.get('endpoint', None) 82 | if endpoint is None: 83 | return None 84 | return '[local]{0}'.format(endpoint) 85 | 86 | def _get_file_path_from_key(self, key: str): 87 | return str(os.path.join(self.store_root, key)) 88 | 89 | def get_file_url(self, key, method='GET'): 90 | return self._get_file_path_from_key(key) 91 | 92 | def get_file_timestamp(self, key): 93 | key_path: str = self._get_file_path_from_key(key) 94 | if os.path.exists(key_path): 95 | return os.path.getmtime(key_path) 96 | else: 97 | return None 98 | 99 | def get_qualified_location(self, key): 100 | return 'file:/' + self.store_root + '/' + key 101 | 102 | def get_endpoint(self): 103 | return self.endpoint 104 | 105 | def _report_fatal(self, msg: str): 106 | util.report_fatal(msg, self.logger) 107 | -------------------------------------------------------------------------------- /studio/model.py: -------------------------------------------------------------------------------- 1 | """Data providers.""" 2 | import uuid 3 | 4 | from studio.firebase_provider import FirebaseProvider 5 | from studio.http_provider import HTTPProvider 6 | from studio.pubsub_queue import PubsubQueue 7 | from studio.gcloud_worker import GCloudWorkerManager 8 | from studio.ec2cloud_worker import EC2WorkerManager 9 | from studio.util.util import parse_verbosity 10 | from studio.auth import get_auth 11 | 12 | from studio.db_providers import db_provider_setup 13 | from studio.queues import queues_setup 14 | from studio.storage.storage_setup import setup_storage, get_storage_db_provider,\ 15 | reset_storage, set_storage_verbose_level 16 | from studio.util import logs 17 | 18 | def reset_storage_providers(): 19 | reset_storage() 20 | 21 | 22 | def get_config(config_file=None): 23 | return db_provider_setup.get_config(config_file=config_file) 24 | 25 | 26 | def get_db_provider(config=None, blocking_auth=True): 27 | 28 | db_provider = get_storage_db_provider() 29 | if db_provider is not None: 30 | return db_provider 31 | 32 | if config is None: 33 | config = get_config() 34 | verbose = parse_verbosity(config.get('verbose', None)) 35 | 36 | # Save this verbosity level as global for the whole experiment job: 37 | set_storage_verbose_level(verbose) 38 | 39 | logger = logs.get_logger("get_db_provider") 40 | logger.setLevel(verbose) 41 | logger.debug('Choosing db provider with config:') 42 | logger.debug(config) 43 | 44 | if 'storage' in config.keys(): 45 | artifact_store = db_provider_setup.get_artifact_store(config['storage']) 46 | else: 47 | artifact_store = None 48 | 49 | assert 'database' in config.keys() 50 | db_config = config['database'] 51 | if db_config['type'].lower() == 'firebase': 52 | db_provider = FirebaseProvider(db_config, 53 | blocking_auth=blocking_auth) 54 | 55 | elif db_config['type'].lower() == 'http': 56 | db_provider = HTTPProvider(db_config, 57 | verbose=verbose, 58 | blocking_auth=blocking_auth) 59 | else: 60 | db_provider = db_provider_setup.get_db_provider( 61 | config=config, blocking_auth=blocking_auth) 62 | 63 | setup_storage(db_provider, artifact_store) 64 | return db_provider 65 | 66 | def get_queue( 67 | queue_name=None, 68 | cloud=None, 69 | config=None, 70 | logger=None, 71 | close_after=None, 72 | verbose=10): 73 | queue = queues_setup.get_queue(queue_name=queue_name, 74 | cloud=cloud, 75 | config=config, 76 | logger=logger, 77 | close_after=close_after, 78 | verbose=verbose) 79 | if queue is None: 80 | queue = PubsubQueue(queue_name, verbose=verbose) 81 | return queue 82 | 83 | def shutdown_queue(queue, logger=None, delete_queue=True): 84 | queues_setup.shutdown_queue(queue, logger=logger, delete_queue=delete_queue) 85 | 86 | def get_worker_manager(config, cloud=None, verbose=10): 87 | if cloud is None: 88 | return None 89 | 90 | assert cloud in ['gcloud', 'gcspot', 'ec2', 'ec2spot'] 91 | logger = logs.get_logger('runner.get_worker_manager') 92 | logger.setLevel(verbose) 93 | 94 | auth = get_auth(config['database']['authentication']) 95 | auth_cookie = auth.get_token_file() if auth else None 96 | 97 | branch = config['cloud'].get('branch') 98 | 99 | logger.info('using branch {}'.format(branch)) 100 | 101 | if cloud in ['gcloud', 'gcspot']: 102 | 103 | cloudconfig = config['cloud']['gcloud'] 104 | worker_manager = GCloudWorkerManager( 105 | auth_cookie=auth_cookie, 106 | zone=cloudconfig['zone'], 107 | branch=branch, 108 | user_startup_script=config['cloud'].get('user_startup_script') 109 | ) 110 | 111 | if cloud in ['ec2', 'ec2spot']: 112 | worker_manager = EC2WorkerManager( 113 | auth_cookie=auth_cookie, 114 | branch=branch, 115 | user_startup_script=config['cloud'].get('user_startup_script') 116 | ) 117 | return worker_manager 118 | 119 | -------------------------------------------------------------------------------- /studio/db_providers/db_provider_setup.py: -------------------------------------------------------------------------------- 1 | """Data providers.""" 2 | import os 3 | import yaml 4 | import pyhocon 5 | 6 | from studio.db_providers.local_db_provider import LocalDbProvider 7 | from studio.db_providers.s3_provider import S3Provider 8 | from studio.storage.storage_handler import StorageHandler 9 | from studio.storage.storage_handler_factory import StorageHandlerFactory 10 | from studio.storage.storage_setup import setup_storage, get_storage_db_provider,\ 11 | set_storage_verbose_level 12 | from studio.storage.storage_type import StorageType 13 | from studio.util import logs 14 | from studio.util.util import parse_verbosity 15 | 16 | def get_config(config_file=None): 17 | 18 | config_paths = [] 19 | if config_file: 20 | if not os.path.exists(config_file): 21 | raise ValueError('User config file {} not found' 22 | .format(config_file)) 23 | config_paths.append(os.path.expanduser(config_file)) 24 | 25 | config_paths.append(os.path.expanduser('~/.studioml/config.yaml')) 26 | config_paths.append( 27 | os.path.join( 28 | os.path.dirname(os.path.realpath(__file__)), 29 | "default_config.yaml")) 30 | 31 | for path in config_paths: 32 | if not os.path.exists(path): 33 | continue 34 | 35 | with(open(path)) as f_in: 36 | if path.endswith('.hocon'): 37 | config = pyhocon.ConfigFactory.parse_string(f_in.read()) 38 | else: 39 | config = yaml.load(f_in.read(), Loader=yaml.FullLoader) 40 | 41 | def replace_with_env(config): 42 | for key, value in config.items(): 43 | if isinstance(value, str): 44 | config[key] = os.path.expandvars(value) 45 | 46 | elif isinstance(value, dict): 47 | replace_with_env(value) 48 | 49 | replace_with_env(config) 50 | 51 | return config 52 | 53 | raise ValueError('None of the config paths {0} exists!' 54 | .format(config_paths)) 55 | 56 | def get_artifact_store(config) -> StorageHandler: 57 | storage_type: str = config['type'].lower() 58 | 59 | factory: StorageHandlerFactory = StorageHandlerFactory.get_factory() 60 | if storage_type == 's3': 61 | handler = factory.get_handler(StorageType.storageS3, config) 62 | return handler 63 | if storage_type == 'local': 64 | handler = factory.get_handler(StorageType.storageLocal, config) 65 | return handler 66 | raise ValueError('Unknown storage type: ' + storage_type) 67 | 68 | def get_db_provider(config=None, blocking_auth=True): 69 | 70 | db_provider = get_storage_db_provider() 71 | if db_provider is not None: 72 | return db_provider 73 | 74 | if config is None: 75 | config = get_config() 76 | verbose = parse_verbosity(config.get('verbose', None)) 77 | 78 | # Save this verbosity level as global for the whole experiment job: 79 | set_storage_verbose_level(verbose) 80 | 81 | logger = logs.get_logger("get_db_provider") 82 | logger.setLevel(verbose) 83 | logger.debug('Choosing db provider with config:') 84 | logger.debug(config) 85 | 86 | if 'storage' in config.keys(): 87 | artifact_store = get_artifact_store(config['storage']) 88 | else: 89 | artifact_store = None 90 | 91 | assert 'database' in config.keys() 92 | db_config = config['database'] 93 | if db_config['type'].lower() == 's3': 94 | db_provider = S3Provider(db_config, 95 | blocking_auth=blocking_auth) 96 | if artifact_store is None: 97 | artifact_store = db_provider.get_storage_handler() 98 | 99 | elif db_config['type'].lower() == 'gs': 100 | raise NotImplementedError("GS is not supported.") 101 | 102 | elif db_config['type'].lower() == 'local': 103 | db_provider = LocalDbProvider(db_config, 104 | blocking_auth=blocking_auth) 105 | if artifact_store is None: 106 | artifact_store = db_provider.get_storage_handler() 107 | 108 | else: 109 | raise ValueError('Unknown type of the database ' + db_config['type']) 110 | 111 | setup_storage(db_provider, artifact_store) 112 | return db_provider 113 | -------------------------------------------------------------------------------- /docs/gcloud_setup.rst: -------------------------------------------------------------------------------- 1 | Setting up Google Cloud Compute 2 | =============================== 3 | 4 | This page describes the process of setting up Google Cloud and 5 | configuring Studio to integrate with it. 6 | 7 | Configuring Google Cloud Compute 8 | -------------------------------- 9 | 10 | Create and select a new Google Cloud project 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | Go to the Google Cloud console (https://console.cloud.google.com), and 14 | either choose a project that you will use to back cloud 15 | computing or create a new one. If you have not used the Google console 16 | before and there are no projects, there will be a big button "create 17 | project" in the dashboard. Otherwise, you can create a new project by 18 | selecting the drop-down arrow next to current project name in the top 19 | panel, and then clicking the "+" button. 20 | 21 | Enable billing for the project 22 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 23 | 24 | Google Cloud computing actually bills you for the compute time you 25 | use, so you must have billing enabled. On the bright side, when you 26 | sign up with Google Cloud they provide $300 of promotional credit, so 27 | really in the beginning you are still using it for free. On the not so 28 | bright side, to use machines with gpus you'll need to show 29 | that you are a legitimate customer and add $35 to your billing account. 30 | In order to enable billing, go to the left-hand pane in the Google Cloud 31 | console, select billing, and follow the instructions to set up your payment 32 | method. 33 | 34 | Generate service credentials 35 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 36 | 37 | The machines that submit cloud jobs will need to be authorized with 38 | service credentials. Go to the left-hand pane in the Google Cloud console and 39 | select API Manager -> Credentials. Then click the "Create credentials" 40 | button, choose service account key, leave key type as JSON, and in the 41 | "Service account" drop-down select "New service account". Enter a 42 | service account name (the name can be virtually anything and won't 43 | matter for the rest of the instructions). The important part is selecting a 44 | role. Click the "Select a role" dropdown menu, in "Project" select "Service 45 | Account Actor", and then scroll down to "Compute Engine" and select "Compute 46 | Engine Admin (v1)". Then scroll down to "Pub/Sub", and add a role 47 | "Pub/Sub editor" (this is required to create queues, publish and read 48 | messages from them). If you are planning to use Google Cloud storage 49 | (directly, without the Firebase layer) for artifact storage, select the Storage 50 | Admin role as well. You can also add other roles if you are planning to use 51 | these credentials in other applications. When done, click "Create". 52 | Google Cloud console should generate a json credentials file and save it 53 | to your computer. 54 | 55 | Configuring Studio 56 | ------------------ 57 | 58 | Adding credentials 59 | ~~~~~~~~~~~~~~~~~~ 60 | 61 | Copy the json file credentials to the machine where Studio will be 62 | run, and create the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` 63 | that points to it. That is, run 64 | 65 | :: 66 | 67 | export GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json 68 | 69 | Note that this variable will be gone when you restart the terminal, so 70 | if you want to reuse it, add it to ``~/.bashrc`` (linux) or 71 | ``~/.bash_profile`` (OS X) 72 | 73 | Modifying the configuration file 74 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 75 | 76 | In the config file (the one that you use with the ``--config`` flag, or, if you 77 | use the default, in the ``studio/default_config.yaml``), go to the ``cloud`` 78 | section. Change projectId to the project id of the Google project for which 79 | you enabled cloud computing. You can also modify the default instance 80 | parameters (see `Cloud computing for studio `__ for 81 | limitations though). 82 | 83 | Test 84 | ~~~~ 85 | 86 | To test if things are set up correctly, go to 87 | ``studio/examples/general`` and run 88 | 89 | :: 90 | 91 | studio run --cloud=gcloud report_system_info.py 92 | 93 | Then run ``studio`` locally, and watch the new experiment. In a little 94 | while, it should change its status to "finished" and show the system 95 | information (number of cpus, amount of ram / hdd) of a default instance. 96 | See `Cloud computing for studio `__ for more instructions on 97 | using an instance with specific hardware parameters. 98 | -------------------------------------------------------------------------------- /studio/queues/local_queue.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | import time 4 | import filelock 5 | 6 | from studio.artifacts.artifacts_tracker import get_studio_home 7 | from studio.storage.storage_setup import get_storage_verbose_level 8 | from studio.util import logs 9 | from studio.util.util import check_for_kb_interrupt, rm_rf 10 | 11 | LOCK_FILE_NAME = 'lock.lock' 12 | 13 | class LocalQueue: 14 | def __init__(self, name: str, path: str = None, logger=None): 15 | if logger is not None: 16 | self._logger = logger 17 | else: 18 | self._logger = logs.get_logger('LocalQueue') 19 | self._logger.setLevel(get_storage_verbose_level()) 20 | 21 | self.name = name 22 | if path is None: 23 | self.path = self._get_queue_directory() 24 | else: 25 | self.path = path 26 | self.path = os.path.join(self.path, name) 27 | os.makedirs(self.path, exist_ok=True) 28 | 29 | # Local queue is considered active, iff its directory exists. 30 | self._lock_path = os.path.join(self.path, LOCK_FILE_NAME) 31 | self._lock = filelock.SoftFileLock(self._lock_path) 32 | 33 | 34 | def _get_queue_status(self): 35 | is_active = os.path.isdir(self.path) 36 | if is_active: 37 | try: 38 | with self._lock: 39 | files = os.listdir(self.path) 40 | files.remove(LOCK_FILE_NAME) 41 | return True, files 42 | except BaseException as exc: 43 | check_for_kb_interrupt() 44 | self._logger.info("FAILED to get queue status for %s - %s", 45 | self.path, exc) 46 | # Ignore possible exception: 47 | # we just want list of files without internal lock file 48 | return False, list() 49 | 50 | def _get_queue_directory(self): 51 | queue_dir: str = os.path.join( 52 | get_studio_home(), 53 | 'queue') 54 | return queue_dir 55 | 56 | def has_next(self): 57 | is_active, files = self._get_queue_status() 58 | return is_active and len(files) > 0 59 | 60 | def is_active(self): 61 | is_active = os.path.isdir(self.path) 62 | return is_active 63 | 64 | def clean(self, timeout=0): 65 | _ = timeout 66 | rm_rf(self.path) 67 | 68 | def delete(self): 69 | self.clean() 70 | 71 | def _get_time(self, file: str): 72 | return os.path.getmtime(os.path.join(self.path, file)) 73 | 74 | def dequeue(self, acknowledge=True, timeout=0): 75 | sleep_in_seconds = 1 76 | total_wait_time = 0 77 | if not self.is_active(): 78 | return None 79 | 80 | while True: 81 | with self._lock: 82 | is_active, files = self._get_queue_status() 83 | if not is_active: 84 | return None, None 85 | if any(files): 86 | first_file = min([(p, self._get_time(p)) for p in files], 87 | key=lambda t: t[1])[0] 88 | first_file = os.path.join(self.path, first_file) 89 | 90 | with open(first_file, 'r') as f_in: 91 | data = f_in.read() 92 | 93 | if acknowledge: 94 | self.acknowledge(first_file) 95 | return data, None 96 | return data, first_file 97 | 98 | if total_wait_time >= timeout: 99 | return None, None 100 | time.sleep(sleep_in_seconds) 101 | total_wait_time += sleep_in_seconds 102 | 103 | def enqueue(self, data): 104 | with self._lock: 105 | filename = os.path.join(self.path, str(uuid.uuid4())) 106 | with open(filename, 'w') as f_out: 107 | f_out.write(data) 108 | 109 | def acknowledge(self, key): 110 | try: 111 | os.remove(key) 112 | except BaseException: 113 | check_for_kb_interrupt() 114 | 115 | def hold(self, key, minutes): 116 | _ = minutes 117 | self.acknowledge(key) 118 | 119 | def get_name(self): 120 | return self.name 121 | 122 | def get_path(self): 123 | return self.path 124 | 125 | def shutdown(self, delete_queue=True): 126 | _ = delete_queue 127 | self.delete() 128 | --------------------------------------------------------------------------------