├── studio
├── util
│ ├── __init__.py
│ ├── logs.py
│ └── gpu_util.py
├── artifacts
│ └── __init__.py
├── queues
│ ├── __init__.py
│ ├── queues_setup.py
│ ├── qclient_cache.py
│ └── local_queue.py
├── credentials
│ └── __init__.py
├── db_providers
│ ├── __init__.py
│ ├── local_db_provider.py
│ ├── s3_provider.py
│ └── db_provider_setup.py
├── experiments
│ └── __init__.py
├── optimizer_plugins
│ ├── __init__.py
│ └── opt_util.py
├── payload_builders
│ ├── __init__.py
│ ├── unencrypted_payload_builder.py
│ └── payload_builder.py
├── storage
│ ├── __init__.py
│ ├── storage_type.py
│ ├── storage_setup.py
│ ├── http_storage_handler.py
│ ├── storage_handler_factory.py
│ ├── storage_handler.py
│ └── local_storage_handler.py
├── dependencies_policies
│ ├── __init__.py
│ ├── dependencies_policy.py
│ └── studio_dependencies_policy.py
├── scripts
│ ├── studio-runs
│ ├── studio-run
│ ├── studio-serve
│ ├── studio-ui
│ ├── studio-local-worker
│ ├── studio-remote-worker
│ ├── studio
│ ├── studio-add-credentials
│ ├── ec2_worker_startup.sh
│ ├── gcloud_worker_startup.sh
│ ├── studio-start-remote-worker
│ └── install_studio.sh
├── completion_service
│ ├── __init__.py
│ ├── completion_service_testfunc.py
│ ├── completion_service_testfunc_files.py
│ ├── completion_service_testfunc_saveload.py
│ ├── completion_service_client.py
│ └── encryptor.py
├── static
│ ├── tfs_small.png
│ └── Studio.ml-icon-std-1000.png
├── torch
│ ├── __init__.py
│ ├── summary_test.py
│ ├── saver.py
│ └── summary.py
├── templates
│ ├── error.html
│ ├── user_details.html
│ ├── all_experiments.html
│ ├── dashboard.html
│ ├── project_details.html
│ ├── projects.html
│ └── users.html
├── __init__.py
├── appengine_config.py
├── app.yaml
├── apiserver_config.yaml
├── run_magic.py.stub
├── aws
│ ├── aws_amis.yaml
│ └── aws_prices.yaml
├── default_config.yaml
├── client_config.yaml
├── server_config.yaml
├── patches
│ └── requests
│ │ └── models.py.patch
├── rmq_config.yml
├── postgres_provider.py
├── serve.py
├── firebase_provider.py
├── ed25519_key_util.py
├── deploy_apiserver.sh
├── fs_tracker.py
├── cloud_worker_util.py
├── remote_worker.py
├── git_util.py
├── experiment_submitter.py
├── magics.py
├── cli.py
└── model.py
├── .git_archival.txt
├── .gitattributes
├── logo.png
├── docs
├── logo.png
├── _static
│ └── img
│ │ └── logo.png
├── docker.rst
├── examples.rst
├── testing.rst
├── authentication.rst
├── ci_testing.rst
├── ec2_setup.rst
├── cli.rst
├── local_filesystem_setup.rst
├── customenv.rst
├── index.rst
├── jupyter.rst
├── containers.rst
├── faq.rst
├── README.rst
└── gcloud_setup.rst
├── extra_example_requirements.txt
├── tests
├── hyperparam_hello_world.py
├── check_style.sh
├── test_array.py
├── stop_experiment.py
├── test_config_http_server.yaml
├── tf_hello_world.py
├── model_increment.py
├── conflicting_args.py
├── art_hello_world.py
├── config_http_client.yaml
├── test_config_s3_storage.yaml
├── test_config_gcloud_storage.yaml
├── test_config_http_client.yaml
├── env_detect.py
├── save_model.py
├── test_bad_config.yaml
├── runner_test.py
├── test_config_s3.yaml
├── config_http_server.yaml
├── gpu_util_test.py
├── test_config_env.yaml
├── test_config_gs.yaml
├── test_config_auth.yaml
├── test_config_datacenter.yaml
├── fs_tracker_test.py
├── test_config.yaml
├── git_util_test.py
├── hyperparam_test.py
├── util_test.py
├── model_test.py
├── serving_test.py
└── http_provider_hosted_test.py
├── MANIFEST.in
├── test_requirements-cs.txt
├── test_requirements.txt
├── requirements-cs.txt
├── examples
├── pytorch
│ └── README.md
├── tensorflow
│ ├── helloworld.py
│ └── train_mnist.py
├── general
│ ├── report_system_info.py
│ ├── print_norm_linreg.py
│ └── train_linreg.py
└── keras
│ ├── train_mnist.py
│ ├── train_mnist_keras_mutligpu.py
│ ├── multi_gpu.py
│ ├── fashion_mnist.py
│ └── train_cifar10.py
├── studioml.bib
├── studioml.bibtex
├── Dockerfile
├── .github
└── dependabot.yml
├── requirements.txt
├── Dockerfile_keras_example
├── runtests.sh
├── Dockerfile_standalone_testing
├── .gitignore
├── .travis.yml
└── test-runner.yaml
/studio/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/studio/artifacts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/studio/queues/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/studio/credentials/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/studio/db_providers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/studio/experiments/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/studio/optimizer_plugins/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/studio/payload_builders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/studio/storage/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/studio/dependencies_policies/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.git_archival.txt:
--------------------------------------------------------------------------------
1 | ref-names: HEAD -> master
2 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | .git_archival.txt export-subst
2 |
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/studioml/studio/HEAD/logo.png
--------------------------------------------------------------------------------
/docs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/studioml/studio/HEAD/docs/logo.png
--------------------------------------------------------------------------------
/extra_example_requirements.txt:
--------------------------------------------------------------------------------
1 | pillow
2 | keras
3 | jupyter
4 | tensorflow
5 |
6 |
--------------------------------------------------------------------------------
/studio/scripts/studio-runs:
--------------------------------------------------------------------------------
1 | #!python
2 | from studio import cli
3 | cli.main()
4 |
--------------------------------------------------------------------------------
/studio/scripts/studio-run:
--------------------------------------------------------------------------------
1 | #!python
2 | from studio import runner
3 | runner.main()
4 |
--------------------------------------------------------------------------------
/studio/scripts/studio-serve:
--------------------------------------------------------------------------------
1 | #!python
2 | from studio import serve
3 | serve.main()
4 |
--------------------------------------------------------------------------------
/studio/completion_service/__init__.py:
--------------------------------------------------------------------------------
1 | from .completion_service import CompletionService
2 |
--------------------------------------------------------------------------------
/studio/scripts/studio-ui:
--------------------------------------------------------------------------------
1 | #!python
2 | from studio import apiserver
3 | apiserver.main()
4 |
--------------------------------------------------------------------------------
/docs/_static/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/studioml/studio/HEAD/docs/_static/img/logo.png
--------------------------------------------------------------------------------
/studio/scripts/studio-local-worker:
--------------------------------------------------------------------------------
1 | #!python
2 | from studio import local_worker
3 | local_worker.main()
4 |
--------------------------------------------------------------------------------
/studio/static/tfs_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/studioml/studio/HEAD/studio/static/tfs_small.png
--------------------------------------------------------------------------------
/studio/scripts/studio-remote-worker:
--------------------------------------------------------------------------------
1 | #!python
2 | from studio import remote_worker
3 | remote_worker.main()
4 |
--------------------------------------------------------------------------------
/studio/torch/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .saver import Saver, load_checkpoint, save_checkpoint
3 | from .summary import Reporter
4 |
--------------------------------------------------------------------------------
/tests/hyperparam_hello_world.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | learning_rate = 0.3
4 | print(learning_rate)
5 |
6 | sys.stdout.flush()
7 |
--------------------------------------------------------------------------------
/studio/static/Studio.ml-icon-std-1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/studioml/studio/HEAD/studio/static/Studio.ml-icon-std-1000.png
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.rst
2 | include studio/default_config.yaml
3 | include studio/templates/*.html
4 | include requirements.txt
5 |
--------------------------------------------------------------------------------
/test_requirements-cs.txt:
--------------------------------------------------------------------------------
1 | pycodestyle
2 | pytest_xdist
3 | tensorflow
4 | keras
5 | certifi==2023.7.22
6 | pillow
7 | timeout_decorator
8 |
9 |
--------------------------------------------------------------------------------
/test_requirements.txt:
--------------------------------------------------------------------------------
1 | pycodestyle
2 | pytest_xdist
3 | tensorflow
4 | keras
5 | certifi==2023.7.22
6 | pillow
7 | timeout_decorator
8 |
9 |
--------------------------------------------------------------------------------
/docs/docker.rst:
--------------------------------------------------------------------------------
1 | To run a docker image with aws credentials us
2 | sudo docker run -v ${HOME}/.aws/credentials:/root/.aws/credentials:ro standalone_testing
3 |
4 |
--------------------------------------------------------------------------------
/studio/templates/error.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block content %}
4 |
Error
5 | {{ errormsg }} |
6 | {% endblock %}
7 |
--------------------------------------------------------------------------------
/studio/__init__.py:
--------------------------------------------------------------------------------
1 | from setuptools_scm import get_version
2 |
3 | try:
4 | __version__ = get_version(root='..', relative_to=__file__)
5 | except BaseException:
6 | pass
7 |
--------------------------------------------------------------------------------
/requirements-cs.txt:
--------------------------------------------------------------------------------
1 | pip
2 | setuptools_scm
3 |
4 | configparser
5 |
6 | requests
7 |
8 | PyYAML
9 | pyhocon
10 |
11 | boto3
12 |
13 | filelock
14 | pika >= 1.1.0
15 |
16 |
--------------------------------------------------------------------------------
/docs/examples.rst:
--------------------------------------------------------------------------------
1 | In order to try examples in studio/examples directory,
2 | be sure to install additional Python dependencies by:
3 |
4 | pip install -r extra_example_requirements.txt
5 |
6 |
--------------------------------------------------------------------------------
/studio/appengine_config.py:
--------------------------------------------------------------------------------
1 | from google.appengine.ext import vendor
2 | import tempfile
3 | import subprocess
4 |
5 | tempfile.SpooledTemporaryFile = tempfile.TemporaryFile
6 | subprocess.Popen = None
7 |
8 | vendor.add('lib')
9 |
--------------------------------------------------------------------------------
/tests/check_style.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | APP_ROOT=$(dirname $(dirname $(stat -fm $0)))
4 | echo $APP_ROOT
5 | pep8 --show-source --statistics $APP_ROOT
6 | # find $APP_ROOT -name "*.py" -exec pep8 --show-source {} \;
7 |
--------------------------------------------------------------------------------
/tests/test_array.py:
--------------------------------------------------------------------------------
1 | from studio import fs_tracker
2 | import numpy as np
3 |
4 | try:
5 | lr = np.load(fs_tracker.get_artifact('lr'))
6 | except BaseException:
7 | lr = np.random.random(10)
8 |
9 | print("fitness: %s" % np.abs(np.sum(lr)))
10 |
--------------------------------------------------------------------------------
/tests/stop_experiment.py:
--------------------------------------------------------------------------------
1 | import time
2 | from studio import logs
3 |
4 | logger = logs.get_logger('helloworld')
5 | logger.setLevel(10)
6 |
7 | i = 0
8 | while True:
9 | logger.info('{} seconds passed '.format(i))
10 | time.sleep(1)
11 | i += 1
12 |
--------------------------------------------------------------------------------
/tests/test_config_http_server.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: gs
3 | bucket: "studioml-meta"
4 | authentication: none
5 |
6 | storage:
7 | type: gcloud
8 | bucket: "studioml-artifacts"
9 |
10 | server:
11 | authentication: github
12 |
13 |
--------------------------------------------------------------------------------
/tests/tf_hello_world.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | a = tf.constant([1.0, 5.0])
4 |
5 | @tf.function
6 | def forward(x):
7 | return x + 1.0
8 |
9 | result = forward(a)
10 | assert len(result) == 2
11 |
12 | print("[ {} {} ]".format(result[0], result[1]))
13 |
--------------------------------------------------------------------------------
/studio/storage/storage_type.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | class StorageType(Enum):
4 | storageHTTP = 1
5 | storageS3 = 2
6 | storageLocal = 3
7 | storageFirebase = 4
8 | storageDockerHub = 5
9 | storageSHub = 6
10 |
11 | storageInvalid = 99
--------------------------------------------------------------------------------
/tests/model_increment.py:
--------------------------------------------------------------------------------
1 | import six
2 |
3 |
4 | def create_model(modeldir):
5 |
6 | def model(data):
7 | retval = {}
8 | for k, v in six.iteritems(data):
9 | retval[k] = v + 1
10 | return retval
11 |
12 | return model
13 |
--------------------------------------------------------------------------------
/studio/app.yaml:
--------------------------------------------------------------------------------
1 | runtime: python27
2 | # api_verison: 1
3 | threadsafe: false
4 |
5 | handlers:
6 | - url: /.*
7 | script: studio.apiserver.app
8 |
9 | libraries:
10 | - name: ssl
11 | version: latest
12 | - name: numpy
13 | version: latest
14 |
15 |
--------------------------------------------------------------------------------
/studio/completion_service/completion_service_testfunc.py:
--------------------------------------------------------------------------------
1 |
2 | def clientFunction(args, files):
3 | print('client function call with args ' +
4 | str(args) + ' and files ' + str(files))
5 | return args
6 |
7 |
8 | if __name__ == "__main__":
9 | clientFunction()
10 |
--------------------------------------------------------------------------------
/tests/conflicting_args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 |
4 |
5 | parser = argparse.ArgumentParser(description='Test argument conflict')
6 | parser.add_argument('--experiment', '-e', help='experiment key', required=True)
7 | args = parser.parse_args()
8 |
9 | print("Experiment key = " + args.experiment)
10 | sys.stdout.flush()
11 |
--------------------------------------------------------------------------------
/examples/pytorch/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch Examples for Studio
2 |
3 | General dependency:
4 |
5 | ```
6 | pip install torch torchvision
7 | ```
8 |
9 | ## MNIST
10 |
11 | Straightforward example for MNIST dataset, taken from `https://github.com/pytorch/examples/tree/master/mnist`:
12 |
13 | ```
14 | studio run mnist.py
15 | ```
16 |
17 |
--------------------------------------------------------------------------------
/studio/apiserver_config.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: gs
3 | bucket: "studioml-meta"
4 | projectId: "studio-ed756"
5 |
6 | authentication: none
7 |
8 | server:
9 | authentication: github
10 |
11 | storage:
12 | type: gcloud
13 | bucket: studio-ed756.appspot.com
14 |
15 |
16 | verbose: debug
17 |
18 |
19 |
--------------------------------------------------------------------------------
/tests/art_hello_world.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from studio import fs_tracker
3 |
4 | print(fs_tracker.get_artifact('f'))
5 | with open(fs_tracker.get_artifact('f'), 'r') as f:
6 | print(f.read())
7 |
8 | if len(sys.argv) > 1:
9 | with open(fs_tracker.get_artifact('f'), 'w') as f:
10 | f.write(sys.argv[1])
11 |
12 |
13 | sys.stdout.flush()
14 |
--------------------------------------------------------------------------------
/examples/tensorflow/helloworld.py:
--------------------------------------------------------------------------------
1 | import time
2 | import logging
3 | import tensorflow as tf
4 |
5 |
6 | s = tf.Session()
7 |
8 | x = tf.constant([1.0, 2.0])
9 | y = x * 2
10 |
11 | logging.basicConfig()
12 | logger = logging.get_logger('helloworld')
13 | logger.setLevel(10)
14 |
15 | while True:
16 | logger.info(s.run(y))
17 | time.sleep(10)
18 |
--------------------------------------------------------------------------------
/examples/general/report_system_info.py:
--------------------------------------------------------------------------------
1 | import psutil
2 |
3 | g = 2**30 + 0.0
4 | meminfo = psutil.virtual_memory()
5 | diskinfo = psutil.disk_usage('/')
6 | print("Cpu count = {}".format(psutil.cpu_count()))
7 | print("RAM: total {}g, free {}g".format(meminfo.total / g, meminfo.free / g))
8 | print("HDD: total {}g, free {}g".format(diskinfo.total / g, diskinfo.free / g))
9 |
--------------------------------------------------------------------------------
/studio/dependencies_policies/dependencies_policy.py:
--------------------------------------------------------------------------------
1 | class DependencyPolicy:
2 | """
3 | Abstract class representing some policy
4 | for generating Python packages dependencies
5 | to be used for submitted experiment.
6 | """
7 |
8 | def generate(self, resources_needed):
9 | raise NotImplementedError('Not implemented DependencyPolicy')
10 |
--------------------------------------------------------------------------------
/studio/templates/user_details.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block content %}
4 |
5 |
6 | {{ experimenttable(delete_button=true) }}
7 |
8 |
11 |
12 | {% endblock %}
13 |
--------------------------------------------------------------------------------
/tests/config_http_client.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: http
3 | serverUrl: "http://localhost:5000"
4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8
5 | authDomain: "studio-ed756.firebaseapp.com"
6 |
7 | saveMetricsFrequency: 1m
8 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes)
9 |
10 | cloud:
11 | cpus: 1
12 | gpus: 0
13 | ram: 4g
14 | hdd: 10g
15 |
--------------------------------------------------------------------------------
/studio/templates/all_experiments.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block content %}
4 |
5 |
6 |
7 | {{ experimenttable() }}
8 |
9 |
14 |
15 |
16 |
17 | {% endblock %}
18 |
--------------------------------------------------------------------------------
/studioml.bib:
--------------------------------------------------------------------------------
1 | @software{StudioML,
2 | author = {Zhokhov, Peter and Denissov, Andrei and Mutch, Karl},
3 | title = {Studio},
4 | howpublished = {\url{https://studio.ml}},
5 | publisher={Cognizant Evolutionary AI},
6 | version = {0.0.48},
7 | url = {https://github.com/studioml/studio/tree/0.0.48},
8 | year = {2017--2021},
9 | annotate = {Studio: Simplify and expedite the model building process}
10 | }
11 |
--------------------------------------------------------------------------------
/studioml.bibtex:
--------------------------------------------------------------------------------
1 | @software{StudioML,
2 | author = {Zhokhov, Peter and Denissov, Andrei and Mutch, Karl},
3 | title = {Studio},
4 | howpublished = {\url{https://studio.ml}},
5 | publisher={Cognizant Evolutionary AI},
6 | version = {0.0.48},
7 | url = {https://github.com/studioml/studio/tree/0.0.48},
8 | year = {2017--2021},
9 | annotate = {Studio: Simplify and expedite the model building process}
10 | }
11 |
--------------------------------------------------------------------------------
/tests/test_config_s3_storage.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: FireBase
3 |
4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8
5 | authDomain: "{}.firebaseapp.com"
6 | databaseURL: "https://{}.firebaseio.com"
7 | projectId: studio-ed756
8 | storageBucket: "{}.appspot.com"
9 | messagingSenderId: 81790704397
10 |
11 | guest: true
12 |
13 | storage:
14 | type: s3
15 | bucket: "studioml-artifacts"
16 |
17 |
--------------------------------------------------------------------------------
/examples/general/print_norm_linreg.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | from studio import fs_tracker
4 | import pickle
5 |
6 |
7 | weights_list = sorted(
8 | glob.glob(
9 | os.path.join(
10 | fs_tracker.get_artifact('w'),
11 | '*.pck')))
12 |
13 | print('*****')
14 | print(weights_list[-1])
15 | with open(weights_list[-1], 'r') as f:
16 | w = pickle.load(f)
17 |
18 | print(w.dot(w))
19 | print('*****')
20 |
--------------------------------------------------------------------------------
/tests/test_config_gcloud_storage.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: FireBase
3 |
4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8
5 | authDomain: "{}.firebaseapp.com"
6 | databaseURL: "https://{}.firebaseio.com"
7 | projectId: studio-ed756
8 | storageBucket: "{}.appspot.com"
9 | messagingSenderId: 81790704397
10 |
11 | guest: true
12 |
13 | storage:
14 | type: gcloud
15 | bucket: "studio-ed756.appspot.com"
16 |
17 |
--------------------------------------------------------------------------------
/tests/test_config_http_client.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: http
3 | serverUrl: "https://studio-ed756.appspot.com"
4 | authentication:
5 | type: none
6 |
7 | log:
8 | name: output.log
9 |
10 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes)
11 | saveMetricsFrequency: 1m
12 |
13 | verbose: debug
14 | cloud:
15 | gcloud:
16 | zone: us-east1-c
17 |
18 | resources_needed:
19 | cpus: 2
20 | ram: 3g
21 | hdd: 60g
22 | gpus: 0
23 |
24 |
--------------------------------------------------------------------------------
/studio/payload_builders/unencrypted_payload_builder.py:
--------------------------------------------------------------------------------
1 | from studio.payload_builders.payload_builder import PayloadBuilder
2 | from studio.experiments.experiment import Experiment
3 |
4 | class UnencryptedPayloadBuilder(PayloadBuilder):
5 | """
6 | Simple payload builder constructing
7 | unencrypted experiment payloads.
8 | """
9 | def construct(self, experiment: Experiment, config, packages):
10 | return { 'experiment': experiment.to_dict(),
11 | 'config': config}
12 |
--------------------------------------------------------------------------------
/tests/env_detect.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 | AWSInstance = "aws" in list(
4 | yaml.load(
5 | open(
6 | "tests/test_config.yaml",
7 | "r"), Loader=yaml.SafeLoader)["cloud"].keys())
8 | GcloudInstance = "gcloud" in list(
9 | yaml.load(
10 | open(
11 | "tests/test_config.yaml",
12 | "r"), Loader=yaml.SafeLoader)["cloud"].keys())
13 |
14 |
15 | def on_gcp():
16 | return GcloudInstance
17 |
18 |
19 | def on_aws():
20 | return AWSInstance
21 |
--------------------------------------------------------------------------------
/studio/payload_builders/payload_builder.py:
--------------------------------------------------------------------------------
1 | class PayloadBuilder:
2 | """
3 | Abstract class representing
4 | payload object construction from experiment components.
5 | Result is payload ready to be submitted for execution.
6 | """
7 | def __init__(self, name: str):
8 | self.name = name if name else 'NO NAME'
9 |
10 | def construct(self, experiment, config, packages):
11 | raise NotImplementedError(
12 | 'Not implemented for payload builder {0}'.format(self.name))
13 |
--------------------------------------------------------------------------------
/docs/testing.rst:
--------------------------------------------------------------------------------
1 | Tests only be run from the in the main studio directory
2 | use "python -m pytest tests/{testname}" to run individual tests
3 | use "python -m pytest tests" to run all tests
4 |
5 | Be sure to install additional Python dependencies for running tests
6 | by:
7 |
8 | pip install -r test_requirements.txt
9 |
10 | Also you would have to set cloud environment and credentials in test_config.yaml
11 | prior to running tests.
12 | Verify also that "studio" executable is in your PATH
13 | to be able to run local_worker test.
14 |
15 |
16 |
--------------------------------------------------------------------------------
/studio/completion_service/completion_service_testfunc_files.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import six
3 | from studio import util
4 |
5 |
6 | def clientFunction(args, files):
7 | print('client function call with args ' +
8 | str(args) + ' and files ' + str(files))
9 |
10 | cs_files = {'output', 'clientscript', 'args'}
11 | filehashes = {
12 | k: util.filehash(
13 | v,
14 | hashobj=hashlib.md5()) for k,
15 | v in six.iteritems(files) if k not in cs_files}
16 |
17 | return (args, filehashes)
18 |
--------------------------------------------------------------------------------
/studio/util/logs.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | DEBUG = logging.DEBUG
4 | INFO = logging.INFO
5 | ERROR = logging.ERROR
6 |
7 | logging.basicConfig(
8 | format='%(asctime)s %(levelname)-6s %(name)s - %(message)s',
9 | level=ERROR,
10 | datefmt='%Y-%m-%d %H:%M:%S')
11 |
12 |
13 | def get_logger(name):
14 | return logging.getLogger(name)
15 |
16 |
17 | def debug(line):
18 | return logging.debug(line)
19 |
20 |
21 | def error(line):
22 | return logging.error(line)
23 |
24 |
25 | def info(line):
26 | return logging.info(line)
27 |
--------------------------------------------------------------------------------
/tests/save_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from keras.layers import Dense
5 | from keras.models import Sequential
6 |
7 | from studio import fs_tracker
8 |
9 | model = Sequential()
10 | model.add(Dense(2, input_shape=(2,)))
11 |
12 | weights = model.get_weights()
13 | new_weights = [np.array([[2, 0], [0, 2]])]
14 | # print weights
15 | # new_weights = []
16 | # for weight in weights:
17 | # new_weights.append(weight + 1)
18 |
19 | model.set_weights(new_weights)
20 | model.save(os.path.join(fs_tracker.get_model_directory(), 'weights.h5'))
21 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04
2 |
3 | # add tensorflow-gpu to use with gpu to sudo pip install
4 | # to use on linux machines with gpus
5 | RUN apt-get update && \
6 | apt-get -y install python-pip python-dev python3-pip python3-dev python3 git wget && \
7 | python -m pip install --upgrade pip && \
8 | python3 -m pip install --upgrade pip
9 |
10 | COPY . /studio
11 | RUN cd studio && \
12 | python -m pip install -e . --upgrade && \
13 | python3 -m pip install -e . --upgrade
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/tests/test_bad_config.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: FireBase
3 |
4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8
5 | authDomain: "{}.firebaseapp.com"
6 | databaseURL: "https://{}.firebaseio.com"
7 | projectId: studio-e756
8 | storageBucket: "{}.appspot.com"
9 | messagingSenderId: 81790704397
10 |
11 | guest: true
12 |
13 | log:
14 | name: output.log
15 |
16 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes)
17 | saveMetricsFrequency: 1m
18 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
5 |
6 | version: 2
7 | updates:
8 | - package-ecosystem: "github-actions" # See documentation for possible values
9 | directory: "/" # Location of package manifests
10 | schedule:
11 | interval: "weekly"
12 |
--------------------------------------------------------------------------------
/studio/torch/summary_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from io import StringIO
3 |
4 | from studio.torch import summary
5 |
6 |
7 | class ReporterTest(unittest.TestCase):
8 |
9 | def test_summary_report(self):
10 | r = summary.Reporter(log_interval=2, smooth_interval=2)
11 | out = StringIO()
12 | r.add(0, 'k', 0.1)
13 | r.add(1, 'k', 0.2)
14 | r.report()
15 | r.add(2, 'k', 0.3)
16 | r.report(out)
17 | self.assertEqual(out.getvalue(), "Step 2: k = 0.25000")
18 |
19 |
20 | if __name__ == "__main__":
21 | unittest.main()
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pip
2 |
3 | setuptools_scm
4 | setuptools_scm_git_archive
5 |
6 | configparser
7 |
8 | numpy
9 |
10 | flask
11 |
12 | cma
13 |
14 | psutil
15 |
16 | apscheduler
17 | openssh_key_parser
18 | pycryptodome
19 | sshpubkeys
20 | PyNaCl
21 | requests
22 | requests_toolbelt
23 | python_jwt
24 | sseclient
25 |
26 | terminaltables
27 |
28 | PyYAML
29 | pyhocon
30 |
31 | google-api-core
32 | google-api-python-client
33 | google-cloud-storage
34 | google-cloud-pubsub
35 | google-auth-httplib2
36 | oauth2client==3.0.0
37 |
38 | boto3
39 | rsa
40 |
41 | filelock
42 | pika >= 1.1.0
43 |
44 |
--------------------------------------------------------------------------------
/tests/runner_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from studio import model
4 |
5 |
6 | class RunnerTest(unittest.TestCase):
7 |
8 | def test_add_packages(self):
9 |
10 | list1 = ['keras==2.0.5', 'boto3==1.1.3']
11 | list2 = ['keras==1.0.9', 'h5py==2.7.0', 'abc']
12 |
13 | result = set(model.add_packages(list1, list2))
14 | expected_result = set(['boto3==1.1.3', 'h5py==2.7.0',
15 | 'keras==1.0.9', 'abc'])
16 |
17 | self.assertEqual(result, expected_result)
18 |
19 |
20 | if __name__ == '__main__':
21 | unittest.main()
22 |
--------------------------------------------------------------------------------
/Dockerfile_keras_example:
--------------------------------------------------------------------------------
1 | FROM ubuntu:16.04
2 |
3 | MAINTAINER jiamingjxu@gmail.com
4 |
5 | ENV LANG C.UTF-8
6 |
7 | RUN mkdir -p /setupTesting
8 |
9 | COPY . /setupTesting
10 |
11 | WORKDIR /setupTesting
12 |
13 | RUN apt-get update && \
14 | apt-get install -y python-pip libpq-dev python-dev && \
15 | apt-get install -y git && \
16 | pip install -U pytest && \
17 | pip install -r test_requirements.txt && \
18 | python setup.py build && \
19 | python setup.py install
20 |
21 | CMD studio run --lifetime=30m --max-duration=20m --gpus 4 --queue=rmq_kmutch --force-git /examples/keras/train_mnist_keras.py
--------------------------------------------------------------------------------
/tests/test_config_s3.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: s3
3 |
4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8
5 | authDomain: "{}.firebaseapp.com"
6 | databaseURL: "https://{}.firebaseio.com"
7 | projectId: studio-ed756
8 | storageBucket: "{}.appspot.com"
9 | messagingSenderId: 81790704397
10 |
11 | bucket: studioml-meta
12 | guest: true
13 | max_keys: -1
14 |
15 | log:
16 | name: output.log
17 |
18 | verbose: error
19 |
20 | cloud:
21 | gcloud:
22 | zone: us-central1-f
23 |
24 | resources_needed:
25 | cpus: 2
26 | ram: 3g
27 | hdd: 10g
28 | gpus: 0
29 |
30 |
--------------------------------------------------------------------------------
/tests/config_http_server.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: FireBase
3 |
4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8
5 | authDomain: "{}.firebaseapp.com"
6 | databaseURL: "https://{}.firebaseio.com"
7 | projectId: studio-ed756
8 | storageBucket: "{}.appspot.com"
9 | messagingSenderId: 81790704397
10 | serviceAccount: /Users/peter.zhokhov/gkeys/studio-ed756-firebase-adminsdk-4ug1o-7202e6402d.json
11 |
12 | storage:
13 | type: gcloud
14 | bucket: "studio-ed756.appspot.com"
15 |
16 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes)
17 | saveMetricsFrequency: 1m
18 |
19 |
20 |
--------------------------------------------------------------------------------
/studio/run_magic.py.stub:
--------------------------------------------------------------------------------
1 | import __main__ as _M
2 | import pickle
3 | import six
4 | import gzip
5 |
6 | from studio import fs_tracker
7 |
8 | with gzip.open(fs_tracker.get_artifact('_ns'), 'rb') as f:
9 | _M.__dict__.update(pickle.load(f))
10 |
11 | {script}
12 |
13 | ns_dict = _M.__dict__.copy()
14 | pickleable_ns = dict()
15 | for varname, var in six.iteritems(ns_dict):
16 | try:
17 | pickle.dumps(var)
18 | pickleable_ns[varname] = var
19 | except BaseException:
20 | pass
21 |
22 | with open(fs_tracker.get_artifact('_ns'), 'w') as f:
23 | f.write(pickle.dumps(pickleable_ns))
24 |
25 |
--------------------------------------------------------------------------------
/tests/gpu_util_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from studio.util.gpu_util import memstr2int
3 |
4 |
5 | class GpuUtilTest(unittest.TestCase):
6 |
7 | def test_memstr2int(self):
8 | self.assertEqual(memstr2int('123 Mb'), 123 * (2**20))
9 | self.assertEqual(memstr2int('456 MiB'), 456 * (2**20))
10 | self.assertEqual(memstr2int('23 Gb'), 23 * (2**30))
11 | self.assertEqual(memstr2int('23 GiB'), 23 * (2**30))
12 | self.assertEqual(memstr2int('23 '), 23)
13 |
14 | with self.assertRaises(ValueError):
15 | memstr2int('300 spartans')
16 |
17 |
18 | if __name__ == "__main__":
19 | unittest.main()
20 |
--------------------------------------------------------------------------------
/tests/test_config_env.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: FireBase
3 |
4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8
5 | authDomain: "{}.firebaseapp.com"
6 | databaseURL: "https://{}.firebaseio.com"
7 | projectId: studio-ed756
8 | storageBucket: "{}.appspot.com"
9 | messagingSenderId: 81790704397
10 |
11 | guest: true
12 |
13 | log:
14 | name: output.log
15 |
16 | test_key: $TEST_VAR1
17 | test_section:
18 | test_key: $TEST_VAR2
19 |
20 |
21 | cloud:
22 | type: google
23 | zone: us-central1-f
24 |
25 | cpus: 2
26 | ram: 3g
27 | hdd: 10g
28 | gpus: 0
29 |
30 |
--------------------------------------------------------------------------------
/tests/test_config_gs.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: gs
3 |
4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8
5 | authDomain: "{}.firebaseapp.com"
6 | databaseURL: "https://{}.firebaseio.com"
7 | projectId: studio-ed756
8 | storageBucket: "{}.appspot.com"
9 | messagingSenderId: 81790704397
10 |
11 | bucket: studioml-meta
12 | guest: true
13 | max_keys: -1
14 |
15 | log:
16 | name: output.log
17 |
18 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes)
19 | verbose: error
20 |
21 | cloud:
22 | gcloud:
23 | zone: us-central1-f
24 |
25 | resources_needed:
26 | cpus: 2
27 | ram: 3g
28 | hdd: 10g
29 | gpus: 0
30 |
31 |
--------------------------------------------------------------------------------
/studio/templates/dashboard.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block content %}
4 |
5 |
6 |
7 | {{ experimenttable() }}
8 |
9 |
21 | {% endblock %}
22 |
--------------------------------------------------------------------------------
/tests/test_config_auth.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: FireBase
3 |
4 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8
5 | authDomain: "{}.firebaseapp.com"
6 | databaseURL: "https://{}.firebaseio.com"
7 | projectId: studio-ed756
8 | storageBucket: "{}.appspot.com"
9 | messagingSenderId: 81790704397
10 |
11 | authentication:
12 | type: firebase
13 | use_email_auth: true
14 | email: authtest@sentient.ai
15 | password: HumptyDumptyS@tOnAWall
16 | authDomain: "studio-ed756.firebaseapp.com"
17 | apiKey: AIzaSyCLQbp5X2B4SWzBw-sz9rUnGHNSdMl0Yx8
18 |
19 | log:
20 | name: output.log
21 |
--------------------------------------------------------------------------------
/tests/test_config_datacenter.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: http
3 | serverUrl: https://studio-sentient.appspot.com
4 | projectId: studio-sentient
5 | compression: None
6 | guest: true
7 |
8 | log:
9 | name: output.log
10 |
11 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes)
12 | saveMetricsFrequency: 1m
13 |
14 | verbose: debug
15 | cloud:
16 | gcloud:
17 | zone: us-east1-c
18 |
19 | resources_needed:
20 | cpus: 2
21 | ram: 3g
22 | hdd: 10g
23 | gpus: 0
24 |
25 | env:
26 | AWS_ACCESS_KEY_ID: $AWS_ACCESS_KEY_ID
27 | AWS_SECRET_ACCESS_KEY: $AWS_SECRET_ACCESS_KEY
28 | AWS_DEFAULT_REGION: us-west-2
29 |
30 |
31 |
32 | healthcheck: python -c "import tensorflow"
33 |
34 |
--------------------------------------------------------------------------------
/studio/aws/aws_amis.yaml:
--------------------------------------------------------------------------------
1 | ubuntu16.04:
2 | # us-east-1: ami-da05a4a0 # Vanilla ubuntu 16.04
3 | us-east-1: ami-d91866a3 # studio.ml enabled
4 | # us-east-1: ami-f346c289 # aws deep learning ami
5 |
6 | us-east-2: ami-336b4456
7 | us-west-1: ami-1c1d217c
8 | # us-west-2: ami-0a00ce72 # Vanilla ubuntu 16.04
9 | us-west-2: ami-aa508ad2 # studio.ml enabled
10 |
11 | ca-central-1: ami-8a71c9ee
12 | eu-west-1: ami-add175d4
13 | eu-central-1: ami-97e953f8
14 | eu-west-2: ami-ecbea388
15 | ap-southeast-1: ami-67a6e604
16 | ap-southeast-2: ami-41c12e23
17 | ap-northeast-2: ami-7b1cb915
18 | ap-northeast-1: ami-15872773
19 | ap-south-1: ami-bc0d40d3
20 | sa-east-1: ami-466b132a
21 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/runtests.sh:
--------------------------------------------------------------------------------
1 | python tests/artifact_store_test.py
2 | python tests/cloud_worker_test.py
3 | python tests/fs_tracker_test.py
4 | python tests/git_util_test.py
5 | python tests/gpu_util_test.py
6 | python tests/http_provider_hosted_test.py
7 | python tests/http_provider_test.py
8 | python tests/hyperparam_test.py
9 | python tests/local_worker_test.py
10 | python tests/model_test.py
11 | python tests/model_util_test.py
12 | python tests/providers_test.py
13 | python tests/queue_test.py
14 | python tests/remote_worker_test.py
15 | python tests/runner_test.py
16 | python tests/serving_test.py
17 | python tests/util_test.py
18 |
19 |
--------------------------------------------------------------------------------
/studio/templates/project_details.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block content %}
4 |
5 |
6 | {{ experimenttable(showproject=false) }}
7 |
8 | {% if allow_tensorboard %}
9 |
10 | {%- endif %}
11 |
12 |
24 |
25 | {% endblock %}
26 |
--------------------------------------------------------------------------------
/studio/completion_service/completion_service_testfunc_saveload.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from studio import fs_tracker
4 |
5 |
6 | def clientFunction(args, files):
7 | print('client function call with args ' +
8 | str(args) + ' and files ' + str(files))
9 |
10 | modelfile = 'model.dat'
11 | filename = files.get('model') or \
12 | os.path.join(fs_tracker.get_artifact('modeldir'), modelfile)
13 |
14 | print("Trying to load file {}".format(filename))
15 |
16 | if os.path.exists(filename):
17 | with open(filename, 'rb') as f:
18 | args = pickle.loads(f.read()) + 1
19 |
20 | else:
21 | print("Trying to write file {}".format(filename))
22 | with open(filename, 'wb') as f:
23 | f.write(pickle.dumps(args, protocol=2))
24 |
25 | return args
26 |
27 |
28 | if __name__ == "__main__":
29 | clientFunction('test', {})
30 |
--------------------------------------------------------------------------------
/docs/authentication.rst:
--------------------------------------------------------------------------------
1 | Authentication
2 | ==============
3 |
4 | Currently, Studio uses GitHub auth for authentication. For command-line tools
5 | (studio ui, studio run etc) the authentication is done via personal
6 | access tokens. When no token is present (e.g. when you run studio for
7 | the first time), you will be asked to input your github username and
8 | password. Studio DOES NOT store your username or password, instead,
9 | those are being sent to GitHub API server in exchange to an access token.
10 | The access token is being saved and used from that point on.
11 | The personal access tokens do not expire, can be transferred from one
12 | machine to another, and, if necessary, can be revoked by going to
13 | GitHub -> Settings -> Developer Settings -> Personal Access Tokens
14 |
15 | Authentication for the hosted UI server (https://zoo.studio.ml) follows
16 | the standard GitHub Auth flow for web apps.
17 |
--------------------------------------------------------------------------------
/studio/default_config.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: http
3 | serverUrl: https://zoo.studio.ml
4 | authentication: github
5 |
6 | storage:
7 | type: gcloud
8 | bucket: studio-ed756.appspot.com
9 |
10 |
11 |
12 | queue: local
13 |
14 | saveMetricsFrequency: 1m
15 | saveWorkspaceFrequency: 1m
16 | verbose: error
17 |
18 | cloud:
19 | gcloud:
20 | zone: us-east1-c
21 |
22 | resources_needed:
23 | cpus: 2
24 | ram: 3g
25 | hdd: 60g
26 | gpus: 0
27 |
28 | sleep_time: 1
29 | worker_timeout: 30
30 |
31 | optimizer:
32 | cmaes_config:
33 | popsize: 100
34 | sigma0: 0.25
35 | load_best_only: false
36 | load_checkpoint_file:
37 | visualization: true
38 | result_dir: "~/Desktop/"
39 | checkpoint_interval: 0
40 | termination_criterion:
41 | generation: 5
42 | fitness: 999
43 | skip_gen_thres: 1.0
44 | skip_gen_timeout: 30
45 |
--------------------------------------------------------------------------------
/docs/ci_testing.rst:
--------------------------------------------------------------------------------
1 | Requirements: Docker, Dockerhub Account, Kubernetes, Keel
2 |
3 | https://docs.docker.com/install/
4 |
5 | https://keel.sh/v1/guide/installation.html
6 |
7 | https://kubernetes.io/docs/tasks/tools/install-kubectl/
8 |
9 | https://keel.sh/v1/guide/installation.html
10 |
11 | To run individual tests, edit the Dockerfile_standalone_testing.
12 |
13 | After editing build the image using
14 | "docker image build --tag [dockerhubUsername]/standalone_testing:latest . -f Dockerfile_standalone_testing"
15 |
16 | May have to use sudo
17 |
18 | Push the image to your docker account with
19 |
20 | "docker push [dockerhubUsername]/standalone_testing"
21 |
22 | Then to run the tests edit the test-runner.yaml:56 to
23 |
24 | "- image: [dockerhubUsername]/standalone_testing"
25 |
26 | Finally use "kubectl apply -f test-runner.yaml" to automatically run tests,
27 |
28 | results can be seen using "kubectl log test-runner-xxxxxxx-xxxxx"
29 |
--------------------------------------------------------------------------------
/Dockerfile_standalone_testing:
--------------------------------------------------------------------------------
1 | FROM ubuntu:16.04
2 |
3 | MAINTAINER jiamingjxu@gmail.com
4 |
5 | ENV LANG C.UTF-8
6 |
7 | RUN mkdir -p /setupTesting
8 |
9 | COPY . /setupTesting
10 |
11 | WORKDIR /setupTesting
12 |
13 | RUN apt-get update && apt-get install -y \
14 | curl
15 |
16 | RUN \
17 | apt-get update && apt-get install -y apt-transport-https && \
18 | curl -s https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - && \
19 | echo "deb https://apt.kubernetes.io/ kubernetes-xenial main" | tee -a /etc/apt/sources.list.d/kubernetes.list && \
20 | apt-get update && \
21 | apt-get install -y kubectl
22 |
23 | RUN apt-get update && \
24 | apt-get install -y python-pip libpq-dev python-dev && \
25 | apt-get install -y git && \
26 | pip install -U pytest && \
27 | pip install -r test_requirements.txt && \
28 | python setup.py build && \
29 | python setup.py install
30 |
31 | CMD python -m pytest tests/util_test.py
32 |
--------------------------------------------------------------------------------
/studio/client_config.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: http
3 | serverUrl: http://localhost:5000
4 | authentication:
5 | type: github
6 |
7 | storage:
8 | type: gcloud
9 | bucket: studioml-sentient-artifacts
10 |
11 |
12 |
13 | queue: local
14 |
15 | saveMetricsFrequency: 1m
16 | saveWorkspaceFrequency: 1m
17 | verbose: error
18 |
19 | cloud:
20 | gcloud:
21 | zone: us-central1-f
22 |
23 | resources_needed:
24 | cpus: 2
25 | ram: 3g
26 | hdd: 10g
27 | gpus: 0
28 |
29 | sleep_time: 1
30 | worker_timeout: 30
31 |
32 | optimizer:
33 | cmaes_config:
34 | popsize: 100
35 | sigma0: 0.25
36 | load_best_only: false
37 | load_checkpoint_file:
38 | visualization: true
39 | result_dir: "~/Desktop/"
40 | checkpoint_interval: 0
41 | termination_criterion:
42 | generation: 5
43 | fitness: 999
44 | skip_gen_thres: 1.0
45 | skip_gen_timeout: 30
46 |
--------------------------------------------------------------------------------
/docs/ec2_setup.rst:
--------------------------------------------------------------------------------
1 | Setting up Amazon EC2
2 | =====================
3 |
4 | This page describes the process of configuring Studio to work
5 | with Amazon EC2. We assume that you already have AWS credentials
6 | and an AWS account set up.
7 |
8 | Install boto3
9 | -------------
10 |
11 | Studio interacts with AWS via the boto3 API. Thus, in order to use EC2
12 | cloud you'll need to install boto3:
13 |
14 | ::
15 |
16 | pip install boto3
17 |
18 | Set up credentials
19 | ------------------
20 |
21 | Add credentials to a location where boto3 can access them. The
22 | recommended way is to install the AWS CLI:
23 |
24 | ::
25 |
26 | pip install awscli
27 |
28 | and then run
29 |
30 | ::
31 |
32 | aws configure
33 |
34 | and enter your AWS credentials and region. The output format cam be left as
35 | None. Alternatively, use any method of letting boto3 know the
36 | credentials described here:
37 | http://boto3.readthedocs.io/en/latest/guide/configuration.html
38 |
--------------------------------------------------------------------------------
/tests/fs_tracker_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 |
4 | from studio import fs_tracker
5 |
6 |
7 | class StudioLoggingTest(unittest.TestCase):
8 |
9 | def test_get_model_directory_args(self):
10 | experimentName = 'testExperiment'
11 | modelDir = fs_tracker.get_model_directory(experimentName)
12 | self.assertTrue(
13 | modelDir == os.path.join(
14 | os.path.expanduser('~'),
15 | '.studioml/experiments/testExperiment/modeldir'))
16 |
17 | def test_get_model_directory_noargs(self):
18 | testExperiment = 'testExperiment'
19 | testPath = os.path.join(
20 | os.path.expanduser('~'),
21 | '.studioml/experiments',
22 | testExperiment, 'modeldir')
23 |
24 | os.environ['STUDIOML_EXPERIMENT'] = testExperiment
25 | self.assertTrue(testPath == fs_tracker.get_model_directory())
26 |
27 |
28 | if __name__ == "__main__":
29 | unittest.main()
30 |
--------------------------------------------------------------------------------
/examples/general/train_linreg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import os
4 |
5 | no_samples = 100
6 | dim_samples = 5
7 |
8 | learning_rate = 0.01
9 | no_steps = 10
10 |
11 | X = np.random.random((no_samples, dim_samples))
12 | y = np.random.random((no_samples,))
13 |
14 | w = np.random.random((dim_samples,))
15 |
16 | for step in range(no_steps):
17 | yhat = X.dot(w)
18 | err = (yhat - y)
19 | dw = err.dot(X)
20 | w -= learning_rate * dw
21 | loss = 0.5 * err.dot(err)
22 |
23 | print("step = {}, loss = {}, L2 norm = {}".format(step, loss, w.dot(w)))
24 |
25 | # with open(os.path.expanduser('~/weights/lr_w_{}_{}.pck'
26 | # .format(step, loss)), 'w') as f:
27 | # f.write(pickle.dumps(w))
28 |
29 | from studio import fs_tracker
30 | with open(os.path.join(fs_tracker.get_artifact('weights'),
31 | 'lr_w_{}_{}.pck'.format(step, loss)),
32 | 'w') as f:
33 | f.write(pickle.dumps(w))
34 |
--------------------------------------------------------------------------------
/studio/server_config.yaml:
--------------------------------------------------------------------------------
1 | database:
2 | type: gs
3 | bucket: studioml-sentient-meta
4 | projectId: "studio-sentient"
5 |
6 | authentication:
7 | type: github
8 |
9 | storage:
10 | type: gcloud
11 | bucket: studioml-sentient-artifacts
12 |
13 |
14 |
15 | queue: local
16 |
17 | saveMetricsFrequency: 1m
18 | saveWorkspaceFrequency: 1m
19 | verbose: error
20 |
21 | cloud:
22 | gcloud:
23 | zone: us-central1-f
24 |
25 | resources_needed:
26 | cpus: 2
27 | ram: 3g
28 | hdd: 10g
29 | gpus: 0
30 |
31 | sleep_time: 1
32 | worker_timeout: 30
33 |
34 | optimizer:
35 | cmaes_config:
36 | popsize: 100
37 | sigma0: 0.25
38 | load_best_only: false
39 | load_checkpoint_file:
40 | visualization: true
41 | result_dir: "~/Desktop/"
42 | checkpoint_interval: 0
43 | termination_criterion:
44 | generation: 5
45 | fitness: 999
46 | skip_gen_thres: 1.0
47 | skip_gen_timeout: 30
48 |
--------------------------------------------------------------------------------
/tests/test_config.yaml:
--------------------------------------------------------------------------------
1 | ## NOTE: to run unit tests in your environment,
2 | ## please provide your own credentials for S3 and RMQ servers access!
3 | database:
4 | type: s3
5 | authentication: none
6 | aws_access_key: *********
7 | aws_secret_key: **************
8 | bucket: test-database
9 | endpoint: http://******************************/
10 |
11 | env:
12 | AWS_DEFAULT_REGION: us-west-2
13 |
14 | server:
15 | authentication: None
16 |
17 | storage:
18 | type: s3
19 | aws_access_key: *********
20 | aws_secret_key: **************
21 | bucket: test-storage
22 | endpoint: http://******************************/
23 |
24 | log:
25 | name: output.log
26 |
27 | saveWorkspaceFrequency: 1m #how often is workspace being saved (minutes)
28 | saveMetricsFrequency: 1m
29 | verbose: error
30 |
31 | cloud:
32 | queue:
33 | rmq: amqp://**********************************
34 |
35 | resources_needed:
36 | cpus: 2
37 | ram: 3g
38 | hdd: 60g
39 | gpus: 0
40 |
41 |
--------------------------------------------------------------------------------
/studio/patches/requests/models.py.patch:
--------------------------------------------------------------------------------
1 | --- lib/requests/models.py 2018-01-03 14:10:27.000000000 -0800
2 | +++ models.py 2018-01-03 14:09:01.000000000 -0800
3 | @@ -742,8 +742,16 @@
4 | # Special case for urllib3.
5 | if hasattr(self.raw, 'stream'):
6 | try:
7 | - for chunk in self.raw.stream(chunk_size, decode_content=True):
8 | - yield chunk
9 | + if isinstance(self.raw._original_response._method, int):
10 | + while True:
11 | + chunk = self.raw.read(chunk_size, decode_content=True)
12 | + if not chunk:
13 | + break
14 | + yield chunk
15 | + else:
16 | + for chunk in self.raw.stream(chunk_size, decode_content=True):
17 | + yield chunk
18 | +
19 | except ProtocolError as e:
20 | raise ChunkedEncodingError(e)
21 | except DecodeError as e:
22 |
--------------------------------------------------------------------------------
/docs/cli.rst:
--------------------------------------------------------------------------------
1 | ======================
2 | Command-line interface
3 | ======================
4 |
5 | In some cases, a (semi-)programmatic way of keeping track of experiments may be preferred. On top of the Python and HTTP API, we provide
6 | a command-line tool to get a quick overview of exising experiments and take actions on them. Commands available
7 | at the moment are:
8 |
9 | - ``studio runs list users`` - lists all users
10 | - ``studio runs list projects`` - lists all projects
11 | - ``studio runs list [user]`` - lists your (default) or someone else's experiments
12 | - ``studio runs list project `` - lists all experiments in a project
13 | - ``studio runs list all`` - lists all experiments
14 |
15 | - ``studio runs kill `` - deletes experiment
16 | - ``studio runs stop `` - stops experiment
17 |
18 | Note that for now if the experiment is running, killing it will NOT automatically stop the runner. You should stop the experiment first, ensure its status has been changed to stopped, and then kill it. This is a known issue, and we are working on a solution.
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/studio/rmq_config.yml:
--------------------------------------------------------------------------------
1 | database:
2 | type: s3
3 | endpoint: https://s3-us-west-2.amazonaws.com/
4 | bucket: "karl-mutch-rmq"
5 | compression: None
6 | authentication: none
7 |
8 | storage:
9 | type: s3
10 | endpoint: https://s3-us-west-2.amazonaws.com/
11 | bucket: "karl-mutch-rmq"
12 | compression: None
13 |
14 | runner:
15 | slack_destination: "@karl.mutch"
16 |
17 | cloud:
18 | queue:
19 | rmq: "amqp://guest:guest@localhost:5672/"
20 |
21 | verbose: debug
22 | saveWorkspaceFrequency: 3m
23 | experimentLifetime: 20m
24 |
25 | resources_needed:
26 | cpus: 1
27 | gpus: 1
28 | hdd: 3gb
29 | ram: 2gb
30 | gpuMem: 3gb
31 |
32 | env:
33 | AWS_ACCESS_KEY_ID: **removed**
34 | AWS_DEFAULT_REGION: us-west-2
35 | AWS_SECRET_ACCESS_KEY: **removed**
36 | PATH: "%PATH%:./bin"
37 |
38 | pip:
39 | - keras==2.0.8
40 |
41 | optimizer:
42 | cmaes_config:
43 | popsize: 100
44 | sigma0: 0.25
45 | load_best_only: false
46 | load_checkpoint_file:
47 | visualization: true
48 | result_dir: "~/Desktop/"
49 | checkpoint_interval: 0
50 | termination_criterion:
51 | generation: 5
52 | fitness: 999
53 | skip_gen_thres: 1.0
54 | skip_gen_timeout: 30
55 |
--------------------------------------------------------------------------------
/tests/git_util_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import tempfile
3 | import os
4 | import subprocess
5 | import uuid
6 | import re
7 | from studio import git_util
8 |
9 |
10 | class GitUtilTest(unittest.TestCase):
11 |
12 | def test_is_git(self):
13 | self.assertTrue(git_util.is_git())
14 |
15 | def test_is_not_git(self):
16 | self.assertFalse(git_util.is_git(tempfile.gettempdir()))
17 |
18 | def test_is_not_clean(self):
19 | filename = str(uuid.uuid4())
20 | subprocess.call(['touch', filename])
21 | is_clean = git_util.is_clean()
22 | os.remove(filename)
23 | self.assertFalse(is_clean)
24 |
25 | @unittest.skipIf(os.environ.get('TEST_GIT_REPO_ADDRESS') != 1,
26 | 'skip if being tested from a forked repo')
27 | def test_repo_url(self):
28 | expected = re.compile(
29 | 'https{0,1}://github\.com/studioml/studio(\.git){0,1}')
30 | expected2 = re.compile(
31 | 'git@github\.com:studioml/studio(\.git){0,1}')
32 | actual = git_util.get_repo_url(remove_user=True)
33 | self.assertTrue(
34 | (expected.match(actual) is not None) or
35 | (expected2.match(actual) is not None))
36 |
37 |
38 | if __name__ == "__main__":
39 | unittest.main()
40 |
--------------------------------------------------------------------------------
/studio/storage/storage_setup.py:
--------------------------------------------------------------------------------
1 | from studio.storage.storage_handler import StorageHandler
2 | from studio.storage.storage_type import StorageType
3 | from studio.util.logs import INFO
4 |
5 | DB_KEY = "database"
6 | STORE_KEY = "store"
7 |
8 | # Global dictionary which keeps Database Provider
9 | # and Artifact Store objects created from experiment configuration.
10 | _storage_setup = None
11 |
12 | _storage_verbose_level = INFO
13 |
14 | def setup_storage(db_provider, artifact_store):
15 | global _storage_setup
16 | _storage_setup = { DB_KEY: db_provider, STORE_KEY: artifact_store }
17 |
18 | def get_storage_db_provider():
19 | global _storage_setup
20 | if _storage_setup is None:
21 | return None
22 | return _storage_setup.get(DB_KEY, None)
23 |
24 | def get_storage_artifact_store():
25 | global _storage_setup
26 | if _storage_setup is None:
27 | return None
28 | return _storage_setup.get(STORE_KEY, None)
29 |
30 | def reset_storage():
31 | global _storage_setup
32 | _storage_setup = None
33 |
34 | def get_storage_verbose_level():
35 | global _storage_verbose_level
36 | return _storage_verbose_level
37 |
38 | def set_storage_verbose_level(level: int):
39 | global _storage_verbose_level
40 | _storage_verbose_level = level
41 |
42 |
43 |
--------------------------------------------------------------------------------
/studio/templates/projects.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block content %}
4 |
5 |
6 |
7 |
38 | {% endblock %}
39 |
--------------------------------------------------------------------------------
/examples/keras/train_mnist.py:
--------------------------------------------------------------------------------
1 |
2 | import tensorflow as tf
3 | import tensorflow_datasets as tfds
4 |
5 | (ds_train, ds_test), ds_info = tfds.load(
6 | 'mnist',
7 | split=['train', 'test'],
8 | shuffle_files=True,
9 | as_supervised=True,
10 | with_info=True,
11 | )
12 |
13 |
14 | def normalize_img(image, label):
15 | """Normalizes images: `uint8` -> `float32`."""
16 | return tf.cast(image, tf.float32) / 255., label
17 |
18 |
19 | ds_train = ds_train.map(
20 | normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
21 | ds_train = ds_train.cache()
22 | ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
23 | ds_train = ds_train.batch(128)
24 | ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
25 |
26 | ds_test = ds_test.map(
27 | normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
28 | ds_test = ds_test.batch(128)
29 | ds_test = ds_test.cache()
30 | ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
31 |
32 |
33 | model = tf.keras.models.Sequential([
34 | tf.keras.layers.Flatten(input_shape=(28, 28)),
35 | tf.keras.layers.Dense(128, activation='relu'),
36 | tf.keras.layers.Dense(10)
37 | ])
38 | model.compile(
39 | optimizer=tf.keras.optimizers.Adam(0.001),
40 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
41 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
42 | )
43 |
44 | model.fit(
45 | ds_train,
46 | epochs=6,
47 | validation_data=ds_test,
48 | )
49 |
50 | tf.saved_model.save(model, "../modeldir")
51 |
--------------------------------------------------------------------------------
/docs/local_filesystem_setup.rst:
--------------------------------------------------------------------------------
1 | Setting up experiment storage and database in local filesystem
2 | ==============================================================
3 |
4 | This page describes how to setup studioml to use
5 | local filesystem for storing experiment artifacts and meta-data.
6 | With this option, there is no need to setup any external
7 | connection to S3/Minio/GCS etc.
8 |
9 | StudioML configuration
10 | --------------------
11 |
12 | ::
13 |
14 | "studio_ml_config": {
15 |
16 | ...
17 |
18 | "database": {
19 | "type": "local",
20 | "endpoint": SOME_DB_LOCAL_PATH,
21 | "bucket": DB_BUCKET_NAME,
22 | "authentication": "none"
23 | },
24 | "storage": {
25 | "type": "local",
26 | "endpoint": SOME_ARTIFACTS_LOCAL_PATH,
27 | "bucket": ARTIFACTS_BUCKET_NAME,
28 | }
29 |
30 | ...
31 | }
32 |
33 |
34 | With StudioML database type set to "local",
35 | all experiment meta-data will be stored locally under
36 | directory: SOME_DB_LOCAL_PATH/DB_BUCKET_NAME.
37 | Similarly, with storage type set to "local",
38 | all experiment artifacts will be stored locally under
39 | directory: SOME_ARTIFACTS_LOCAL_PATH/ARTIFACTS_BUCKET_NAME.
40 |
41 | Note: if you are using "local" mode, it is recommended to use it
42 | for both storage and database configuration.
43 | But it's technically possible to mix, for example, local storage configuration
44 | and S3-based database configuration etc.
45 |
46 |
--------------------------------------------------------------------------------
/studio/aws/aws_prices.yaml:
--------------------------------------------------------------------------------
1 | t2.nano : 0.0058
2 | t2.micro : 0.0116
3 | t2.small : 0.023
4 | t2.medium : 0.0464
5 | t2.large : 0.0928
6 | t2.xlarge : 0.1856
7 | t2.2xlarge : 0.3712
8 | m4.large : 0.1
9 | m4.xlarge : 0.2
10 | m4.2xlarge : 0.4
11 | m4.4xlarge : 0.8
12 | m4.10xlarge : 2
13 | m4.16xlarge : 3.2
14 | m3.medium : 0.067
15 | m3.large : 0.133
16 | m3.xlarge : 0.266
17 | m3.2xlarge : 0.532
18 | c5.large : 0.085
19 | c5.xlarge : 0.17
20 | c5.2xlarge : 0.34
21 | c5.4xlarge : 0.68
22 | c5.9xlarge : 1.53
23 | c5.18xlarge : 3.06
24 | c4.large : 0.1
25 | c4.xlarge : 0.199
26 | c4.2xlarge : 0.398
27 | c4.4xlarge : 0.796
28 | c4.8xlarge : 1.591
29 | c3.large : 0.105
30 | c3.xlarge : 0.21
31 | c3.2xlarge : 0.42
32 | c3.4xlarge : 0.84
33 | c3.8xlarge : 1.68
34 | p2.xlarge : 0.9
35 | p2.8xlarge : 7.2
36 | p2.16xlarge : 14.4
37 | p3.2xlarge : 3.06
38 | p3.8xlarge : 12.24
39 | p3.16xlarge : 24.48
40 | g2.2xlarge : 0.65
41 | g2.8xlarge : 2.6
42 | g3.4xlarge : 1.14
43 | g3.8xlarge : 2.28
44 | g3.16xlarge : 4.56
45 | f1.2xlarge : 1.65
46 | f1.16xlarge : 13.2
47 | x1.16xlarge : 6.669
48 | x1.32xlarge : 13.338
49 | x1e.32xlarge : 26.688
50 | r3.large : 0.166
51 | r3.xlarge : 0.333
52 | r3.2xlarge : 0.665
53 | r3.4xlarge : 1.33
54 | r3.8xlarge : 2.66
55 | r4.large : 0.133
56 | r4.xlarge : 0.266
57 | r4.2xlarge : 0.532
58 | r4.4xlarge : 1.064
59 | r4.8xlarge : 2.128
60 | r4.16xlarge : 4.256
61 | i3.large : 0.156
62 | i3.xlarge : 0.312
63 | i3.2xlarge : 0.624
64 | i3.4xlarge : 1.248
65 | i3.8xlarge : 2.496
66 | i3.16xlarge : 4.992
67 | d2.xlarge : 0.69
68 | d2.2xlarge : 1.38
69 | d2.4xlarge : 2.76
70 | d2.8xlarge : 5.52
71 |
--------------------------------------------------------------------------------
/studio/queues/queues_setup.py:
--------------------------------------------------------------------------------
1 | """Data providers."""
2 | import uuid
3 |
4 | from studio.queues.local_queue import LocalQueue
5 | from studio.queues.sqs_queue import SQSQueue
6 | from studio.queues.qclient_cache import get_cached_queue, shutdown_cached_queue
7 |
8 | def get_queue(
9 | queue_name=None,
10 | cloud=None,
11 | config=None,
12 | logger=None,
13 | close_after=None,
14 | verbose=10):
15 | _ = verbose
16 | if queue_name is None:
17 | if cloud in ['gcloud', 'gcspot']:
18 | queue_name = 'pubsub_' + str(uuid.uuid4())
19 | elif cloud in ['ec2', 'ec2spot']:
20 | queue_name = 'sqs_' + str(uuid.uuid4())
21 | else:
22 | queue_name = 'local_' + str(uuid.uuid4())
23 |
24 | if queue_name.startswith('ec2') or \
25 | queue_name.startswith('sqs'):
26 | return SQSQueue(queue_name, config=config, logger=logger)
27 | if queue_name.startswith('rmq_'):
28 | return get_cached_queue(
29 | name=queue_name,
30 | route='StudioML.' + queue_name,
31 | config=config,
32 | close_after=close_after,
33 | logger=logger)
34 | if queue_name.startswith('local'):
35 | return LocalQueue(queue_name, logger=logger)
36 | return None
37 |
38 | def shutdown_queue(queue, logger=None, delete_queue=True):
39 | if queue is None:
40 | return
41 | queue_name = queue.get_name()
42 | if queue_name.startswith("rmq_"):
43 | shutdown_cached_queue(queue, logger, delete_queue)
44 | else:
45 | queue.shutdown(delete_queue)
46 |
--------------------------------------------------------------------------------
/studio/templates/users.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block content %}
4 | Users
5 |
6 |
7 |
47 | {% endblock %}
48 |
--------------------------------------------------------------------------------
/studio/postgres_provider.py:
--------------------------------------------------------------------------------
1 | class PostgresProvider(object):
2 | """Data provider for Postgres."""
3 |
4 | def __init__(self, connection_uri):
5 | # TODO: implement connection
6 | pass
7 |
8 | def add_experiment(self, experiment):
9 | raise NotImplementedError()
10 |
11 | def delete_experiment(self, experiment):
12 | raise NotImplementedError()
13 |
14 | def start_experiment(self, experiment):
15 | raise NotImplementedError()
16 |
17 | def stop_experiment(self, experiment):
18 | raise NotImplementedError()
19 |
20 | def finish_experiment(self, experiment):
21 | raise NotImplementedError()
22 |
23 | def get_experiment(self, key):
24 | raise NotImplementedError()
25 |
26 | def get_user_experiments(self, user):
27 | raise NotImplementedError()
28 |
29 | def get_projects(self):
30 | raise NotImplementedError()
31 |
32 | def get_project_experiments(self):
33 | raise NotImplementedError()
34 |
35 | def get_artifacts(self):
36 | raise NotImplementedError()
37 |
38 | def get_artifact(self):
39 | raise NotImplementedError()
40 |
41 | def get_users(self):
42 | raise NotImplementedError()
43 |
44 | def checkpoint_experiment(self, experiment):
45 | raise NotImplementedError()
46 |
47 | def refresh_auth_token(self, email, refresh_token):
48 | raise NotImplementedError()
49 |
50 | def is_auth_expired(self):
51 | raise NotImplementedError()
52 |
53 | def can_write_experiment(self, key=None, user=None):
54 | raise NotImplementedError()
55 |
56 | def register_user(self, userid, email):
57 | raise NotImplementedError()
58 |
--------------------------------------------------------------------------------
/studio/optimizer_plugins/opt_util.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import sys
5 | import numpy as np
6 |
7 | try:
8 | import matplotlib
9 | matplotlib.use("Agg")
10 | import matplotlib.pyplot as plt
11 | except BaseException:
12 | pass
13 |
14 | EPSILON = 1e-12
15 |
16 |
17 | def scale_var(var, min_range, max_range):
18 | return (var - min_range) / max((max_range - min_range), EPSILON)
19 |
20 |
21 | def unscale_var(var, min_range, max_range):
22 | return (var * (max_range - min_range)) + min_range
23 |
24 |
25 | def visualize_fitness(fitness_file=None, best_fitnesses=None,
26 | mean_fitnesses=None, outfile="fitness.png"):
27 | if best_fitnesses is None or mean_fitnesses is None:
28 | assert os.path.exists(fitness_file)
29 | best_fitnesses = []
30 | mean_fitnesses = []
31 | with open(fitness_file) as f:
32 | for line in f.readlines():
33 | best_fit, mean_fit = [float(x) for x in line.rstrip().split()]
34 | best_fitnesses.append(best_fit)
35 | mean_fitnesses.append(mean_fit)
36 |
37 | plt.figure(figsize=(16, 12))
38 | plt.plot(np.arange(len(best_fitnesses)), best_fitnesses,
39 | label="Best Fitness")
40 | plt.plot(np.arange(len(mean_fitnesses)), mean_fitnesses,
41 | label="Mean Fitness")
42 | plt.xlabel("Generation")
43 | plt.ylabel("Fitness")
44 | plt.grid()
45 | plt.legend(loc='lower right')
46 |
47 | outfile = os.path.abspath(os.path.expanduser(outfile))
48 | plt.savefig(outfile, bbox_inches='tight')
49 |
50 |
51 | if __name__ == "__main__":
52 | func = eval(sys.argv[1])
53 | func(*sys.argv[2:])
54 |
--------------------------------------------------------------------------------
/studio/scripts/studio:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | function print_help {
5 | echo "List of allowed commands:"
6 | echo " "
7 | echo " studio ui -- launches WebUI"
8 | echo " studio run -- runs an experiment from a python script"
9 | echo " studio runs -- CLI to view and manipulate existing experiments"
10 | echo " studio start remote worker -- orchestrates (re)-starts of worker that listens to the remote queue in a docker container"
11 | echo " studio remote worker -- start a worker that listens to remote queue"
12 | echo " studio add credentials -- create a docker image with credentials baked in"
13 | echo " "
14 | echo "Type --help for a list of options for each command"
15 | }
16 |
17 | function parse_command {
18 |
19 | expected_command=$1
20 | shift
21 | cmd="studio"
22 | for w in $expected_command; do
23 | if [ "$w" = "$1" ]; then
24 | cmd="$cmd-$w"
25 | shift
26 | else
27 | echo "Unknown command $1"
28 | print_help
29 | exit 1
30 | fi
31 | done
32 |
33 | echo $cmd $*
34 | eval $cmd $*
35 | exit $?
36 | }
37 |
38 |
39 | case $1 in
40 | run)
41 | parse_command "run" $*
42 | ;;
43 |
44 | serve)
45 | parse_command "serve" $*
46 | ;;
47 |
48 | runs)
49 | parse_command "runs" $*
50 | ;;
51 |
52 | ui)
53 | parse_command "ui" $*
54 | ;;
55 |
56 | start)
57 | parse_command "start remote worker" $*
58 | ;;
59 |
60 | remote)
61 | parse_command "remote worker" $*
62 | ;;
63 |
64 | add)
65 | parse_command "add credentials" $*
66 | ;;
67 |
68 | esac
69 |
70 | echo "Unknown command $1"
71 | print_help
72 | exit 1
73 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/examples/keras/train_mnist_keras_mutligpu.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from keras.layers import Input, Dense
3 | from keras.models import Model
4 | from keras.datasets import mnist
5 | from keras.utils import to_categorical
6 | from keras.callbacks import ModelCheckpoint, TensorBoard
7 |
8 | from studio import fs_tracker
9 | from studio import multi_gpu
10 |
11 | # this placeholder will contain our input digits, as flat vectors
12 | img = Input((784,))
13 | # fully-connected layer with 128 units and ReLU activation
14 | x = Dense(128, activation='relu')(img)
15 | x = Dense(128, activation='relu')(x)
16 | # output layer with 10 units and a softmax activation
17 | preds = Dense(10, activation='softmax')(x)
18 |
19 |
20 | no_gpus = 2
21 | batch_size = 128
22 |
23 | model = Model(img, preds)
24 | model = multi_gpu.make_parallel(model, no_gpus)
25 | model.compile(loss='categorical_crossentropy', optimizer='adam')
26 |
27 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
28 |
29 | x_train = x_train.reshape(60000, 784)
30 | x_test = x_test.reshape(10000, 784)
31 | x_train = x_train.astype('float32')
32 | x_test = x_test.astype('float32')
33 | x_train /= 255
34 | x_test /= 255
35 |
36 | # convert class vectors to binary class matrices
37 | y_train = to_categorical(y_train, 10)
38 | y_test = to_categorical(y_test, 10)
39 |
40 |
41 | checkpointer = ModelCheckpoint(
42 | fs_tracker.get_model_directory() +
43 | '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf')
44 |
45 |
46 | tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(),
47 | histogram_freq=0,
48 | write_graph=True,
49 | write_images=False)
50 |
51 |
52 | model.fit(
53 | x_train,
54 | y_train,
55 | validation_data=(x_test, y_test),
56 | epochs=int(sys.argv[1]),
57 | batch_size=batch_size * no_gpus,
58 | callbacks=[checkpointer, tbcallback])
59 |
--------------------------------------------------------------------------------
/docs/customenv.rst:
--------------------------------------------------------------------------------
1 | Custom environments
2 | ===================
3 |
4 | Using custom environment variables at runtime
5 | ---------------------------------------------
6 |
7 | You can add an env section to your yaml configuration file in order to send environment variables into your runner environment variables table. Variables can be prefixed with a $ sign if you wish to substitute local environment variables into your run configuration. Be aware that all values are stored in clear text. If you wish to exchange secrets you will need to encrypt them into your configuration file and then decrypt your secrets within your python code used during the experiment.
8 |
9 |
10 | Customization of python environment for the workers
11 | ---------------------------------------------------
12 |
13 | Sometimes your experiment relies on an older / custom version of some
14 | python package. For example, the Keras API has changed quite a bit between
15 | versions 1 and 2. What if you are using a new environment locally, but
16 | would like to re-run old experiments that needed older version of
17 | packages? Or, for example, you'd like to see if your code would work
18 | with the latest version of a package. Studio gives you this
19 | opportunity.
20 |
21 | ::
22 |
23 | studio run --python-pkg===
24 |
25 | allows you to run ```` on a remote / cloud worker with a
26 | specific version of a package. You can also omit ``==``
27 | to install the latest version of the package (which may not be
28 | equal to the version in your environment). Note that if a package with a
29 | custom version has dependencies conflicting with the current version, the situation
30 | gets tricky. For now, it is up to pip to resolve conflicts. In some
31 | cases it may fail and you'll have to manually specify dependencies
32 | versions by adding more ``--python-pkg`` arguments.
33 |
--------------------------------------------------------------------------------
/tests/hyperparam_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 |
4 | from studio.hyperparameter import HyperparameterParser, Hyperparameter
5 | from studio import logs
6 |
7 |
8 | class RunnerArgs(object):
9 | def __init__(self):
10 | self.optimizer = "grid"
11 | self.verbose = False
12 |
13 |
14 | class HyperparamTest(unittest.TestCase):
15 | def test_parse_range(self):
16 | logger = logs.get_logger('test_stop_experiment')
17 | h = HyperparameterParser(RunnerArgs(), logger)
18 | range_strs = ['1,2,3', ':5', '2:5', '0.1:0.05:0.3', '0.1:3:0.3',
19 | '0.01:4l:10']
20 | gd_truths = [
21 | [
22 | 1.0, 2.0, 3.0], [
23 | 0.0, 1.0, 2.0, 3.0, 4.0, 5.0], [
24 | 2.0, 3.0, 4.0, 5.0], [
25 | 0.1, 0.15, 0.2, 0.25, 0.3], [
26 | 0.1, 0.2, 0.3], [
27 | 0.01, 0.1, 1, 10]]
28 |
29 | for range_str, gd_truth in zip(range_strs, gd_truths):
30 | hyperparameter = h._parse_grid("test", range_str)
31 | self.assertTrue(np.isclose(hyperparameter.values, gd_truth).all())
32 |
33 | def test_unfold_tuples(self):
34 | logger = logs.get_logger('test_stop_experiment')
35 | h = HyperparameterParser(RunnerArgs(), logger)
36 |
37 | hyperparams = [Hyperparameter(name='a', values=[1, 2, 3]),
38 | Hyperparameter(name='b', values=[4, 5])]
39 |
40 | expected_tuples = [
41 | {'a': 1, 'b': 4}, {'a': 2, 'b': 4}, {'a': 3, 'b': 4},
42 | {'a': 1, 'b': 5}, {'a': 2, 'b': 5}, {'a': 3, 'b': 5}]
43 |
44 | self.assertEqual(
45 | sorted(h.convert_to_tuples(hyperparams), key=lambda x: str(x)),
46 | sorted(expected_tuples, key=lambda x: str(x)))
47 |
48 |
49 | if __name__ == '__main__':
50 | unittest.main()
51 |
--------------------------------------------------------------------------------
/examples/keras/multi_gpu.py:
--------------------------------------------------------------------------------
1 | from keras.layers import merge
2 | from keras.layers.core import Lambda
3 | from keras.models import Model
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def make_parallel(model, gpu_count):
9 | def get_slice(data, idx, parts):
10 | shape = tf.shape(data)
11 | size = tf.concat([shape[:1] // parts, shape[1:]], 0)
12 | stride = tf.concat([shape[:1] // parts, shape[1:] * 0], 0)
13 | start = stride * idx
14 | return tf.slice(data, start, size)
15 |
16 | outputs_all = []
17 | for i in range(len(model.outputs)):
18 | outputs_all.append([])
19 |
20 | # Place a copy of the model on each GPU, each getting a slice of the batch
21 | for i in range(gpu_count):
22 | with tf.device('/gpu:%d' % i):
23 | with tf.name_scope('tower_%d' % i):
24 |
25 | inputs = []
26 | # Slice each input into a piece for processing on this GPU
27 | for x in model.inputs:
28 | input_shape = tuple(x.get_shape().as_list())[1:]
29 | slice_n = Lambda(
30 | get_slice, output_shape=input_shape, arguments={
31 | 'idx': i, 'parts': gpu_count})(x)
32 | inputs.append(slice_n)
33 |
34 | outputs = model(inputs)
35 |
36 | if not isinstance(outputs, list):
37 | outputs = [outputs]
38 |
39 | # Save all the outputs for merging back together later
40 | for l in range(len(outputs)):
41 | outputs_all[l].append(outputs[l])
42 |
43 | # merge outputs on CPU
44 | with tf.device('/cpu:0'):
45 | merged = []
46 | for outputs in outputs_all:
47 | merged.append(merge(outputs, mode='concat', concat_axis=0))
48 |
49 | return Model(input=model.inputs, output=merged)
50 |
--------------------------------------------------------------------------------
/studio/queues/qclient_cache.py:
--------------------------------------------------------------------------------
1 | import threading
2 |
3 | from studio.queues.rabbit_queue import RMQueue
4 | from studio.util.util import check_for_kb_interrupt
5 |
6 | _queue_cache = {}
7 |
8 | def get_cached_queue(
9 | name,
10 | route,
11 | config=None,
12 | logger=None,
13 | close_after=None):
14 |
15 | queue = _queue_cache.get(name, None)
16 | if queue is not None:
17 | if logger is not None:
18 | logger.info("Got queue named %s from queue cache.", name)
19 | return queue
20 |
21 | queue = RMQueue(
22 | queue=name,
23 | route=route,
24 | config=config,
25 | logger=logger)
26 |
27 | if logger is not None:
28 | logger.info("Created new queue named %s.", name)
29 |
30 | if close_after is not None and close_after.total_seconds() > 0:
31 | thr = threading.Timer(
32 | interval=close_after.total_seconds(),
33 | function=purge_rmq,
34 | kwargs={
35 | "q": queue,
36 | "logger": logger})
37 | thr.setDaemon(True)
38 | thr.start()
39 |
40 | _queue_cache[name] = queue
41 | if logger is not None:
42 | logger.info("Added queue named %s to queue cache.", name)
43 | return queue
44 |
45 | def shutdown_cached_queue(queue, logger=None, delete_queue=True):
46 | if queue is None:
47 | return
48 |
49 | _queue_cache.pop(queue.get_name(), None)
50 | if logger is not None:
51 | logger.info("Removed queue named %s from queue cache.",
52 | queue.get_name())
53 |
54 | queue.shutdown(delete_queue)
55 |
56 |
57 | def purge_rmq(queue, logger):
58 | if queue is None:
59 | return
60 |
61 | try:
62 | queue.shutdown(True)
63 | except BaseException as exc:
64 | check_for_kb_interrupt()
65 | logger.warning(exc)
66 | return
67 | return
68 |
--------------------------------------------------------------------------------
/studio/serve.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | from studio import runner
4 |
5 |
6 | def main():
7 | argparser = argparse.ArgumentParser()
8 | argparser.add_argument(
9 | '--wrapper', '-w',
10 | help='python script with function create_model ' +
11 | 'that takes modeldir '
12 | '(that is, directory where experiment saves ' +
13 | 'the checkpoints etc)' +
14 | 'and returns dict -> dict function (model).' +
15 | 'By default, studio-serve will try to determine ' +
16 | 'this function automatically.',
17 | default=None
18 | )
19 |
20 | argparser.add_argument('--port',
21 | help='port to run Flask server on',
22 | type=int,
23 | default=5000)
24 |
25 | argparser.add_argument('--host',
26 | help='host name.',
27 | default='0.0.0.0')
28 |
29 | argparser.add_argument(
30 | '--killafter',
31 | help='Shut down after this many seconds of inactivity',
32 | default=3600)
33 |
34 | options, other_args = argparser.parse_known_args(sys.argv[1:])
35 | serve_args = ['studio::serve_main']
36 |
37 | assert len(other_args) >= 1
38 | experiment_key = other_args[-1]
39 | runner_args = other_args[:-1]
40 | runner_args.append('--reuse={}/modeldir:modeldata'.format(experiment_key))
41 | runner_args.append('--force-git')
42 | runner_args.append('--port=' + str(options.port))
43 |
44 | if options.wrapper:
45 | serve_args.append('--wrapper=' + options.wrapper)
46 | serve_args.append('--port=' + str(options.port))
47 |
48 | serve_args.append('--host=' + options.host)
49 | serve_args.append('--killafter=' + str(options.killafter))
50 |
51 | total_args = runner_args + serve_args
52 | runner.main(total_args)
53 |
54 |
55 | if __name__ == '__main__':
56 | main()
57 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 | studio/.DS_Store
92 | .DS_Store
93 |
94 | auth/.firebaserc
95 | auth/firebase.json
96 |
97 | # Ignore directories used by the packaging tools,
98 | # the go runner, and google tooling
99 | #
100 | bin/
101 | include/
102 | local/
103 | pkg/
104 | src/
105 | pip-selfcheck.json
106 |
107 | # ctags
108 | .tags
109 | .tags_sorted_by_file
110 |
111 | #version - for deployment from travis
112 | .version
113 | examples/.envrc
114 |
--------------------------------------------------------------------------------
/studio/dependencies_policies/studio_dependencies_policy.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | try:
4 | try:
5 | from pip._internal.operations import freeze
6 | except Exception:
7 | from pip.operations import freeze
8 | except ImportError:
9 | freeze = None
10 |
11 | from studio.dependencies_policies.dependencies_policy import DependencyPolicy
12 |
13 |
14 | class StudioDependencyPolicy(DependencyPolicy):
15 | """
16 | StudioML policy for adjusting experiment dependencies
17 | to specific execution environment and required resources.
18 | """
19 |
20 | def generate(self, resources_needed):
21 | if freeze is None:
22 | raise ValueError(
23 | "freeze operation is not available for StudioDependencyPolicy")
24 |
25 | needs_gpu = self._needs_gpu(resources_needed)
26 | packages = freeze.freeze()
27 | result = []
28 | for pkg in packages:
29 | if self._is_special_reference(pkg):
30 | # git repo or directly installed package
31 | result.append(pkg)
32 | elif '==' in pkg:
33 | # pypi package
34 | pkey = re.search(r'^.*?(?=\=\=)', pkg).group(0)
35 | pversion = re.search(r'(?<=\=\=).*\Z', pkg).group(0)
36 |
37 | if needs_gpu and \
38 | (pkey == 'tensorflow' or pkey == 'tf-nightly'):
39 | pkey = pkey + '-gpu'
40 |
41 | # TODO add installation logic for torch
42 | result.append(pkey + '==' + pversion)
43 | return result
44 |
45 | def _is_special_reference(self, pkg: str):
46 | if pkg.startswith('-e git+'):
47 | return True
48 | if 'git+https://' in pkg or 'file://' in pkg:
49 | return True
50 | return False
51 |
52 | def _needs_gpu(self, resources_needed):
53 | return resources_needed is not None and \
54 | int(resources_needed.get('gpus')) > 0
55 |
--------------------------------------------------------------------------------
/studio/firebase_provider.py:
--------------------------------------------------------------------------------
1 | from studio.db_providers.keyvalue_provider import KeyValueProvider
2 | from .firebase_storage_handler import FirebaseStorageHandler
3 |
4 | class FirebaseProvider(KeyValueProvider):
5 |
6 | def __init__(self, db_config, blocking_auth=True):
7 | self.meta_store = FirebaseStorageHandler(db_config)
8 |
9 | super().__init__(
10 | db_config,
11 | self.meta_store,
12 | blocking_auth)
13 |
14 | def _get(self, key, shallow=False):
15 | try:
16 | splitKey = key.split('/')
17 | key_path = '/'.join(splitKey[:-1])
18 | key_name = splitKey[-1]
19 | dbobj = self.app.database().child(key_path).child(key_name)
20 | return dbobj.get(self.auth.get_token(), shallow=shallow).val() \
21 | if self.auth else dbobj.get(shallow=shallow).val()
22 | except Exception as err:
23 | self.logger.warn(("Getting key {} from a database " +
24 | "raised an exception: {}").format(key, err))
25 | return None
26 |
27 | def _set(self, key, value):
28 | try:
29 | splitKey = key.split('/')
30 | key_path = '/'.join(splitKey[:-1])
31 | key_name = splitKey[-1]
32 | dbobj = self.app.database().child(key_path)
33 | if self.auth:
34 | dbobj.update({key_name: value}, self.auth.get_token())
35 | else:
36 | dbobj.update({key_name: value})
37 | except Exception as err:
38 | self.logger.warn(("Putting key {}, value {} into a database " +
39 | "raised an exception: {}")
40 | .format(key, value, err))
41 |
42 | def _delete(self, key, shallow=True, token=None):
43 | dbobj = self.app.database().child(key)
44 |
45 | if self.auth:
46 | dbobj.remove(self.auth.get_token())
47 | else:
48 | dbobj.remove()
49 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | =========
2 | Studio.ml
3 | =========
4 |
5 | `Github `_ |
6 | `pip `_ |
7 |
8 |
9 | Studio is a model management framework written in Python to help simplify and expedite your model building experience. It was developed to minimize the overhead involved with scheduling, running, monitoring and managing artifacts of your machine learning experiments. No one wants to spend their time configuring different machines, setting up dependencies, or playing archeologist to track down previous model artifacts.
10 |
11 | Most of the features are compatible with any Python machine learning framework (`Keras `__, `TensorFlow `__, `PyTorch `__, `scikit-learn `__, etc) without invasion of your code; some extra features are available for Keras and TensorFlow.
12 |
13 | **Use Studio to:**
14 |
15 | - Capture experiment information- Python environment, files, dependencies and logs- without modifying the experiment code. Monitor and organize experiments using a web dashboard that integrates with TensorBoard.
16 | - Run experiments locally, remotely, or in the cloud (Google Cloud or Amazon EC2)
17 | - Manage artifacts
18 | - Perform hyperparameter search
19 | - Create customizable Python environments for remote workers.
20 |
21 |
22 | .. toctree::
23 | :hidden:
24 | :caption: Introduction
25 |
26 | Getting Started
27 | installation
28 | authentication
29 | cli
30 |
31 | .. toctree::
32 | :hidden:
33 | :caption: Main Documentation
34 |
35 | artifacts
36 | hyperparams
37 | model_pipelines
38 | setup_database
39 |
40 | .. toctree::
41 | :hidden:
42 | :caption: Remote computing
43 |
44 | remote_worker
45 | customenv
46 |
47 | .. toctree::
48 | :hidden:
49 | :caption: Cloud computing
50 |
51 | cloud
52 | ec2_setup
53 | gcloud_setup
54 |
55 |
--------------------------------------------------------------------------------
/tests/util_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from studio import util
3 | from random import randint
4 |
5 |
6 | class UtilTest(unittest.TestCase):
7 | def test_remove_backspaces(self):
8 | testline = 'abcd\x08\x08\x08efg\x08\x08hi\x08'
9 | removed = util.remove_backspaces(testline)
10 | self.assertTrue(removed == 'aeh')
11 |
12 | testline = 'abcd\x08\x08\x08efg\x08\x08hi'
13 | removed = util.remove_backspaces(testline)
14 | self.assertTrue(removed == 'aehi')
15 |
16 | testline = 'abcd'
17 | removed = util.remove_backspaces(testline)
18 | self.assertTrue(removed == 'abcd')
19 |
20 | testline = 'abcd\n\ndef'
21 | removed = util.remove_backspaces(testline)
22 | self.assertTrue(removed == testline)
23 |
24 | def test_retry(self):
25 | attempts = [0]
26 | value = randint(0, 1000)
27 |
28 | def failing_func():
29 | attempts[0] += 1
30 | if attempts[0] != 2:
31 | raise ValueError('Attempt {} failed'
32 | .format(attempts[0]))
33 | return value
34 |
35 | retval = util.retry(failing_func,
36 | no_retries=2,
37 | sleep_time=1,
38 | exception_class=ValueError)
39 |
40 | self.assertEqual(retval, value)
41 | self.assertEqual(attempts, [2])
42 |
43 | # test out for catching different exception class
44 | try:
45 | retval = util.retry(failing_func,
46 | no_retries=2,
47 | sleep_time=1,
48 | exception_class=OSError)
49 | except ValueError:
50 | pass
51 | else:
52 | self.assertTrue(False)
53 |
54 | def test_str2duration(self):
55 | self.assertEqual(
56 | int(util.str2duration('30m').total_seconds()),
57 | 1800)
58 |
59 |
60 | if __name__ == "__main__":
61 | unittest.main()
62 |
--------------------------------------------------------------------------------
/docs/jupyter.rst:
--------------------------------------------------------------------------------
1 | Jupyter / ipython notebooks
2 | ===========================
3 |
4 | Studio can be used not only with scripts, but also with
5 | jupyter notebooks. The main idea is as follows -
6 | the cell annotated with a special cell magic is being treated
7 | as a separate script; and the variables are being passed in and
8 | out as artifacts (this means that all variables the cell
9 | depends on have to be pickleable). The script can then be run
10 | either locally (in which case the main benefit of studio
11 | is keeping track of all runs of the cell), or in the cloud / remotely.
12 |
13 | To use Studio in your notebook, add
14 |
15 | ::
16 |
17 | from studio import magics
18 |
19 | to the import section of your notebook.
20 |
21 | Then annotate the cell that you'd like to run via studio with
22 |
23 | ::
24 |
25 | %%studio_run
26 |
27 | This will execute the statements in the cell using studio,
28 | also passing ```` to the runner.
29 | For example, let's imagine that a variable ``x`` is declared in
30 | your notebook. Then
31 |
32 | ::
33 |
34 | %%studio_run --cloud=gcloud
35 | x += 1
36 |
37 | will do the increment of the variable ``x`` in the notebook namespace
38 | using a google cloud compute
39 | instance (given that increment of a variable in python does not take a millisecond,
40 | spinning up an entire instance to do that is probably the most wasteful thing you
41 | have seen today, but you get the idea :). The ``%%studio_run`` cell magic
42 | accepts the same arguments as the ``studio run`` command, please refer to
43 | `` for a more involved discussion of cloud and hardware selection options.
44 |
45 | Every run with studio will get a unique key and can be viewed as an experiment in
46 | studio ui.
47 |
48 | The only limitation to using studio in a notebook is that variables being used
49 | in a studio-run cell have to be pickleable. That means that, for example, you
50 | cannot use lambda functions defined elsewhere, because those are not
51 | pickleable.
52 |
53 |
54 |
--------------------------------------------------------------------------------
/docs/containers.rst:
--------------------------------------------------------------------------------
1 | =========================
2 | Containerized experiments
3 | =========================
4 |
5 | Some experiments may require more than just a specific python environment to be run reproducibly. For instance, 2017 NIPS running
6 | competition relied on a specific set of system-level pacakges for walker physics simulations. To address such experiments, Studio.ML
7 | supports execution in containers by using Singularity (https://singularity.lbl.gov). Singularity supports both Docker and its own format
8 | of containers. Containers can be used in two main ways:
9 |
10 | 1. Running experiment using container environment
11 | -------------------------------------------------
12 | In this mode, an environment is set up within the container, but the python code is outside. Studio.ML with help of Singularity
13 | mounts copy of current directory and artifacts into the container and executes the script. Typical command line will look like
14 |
15 | ::
16 |
17 | studio run --container=/path/to/container.simg script.py args
18 |
19 |
20 | Note that if your script is using Studio.ML library functions (such as `fs_tracker.get_artifact()`), Studio.ML will need to be
21 | installed within the container.
22 |
23 | 2. Running experiment using executable container
24 | ------------------------------------------------
25 | Both singularity and docker support executable containers. Studio.ML experiment can consist solely out of an executable container:
26 |
27 | ::
28 |
29 | studio run --container=/path/to/container.simg
30 |
31 | In this case, the code does not even need to be python, but all Studio.ML perks (such as cloud execution with hardware selection,
32 | keeping track of inputs and outputs of the experiment etc) still apply. There is even an artifact management - artifacts will be
33 | seen in the container in the folder one level up from working directory.
34 |
35 | Containers can be located either locally as `*.simg` files, or in the Singularity/Docker hub. In the latter case, provide a link that
36 | starts with `shub://` or `dockerhub://`
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/studio/db_providers/local_db_provider.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | from studio.db_providers.keyvalue_provider import KeyValueProvider
5 | from studio.storage.storage_handler_factory import StorageHandlerFactory
6 | from studio.storage.storage_type import StorageType
7 | from studio.util import util
8 |
9 | class LocalDbProvider(KeyValueProvider):
10 |
11 | def __init__(self, config, blocking_auth=True):
12 | self.config = config
13 | self.bucket = config.get('bucket', 'studioml-meta')
14 |
15 | factory: StorageHandlerFactory = StorageHandlerFactory.get_factory()
16 | self.meta_store = factory.get_handler(StorageType.storageLocal, config)
17 |
18 | self.endpoint = self.meta_store.get_endpoint()
19 | self.db_root = os.path.join(self.endpoint, self.bucket)
20 | self._ensure_path_dirs_exist(self.db_root)
21 |
22 | super().__init__(config, self.meta_store, blocking_auth)
23 |
24 | def _ensure_path_dirs_exist(self, path):
25 | dirs = os.path.dirname(path)
26 | os.makedirs(dirs, mode = 0o777, exist_ok = True)
27 |
28 | def _get(self, key, shallow=False):
29 | file_name = os.path.join(self.db_root, key)
30 | if not os.path.exists(file_name):
31 | return None
32 | try:
33 | with open(file_name) as infile:
34 | result = json.load(infile)
35 | except BaseException as exc:
36 | self.logger.error("FAILED to load file %s - %s", file_name, exc)
37 | result = None
38 | return result
39 |
40 | def _delete(self, key, shallow=True):
41 | file_name = os.path.join(self.db_root, key)
42 | if os.path.exists(file_name):
43 | self.logger.debug("Deleting local database file %s.", file_name)
44 | util.delete_local_path(file_name, self.db_root, shallow)
45 |
46 | def _set(self, key, value):
47 | file_name = os.path.join(self.db_root, key)
48 | self._ensure_path_dirs_exist(file_name)
49 | with open(file_name, 'w') as outfile:
50 | json.dump(value, outfile)
51 |
--------------------------------------------------------------------------------
/tests/model_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import uuid
3 | try:
4 | try:
5 | from pip._internal.operations import freeze
6 | except Exception:
7 | from pip.operations import freeze
8 | except ImportError:
9 | freeze = None
10 | import os
11 |
12 | from studio import model
13 | from studio.experiments.experiment import create_experiment
14 | from studio.dependencies_policies.studio_dependencies_policy import StudioDependencyPolicy
15 |
16 |
17 | def get_test_experiment():
18 | filename = 'test.py'
19 | args = ['a', 'b', 'c']
20 | experiment_name = 'test_experiment_' + str(uuid.uuid4())
21 | experiment = create_experiment(filename, args, experiment_name, dependency_policy=StudioDependencyPolicy())
22 | return experiment, experiment_name, filename, args
23 |
24 |
25 | class ModelTest(unittest.TestCase):
26 | def test_create_experiment(self):
27 | _, experiment_name, filename, args = get_test_experiment()
28 | experiment_project = 'create_experiment_project'
29 | experiment = create_experiment(
30 | filename, args, experiment_name, experiment_project, dependency_policy=StudioDependencyPolicy())
31 |
32 | packages = [p for p in freeze.freeze()]
33 |
34 | self.assertTrue(experiment.key == experiment_name)
35 | self.assertTrue(experiment.filename == filename)
36 | self.assertTrue(experiment.args == args)
37 | self.assertTrue(experiment.project == experiment_project)
38 | self.assertTrue(sorted(experiment.pythonenv) == sorted(packages))
39 |
40 | def test_get_config_env(self):
41 | value1 = str(uuid.uuid4())
42 | os.environ['TEST_VAR1'] = value1
43 | value2 = str(uuid.uuid4())
44 | os.environ['TEST_VAR2'] = value2
45 |
46 | config = model.get_config(
47 | os.path.join(os.path.dirname(os.path.realpath(__file__)),
48 | 'test_config_env.yaml'))
49 | self.assertEqual(config['test_key'], value1)
50 | self.assertEqual(config['test_section']['test_key'], value2)
51 |
52 |
53 | if __name__ == "__main__":
54 | unittest.main()
55 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | dist: trusty
2 |
3 | language: python
4 | python:
5 | - "2.7"
6 | - "3.6"
7 |
8 | env:
9 | fast_finish: true
10 |
11 | script:
12 | - git fetch --unshallow
13 | - 'if [ -n "$GOOGLE_CREDENTIALS_ENCODED" ]; then echo "$GOOGLE_CREDENTIALS_ENCODED" | base64 --decode > ./credentials.json && export GOOGLE_APPLICATION_CREDENTIALS=$(pwd)/credentials.json; fi'
14 | # - 'if [ -n "$GOOGLE_CREDENTIALS_DS_ENCODED" ]; then echo "$GOOGLE_CREDENTIALS_DS_ENCODED" | base64 --decode > ./credentials_ds.json && export GOOGLE_APPLICATION_CREDENTIALS_DC=$(pwd)/credentials_ds.json; fi'
15 | - pip install -e .
16 | - pip install -r test_requirements.txt
17 |
18 | - pycodestyle --show-source --exclude=.eggs --ignore=W605,W504,E704 .
19 |
20 | - echo GOOGLE_APPLICATION_CREDENTIALS=$GOOGLE_APPLICATION_CREDENTIALS
21 | - echo GOOGLE_APPLICATION_CREDENTIALS_DC=$GOOGLE_APPLICATION_CREDENTIALS_DC
22 | - pip freeze
23 | - pytest --collect-only -v
24 | - pytest -n 4 -v --durations=0 -l
25 |
26 | jobs:
27 | include:
28 | - stage: deploy
29 | python: 2.7
30 | script:
31 | - git fetch --unshallow
32 | - python setup.py sdist
33 |
34 | deploy:
35 | provider: pypi
36 | distributions: "sdist"
37 | user: pzhokhov
38 | password:
39 | secure: "FkfIvyF3PReaXOuCNSeqlg9SAiocc8WxCzlbqnfssG6JxsIBWNuoOPtgKtiYmMO+LBXnzpZ8CkTjoHT2oWWBD8r92QNmdIOze1MPs5xQmGxu12TEc5dEgMUnvenV3vkOZGC/hhwbio16Dqfd8PlHYBRkduvvqrvRD1xJPsggx9tgZwZ+Vvv0h51/BpOIMxm0xV3qu1U5A3BIvyzTUucvCAPbmHeZj+tWtiTq3OEwpc812PsCNkFNEOIBUkwafu1VE15tFcud/pZyFYwEK8/z35CJDGY4oVKWZyoC5/Gp9w658ps6MnjWP36Y1GNj4wpQPL2ftvGzRfXvfZ/FcSi4mreon24YPYFlY76ezLibWq8m2/ZBlQQwchFi48nGoalDJ7a92hyVkrRr1UP87+fbrZW8u4l/2HiKYKgotAixnoKKVOaWcmhraLve+kLE5F6e0lwmwCOHQwO9Cz4PECiZtxf8ePdDj+Dr9RLI3rKehRATpL+kXkR4WzKXPJAfDiPS6WiOz/melZJIttQcFZpOUVCs79yKWj8E95atbEy0jFaGn1PR+lFHqFYsg8TTe9Q9L/qMc+ECzBzQ4ryClCKDMYylCqbfDpOmCtHZw0a79IKmSyyvNbjnFMhAcn+GHFp+O46iF7Zxb1fAmp45AZ48bG/FCVqp66EXr590Fy1gjh0="
40 | skip_cleanup: true
41 | on:
42 | # tags: true
43 | branch: master
44 | repo: studioml/studio
45 | condition: $TRAVIS_EVENT_TYPE == "push"
46 |
--------------------------------------------------------------------------------
/examples/tensorflow/train_mnist.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import sys
3 | import os
4 | import time
5 |
6 | from keras.objectives import categorical_crossentropy
7 | from keras.layers import Dense
8 | from tensorflow.examples.tutorials.mnist import input_data
9 |
10 | from keras import backend as K
11 | from studio import fs_tracker
12 |
13 | import logging
14 |
15 | logging.basicConfig()
16 |
17 | sess = tf.Session()
18 | K.set_session(sess)
19 | # Now let's get started with our MNIST model. We can start building a
20 | # classifier exactly as you would do in TensorFlow:
21 |
22 | # this placeholder will contain our input digits, as flat vectors
23 | img = tf.placeholder(tf.float32, shape=(None, 784))
24 | # We can then use Keras layers to speed up the model definition process:
25 |
26 |
27 | # Keras layers can be called on TensorFlow tensors:
28 | # fully-connected layer with 128 units and ReLU activation
29 | x = Dense(128, activation='relu')(img)
30 | x = Dense(128, activation='relu')(x)
31 | # output layer with 10 units and a softmax activation
32 | preds = Dense(10, activation='softmax')(x)
33 | # We define the placeholder for the labels, and the loss function we will use:
34 |
35 | labels = tf.placeholder(tf.float32, shape=(None, 10))
36 |
37 | loss = tf.reduce_mean(categorical_crossentropy(labels, preds))
38 | # Let's train the model with a TensorFlow optimizer:
39 |
40 | mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True)
41 |
42 | global_step = tf.Variable(0, name='global_step', trainable=False)
43 | train_step = tf.train.GradientDescentOptimizer(
44 | 0.5).minimize(loss, global_step=global_step)
45 | # Initialize all variables
46 | init_op = tf.global_variables_initializer()
47 | saver = tf.train.Saver()
48 | sess.run(init_op)
49 |
50 |
51 | logger = logging.get_logger('train_mnist')
52 | logger.setLevel(10)
53 | # Run training loop
54 | with sess.as_default():
55 | while True:
56 | batch = mnist_data.train.next_batch(50)
57 | train_step.run(feed_dict={img: batch[0],
58 | labels: batch[1]})
59 |
60 | sys.stdout.flush()
61 | saver.save(
62 | sess,
63 | os.path.join(
64 | fs_tracker.get_model_directory(),
65 | "ckpt"),
66 | global_step=global_step)
67 | time.sleep(1)
68 |
--------------------------------------------------------------------------------
/examples/keras/fashion_mnist.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import urllib
3 |
4 | from keras.layers import Dense, Flatten
5 |
6 | from keras.models import Sequential
7 |
8 | ###
9 | # AS OF 10/18/2017, fashion_mnist available as a part of github master
10 | # branch of keras
11 | # but not a part of pypi package
12 | # Therefore, to use this, you'll need keras installed from a git repo:
13 | # git clone https://github.com/fchollet/keras && cd keras && pip install .
14 | ###
15 | from keras.datasets import fashion_mnist
16 | from keras.utils import to_categorical
17 |
18 | from keras.callbacks import ModelCheckpoint, TensorBoard
19 | from keras import optimizers
20 |
21 | import numpy as np
22 | from PIL import Image
23 | from io import BytesIO
24 |
25 | from studio import fs_tracker, model_util
26 |
27 | (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
28 |
29 | x_train = x_train.reshape(60000, 28, 28, 1)
30 | x_test = x_test.reshape(10000, 28, 28, 1)
31 | x_train = x_train.astype('float32')
32 | x_test = x_test.astype('float32')
33 | x_train /= 255
34 | x_test /= 255
35 |
36 | # convert class vectors to binary class matrices
37 | y_train = to_categorical(y_train, 10)
38 | y_test = to_categorical(y_test, 10)
39 |
40 |
41 | model = Sequential()
42 |
43 | model.add(Flatten(input_shape=(28, 28, 1)))
44 | model.add(Dense(128, activation='relu'))
45 | model.add(Dense(128, activation='relu'))
46 |
47 | model.add(Dense(10, activation='softmax'))
48 | model.summary()
49 |
50 |
51 | batch_size = 128
52 | no_epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 10
53 | lr = 0.01
54 |
55 | print('learning rate = {}'.format(lr))
56 | print('batch size = {}'.format(batch_size))
57 | print('no_epochs = {}'.format(no_epochs))
58 |
59 | model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=lr),
60 | metrics=['accuracy'])
61 |
62 |
63 | checkpointer = ModelCheckpoint(
64 | fs_tracker.get_model_directory() +
65 | '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf')
66 |
67 |
68 | tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(),
69 | histogram_freq=0,
70 | write_graph=True,
71 | write_images=True)
72 |
73 |
74 | model.fit(
75 | x_train, y_train, validation_data=(
76 | x_test,
77 | y_test),
78 | epochs=no_epochs,
79 | callbacks=[checkpointer, tbcallback],
80 | batch_size=batch_size)
81 |
--------------------------------------------------------------------------------
/studio/ed25519_key_util.py:
--------------------------------------------------------------------------------
1 | import openssh_key.private_key_list as pkl
2 | from openssh_key.key import PublicKey
3 |
4 | class Ed25519KeyUtil:
5 |
6 | @classmethod
7 | def parse_private_key_file(cls, filepath: str, logger):
8 | """
9 | Parse a file with ed25519 private key in OPEN SSH PRIVATE key format.
10 | :param filepath: file path to a key file
11 | :param logger: logger to use for messages
12 | :return: (public key part as bytes, private key part as bytes)
13 | """
14 | contents: str = ''
15 | try:
16 | with open(filepath, "r") as f:
17 | contents = f.read()
18 | except Exception as exc:
19 | if logger is not None:
20 | logger.error("FAILED to read keyfile %s: %s",
21 | filepath, exc)
22 | return None, None
23 | try:
24 | key_data = pkl.PrivateKeyList.from_string(contents)
25 | data_public = key_data[0].private.params['public']
26 | data_private = key_data[0].private.params['private_public']
27 | return data_public, data_private[:32]
28 | except Exception as exc:
29 | if logger is not None:
30 | logger.error("FAILED to decode keyfile format %s: %s",
31 | filepath, exc)
32 | return None, None
33 |
34 | @classmethod
35 | def parse_public_key_file(cls, filepath: str, logger):
36 | """
37 | Parse a file with ed25519 public.
38 | :param filepath: file path to a key file
39 | :param logger: logger to use for messages
40 | :return: public key part as bytes
41 | """
42 | contents: str = ''
43 | try:
44 | with open(filepath, "r") as f:
45 | contents = f.read()
46 | except Exception as exc:
47 | if logger is not None:
48 | logger.error("FAILED to read keyfile %s: %s",
49 | filepath, exc)
50 | return None, None
51 | try:
52 | key_data = PublicKey.from_string(contents)
53 | data_public = key_data.params['public']
54 | return data_public
55 | except Exception as exc:
56 | if logger is not None:
57 | logger.error("FAILED to decode keyfile format %s: %s",
58 | filepath, exc)
59 | return None
60 |
--------------------------------------------------------------------------------
/studio/torch/saver.py:
--------------------------------------------------------------------------------
1 | """Tools to save/restore model from checkpoints."""
2 |
3 | import os
4 | try:
5 | import torch
6 | except ImportError:
7 | torch = None
8 |
9 |
10 | def load_checkpoint(model, optimizer, model_dir, map_to_cpu=False):
11 | path = os.path.join(model_dir, 'checkpoint')
12 | if os.path.exists(path):
13 | print("Loading model from %s" % path)
14 | if map_to_cpu:
15 | checkpoint = torch.load(
16 | path, map_location=lambda storage, location: storage)
17 | else:
18 | checkpoint = torch.load(path)
19 | old_state_dict = model.state_dict()
20 | for key in old_state_dict.keys():
21 | if key not in checkpoint['model']:
22 | checkpoint['model'][key] = old_state_dict[key]
23 | model.load_state_dict(checkpoint['model'])
24 | optimizer.load_state_dict(checkpoint['optimizer'])
25 | return checkpoint.get('step', 0)
26 | return 0
27 |
28 |
29 | def save_checkpoint(model, optimizer, step, model_dir, ignore=[]):
30 | if not os.path.exists(model_dir):
31 | os.makedirs(model_dir)
32 | path = os.path.join(model_dir, 'checkpoint')
33 | state_dict = model.state_dict()
34 | if ignore:
35 | for key in state_dict.keys():
36 | for item in ignore:
37 | if key.startswith(item):
38 | state_dict.pop(key)
39 | torch.save({
40 | 'model': state_dict,
41 | 'optimizer': optimizer.state_dict(),
42 | 'step': step
43 | }, path)
44 |
45 |
46 | class Saver(object):
47 | """Class to manage save and restore for the model and optimizer."""
48 |
49 | def __init__(self, model, optimizer):
50 | self._model = model
51 | self._optimizer = optimizer
52 |
53 | def restore(self, model_dir, map_to_cpu=False):
54 | """Restores model and optimizer from given directory.
55 |
56 | Returns:
57 | Last training step for the model restored.
58 | """
59 | last_step = load_checkpoint(
60 | self._model, self._optimizer, model_dir, map_to_cpu)
61 | return last_step
62 |
63 | def save(self, model_dir, step):
64 | """Saves model and optimizer to given directory.
65 |
66 | Args:
67 | model_dir: Model directory to save.
68 | step: Current training step.
69 | """
70 | save_checkpoint(self._model, self._optimizer, step, model_dir)
71 |
--------------------------------------------------------------------------------
/examples/keras/train_cifar10.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from studio import fs_tracker
4 |
5 | from keras.layers import Dense, Flatten, Conv2D, BatchNormalization
6 |
7 | from keras.models import Sequential
8 | from keras.datasets import cifar10
9 | from keras.utils import to_categorical
10 |
11 | from keras import optimizers
12 | from keras.callbacks import ModelCheckpoint, TensorBoard
13 |
14 | import tensorflow as tf
15 | from keras import backend as backend
16 |
17 | config = tf.ConfigProto()
18 | config.gpu_options.allow_growth = True
19 | sess = tf.Session(config=config)
20 |
21 | backend.set_session(sess)
22 |
23 | (x_train, y_train), (x_test, y_test) = cifar10.load_data()
24 |
25 | x_train = x_train.reshape(50000, 32, 32, 3)
26 | x_test = x_test.reshape(10000, 32, 32, 3)
27 | x_train = x_train.astype('float32')
28 | x_test = x_test.astype('float32')
29 | x_train /= 255
30 | x_test /= 255
31 |
32 | # convert class vectors to binary class matrices
33 | y_train = to_categorical(y_train, 10)
34 | y_test = to_categorical(y_test, 10)
35 |
36 |
37 | model = Sequential()
38 |
39 | model.add(Conv2D(32, (5, 5), activation='relu', input_shape=(32, 32, 3)))
40 | model.add(Conv2D(32, (3, 3), activation='relu'))
41 | model.add(BatchNormalization())
42 | model.add(Conv2D(64, (5, 5), activation='relu'))
43 | model.add(Conv2D(64, (3, 3), activation='relu'))
44 | model.add(BatchNormalization())
45 | model.add(Conv2D(128, (3, 3), activation='relu'))
46 | model.add(Conv2D(128, (3, 3), activation='relu'))
47 | model.add(BatchNormalization())
48 | model.add(Flatten())
49 | model.add(Dense(10, activation='softmax'))
50 | model.summary()
51 |
52 |
53 | batch_size = 128
54 | no_epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 40
55 | lr = 0.01
56 |
57 | print('learning rate = {}'.format(lr))
58 | print('batch size = {}'.format(batch_size))
59 | print('no_epochs = {}'.format(no_epochs))
60 |
61 | model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=lr),
62 | metrics=['accuracy'])
63 |
64 |
65 | checkpointer = ModelCheckpoint(
66 | fs_tracker.get_model_directory() +
67 | '/checkpoint.{epoch:02d}-{val_loss:.2f}.hdf')
68 |
69 |
70 | tbcallback = TensorBoard(log_dir=fs_tracker.get_tensorboard_dir(),
71 | histogram_freq=0,
72 | write_graph=True,
73 | write_images=True)
74 |
75 |
76 | model.fit(
77 | x_train, y_train, validation_data=(
78 | x_test,
79 | y_test),
80 | epochs=no_epochs,
81 | callbacks=[checkpointer, tbcallback],
82 | batch_size=batch_size)
83 |
--------------------------------------------------------------------------------
/studio/deploy_apiserver.sh:
--------------------------------------------------------------------------------
1 | #if [[ -z "$FIREBASE_ADMIN_CREDENTIALS" ]]; then
2 | # echo "*** Firbase admin credentials file reqiured! ***"
3 | # echo "Input path to firebase admin credentials json"
4 | # echo "(you can also set FIREBASE_ADMIN_CREDENTIALS env variable manually):"
5 | # read -p ">>" firebase_creds
6 | #else
7 | # firebase_creds=$FIREBASE_ADMIN_CREDENTIALS
8 | #fi
9 |
10 | #if [ ! -f $firebase_creds ]; then
11 | # echo " *** File $firebase_creds does not exist! ***"
12 | # exit 1
13 | #fi
14 |
15 | #creds="./firebase_admin_creds.json"
16 | #cp $firebase_creds $creds
17 |
18 | config="apiserver_config.yaml"
19 | if [ -n "$2" ]; then
20 | config=$2
21 | fi
22 |
23 | echo "config file = $config"
24 |
25 | if [ "$1" = "gae" ]; then
26 |
27 | mv default_config.yaml default_config.yaml.orig
28 | cp $config default_config.yaml
29 | cp app.yaml app.yaml.old
30 | echo "env_variables:" >> app.yaml
31 |
32 | if [ -n "$AWS_ACCESS_KEY_ID" ]; then
33 | echo "exporting AWS env variables to app.yaml"
34 | echo " AWS_ACCESS_KEY_ID: $AWS_ACCESS_KEY_ID" >> app.yaml
35 | echo " AWS_SECRET_ACCESS_KEY: $AWS_SECRET_ACCESS_KEY" >> app.yaml
36 | echo " AWS_DEFAULT_REGION: $AWS_DEFAULT_REGION" >> app.yaml
37 | fi
38 |
39 | if [ -n "$STUDIO_GITHUB_ID" ]; then
40 | echo "exporting github secret env variables to app.yaml"
41 | echo " STUDIO_GITHUB_ID: $STUDIO_GITHUB_ID" >> app.yaml
42 | echo " STUDIO_GITHUB_SECRET: $STUDIO_GITHUB_SECRET" >> app.yaml
43 | fi
44 |
45 | rm -rf lib
46 | # pip install -t lib -r ../requirements.txt
47 | pip install -t lib ../
48 | # pip install -t lib -r ../extra_server_requirements.txt
49 |
50 | # patch library files where necessary
51 | for patch in $(find patches -name "*.patch"); do
52 | filename=${patch#patches/}
53 | filename=${filename%.patch}
54 | patch lib/$filename $patch
55 | done
56 |
57 | rm lib/tensorflow/python/_pywrap_tensorflow_internal.so
58 | echo "" > lib/tensorflow/__init__.py
59 |
60 | # dev_appserver.py app.yaml --dev_appserver_log_level debug
61 | yes Y | gcloud app deploy --no-promote
62 |
63 | mv default_config.yaml.orig default_config.yaml
64 | mv app.yaml.old app.yaml
65 | else if [ "$1" = "local" ]; then
66 | port=$2
67 | studio ui --config=apiserver_config.yaml --port=$port
68 | else
69 | echo "*** unknown target: $1 (should be either gae or local) ***"
70 | exit 1
71 | fi
72 | fi
73 |
74 | # rm -f $creds
75 | # rm -rf lib
76 |
--------------------------------------------------------------------------------
/studio/fs_tracker.py:
--------------------------------------------------------------------------------
1 | """Utilities to track and record file system."""
2 |
3 | import os
4 | import shutil
5 |
6 | from studio.artifacts import artifacts_tracker
7 |
8 | STUDIOML_EXPERIMENT = 'STUDIOML_EXPERIMENT'
9 | STUDIOML_HOME = 'STUDIOML_HOME'
10 | STUDIOML_ARTIFACT_MAPPING = 'STUDIOML_ARTIFACT_MAPPING'
11 |
12 |
13 | def get_experiment_key():
14 | return artifacts_tracker.get_experiment_key()
15 |
16 |
17 | def get_studio_home():
18 | return artifacts_tracker.get_studio_home()
19 |
20 |
21 | def setup_experiment(env, experiment, clean=True):
22 | artifacts_tracker.setup_experiment(env, experiment, clean=clean)
23 |
24 |
25 | def get_artifact(tag):
26 | return artifacts_tracker.get_experiment(tag)
27 |
28 |
29 | def get_artifacts():
30 | return artifacts_tracker.get_artifacts()
31 |
32 |
33 | def get_model_directory(experiment_name=None):
34 | return get_artifact_cache('modeldir', experiment_name)
35 |
36 |
37 | def get_artifact_cache(tag, experiment_name=None):
38 | return artifacts_tracker.get_artifact_cache(
39 | tag,
40 | experiment_name=experiment_name)
41 |
42 |
43 | def get_blob_cache(blobkey):
44 | return artifacts_tracker.get_blob_cache(blobkey)
45 |
46 | def get_model_directory(experiment_name=None):
47 | return get_artifact_cache('modeldir', experiment_name)
48 |
49 | def _get_artifact_mapping_path(experiment_name=None):
50 | experiment_name = experiment_name if experiment_name else \
51 | os.environ[STUDIOML_EXPERIMENT]
52 |
53 | basepath = os.path.join(
54 | get_studio_home(),
55 | 'artifact_mappings',
56 | experiment_name
57 | )
58 | if not os.path.exists(basepath):
59 | os.makedirs(basepath)
60 |
61 | return os.path.join(basepath, 'artifacts.json')
62 |
63 |
64 | def _get_experiment_key(experiment):
65 | if not isinstance(experiment, str):
66 | return experiment.key
67 | else:
68 | return experiment
69 |
70 |
71 | def _setup_model_directory(experiment_name, clean=False):
72 | path = get_model_directory(experiment_name)
73 | if clean and os.path.exists(path):
74 | shutil.rmtree(path)
75 |
76 | if not os.path.exists(path):
77 | os.makedirs(path)
78 |
79 |
80 | def get_queue_directory():
81 | queue_dir = os.path.join(
82 | get_studio_home(),
83 | 'queue')
84 | if not os.path.exists(queue_dir):
85 | try:
86 | os.makedirs(queue_dir)
87 | except OSError:
88 | pass
89 |
90 | return queue_dir
91 |
92 |
93 | def get_tensorboard_dir(experiment_name=None):
94 | return get_artifact_cache('tb', experiment_name)
95 |
--------------------------------------------------------------------------------
/studio/cloud_worker_util.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from studio.util.util import rand_string
4 | """
5 | Utility functions for anything shared in common by ec2cloud_worker and
6 | gcloud_worker
7 | """
8 |
9 | INDENT = 4
10 |
11 |
12 | def insert_user_startup_script(user_startup_script, startup_script_str,
13 | logger):
14 | if user_startup_script is None:
15 | return startup_script_str
16 |
17 | try:
18 | with open(os.path.abspath(os.path.expanduser(
19 | user_startup_script))) as f:
20 | user_startup_script_lines = f.read().splitlines()
21 | except BaseException:
22 | if user_startup_script is not None:
23 | logger.warn("User startup script (%s) cannot be loaded" %
24 | user_startup_script)
25 | return startup_script_str
26 |
27 | startup_script_lines = startup_script_str.splitlines()
28 | new_startup_script_lines = []
29 | whitespace = " " * INDENT
30 | for line in startup_script_lines:
31 |
32 | if line.startswith("studio remote worker") or \
33 | line.startswith("studio-remote-worker"):
34 | curr_working_dir = "curr_working_dir_%s" % rand_string(32)
35 | func_name = "user_script_%s" % rand_string(32)
36 |
37 | new_startup_script_lines.append("%s=$(pwd)\n" % curr_working_dir)
38 | new_startup_script_lines.append("cd ~\n")
39 | new_startup_script_lines.append("%s()(\n" % func_name)
40 | for user_line in user_startup_script_lines:
41 | if user_line.startswith("#!"):
42 | continue
43 | new_startup_script_lines.append("%s%s\n" %
44 | (whitespace, user_line))
45 |
46 | new_startup_script_lines.append("%scd $%s\n" %
47 | (whitespace, curr_working_dir))
48 | new_startup_script_lines.append("%s%s\n" %
49 | (whitespace, line))
50 | new_startup_script_lines.append(")\n")
51 | new_startup_script_lines.append("%s\n" % func_name)
52 | else:
53 | new_startup_script_lines.append("%s\n" % line)
54 |
55 | new_startup_script = "".join(new_startup_script_lines)
56 | logger.info('Inserting the following user startup script'
57 | ' into the default startup script:')
58 | logger.info("\n".join(user_startup_script_lines))
59 |
60 | # with open("/home/jason/Desktop/script.sh", 'wb') as f:
61 | # f.write(new_startup_script)
62 | # sys.exit()
63 |
64 | return new_startup_script
65 |
--------------------------------------------------------------------------------
/studio/storage/http_storage_handler.py:
--------------------------------------------------------------------------------
1 | import os
2 | from urllib.parse import urlparse
3 | from typing import Dict
4 | from studio.util import logs
5 | from studio.credentials.credentials import Credentials
6 | from studio.storage.storage_setup import get_storage_verbose_level
7 | from studio.storage.storage_type import StorageType
8 | from studio.storage.storage_handler import StorageHandler
9 | from studio.storage import storage_util
10 |
11 | class HTTPStorageHandler(StorageHandler):
12 | def __init__(self, remote_path, credentials_dict,
13 | timestamp=None,
14 | compression=None):
15 |
16 | self.logger = logs.get_logger(self.__class__.__name__)
17 | self.logger.setLevel(get_storage_verbose_level())
18 |
19 | self.url = remote_path
20 | self.timestamp = timestamp
21 |
22 | parsed_url = urlparse(self.url)
23 | self.scheme = parsed_url.scheme
24 | self.endpoint = parsed_url.netloc
25 | self.path = parsed_url.path
26 | self.credentials = Credentials(credentials_dict)
27 |
28 | super().__init__(StorageType.storageHTTP,
29 | self.logger,
30 | False,
31 | compression=compression)
32 |
33 | def upload_file(self, key, local_path):
34 | storage_util.upload_file(self.url, local_path, self.logger)
35 |
36 | def download_file(self, key, local_path):
37 | return storage_util.download_file(self.url, local_path, self.logger)
38 |
39 | def download_remote_path(self, remote_path, local_path):
40 | head, _ = os.path.split(local_path)
41 | if head is not None:
42 | os.makedirs(head, exist_ok=True)
43 | return storage_util.download_file(remote_path, local_path, self.logger)
44 |
45 | @classmethod
46 | def get_id(cls, config: Dict) -> str:
47 | endpoint = config.get('endpoint', None)
48 | if endpoint is None:
49 | return None
50 | creds: Credentials = Credentials.get_credentials(config)
51 | creds_fingerprint = creds.get_fingerprint() if creds else ''
52 | return '[http]{0}::{1}'.format(endpoint, creds_fingerprint)
53 |
54 | def get_local_destination(self, remote_path: str):
55 | parsed_url = urlparse(remote_path)
56 | parts = parsed_url.path.split('/')
57 | return None, parts[len(parts)-1]
58 |
59 | def delete_file(self, key, shallow=True):
60 | raise NotImplementedError
61 |
62 | def get_file_url(self, key):
63 | return self.url
64 |
65 | def get_file_timestamp(self, key):
66 | return self.timestamp
67 |
68 | def get_qualified_location(self, key):
69 | return self.url
70 |
--------------------------------------------------------------------------------
/studio/scripts/studio-add-credentials:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # script that builds docker image with credentials baked in
4 | # usage:
5 | # studio-add-credentials [--base-image=] [--tag=] [--check-gpu]
6 | #
7 | # --base-image specifies the base image to add credentials to. Default is peterzhokhoff/tfstudio
8 | # --tag specifes the tag of the resulting image. Default is _creds
9 | # --check-gpu option (if specified) checks if nvidia-smi works correctly on a current machine, and if it is not, uninstalls tensorflow-gpu from docker image.
10 | # Without --check-gpu option the built docker image may not work on a current machine.
11 |
12 | base_img=peterzhokhoff/tfstudio
13 |
14 | while [[ $# -gt 0 ]]
15 | do
16 | key="$1"
17 | case ${key%%=*} in
18 |
19 | -b|--base-image)
20 | base_img="${1##*=}"
21 | ;;
22 |
23 | -t|--tag)
24 | output_img="${1##*=}"
25 | ;;
26 |
27 | esac
28 | shift
29 | done
30 |
31 | if [ -z $output_img ]; then
32 | output_img=$base_img"_creds"
33 | fi
34 |
35 | # mypath=$(pwd)/${0%/*}
36 | dockerfile=".Dockerfile_bake_creds"
37 | awspath=$HOME/.aws
38 |
39 | # cd $mypath/../../
40 |
41 | echo "Base image: $base_img"
42 | echo "Tag: $output_img"
43 |
44 |
45 | echo "Uninstall tensorflow-gpu: $uninstall_tfgpu"
46 |
47 | contextdir=$TMPDIR/studioml_container_context
48 | mkdir $contextdir
49 | cd $contextdir
50 |
51 | if [ -d $HOME/.studioml/keys ]; then
52 | cp -r $HOME/.studioml/keys .keys
53 | fi
54 |
55 | if [ -n $GOOGLE_APPLICATION_CREDENTIALS ]; then
56 | cp $GOOGLE_APPLICATION_CREDENTIALS .gac_credentials
57 | fi
58 |
59 | if [ -d $awspath ]; then
60 | cp -r $awspath .aws
61 | fi
62 |
63 | # build dockerfile
64 | echo "Constructing dockerfile..."
65 | echo "FROM $base_img" > $dockerfile
66 |
67 | if [ -d $HOME/.studioml/keys/ ]; then
68 | echo "ADD .keys /root/.studioml/keys" >> $dockerfile
69 | fi
70 |
71 | if [ -n $GOOGLE_APPLICATION_CREDENTIALS ]; then
72 | echo "ADD .gac_credentials /root/gac_credentials" >> $dockerfile
73 | echo "ENV GOOGLE_APPLICATION_CREDENTIALS /root/gac_credentials" >> $dockerfile
74 | fi
75 |
76 | if [ -d $awspath ]; then
77 | echo "ADD .aws /root/.aws" >> $dockerfile
78 | fi
79 |
80 | echo "Done. Resulting dockerfile: "
81 | cat $dockerfile
82 |
83 |
84 | # build docker image
85 | echo "Building image..."
86 | docker build -t $output_img -f $dockerfile .
87 | echo "Done"
88 |
89 | # cleanup
90 | echo "Cleaning up..."
91 | rm -rf .keys
92 | rm -rf $dockerfile
93 |
94 | if [ -n $GOOGLE_APPLICATION_CREDENTIALS ]; then
95 | rm -rf $gac_file
96 | fi
97 |
98 | if [ -d $awspath ]; then
99 | rm -rf .aws
100 | fi
101 | echo "Done"
102 |
--------------------------------------------------------------------------------
/studio/remote_worker.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 |
4 | from studio import model, logs
5 | from .local_worker import worker_loop
6 | from .pubsub_queue import PubsubQueue
7 | from .sqs_queue import SQSQueue
8 | from .util import parse_verbosity
9 |
10 | from .qclient_cache import get_cached_queue
11 |
12 |
13 | def main(args=sys.argv):
14 | logger = logs.get_logger('studio-remote-worker')
15 | parser = argparse.ArgumentParser(
16 | description='Studio remote worker. \
17 | Usage: studio-remote-worker \
18 | ')
19 | parser.add_argument('--config', help='configuration file', default=None)
20 |
21 | parser.add_argument(
22 | '--guest',
23 | help='Guest mode (does not require db credentials)',
24 | action='store_true')
25 |
26 | parser.add_argument(
27 | '--single-run',
28 | help='quit after a single run (regardless of the state of the queue)',
29 | action='store_true')
30 |
31 | parser.add_argument('--queue', help='queue name', required=True)
32 | parser.add_argument(
33 | '--verbose', '-v',
34 | help='Verbosity level. Allowed vaules: ' +
35 | 'debug, info, warn, error, crit ' +
36 | 'or numerical value of logger levels.',
37 | default=None)
38 |
39 | parser.add_argument(
40 | '--timeout', '-t',
41 | help='Timeout after which remote worker stops listening (in seconds)',
42 | type=int,
43 | default=100)
44 |
45 | parsed_args, script_args = parser.parse_known_args(args)
46 | verbose = parse_verbosity(parsed_args.verbose)
47 | logger.setLevel(verbose)
48 |
49 | config = None
50 | if parsed_args.config is not None:
51 | config = model.get_config(parsed_args.config)
52 |
53 | if parsed_args.queue.startswith('ec2_') or \
54 | parsed_args.queue.startswith('sqs_'):
55 | queue = SQSQueue(parsed_args.queue, verbose=verbose)
56 | elif parsed_args.queue.startswith('rmq_'):
57 | queue = get_cached_queue(
58 | name=parsed_args.queue,
59 | route='StudioML.' + parsed_args.queue,
60 | config=config,
61 | logger=logger,
62 | verbose=verbose)
63 | else:
64 | queue = PubsubQueue(parsed_args.queue, verbose=verbose)
65 |
66 | logger.info('Waiting for work')
67 |
68 | timeout_before = parsed_args.timeout
69 | timeout_after = timeout_before if timeout_before > 0 else 0
70 | # wait_for_messages(queue, timeout_before, logger)
71 |
72 | logger.info('Starting working')
73 | worker_loop(queue, parsed_args,
74 | single_experiment=parsed_args.single_run,
75 | timeout=timeout_after,
76 | verbose=verbose)
77 |
78 |
79 | if __name__ == "__main__":
80 | main()
81 |
--------------------------------------------------------------------------------
/studio/storage/storage_handler_factory.py:
--------------------------------------------------------------------------------
1 | from studio.util import logs
2 | from typing import Dict
3 |
4 | from studio.storage.http_storage_handler import HTTPStorageHandler
5 | from studio.storage.local_storage_handler import LocalStorageHandler
6 | from studio.storage.storage_setup import get_storage_verbose_level
7 | from studio.storage.storage_handler import StorageHandler
8 | from studio.storage.storage_type import StorageType
9 | from studio.storage.s3_storage_handler import S3StorageHandler
10 |
11 | _storage_factory = None
12 |
13 | class StorageHandlerFactory:
14 | def __init__(self):
15 | self.logger = logs.get_logger(self.__class__.__name__)
16 | self.logger.setLevel(get_storage_verbose_level())
17 | self.handlers_cache = dict()
18 | self.cleanup_at_exit: bool = True
19 |
20 | @classmethod
21 | def get_factory(cls):
22 | global _storage_factory
23 | if _storage_factory is None:
24 | _storage_factory = StorageHandlerFactory()
25 | return _storage_factory
26 |
27 | def set_cleanup_at_exit(self, value: bool):
28 | self.cleanup_at_exit = value
29 |
30 | def cleanup(self):
31 | if not self.cleanup_at_exit:
32 | return
33 |
34 | for _, handler in self.handlers_cache.items():
35 | if handler is not None:
36 | handler.cleanup()
37 |
38 | def get_handler(self, handler_type: StorageType,
39 | config: Dict) -> StorageHandler:
40 | if handler_type == StorageType.storageS3:
41 | handler_id: str = S3StorageHandler.get_id(config)
42 | handler = self.handlers_cache.get(handler_id, None)
43 | if handler is None:
44 | handler = S3StorageHandler(config)
45 | self.handlers_cache[handler_id] = handler
46 | return handler
47 | if handler_type == StorageType.storageHTTP:
48 | handler_id: str = HTTPStorageHandler.get_id(config)
49 | handler = self.handlers_cache.get(handler_id, None)
50 | if handler is None:
51 | handler = HTTPStorageHandler(
52 | config.get('endpoint', None),
53 | config.get('credentials', None))
54 | self.handlers_cache[handler_id] = handler
55 | return handler
56 | if handler_type == StorageType.storageLocal:
57 | handler_id: str = LocalStorageHandler.get_id(config)
58 | handler = self.handlers_cache.get(handler_id, None)
59 | if handler is None:
60 | handler = LocalStorageHandler(config)
61 | self.handlers_cache[handler_id] = handler
62 | return handler
63 | self.logger("FAILED to get storage handler: unsupported type %s",
64 | repr(handler_type))
65 | return None
--------------------------------------------------------------------------------
/tests/serving_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import subprocess
3 | import time
4 | from random import randint
5 | import os
6 | import uuid
7 | import requests
8 |
9 | from studio import model
10 | from env_detect import on_gcp
11 |
12 |
13 | @unittest.skipIf(
14 | not on_gcp(),
15 | 'User indicated not on gcp')
16 | class UserIndicatedOnGCPTest(unittest.TestCase):
17 | def test_on_enviornment(self):
18 | self.assertTrue('GOOGLE_APPLICATION_CREDENTIALS' in os.environ.keys())
19 |
20 |
21 | @unittest.skipIf(
22 | (not on_gcp()) or
23 | 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ.keys(),
24 | 'Skipping due to userinput or GCP Not detected' +
25 | "GOOGLE_APPLICATION_CREDENTIALS is missing, needed for " +
26 | "server to communicate with storage")
27 | class ServingTest(unittest.TestCase):
28 |
29 | _mutliprocess_shared_ = True
30 |
31 | def _test_serving(self, data_in, expected_data_out, wrapper=None):
32 |
33 | self.port = randint(5000, 9000)
34 | server_experimentid = 'test_serving_' + str(uuid.uuid4())
35 |
36 | args = [
37 | 'studio', 'run',
38 | '--force-git',
39 | '--verbose=debug',
40 | '--experiment=' + server_experimentid,
41 | '--config=' + self.get_config_path(),
42 | 'studio::serve_main',
43 | '--port=' + str(self.port),
44 | '--host=localhost'
45 | ]
46 |
47 | if wrapper:
48 | args.append('--wrapper=' + wrapper)
49 |
50 | subprocess.Popen(args, cwd=os.path.dirname(__file__))
51 | time.sleep(60)
52 |
53 | try:
54 | retval = requests.post(
55 | url='http://localhost:' + str(self.port), json=data_in)
56 | data_out = retval.json()
57 | assert data_out == expected_data_out
58 |
59 | finally:
60 | with model.get_db_provider(model.get_config(
61 | self.get_config_path())) as db:
62 |
63 | db.stop_experiment(server_experimentid)
64 | time.sleep(20)
65 | db.delete_experiment(server_experimentid)
66 |
67 | def test_serving_identity(self):
68 | data = {"a": "b"}
69 | self._test_serving(
70 | data_in=data,
71 | expected_data_out=data
72 | )
73 |
74 | def test_serving_increment(self):
75 | data_in = {"a": 1}
76 | data_out = {"a": 2}
77 |
78 | self._test_serving(
79 | data_in=data_in,
80 | expected_data_out=data_out,
81 | wrapper='model_increment.py'
82 | )
83 |
84 | def get_config_path(self):
85 | return os.path.join(
86 | os.path.dirname(__file__),
87 | 'test_config_http_client.yaml'
88 | )
89 |
90 |
91 | if __name__ == '__main__':
92 | unittest.main()
93 |
--------------------------------------------------------------------------------
/studio/git_util.py:
--------------------------------------------------------------------------------
1 | import re
2 | import subprocess
3 | import os
4 |
5 |
6 | def get_git_info(path='.', abort_dirty=True):
7 | info = {}
8 | if not is_git(path):
9 | return None
10 |
11 | if abort_dirty and not is_clean(path):
12 | return None
13 |
14 | info['url'] = get_repo_url(path)
15 | info['commit'] = get_commit(path)
16 | return info
17 |
18 |
19 | def is_git(path='.'):
20 | p = subprocess.Popen(
21 | ['git', 'status'],
22 | stdout=subprocess.DEVNULL,
23 | stderr=subprocess.DEVNULL,
24 | cwd=path)
25 |
26 | p.wait()
27 | return (p.returncode == 0)
28 |
29 |
30 | def is_clean(path='.'):
31 | p = subprocess.Popen(
32 | ['git', 'status', '-s'],
33 | stdout=subprocess.PIPE,
34 | stderr=subprocess.PIPE,
35 | cwd=path)
36 |
37 | stdout, _ = p.communicate()
38 | if not p.returncode == 0:
39 | return False
40 |
41 | return (stdout.strip() == '')
42 |
43 |
44 | def get_repo_url(path='.', remove_user=True):
45 | p = subprocess.Popen(
46 | ['git', 'config', '--get', 'remote.origin.url'],
47 | stdout=subprocess.PIPE,
48 | stderr=subprocess.PIPE,
49 | cwd=path)
50 |
51 | stdout, _ = p.communicate()
52 | if p.returncode != 0:
53 | return None
54 |
55 | url = stdout.strip()
56 | if remove_user:
57 | url = re.sub('(?<=://).*@', '', url.decode('utf-8'))
58 | return url
59 |
60 |
61 | def get_branch(path='.'):
62 | p = subprocess.Popen(
63 | ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
64 | stdout=subprocess.PIPE,
65 | stderr=subprocess.STDOUT,
66 | cwd=path)
67 |
68 | stdout, _ = p.communicate()
69 | if p.returncode == 0:
70 | return None
71 |
72 | return stdout.strip().decode('utf8')
73 |
74 |
75 | def get_commit(path='.'):
76 | p = subprocess.Popen(
77 | ['git', 'rev-parse', 'HEAD'],
78 | stdout=subprocess.PIPE,
79 | stderr=subprocess.STDOUT,
80 | cwd=path)
81 |
82 | stdout, _ = p.communicate()
83 | if p.returncode != 0:
84 | return None
85 |
86 | return stdout.strip().decode('utf8')
87 |
88 |
89 | def get_my_repo_url():
90 | mypath = os.path.dirname(os.path.realpath(__file__))
91 | repo = get_repo_url(mypath)
92 | if repo is None:
93 | repo = "https://github.com/studioml/studio"
94 | return repo
95 |
96 |
97 | def get_my_branch():
98 | mypath = os.path.dirname(os.path.realpath(__file__))
99 | branch = get_branch(mypath)
100 | if branch is None:
101 | branch = "master"
102 | return branch
103 |
104 |
105 | def get_my_checkout_target():
106 | mypath = os.path.dirname(os.path.realpath(__file__))
107 | target = get_commit(mypath)
108 | if target is None:
109 | target = get_my_branch()
110 |
111 | return target
112 |
--------------------------------------------------------------------------------
/studio/completion_service/completion_service_client.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import shutil
3 | import pickle
4 | import os
5 | import sys
6 | import six
7 | import signal
8 | import pdb
9 |
10 | from studio import fs_tracker, model, logs, util
11 |
12 | logger = logs.get_logger('completion_service_client')
13 | try:
14 | logger.setLevel(model.parse_verbosity(sys.argv[1]))
15 | except BaseException:
16 | logger.setLevel(10)
17 |
18 |
19 | def main():
20 | logger.debug('copying and importing client module')
21 | logger.debug('getting file mappings')
22 |
23 | # Register signal handler for signal.SIGUSR1
24 | # which will invoke built-in Python debugger:
25 | signal.signal(signal.SIGUSR1, lambda sig, stack: pdb.set_trace())
26 |
27 | artifacts = fs_tracker.get_artifacts()
28 | files = {}
29 | logger.debug("Artifacts = {}".format(artifacts))
30 |
31 | for tag, path in six.iteritems(artifacts):
32 | if tag not in {'workspace', 'modeldir', 'tb', '_runner'}:
33 | if os.path.isfile(path):
34 | files[tag] = path
35 | elif os.path.isdir(path):
36 | dirlist = os.listdir(path)
37 | if any(dirlist):
38 | files[tag] = os.path.join(
39 | path,
40 | dirlist[0]
41 | )
42 |
43 | logger.debug("Files = {}".format(files))
44 | script_path = files['clientscript']
45 | retval_path = fs_tracker.get_artifact('retval')
46 | util.rm_rf(retval_path)
47 |
48 | # script_name = os.path.basename(script_path)
49 | new_script_path = os.path.join(os.getcwd(), '_clientscript.py')
50 | shutil.copy(script_path, new_script_path)
51 |
52 | script_path = new_script_path
53 | logger.debug("script path: " + script_path)
54 |
55 | mypath = os.path.dirname(script_path)
56 | sys.path.append(mypath)
57 | # os.path.splitext(os.path.basename(script_path))[0]
58 | module_name = '_clientscript'
59 |
60 | client_module = importlib.import_module(module_name)
61 | logger.debug('loading args')
62 |
63 | args_path = files['args']
64 |
65 | with open(args_path, 'rb') as f:
66 | args = pickle.loads(f.read())
67 |
68 | logger.debug('calling client function')
69 | retval = client_module.clientFunction(args, files)
70 |
71 | logger.debug('saving the return value')
72 | if os.path.isdir(fs_tracker.get_artifact('clientscript')):
73 | # on go runner:
74 | logger.debug("Running in a go runner, creating {} for retval"
75 | .format(retval_path))
76 | try:
77 | os.mkdir(retval_path)
78 | except OSError:
79 | logger.debug('retval dir present')
80 |
81 | retval_path = os.path.join(retval_path, 'retval')
82 | logger.debug("New retval_path is {}".format(retval_path))
83 |
84 | logger.debug('Saving retval')
85 | with open(retval_path, 'wb') as f:
86 | f.write(pickle.dumps(retval, protocol=2))
87 | logger.debug('Done')
88 |
89 |
90 | if __name__ == "__main__":
91 | main()
92 |
--------------------------------------------------------------------------------
/docs/faq.rst:
--------------------------------------------------------------------------------
1 | Frequently Asked Questions
2 | ==========================
3 |
4 | `Join us on Slack! `_
5 | -----------------------------------------------
6 |
7 | - `What is the complete list of tools Studio.ML is compatible with?`_
8 |
9 | - `Is Studio.ML compatible with Python 3?`_
10 |
11 | - `Do I need to change my code to use Studio.ML?`_
12 |
13 | - `How can I track the training of my models?`_
14 |
15 | - `How does Studio.ml integrate with Google Cloud or Amazon EC2?`_
16 |
17 | - `Is it possible to view the experiment artifacts outside of the Web UI?`_
18 |
19 |
20 | _`What is the complete list of tools Studio.ML is compatible with?`
21 | ------------------------------------------------------------------
22 |
23 | Keras, TensorFlow, PyTorch, scikit-learn, pandas - anything that runs in python
24 |
25 | _`Is Studio.ML compatible with Python 3?`
26 | --------------------------------
27 |
28 | Yes! Studio.ML is now compatible to use with Python 3.
29 |
30 | _`Can I use Studio.ML with my jupyter / ipython notebooks?`
31 | -----------------------------------------------------------
32 |
33 | Yes! The basic usage pattern is import ``magics`` module from studio,
34 | and then annotate cells that need to be run via studio with
35 | ``%%studio_run`` cell magic (optionally followed by the same command-line arguments that
36 | ``studio run`` accepts. Please refer to `` the for more info
37 |
38 |
39 | _`Do I need to change my code to use Studio.ML?`
40 | ---------------------------------------------
41 |
42 | Studio is designed to minimize any invasion of your existing code. Running an experiment with Studio should be as simple as replacing ``python`` with ``studio run`` in your command line with a few flags for capturing your workspace or naming your experiments.
43 |
44 | _`How can I track the training of my models?`
45 | --------------------
46 |
47 | You can manage any of your experiments- current, old or queued- through the web interface. Simply run ``studio ui`` to launch the UI to view details of any of your experiments.
48 |
49 | _`How does Studio.ml integrate with Google Cloud or Amazon EC2?`
50 | -----------------
51 |
52 | We use standard Python tools like Boto and Google Cloud Python Client to launch GPU instances that are used for model training and de-provision them when the experiment is finished.
53 |
54 | _`Is it possible to view the experiment artifacts outside of the Web UI?`
55 | -------------------
56 |
57 | Yes!
58 |
59 | ::
60 |
61 | from studio import model
62 |
63 | with model.get_db_provider() as db:
64 | experiment = db.get_experiment()
65 |
66 |
67 | will return an experiment object that contains all the information about the experiment with key ````, including artifacts.
68 | The artifacts can then be downloaded:
69 |
70 | ::
71 |
72 | with model.get_db_provider() as db:
73 | artifact_path = db.get_artifact(experiment.artifacts[])
74 |
75 | will download an artifact with tag ```` and return a local path to it in ``artifact_path`` variable
76 |
77 |
--------------------------------------------------------------------------------
/studio/scripts/ec2_worker_startup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | exec > >(tee -i ~/ec2_worker_logfile.txt)
4 | exec 2>&1
5 |
6 | cd ~
7 | mkdir .aws
8 | echo "[default]" > .aws/config
9 | echo "region = {region}" >> .aws/config
10 |
11 | mkdir -p .studioml/keys
12 | key_name="{auth_key}"
13 | queue_name="{queue_name}"
14 | echo "{auth_data}" | base64 --decode > .studioml/keys/$key_name
15 | echo "{google_app_credentials}" | base64 --decode > credentials.json
16 |
17 | export GOOGLE_APPLICATION_CREDENTIALS=~/credentials.json
18 |
19 | export AWS_ACCESS_KEY_ID="{aws_access_key}"
20 | export AWS_SECRET_ACCESS_KEY="{aws_secret_key}"
21 |
22 | code_url_base="https://storage.googleapis.com/studio-ed756.appspot.com/src"
23 | #code_ver="tfstudio-64_config_location-2017-08-04_1.tgz"
24 |
25 |
26 | autoscaling_group="{autoscaling_group}"
27 | instance_id=$(wget -q -O - http://169.254.169.254/latest/meta-data/instance-id)
28 |
29 | echo "Environment varibles:"
30 | env
31 |
32 | {install_studio}
33 |
34 | python $(which studio-remote-worker) --queue=$queue_name --verbose=debug --timeout={timeout}
35 |
36 | # sudo update-alternatives --set python3
37 |
38 | # shutdown the instance
39 | echo "Work done"
40 |
41 | hostname=$(hostname)
42 | aws s3 cp /var/log/cloud-init-output.log "s3://studioml-logs/$queue_name/$hostname.txt"
43 |
44 | if [[ -n $(who) ]]; then
45 | echo "Users are logged in, not shutting down"
46 | echo "Do not forget to shut the instance down manually"
47 | exit 0
48 | fi
49 |
50 |
51 |
52 | if [ -n $autoscaling_group ]; then
53 |
54 | echo "Getting info for auto-scaling group $autoscaling_group"
55 |
56 | asg_info="aws autoscaling describe-auto-scaling-groups --auto-scaling-group-name $autoscaling_group"
57 | desired_size=$( $asg_info | jq --raw-output ".AutoScalingGroups | .[0] | .DesiredCapacity" )
58 | launch_config=$( $asg_info | jq --raw-output ".AutoScalingGroups | .[0] | .LaunchConfigurationName" )
59 |
60 | echo "Launch config: $launch_config"
61 | echo "Current autoscaling group size (desired): $desired_size"
62 |
63 | if [[ $desired_size -gt 1 ]]; then
64 | echo "Detaching myself ($instance_id) from the ASG $autoscaling_group"
65 | aws autoscaling detach-instances --instance-ids $instance_id --auto-scaling-group-name $autoscaling_group --should-decrement-desired-capacity
66 | #new_desired_size=$((desired_size - 1))
67 | #echo "Decreasing ASG size to $new_desired_size"
68 | #aws autoscaling update-auto-scaling-group --auto-scaling-group-name $autoscaling_group --desired-capacity $new_desired_size
69 | else
70 | echo "Deleting launch configuration and auto-scaling group"
71 | aws autoscaling delete-auto-scaling-group --auto-scaling-group-name $autoscaling_group --force-delete
72 | aws autoscaling delete-launch-configuration --launch-configuration-name $launch_config
73 | fi
74 | # if desired_size > 1 decrease desired size (with cooldown - so that it does not try to remove any other instances!)
75 | # else delete the group - that should to the shutdown
76 | #
77 |
78 | fi
79 | aws s3 cp /var/log/cloud-init-output.log "s3://studioml-logs/$queue_name/$hostname.txt"
80 | echo "Shutting the instance down!"
81 | sudo shutdown now
82 |
--------------------------------------------------------------------------------
/studio/util/gpu_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import xml.etree.ElementTree as ET
4 |
5 | from studio.util.util import sixdecode
6 |
7 |
8 | def memstr2int(string):
9 | conversion_factors = [
10 | ('Mb', 2**20), ('MiB', 2**20), ('m', 2**20), ('mb', 2**20),
11 | ('Gb', 2**30), ('GiB', 2**30), ('g', 2**30), ('gb', 2**30),
12 | ('kb', 2**10), ('k', 2**10)
13 | ]
14 |
15 | for key, factor in conversion_factors:
16 | if string.endswith(key):
17 | return int(float(string.replace(key, '')) * factor)
18 |
19 | return int(string)
20 |
21 |
22 | def get_available_gpus(gpu_mem_needed=None, strict=False):
23 | gpus = _get_gpu_info()
24 |
25 | def check_gpu_nomem_strict(gpu):
26 | return memstr2int(gpu.find('fb_memory_usage').find('used').text) < \
27 | memstr2int(gpu.find('fb_memory_usage').find('free').text)
28 | if gpu_mem_needed is None:
29 | if strict:
30 | return [gpu.find('minor_number').text for gpu in gpus if
31 | check_gpu_nomem_strict(gpu)]
32 | return [gpu.find('minor_number').text for gpu in gpus]
33 |
34 | gpu_mem_needed = memstr2int(gpu_mem_needed)
35 |
36 | def check_gpu_mem_strict(gpu):
37 | return gpu_mem_needed < \
38 | memstr2int(gpu.find('fb_memory_usage').find('free').text)
39 |
40 | def check_gpu_mem_loose(gpu):
41 | return (gpu_mem_needed <
42 | memstr2int(gpu.find('fb_memory_usage').find('total').text)) \
43 | and check_gpu_nomem_strict(gpu)
44 |
45 | if strict:
46 | return [gpu.find('minor_number').text
47 | for gpu in gpus if check_gpu_mem_strict(gpu)]
48 | return [gpu.find('minor_number').text
49 | for gpu in gpus if check_gpu_mem_loose(gpu)]
50 |
51 |
52 | def _get_gpu_info():
53 | try:
54 | with subprocess.Popen(['nvidia-smi', '-q', '-x'],
55 | stdout=subprocess.PIPE,
56 | stderr=subprocess.STDOUT) as smi_proc:
57 | smi_output, _ = smi_proc.communicate()
58 | xmlroot = ET.fromstring(sixdecode(smi_output))
59 | return xmlroot.findall('gpu')
60 | except Exception:
61 | return []
62 |
63 |
64 | def get_gpus_summary():
65 | info = _get_gpu_info()
66 |
67 | def info_to_summary(gpuinfo):
68 | util = gpuinfo.find('utilization').find('gpu_util').text
69 | mem = gpuinfo.find('fb_memory_usage').find('used').text
70 |
71 | return "util: {}, mem {}".format(util, memstr2int(mem))
72 |
73 | return " ".join([
74 | "gpu {} {}".format(
75 | gpuinfo.find('minor_number').text,
76 | info_to_summary(gpuinfo)) for gpuinfo in info])
77 |
78 |
79 | def get_gpu_mapping():
80 | no_gpus = len(_get_gpu_info())
81 | return {str(i): i for i in range(no_gpus)}
82 |
83 |
84 | def _find_my_gpus(prop='minor_number'):
85 | gpu_info = _get_gpu_info()
86 | my_gpus = [g.find(prop).text for g in gpu_info if os.getpid() in [int(
87 | p.find('pid').text) for p in
88 | g.find('processes').findall('process_info')]]
89 |
90 | return my_gpus
91 |
92 |
93 | if __name__ == "__main__":
94 | print(get_gpu_mapping())
95 |
--------------------------------------------------------------------------------
/studio/torch/summary.py:
--------------------------------------------------------------------------------
1 | """Tools to simplify PyTorch reporting and integrate with TensorBoard."""
2 |
3 | import collections
4 | import six
5 | import time
6 |
7 | try:
8 | from tensorflow import summary as tb_summary
9 | except ImportError:
10 | tb_summary = None
11 |
12 |
13 | class TensorBoardWriter(object):
14 | """Write events in TensorBoard format."""
15 |
16 | def __init__(self, logdir):
17 | if tb_summary is None:
18 | raise ValueError(
19 | "You must install TensorFlow " +
20 | "to use Tensorboard summary writer.")
21 | self._writer = tb_summary.FileWriter(logdir)
22 |
23 | def add(self, step, key, value):
24 | summary = tb_summary.Summary()
25 | summary_value = summary.value.add()
26 | summary_value.tag = key
27 | summary_value.simple_value = value
28 | self._writer.add_summary(summary, global_step=step)
29 |
30 | def flush(self):
31 | self._writer.flush()
32 |
33 | def close(self):
34 | self._writer.close()
35 |
36 |
37 | class Reporter(object):
38 | """Manages reporting of metrics."""
39 |
40 | def __init__(self, log_interval=10, logdir=None, smooth_interval=10):
41 | self._writer = None
42 | if logdir:
43 | self._writer = TensorBoardWriter(logdir)
44 | self._last_step = 0
45 | self._last_reported_step = None
46 | self._last_reported_time = None
47 | self._log_interval = log_interval
48 | self._smooth_interval = smooth_interval
49 | self._metrics = collections.defaultdict(collections.deque)
50 |
51 | def record(self, step, **kwargs):
52 | for key, value in six.iteritems(kwargs):
53 | self.add(step, key, value)
54 |
55 | def add(self, step, key, value):
56 | self._last_step = step
57 | self._metrics[key].append(value)
58 | if len(self._metrics[key]) > self._smooth_interval:
59 | self._metrics[key].popleft()
60 | if self._last_step % self._log_interval == 0:
61 | if self._writer:
62 | self._writer.add(step, key, value)
63 |
64 | def report(self, stdout=None):
65 | if self._last_step % self._log_interval == 0:
66 | def smooth(values):
67 | return (sum(values) / len(values)) if values else 0.0
68 | metrics = ','.join(["%s = %.5f" % (k, smooth(v))
69 | for k, v in six.iteritems(self._metrics)])
70 | if self._last_reported_time:
71 | elapsed_secs = time.time() - self._last_reported_time
72 | metrics += " (%.3f sec)" % elapsed_secs
73 | if self._writer:
74 | elapsed_steps = float(
75 | self._last_step - self._last_reported_step)
76 | self._writer.add(
77 | self._last_step, 'step/sec',
78 | elapsed_steps / elapsed_secs)
79 |
80 | line = u"Step {}: {}".format(self._last_step, metrics)
81 | if stdout:
82 | stdout.write(line)
83 | else:
84 | print(line)
85 |
86 | self._last_reported_time = time.time()
87 | self._last_reported_step = self._last_step
88 |
--------------------------------------------------------------------------------
/studio/storage/storage_handler.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | import time
4 | from typing import Dict
5 |
6 | from studio.storage.storage_type import StorageType
7 | from studio.util.util import get_temp_filename, check_for_kb_interrupt
8 |
9 | # StorageHandler encapsulates the logic of basic storage operations
10 | # for specific storage endpoint (S3, http, local etc.)
11 | # together with access credentials for this endpoint.
12 | class StorageHandler:
13 | def __init__(self, storage_type: StorageType,
14 | logger,
15 | measure_timestamp_diff=False,
16 | compression=None):
17 | self.type = storage_type
18 | self.logger = logger
19 | self.compression = compression
20 | self._timestamp_shift = 0
21 | if measure_timestamp_diff:
22 | try:
23 | self._timestamp_shift = self._measure_timestamp_diff()
24 | except BaseException:
25 | check_for_kb_interrupt()
26 | self._timestamp_shift = 0
27 |
28 | @classmethod
29 | def get_id(cls, config: Dict) -> str:
30 | raise NotImplementedError("Not implemented: upload_file")
31 |
32 | def upload_file(self, key, local_path):
33 | raise NotImplementedError("Not implemented: upload_file")
34 |
35 | def download_file(self, key, local_path):
36 | raise NotImplementedError("Not implemented: download_file")
37 |
38 | def download_remote_path(self, remote_path, local_path):
39 | raise NotImplementedError("Not implemented: download_remote_path")
40 |
41 | def delete_file(self, key, shallow=True):
42 | raise NotImplementedError("Not implemented: delete_file")
43 |
44 | def get_file_url(self, key, method='GET'):
45 | raise NotImplementedError("Not implemented: get_file_url")
46 |
47 | def get_file_timestamp(self, key):
48 | raise NotImplementedError("Not implemented: get_file_timestamp")
49 |
50 | def get_qualified_location(self, key):
51 | raise NotImplementedError("Not implemented: get_qualified_location")
52 |
53 | def get_local_destination(self, remote_path: str):
54 | raise NotImplementedError("Not implemented: get_local_destination")
55 |
56 | def get_timestamp_shift(self):
57 | return self._timestamp_shift
58 |
59 | def get_compression(self):
60 | return self.compression
61 |
62 | def cleanup(self):
63 | pass
64 |
65 | def _measure_timestamp_diff(self):
66 | max_diff = 60
67 | tmpfile = get_temp_filename() + '.txt'
68 | with open(tmpfile, 'w') as f:
69 | f.write('timestamp_diff_test')
70 | key = 'tests/' + str(uuid.uuid4())
71 | self.upload_file(key, tmpfile)
72 | remote_timestamp = self.get_file_timestamp(key)
73 |
74 | if remote_timestamp is not None:
75 | now_remote_diff = time.time() - remote_timestamp
76 | self.storage_handler.delete_file(key)
77 | os.remove(tmpfile)
78 |
79 | assert -max_diff < now_remote_diff and \
80 | now_remote_diff < max_diff, \
81 | "Timestamp difference is more than 60 seconds. " + \
82 | "You'll need to adjust local clock for caching " + \
83 | "to work correctly"
84 |
85 | return -now_remote_diff if now_remote_diff < 0 else 0
86 |
87 |
--------------------------------------------------------------------------------
/tests/http_provider_hosted_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | import tempfile
4 | import uuid
5 |
6 | from studio import model
7 | from model_test import get_test_experiment
8 |
9 | # We are not currently working with HTTP providers.
10 | @unittest.skip
11 | class HTTPProviderHostedTest(unittest.TestCase):
12 |
13 | def get_db_provider(self, config_name):
14 | config_file = os.path.join(
15 | os.path.dirname(
16 | os.path.realpath(__file__)),
17 | config_name)
18 | return model.get_db_provider(model.get_config(config_file))
19 |
20 | def test_add_get_delete_experiment(self):
21 | with self.get_db_provider('test_config_http_client.yaml') as hp:
22 |
23 | experiment_tuple = get_test_experiment()
24 | hp.add_experiment(experiment_tuple[0])
25 | experiment = hp.get_experiment(experiment_tuple[0].key)
26 | self.assertEquals(experiment.key, experiment_tuple[0].key)
27 | self.assertEquals(
28 | experiment.filename,
29 | experiment_tuple[0].filename)
30 | self.assertEquals(experiment.args, experiment_tuple[0].args)
31 |
32 | hp.delete_experiment(experiment_tuple[1])
33 |
34 | self.assertTrue(hp.get_experiment(experiment_tuple[1]) is None)
35 |
36 | def test_start_experiment(self):
37 | with self.get_db_provider('test_config_http_client.yaml') as hp:
38 | experiment_tuple = get_test_experiment()
39 |
40 | hp.add_experiment(experiment_tuple[0])
41 | hp.start_experiment(experiment_tuple[0])
42 |
43 | experiment = hp.get_experiment(experiment_tuple[1])
44 |
45 | self.assertTrue(experiment.status == 'running')
46 |
47 | self.assertEquals(experiment.key, experiment_tuple[0].key)
48 | self.assertEquals(
49 | experiment.filename,
50 | experiment_tuple[0].filename)
51 | self.assertEquals(experiment.args, experiment_tuple[0].args)
52 |
53 | hp.finish_experiment(experiment_tuple[0])
54 | hp.delete_experiment(experiment_tuple[1])
55 |
56 | def test_add_get_experiment_artifacts(self):
57 | experiment_tuple = get_test_experiment()
58 | e_experiment = experiment_tuple[0]
59 | e_artifacts = e_experiment.artifacts
60 |
61 | a1_filename = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
62 | a2_filename = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
63 |
64 | with open(a1_filename, 'w') as f:
65 | f.write('hello world')
66 |
67 | e_artifacts['a1'] = {
68 | 'local': a1_filename,
69 | 'mutable': False
70 | }
71 |
72 | e_artifacts['a2'] = {
73 | 'local': a2_filename,
74 | 'mutable': True
75 | }
76 |
77 | with self.get_db_provider('test_config_http_client.yaml') as db:
78 | db.add_experiment(e_experiment)
79 |
80 | experiment = db.get_experiment(e_experiment.key)
81 | self.assertEquals(experiment.key, e_experiment.key)
82 | self.assertEquals(experiment.filename, e_experiment.filename)
83 | self.assertEquals(experiment.args, e_experiment.args)
84 | db.delete_experiment(e_experiment.key)
85 | os.remove(a1_filename)
86 |
87 |
88 | if __name__ == '__main__':
89 | unittest.main()
90 |
--------------------------------------------------------------------------------
/studio/experiment_submitter.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | import traceback
5 |
6 | from studio.artifacts.artifact import Artifact
7 | from studio.db_providers import db_provider_setup
8 | from studio.experiments.experiment import Experiment
9 | from studio import git_util
10 | from studio.payload_builders.payload_builder import PayloadBuilder
11 | from studio.payload_builders.unencrypted_payload_builder import UnencryptedPayloadBuilder
12 | from studio.encrypted_payload_builder import EncryptedPayloadBuilder
13 | from studio.storage import storage_setup
14 | from studio.util import util
15 |
16 | def submit_experiments(
17 | experiments,
18 | config=None,
19 | logger=None,
20 | queue=None,
21 | python_pkg=[],
22 | external_payload_builder: PayloadBuilder=None):
23 |
24 | num_experiments = len(experiments)
25 |
26 | payload_builder = external_payload_builder
27 | if payload_builder is None:
28 | # Setup our own payload builder
29 | payload_builder = UnencryptedPayloadBuilder("simple-payload")
30 | # Are we using experiment payload encryption?
31 | public_key_path = config.get('public_key_path', None)
32 | if public_key_path is not None:
33 | logger.info("Using RSA public key path: {0}".format(public_key_path))
34 | signing_key_path = config.get('signing_key_path', None)
35 | if signing_key_path is not None:
36 | logger.info("Using RSA signing key path: {0}".format(signing_key_path))
37 | payload_builder = \
38 | EncryptedPayloadBuilder(
39 | "cs-rsa-encryptor [{0}]".format(public_key_path),
40 | public_key_path, signing_key_path)
41 |
42 | start_time = time.time()
43 |
44 | # Reset our storage setup, which will guarantee
45 | # that we rebuild our database and storage provider objects
46 | # that's important in the case that previous experiment batch
47 | # cleaned up after itself.
48 | storage_setup.reset_storage()
49 |
50 | for experiment in experiments:
51 | # Update Python environment info for our experiments:
52 | experiment.pythonenv = util.add_packages(experiment.pythonenv, python_pkg)
53 |
54 | # Add experiment to database:
55 | try:
56 | with db_provider_setup.get_db_provider(config) as db:
57 | _add_git_info(experiment, logger)
58 | db.add_experiment(experiment)
59 | except BaseException:
60 | traceback.print_exc()
61 | raise
62 |
63 | payload = payload_builder.construct(experiment, config, python_pkg)
64 |
65 | logger.info("Submitting experiment: {0}"
66 | .format(json.dumps(payload, indent=4)))
67 |
68 | queue.enqueue(json.dumps(payload))
69 | logger.info("studio run: submitted experiment " + experiment.key)
70 |
71 | logger.info("Added {0} experiment(s) in {1} seconds to queue {2}"
72 | .format(num_experiments, int(time.time() - start_time), queue.get_name()))
73 | return queue.get_name()
74 |
75 | def _add_git_info(experiment: Experiment, logger):
76 | wrk_space: Artifact = experiment.artifacts.get('workspace', None)
77 | if wrk_space is not None:
78 | if wrk_space.local_path is not None and \
79 | os.path.exists(wrk_space.local_path):
80 | if logger is not None:
81 | logger.info("git location for experiment %s", wrk_space.local_path)
82 | experiment.git = git_util.get_git_info(wrk_space.local_path)
83 |
--------------------------------------------------------------------------------
/studio/scripts/gcloud_worker_startup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | exec > >(tee -i ~/gcloud_worker_logfile.txt)
4 | exec 2>&1
5 |
6 | metadata_url="http://metadata.google.internal/computeMetadata/v1/instance"
7 | queue_name=$(curl "$metadata_url/attributes/queue_name" -H "Metadata-Flavor: Google")
8 | key_name=$(curl "$metadata_url/attributes/auth_key" -H "Metadata-Flavor: Google")
9 | timeout=$(curl "$metadata_url/attributes/timeout" -H "Metadata-Flavor: Google")
10 |
11 | zone=$(curl "$metadata_url/zone" -H "Metadata-Flavor: Google")
12 | instance_name=$(curl "$metadata_url/name" -H "Metadata-Flavor: Google")
13 | group_name=$(curl "$metadata_url/attributes/groupname" -H "Metadata-Flavor: Google")
14 |
15 | echo Instance name is $instance_name
16 | echo Group name is $group_name
17 |
18 | cd ~
19 |
20 | mkdir -p .studioml/keys
21 | curl "$metadata_url/attributes/auth_data" -H "Metadata-Flavor: Google" > .studioml/keys/$key_name
22 | curl "$metadata_url/attributes/credentials" -H "Metadata-Flavor: Google" > credentials.json
23 | export GOOGLE_APPLICATION_CREDENTIALS=~/credentials.json
24 |
25 |
26 | : "${{GOOGLE_APPLICATION_CREDENTIALS?Need to point GOOGLE_APPLICATION_CREDENTIALS to the google credentials file}}"
27 | : "${{queue_name?Queue name is not specified (pass as a script argument}}"
28 |
29 | gac_path=${{GOOGLE_APPLICATION_CREDENTIALS%/*}}
30 | gac_name=${{GOOGLE_APPLICATION_CREDENTIALS##*/}}
31 | #bash_cmd="git clone $repo && \
32 | # cd studio && \
33 | # git checkout $branch && \
34 | # sudo pip install --upgrade pip && \
35 | # sudo pip install -e . --upgrade && \
36 | # mkdir /workspace && cd /workspace && \
37 | # studio-rworker --queue=$queue_name"
38 |
39 | code_url_base="https://storage.googleapis.com/studio-ed756.appspot.com/src"
40 | #code_ver="tfstudio-64_config_location-2017-08-04_1.tgz"
41 |
42 | echo "Environment varibles:"
43 | env
44 |
45 | {install_studio}
46 |
47 | python $(which studio-remote-worker) --queue=$queue_name --verbose=debug --timeout=$timeout
48 |
49 | logbucket={log_bucket}
50 | if [[ -n $logbucket ]]; then
51 | gsutil cp /var/log/syslog gs://$logbucket/$queue_name/$instance_name.log
52 | fi
53 |
54 | if [[ -n $(who) ]]; then
55 | echo "Users logged in, preventing auto-shutdown"
56 | echo "Do not forget to turn the instance off manually"
57 | exit 0
58 | fi
59 |
60 | # shutdown the instance
61 | not_spot=$(echo "$group_name" | grep "Error 404" | wc -l)
62 | echo "not_spot = $not_spot"
63 |
64 | if [[ "$not_spot" -eq "0" ]]; then
65 | current_size=$(gcloud compute instance-groups managed describe $group_name --zone $zone | grep "targetSize" | awk '{{print $2}}')
66 | echo Current group size is $current_size
67 | if [[ $current_size -gt 1 ]]; then
68 | echo "Deleting myself (that is, $instance_name) from $group_name"
69 | gcloud compute instance-groups managed delete-instances $group_name --zone $zone --instances $instance_name
70 | else
71 | template=$(gcloud compute instance-groups managed describe $group_name --zone $zone | grep "instanceTemplate" | awk '{{print $2}}')
72 | echo "Detaching myself, deleting group $group_name and the template $template"
73 | gcloud compute instance-groups managed abandon-instances $group_name --zone $zone --instances $instance_name
74 | sleep 5
75 | gcloud compute instance-groups managed delete $group_name --zone $zone --quiet
76 | sleep 5
77 | gcloud compute instance-templates delete $template --quiet
78 | fi
79 |
80 | fi
81 | echo "Shutting down"
82 | gcloud compute instances delete $instance_name --zone $zone --quiet
83 |
--------------------------------------------------------------------------------
/studio/db_providers/s3_provider.py:
--------------------------------------------------------------------------------
1 | import json
2 | from studio.db_providers.keyvalue_provider import KeyValueProvider
3 | from studio.storage.storage_handler_factory import StorageHandlerFactory
4 | from studio.storage.storage_type import StorageType
5 |
6 |
7 | class S3Provider(KeyValueProvider):
8 |
9 | def __init__(self, config, blocking_auth=True):
10 | self.config = config
11 | self.bucket = config.get('bucket', 'studioml-meta')
12 |
13 | factory: StorageHandlerFactory = StorageHandlerFactory.get_factory()
14 | self.meta_store = factory.get_handler(StorageType.storageS3, config)
15 |
16 | super().__init__(
17 | config,
18 | self.meta_store,
19 | blocking_auth)
20 |
21 | def _get(self, key, shallow=False):
22 | try:
23 | response = self.meta_store.client.list_objects(
24 | Bucket=self.bucket,
25 | Prefix=key,
26 | Delimiter='/',
27 | MaxKeys=1024*16
28 | )
29 | except Exception as exc:
30 | msg: str = "FAILED to list objects in bucket {0}: {1}"\
31 | .format(self.bucket, exc)
32 | self._report_fatal(msg)
33 | return None
34 |
35 | if response is None:
36 | return None
37 |
38 | if 'Contents' not in response.keys():
39 | return None
40 |
41 | key_count = len(response['Contents'])
42 |
43 | if key_count == 0:
44 | return None
45 |
46 | for key_item in response['Contents']:
47 | if 'Key' in key_item.keys() and key_item['Key'] == key:
48 | response = self.meta_store.client.get_object(
49 | Bucket=self.bucket,
50 | Key=key)
51 | return json.loads(response['Body'].read().decode("utf-8"))
52 |
53 | return None
54 |
55 | def _delete(self, key, shallow=True):
56 | self.logger.info("S3 deleting object: %s/%s", self.bucket, key)
57 |
58 | try:
59 | response = self.meta_store.client.delete_object(
60 | Bucket=self.bucket,
61 | Key=key)
62 | except Exception as exc:
63 | msg: str = "FAILED to delete object {0} in bucket {1}: {2}"\
64 | .format(key, self.bucket, exc)
65 | self.logger.info(msg)
66 | return
67 |
68 | reason = response['ResponseMetadata'] if response else "None"
69 | if response is None or\
70 | response['ResponseMetadata']['HTTPStatusCode'] != 204:
71 | msg: str = ('attempt to delete key {0} in bucket {1}' +
72 | ' returned response {2}')\
73 | .format(key, self.bucket, reason)
74 | self.logger.info(msg)
75 |
76 | def _set(self, key, value):
77 | try:
78 | response = self.meta_store.client.put_object(
79 | Bucket=self.bucket,
80 | Key=key,
81 | Body=json.dumps(value))
82 | except Exception as exc:
83 | msg: str = "FAILED to write object {0} in bucket {1}: {2}"\
84 | .format(key, self.bucket, exc)
85 | self._report_fatal(msg)
86 |
87 | reason = response['ResponseMetadata'] if response else "None"
88 | if response is None or \
89 | response['ResponseMetadata']['HTTPStatusCode'] != 200:
90 | msg: str = ('attempt to write key {0} in bucket {1}' +
91 | ' returned response {2}')\
92 | .format(key, self.bucket, reason)
93 | self._report_fatal(msg)
94 |
--------------------------------------------------------------------------------
/studio/completion_service/encryptor.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from Crypto.PublicKey import RSA
3 | from Crypto.Cipher import PKCS1_OAEP
4 | import nacl.secret
5 | import nacl.utils
6 | import base64
7 |
8 | class Encryptor:
9 | """
10 | Implementation for experiment payload builder
11 | using public key RSA encryption.
12 | """
13 | def __init__(self, keypath: str):
14 | """
15 | param: keypath - file path to .pem file with public key
16 | """
17 |
18 | self.key_path = keypath
19 | self.recipient_key = None
20 | try:
21 | self.recipient_key = RSA.import_key(open(self.key_path).read())
22 | except:
23 | print(
24 | "FAILED to import recipient public key from: {0}".format(self.key_path))
25 | return
26 |
27 | def _import_rsa_key(self, key_path: str):
28 | key = None
29 | try:
30 | key = RSA.import_key(open(key_path).read())
31 | except:
32 | self.logger.error(
33 | "FAILED to import RSA key from: {0}".format(key_path))
34 | key = None
35 | return key
36 |
37 | def _encrypt_str(self, workload: str):
38 | # Generate one-time symmetric session key:
39 | session_key = nacl.utils.random(32)
40 |
41 | # Encrypt the data with the NaCL session key
42 | data_to_encrypt = workload.encode("utf-8")
43 | box_out = nacl.secret.SecretBox(session_key)
44 | encrypted_data = box_out.encrypt(data_to_encrypt)
45 | encrypted_data_text = base64.b64encode(encrypted_data)
46 |
47 | # Encrypt the session key with the public RSA key
48 | cipher_rsa = PKCS1_OAEP.new(self.recipient_key)
49 | encrypted_session_key = cipher_rsa.encrypt(session_key)
50 | encrypted_session_key_text = base64.b64encode(encrypted_session_key)
51 |
52 | return encrypted_session_key_text, encrypted_data_text
53 |
54 | def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_text):
55 | private_key = self._import_rsa_key(private_key_path)
56 | if private_key is None:
57 | return None
58 |
59 | try:
60 | private_key = RSA.import_key(open(private_key_path).read())
61 | except:
62 | self.logger.error(
63 | "FAILED to import private key from: {0}".format(private_key_path))
64 | return None
65 |
66 | # Decrypt the session key with the private RSA key
67 | cipher_rsa = PKCS1_OAEP.new(private_key)
68 | session_key = cipher_rsa.decrypt(
69 | base64.b64decode(encrypted_key_text))
70 |
71 | # Decrypt the data with the NaCL session key
72 | box_in = nacl.secret.SecretBox(session_key)
73 | decrypted_data = box_in.decrypt(
74 | base64.b64decode(encrypted_data_text))
75 | decrypted_data = decrypted_data.decode("utf-8")
76 |
77 | return decrypted_data
78 |
79 | def encrypt(self, payload: str):
80 | enc_key, enc_payload = self._encrypt_str(payload)
81 |
82 | enc_key_str = enc_key.decode("utf-8")
83 | enc_payload_str = enc_payload.decode("utf-8")
84 |
85 | return "{0},{1}".format(enc_key_str, enc_payload_str)
86 |
87 | def main():
88 | if len(sys.argv) < 3:
89 | print("USAGE {0} public-key-file-path string-to-encrypt"
90 | .format(sys.argv[0]))
91 | return
92 |
93 | encryptor = Encryptor(sys.argv[1])
94 | data = sys.argv[2]
95 | print(data)
96 | result = encryptor.encrypt(data)
97 | print(result)
98 |
99 | if __name__ == '__main__':
100 | main()
101 |
--------------------------------------------------------------------------------
/docs/README.rst:
--------------------------------------------------------------------------------
1 | .. raw:: html
2 |
3 |
4 |
5 |
6 |
7 | |Hex.pm| |Build.pm|
8 |
9 | Studio is a model management framework written in Python to help simplify and expedite your model building experience. It was developed to minimize the overhead involved with scheduling, running, monitoring and managing artifacts of your machine learning experiments. No one wants to spend their time configuring different machines, setting up dependencies, or playing archeologist to track down previous model artifacts.
10 |
11 | Most of the features are compatible with any Python machine learning
12 | framework (`Keras `__,
13 | `TensorFlow `__,
14 | `PyTorch `__,
15 | `scikit-learn `__, etc);
16 | some extra features are available for Keras and TensorFlow.
17 |
18 | **Use Studio to:**
19 |
20 | * Capture experiment information- Python environment, files, dependencies and logs- without modifying the experiment code.
21 | * Monitor and organize experiments using a web dashboard that integrates with TensorBoard.
22 | * Run experiments locally, remotely, or in the cloud (Google Cloud or Amazon EC2)
23 | * Manage artifacts
24 | * Perform hyperparameter search
25 | * Create customizable Python environments for remote workers.
26 |
27 | NOTE: ``studio`` package is compatible with Python 2 and 3!
28 |
29 | Example usage
30 | -------------
31 |
32 | Start visualizer:
33 |
34 | ::
35 |
36 | studio ui
37 |
38 | Run your jobs:
39 |
40 | ::
41 |
42 | studio run train_mnist_keras.py
43 |
44 | You can see results of your job at http://localhost:5000. Run
45 | ``studio {ui|run} --help`` for a full list of ui / runner options.
46 | WARNING: because studio tries to create a reproducible environment
47 | for your experiment, if you run it in a large folder, it will take
48 | a while to archive and upload the folder.
49 |
50 | Installation
51 | ------------
52 |
53 | pip install studioml from the master pypi repositry:
54 |
55 | ::
56 |
57 | pip install studioml
58 |
59 | Find more `details `__ on installation methods and the release process.
60 |
61 | Authentication
62 | --------------
63 |
64 | Currently Studio supports 2 methods of authentication: `email / password `__ and using a `Google account. `__ To use studio runner and studio ui in guest
65 | mode, in studio/default\_config.yaml, uncomment "guest: true" under the
66 | database section.
67 |
68 | Alternatively, you can set up your own database and configure Studio to
69 | use it. See `setting up database `__. This is a
70 | preferred option if you want to keep your models and artifacts private.
71 |
72 |
73 | Further reading and cool features
74 | ---------------------------------
75 |
76 | - `Running experiments remotely `__
77 |
78 | - `Custom Python environments for remote workers `__
79 |
80 | - `Running experiments in the cloud `__
81 |
82 | - `Google Cloud setup instructions `__
83 |
84 | - `Amazon EC2 setup instructions `__
85 |
86 | - `Artifact management `__
87 | - `Hyperparameter search `__
88 | - `Pipelines for trained models `__
89 | - `Containerized experiments `__
90 |
91 | .. |Hex.pm| image:: https://img.shields.io/hexpm/l/plug.svg
92 | :target: https://github.com/studioml/studio/blob/master/LICENSE
93 |
94 | .. |Build.pm| image:: https://travis-ci.org/studioml/studio.svg?branch=master
95 | :target: https://travis-ci.org/studioml/studio.svg?branch=master
96 |
--------------------------------------------------------------------------------
/studio/magics.py:
--------------------------------------------------------------------------------
1 | # This code can be put in any Python module, it does not require IPython
2 | # itself to be running already. It only creates the magics subclass but
3 | # doesn't instantiate it yet.
4 | # from __future__ import print_function
5 | import pickle
6 | from IPython.core.magic import (Magics, magics_class, line_cell_magic)
7 |
8 | from types import ModuleType
9 | import six
10 | import subprocess
11 | import uuid
12 | import os
13 | import time
14 | import gzip
15 | from apscheduler.schedulers.background import BackgroundScheduler
16 |
17 | from studio.extra_util import rsync_cp
18 | from studio import fs_tracker, model
19 | from studio.runner import main as runner_main
20 |
21 | from studio.util import logs
22 |
23 |
24 | @magics_class
25 | class StudioMagics(Magics):
26 |
27 | @line_cell_magic
28 | def studio_run(self, line, cell=None):
29 | script_text = []
30 | pickleable_ns = {}
31 |
32 | for varname, var in six.iteritems(self.shell.user_ns):
33 | if not varname.startswith('__'):
34 | if isinstance(var, ModuleType) and \
35 | var.__name__ != 'studio.magics':
36 | script_text.append(
37 | 'import {} as {}'.format(var.__name__, varname)
38 | )
39 |
40 | else:
41 | try:
42 | pickle.dumps(var)
43 | pickleable_ns[varname] = var
44 | except BaseException:
45 | pass
46 |
47 | script_text.append(cell)
48 | script_text = '\n'.join(script_text)
49 | stub_path = os.path.join(
50 | os.path.dirname(os.path.realpath(__file__)),
51 | 'run_magic.py.stub')
52 |
53 | with open(stub_path) as f:
54 | script_stub = f.read()
55 |
56 | script = script_stub.format(script=script_text)
57 |
58 | experiment_key = str(int(time.time())) + \
59 | "_jupyter_" + str(uuid.uuid4())
60 |
61 | print('Running studio with experiment key ' + experiment_key)
62 | config = model.get_config()
63 | if config['database']['type'] == 'http':
64 | print("Experiment progress can be viewed/shared at:")
65 | print("{}/experiment/{}".format(
66 | config['database']['serverUrl'],
67 | experiment_key))
68 |
69 | workspace_new = fs_tracker.get_artifact_cache(
70 | 'workspace', experiment_key)
71 |
72 | rsync_cp('.', workspace_new)
73 | with open(os.path.join(workspace_new, '_script.py'), 'w') as f:
74 | f.write(script)
75 |
76 | ns_path = fs_tracker.get_artifact_cache('_ns', experiment_key)
77 |
78 | with gzip.open(ns_path, 'wb') as f:
79 | f.write(pickle.dumps(pickleable_ns))
80 |
81 | if any(line):
82 | runner_args = line.strip().split(' ')
83 | else:
84 | runner_args = []
85 |
86 | runner_args.append('--capture={}:_ns'.format(ns_path))
87 | runner_args.append('--capture-once=.:workspace')
88 | runner_args.append('--force-git')
89 | runner_args.append('--experiment=' + experiment_key)
90 |
91 | notebook_cwd = os.getcwd()
92 | os.chdir(workspace_new)
93 | print(runner_args + ['_script.py'])
94 | runner_main(runner_args + ['_script.py'])
95 | os.chdir(notebook_cwd)
96 |
97 | with model.get_db_provider() as db:
98 | while True:
99 | experiment = db.get_experiment(experiment_key)
100 | if experiment and experiment.status == 'finished':
101 | break
102 |
103 | time.sleep(10)
104 |
105 | new_ns_path = db.get_artifact(experiment.artifacts['_ns'])
106 |
107 | with open(new_ns_path) as f:
108 | new_ns = pickle.loads(f.read())
109 |
110 | self.shell.user_ns.update(new_ns)
111 |
112 |
113 | ip = get_ipython()
114 | ip.register_magics(StudioMagics)
115 |
--------------------------------------------------------------------------------
/studio/scripts/studio-start-remote-worker:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Script that starts remote worker
4 | # Usage:
5 | #
6 | # studio-start-remote-worker --queue= [--image= argument"
57 | exit 1
58 | fi
59 |
60 | echo "Docker image = $docker_img"
61 | echo "Queue = $queue_name"
62 |
63 | eval $docker_cmd
64 | if [ $? != 0 ]; then
65 | echo "Docker not installed! Install docker."
66 | exit 1
67 | fi
68 |
69 | eval nvidia-smi
70 | if [ $? == 0 ]; then
71 | eval nvidia-docker
72 | if [ $? == 0 ]; then
73 | docker_cmd=nvidia-docker
74 | bash_cmd=""
75 | else
76 | echo "Warning! nvidia-docker is not installed correctly, won't be able to use gpus"
77 | fi
78 | fi
79 |
80 | echo "docker_cmd = $docker_cmd"
81 | if [ $docker_cmd == 'docker' ]; then
82 | bash_cmd="$bash_cmd python -m pip install -e ./studio && python3 -m pip install -e ./studio && "
83 | fi
84 |
85 | bash_cmd="$bash_cmd studio-remote-worker --queue=$queue_name --verbose=debug --timeout=$timeout --single-run"
86 |
87 | : "${GOOGLE_APPLICATION_CREDENTIALS?Need to point GOOGLE_APPLICATION_CREDENTIALS to the google credentials file}"
88 | : "${queue_name?Queue name is not specified (pass as a script argument}"
89 |
90 | gac_path=${GOOGLE_APPLICATION_CREDENTIALS%/*}
91 | gac_name=${GOOGLE_APPLICATION_CREDENTIALS##*/}
92 |
93 | #bash_cmd="git clone $repo && \
94 | # cd studio && \
95 | # git checkout $branch && \
96 | # sudo pip install --upgrade pip && \
97 | # sudo pip install -e . --upgrade && \
98 | # mkdir /workspace && cd /workspace && \
99 | # studio-rworker --queue=$queue_name"
100 |
101 |
102 | # loop until killed
103 |
104 | docker_args="run --rm --pid=host "
105 | if [[ $no_cache -ne 1 ]]; then
106 | docker_args="$docker_args -v $HOME/.studioml/experiments:/root/.studioml/experiemnts"
107 | docker_args="$docker_args -v $HOME/.studioml/blobcache:/root/.studioml/blobcache"
108 | fi
109 |
110 | if [[ $baked_credentials -ne 1 ]]; then
111 | docker_args="$docker_args -v $HOME/.studioml/keys:/root/.studioml/keys"
112 | docker_args="$docker_args -v $gac_path:/creds -e GOOGLE_APPLICATION_CREDENTIALS=/creds/$gac_name"
113 | fi
114 |
115 | echo "Docker args = $docker_args"
116 | echo "Docker image = $docker_img"
117 |
118 | while true
119 | do
120 | echo $bash_cmd
121 | $docker_cmd pull $docker_img
122 | $docker_cmd $docker_args $docker_img /bin/bash -c "$bash_cmd"
123 | if [ $single_run ];
124 | then
125 | exit 0
126 | fi
127 | done
128 |
129 |
--------------------------------------------------------------------------------
/studio/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import time
4 | from terminaltables import AsciiTable
5 |
6 | from studio import model
7 | from studio.util import logs
8 |
9 | _my_logger = None
10 |
11 |
12 | def print_help():
13 | print('Usage: studio runs [command] arguments')
14 | print('\ncommand can be one of the following:')
15 | print('')
16 | print('\tlist [username] - display experiments')
17 | print('\tstop [experiment] - stop running experiment')
18 | print('\tkill [experiment] - stop and delete experiment')
19 | print('')
20 |
21 |
22 | def main():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--config', help='configuration file', default=None)
25 | parser.add_argument(
26 | '--short', '-s', help='Brief output - names of experiments only',
27 | action='store_true')
28 |
29 | cli_args, script_args = parser.parse_known_args(sys.argv)
30 |
31 | get_logger().setLevel(10)
32 |
33 | if len(script_args) < 2:
34 | get_logger().critical('No command provided!')
35 | parser.print_help()
36 | print_help()
37 | return
38 |
39 | cmd = script_args[1]
40 |
41 | if cmd == 'list':
42 | _list(script_args[2:], cli_args)
43 | elif cmd == 'stop':
44 | _stop(script_args[2:], cli_args)
45 | elif cmd == 'kill':
46 | _kill(script_args[2:], cli_args)
47 |
48 | else:
49 | get_logger().critical('Unknown command ' + cmd)
50 | parser.print_help()
51 | print_help()
52 | return
53 |
54 |
55 | def _list(args, cli_args):
56 | with model.get_db_provider(cli_args.config) as db:
57 | if len(args) == 0:
58 | experiments = db.get_user_experiments()
59 | elif args[0] == 'project':
60 | assert len(args) == 2
61 | experiments = db.get_project_experiments(args[1])
62 | elif args[0] == 'users':
63 | assert len(args) == 1
64 | users = db.get_users()
65 | for u in users:
66 | print(users[u].get('email'))
67 | return
68 | elif args[0] == 'user':
69 | assert len(args) == 2
70 | users = db.get_users()
71 | user_ids = [u for u in users if users[u].get('email') == args[1]]
72 | assert len(user_ids) == 1, \
73 | 'The user with email ' + args[1] + \
74 | 'not found!'
75 | experiments = db.get_user_experiments(user_ids[0])
76 | elif args[0] == 'all':
77 | assert len(args) == 1
78 | users = db.get_users()
79 | experiments = []
80 | for u in users:
81 | experiments += db.get_user_experiments(u)
82 | else:
83 | get_logger().critical('Unknown command ' + args[0])
84 | return
85 |
86 | if cli_args.short:
87 | for e in experiments:
88 | print(e)
89 | return
90 |
91 | experiments = [db.get_experiment(e) for e in experiments]
92 |
93 | experiments.sort(key=lambda e: -e.time_added)
94 | table = [['Time added', 'Key', 'Project', 'Status']]
95 |
96 | for e in experiments:
97 | table.append([
98 | time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(e.time_added)),
99 | e.key,
100 | e.project,
101 | e.status])
102 |
103 | print(AsciiTable(table).table)
104 |
105 |
106 | def _stop(args, cli_args):
107 | with model.get_db_provider(cli_args.config) as db:
108 | for e in args:
109 | get_logger().info('Stopping experiment ' + e)
110 | db.stop_experiment(e)
111 |
112 |
113 | def _kill(args, cli_args):
114 | with model.get_db_provider(cli_args.config) as db:
115 | for e in args:
116 | get_logger().info('Deleting experiment ' + e)
117 | db.delete_experiment(e)
118 |
119 |
120 | def get_logger():
121 | global _my_logger
122 | if not _my_logger:
123 | _my_logger = logs.get_logger('studio-runs')
124 | return _my_logger
125 |
126 |
127 | if __name__ == '__main__':
128 | main()
129 |
--------------------------------------------------------------------------------
/studio/scripts/install_studio.sh:
--------------------------------------------------------------------------------
1 | repo_url="{repo_url}"
2 | branch="{studioml_branch}"
3 |
4 | if [ ! -d "studio" ]; then
5 | echo "Installing system packages..."
6 | sudo apt -y update
7 | sudo apt install -y wget git jq
8 | sudo apt install -y python python-pip python-dev python3 python3-dev python3-pip dh-autoreconf build-essential
9 | echo "python2 version: " $(python -V)
10 |
11 | sudo python -m pip install --upgrade pip
12 | sudo python -m pip install --upgrade awscli boto3
13 |
14 | echo "python3 version: " $(python3 -V)
15 |
16 | sudo python3 -m pip install --upgrade pip
17 | sudo python3 -m pip install --upgrade awscli boto3
18 |
19 | # Install singularity
20 | git clone https://github.com/singularityware/singularity.git
21 | cd singularity
22 | ./autogen.sh
23 | ./configure --prefix=/usr/local --sysconfdir=/etc
24 | time (make && make install)
25 | cd ..
26 |
27 | time apt-get -y install python3-tk python-tk
28 | time python -m pip install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp27-cp27mu-linux_x86_64.whl
29 | time python3 -m pip install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp35-cp35m-linux_x86_64.whl
30 |
31 | nvidia-smi
32 | nvidia_smi_error=$?
33 |
34 | if [ "{use_gpus}" -eq 1 ] && [ "$nvidia_smi_error" -ne 0 ]; then
35 | cudnn5="libcudnn5_5.1.10-1_cuda8.0_amd64.deb"
36 | cudnn6="libcudnn6_6.0.21-1_cuda8.0_amd64.deb"
37 | cuda_base="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/"
38 | cuda_ver="cuda-repo-ubuntu1604_8.0.61-1_amd64.deb"
39 |
40 | cuda_driver='nvidia-diag-driver-local-repo-ubuntu1604-384.66_1.0-1_amd64.deb'
41 | wget $code_url_base/$cuda_driver
42 | dpkg -i $cuda_driver
43 | apt-key add /var/nvidia-diag-driver-local-repo-384.66/7fa2af80.pub
44 | apt-get -y update
45 | apt-get -y install cuda-drivers
46 | apt-get -y install unzip
47 |
48 | # install cuda
49 | cuda_url="https://developer.nvidia.com/compute/cuda/8.0/Prod2/local_installers/cuda_8.0.61_375.26_linux-run"
50 | cuda_patch_url="https://developer.nvidia.com/compute/cuda/8.0/Prod2/patches/2/cuda_8.0.61.2_linux-run"
51 |
52 | # install cuda
53 | wget $cuda_url
54 | wget $cuda_patch_url
55 | #udo dpkg -i $cuda_ver
56 | #sudo apt -y update
57 | #sudo apt install -y "cuda-8.0"
58 | sh ./cuda_8.0.61_375.26_linux-run --silent --toolkit
59 | sh ./cuda_8.0.61.2_linux-run --silent --accept-eula
60 |
61 | export PATH=$PATH:/usr/local/cuda-8.0/bin
62 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-8.0/lib64
63 |
64 | # wget $cuda_base/$cuda_ver
65 | # sudo dpkg -i $cuda_ver
66 | # sudo apt -y update
67 | # sudo apt install -y "cuda-8.0"
68 |
69 | # install cudnn
70 | wget $code_url_base/$cudnn5
71 | wget $code_url_base/$cudnn6
72 | sudo dpkg -i $cudnn5
73 | sudo dpkg -i $cudnn6
74 |
75 | # sudo python -m pip install tf-nightly tf-nightly-gpu --upgrade
76 | # sudo python3 -m pip install tf-nightly tf-nightly-gpu --upgrade
77 | else
78 | sudo apt install -y default-jre
79 | fi
80 | fi
81 |
82 | #if [[ "{use_gpus}" -ne 1 ]]; then
83 | # rm -rf /usr/lib/x86_64-linux-gnu/libcuda*
84 | #fi
85 |
86 | rm -rf studio
87 | git clone $repo_url
88 | if [[ $? -ne 0 ]]; then
89 | git clone https://github.com/studioml/studio
90 | fi
91 |
92 | cd studio
93 | git pull
94 | git checkout $branch
95 |
96 | apoptosis() {{
97 | while :
98 | do
99 | date
100 | shutdown +1
101 | (nvidia-smi; shutdown -c; echo "nvidia-smi functional, preventing shutdown")&
102 | sleep 90
103 | done
104 |
105 | }}
106 |
107 | if [[ "{use_gpus}" -eq 1 ]]; then
108 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-8.0/lib64
109 | # (apoptosis > autoterminate.log)&
110 | fi
111 |
112 | time python -m pip install -e . --upgrade
113 | time python3 -m pip install -e . --upgrade
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/test-runner.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | apiVersion: v1
3 | kind: PersistentVolumeClaim
4 | metadata:
5 | # This name uniquely identifies the PVC. Will be used in deployment below.
6 | name: test-runner-pv-claim
7 | labels:
8 | app: test-runner-storage-claim
9 | namespace: default
10 | spec:
11 | # Read more about access modes here: https://kubernetes.io/docs/user-guide/persistent-volumes/#access-modes
12 | accessModes:
13 | - ReadWriteMany
14 | resources:
15 | # This is the request for storage. Should be available in the cluster.
16 | requests:
17 | storage: 10Gi
18 | # Uncomment and add storageClass specific to your requirements below. Read more https://kubernetes.io/docs/concepts/storage/persistent-volumes/#class-1
19 | #storageClassName:
20 | ---
21 | apiVersion: extensions/v1beta1
22 | kind: Deployment
23 | metadata:
24 | name: test-runner
25 | namespace: default
26 | labels:
27 | name: "test-runer"
28 | keel.sh/policy: force
29 | keel.sh/trigger: poll
30 | spec:
31 | template:
32 | metadata:
33 | name: test-runner
34 | namespace: default
35 | labels:
36 | app: test-runner
37 | spec:
38 | volumes:
39 | - name: test-runner-storage
40 | persistentVolumeClaim:
41 | # Name of the PVC created earlier
42 | claimName: test-runner-pv-claim
43 | - name: podinfo
44 | downwardAPI:
45 | items:
46 | - path: "namespace"
47 | fieldRef:
48 | fieldPath: metadata.namespace
49 | - path: "annotations"
50 | fieldRef:
51 | fieldPath: metadata.annotations
52 | - path: "labels"
53 | fieldRef:
54 | fieldPath: metadata.labels
55 | containers:
56 | - image: xujiamin9/standalone_testing
57 | imagePullPolicy: Always
58 | name: test-runner
59 | env:
60 | - name: K8S_POD_NAME
61 | valueFrom:
62 | fieldRef:
63 | fieldPath: metadata.name
64 | - name: K8S_NAMESPACE
65 | valueFrom:
66 | fieldRef:
67 | fieldPath: metadata.namespace
68 | ports:
69 | - containerPort: 8080
70 | resources:
71 | requests:
72 | memory: "128Mi"
73 | cpu: "500m"
74 | limits:
75 | memory: "1024Mi"
76 | cpu: 4
77 | volumeMounts:
78 | - name: test-runner-storage # must match the volume name, above
79 | mountPath: "/build"
80 | - name: podinfo
81 | mountPath: /etc/podinfo
82 | readOnly: false
83 | lifecycle:
84 | postStart:
85 | exec:
86 | command:
87 | - "/bin/bash"
88 | - "-c"
89 | - >
90 | set -euo pipefail ;
91 | IFS=$'\n\t' ;
92 | echo "Starting the keel modifications" $K8S_POD_NAME ;
93 | kubectl label deployment test-runner keel.sh/policy- --namespace=$K8S_NAMESPACE ;
94 | curl -v --cacert /var/run/secrets/kubernetes.io/serviceaccount/ca.crt -H "Authorization: Bearer $(cat /var/runsecrets/kubernetes.io/serviceaccount/token)" https://$KUBERNETES_SERVICE_HOST:$KUBERNETES_PORT_443_TCP_PORT/api/v1/namespaces/$K8S_NAMESPACE/pods/$K8S_POD_NAME
95 | preStop:
96 | exec:
97 | command:
98 | - "/bin/bash"
99 | - "-c"
100 | - >
101 | set -euo pipefail;
102 | IFS=$'\n\t' ;
103 | echo "Starting the namespace injections etc" $K8S_POD_NAME ;
104 | kubectl label deployment test-runner keel.sh/policy=force --namespace=$K8S_NAMESPACE ;
105 | for (( ; ; )) ;
106 | do ;
107 | sleep 10 ;
108 | done
109 |
--------------------------------------------------------------------------------
/studio/storage/local_storage_handler.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from typing import Dict
4 |
5 | from studio.storage.storage_setup import get_storage_verbose_level
6 | from studio.storage.storage_type import StorageType
7 | from studio.storage.storage_handler import StorageHandler
8 | from studio.util import logs
9 | from studio.util import util
10 |
11 | class LocalStorageHandler(StorageHandler):
12 | def __init__(self, config,
13 | measure_timestamp_diff=False,
14 | compression=None):
15 |
16 | self.logger = logs.get_logger(self.__class__.__name__)
17 | self.logger.setLevel(get_storage_verbose_level())
18 |
19 | if compression is None:
20 | compression = config.get('compression', None)
21 |
22 | self.endpoint = config.get('endpoint', '~')
23 | self.endpoint = os.path.realpath(os.path.expanduser(self.endpoint))
24 | if not os.path.exists(self.endpoint) \
25 | or not os.path.isdir(self.endpoint):
26 | msg: str = "Store root {0} doesn't exist or not a directory. Aborting."\
27 | .format(self.endpoint)
28 | self._report_fatal(msg)
29 |
30 | self.bucket = config.get('bucket', 'storage')
31 | self.store_root = os.path.join(self.endpoint, self.bucket)
32 | self._ensure_path_dirs_exist(self.store_root)
33 |
34 | super().__init__(StorageType.storageLocal,
35 | self.logger,
36 | measure_timestamp_diff,
37 | compression=compression)
38 |
39 | def _ensure_path_dirs_exist(self, path):
40 | dirs = os.path.dirname(path)
41 | os.makedirs(dirs, mode = 0o777, exist_ok = True)
42 |
43 | def _copy_file(self, from_path, to_path):
44 | try:
45 | shutil.copyfile(from_path, to_path)
46 | except Exception as exc:
47 | msg: str = "FAILED to copy '{0}' to '{1}': {2}. Aborting."\
48 | .format(from_path, to_path,exc)
49 | self._report_fatal(msg)
50 |
51 | def upload_file(self, key, local_path):
52 | target_path = os.path.join(self.store_root, key)
53 | if not os.path.exists(local_path):
54 | self.logger.debug(
55 | "Local path {0} does not exist. SKIPPING upload to {1}"
56 | .format(local_path, target_path))
57 | return False
58 | self._ensure_path_dirs_exist(target_path)
59 | self._copy_file(local_path, target_path)
60 | return True
61 |
62 | def download_file(self, key, local_path):
63 | source_path = os.path.join(self.store_root, key)
64 | if not os.path.exists(source_path):
65 | self.logger.debug(
66 | "Source path {0} does not exist. SKIPPING download to {1}"
67 | .format(source_path, local_path))
68 | return False
69 | self._ensure_path_dirs_exist(local_path)
70 | self._copy_file(source_path, local_path)
71 | return True
72 |
73 | def delete_file(self, key, shallow=True):
74 | key_path: str = self._get_file_path_from_key(key)
75 | if os.path.exists(key_path):
76 | self.logger.debug("Deleting local file {0}.".format(key_path))
77 | util.delete_local_path(key_path, self.store_root, False)
78 |
79 | @classmethod
80 | def get_id(cls, config: Dict) -> str:
81 | endpoint = config.get('endpoint', None)
82 | if endpoint is None:
83 | return None
84 | return '[local]{0}'.format(endpoint)
85 |
86 | def _get_file_path_from_key(self, key: str):
87 | return str(os.path.join(self.store_root, key))
88 |
89 | def get_file_url(self, key, method='GET'):
90 | return self._get_file_path_from_key(key)
91 |
92 | def get_file_timestamp(self, key):
93 | key_path: str = self._get_file_path_from_key(key)
94 | if os.path.exists(key_path):
95 | return os.path.getmtime(key_path)
96 | else:
97 | return None
98 |
99 | def get_qualified_location(self, key):
100 | return 'file:/' + self.store_root + '/' + key
101 |
102 | def get_endpoint(self):
103 | return self.endpoint
104 |
105 | def _report_fatal(self, msg: str):
106 | util.report_fatal(msg, self.logger)
107 |
--------------------------------------------------------------------------------
/studio/model.py:
--------------------------------------------------------------------------------
1 | """Data providers."""
2 | import uuid
3 |
4 | from studio.firebase_provider import FirebaseProvider
5 | from studio.http_provider import HTTPProvider
6 | from studio.pubsub_queue import PubsubQueue
7 | from studio.gcloud_worker import GCloudWorkerManager
8 | from studio.ec2cloud_worker import EC2WorkerManager
9 | from studio.util.util import parse_verbosity
10 | from studio.auth import get_auth
11 |
12 | from studio.db_providers import db_provider_setup
13 | from studio.queues import queues_setup
14 | from studio.storage.storage_setup import setup_storage, get_storage_db_provider,\
15 | reset_storage, set_storage_verbose_level
16 | from studio.util import logs
17 |
18 | def reset_storage_providers():
19 | reset_storage()
20 |
21 |
22 | def get_config(config_file=None):
23 | return db_provider_setup.get_config(config_file=config_file)
24 |
25 |
26 | def get_db_provider(config=None, blocking_auth=True):
27 |
28 | db_provider = get_storage_db_provider()
29 | if db_provider is not None:
30 | return db_provider
31 |
32 | if config is None:
33 | config = get_config()
34 | verbose = parse_verbosity(config.get('verbose', None))
35 |
36 | # Save this verbosity level as global for the whole experiment job:
37 | set_storage_verbose_level(verbose)
38 |
39 | logger = logs.get_logger("get_db_provider")
40 | logger.setLevel(verbose)
41 | logger.debug('Choosing db provider with config:')
42 | logger.debug(config)
43 |
44 | if 'storage' in config.keys():
45 | artifact_store = db_provider_setup.get_artifact_store(config['storage'])
46 | else:
47 | artifact_store = None
48 |
49 | assert 'database' in config.keys()
50 | db_config = config['database']
51 | if db_config['type'].lower() == 'firebase':
52 | db_provider = FirebaseProvider(db_config,
53 | blocking_auth=blocking_auth)
54 |
55 | elif db_config['type'].lower() == 'http':
56 | db_provider = HTTPProvider(db_config,
57 | verbose=verbose,
58 | blocking_auth=blocking_auth)
59 | else:
60 | db_provider = db_provider_setup.get_db_provider(
61 | config=config, blocking_auth=blocking_auth)
62 |
63 | setup_storage(db_provider, artifact_store)
64 | return db_provider
65 |
66 | def get_queue(
67 | queue_name=None,
68 | cloud=None,
69 | config=None,
70 | logger=None,
71 | close_after=None,
72 | verbose=10):
73 | queue = queues_setup.get_queue(queue_name=queue_name,
74 | cloud=cloud,
75 | config=config,
76 | logger=logger,
77 | close_after=close_after,
78 | verbose=verbose)
79 | if queue is None:
80 | queue = PubsubQueue(queue_name, verbose=verbose)
81 | return queue
82 |
83 | def shutdown_queue(queue, logger=None, delete_queue=True):
84 | queues_setup.shutdown_queue(queue, logger=logger, delete_queue=delete_queue)
85 |
86 | def get_worker_manager(config, cloud=None, verbose=10):
87 | if cloud is None:
88 | return None
89 |
90 | assert cloud in ['gcloud', 'gcspot', 'ec2', 'ec2spot']
91 | logger = logs.get_logger('runner.get_worker_manager')
92 | logger.setLevel(verbose)
93 |
94 | auth = get_auth(config['database']['authentication'])
95 | auth_cookie = auth.get_token_file() if auth else None
96 |
97 | branch = config['cloud'].get('branch')
98 |
99 | logger.info('using branch {}'.format(branch))
100 |
101 | if cloud in ['gcloud', 'gcspot']:
102 |
103 | cloudconfig = config['cloud']['gcloud']
104 | worker_manager = GCloudWorkerManager(
105 | auth_cookie=auth_cookie,
106 | zone=cloudconfig['zone'],
107 | branch=branch,
108 | user_startup_script=config['cloud'].get('user_startup_script')
109 | )
110 |
111 | if cloud in ['ec2', 'ec2spot']:
112 | worker_manager = EC2WorkerManager(
113 | auth_cookie=auth_cookie,
114 | branch=branch,
115 | user_startup_script=config['cloud'].get('user_startup_script')
116 | )
117 | return worker_manager
118 |
119 |
--------------------------------------------------------------------------------
/studio/db_providers/db_provider_setup.py:
--------------------------------------------------------------------------------
1 | """Data providers."""
2 | import os
3 | import yaml
4 | import pyhocon
5 |
6 | from studio.db_providers.local_db_provider import LocalDbProvider
7 | from studio.db_providers.s3_provider import S3Provider
8 | from studio.storage.storage_handler import StorageHandler
9 | from studio.storage.storage_handler_factory import StorageHandlerFactory
10 | from studio.storage.storage_setup import setup_storage, get_storage_db_provider,\
11 | set_storage_verbose_level
12 | from studio.storage.storage_type import StorageType
13 | from studio.util import logs
14 | from studio.util.util import parse_verbosity
15 |
16 | def get_config(config_file=None):
17 |
18 | config_paths = []
19 | if config_file:
20 | if not os.path.exists(config_file):
21 | raise ValueError('User config file {} not found'
22 | .format(config_file))
23 | config_paths.append(os.path.expanduser(config_file))
24 |
25 | config_paths.append(os.path.expanduser('~/.studioml/config.yaml'))
26 | config_paths.append(
27 | os.path.join(
28 | os.path.dirname(os.path.realpath(__file__)),
29 | "default_config.yaml"))
30 |
31 | for path in config_paths:
32 | if not os.path.exists(path):
33 | continue
34 |
35 | with(open(path)) as f_in:
36 | if path.endswith('.hocon'):
37 | config = pyhocon.ConfigFactory.parse_string(f_in.read())
38 | else:
39 | config = yaml.load(f_in.read(), Loader=yaml.FullLoader)
40 |
41 | def replace_with_env(config):
42 | for key, value in config.items():
43 | if isinstance(value, str):
44 | config[key] = os.path.expandvars(value)
45 |
46 | elif isinstance(value, dict):
47 | replace_with_env(value)
48 |
49 | replace_with_env(config)
50 |
51 | return config
52 |
53 | raise ValueError('None of the config paths {0} exists!'
54 | .format(config_paths))
55 |
56 | def get_artifact_store(config) -> StorageHandler:
57 | storage_type: str = config['type'].lower()
58 |
59 | factory: StorageHandlerFactory = StorageHandlerFactory.get_factory()
60 | if storage_type == 's3':
61 | handler = factory.get_handler(StorageType.storageS3, config)
62 | return handler
63 | if storage_type == 'local':
64 | handler = factory.get_handler(StorageType.storageLocal, config)
65 | return handler
66 | raise ValueError('Unknown storage type: ' + storage_type)
67 |
68 | def get_db_provider(config=None, blocking_auth=True):
69 |
70 | db_provider = get_storage_db_provider()
71 | if db_provider is not None:
72 | return db_provider
73 |
74 | if config is None:
75 | config = get_config()
76 | verbose = parse_verbosity(config.get('verbose', None))
77 |
78 | # Save this verbosity level as global for the whole experiment job:
79 | set_storage_verbose_level(verbose)
80 |
81 | logger = logs.get_logger("get_db_provider")
82 | logger.setLevel(verbose)
83 | logger.debug('Choosing db provider with config:')
84 | logger.debug(config)
85 |
86 | if 'storage' in config.keys():
87 | artifact_store = get_artifact_store(config['storage'])
88 | else:
89 | artifact_store = None
90 |
91 | assert 'database' in config.keys()
92 | db_config = config['database']
93 | if db_config['type'].lower() == 's3':
94 | db_provider = S3Provider(db_config,
95 | blocking_auth=blocking_auth)
96 | if artifact_store is None:
97 | artifact_store = db_provider.get_storage_handler()
98 |
99 | elif db_config['type'].lower() == 'gs':
100 | raise NotImplementedError("GS is not supported.")
101 |
102 | elif db_config['type'].lower() == 'local':
103 | db_provider = LocalDbProvider(db_config,
104 | blocking_auth=blocking_auth)
105 | if artifact_store is None:
106 | artifact_store = db_provider.get_storage_handler()
107 |
108 | else:
109 | raise ValueError('Unknown type of the database ' + db_config['type'])
110 |
111 | setup_storage(db_provider, artifact_store)
112 | return db_provider
113 |
--------------------------------------------------------------------------------
/docs/gcloud_setup.rst:
--------------------------------------------------------------------------------
1 | Setting up Google Cloud Compute
2 | ===============================
3 |
4 | This page describes the process of setting up Google Cloud and
5 | configuring Studio to integrate with it.
6 |
7 | Configuring Google Cloud Compute
8 | --------------------------------
9 |
10 | Create and select a new Google Cloud project
11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
12 |
13 | Go to the Google Cloud console (https://console.cloud.google.com), and
14 | either choose a project that you will use to back cloud
15 | computing or create a new one. If you have not used the Google console
16 | before and there are no projects, there will be a big button "create
17 | project" in the dashboard. Otherwise, you can create a new project by
18 | selecting the drop-down arrow next to current project name in the top
19 | panel, and then clicking the "+" button.
20 |
21 | Enable billing for the project
22 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23 |
24 | Google Cloud computing actually bills you for the compute time you
25 | use, so you must have billing enabled. On the bright side, when you
26 | sign up with Google Cloud they provide $300 of promotional credit, so
27 | really in the beginning you are still using it for free. On the not so
28 | bright side, to use machines with gpus you'll need to show
29 | that you are a legitimate customer and add $35 to your billing account.
30 | In order to enable billing, go to the left-hand pane in the Google Cloud
31 | console, select billing, and follow the instructions to set up your payment
32 | method.
33 |
34 | Generate service credentials
35 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
36 |
37 | The machines that submit cloud jobs will need to be authorized with
38 | service credentials. Go to the left-hand pane in the Google Cloud console and
39 | select API Manager -> Credentials. Then click the "Create credentials"
40 | button, choose service account key, leave key type as JSON, and in the
41 | "Service account" drop-down select "New service account". Enter a
42 | service account name (the name can be virtually anything and won't
43 | matter for the rest of the instructions). The important part is selecting a
44 | role. Click the "Select a role" dropdown menu, in "Project" select "Service
45 | Account Actor", and then scroll down to "Compute Engine" and select "Compute
46 | Engine Admin (v1)". Then scroll down to "Pub/Sub", and add a role
47 | "Pub/Sub editor" (this is required to create queues, publish and read
48 | messages from them). If you are planning to use Google Cloud storage
49 | (directly, without the Firebase layer) for artifact storage, select the Storage
50 | Admin role as well. You can also add other roles if you are planning to use
51 | these credentials in other applications. When done, click "Create".
52 | Google Cloud console should generate a json credentials file and save it
53 | to your computer.
54 |
55 | Configuring Studio
56 | ------------------
57 |
58 | Adding credentials
59 | ~~~~~~~~~~~~~~~~~~
60 |
61 | Copy the json file credentials to the machine where Studio will be
62 | run, and create the environment variable ``GOOGLE_APPLICATION_CREDENTIALS``
63 | that points to it. That is, run
64 |
65 | ::
66 |
67 | export GOOGLE_APPLICATION_CREDENTIALS=/path/to/credentials.json
68 |
69 | Note that this variable will be gone when you restart the terminal, so
70 | if you want to reuse it, add it to ``~/.bashrc`` (linux) or
71 | ``~/.bash_profile`` (OS X)
72 |
73 | Modifying the configuration file
74 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
75 |
76 | In the config file (the one that you use with the ``--config`` flag, or, if you
77 | use the default, in the ``studio/default_config.yaml``), go to the ``cloud``
78 | section. Change projectId to the project id of the Google project for which
79 | you enabled cloud computing. You can also modify the default instance
80 | parameters (see `Cloud computing for studio `__ for
81 | limitations though).
82 |
83 | Test
84 | ~~~~
85 |
86 | To test if things are set up correctly, go to
87 | ``studio/examples/general`` and run
88 |
89 | ::
90 |
91 | studio run --cloud=gcloud report_system_info.py
92 |
93 | Then run ``studio`` locally, and watch the new experiment. In a little
94 | while, it should change its status to "finished" and show the system
95 | information (number of cpus, amount of ram / hdd) of a default instance.
96 | See `Cloud computing for studio `__ for more instructions on
97 | using an instance with specific hardware parameters.
98 |
--------------------------------------------------------------------------------
/studio/queues/local_queue.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uuid
3 | import time
4 | import filelock
5 |
6 | from studio.artifacts.artifacts_tracker import get_studio_home
7 | from studio.storage.storage_setup import get_storage_verbose_level
8 | from studio.util import logs
9 | from studio.util.util import check_for_kb_interrupt, rm_rf
10 |
11 | LOCK_FILE_NAME = 'lock.lock'
12 |
13 | class LocalQueue:
14 | def __init__(self, name: str, path: str = None, logger=None):
15 | if logger is not None:
16 | self._logger = logger
17 | else:
18 | self._logger = logs.get_logger('LocalQueue')
19 | self._logger.setLevel(get_storage_verbose_level())
20 |
21 | self.name = name
22 | if path is None:
23 | self.path = self._get_queue_directory()
24 | else:
25 | self.path = path
26 | self.path = os.path.join(self.path, name)
27 | os.makedirs(self.path, exist_ok=True)
28 |
29 | # Local queue is considered active, iff its directory exists.
30 | self._lock_path = os.path.join(self.path, LOCK_FILE_NAME)
31 | self._lock = filelock.SoftFileLock(self._lock_path)
32 |
33 |
34 | def _get_queue_status(self):
35 | is_active = os.path.isdir(self.path)
36 | if is_active:
37 | try:
38 | with self._lock:
39 | files = os.listdir(self.path)
40 | files.remove(LOCK_FILE_NAME)
41 | return True, files
42 | except BaseException as exc:
43 | check_for_kb_interrupt()
44 | self._logger.info("FAILED to get queue status for %s - %s",
45 | self.path, exc)
46 | # Ignore possible exception:
47 | # we just want list of files without internal lock file
48 | return False, list()
49 |
50 | def _get_queue_directory(self):
51 | queue_dir: str = os.path.join(
52 | get_studio_home(),
53 | 'queue')
54 | return queue_dir
55 |
56 | def has_next(self):
57 | is_active, files = self._get_queue_status()
58 | return is_active and len(files) > 0
59 |
60 | def is_active(self):
61 | is_active = os.path.isdir(self.path)
62 | return is_active
63 |
64 | def clean(self, timeout=0):
65 | _ = timeout
66 | rm_rf(self.path)
67 |
68 | def delete(self):
69 | self.clean()
70 |
71 | def _get_time(self, file: str):
72 | return os.path.getmtime(os.path.join(self.path, file))
73 |
74 | def dequeue(self, acknowledge=True, timeout=0):
75 | sleep_in_seconds = 1
76 | total_wait_time = 0
77 | if not self.is_active():
78 | return None
79 |
80 | while True:
81 | with self._lock:
82 | is_active, files = self._get_queue_status()
83 | if not is_active:
84 | return None, None
85 | if any(files):
86 | first_file = min([(p, self._get_time(p)) for p in files],
87 | key=lambda t: t[1])[0]
88 | first_file = os.path.join(self.path, first_file)
89 |
90 | with open(first_file, 'r') as f_in:
91 | data = f_in.read()
92 |
93 | if acknowledge:
94 | self.acknowledge(first_file)
95 | return data, None
96 | return data, first_file
97 |
98 | if total_wait_time >= timeout:
99 | return None, None
100 | time.sleep(sleep_in_seconds)
101 | total_wait_time += sleep_in_seconds
102 |
103 | def enqueue(self, data):
104 | with self._lock:
105 | filename = os.path.join(self.path, str(uuid.uuid4()))
106 | with open(filename, 'w') as f_out:
107 | f_out.write(data)
108 |
109 | def acknowledge(self, key):
110 | try:
111 | os.remove(key)
112 | except BaseException:
113 | check_for_kb_interrupt()
114 |
115 | def hold(self, key, minutes):
116 | _ = minutes
117 | self.acknowledge(key)
118 |
119 | def get_name(self):
120 | return self.name
121 |
122 | def get_path(self):
123 | return self.path
124 |
125 | def shutdown(self, delete_queue=True):
126 | _ = delete_queue
127 | self.delete()
128 |
--------------------------------------------------------------------------------