├── .gitignore
├── LICENCE
├── README.md
├── bo-mbexp.py
├── bohb
├── HPWorker.py
└── __init__.py
├── config_space
└── config_space.py
├── dmbrl
├── config
│ ├── __init__.py
│ ├── cartpole.py
│ ├── default.py
│ ├── halfcheetah_v3.py
│ ├── hopper.py
│ ├── pusher.py
│ ├── reacher.py
│ └── template.py
├── controllers
│ ├── Controller.py
│ ├── MPC.py
│ └── __init__.py
├── env
│ ├── __init__.py
│ ├── assets
│ │ ├── cartpole.xml
│ │ ├── half_cheetah.xml
│ │ ├── hopper.xml
│ │ ├── pusher.xml
│ │ └── reacher3d.xml
│ ├── cartpole.py
│ ├── half_cheetah_v3.py
│ ├── hopper.py
│ ├── pusher.py
│ └── reacher.py
├── misc
│ ├── Agent.py
│ ├── DotmapUtils.py
│ ├── MBExp.py
│ ├── MBwBOExp.py
│ ├── MBwPBTBTExp.py
│ ├── MBwPBTExp.py
│ ├── __init__.py
│ ├── optimizers
│ │ ├── __init__.py
│ │ ├── cem.py
│ │ ├── optimizer.py
│ │ └── random.py
│ └── render.py
└── modeling
│ ├── __init__.py
│ ├── layers
│ ├── FC.py
│ └── __init__.py
│ ├── models
│ ├── BNN.py
│ ├── NN.py
│ ├── TFGP.py
│ └── __init__.py
│ └── utils
│ ├── HPO.py
│ ├── TensorStandardScaler.py
│ └── __init__.py
├── environment.yml
├── eval_schedules.py
├── learned_schedule
├── daisy_hb_model_train.json
├── daisy_hb_planning.json
├── daisy_pbt_model_train_incumbent.json
├── daisy_pbt_model_train_top_avg.json
├── daisy_pbt_planning_incumbent.json
├── daisy_pbt_planning_top_avg.json
├── daisy_rs_model_train.json
├── daisy_rs_planning.json
├── halfcheetah_default.json
├── halfcheetah_hb_model_train.json
├── halfcheetah_hb_planning.json
├── halfcheetah_pbt_model_train_incumbent.json
├── halfcheetah_pbt_model_train_top_avg.json
├── halfcheetah_pbt_planning_incumbent.json
├── halfcheetah_pbt_planning_top_avg.json
├── halfcheetah_pbtbt_model_train.json
├── halfcheetah_pbtbt_planning.json
├── halfcheetah_rs_model_train.json
├── halfcheetah_rs_planning.json
├── hopper_hb_model_train.json
├── hopper_hb_planning.json
├── hopper_pbt_model_train_incumbent.json
├── hopper_pbt_model_train_top_avg.json
├── hopper_pbt_planning_incumbent.json
├── hopper_pbt_planning_top_avg.json
├── hopper_pbtbt_model_train.json
├── hopper_pbtbt_planning.json
├── hopper_rs_model_train.json
├── hopper_rs_planning.json
├── pusher_default.json
├── pusher_hb_model_train.json
├── pusher_hb_planning.json
├── quadruped_hb_model_train.json
├── quadruped_hb_planning.json
├── quadruped_pbt_model_train_incumbent.json
├── quadruped_pbt_model_train_top_avg.json
├── quadruped_pbt_planning_incumbent.json
├── quadruped_pbt_planning_top_avg.json
├── quadruped_rs_model_train.json
├── quadruped_rs_planning.json
├── reacher_hb_model_train.json
├── reacher_hb_planning.json
├── reacher_pbt_model_train_incumbent.json
├── reacher_pbt_model_train_top_avg.json
├── reacher_pbt_planning_incumbent.json
├── reacher_pbt_planning_top_avg.json
├── reacher_rs_model_train.json
└── reacher_rs_planning.json
├── mbexp.py
├── pbt-bt-mbexp.py
├── pbt-mbexp.py
├── pbt
├── __init__.py
├── backtrack_controller.py
├── backtrack_scheduler.py
├── controller.py
├── exploitation
│ ├── __init__.py
│ ├── constant.py
│ ├── exploitation_strategy.py
│ └── truncation.py
├── exploration
│ ├── __init__.py
│ ├── constant.py
│ ├── exploration_bt.py
│ ├── exploration_strategy.py
│ ├── model_based.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── config_tree
│ │ │ ├── __init__.py
│ │ │ ├── config_tree.py
│ │ │ └── nodes
│ │ │ │ ├── __init__.py
│ │ │ │ ├── categorical.py
│ │ │ │ ├── float.py
│ │ │ │ ├── integer.py
│ │ │ │ ├── log_float.py
│ │ │ │ └── node.py
│ │ ├── model.py
│ │ └── tree_parzen_estimator.py
│ ├── perturb.py
│ ├── perturb_and_resample.py
│ └── resample.py
├── garbage_collector.py
├── network
│ ├── __init__.py
│ ├── controller_adapter.py
│ ├── controller_daemon.py
│ ├── daemon.py
│ ├── worker_adapter.py
│ └── worker_daemon.py
├── population
│ ├── __init__.py
│ ├── member.py
│ ├── population.py
│ └── trial.py
├── scheduler.py
├── tqdm_logger.py
└── worker.py
├── requirements.txt
└── scripts
├── hyperband.sh
├── pbt-bt.sh
└── pbt.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | tags
3 | *.pyc
4 | log/
5 | *.out
6 | .ipynb_checkpoints/*
7 | .vscode/
8 | *.mat
9 | *.pdf
10 | *.png
11 | visualization/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Author:
2 | Baohe Zhang
3 |
4 | # Code Instruction
5 |
6 | ### Overview:
7 | This code is mainly based on Kurtland Chua's implementation of PETS paper. See [code](https://github.com/kchua/handful-of-trials).
8 |
9 | Besides that, to inject the bayesian optimization and Populaion-based Training optimization methods, several components are added.
10 |
11 | config_space (folder):
12 | **config_space.py** : Define the configuration space of hyperparameters for each environment. It is also the key component.
13 |
14 | pbt (folder):
15 | This folder contains the codes for Population-based training methods and Population-based training with back-tracking. The experiment object can be found pbt-mbexp.py and pbt-bt-mbexp.py.
16 |
17 | bohb (folder):
18 | This folder contains the bayesian optimization method [BOHB](https://github.com/automl/HpBandSter/blob/master/hpbandster/optimizers/bohb.py). This also provides Hyperband and Random search functionalities.
19 |
20 | ### Environment Setting
21 | You can use anaconda environment yml file if you want which contains all the necessary libraries used to run this code.
22 |
23 | ### Scripts for Running codes
24 | In order to run the experiments given a fixed hyperparameters, you can refer the following scripts. In the scripts folder, there are some examples when using slurm cluster.
25 |
26 | Example commend to start the code:
27 | ```
28 | # xxx refers the cluster partation
29 | # Hyperband
30 | cd scripts
31 | sbatch -p xxx -a 1-10 hyperband.sh
32 | # For PBT
33 |
34 | # PBT
35 | cd scripts
36 | sbatch -p xxx -a 1-10 pbt.sh
37 |
38 | # PBT-BT
39 | cd scripts
40 | sbatch -p xxx -a 1-10 pbt-bt.sh
41 |
42 | ```
43 |
44 | ### Learned Schedules
45 | In the learned_schedule folder, we have uploaded few hyperparameter schedules that are learned by running HPO methods on PETS. You should be able to reproduce the results reported in our [paper](https://arxiv.org/abs/2102.13651). The evaluation function is in eval_schedules.py file.
46 |
47 | ### Disclaimer
48 | This code is still not yet perfectly cleaned up and might contain issues. Feel free to contact me if there are anything unclear to you.
49 |
50 | ### Known Issue
51 | When running PBT on Slurm-based cluster, the controller will occupy one gpu without really using it since the controller will not train any model. One possible way to alleviate this temporarly is by starting the controller first and then start the workers. But this would need the user taking care that the workers know the directory which stores the server connect file.
--------------------------------------------------------------------------------
/bo-mbexp.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 | import argparse
7 | import pprint
8 |
9 | from dotmap import DotMap
10 |
11 | from dmbrl.misc.MBwBOExp import MBWithBOExperiment
12 | from dmbrl.controllers.MPC import MPC
13 | from dmbrl.config import create_config
14 |
15 | import tensorflow as tf
16 | import numpy as np
17 | import random
18 |
19 | def create_cfg_creator(env, ctrl_type, ctrl_args, base_overrides, logdir):
20 | ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args})
21 |
22 | def cfg_creator(additional_overrides=None):
23 | if additional_overrides is not None:
24 | return create_config(env, ctrl_type, ctrl_args, base_overrides + additional_overrides, logdir)
25 | return create_config(env, ctrl_type, ctrl_args, base_overrides, logdir)
26 |
27 | return cfg_creator
28 |
29 |
30 | def main(args):
31 | cfg_creator = create_cfg_creator(args.env, args.ctrl_type, args.ctrl_arg, args.override, args.logdir)
32 | cfg = cfg_creator()[0]
33 | cfg.pprint()
34 |
35 | if args.ctrl_type == "MPC":
36 | cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg)
37 |
38 | exp = MBWithBOExperiment(cfg.exp_cfg, cfg_creator, args)
39 | os.makedirs(exp.logdir, exist_ok=True)
40 |
41 | exp.run_experiment()
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('-env', type=str, required=True,
47 | help='Environment name: select from [cartpole, reacher, pusher, halfcheetah]')
48 | parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[],
49 | help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments')
50 | parser.add_argument('-o', '--override', action='append', nargs=2, default=[],
51 | help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides')
52 | parser.add_argument('-logdir', type=str, default='log',
53 | help='Directory to which results will be logged (default: ./log)')
54 | parser.add_argument('-ctrl_type', type=str, default='MPC',
55 | help='Control type will be applied (default: MPC)')
56 | # Parser for running BOHB on cluster
57 | parser.add_argument('-worker', action='store_true',
58 | help='Flag to turn this into a worker process')
59 | parser.add_argument('-interface', type=str, default='lo',
60 | help='Interface name to use for creating host')
61 | parser.add_argument('-run_id', type=str, default='111',
62 | help='A unique run id for this optimization run. An easy option is to use'
63 | ' the job id of the clusters scheduler.')
64 | parser.add_argument('-worker_id', type=int, default=0,
65 | help='The ID of the worker')
66 | parser.add_argument('-config_names', type=str, default="model_learning_rate", nargs="+",
67 | help='Specify which hyperparameters to optimize)')
68 | parser.add_argument('-opt_type', type=str, default="bohb",
69 | help='Specify which optimizer to use')
70 | parser.add_argument('-seed', type=int, default=0,
71 | help='Specify the random seed to use')
72 | args = parser.parse_args()
73 | tf.set_random_seed(args.seed)
74 | np.random.seed(args.seed)
75 | random.seed(args.seed)
76 |
77 | print(args)
78 | main(args)
79 |
--------------------------------------------------------------------------------
/bohb/HPWorker.py:
--------------------------------------------------------------------------------
1 | import ConfigSpace as CS
2 | from hpbandster.core.worker import Worker
3 | import numpy as np
4 | from config_space.config_space import DEFAULT_CONFIGSPACE
5 |
6 | class HPWorker(Worker):
7 | """"Worker for BOHB/Hyperband/Random Search
8 | We use the implementation based on https://github.com/automl/HpBandSter
9 | """
10 | def __init__(self,
11 | train_func,
12 | env,
13 | config_names,
14 | last_k,
15 | **kwargs):
16 | """
17 | Initialization
18 | """
19 | super().__init__(**kwargs)
20 | self.train_func = train_func
21 | if len(config_names) == 0:
22 | raise ValueError("config_names can not be an empty list")
23 | self.config_names = config_names
24 | self.env = env
25 | self.last_k = last_k
26 |
27 | @staticmethod
28 | def get_configspace(env, config_names):
29 | """
30 | Construct the configuration space
31 | Param:
32 | env: (str) Name of the environment that we will used
33 |
34 | Return:
35 | cs: (ConfigurationSpace) A configuration space which defines the search space
36 | """
37 | cs = CS.ConfigurationSpace()
38 | config_lst = []
39 | for config_name in config_names:
40 | config_lst.append(DEFAULT_CONFIGSPACE[env][config_name])
41 | cs.add_hyperparameters(config_lst)
42 | return cs
43 |
44 | def compute(self, config_id, config, budget, **kwargs):
45 | """
46 | Evaluates the configuration on the defined budget and returns the validation performance.
47 | Args:
48 | config_id: (int) configuration id
49 | config: dictionary containing the sampled configurations by the optimizer
50 | budget: (float) amount of time/epochs/etc. the model can use to train
51 | Returns:
52 | dictionary with mandatory fields:
53 | 'loss' (scalar)
54 | 'info' (dict)
55 | """
56 | train_func = self.train_func
57 | # Construct the configurations
58 | cfg = dict()
59 | for config_name in DEFAULT_CONFIGSPACE[self.env]:
60 | if config_name in self.config_names:
61 | cfg[config_name] = config[config_name]
62 | else:
63 | cfg[config_name] = DEFAULT_CONFIGSPACE[self.env][config_name].default_value
64 |
65 | # Run one configuration
66 | traj_rets, traj_eval_rets = train_func(config_id, cfg, budget)
67 | all_rets = np.concatenate([traj_rets, traj_eval_rets], axis=1)
68 | loss = - np.mean(all_rets, axis=1)[-self.last_k:].mean().item()
69 | return ({
70 | 'loss': loss,
71 | 'info': {}
72 | })
73 |
74 |
--------------------------------------------------------------------------------
/bohb/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/automl/HPO_for_RL/d82c7ddd6fe19834c088137570530f11761d9390/bohb/__init__.py
--------------------------------------------------------------------------------
/config_space/config_space.py:
--------------------------------------------------------------------------------
1 | import ConfigSpace as CS
2 | import ConfigSpace.hyperparameters as CSH
3 |
4 | """
5 | This defines the search space of each environment. The default value is chosen based on Kurtland et al. 2018.
6 | See https://github.com/kchua/handful-of-trials for more information
7 | """
8 | DEFAULT_CONFIGSPACE = {
9 | "cartpole": {
10 | # Model Architecture
11 | "num_hidden_layers" : CSH.UniformIntegerHyperparameter("num_hidden_layers", lower=2, upper=8, default_value=3, log=False),
12 | "hidden_layer_width" : CSH.UniformIntegerHyperparameter("hidden_layer_width", lower=100, upper=600, default_value=500, log=True),
13 | "act_idx" : CSH.CategoricalHyperparameter(name='activation', choices=['relu', 'tanh', 'sigmoid', 'swish'], default_value='swish'),
14 | # Model Optimizer
15 | "model_learning_rate" : CSH.UniformFloatHyperparameter('model_learning_rate', lower=1e-5, upper=4e-2, default_value=1e-3, log=True),
16 | "model_weight_decay" : CSH.UniformFloatHyperparameter('model_weight_decay', lower=1e-7, upper=1e-1, default_value=0.00025, log=True),
17 | "model_opt_idx" : CSH.CategoricalHyperparameter(name="model_opt_idx", choices=['adam', 'adadelta', 'adagrad', 'sgd', 'rms'], default_value='adam'),
18 | "model_train_epoch" : CSH.UniformIntegerHyperparameter("model_train_epoch", lower=3, upper=15, default_value=5, log=False),
19 | # Planner
20 | "num_cem_iters" : CSH.UniformIntegerHyperparameter("num_cem_iters", lower=4, upper=6, default_value=5, log=False),
21 | "cem_popsize" : CSH.UniformIntegerHyperparameter("cem_popsize", lower=200, upper=700, default_value=400, log=True),
22 | "cem_alpha" : CSH.UniformFloatHyperparameter("cem_alpha", lower=0.05, upper=0.2, default_value=0.1, log=False),
23 | "cem_elites_ratio" : CSH.UniformFloatHyperparameter("cem_elites_ratio", lower=0.04, upper=0.5, default_value=0.1, log=True),
24 | "plan_hor" : CSH.UniformIntegerHyperparameter("plan_hor", lower=5, upper=40, default_value=25, log=False),
25 | },
26 | "pusher": {
27 | # Model Architecture
28 | "num_hidden_layers" : CSH.UniformIntegerHyperparameter("num_hidden_layers", lower=3, upper=8, default_value=3, log=False),
29 | "hidden_layer_width" : CSH.UniformIntegerHyperparameter("hidden_layer_width", lower=100, upper=600, default_value=200, log=True),
30 | "act_idx" : CSH.CategoricalHyperparameter(name='activation', choices=['relu', 'tanh', 'sigmoid', 'swish'], default_value='swish'),
31 | # Model Optimizer
32 | "model_learning_rate" : CSH.UniformFloatHyperparameter('model_learning_rate', lower=3e-5, upper=3e-3, default_value=0.001, log=True),
33 | "model_weight_decay" : CSH.UniformFloatHyperparameter('model_weight_decay', lower=1e-7, upper=1e-1, default_value=5e-4, log=True),
34 | "model_opt_idx" : CSH.CategoricalHyperparameter(name="model_opt_idx", choices=['adam', 'adadelta', 'adagrad', 'sgd', 'rms'], default_value='adam'),
35 | "model_train_epoch" : CSH.UniformIntegerHyperparameter("model_train_epoch", lower=3, upper=20, default_value=5, log=False),
36 | # Planners
37 | "num_cem_iters" : CSH.UniformIntegerHyperparameter("num_cem_iters", lower=3, upper=10, default_value=5, log=False),
38 | "cem_popsize" : CSH.UniformIntegerHyperparameter("cem_popsize", lower=100, upper=700, default_value=500, log=True),
39 | "cem_alpha" : CSH.UniformFloatHyperparameter("cem_alpha", lower=0.01, upper=0.5, default_value=0.1, log=False),
40 | "cem_elites_ratio" : CSH.UniformFloatHyperparameter("cem_elites_ratio", lower=0.04, upper=0.5, default_value=0.1, log=True),
41 | "plan_hor" : CSH.UniformIntegerHyperparameter("plan_hor", lower=5, upper=40, default_value=25, log=False),
42 | },
43 | "reacher": {
44 | # Model Architecture
45 | "num_hidden_layers" : CSH.UniformIntegerHyperparameter("num_hidden_layers", lower=3, upper=8, default_value=4, log=False),
46 | "hidden_layer_width" : CSH.UniformIntegerHyperparameter("hidden_layer_width", lower=100, upper=600, default_value=200, log=True),
47 | "act_idx" : CSH.CategoricalHyperparameter(name='activation', choices=['relu', 'tanh', 'sigmoid', 'swish'], default_value='swish'),
48 | # Model Optimizer
49 | "model_learning_rate" : CSH.UniformFloatHyperparameter('model_learning_rate', lower=1e-5, upper=4e-2, default_value=0.00075, log=True),
50 | "model_weight_decay" : CSH.UniformFloatHyperparameter('model_weight_decay', lower=1e-7, upper=1e-1, default_value=0.0005, log=True),
51 | "model_opt_idx" : CSH.CategoricalHyperparameter(name="model_opt_idx", choices=['adam', 'adadelta', 'adagrad', 'sgd', 'rms'], default_value='adam'),
52 | "model_train_epoch" : CSH.UniformIntegerHyperparameter("model_train_epoch", lower=3, upper=20, default_value=5, log=False),
53 | # Planner
54 | "num_cem_iters" : CSH.UniformIntegerHyperparameter("num_cem_iters", lower=4, upper=6, default_value=5, log=False),
55 | "cem_popsize" : CSH.UniformIntegerHyperparameter("cem_popsize", lower=200, upper=700, default_value=400, log=True),
56 | "cem_alpha" : CSH.UniformFloatHyperparameter("cem_alpha", lower=0.05, upper=0.4, default_value=0.1, log=False),
57 | "cem_elites_ratio" : CSH.UniformFloatHyperparameter("cem_elites_ratio", lower=0.04, upper=0.5, default_value=0.1, log=True),
58 | "plan_hor" : CSH.UniformIntegerHyperparameter("plan_hor", lower=5, upper=40, default_value=25, log=False),
59 | },
60 | "halfcheetah_v3": {
61 | # Model Architecture
62 | "num_hidden_layers" : CSH.UniformIntegerHyperparameter("num_hidden_layers", lower=2, upper=8, default_value=4, log=False),
63 | "hidden_layer_width" : CSH.UniformIntegerHyperparameter("hidden_layer_width", lower=100, upper=600, default_value=200, log=False),
64 | "act_idx" : CSH.CategoricalHyperparameter(name='activation', choices=['relu', 'tanh', 'sigmoid', 'swish'], default_value='swish'),
65 | # Model Optimizer
66 | "model_learning_rate" : CSH.UniformFloatHyperparameter('model_learning_rate', lower=1e-5, upper=4e-2, default_value=0.001, log=True),
67 | "model_weight_decay" : CSH.UniformFloatHyperparameter('model_weight_decay', lower=1e-7, upper=1e-1, default_value=0.000075, log=True),
68 | "model_opt_idx" : CSH.CategoricalHyperparameter(name="model_opt_idx", choices=['adam', 'adadelta', 'adagrad', 'sgd', 'rms'], default_value='adam'),
69 | "model_train_epoch" : CSH.UniformIntegerHyperparameter("model_train_epoch", lower=3, upper=20, default_value=5, log=False),
70 | # Planner
71 | "num_cem_iters" : CSH.UniformIntegerHyperparameter("num_cem_iters", lower=3, upper=8, default_value=5, log=False),
72 | "cem_popsize" : CSH.UniformIntegerHyperparameter("cem_popsize", lower=200, upper=700, default_value=500, log=True),
73 | "cem_alpha" : CSH.UniformFloatHyperparameter("cem_alpha", lower=0.05, upper=0.2, default_value=0.1, log=False),
74 | "cem_elites_ratio" : CSH.UniformFloatHyperparameter("cem_elites_ratio", lower=0.04, upper=0.5, default_value=0.1, log=True),
75 | "plan_hor" : CSH.UniformIntegerHyperparameter("plan_hor", lower=5, upper=60, default_value=30, log=False),
76 | },
77 | "hopper": {
78 | # Model Architecture
79 | "num_hidden_layers" : CSH.UniformIntegerHyperparameter("num_hidden_layers", lower=2, upper=8, default_value=4, log=False),
80 | "hidden_layer_width" : CSH.UniformIntegerHyperparameter("hidden_layer_width", lower=100, upper=600, default_value=200, log=False),
81 | "act_idx" : CSH.CategoricalHyperparameter(name='activation', choices=['relu', 'tanh', 'sigmoid', 'swish'], default_value='swish'),
82 | # Model Optimizer
83 | "model_learning_rate" : CSH.UniformFloatHyperparameter('model_learning_rate', lower=1e-5, upper=4e-2, default_value=0.001, log=True),
84 | "model_weight_decay" : CSH.UniformFloatHyperparameter('model_weight_decay', lower=1e-7, upper=1e-1, default_value=0.000075, log=True),
85 | "model_opt_idx" : CSH.CategoricalHyperparameter(name="model_opt_idx", choices=['adam', 'adadelta', 'adagrad', 'sgd', 'rms'], default_value='adam'),
86 | "model_train_epoch" : CSH.UniformIntegerHyperparameter("model_train_epoch", lower=3, upper=20, default_value=5, log=False),
87 | # Planner
88 | "num_cem_iters" : CSH.UniformIntegerHyperparameter("num_cem_iters", lower=3, upper=8, default_value=5, log=False),
89 | "cem_popsize" : CSH.UniformIntegerHyperparameter("cem_popsize", lower=200, upper=700, default_value=500, log=True),
90 | "cem_alpha" : CSH.UniformFloatHyperparameter("cem_alpha", lower=0.05, upper=0.2, default_value=0.1, log=False),
91 | "cem_elites_ratio" : CSH.UniformFloatHyperparameter("cem_elites_ratio", lower=0.04, upper=0.5, default_value=0.1, log=True),
92 | "plan_hor" : CSH.UniformIntegerHyperparameter("plan_hor", lower=5, upper=60, default_value=30, log=False),
93 | }
94 | }
95 |
--------------------------------------------------------------------------------
/dmbrl/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .default import create_config
--------------------------------------------------------------------------------
/dmbrl/config/cartpole.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from dotmap import DotMap
8 | import gym
9 |
10 | from dmbrl.misc.DotmapUtils import get_required_argument
11 | from dmbrl.modeling.layers import FC
12 | import dmbrl.env
13 |
14 |
15 | class CartpoleConfigModule:
16 | ENV_NAME = "MBRLCartpole-v0"
17 | TASK_HORIZON = 200
18 | NTRAIN_ITERS = 50
19 | NROLLOUTS_PER_ITER = 1
20 | PLAN_HOR = 25
21 | MODEL_IN, MODEL_OUT = 6, 4
22 | GP_NINDUCING_POINTS = 200
23 |
24 | def __init__(self):
25 | self.ENV = gym.make(self.ENV_NAME)
26 | cfg = tf.ConfigProto()
27 | cfg.gpu_options.allow_growth = True
28 | self.SESS = tf.Session(config=cfg)
29 | self.NN_TRAIN_CFG = {"epochs": 5}
30 | self.OPT_CFG = {
31 | "Random": {
32 | "popsize": 2000
33 | },
34 | "CEM": {
35 | "popsize": 400,
36 | "num_elites": 40,
37 | "max_iters": 5,
38 | "alpha": 0.1
39 | }
40 | }
41 |
42 | @staticmethod
43 | def obs_preproc(obs):
44 | if isinstance(obs, np.ndarray):
45 | return np.concatenate([np.sin(obs[:, 1:2]), np.cos(obs[:, 1:2]), obs[:, :1], obs[:, 2:]], axis=1)
46 | else:
47 | return tf.concat([tf.sin(obs[:, 1:2]), tf.cos(obs[:, 1:2]), obs[:, :1], obs[:, 2:]], axis=1)
48 |
49 | @staticmethod
50 | def obs_postproc(obs, pred):
51 | return obs + pred
52 |
53 | @staticmethod
54 | def targ_proc(obs, next_obs):
55 | return next_obs - obs
56 |
57 | @staticmethod
58 | def obs_cost_fn(obs):
59 | if isinstance(obs, np.ndarray):
60 | return -np.exp(-np.sum(
61 | np.square(CartpoleConfigModule._get_ee_pos(obs, are_tensors=False) - np.array([0.0, 0.6])), axis=1
62 | ) / (0.6 ** 2))
63 | else:
64 | return -tf.exp(-tf.reduce_sum(
65 | tf.square(CartpoleConfigModule._get_ee_pos(obs, are_tensors=True) - np.array([0.0, 0.6])), axis=1
66 | ) / (0.6 ** 2))
67 |
68 | @staticmethod
69 | def ac_cost_fn(acs):
70 | if isinstance(acs, np.ndarray):
71 | return 0.01 * np.sum(np.square(acs), axis=1)
72 | else:
73 | return 0.01 * tf.reduce_sum(tf.square(acs), axis=1)
74 |
75 | def nn_constructor(self, model_init_cfg):
76 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
77 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"),
78 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False),
79 | model_dir=model_init_cfg.get("model_dir", None)
80 | ))
81 | if not model_init_cfg.get("load_model", False):
82 | model.add(FC(500, input_dim=self.MODEL_IN, activation='swish', weight_decay=0.0001))
83 | model.add(FC(500, activation='swish', weight_decay=0.00025))
84 | model.add(FC(500, activation='swish', weight_decay=0.00025))
85 | model.add(FC(self.MODEL_OUT, weight_decay=0.0005))
86 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001})
87 | return model
88 |
89 | def gp_constructor(self, model_init_cfg):
90 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
91 | name="model",
92 | kernel_class=get_required_argument(model_init_cfg, "kernel_class", "Must provide kernel class"),
93 | kernel_args=model_init_cfg.get("kernel_args", {}),
94 | num_inducing_points=get_required_argument(
95 | model_init_cfg, "num_inducing_points", "Must provide number of inducing points."
96 | ),
97 | sess=self.SESS
98 | ))
99 | return model
100 |
101 | @staticmethod
102 | def _get_ee_pos(obs, are_tensors=False):
103 | x0, theta = obs[:, :1], obs[:, 1:2]
104 | if are_tensors:
105 | return tf.concat([x0 - 0.6 * tf.sin(theta), -0.6 * tf.cos(theta)], axis=1)
106 | else:
107 | return np.concatenate([x0 - 0.6 * np.sin(theta), -0.6 * np.cos(theta)], axis=1)
108 |
109 |
110 | CONFIG_MODULE = CartpoleConfigModule
111 |
--------------------------------------------------------------------------------
/dmbrl/config/halfcheetah_v3.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from dotmap import DotMap
8 | import gym
9 |
10 | from dmbrl.misc.DotmapUtils import get_required_argument
11 | from dmbrl.modeling.layers import FC
12 | import dmbrl.env
13 |
14 |
15 | class HalfCheetahConfigModule:
16 | ENV_NAME = "MBRLHalfCheetah-v3"
17 | TASK_HORIZON = 1000 # Modified from 1000 -> 300
18 | NTRAIN_ITERS = 300
19 | NROLLOUTS_PER_ITER = 1
20 | PLAN_HOR = 15 # Modified from 30 -> 15
21 | MODEL_IN, MODEL_OUT = 24, 18
22 | GP_NINDUCING_POINTS = 300
23 |
24 | def __init__(self):
25 | self.ENV = gym.make(self.ENV_NAME)
26 | cfg = tf.ConfigProto()
27 | cfg.gpu_options.allow_growth = True
28 | self.SESS = tf.Session(config=cfg)
29 | self.NN_TRAIN_CFG = {"epochs": 5}
30 | self.OPT_CFG = {
31 | "Random": {
32 | "popsize": 2500
33 | },
34 | "CEM": {
35 | "popsize": 500,
36 | "num_elites": 50,
37 | "max_iters": 5,
38 | "alpha": 0.1
39 | }
40 | }
41 |
42 | @staticmethod
43 | def obs_preproc(obs):
44 | if isinstance(obs, np.ndarray):
45 | return np.concatenate([obs[:, 1:2], np.sin(obs[:, 2:3]), np.cos(obs[:, 2:3]), obs[:, 3:]], axis=1)
46 | else:
47 | return tf.concat([obs[:, 1:2], tf.sin(obs[:, 2:3]), tf.cos(obs[:, 2:3]), obs[:, 3:]], axis=1)
48 |
49 | @staticmethod
50 | def obs_postproc(obs, pred):
51 | if isinstance(obs, np.ndarray):
52 | return np.concatenate([pred[:, :1], obs[:, 1:] + pred[:, 1:]], axis=1)
53 | else:
54 | return tf.concat([pred[:, :1], obs[:, 1:] + pred[:, 1:]], axis=1)
55 |
56 | @staticmethod
57 | def targ_proc(obs, next_obs):
58 | return np.concatenate([next_obs[:, :1], next_obs[:, 1:] - obs[:, 1:]], axis=1)
59 |
60 | @staticmethod
61 | def obs_cost_fn(obs):
62 | return -obs[:, 0]
63 |
64 | @staticmethod
65 | def ac_cost_fn(acs):
66 | if isinstance(acs, np.ndarray):
67 | return 0.1 * np.sum(np.square(acs), axis=1)
68 | else:
69 | return 0.1 * tf.reduce_sum(tf.square(acs), axis=1)
70 |
71 | def nn_constructor(self, model_init_cfg):
72 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
73 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"),
74 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False),
75 | model_dir=model_init_cfg.get("model_dir", None)
76 | ))
77 | if not model_init_cfg.get("load_model", False):
78 | model.add(FC(200, input_dim=self.MODEL_IN, activation="swish", weight_decay=0.000025))
79 | model.add(FC(200, activation="swish", weight_decay=0.00005))
80 | model.add(FC(200, activation="swish", weight_decay=0.000075))
81 | model.add(FC(200, activation="swish", weight_decay=0.000075))
82 | model.add(FC(self.MODEL_OUT, weight_decay=0.0001))
83 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001})
84 | return model
85 |
86 | def gp_constructor(self, model_init_cfg):
87 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
88 | name="model",
89 | kernel_class=get_required_argument(model_init_cfg, "kernel_class", "Must provide kernel class"),
90 | kernel_args=model_init_cfg.get("kernel_args", {}),
91 | num_inducing_points=get_required_argument(
92 | model_init_cfg, "num_inducing_points", "Must provide number of inducing points."
93 | ),
94 | sess=self.SESS
95 | ))
96 | return model
97 |
98 |
99 | CONFIG_MODULE = HalfCheetahConfigModule
100 |
--------------------------------------------------------------------------------
/dmbrl/config/hopper.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from dotmap import DotMap
8 | import gym
9 |
10 | from dmbrl.misc.DotmapUtils import get_required_argument
11 | from dmbrl.modeling.layers import FC
12 | import dmbrl.env
13 |
14 |
15 | class HopperConfigModule:
16 | ENV_NAME = "MBRLHopper-v3"
17 | TASK_HORIZON = 1000 # Modified from 1000 -> 300
18 | NTRAIN_ITERS = 300
19 | NROLLOUTS_PER_ITER = 1
20 | PLAN_HOR = 15 # Modified from 30 -> 15
21 | MODEL_IN, MODEL_OUT = 14, 12
22 | GP_NINDUCING_POINTS = 300
23 |
24 | def __init__(self):
25 | self.ENV = gym.make(self.ENV_NAME)
26 | cfg = tf.ConfigProto()
27 | cfg.gpu_options.allow_growth = True
28 | self.SESS = tf.Session(config=cfg)
29 | self.NN_TRAIN_CFG = {"epochs": 5}
30 | self.OPT_CFG = {
31 | "Random": {
32 | "popsize": 2500
33 | },
34 | "CEM": {
35 | "popsize": 500,
36 | "num_elites": 50,
37 | "max_iters": 5,
38 | "alpha": 0.1
39 | }
40 | }
41 |
42 | @staticmethod
43 | def obs_preproc(obs):
44 | return obs[:, 1:]
45 |
46 | @staticmethod
47 | def obs_postproc(obs, pred):
48 | if isinstance(obs, np.ndarray):
49 | return np.concatenate([pred[:, :1], obs[:, 1:] + pred[:, 1:]], axis=1)
50 | else:
51 | return tf.concat([pred[:, :1], obs[:, 1:] + pred[:, 1:]], axis=1)
52 |
53 | @staticmethod
54 | def targ_proc(obs, next_obs):
55 | return np.concatenate([next_obs[:, :1], next_obs[:, 1:] - obs[:, 1:]], axis=1)
56 |
57 | @staticmethod
58 | def obs_cost_fn(obs):
59 | obs_cost = - obs[:, 0]
60 | return obs_cost
61 | # height, ang = obs[1:3]
62 | # if isinstance(obs, np.ndarray):
63 | # alive = (np.isfinite(obs).all() and (np.abs(obs[2:]) < 100).all() and
64 | # (height > .7) and (np.abs(ang) < .2))
65 | # else:
66 | # alive = (tf.reduce_all(tf.math.is_inf(obs)) and (tf.reduce_all(tf.math.abs(obs[2:]) < 100)) and
67 | # (height > .7) and (tf.math.abs(ang) < .2))
68 | # alive_bonus = 1.0 if alive else 0.0
69 |
70 | # return obs_cost - alive_bonus
71 |
72 | @staticmethod
73 | def ac_cost_fn(acs):
74 | if isinstance(acs, np.ndarray):
75 | return 1e-3 * np.sum(np.square(acs), axis=1)
76 | else:
77 | return 1e-3 * tf.reduce_sum(tf.square(acs), axis=1)
78 |
79 | def nn_constructor(self, model_init_cfg):
80 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
81 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"),
82 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False),
83 | model_dir=model_init_cfg.get("model_dir", None)
84 | ))
85 | if not model_init_cfg.get("load_model", False):
86 | model.add(FC(200, input_dim=self.MODEL_IN, activation="swish", weight_decay=0.000025))
87 | model.add(FC(200, activation="swish", weight_decay=0.00005))
88 | model.add(FC(200, activation="swish", weight_decay=0.000075))
89 | model.add(FC(200, activation="swish", weight_decay=0.000075))
90 | model.add(FC(self.MODEL_OUT, weight_decay=0.0001))
91 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001})
92 | return model
93 |
94 | def gp_constructor(self, model_init_cfg):
95 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
96 | name="model",
97 | kernel_class=get_required_argument(model_init_cfg, "kernel_class", "Must provide kernel class"),
98 | kernel_args=model_init_cfg.get("kernel_args", {}),
99 | num_inducing_points=get_required_argument(
100 | model_init_cfg, "num_inducing_points", "Must provide number of inducing points."
101 | ),
102 | sess=self.SESS
103 | ))
104 | return model
105 |
106 |
107 | CONFIG_MODULE = HopperConfigModule
108 |
--------------------------------------------------------------------------------
/dmbrl/config/pusher.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from dotmap import DotMap
8 | import gym
9 |
10 | from dmbrl.misc.DotmapUtils import get_required_argument
11 | from dmbrl.modeling.layers import FC
12 | import dmbrl.env
13 |
14 |
15 | class PusherConfigModule:
16 | ENV_NAME = "MBRLPusher-v0"
17 | TASK_HORIZON = 150
18 | NTRAIN_ITERS = 100
19 | NROLLOUTS_PER_ITER = 1
20 | PLAN_HOR = 25 # Modified from 25 to 20
21 | MODEL_IN, MODEL_OUT = 27, 20
22 | GP_NINDUCING_POINTS = 200
23 |
24 | def __init__(self):
25 | self.ENV = gym.make(self.ENV_NAME)
26 | cfg = tf.ConfigProto()
27 | cfg.gpu_options.allow_growth = True
28 | self.SESS = tf.Session(config=cfg)
29 | self.NN_TRAIN_CFG = {"epochs": 5}
30 | self.OPT_CFG = {
31 | "Random": {
32 | "popsize": 2500
33 | },
34 | "CEM": {
35 | "popsize": 500,
36 | "num_elites": 50,
37 | "max_iters": 5,
38 | "alpha": 0.1
39 | }
40 | }
41 |
42 | @staticmethod
43 | def obs_postproc(obs, pred):
44 | return obs + pred
45 |
46 | @staticmethod
47 | def targ_proc(obs, next_obs):
48 | return next_obs - obs
49 |
50 | def obs_cost_fn(self, obs):
51 | to_w, og_w = 0.5, 1.25
52 | tip_pos, obj_pos, goal_pos = obs[:, 14:17], obs[:, 17:20], self.ENV.ac_goal_pos
53 |
54 | if isinstance(obs, np.ndarray):
55 | tip_obj_dist = np.sum(np.abs(tip_pos - obj_pos), axis=1)
56 | obj_goal_dist = np.sum(np.abs(goal_pos - obj_pos), axis=1)
57 | return to_w * tip_obj_dist + og_w * obj_goal_dist
58 | else:
59 | tip_obj_dist = tf.reduce_sum(tf.abs(tip_pos - obj_pos), axis=1)
60 | obj_goal_dist = tf.reduce_sum(tf.abs(goal_pos - obj_pos), axis=1)
61 | return to_w * tip_obj_dist + og_w * obj_goal_dist
62 |
63 | @staticmethod
64 | def ac_cost_fn(acs):
65 | if isinstance(acs, np.ndarray):
66 | return 0.1 * np.sum(np.square(acs), axis=1)
67 | else:
68 | return 0.1 * tf.reduce_sum(tf.square(acs), axis=1)
69 |
70 | def nn_constructor(self, model_init_cfg):
71 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
72 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"),
73 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False),
74 | model_dir=model_init_cfg.get("model_dir", None)
75 | ))
76 | if not model_init_cfg.get("load_model", False):
77 | model.add(FC(200, input_dim=self.MODEL_IN, activation="swish", weight_decay=0.00025))
78 | model.add(FC(200, activation="swish", weight_decay=0.0005))
79 | model.add(FC(200, activation="swish", weight_decay=0.0005))
80 | model.add(FC(self.MODEL_OUT, weight_decay=0.00075))
81 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001})
82 | return model
83 |
84 | def gp_constructor(self, model_init_cfg):
85 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
86 | name="model",
87 | kernel_class=get_required_argument(model_init_cfg, "kernel_class", "Must provide kernel class"),
88 | kernel_args=model_init_cfg.get("kernel_args", {}),
89 | num_inducing_points=get_required_argument(
90 | model_init_cfg, "num_inducing_points", "Must provide number of inducing points."
91 | ),
92 | sess=self.SESS
93 | ))
94 | return model
95 |
96 | def value_net_constructor(self, model_init_cfg):
97 | model = NN(DotMap(
98 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"),
99 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False),
100 | model_dir=model_init_cfg.get("model_dir", None)
101 | ))
102 | if not model_init_cfg.get("load_model", False):
103 | model.add(FC(200, input_dim=self.MODEL_IN, activation="swish", weight_decay=0.00025))
104 | model.add(FC(200, activation="swish", weight_decay=0.0005))
105 | model.add(FC(200, activation="swish", weight_decay=0.0005))
106 | model.add(FC(self.MODEL_OUT, weight_decay=0.00075))
107 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001})
108 | return model
109 |
110 |
111 | CONFIG_MODULE = PusherConfigModule
112 |
--------------------------------------------------------------------------------
/dmbrl/config/reacher.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from dotmap import DotMap
8 | import gym
9 |
10 | from dmbrl.misc.DotmapUtils import get_required_argument
11 | from dmbrl.modeling.layers import FC
12 | import dmbrl.env
13 |
14 |
15 | class ReacherConfigModule:
16 | ENV_NAME = "MBRLReacher3D-v0"
17 | TASK_HORIZON = 150
18 | NTRAIN_ITERS = 100
19 | NROLLOUTS_PER_ITER = 1
20 | PLAN_HOR = 25
21 | MODEL_IN, MODEL_OUT = 24, 17
22 | GP_NINDUCING_POINTS = 200
23 |
24 | def __init__(self):
25 | self.ENV = gym.make(self.ENV_NAME)
26 | self.ENV.reset()
27 | cfg = tf.ConfigProto()
28 | cfg.gpu_options.allow_growth = True
29 | self.SESS = tf.Session(config=cfg)
30 | self.NN_TRAIN_CFG = {"epochs": 5}
31 | self.OPT_CFG = {
32 | "Random": {
33 | "popsize": 2000
34 | },
35 | "CEM": {
36 | "popsize": 400,
37 | "num_elites": 40,
38 | "max_iters": 5,
39 | "alpha": 0.1
40 | }
41 | }
42 | self.UPDATE_FNS = [self.update_goal]
43 |
44 | self.goal = tf.Variable(self.ENV.goal, dtype=tf.float32)
45 | self.SESS.run(self.goal.initializer)
46 |
47 | @staticmethod
48 | def obs_postproc(obs, pred):
49 | return obs + pred
50 |
51 | @staticmethod
52 | def targ_proc(obs, next_obs):
53 | return next_obs - obs
54 |
55 | def update_goal(self, sess=None):
56 | if sess is not None:
57 | self.goal.load(self.ENV.goal, sess)
58 |
59 | def obs_cost_fn(self, obs):
60 | if isinstance(obs, np.ndarray):
61 | return np.sum(np.square(ReacherConfigModule.get_ee_pos(obs, are_tensors=False) - self.ENV.goal), axis=1)
62 | else:
63 | return tf.reduce_sum(tf.square(ReacherConfigModule.get_ee_pos(obs, are_tensors=True) - self.goal), axis=1)
64 |
65 | @staticmethod
66 | def ac_cost_fn(acs):
67 | if isinstance(acs, np.ndarray):
68 | return 0.01 * np.sum(np.square(acs), axis=1)
69 | else:
70 | return 0.01 * tf.reduce_sum(tf.square(acs), axis=1)
71 |
72 | def nn_constructor(self, model_init_cfg):
73 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
74 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"),
75 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False),
76 | model_dir=model_init_cfg.get("model_dir", None)
77 | ))
78 | if not model_init_cfg.get("load_model", False):
79 | model.add(FC(200, input_dim=self.MODEL_IN, activation="swish", weight_decay=0.00025))
80 | model.add(FC(200, activation="swish", weight_decay=0.0005))
81 | model.add(FC(200, activation="swish", weight_decay=0.0005))
82 | model.add(FC(200, activation="swish", weight_decay=0.0005))
83 | model.add(FC(self.MODEL_OUT, weight_decay=0.00075))
84 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.00075})
85 | return model
86 |
87 | def gp_constructor(self, model_init_cfg):
88 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
89 | name="model",
90 | kernel_class=get_required_argument(model_init_cfg, "kernel_class", "Must provide kernel class"),
91 | kernel_args=model_init_cfg.get("kernel_args", {}),
92 | num_inducing_points=get_required_argument(
93 | model_init_cfg, "num_inducing_points", "Must provide number of inducing points."
94 | ),
95 | sess=self.SESS
96 | ))
97 | return model
98 |
99 | @staticmethod
100 | def get_ee_pos(states, are_tensors=False):
101 | theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \
102 | states[:, :1], states[:, 1:2], states[:, 2:3], states[:, 3:4], states[:, 4:5], states[:, 5:6], states[:, 6:]
103 | if are_tensors:
104 | rot_axis = tf.concat([tf.cos(theta2) * tf.cos(theta1), tf.cos(theta2) * tf.sin(theta1), -tf.sin(theta2)],
105 | axis=1)
106 | rot_perp_axis = tf.concat([-tf.sin(theta1), tf.cos(theta1), tf.zeros(tf.shape(theta1))], axis=1)
107 | cur_end = tf.concat([
108 | 0.1 * tf.cos(theta1) + 0.4 * tf.cos(theta1) * tf.cos(theta2),
109 | 0.1 * tf.sin(theta1) + 0.4 * tf.sin(theta1) * tf.cos(theta2) - 0.188,
110 | -0.4 * tf.sin(theta2)
111 | ], axis=1)
112 |
113 | for length, hinge, roll in [(0.321, theta4, theta3), (0.16828, theta6, theta5)]:
114 | perp_all_axis = tf.cross(rot_axis, rot_perp_axis)
115 | x = tf.cos(hinge) * rot_axis
116 | y = tf.sin(hinge) * tf.sin(roll) * rot_perp_axis
117 | z = -tf.sin(hinge) * tf.cos(roll) * perp_all_axis
118 | new_rot_axis = x + y + z
119 | new_rot_perp_axis = tf.cross(new_rot_axis, rot_axis)
120 | new_rot_perp_axis = tf.where(tf.less(tf.norm(new_rot_perp_axis, axis=1), 1e-30),
121 | rot_perp_axis, new_rot_perp_axis)
122 | new_rot_perp_axis /= tf.norm(new_rot_perp_axis, axis=1, keepdims=True)
123 | rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis
124 | else:
125 | rot_axis = np.concatenate([np.cos(theta2) * np.cos(theta1), np.cos(theta2) * np.sin(theta1), -np.sin(theta2)],
126 | axis=1)
127 | rot_perp_axis = np.concatenate([-np.sin(theta1), np.cos(theta1), np.zeros(theta1.shape)], axis=1)
128 | cur_end = np.concatenate([
129 | 0.1 * np.cos(theta1) + 0.4 * np.cos(theta1) * np.cos(theta2),
130 | 0.1 * np.sin(theta1) + 0.4 * np.sin(theta1) * np.cos(theta2) - 0.188,
131 | -0.4 * np.sin(theta2)
132 | ], axis=1)
133 |
134 | for length, hinge, roll in [(0.321, theta4, theta3), (0.16828, theta6, theta5)]:
135 | perp_all_axis = np.cross(rot_axis, rot_perp_axis)
136 | x = np.cos(hinge) * rot_axis
137 | y = np.sin(hinge) * np.sin(roll) * rot_perp_axis
138 | z = -np.sin(hinge) * np.cos(roll) * perp_all_axis
139 | new_rot_axis = x + y + z
140 | new_rot_perp_axis = np.cross(new_rot_axis, rot_axis)
141 | new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = \
142 | rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30]
143 | new_rot_perp_axis /= np.linalg.norm(new_rot_perp_axis, axis=1, keepdims=True)
144 | rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis
145 |
146 | return cur_end
147 |
148 |
149 | CONFIG_MODULE = ReacherConfigModule
150 |
--------------------------------------------------------------------------------
/dmbrl/config/template.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from dotmap import DotMap
8 | import gym
9 |
10 | from dmbrl.misc.DotmapUtils import get_required_argument
11 | from dmbrl.modeling.layers import FC
12 |
13 |
14 | class EnvConfigModule:
15 | ENV_NAME = None
16 | TASK_HORIZON = None
17 | NTRAIN_ITERS = None
18 | NROLLOUTS_PER_ITER = None
19 | PLAN_HOR = None
20 |
21 | def __init__(self):
22 | self.ENV = gym.make(self.ENV_NAME)
23 | cfg = tf.ConfigProto()
24 | cfg.gpu_options.allow_growth = True
25 | self.SESS = tf.Session(config=cfg)
26 | self.NN_TRAIN_CFG = {"epochs": None}
27 | self.OPT_CFG = {
28 | "Random": {
29 | "popsize": None
30 | },
31 | "CEM": {
32 | "popsize": None,
33 | "num_elites": None,
34 | "max_iters": None,
35 | "alpha": None
36 | }
37 | }
38 | self.UPDATE_FNS = []
39 |
40 | # Fill in other things to be done here.
41 |
42 | @staticmethod
43 | def obs_preproc(obs):
44 | # Note: Must be able to process both NumPy and Tensorflow arrays.
45 | if isinstance(obs, np.ndarray):
46 | raise NotImplementedError()
47 | else:
48 | raise NotImplementedError
49 |
50 | @staticmethod
51 | def obs_postproc(obs, pred):
52 | # Note: Must be able to process both NumPy and Tensorflow arrays.
53 | if isinstance(obs, np.ndarray):
54 | raise NotImplementedError()
55 | else:
56 | raise NotImplementedError()
57 |
58 | @staticmethod
59 | def targ_proc(obs, next_obs):
60 | # Note: Only needs to process NumPy arrays.
61 | raise NotImplementedError()
62 |
63 | @staticmethod
64 | def obs_cost_fn(obs):
65 | # Note: Must be able to process both NumPy and Tensorflow arrays.
66 | if isinstance(obs, np.ndarray):
67 | raise NotImplementedError()
68 | else:
69 | raise NotImplementedError()
70 |
71 | @staticmethod
72 | def ac_cost_fn(acs):
73 | # Note: Must be able to process both NumPy and Tensorflow arrays.
74 | if isinstance(acs, np.ndarray):
75 | raise NotImplementedError()
76 | else:
77 | raise NotImplementedError()
78 |
79 | def nn_constructor(self, model_init_cfg):
80 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
81 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"),
82 | sess=self.SESS
83 | ))
84 | # Construct model below. For example:
85 | # model.add(FC(*args))
86 | # ...
87 | # model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001})
88 | return model
89 |
90 |
91 | CONFIG_MODULE = EnvConfigModule
92 |
93 |
--------------------------------------------------------------------------------
/dmbrl/controllers/Controller.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 |
6 | class Controller:
7 | def __init__(self, *args, **kwargs):
8 | """Creates class instance.
9 | """
10 | pass
11 |
12 | def train(self, obs_trajs, acs_trajs, rews_trajs):
13 | """Trains this controller using lists of trajectories.
14 | """
15 | raise NotImplementedError("Must be implemented in subclass.")
16 |
17 | def reset(self):
18 | """Resets this controller.
19 | """
20 | raise NotImplementedError("Must be implemented in subclass.")
21 |
22 | def act(self, obs, t, get_pred_cost=False):
23 | """Performs an action.
24 | """
25 | raise NotImplementedError("Must be implemented in subclass.")
26 |
27 | def dump_logs(self, primary_logdir, iter_logdir):
28 | """Dumps logs into primary log directory and per-train iteration log directory.
29 | """
30 | raise NotImplementedError("Must be implemented in subclass.")
31 |
--------------------------------------------------------------------------------
/dmbrl/controllers/__init__.py:
--------------------------------------------------------------------------------
1 | from .MPC import MPC
2 |
--------------------------------------------------------------------------------
/dmbrl/env/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 |
3 |
4 | register(
5 | id='MBRLCartpole-v0',
6 | entry_point='dmbrl.env.cartpole:CartpoleEnv'
7 | )
8 |
9 |
10 | register(
11 | id='MBRLReacher3D-v0',
12 | entry_point='dmbrl.env.reacher:Reacher3DEnv'
13 | )
14 |
15 |
16 | register(
17 | id='MBRLPusher-v0',
18 | entry_point='dmbrl.env.pusher:PusherEnv'
19 | )
20 |
21 | register(
22 | id='MBRLHalfCheetah-v3',
23 | entry_point='dmbrl.env.half_cheetah_v3:HalfCheetahEnv'
24 | )
25 |
26 | register(
27 | id='MBRLHopper-v3',
28 | entry_point='dmbrl.env.hopper:HopperEnv'
29 | )
30 |
--------------------------------------------------------------------------------
/dmbrl/env/assets/cartpole.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/dmbrl/env/assets/half_cheetah.xml:
--------------------------------------------------------------------------------
1 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/dmbrl/env/assets/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
43 |
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/dmbrl/env/assets/pusher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/dmbrl/env/cartpole.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 |
7 | import numpy as np
8 | from gym import utils
9 | from gym.envs.mujoco import mujoco_env
10 |
11 |
12 | class CartpoleEnv(mujoco_env.MujocoEnv, utils.EzPickle):
13 | PENDULUM_LENGTH = 0.6
14 |
15 | def __init__(self):
16 | utils.EzPickle.__init__(self)
17 | dir_path = os.path.dirname(os.path.realpath(__file__))
18 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/cartpole.xml' % dir_path, 2)
19 |
20 | def step(self, a):
21 | self.do_simulation(a, self.frame_skip)
22 | ob = self._get_state()
23 |
24 | cost_lscale = CartpoleEnv.PENDULUM_LENGTH
25 | reward = np.exp(
26 | -np.sum(np.square(self._get_ee_pos(ob) - np.array([0.0, CartpoleEnv.PENDULUM_LENGTH]))) / (cost_lscale ** 2)
27 | )
28 | reward -= 0.01 * np.sum(np.square(a))
29 |
30 | done = False
31 | return ob, reward, done, {}
32 |
33 | def reset_model(self):
34 | qpos = self.init_qpos + np.random.normal(0, 0.1, np.shape(self.init_qpos))
35 | qvel = self.init_qvel + np.random.normal(0, 0.1, np.shape(self.init_qvel))
36 | self.set_state(qpos, qvel)
37 | return self._get_state()
38 |
39 | def _get_state(self):
40 | return np.concatenate([self.data.qpos, self.data.qvel]).ravel()
41 |
42 | @staticmethod
43 | def _get_ee_pos(x):
44 | x0, theta = x[0], x[1]
45 | return np.array([
46 | x0 - CartpoleEnv.PENDULUM_LENGTH * np.sin(theta),
47 | -CartpoleEnv.PENDULUM_LENGTH * np.cos(theta)
48 | ])
49 |
50 | def viewer_setup(self):
51 | v = self.viewer
52 | v.cam.trackbodyid = 0
53 | v.cam.distance = v.model.stat.extent
54 |
--------------------------------------------------------------------------------
/dmbrl/env/half_cheetah_v3.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 | import os
5 | import numpy as np
6 | from gym import utils
7 | from gym.envs.mujoco import mujoco_env
8 |
9 | DEFAULT_CAMERA_CONFIG = {
10 | # 'trackbodyid': 1,
11 | # 'lookat': np.array((0.0, 0.0, 2.0)),
12 | 'distance': 4.0,
13 | }
14 |
15 |
16 | class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
17 | def __init__(self,
18 | xml_file='half_cheetah.xml',
19 | forward_reward_weight=1.0,
20 | ctrl_cost_weight=0.1,
21 | reset_noise_scale=0.1):
22 | utils.EzPickle.__init__(**locals())
23 | dir_path = os.path.dirname(os.path.realpath(__file__))
24 |
25 | self._forward_reward_weight = forward_reward_weight
26 |
27 | self._ctrl_cost_weight = ctrl_cost_weight
28 |
29 | self._reset_noise_scale = reset_noise_scale
30 | self.prev_qpos = None
31 |
32 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/half_cheetah.xml' % dir_path, 5)
33 |
34 |
35 | def control_cost(self, action):
36 | control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
37 | return control_cost
38 |
39 | def step(self, action):
40 | x_position_before = self.sim.data.qpos[0]
41 | self.prev_qpos = np.copy(self.sim.data.qpos.flat)
42 |
43 | self.do_simulation(action, self.frame_skip)
44 | x_position_after = self.sim.data.qpos[0]
45 | x_velocity = ((x_position_after - x_position_before)
46 | / self.dt)
47 |
48 | ctrl_cost = self.control_cost(action)
49 |
50 | forward_reward = self._forward_reward_weight * x_velocity
51 |
52 | observation = self._get_obs()
53 | reward = forward_reward - ctrl_cost
54 | done = False
55 | info = {}
56 |
57 | return observation, reward, done, info
58 |
59 |
60 | def _get_obs(self):
61 | position = self.sim.data.qpos.flat.copy()
62 | velocity = self.sim.data.qvel.flat.copy()
63 |
64 | return np.concatenate([
65 | (position[:1] - self.prev_qpos[:1]) / self.dt,
66 | position[1:],
67 | velocity,
68 | ]).ravel()
69 |
70 | def _get_state(self):
71 | return np.concatenate([
72 | self.sim.data.qpos.flat,
73 | self.sim.data.qvel.flat,
74 | ])
75 |
76 | def reset_model(self):
77 | noise_low = -self._reset_noise_scale
78 | noise_high = self._reset_noise_scale
79 |
80 | qpos = self.init_qpos + self.np_random.uniform(
81 | low=noise_low, high=noise_high, size=self.model.nq)
82 | qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn(
83 | self.model.nv)
84 |
85 | self.set_state(qpos, qvel)
86 |
87 | observation = self._get_obs()
88 | return observation
89 |
90 | def viewer_setup(self):
91 | for key, value in DEFAULT_CAMERA_CONFIG.items():
92 | if isinstance(value, np.ndarray):
93 | getattr(self.viewer.cam, key)[:] = value
94 | else:
95 | setattr(self.viewer.cam, key, value)
96 |
--------------------------------------------------------------------------------
/dmbrl/env/hopper.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 | import os
5 | import numpy as np
6 | from gym import utils
7 | from gym.envs.mujoco import mujoco_env
8 |
9 |
10 | class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
11 | def __init__(self,
12 | xml_file='hopper.xml'):
13 |
14 | utils.EzPickle.__init__(**locals())
15 | dir_path = os.path.dirname(os.path.realpath(__file__))
16 | self.prev_qpos = None
17 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/hopper.xml' % dir_path, 4)
18 |
19 |
20 | def step(self, a):
21 | posbefore = self.sim.data.qpos[0]
22 | self.prev_qpos = np.copy(self.sim.data.qpos.flat)
23 | self.do_simulation(a, self.frame_skip)
24 | posafter, height, ang = self.sim.data.qpos[0:3]
25 |
26 | s = self.state_vector()
27 |
28 | alive = (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and
29 | (height > .7) and (abs(ang) < .2))
30 |
31 | alive_bonus = 1.0 if alive else 0.0
32 |
33 | forward_reward = (posafter - posbefore) / self.dt
34 | control_cost = 1e-3 * np.square(a).sum()
35 | reward = forward_reward - control_cost + alive_bonus
36 |
37 | done = False
38 | ob = self._get_obs()
39 | return ob, reward, done, {"forward_reward" : forward_reward}
40 |
41 | def _get_obs(self):
42 |
43 | position = self.sim.data.qpos.flat.copy()
44 | velocity = self.sim.data.qvel.flat.copy()
45 |
46 | return np.concatenate([
47 | (position[:1] - self.prev_qpos[:1]) / self.dt,
48 | position[1:],
49 | np.clip(velocity, -10, 10)
50 | ])
51 |
52 | def _get_state(self):
53 | return np.concatenate([
54 | self.sim.data.qpos.flat,
55 | self.sim.data.qvel.flat,
56 | ])
57 |
58 | def reset_model(self):
59 | qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq)
60 | qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
61 | self.set_state(qpos, qvel)
62 | return self._get_obs()
63 |
64 | def viewer_setup(self):
65 | self.viewer.cam.trackbodyid = 2
66 | self.viewer.cam.distance = self.model.stat.extent * 0.75
67 | self.viewer.cam.lookat[2] = 1.15
68 | self.viewer.cam.elevation = -20
--------------------------------------------------------------------------------
/dmbrl/env/pusher.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 |
7 | import numpy as np
8 | from gym import utils
9 | from gym.envs.mujoco import mujoco_env
10 |
11 |
12 | class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
13 | def __init__(self):
14 | dir_path = os.path.dirname(os.path.realpath(__file__))
15 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/pusher.xml' % dir_path, 4)
16 | utils.EzPickle.__init__(self)
17 | self.reset_model()
18 |
19 | def _step(self, a):
20 | obj_pos = self.get_body_com("object"),
21 | vec_1 = obj_pos - self.get_body_com("tips_arm")
22 | vec_2 = obj_pos - self.get_body_com("goal")
23 |
24 | reward_near = -np.sum(np.abs(vec_1))
25 | reward_dist = -np.sum(np.abs(vec_2))
26 | reward_ctrl = -np.square(a).sum()
27 | reward = 1.25 * reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near
28 |
29 | self.do_simulation(a, self.frame_skip)
30 | ob = self._get_obs()
31 | done = False
32 | return ob, reward, done, {}
33 |
34 | def viewer_setup(self):
35 | self.viewer.cam.trackbodyid = -1
36 | self.viewer.cam.distance = 4.0
37 |
38 | def reset_model(self):
39 | qpos = self.init_qpos
40 |
41 | self.goal_pos = np.asarray([0, 0])
42 | self.cylinder_pos = np.array([-0.25, 0.15]) + np.random.normal(0, 0.025, [2])
43 |
44 | qpos[-4:-2] = self.cylinder_pos
45 | qpos[-2:] = self.goal_pos
46 | qvel = self.init_qvel + self.np_random.uniform(low=-0.005,
47 | high=0.005, size=self.model.nv)
48 | qvel[-4:] = 0
49 | self.set_state(qpos, qvel)
50 | self.ac_goal_pos = self.get_body_com("goal")
51 |
52 | return self._get_obs()
53 |
54 | def _get_obs(self):
55 | return np.concatenate([
56 | self.model.data.qpos.flat[:7],
57 | self.model.data.qvel.flat[:7],
58 | self.get_body_com("tips_arm"),
59 | self.get_body_com("object"),
60 | ])
61 |
62 | def _get_state(self):
63 | return np.concatenate([
64 | self.model.data.qpos.flat,
65 | self.model.data.qvel.flat,
66 | ])
67 |
--------------------------------------------------------------------------------
/dmbrl/env/reacher.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 |
7 | import numpy as np
8 | from gym import utils
9 | from gym.envs.mujoco import mujoco_env
10 |
11 |
12 | class Reacher3DEnv(mujoco_env.MujocoEnv, utils.EzPickle):
13 | def __init__(self):
14 | self.viewer = None
15 | utils.EzPickle.__init__(self)
16 | dir_path = os.path.dirname(os.path.realpath(__file__))
17 | self.goal = np.zeros(3)
18 | mujoco_env.MujocoEnv.__init__(self, os.path.join(dir_path, 'assets/reacher3d.xml'), 2)
19 |
20 | def step(self, a):
21 | self.do_simulation(a, self.frame_skip)
22 | ob = self._get_obs()
23 | reward = -np.sum(np.square(self.get_EE_pos(ob[None]) - self.goal))
24 | reward -= 0.01 * np.square(a).sum()
25 | done = False
26 | return ob, reward, done, dict(reward_dist=0, reward_ctrl=0)
27 |
28 | def viewer_setup(self):
29 | self.viewer.cam.trackbodyid = 1
30 | self.viewer.cam.distance = 2.5
31 | self.viewer.cam.elevation = -30
32 | self.viewer.cam.azimuth = 270
33 |
34 | def reset_model(self):
35 | qpos, qvel = np.copy(self.init_qpos), np.copy(self.init_qvel)
36 | qpos[-3:] += np.random.normal(loc=0, scale=0.1, size=[3])
37 | qvel[-3:] = 0
38 | self.goal = qpos[-3:]
39 | self.set_state(qpos, qvel)
40 | return self._get_obs()
41 |
42 | def _get_obs(self):
43 | return np.concatenate([
44 | self.sim.data.qpos.flat,
45 | self.sim.data.qvel.flat[:-3],
46 | ])
47 |
48 | def _get_state(self):
49 | return np.concatenate([
50 | self.sim.data.qpos.flat,
51 | self.sim.data.qvel.flat,
52 | ])
53 |
54 | def get_EE_pos(self, states):
55 | theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \
56 | states[:, :1], states[:, 1:2], states[:, 2:3], states[:, 3:4], states[:, 4:5], states[:, 5:6], states[:, 6:]
57 |
58 | rot_axis = np.concatenate([np.cos(theta2) * np.cos(theta1), np.cos(theta2) * np.sin(theta1), -np.sin(theta2)],
59 | axis=1)
60 | rot_perp_axis = np.concatenate([-np.sin(theta1), np.cos(theta1), np.zeros(theta1.shape)], axis=1)
61 | cur_end = np.concatenate([
62 | 0.1 * np.cos(theta1) + 0.4 * np.cos(theta1) * np.cos(theta2),
63 | 0.1 * np.sin(theta1) + 0.4 * np.sin(theta1) * np.cos(theta2) - 0.188,
64 | -0.4 * np.sin(theta2)
65 | ], axis=1)
66 |
67 | for length, hinge, roll in [(0.321, theta4, theta3), (0.16828, theta6, theta5)]:
68 | perp_all_axis = np.cross(rot_axis, rot_perp_axis)
69 | x = np.cos(hinge) * rot_axis
70 | y = np.sin(hinge) * np.sin(roll) * rot_perp_axis
71 | z = -np.sin(hinge) * np.cos(roll) * perp_all_axis
72 | new_rot_axis = x + y + z
73 | new_rot_perp_axis = np.cross(new_rot_axis, rot_axis)
74 | new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = \
75 | rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30]
76 | new_rot_perp_axis /= np.linalg.norm(new_rot_perp_axis, axis=1, keepdims=True)
77 | rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis
78 |
79 | return cur_end
80 |
81 |
--------------------------------------------------------------------------------
/dmbrl/misc/Agent.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | # from gym.monitoring import VideoRecorder
7 | from dotmap import DotMap
8 | import mujoco_py
9 | import time
10 |
11 |
12 | class Agent:
13 | """An general class for RL agents.
14 | """
15 | def __init__(self, params):
16 | """Initializes an agent.
17 |
18 | Arguments:
19 | params: (DotMap) A DotMap of agent parameters.
20 | .env: (OpenAI gym environment) The environment for this agent.
21 | .noisy_actions: (bool) Indicates whether random Gaussian noise will
22 | be added to the actions of this agent.
23 | .noise_stddev: (float) The standard deviation to be used for the
24 | action noise if params.noisy_actions is True.
25 | """
26 | self.env = params.env
27 | self.noise_stddev = params.noise_stddev if params.get("noisy_actions", False) else None
28 |
29 | if isinstance(self.env, DotMap):
30 | raise ValueError("Environment must be provided to the agent at initialization.")
31 | if (not isinstance(self.noise_stddev, float)) and params.get("noisy_actions", False):
32 | raise ValueError("Must provide standard deviation for noise for noisy actions.")
33 |
34 | if self.noise_stddev is not None:
35 | self.dU = self.env.action_space.shape[0]
36 |
37 | def save(self, path):
38 | """
39 | Save the environment to the given path
40 | """
41 | pass
42 |
43 | def load(self, path):
44 | """
45 | Load the environment from the given poth
46 | """
47 | pass
48 |
49 | def sample(self, horizon, policy, record_fname=None, catch_error=False):
50 | """Samples a rollout from the agent.
51 |
52 | Arguments:
53 | horizon: (int) The length of the rollout to generate from the agent.
54 | policy: (policy) The policy that the agent will use for actions.
55 | record_fname: (str/None) The name of the file to which a recording of the rollout
56 | will be saved. If None, the rollout will not be recorded.
57 |
58 | Returns: (dict) A dictionary containing data from the rollout.
59 | The keys of the dictionary are 'obs', 'ac', and 'reward_sum'.
60 | """
61 | video_record = record_fname is not None
62 | recorder = None
63 |
64 | times, rewards = [], []
65 |
66 | O, A, reward_sum, done = [self.env.reset().tolist()], [], 0, False
67 | init_state = self.env._get_state().tolist()
68 | policy.reset()
69 |
70 | if catch_error:
71 | error_state = False
72 | for t in range(horizon):
73 | start = time.time()
74 | A.append(policy.act(O[t], t).tolist())
75 | times.append(time.time() - start)
76 | try:
77 | if self.noise_stddev is None:
78 | obs, reward, done, info = self.env.step(np.array(A[t]))
79 | else:
80 | action = A[t] + np.random.normal(loc=0, scale=self.noise_stddev, size=[self.dU])
81 | action = np.minimum(np.maximum(action, self.env.action_space.low), self.env.action_space.high)
82 | obs, reward, done, info = self.env.step(action)
83 | O.append(obs.tolist())
84 | reward_sum += reward
85 | rewards.append(reward)
86 |
87 | if done:
88 | break
89 | except mujoco_py.builder.MujocoException:
90 | error_state = True
91 | break
92 |
93 | print("Average action selection time: ", np.mean(times))
94 | print("Rollout length: ", len(A))
95 |
96 | return {
97 | "obs": O,
98 | "ac": A,
99 | "reward_sum": reward_sum,
100 | "rewards": rewards,
101 | "init_state": init_state,
102 | "error_state": error_state
103 | }
104 | else:
105 | for t in range(horizon):
106 | start = time.time()
107 | A.append(policy.act(O[t], t).tolist())
108 | times.append(time.time() - start)
109 |
110 | if self.noise_stddev is None:
111 | obs, reward, done, info = self.env.step(np.array(A[t]))
112 | else:
113 | action = A[t] + np.random.normal(loc=0, scale=self.noise_stddev, size=[self.dU])
114 | action = np.minimum(np.maximum(action, self.env.action_space.low), self.env.action_space.high)
115 | obs, reward, done, info = self.env.step(action)
116 | O.append(obs.tolist())
117 | reward_sum += reward
118 | rewards.append(reward)
119 | if done:
120 | break
121 |
122 | print("Average action selection time: ", np.mean(times))
123 | print("Rollout length: ", len(A))
124 |
125 | return {
126 | "obs": O,
127 | "ac": A,
128 | "reward_sum": reward_sum,
129 | "rewards": rewards,
130 | "init_state": init_state
131 | }
132 |
--------------------------------------------------------------------------------
/dmbrl/misc/DotmapUtils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 |
6 | def get_required_argument(dotmap, key, message, default=None):
7 | val = dotmap.get(key, default)
8 | if val is default:
9 | raise ValueError(message)
10 | return val
11 |
--------------------------------------------------------------------------------
/dmbrl/misc/MBExp.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 | from time import time, localtime, strftime
7 |
8 | import numpy as np
9 | from scipy.io import savemat
10 | from dotmap import DotMap
11 |
12 | from dmbrl.misc.DotmapUtils import get_required_argument
13 | from dmbrl.misc.Agent import Agent
14 |
15 |
16 | class MBExperiment:
17 | def __init__(self, params):
18 | """Initializes class instance.
19 |
20 | Argument:
21 | params (DotMap): A DotMap containing the following:
22 | .sim_cfg:
23 | .env (gym.env): Environment for this experiment
24 | .task_hor (int): Task horizon
25 | .stochastic (bool): (optional) If True, agent adds noise to its actions.
26 | Must provide noise_std (see below). Defaults to False.
27 | .noise_std (float): for stochastic agents, noise of the form N(0, noise_std^2I)
28 | will be added.
29 |
30 | .exp_cfg:
31 | .ntrain_iters (int): Number of training iterations to be performed.
32 | .nrollouts_per_iter (int): (optional) Number of rollouts done between training
33 | iterations. Defaults to 1.
34 | .ninit_rollouts (int): (optional) Number of initial rollouts. Defaults to 1.
35 | .policy (controller): Policy that will be trained.
36 |
37 | .log_cfg:
38 | .logdir (str): Parent of directory path where experiment data will be saved.
39 | Experiment will be saved in logdir/
40 | .nrecord (int): (optional) Number of rollouts to record for every iteration.
41 | Defaults to 0.
42 | .neval (int): (optional) Number of rollouts for performance evaluation.
43 | Defaults to 1.
44 | """
45 | self.env = get_required_argument(params.sim_cfg, "env", "Must provide environment.")
46 | self.task_hor = get_required_argument(params.sim_cfg, "task_hor", "Must provide task horizon.")
47 | if params.sim_cfg.get("stochastic", False):
48 | self.agent = Agent(DotMap(
49 | env=self.env, noisy_actions=True,
50 | noise_stddev=get_required_argument(
51 | params.sim_cfg,
52 | "noise_std",
53 | "Must provide noise standard deviation in the case of a stochastic environment."
54 | )
55 | ))
56 | else:
57 | self.agent = Agent(DotMap(env=self.env, noisy_actions=False))
58 |
59 | self.ntrain_iters = get_required_argument(
60 | params.exp_cfg, "ntrain_iters", "Must provide number of training iterations."
61 | )
62 | self.nrollouts_per_iter = params.exp_cfg.get("nrollouts_per_iter", 1)
63 | self.ninit_rollouts = params.exp_cfg.get("ninit_rollouts", 1)
64 | self.policy = get_required_argument(params.exp_cfg, "policy", "Must provide a policy.")
65 |
66 | self.logdir = os.path.join(
67 | get_required_argument(params.log_cfg, "logdir", "Must provide log parent directory."),
68 | strftime("%Y-%m-%d--%H:%M:%S", localtime())
69 | )
70 | self.nrecord = params.log_cfg.get("nrecord", 0)
71 | self.neval = params.log_cfg.get("neval", 1)
72 |
73 | def run_experiment(self):
74 | """Perform experiment.
75 | """
76 | os.makedirs(self.logdir, exist_ok=True)
77 |
78 | traj_obs, traj_acs, traj_rets, traj_rews = [], [], [], []
79 |
80 | # Perform initial rollouts
81 | samples = []
82 | for i in range(self.ninit_rollouts):
83 | samples.append(
84 | self.agent.sample(
85 | self.task_hor, self.policy
86 | )
87 | )
88 | traj_obs.append(samples[-1]["obs"])
89 | traj_acs.append(samples[-1]["ac"])
90 | traj_rews.append(samples[-1]["rewards"])
91 |
92 | if self.ninit_rollouts > 0:
93 | self.policy.train(
94 | [sample["obs"] for sample in samples],
95 | [sample["ac"] for sample in samples],
96 | [sample["rewards"] for sample in samples]
97 | )
98 |
99 | # Training loop
100 | for i in range(self.ntrain_iters):
101 | print("####################################################################")
102 | print("Starting training iteration %d." % (i + 1))
103 |
104 | iter_dir = os.path.join(self.logdir, "train_iter%d" % (i + 1))
105 | os.makedirs(iter_dir, exist_ok=True)
106 |
107 | samples = []
108 | for j in range(self.nrecord):
109 | samples.append(
110 | self.agent.sample(
111 | self.task_hor, self.policy,
112 | os.path.join(iter_dir, "rollout%d.mp4" % j)
113 | )
114 | )
115 | if self.nrecord > 0:
116 | for item in filter(lambda f: f.endswith(".json"), os.listdir(iter_dir)):
117 | os.remove(os.path.join(iter_dir, item))
118 | for j in range(max(self.neval, self.nrollouts_per_iter) - self.nrecord):
119 | samples.append(
120 | self.agent.sample(
121 | self.task_hor, self.policy
122 | )
123 | )
124 | print("Rewards obtained:", [sample["reward_sum"] for sample in samples[:self.neval]])
125 | traj_obs.extend([sample["obs"] for sample in samples[:self.nrollouts_per_iter]])
126 | traj_acs.extend([sample["ac"] for sample in samples[:self.nrollouts_per_iter]])
127 | traj_rets.extend([sample["reward_sum"] for sample in samples[:self.neval]])
128 | traj_rews.extend([sample["rewards"] for sample in samples[:self.nrollouts_per_iter]])
129 | samples = samples[:self.nrollouts_per_iter]
130 |
131 | self.policy.dump_logs(self.logdir, iter_dir)
132 | savemat(
133 | os.path.join(self.logdir, "logs.mat"),
134 | {
135 | "observations": traj_obs,
136 | "actions": traj_acs,
137 | "returns": traj_rets,
138 | "rewards": traj_rews
139 | }
140 | )
141 | # Delete iteration directory if not used
142 | if len(os.listdir(iter_dir)) == 0:
143 | os.rmdir(iter_dir)
144 |
145 | if i < self.ntrain_iters - 1:
146 | self.policy.train(
147 | [sample["obs"] for sample in samples],
148 | [sample["ac"] for sample in samples],
149 | [sample["rewards"] for sample in samples]
150 | )
151 |
--------------------------------------------------------------------------------
/dmbrl/misc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/automl/HPO_for_RL/d82c7ddd6fe19834c088137570530f11761d9390/dmbrl/misc/__init__.py
--------------------------------------------------------------------------------
/dmbrl/misc/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .cem import CEMOptimizer
2 | from .random import RandomOptimizer
--------------------------------------------------------------------------------
/dmbrl/misc/optimizers/cem.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import tensorflow as tf
6 | import numpy as np
7 | import scipy.stats as stats
8 |
9 | from .optimizer import Optimizer
10 |
11 |
12 | class CEMOptimizer(Optimizer):
13 | """A Tensorflow-compatible CEM optimizer.
14 | """
15 | def __init__(self, sol_dim, max_iters, popsize, num_elites, tf_session=None,
16 | upper_bound=None, lower_bound=None, epsilon=0.001, alpha=0.25):
17 | """Creates an instance of this class.
18 |
19 | Arguments:
20 | sol_dim (int): The dimensionality of the problem space
21 | max_iters (int): The maximum number of iterations to perform during optimization
22 | popsize (int): The number of candidate solutions to be sampled at every iteration
23 | num_elites (int): The number of top solutions that will be used to obtain the distribution
24 | at the next iteration.
25 | tf_session (tf.Session): (optional) Session to be used for this optimizer. Defaults to None,
26 | in which case any functions passed in cannot be tf.Tensor-valued.
27 | upper_bound (np.array): An array of upper bounds
28 | lower_bound (np.array): An array of lower bounds
29 | epsilon (float): A minimum variance. If the maximum variance drops below epsilon, optimization is
30 | stopped.
31 | alpha (float): Controls how much of the previous mean and variance is used for the next iteration.
32 | next_mean = alpha * old_mean + (1 - alpha) * elite_mean, and similarly for variance.
33 | """
34 | super().__init__()
35 | self.sol_dim, self.max_iters, self.popsize, self.num_elites = sol_dim, max_iters, popsize, num_elites
36 | self.ub, self.lb = upper_bound, lower_bound
37 | self.epsilon, self.alpha = epsilon, alpha
38 | self.tf_sess = tf_session
39 |
40 | if num_elites > popsize:
41 | raise ValueError("Number of elites must be at most the population size.")
42 |
43 | if self.tf_sess is not None:
44 | with self.tf_sess.graph.as_default():
45 | with tf.variable_scope("CEMSolver") as scope:
46 | self.init_mean = tf.placeholder(dtype=tf.float32, shape=[sol_dim])
47 | self.init_var = tf.placeholder(dtype=tf.float32, shape=[sol_dim])
48 |
49 | self.num_opt_iters, self.mean, self.var = None, None, None
50 | self.tf_compatible, self.cost_function = None, None
51 |
52 | def setup(self, cost_function, tf_compatible):
53 | """Sets up this optimizer using a given cost function.
54 |
55 | Arguments:
56 | cost_function (func): A function for computing costs over a batch of candidate solutions.
57 | tf_compatible (bool): True if the cost function provided is tf.Tensor-valued.
58 |
59 | Returns: None
60 | """
61 | if tf_compatible and self.tf_sess is None:
62 | raise RuntimeError("Cannot pass in a tf.Tensor-valued cost function without passing in a TensorFlow "
63 | "session into the constructor")
64 |
65 | self.tf_compatible = tf_compatible
66 |
67 | if not tf_compatible:
68 | self.cost_function = cost_function
69 | else:
70 | def continue_optimization(t, mean, var, best_val, best_sol):
71 | return tf.logical_and(tf.less(t, self.max_iters), tf.reduce_max(var) > self.epsilon)
72 |
73 | def iteration(t, mean, var, best_val, best_sol):
74 | lb_dist, ub_dist = mean - self.lb, self.ub - mean
75 | constrained_var = tf.minimum(tf.minimum(tf.square(lb_dist / 2), tf.square(ub_dist / 2)), var)
76 | samples = tf.truncated_normal([self.popsize, self.sol_dim], mean, tf.sqrt(constrained_var))
77 |
78 | costs = cost_function(samples)
79 | values, indices = tf.nn.top_k(-costs, k=self.num_elites, sorted=True)
80 |
81 | best_val, best_sol = tf.cond(
82 | tf.less(-values[0], best_val),
83 | lambda: (-values[0], samples[indices[0]]),
84 | lambda: (best_val, best_sol)
85 | )
86 |
87 | elites = tf.gather(samples, indices)
88 | new_mean = tf.reduce_mean(elites, axis=0)
89 | new_var = tf.reduce_mean(tf.square(elites - new_mean), axis=0)
90 |
91 | mean = self.alpha * mean + (1 - self.alpha) * new_mean
92 | var = self.alpha * var + (1 - self.alpha) * new_var
93 |
94 | return t + 1, mean, var, best_val, best_sol
95 |
96 | with self.tf_sess.graph.as_default():
97 | self.num_opt_iters, self.mean, self.var, self.best_val, self.best_sol = tf.while_loop(
98 | cond=continue_optimization, body=iteration,
99 | loop_vars=[0, self.init_mean, self.init_var, float("inf"), self.init_mean]
100 | )
101 |
102 | def reset(self):
103 | pass
104 |
105 | def obtain_solution(self, init_mean, init_var):
106 | """Optimizes the cost function using the provided initial candidate distribution
107 |
108 | Arguments:
109 | init_mean (np.ndarray): The mean of the initial candidate distribution.
110 | init_var (np.ndarray): The variance of the initial candidate distribution.
111 | """
112 | if self.tf_compatible:
113 | sol, solvar = self.tf_sess.run(
114 | [self.mean, self.var],
115 | feed_dict={self.init_mean: init_mean, self.init_var: init_var}
116 | )
117 | else:
118 | mean, var, t = init_mean, init_var, 0
119 | X = stats.truncnorm(-2, 2, loc=np.zeros_like(mean), scale=np.ones_like(mean))
120 |
121 | while (t < self.max_iters) and np.max(var) > self.epsilon:
122 | lb_dist, ub_dist = mean - self.lb, self.ub - mean
123 | constrained_var = np.minimum(np.minimum(np.square(lb_dist / 2), np.square(ub_dist / 2)), var)
124 |
125 | samples = X.rvs(size=[self.popsize, self.sol_dim]) * np.sqrt(constrained_var) + mean
126 | costs = self.cost_function(samples)
127 | elites = samples[np.argsort(costs)][:self.num_elites]
128 |
129 | new_mean = np.mean(elites, axis=0)
130 | new_var = np.var(elites, axis=0)
131 |
132 | mean = self.alpha * mean + (1 - self.alpha) * new_mean
133 | var = self.alpha * var + (1 - self.alpha) * new_var
134 |
135 | t += 1
136 | sol, solvar = mean, var
137 | return sol
138 |
139 |
--------------------------------------------------------------------------------
/dmbrl/misc/optimizers/optimizer.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 |
6 | class Optimizer:
7 | def __init__(self, *args, **kwargs):
8 | pass
9 |
10 | def setup(self, cost_function, tf_compatible):
11 | raise NotImplementedError("Must be implemented in subclass.")
12 |
13 | def reset(self):
14 | raise NotImplementedError("Must be implemented in subclass.")
15 |
16 | def obtain_solution(self, *args, **kwargs):
17 | raise NotImplementedError("Must be implemented in subclass.")
18 |
--------------------------------------------------------------------------------
/dmbrl/misc/optimizers/random.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import absolute_import
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from .optimizer import Optimizer
9 |
10 |
11 | class RandomOptimizer(Optimizer):
12 | def __init__(self, sol_dim, popsize, tf_session,
13 | upper_bound=None, lower_bound=None):
14 | """Creates an instance of this class.
15 |
16 | Arguments:
17 | sol_dim (int): The dimensionality of the problem space
18 | popsize (int): The number of candidate solutions to be sampled at every iteration
19 | num_elites (int): The number of top solutions that will be used to obtain the distribution
20 | at the next iteration.
21 | tf_session (tf.Session): (optional) Session to be used for this optimizer. Defaults to None,
22 | in which case any functions passed in cannot be tf.Tensor-valued.
23 | upper_bound (np.array): An array of upper bounds
24 | lower_bound (np.array): An array of lower bounds
25 | """
26 | super().__init__()
27 | self.sol_dim = sol_dim
28 | self.popsize = popsize
29 | self.ub, self.lb = upper_bound, lower_bound
30 | self.tf_sess = tf_session
31 | self.solution = None
32 | self.tf_compatible, self.cost_function = None, None
33 |
34 | def setup(self, cost_function, tf_compatible):
35 | """Sets up this optimizer using a given cost function.
36 |
37 | Arguments:
38 | cost_function (func): A function for computing costs over a batch of candidate solutions.
39 | tf_compatible (bool): True if the cost function provided is tf.Tensor-valued.
40 |
41 | Returns: None
42 | """
43 | if tf_compatible and self.tf_sess is None:
44 | raise RuntimeError("Cannot pass in a tf.Tensor-valued cost function without passing in a TensorFlow "
45 | "session into the constructor")
46 |
47 | if not tf_compatible:
48 | self.tf_compatible = False
49 | self.cost_function = cost_function
50 | else:
51 | with self.tf_sess.graph.as_default():
52 | self.tf_compatible = True
53 | solutions = tf.random_uniform([self.popsize, self.sol_dim], self.ub, self.lb)
54 | costs = cost_function(solutions)
55 | self.solution = solutions[tf.cast(tf.argmin(costs), tf.int32)]
56 |
57 | def reset(self):
58 | pass
59 |
60 | def obtain_solution(self, *args, **kwargs):
61 | """Optimizes the cost function provided in setup().
62 |
63 | Arguments:
64 | init_mean (np.ndarray): The mean of the initial candidate distribution.
65 | init_var (np.ndarray): The variance of the initial candidate distribution.
66 | """
67 | if self.tf_compatible:
68 | return self.tf_sess.run(self.solution)
69 | else:
70 | solutions = np.random.uniform(self.lb, self.ub, [self.popsize, self.sol_dim])
71 | costs = self.cost_function(solutions)
72 | return solutions[np.argmin(costs)]
73 |
--------------------------------------------------------------------------------
/dmbrl/misc/render.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 | import argparse
7 | import pprint
8 |
9 | from dotmap import DotMap
10 |
11 | from dmbrl.misc.MBExp import MBExperiment
12 | from dmbrl.controllers.MPC import MPC
13 | from dmbrl.config import create_config
14 |
15 |
16 | def main(env, ctrl_type, ctrl_args, overrides, model_dir, logdir):
17 | ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args})
18 |
19 | overrides.append(["ctrl_cfg.prop_cfg.model_init_cfg.model_dir", model_dir])
20 | overrides.append(["ctrl_cfg.prop_cfg.model_init_cfg.load_model", "True"])
21 | overrides.append(["ctrl_cfg.prop_cfg.model_pretrained", "True"])
22 | overrides.append(["exp_cfg.exp_cfg.ninit_rollouts", "0"])
23 | overrides.append(["exp_cfg.exp_cfg.ntrain_iters", "1"])
24 | overrides.append(["exp_cfg.log_cfg.nrecord", "1"])
25 |
26 | cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir)
27 | cfg.pprint()
28 |
29 | if ctrl_type == "MPC":
30 | cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg)
31 | exp = MBExperiment(cfg.exp_cfg)
32 |
33 | os.makedirs(exp.logdir)
34 | with open(os.path.join(exp.logdir, "config.txt"), "w") as f:
35 | f.write(pprint.pformat(cfg.toDict()))
36 |
37 | exp.run_experiment()
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('-env', type=str, required=True)
43 | parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[])
44 | parser.add_argument('-o', '--override', action='append', nargs=2, default=[])
45 | parser.add_argument('-model-dir', type=str, required=True)
46 | parser.add_argument('-logdir', type=str, required=True)
47 | args = parser.parse_args()
48 |
49 | main(args.env, "MPC", args.ctrl_arg, args.override, args.model_dir, args.logdir)
50 |
--------------------------------------------------------------------------------
/dmbrl/modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/automl/HPO_for_RL/d82c7ddd6fe19834c088137570530f11761d9390/dmbrl/modeling/__init__.py
--------------------------------------------------------------------------------
/dmbrl/modeling/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .FC import FC
--------------------------------------------------------------------------------
/dmbrl/modeling/models/TFGP.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 |
6 | import tensorflow as tf
7 | import numpy as np
8 | import gpflow
9 |
10 | from dmbrl.misc.DotmapUtils import get_required_argument
11 |
12 |
13 | class TFGP:
14 | def __init__(self, params):
15 | """Initializes class instance.
16 |
17 | Arguments:
18 | params
19 | .name (str): Model name
20 | .kernel_class (class): Kernel class
21 | .kernel_args (args): Kernel args
22 | .num_inducing_points (int): Number of inducing points
23 | .sess (tf.Session): Tensorflow session
24 | """
25 | self.name = params.get("name", "GP")
26 | self.kernel_class = get_required_argument(params, "kernel_class", "Must provide kernel class.")
27 | self.kernel_args = params.get("kernel_args", {})
28 | self.num_inducing_points = get_required_argument(
29 | params, "num_inducing_points", "Must provide number of inducing points."
30 | )
31 |
32 | if params.get("sess", None) is None:
33 | config = tf.ConfigProto()
34 | config.gpu_options.allow_growth = True
35 | self._sess = tf.Session(config=config)
36 | else:
37 | self._sess = params.get("sess")
38 |
39 | with self._sess.as_default():
40 | with tf.variable_scope(self.name):
41 | output_dim = self.kernel_args["output_dim"]
42 | del self.kernel_args["output_dim"]
43 | self.model = gpflow.models.SGPR(
44 | np.zeros([1, self.kernel_args["input_dim"]]),
45 | np.zeros([1, output_dim]),
46 | kern=self.kernel_class(**self.kernel_args),
47 | Z=np.zeros([self.num_inducing_points, self.kernel_args["input_dim"]])
48 | )
49 | self.model.initialize()
50 |
51 | @property
52 | def is_probabilistic(self):
53 | return True
54 |
55 | @property
56 | def sess(self):
57 | return self._sess
58 |
59 | @property
60 | def is_tf_model(self):
61 | return True
62 |
63 | def train(self, inputs, targets,
64 | *args, **kwargs):
65 | """Optimizes the parameters of the internal GP model.
66 |
67 | Arguments:
68 | inputs: (np.ndarray) An array of inputs.
69 | targets: (np.ndarray) An array of targets.
70 | num_restarts: (int) The number of times that the optimization of
71 | the GP will be restarted to obtain a good set of parameters.
72 |
73 | Returns: None.
74 | """
75 | perm = np.random.permutation(inputs.shape[0])
76 | inputs, targets = inputs[perm], targets[perm]
77 | Z = np.copy(inputs[:self.num_inducing_points])
78 | if Z.shape[0] < self.num_inducing_points:
79 | Z = np.concatenate([Z, np.zeros([self.num_inducing_points - Z.shape[0], Z.shape[1]])])
80 | self.model.X = inputs
81 | self.model.Y = targets
82 | self.model.feature.Z = Z
83 | with self.sess.as_default():
84 | self.model.compile()
85 | print("Optimizing model... ", end="")
86 | gpflow.train.ScipyOptimizer().minimize(self.model)
87 | print("Done.")
88 |
89 | def predict(self, inputs, *args, **kwargs):
90 | """Returns the predictions of this model on inputs.
91 |
92 | Arguments:
93 | inputs: (np.ndarray) The inputs on which predictions will be returned.
94 | ign_var: (bool) If True, only returns the mean prediction
95 |
96 | Returns: (np.ndarrays) The mean and variance of the model on the new points.
97 | """
98 | if self.model is None:
99 | raise RuntimeError("Cannot make predictions without initial batch of data.")
100 |
101 | with self.sess.as_default():
102 | mean, var = self.model.predict_y(inputs)
103 | return mean, var
104 |
105 | def create_prediction_tensors(self, inputs, *args, **kwargs):
106 | ""
107 | if self.model is None:
108 | raise RuntimeError("Cannot make predictions without initial batch of data.")
109 |
110 | inputs = tf.cast(inputs, tf.float64)
111 | mean, var = self.model._build_predict(inputs, full_cov=False)
112 | return tf.cast(mean, dtype=tf.float32), tf.cast(var, tf.float32)
113 |
114 | def save(self, *args, **kwargs):
115 | pass
116 |
--------------------------------------------------------------------------------
/dmbrl/modeling/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .BNN import BNN
2 | from .NN import NN
3 | from .TFGP import TFGP
4 |
--------------------------------------------------------------------------------
/dmbrl/modeling/utils/HPO.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from dmbrl.controllers import MPC
3 | from dmbrl.misc.DotmapUtils import get_required_argument
4 | from dmbrl.modeling.layers import FC
5 | from dotmap import DotMap
6 |
7 |
8 | def build_policy_constructor(exp):
9 | """Returns a function which will translates a parameter vector into a configuration
10 | that can be used to construct a policy.
11 |
12 | Args:
13 | exp: experiment object
14 |
15 | Returns:
16 | translate_into_model: A function which will construct a controller (e.g. MPC) given the hyperparameters
17 | """
18 | def translate_into_model(param_dict, config_space, reset=True):
19 | """Translate a set of hyperparameters into a policy
20 |
21 | Args:
22 | param_dict: dictonary which contains the hyperparameters of the policy
23 | config_space: Configuration space which will be used to determine the default hyperparameter
24 | reset (bool, optional): If True, reset the default graph of Tensorflow. Defaults to True.
25 |
26 | Returns:
27 | policy: MPC controller
28 | """
29 | # Clear the graph to avoid conflicts with previous graphs.
30 | if reset:
31 | tf.reset_default_graph()
32 |
33 | # Give parameters names to avoid bugs
34 | config_dict = dict()
35 | config_space = config_space[exp.env_name]
36 | for config_name in config_space.keys():
37 | if config_name in exp.config_names:
38 | config_dict[config_name] = param_dict[config_name]
39 | else:
40 | config_dict[config_name] = config_space[config_name].default_value
41 |
42 | num_hidden_layers = int(config_dict["num_hidden_layers"])
43 | hidden_layer_width = int(config_dict["hidden_layer_width"])
44 | act_idx = config_dict["act_idx"]
45 | model_learning_rate = config_dict["model_learning_rate"]
46 | model_weight_decay = config_dict["model_weight_decay"]
47 | model_opt_idx = config_dict["model_opt_idx"]
48 | num_cem_iters = int(config_dict["num_cem_iters"])
49 | cem_popsize = int(config_dict["cem_popsize"])
50 | cem_alpha = config_dict["cem_alpha"]
51 | num_cem_elites = int(config_dict["cem_elites_ratio"]*cem_popsize)
52 | model_train_epoch = int(config_dict["model_train_epoch"])
53 | plan_hor = int(config_dict["plan_hor"])
54 |
55 |
56 | # Instantiation of parameter vector
57 |
58 | if exp.opt_type == 'CEM':
59 | overrides = [
60 | ("ctrl_cfg.opt_cfg.cfg.max_iters", num_cem_iters),
61 | ("ctrl_cfg.opt_cfg.cfg.popsize", cem_popsize),
62 | ("ctrl_cfg.opt_cfg.cfg.num_elites", num_cem_elites),
63 | ("ctrl_cfg.opt_cfg.cfg.alpha", cem_alpha),
64 | ("ctrl_cfg.opt_cfg.plan_hor", plan_hor),
65 | ("ctrl_cfg.prop_cfg.model_train_cfg.epochs", model_train_epoch)
66 | ]
67 | elif exp.opt_type == 'Random':
68 | overrides = [
69 | ("ctrl_cfg.opt_cfg.cfg.popsize", cem_popsize),
70 | ("ctrl_cfg.opt_cfg.plan_hor", plan_hor),
71 | ("ctrl_cfg.prop_cfg.model_train_cfg.epochs", model_train_epoch)
72 | ]
73 | else:
74 | print(exp.opt_type)
75 | raise(NotImplementedError("Given control type %s is unknown" %exp.opt_type))
76 |
77 | cfg, cfg_module = exp.cfg_creator(overrides)
78 |
79 | def nn_constructor(model_init_cfg):
80 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap(
81 | name="model",
82 | num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"),
83 | sess=cfg_module.SESS, load_model=model_init_cfg.get("load_model", False),
84 | model_dir=model_init_cfg.get("model_dir", None)
85 | ))
86 | if not model_init_cfg.get("load_model", False):
87 | model.add(FC(
88 | hidden_layer_width, input_dim=cfg_module.MODEL_IN,
89 | activation=exp.activation_fns[act_idx], weight_decay=model_weight_decay
90 | ))
91 | for _ in range(num_hidden_layers - 1):
92 | model.add(FC(
93 | hidden_layer_width, activation=exp.activation_fns[act_idx], weight_decay=model_weight_decay
94 | ))
95 | model.add(FC(cfg_module.MODEL_OUT, weight_decay=model_weight_decay))
96 | model.finalize(exp.model_optimizers[model_opt_idx], {"learning_rate": model_learning_rate})
97 | return model
98 |
99 | cfg.ctrl_cfg.prop_cfg.model_init_cfg.model_constructor = nn_constructor
100 |
101 | # Build up model
102 | policy = MPC(cfg.ctrl_cfg)
103 |
104 | return policy
105 |
106 | return translate_into_model
--------------------------------------------------------------------------------
/dmbrl/modeling/utils/TensorStandardScaler.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import tensorflow as tf
6 | import numpy as np
7 |
8 |
9 | class TensorStandardScaler:
10 | """Helper class for automatically normalizing inputs into the network.
11 | """
12 | def __init__(self, x_dim):
13 | """Initializes a scaler.
14 |
15 | Arguments:
16 | x_dim (int): The dimensionality of the inputs into the scaler.
17 |
18 | Returns: None.
19 | """
20 | self.fitted = False
21 | with tf.variable_scope("Scaler"):
22 | self.mu = tf.get_variable(
23 | name="scaler_mu", shape=[1, x_dim], initializer=tf.constant_initializer(0.0),
24 | trainable=False
25 | )
26 | self.sigma = tf.get_variable(
27 | name="scaler_std", shape=[1, x_dim], initializer=tf.constant_initializer(1.0),
28 | trainable=False
29 | )
30 |
31 | self.cached_mu, self.cached_sigma = np.zeros([0, x_dim]), np.ones([1, x_dim])
32 |
33 | def fit(self, data):
34 | """Runs two ops, one for assigning the mean of the data to the internal mean, and
35 | another for assigning the standard deviation of the data to the internal standard deviation.
36 | This function must be called within a 'with .as_default()' block.
37 |
38 | Arguments:
39 | data (np.ndarray): A numpy array containing the input
40 |
41 | Returns: None.
42 | """
43 | mu = np.mean(data, axis=0, keepdims=True)
44 | sigma = np.std(data, axis=0, keepdims=True)
45 | sigma[sigma < 1e-12] = 1.0
46 |
47 | self.mu.load(mu)
48 | self.sigma.load(sigma)
49 | self.fitted = True
50 | self.cache()
51 |
52 | def transform(self, data):
53 | """Transforms the input matrix data using the parameters of this scaler.
54 |
55 | Arguments:
56 | data (np.array): A numpy array containing the points to be transformed.
57 |
58 | Returns: (np.array) The transformed dataset.
59 | """
60 | return (data - self.mu) / self.sigma
61 |
62 | def inverse_transform(self, data):
63 | """Undoes the transformation performed by this scaler.
64 |
65 | Arguments:
66 | data (np.array): A numpy array containing the points to be transformed.
67 |
68 | Returns: (np.array) The transformed dataset.
69 | """
70 | return self.sigma * data + self.mu
71 |
72 | def get_vars(self):
73 | """Returns a list of variables managed by this object.
74 |
75 | Returns: (list) The list of variables.
76 | """
77 | return [self.mu, self.sigma]
78 |
79 | def set_vars(self, mu, sigma):
80 | """Set the mu and sigma according to the given value
81 | """
82 | self.mu.load(mu)
83 | self.sigma.load(sigma)
84 | self.cache()
85 |
86 | def cache(self):
87 | """Caches current values of this scaler.
88 |
89 | Returns: None.
90 | """
91 | self.cached_mu = self.mu.eval()
92 | self.cached_sigma = self.sigma.eval()
93 |
94 | def load_cache(self):
95 | """Loads values from the cache
96 |
97 | Returns: None.
98 | """
99 | self.mu.load(self.cached_mu)
100 | self.sigma.load(self.cached_sigma)
101 |
--------------------------------------------------------------------------------
/dmbrl/modeling/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .TensorStandardScaler import TensorStandardScaler
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: mbrl
2 | channels:
3 | - menpo
4 | - conda-forge
5 | - anaconda
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - attrs=19.3.0=py_0
10 | - backcall=0.1.0=py36_0
11 | - blas=1.0=mkl
12 | - bleach=3.1.0=py36_0
13 | - ca-certificates=2021.5.25=h06a4308_1
14 | - certifi=2021.5.30=py36h06a4308_0
15 | - cffi=1.12.3=py36h2e261b9_0
16 | - cudatoolkit=10.0.130=0
17 | - dbus=1.13.12=h746ee38_0
18 | - defusedxml=0.6.0=py_0
19 | - entrypoints=0.3=py36_0
20 | - expat=2.2.6=he6710b0_0
21 | - fontconfig=2.13.0=h9420a91_0
22 | - freetype=2.9.1=h8a8886c_1
23 | - gettext=0.19.8.1=hc5be6a0_1002
24 | - glfw3=3.2.1=0
25 | - glib=2.63.1=h5a9c865_0
26 | - gmp=6.1.2=h6c8ec71_1
27 | - gst-plugins-base=1.14.0=hbbd80ab_1
28 | - gstreamer=1.14.0=hb453b48_1
29 | - icu=58.2=h9c2bf20_1
30 | - importlib_metadata=1.5.0=py36_0
31 | - intel-openmp=2019.4=243
32 | - ipykernel=5.1.4=py36h39e3cac_0
33 | - ipython=7.12.0=py36h5ca1d4c_0
34 | - ipython_genutils=0.2.0=py36_0
35 | - ipywidgets=7.5.1=py_0
36 | - jedi=0.16.0=py36_0
37 | - jinja2=2.11.1=py_0
38 | - jpeg=9b=h024ee3a_2
39 | - jsonschema=3.2.0=py36_0
40 | - jupyter=1.0.0=py36_7
41 | - jupyter_client=5.3.4=py36_0
42 | - jupyter_console=6.1.0=py_0
43 | - jupyter_core=4.6.1=py36_0
44 | - libedit=3.1.20181209=hc058e9b_0
45 | - libffi=3.2.1=hd88cf55_4
46 | - libgcc-ng=9.1.0=hdf63c60_0
47 | - libgfortran-ng=7.3.0=hdf63c60_0
48 | - libgpg-error=1.36=he1b5a44_0
49 | - libpng=1.6.37=hbc83047_0
50 | - libsodium=1.0.16=h1bed415_0
51 | - libstdcxx-ng=9.1.0=hdf63c60_0
52 | - libtiff=4.0.10=h2733197_2
53 | - libuuid=1.0.3=h1bed415_2
54 | - libxcb=1.13=h1bed415_1
55 | - libxml2=2.9.9=hea5a465_1
56 | - llvm=3.3=0
57 | - markupsafe=1.1.1=py36h7b6447c_0
58 | - mesa=10.5.4=0
59 | - mistune=0.8.4=py36h7b6447c_0
60 | - mkl=2019.4=243
61 | - mkl-service=2.0.2=py36h7b6447c_0
62 | - mkl_fft=1.0.14=py36ha843d7b_0
63 | - mkl_random=1.0.2=py36hd81dba3_0
64 | - nbconvert=5.6.1=py36_0
65 | - nbformat=5.0.4=py_0
66 | - ncurses=6.1=he6710b0_1
67 | - ninja=1.9.0=py36hfd86e86_0
68 | - notebook=6.0.3=py36_0
69 | - numpy-base=1.19.1=py36hfa32c7d_0
70 | - olefile=0.46=py36_0
71 | - openssl=1.1.1k=h27cfd23_0
72 | - osmesa=12.2.2.dev=0
73 | - pandoc=2.2.3.2=0
74 | - pandocfilters=1.4.2=py36_1
75 | - parso=0.6.1=py_0
76 | - patchelf=0.9=he6710b0_3
77 | - pcre=8.43=he6710b0_0
78 | - pexpect=4.8.0=py36_0
79 | - pickleshare=0.7.5=py36_0
80 | - pillow=6.1.0=py36h34e0f95_0
81 | - pip=19.2.2=py36_0
82 | - prometheus_client=0.7.1=py_0
83 | - prompt_toolkit=3.0.3=py_0
84 | - ptyprocess=0.6.0=py36_0
85 | - pycparser=2.19=py36_0
86 | - pygments=2.5.2=py_0
87 | - pyqt=5.9.2=py36h05f1152_2
88 | - pyrsistent=0.15.7=py36h7b6447c_0
89 | - python=3.6.9=h265db76_0
90 | - pyzmq=18.1.1=py36he6710b0_0
91 | - qt=5.9.7=h5867ecd_1
92 | - qtconsole=4.6.0=py_1
93 | - readline=7.0=h7b6447c_5
94 | - send2trash=1.5.0=py36_0
95 | - setuptools=41.0.1=py36_0
96 | - sip=4.19.8=py36hf484d3e_0
97 | - six=1.12.0=py36_0
98 | - sqlite=3.29.0=h7b6447c_0
99 | - system=5.8=2
100 | - terminado=0.8.3=py36_0
101 | - testpath=0.4.4=py_0
102 | - tk=8.6.8=hbc83047_0
103 | - tornado=6.0.3=py36h7b6447c_3
104 | - traitlets=4.3.3=py36_0
105 | - wcwidth=0.1.8=py_0
106 | - webencodings=0.5.1=py36_1
107 | - wheel=0.33.4=py36_0
108 | - widgetsnbextension=3.5.1=py36_0
109 | - xz=5.2.4=h14c3975_4
110 | - zeromq=4.3.1=he6710b0_3
111 | - zipp=2.2.0=py_0
112 | - zlib=1.2.11=h7b6447c_3
113 | - zstd=1.3.7=h0b5b093_0
114 | - pip:
115 | - absl-py==0.9.0
116 | - astor==0.8.1
117 | - astunparse==1.6.3
118 | - cachetools==4.0.0
119 | - chardet==3.0.4
120 | - cloudpickle==1.2.1
121 | - configspace==0.4.10
122 | - cycler==0.10.0
123 | - cython==0.29.13
124 | - dataclasses==0.7
125 | - decorator==4.4.0
126 | - dm-env==1.2
127 | - dm-tree==0.1.5
128 | - dotmap==1.2.20
129 | - future==0.16.0
130 | - gast==0.2.2
131 | - glfw==1.8.3
132 | - google-auth==1.11.2
133 | - google-auth-oauthlib==0.4.1
134 | - google-pasta==0.2.0
135 | - googledrivedownloader==0.4
136 | - gpflow==1.1.0
137 | - grpcio==1.27.2
138 | - gym==0.14.0
139 | - h5py==2.10.0
140 | - hpbandster==0.7.4
141 | - idna==2.8
142 | - imageio==2.5.0
143 | - isodate==0.6.0
144 | - joblib==0.13.2
145 | - keras-applications==1.0.8
146 | - keras-preprocessing==1.1.2
147 | - kiwisolver==1.1.0
148 | - labmaze==1.0.3
149 | - lockfile==0.12.2
150 | - lxml==4.5.2
151 | - markdown==3.2.1
152 | - matplotlib==3.1.1
153 | - more-itertools==8.3.0
154 | - mujoco-py==2.0.2.5
155 | - multipledispatch==0.6.0
156 | - netifaces==0.10.9
157 | - networkx==2.3
158 | - numpy==1.17.4
159 | - oauthlib==3.1.0
160 | - omegaconf==2.0.0
161 | - opencv-python==4.4.0.42
162 | - opt-einsum==3.2.1
163 | - packaging==20.4
164 | - pandas==0.25.1
165 | - patsy==0.5.1
166 | - pluggy==0.13.1
167 | - plyfile==0.7
168 | - protobuf==3.10.0
169 | - py==1.8.1
170 | - pyasn1==0.4.8
171 | - pyasn1-modules==0.2.8
172 | - pyglet==1.3.2
173 | - pyopengl==3.1.5
174 | - pyparsing==2.4.2
175 | - pyro4==4.80
176 | - pytest==5.4.2
177 | - python-dateutil==2.8.0
178 | - pytz==2019.2
179 | - pyyaml==5.3.1
180 | - rdflib==4.2.2
181 | - requests==2.22.0
182 | - requests-oauthlib==1.3.0
183 | - rsa==4.0
184 | - scikit-learn==0.21.3
185 | - scipy==1.3.1
186 | - seaborn==0.11.0
187 | - serpent==1.30.2
188 | - statsmodels==0.11.1
189 | - tensorboard==1.15.0
190 | - tensorboard-plugin-wit==1.6.0.post3
191 | - tensorflow==1.15.0
192 | - tensorflow-estimator==1.15.1
193 | - termcolor==1.1.0
194 | - tqdm==4.19.4
195 | - typing==3.7.4.1
196 | - typing-extensions==3.7.4.2
197 | - urllib3==1.25.3
198 | - werkzeug==1.0.0
199 | - wrapt==1.12.1
200 | prefix: /home/zhangb/anaconda3/envs/mbrl
201 |
--------------------------------------------------------------------------------
/learned_schedule/reacher_rs_model_train.json:
--------------------------------------------------------------------------------
1 | [{"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}]
--------------------------------------------------------------------------------
/mbexp.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 | import argparse
7 | import pprint
8 |
9 | from dotmap import DotMap
10 |
11 | from dmbrl.misc.MBExp import MBExperiment
12 | from dmbrl.controllers.MPC import MPC
13 | from dmbrl.config import create_config
14 |
15 |
16 | def main(env, ctrl_type, ctrl_args, overrides, logdir):
17 | ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args})
18 | cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir)[0]
19 | cfg.pprint()
20 |
21 | if ctrl_type == "MPC":
22 | cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg)
23 | exp = MBExperiment(cfg.exp_cfg)
24 |
25 | os.makedirs(exp.logdir)
26 | with open(os.path.join(exp.logdir, "config.txt"), "w") as f:
27 | f.write(pprint.pformat(cfg.toDict()))
28 |
29 | exp.run_experiment()
30 |
31 |
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('-env', type=str, required=True,
35 | help='Environment name: select from [cartpole, reacher, pusher, halfcheetah]')
36 | parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[],
37 | help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments')
38 | parser.add_argument('-o', '--override', action='append', nargs=2, default=[],
39 | help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides')
40 | parser.add_argument('-logdir', type=str, default='log',
41 | help='Directory to which results will be logged (default: ./log)')
42 | args = parser.parse_args()
43 |
44 | main(args.env, "MPC", args.ctrl_arg, args.override, args.logdir)
45 |
--------------------------------------------------------------------------------
/pbt-bt-mbexp.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import argparse
6 |
7 | from dotmap import DotMap
8 |
9 | from dmbrl.misc.MBwPBTBTExp import MBWithPBTBTExperiment
10 | from dmbrl.controllers.MPC import MPC
11 | from dmbrl.config import create_config
12 |
13 | import tensorflow as tf
14 | import numpy as np
15 |
16 | def create_cfg_creator(env, ctrl_type, ctrl_args, base_overrides, logdir):
17 | ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args})
18 |
19 | def cfg_creator(additional_overrides=None):
20 | if additional_overrides is not None:
21 | return create_config(env, ctrl_type, ctrl_args, base_overrides + additional_overrides, logdir)
22 | return create_config(env, ctrl_type, ctrl_args, base_overrides, logdir)
23 |
24 | return cfg_creator
25 |
26 |
27 | def main(args):
28 | cfg_creator = create_cfg_creator(args.env, args.ctrl_type, args.ctrl_arg, args.override, args.logdir)
29 | cfg = cfg_creator()[0]
30 | cfg.pprint()
31 |
32 | if args.ctrl_type == "MPC":
33 | cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg)
34 |
35 | exp = MBWithPBTBTExperiment(cfg.exp_cfg, cfg_creator, args)
36 | exp.run_experiment()
37 |
38 |
39 | if __name__ == "__main__":
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument('-env', type=str, required=True,
42 | help='Environment name: select from [cartpole, reacher, pusher, halfcheetah, halfcheetah_v3]')
43 | parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[],
44 | help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments')
45 | parser.add_argument('-o', '--override', action='append', nargs=2, default=[],
46 | help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides')
47 | parser.add_argument('-logdir', type=str, default='log',
48 | help='Directory to which results will be logged (default: ./log)')
49 | parser.add_argument('-ctrl_type', type=str, default='MPC',
50 | help='Control type will be applied (default: MPC)')
51 | # Parser for running dynamic scheduler on cluster
52 | parser.add_argument('-config_names', type=str, default="model_learning_rate", nargs="+",
53 | help='Specify which hyperparameters to optimize')
54 | parser.add_argument('-seed', type=int, default=0,
55 | help='Specify the random seed of the experiment')
56 | parser.add_argument('-worker_id', type=int, default=0,
57 | help='The worker id, e.g. using SLURM ARRAY JOB ID')
58 | parser.add_argument('-worker', action='store_true',
59 | help='Flag to turn this into a worker process otherwise this will start a new controller')
60 | parser.add_argument('-sample_from_percent', type=float, default=0.2,
61 | help='Sample from the top ratio N')
62 | parser.add_argument('-resample_if_not_in_percent', type=float, default=0.8,
63 | help='Resample if the configuration is not in the top ratio N')
64 | parser.add_argument('-resample_probability', type=float, default=0.25,
65 | help='Probability of an exploited member resampling configurations randomly')
66 | parser.add_argument('-resample_prob_decay', type=float, default=1,
67 | help='decay factor of resample if not in percent')
68 | parser.add_argument('-max_steps', type=int, default=60*40,
69 | help='Maximum amount of steps to take')
70 | parser.add_argument('-delta_t', type=int, default=30,
71 | help='every X steps we do once backtracking')
72 | parser.add_argument('-tolerance', type=float, default=0.2,
73 | help='the tolerance of performance drop')
74 | parser.add_argument('-elite_ratio', type=float, default=0.125,
75 | help='elite ratio of the population')
76 |
77 | args = parser.parse_args()
78 | print(args)
79 | # Set the random seeds of the experiment
80 | tf.set_random_seed(args.seed)
81 | np.random.seed(args.seed)
82 | main(args)
83 |
--------------------------------------------------------------------------------
/pbt-mbexp.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import argparse
6 |
7 | from dotmap import DotMap
8 |
9 | from dmbrl.misc.MBwPBTExp import MBWithPBTExperiment
10 | from dmbrl.controllers.MPC import MPC
11 | from dmbrl.config import create_config
12 |
13 | import tensorflow as tf
14 | import numpy as np
15 |
16 | def create_cfg_creator(env, ctrl_type, ctrl_args, base_overrides, logdir):
17 | ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args})
18 |
19 | def cfg_creator(additional_overrides=None):
20 | if additional_overrides is not None:
21 | return create_config(env, ctrl_type, ctrl_args, base_overrides + additional_overrides, logdir)
22 | return create_config(env, ctrl_type, ctrl_args, base_overrides, logdir)
23 |
24 | return cfg_creator
25 |
26 |
27 | def main(args):
28 | cfg_creator = create_cfg_creator(args.env, args.ctrl_type, args.ctrl_arg, args.override, args.logdir)
29 | cfg = cfg_creator()[0]
30 | cfg.pprint()
31 |
32 | if args.ctrl_type == "MPC":
33 | cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg)
34 |
35 | exp = MBWithPBTExperiment(cfg.exp_cfg, cfg_creator, args)
36 | exp.run_experiment()
37 |
38 |
39 | if __name__ == "__main__":
40 | print("Exp start!")
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('-env', type=str, required=True,
43 | help='Environment name: select from [cartpole, reacher, pusher, halfcheetah, halfcheetah_v3]')
44 | parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[],
45 | help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments')
46 | parser.add_argument('-o', '--override', action='append', nargs=2, default=[],
47 | help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides')
48 | parser.add_argument('-logdir', type=str, default='log',
49 | help='Directory to which results will be logged (default: ./log)')
50 | parser.add_argument('-ctrl_type', type=str, default='MPC',
51 | help='Control type will be applied (default: MPC)')
52 | # Parser for running dynamic scheduler on cluster
53 | parser.add_argument('-config_names', type=str, default="model_learning_rate", nargs="+",
54 | help='Specify which hyperparameters to optimize')
55 | parser.add_argument('-seed', type=int, default=0,
56 | help='Specify the random seed of the experiment')
57 | parser.add_argument('-worker_id', type=int, default=0,
58 | help='The worker id, e.g. using SLURM ARRAY JOB ID')
59 | parser.add_argument('-worker', action='store_true',
60 | help='Flag to turn this into a worker process otherwise this will start a new controller')
61 | parser.add_argument('-sample_from_percent', type=float, default=0.2,
62 | help='Sample from the top ratio N')
63 | parser.add_argument('-resample_if_not_in_percent', type=float, default=0.8,
64 | help='Resample if the configuration is not in the top ratio N')
65 | parser.add_argument('-resample_probability', type=float, default=0.25,
66 | help='Probability of an exploited member resampling configurations randomly')
67 | parser.add_argument('-resample_prob_decay', type=float, default=1,
68 | help='decay factor of resample if not in percent')
69 | parser.add_argument('-not_copy_data', type=bool, default=False,
70 | help='Set to True if not copy the data to new trials')
71 |
72 | args = parser.parse_args()
73 | print(args)
74 | # Set the random seeds of the experiment
75 | tf.set_random_seed(args.seed)
76 | np.random.seed(args.seed)
77 | main(args)
78 |
--------------------------------------------------------------------------------
/pbt/__init__.py:
--------------------------------------------------------------------------------
1 | from pbt.controller import PBTController
2 | from pbt.worker import PBTWorker
3 |
4 | __all__ = ['Controller', 'Worker']
5 |
--------------------------------------------------------------------------------
/pbt/backtrack_controller.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import random
4 | import logging
5 | import os
6 | import sys
7 | from time import sleep
8 |
9 | from pbt.exploitation import Truncation
10 | from pbt.exploration import Perturb
11 | from pbt.garbage_collector import GarbageCollector
12 | from pbt.network import ControllerDaemon
13 | from pbt.population import Population
14 | from pbt.backtrack_scheduler import BacktrackScheduler
15 | from pbt.tqdm_logger import TqdmLoggingHandler
16 |
17 | class PBTwBT_Controller:
18 | def __init__(
19 | self, pop_size, start_hyperparameters, exploitation=Truncation(),
20 | exploration=Perturb(), ready=lambda: True,
21 | exploration_bt=lambda x: x,
22 | stop=lambda iterations, _: iterations >= 100.0,
23 | data_path=os.getcwd(), results_path=os.getcwd(),
24 | max_steps=sys.maxsize, delta_t=30, tolerance=0.2, elite_ratio=0.125):
25 |
26 | self.logger = logging.getLogger('pbt')
27 | self.logger.setLevel(logging.DEBUG)
28 | self.logger.addHandler(TqdmLoggingHandler())
29 | self.daemon = None
30 | self.population = Population(pop_size, stop, results_path, backtracking=True, elite_ratio=elite_ratio)
31 | self.data_path = data_path
32 | self.scheduler = BacktrackScheduler(
33 | self.population, start_hyperparameters, exploitation, exploration,
34 | exploration_bt, delta_t, tolerance)
35 | self.garbage_collector = GarbageCollector(data_path)
36 | self.ready = ready
37 | self.max_steps = max_steps
38 | self.total_steps = 0
39 |
40 | self.workers = {}
41 | self._done = False
42 |
43 | self.logger.info(
44 | f'Started controller with parameters: ' +
45 | f'population size: {pop_size}, '
46 | f'data_path: {data_path}, '
47 | f'results_path: {results_path}')
48 |
49 | def start_daemon(self):
50 | self.logger.info('Starting daemon.')
51 | self.daemon = ControllerDaemon(self)
52 | self.daemon.start()
53 |
54 | def register_worker(self, worker):
55 | self.logger.debug(f'Worker {worker.worker_id} registered.')
56 | self.workers[worker.worker_id] = worker
57 |
58 | def request_trial(self):
59 | trial = self.scheduler.get_trial()
60 | return trial.to_tuple()
61 |
62 | def send_evaluation(self, member_id, score):
63 | self.logger.debug(
64 | f'Receiving evaluation for member {member_id}: {score}')
65 | # TODO: Some sort of ready function?
66 | min_timestep = self.population.get_min_time_step()
67 | # exclude all elites members (PBT-BT only)
68 | elites = self.population.get_elites()
69 | # self.garbage_collector.collect(member_id, min_timestep, elites)
70 | trial = self.population.update(member_id, score)
71 | self.scheduler.update_exploration(trial)
72 | self.total_steps += 1
73 | self.logger.debug(
74 | f'Current total step is : {self.total_steps}')
75 | if self.population.is_done():
76 | self.logger.info('Nothing more to do. Shutting down.')
77 | self._shut_down_workers()
78 | sleep(10)
79 | if self.daemon:
80 | self.daemon.shut_down()
81 |
82 | if self.total_steps >= self.max_steps:
83 | self.logger.info('Reach the maximal steps. Shutting down.')
84 | self._shut_down_workers()
85 | sleep(10)
86 | if self.daemon:
87 | self.daemon.shut_down()
88 |
89 | def _shut_down_workers(self):
90 | for worker in self.workers.values():
91 | worker.stop()
92 |
--------------------------------------------------------------------------------
/pbt/backtrack_scheduler.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from pbt.population import NoTrial, Trial
4 | from pbt.tqdm_logger import TqdmLoggingHandler
5 | import random
6 | from itertools import product
7 |
8 | class BacktrackScheduler:
9 | def __init__(
10 | self, population, start_hyperparameters, exploitation, exploration, exploration_bt, delta_t=30, tolerance=0.2):
11 | """Initialize a scheduler with backtracking mechanism
12 | """
13 | self.logger = logging.getLogger('pbt')
14 | self.logger.setLevel(logging.DEBUG)
15 | self.population = population
16 | self.exploitation = exploitation
17 | self.exploration = exploration
18 | self.exploration_bt = exploration_bt
19 |
20 | self.start_hyperparameters = start_hyperparameters
21 | self.delta_t = delta_t
22 | self.tolerance = tolerance
23 | self.record = {} # For tabu search
24 |
25 | def get_trial(self):
26 | self.logger.debug('Trial requested.')
27 | member = self.population.get_next_member()
28 | if not member:
29 | self.logger.debug('No trial ready.')
30 | return NoTrial()
31 |
32 | # intial run with starting hyperparameter
33 | if member.time_step == 0:
34 | start_hyperparameters = {cfg_name : self.start_hyperparameters[cfg_name]()
35 | for cfg_name in self.start_hyperparameters.keys()}
36 | trial = Trial(
37 | member.member_id, -1, 0, -1, start_hyperparameters)
38 | self.population.save_trial(trial)
39 | self.logger.debug(f'Returning first trial {trial}.')
40 | return trial
41 |
42 | if member.time_step % self.delta_t == 0:
43 | # start to check if it drops by X percentage
44 | self.logger.debug(f'Generating trial for member {member.member_id} with times step {member.time_step} with BT.')
45 | self.logger.debug(f'member {member.member_id} actual time step is {member._actual_time_step}.')
46 | elites = self.population.get_elites()
47 | model_id, model_time_step = self.backtracking_exploitation(member, elites)
48 | hyperparameters = self.population.get_hyperparameters_by_time_step(model_id, model_time_step)
49 | if model_id == member.member_id and model_time_step == member.time_step - 1:
50 | self.logger.debug(f'Staying with current model {model_id}.')
51 | else:
52 | self.logger.debug(f'Backtracking to model {model_id}, time step {model_time_step}.')
53 | member.set_actual_time_step(model_time_step)
54 | self.logger.debug(f'Set model {model_id} actual time step to {model_time_step}.')
55 | # TODO: Replace it with backtracking exploration
56 | hyperparameters = self.exploration_bt(hyperparameters, model_id, model_time_step)
57 | self.logger.debug(f'Using exploration. New: {hyperparameters}')
58 | trial = Trial(
59 | member.member_id, model_id, member.time_step, model_time_step,
60 | hyperparameters)
61 | self.population.save_trial(trial)
62 | return trial
63 |
64 | # Jointly do standard PBT with elites
65 | self.logger.debug(f'Generating trial for member {member.member_id} with times step {member.time_step}.')
66 | self.logger.debug(f'member {member.member_id} actual time step is {member._actual_time_step}.')
67 | scores = self.population.get_scores()
68 | # self.logger.debug(f'Collected all scoires.')
69 | # model_id indicates the model that we want to copy
70 | model_id = self.exploitation(member.member_id, scores)
71 | # model_time_step shows the timestep of the model to be copied
72 | model_time_step = self.population.get_latest_time_step(model_id) - 1
73 | # actual_time_step = self.population.get_actual_time_step_by_member_id(model_id)
74 | # self.logger.debug(f'Safely get timestep')
75 | hyperparameters = self.population.get_hyperparameters_by_time_step(model_id, model_time_step)
76 | # self.logger.debug(f'Safely get hyperparameters')
77 | if model_id != member.member_id:
78 | self.logger.debug(f'Copying model {model_id} at time step {model_time_step}.')
79 | member.set_actual_time_step(model_time_step)
80 | self.logger.debug(f'Set model {model_id} actual time step to {model_time_step}.')
81 | hyperparameters = self.exploration(hyperparameters)
82 | self.logger.debug(f'Using exploration. New: {hyperparameters}')
83 | else:
84 | self.logger.debug(f'Staying with current model {model_id}.')
85 | trial = Trial(
86 | member.member_id, model_id, member.time_step, model_time_step,
87 | hyperparameters)
88 | self.population.save_trial(trial)
89 | # self.logger.debug(f'Safely saved trial')
90 | return trial
91 |
92 | def update_exploration(self, trial):
93 | self.exploration.update(trial)
94 |
95 | # For PBT-BT only
96 | def backtracking_exploitation(self, member, elites):
97 | # TODO: handle zero division
98 | percentage_change = (member.get_last_score() - elites[-1].score) / abs(elites[-1].score)
99 | self.logger.debug(f'Performance changed {percentage_change}.')
100 | if percentage_change < - self.tolerance:
101 | trial = random.choice(elites)
102 | else:
103 | trial = member.get_last_trial()
104 |
105 | return trial.member_id, trial.time_step
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/pbt/controller.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import random
4 | import logging
5 | import os
6 | import sys
7 | from time import sleep
8 |
9 | from pbt.exploitation import Truncation
10 | from pbt.exploration import Perturb
11 | from pbt.garbage_collector import GarbageCollector
12 | from pbt.network import ControllerDaemon
13 | from pbt.population import Population
14 | from pbt.scheduler import Scheduler
15 | from pbt.tqdm_logger import TqdmLoggingHandler
16 |
17 | class PBTController:
18 | def __init__(
19 | self, pop_size, start_hyperparameters, exploitation=Truncation(),
20 | exploration=Perturb(), ready=lambda: True,
21 | stop=lambda iterations, _: iterations >= 100.0,
22 | data_path=os.getcwd(), results_path=os.getcwd(),
23 | max_steps=sys.maxsize):
24 |
25 | self.logger = logging.getLogger('pbt')
26 | self.logger.setLevel(logging.DEBUG)
27 | self.logger.addHandler(TqdmLoggingHandler())
28 | self.daemon = None
29 | self.population = Population(pop_size, stop, results_path)
30 | self.data_path = data_path
31 | self.scheduler = Scheduler(
32 | self.population, start_hyperparameters, exploitation, exploration)
33 | self.garbage_collector = GarbageCollector(data_path)
34 | self.ready = ready
35 | self.max_steps = max_steps
36 | self.total_steps = 0
37 |
38 | self.workers = {}
39 | self._done = False
40 |
41 | self.logger.info(
42 | f'Started controller with parameters: ' +
43 | f'population size: {pop_size}, '
44 | f'data_path: {data_path}, '
45 | f'results_path: {results_path}')
46 |
47 | def start_daemon(self):
48 | self.logger.info('Starting daemon.')
49 | self.daemon = ControllerDaemon(self)
50 | self.daemon.start()
51 |
52 | def register_worker(self, worker):
53 | self.logger.debug(f'Worker {worker.worker_id} registered.')
54 | self.workers[worker.worker_id] = worker
55 |
56 | def request_trial(self):
57 | trial = self.scheduler.get_trial()
58 | return trial.to_tuple()
59 |
60 | def send_evaluation(self, member_id, score):
61 | self.logger.debug(
62 | f'Receiving evaluation for member {member_id}: {score}')
63 | # TODO: Some sort of ready function?
64 | min_timestep = self.population.get_min_time_step()
65 | # exclude all elites members (PBT-BT only)
66 | elites = []
67 | self.logger.info(f'min time step is {min_timestep}')
68 | self.garbage_collector.collect(member_id, min_timestep, elites)
69 | trial = self.population.update(member_id, score)
70 | self.scheduler.update_exploration(trial)
71 | self.total_steps += 1
72 | if self.population.is_done():
73 | self.logger.info('Nothing more to do. Shutting down.')
74 | self._shut_down_workers()
75 | sleep(10)
76 | if self.daemon:
77 | self.daemon.shut_down()
78 |
79 | if self.total_steps >= self.max_steps:
80 | self.logger.info('Reach the maximal steps. Shutting down.')
81 | self._shut_down_workers()
82 | sleep(10)
83 | if self.daemon:
84 | self.daemon.shut_down()
85 |
86 | def _shut_down_workers(self):
87 | for worker in self.workers.values():
88 | worker.stop()
89 |
--------------------------------------------------------------------------------
/pbt/exploitation/__init__.py:
--------------------------------------------------------------------------------
1 | from .exploitation_strategy import ExploitationStrategy
2 | from .truncation import Truncation
3 |
4 | __all__ = ['ExploitationStrategy', 'Truncation']
5 |
--------------------------------------------------------------------------------
/pbt/exploitation/constant.py:
--------------------------------------------------------------------------------
1 | class Constant:
2 | def __init__(self):
3 | """
4 | Constant mechanism which doesn't exploit any member
5 | """
6 | pass
7 |
8 | def __call__(self, own_name: str, scores: dict) -> str:
9 | """
10 | Return the name of the given member
11 | :param own_name: The agent of the current agent
12 | :param scores: A dict with names and scores of all agents
13 | :return: The name of the chosen better agent
14 | """
15 | return own_name
--------------------------------------------------------------------------------
/pbt/exploitation/exploitation_strategy.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class ExploitationStrategy:
5 | @abc.abstractmethod
6 | def __call__(self, own_name: str, scores: dict) -> str:
7 | """
8 | This method should implement the exploitation behaviour.
9 | :param own_name: The name of the current agent
10 | :param scores: The names and scores of all agents
11 | :return: The name of the agent to copy
12 | """
13 | raise NotImplementedError('This method has to be overwritten!')
14 |
--------------------------------------------------------------------------------
/pbt/exploitation/truncation.py:
--------------------------------------------------------------------------------
1 | import math
2 | import operator
3 | import random
4 |
5 |
6 | class Truncation:
7 | def __init__(self, sample_from_percent=0.2, resample_if_not_in_percent=0.8):
8 | """
9 | A simple truncation mechanism as specified in (Jaderberg et al., 2017)
10 | :param sample_from_percent: Percent of best part of the population
11 | :param resample_if_not_in_percent: Percent of untouched agents
12 | """
13 | self.sample_from_percent = sample_from_percent
14 | self.resample_if_not_in_percent = resample_if_not_in_percent
15 |
16 | def __call__(self, own_name: str, scores: dict) -> str:
17 | """
18 | Find a better agent or return own_name, if the agent is good enough.
19 | :param own_name: The agent of the current agent
20 | :param scores: A dict with names and scores of all agents
21 | :return: The name of the chosen better agent
22 | """
23 | if len(scores) == 1:
24 | return own_name
25 | if own_name in self._get_best(self.resample_if_not_in_percent, scores):
26 | return own_name
27 | else:
28 | return random.choice(list(
29 | self._get_best(self.sample_from_percent, scores).keys()))
30 |
31 | def _get_best(self, percent, scores):
32 | sorted_scores = sorted(
33 | scores.items(), reverse=True, key=operator.itemgetter(1))
34 | last_index = math.ceil((len(scores) - 1) * percent)
35 | return {name: score for name, score in sorted_scores[:last_index]}
36 |
--------------------------------------------------------------------------------
/pbt/exploration/__init__.py:
--------------------------------------------------------------------------------
1 | from .exploration_strategy import ExplorationStrategy
2 | from .perturb import Perturb
3 | from .resample import Resample
4 | from .perturb_and_resample import PerturbAndResample
5 | from pbt.exploration.models.tree_parzen_estimator import TreeParzenEstimator
6 | from .exploration_bt import Exploration_BT
7 | __all__ = [
8 | 'ExplorationStrategy', 'Perturb', 'Resample', 'PerturbAndResample',
9 | 'TreeParzenEstimator', 'Exploration_BT']
10 |
--------------------------------------------------------------------------------
/pbt/exploration/constant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Constant:
4 | """
5 | A simple constant mechanism which returns the same configuration as given.
6 | """
7 | def __init__(self):
8 | pass
9 |
10 | def __call__(self, hyperparameters: dict) -> dict:
11 | """
12 | Perturb the nodes in the input.
13 | :param hyperparameters: A dict with nodes.
14 | :return: The perturbed nodes.
15 | """
16 | result = hyperparameters.copy()
17 | return result
18 |
19 |
--------------------------------------------------------------------------------
/pbt/exploration/exploration_bt.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from pbt.exploration import Resample
4 | import itertools
5 |
6 | class Exploration_BT():
7 | def __init__(
8 | self, mutations: dict, cs_space: dict, resample_probability: float = 0.25,
9 | boundaries={}):
10 | """
11 | A strategy to do both perturb and resample.
12 | :param mutations: A dictionary with hyperparameter names and mutations
13 | :param resample_probability: The probability to resample for each call
14 | """
15 | self.resample_probability = resample_probability
16 | self.resample = Resample(mutations=mutations)
17 | self.records = {}
18 | self.cs_space = cs_space
19 | self.boundaries = boundaries
20 |
21 | def __call__(self, hyperparameters: dict, model_id: int, model_time_step: int) -> dict:
22 | """
23 | Tabu search of the all possible perturbation. If all possibilities are tried, then sample randomly
24 | from the configuration space
25 | :param hyperparameters: The nodes to perturb
26 | :return: Perturbed and probably resampled nodes
27 | """
28 | result = hyperparameters.copy()
29 | num_hyperparameters = len(result)
30 | if (model_id, model_time_step) not in self.records:
31 | self.records[(model_id, model_time_step)] = list(itertools.product([-1,1], repeat=num_hyperparameters))
32 | random.shuffle(self.records[(model_id, model_time_step)])
33 |
34 | if random.random() < self.resample_probability or len(self.records[(model_id, model_time_step)]) == 0:
35 | result = self.resample(result)
36 | else:
37 | result = self.perturb(result, model_id, model_time_step)
38 |
39 | return result
40 |
41 | def perturb(self, hyperparameters: dict, model_id: int, model_time_step: int) -> dict:
42 | directions = self.records[(model_id, model_time_step)].pop()
43 | for i, key in enumerate(sorted(hyperparameters)):
44 | temp_value = self.cs_space[key]._inverse_transform(hyperparameters[key])
45 | temp_value += directions[i] * 0.2 * temp_value
46 | hyperparameters[key] = self.cs_space[key]._transform(temp_value)
47 | self.ensure_boundaries(hyperparameters)
48 | return hyperparameters
49 |
50 | def ensure_boundaries(self, result):
51 | for key in result:
52 | if key not in self.boundaries:
53 | continue
54 | if result[key] < self.boundaries[key][0]:
55 | result[key] = self.boundaries[key][0]
56 | elif result[key] > self.boundaries[key][1]:
57 | result[key] = self.boundaries[key][1]
58 |
--------------------------------------------------------------------------------
/pbt/exploration/exploration_strategy.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 |
4 | class ExplorationStrategy:
5 | @abstractmethod
6 | def __call__(self, hyperparameters):
7 | """
8 | This method should implement the exploration behaviour.
9 | :param hyperparameters: The nodes to explore
10 | :return: The changed nodes
11 | """
12 | raise NotImplementedError('This method has to be overwritten!')
13 |
14 | def update(self, trial):
15 | pass
16 |
--------------------------------------------------------------------------------
/pbt/exploration/model_based.py:
--------------------------------------------------------------------------------
1 | from pbt.exploration import ExplorationStrategy
2 |
3 |
4 | class ModelBased(ExplorationStrategy):
5 | def __init__(self, model):
6 | self.model = model
7 |
8 | def __call__(self, hyperparameters):
9 | return self.model.sample()
10 |
11 | def update(self, trial):
12 | self.model.update(trial.to_array())
13 |
--------------------------------------------------------------------------------
/pbt/exploration/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Model
2 | from .tree_parzen_estimator import TreeParzenEstimator
3 |
4 | __all__ = ['Model', 'TreeParzenEstimator']
5 |
--------------------------------------------------------------------------------
/pbt/exploration/models/config_tree/__init__.py:
--------------------------------------------------------------------------------
1 | from .config_tree import ConfigTree
2 |
3 | __all__ = ['ConfigTree']
4 |
--------------------------------------------------------------------------------
/pbt/exploration/models/config_tree/config_tree.py:
--------------------------------------------------------------------------------
1 | class ConfigTree:
2 | def __init__(self, root):
3 | self.root = root
4 | self.all_nodes = self._get_all_nodes()
5 | self.hyperparameter_names = self._get_hyperparameter_names()
6 | self._distribute_indices()
7 |
8 | def sample(self):
9 | result = self._sample_from_root_nodes()
10 | for hyperparameter in self.hyperparameter_names:
11 | if hyperparameter not in result:
12 | result[hyperparameter] = None
13 | return result
14 |
15 | def uniform_sample(self):
16 | result = {}
17 | for node in self.all_nodes:
18 | node.uniform_sample(result)
19 | return result
20 |
21 | def evaluate(self, data):
22 | scores = [0.0 for _ in data]
23 | for node in self.root:
24 | node.evaluate(data, scores)
25 | return scores
26 |
27 | def fit(self, data):
28 | for node in self.all_nodes:
29 | node.fit([point[node.index] for point in data])
30 |
31 | def structural_copy(self):
32 | return ConfigTree([node.structural_copy() for node in self.root])
33 |
34 | def _distribute_indices(self):
35 | for i, node in enumerate(self.all_nodes):
36 | node.index = i + 3 # Place 0, 1, and 2 are score and time_step ...
37 |
38 | def _sample_from_root_nodes(self):
39 | result = {}
40 | for node in self.root:
41 | node.sample(result)
42 | return result
43 |
44 | def _get_all_nodes(self):
45 | all_nodes = []
46 | for node in self.root:
47 | all_nodes.append(node)
48 | all_nodes += node.get_children()
49 | return sorted(all_nodes, key=lambda n: n.name)
50 |
51 | def _get_hyperparameter_names(self):
52 | return [node.name for node in self.all_nodes]
53 |
--------------------------------------------------------------------------------
/pbt/exploration/models/config_tree/nodes/__init__.py:
--------------------------------------------------------------------------------
1 | from .node import Node
2 | from .categorical import Categorical
3 | from .float import Float
4 | from .log_float import LogFloat
5 | from .integer import Integer
6 |
7 | __all__ = ['Node', 'Categorical', 'Float', 'LogFloat', 'Integer']
8 |
--------------------------------------------------------------------------------
/pbt/exploration/models/config_tree/nodes/categorical.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 |
3 | import numpy as np
4 |
5 | from pbt.exploration.models.config_tree.nodes import Node
6 |
7 |
8 | class Categorical(Node):
9 | def __init__(self, name, values):
10 | self.name = name
11 | self.keys = list(values.keys())
12 | self.values = values
13 | self.probabilities = [1.0/len(self.keys) for _ in self.keys]
14 |
15 | def sample(self, result):
16 | value = str(np.random.choice(self.keys, p=self.probabilities))
17 | result[self.name] = value
18 | if self.values[value] is not None:
19 | self.values[value].sample(result)
20 |
21 | def uniform_sample(self, result):
22 | result[self.name] = str(np.random.choice(self.keys))
23 |
24 | def evaluate(self, data, scores):
25 | single_scores = [
26 | self._get_log_density(point[self.name]) for point in data]
27 | for i, single_score in enumerate(single_scores):
28 | scores += single_score
29 |
30 | def _get_log_density(self, value):
31 | return np.log(self.probabilities[self.keys.index(value)])
32 |
33 | def fit(self, values):
34 | counter = Counter(values)
35 | self.probabilities = [
36 | counter[key]/len(values) for key in self.keys]
37 | self._add_random_exploration()
38 |
39 | def _add_random_exploration(self, random_probability=0.1):
40 | factor = 1.0 - random_probability
41 | addend = random_probability / len(self.probabilities)
42 |
43 | for i, _ in enumerate(self.probabilities):
44 | self.probabilities[i] *= factor
45 | self.probabilities[i] += addend
46 |
47 | def get_children(self):
48 | result = []
49 | for child in self.values.values():
50 | if child is not None:
51 | result.append(child)
52 | result += child.get_children()
53 | return result
54 |
55 | def structural_copy(self):
56 | values = {
57 | name: node.structural_copy() if node is not None else None
58 | for name, node in self.values.items()}
59 | return Categorical(self.name, values)
60 |
--------------------------------------------------------------------------------
/pbt/exploration/models/config_tree/nodes/float.py:
--------------------------------------------------------------------------------
1 | from bokeh.io import show
2 | from bokeh.plotting import Figure
3 |
4 | import numpy as np
5 | from sklearn.neighbors import KernelDensity
6 |
7 | from pbt.exploration.models.config_tree.nodes import Node
8 |
9 |
10 | class Float(Node):
11 | def __init__(self, name, low, high, width=20):
12 | self.low = low
13 | self.high = high
14 |
15 | self.kde = KernelDensity((high - low) / width)
16 |
17 | self.kde.fit(np.array([high-low])[:, None])
18 |
19 | super().__init__(name)
20 |
21 | def sample(self, result):
22 | value = float('inf')
23 | while value < self.low or value > self.high:
24 | value = float(self.kde.sample())
25 | result[self.name] = value
26 |
27 | def uniform_sample(self, result):
28 | result[self.name] = float(np.random.uniform(
29 | low=self.low, high=self.high))
30 |
31 | def evaluate(self, data, scores):
32 | single_scores = self.kde.score_samples(
33 | np.array([point[self.name] for point in data])[:, None])
34 | for i, single_score in enumerate(single_scores):
35 | scores[i] += single_score
36 |
37 | def fit(self, values):
38 | self.kde.fit(np.array(values)[:, None])
39 |
40 | def get_children(self):
41 | return []
42 |
43 | def structural_copy(self):
44 | return Float(self.name, self.low, self.high)
45 |
46 |
47 | if __name__ == '__main__':
48 | node = Float('lr', low=1e-5, high=1e-3)
49 | node.fit([0.0005, 1e-5, 1e-5])
50 | results = []
51 | for i in range(100000):
52 | r = {}
53 | node.sample(r)
54 | results.append(r['lr'])
55 |
56 | hist, edges = np.histogram(results, bins=100)
57 |
58 | plot = Figure()
59 | plot.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:])
60 | show(plot)
61 |
--------------------------------------------------------------------------------
/pbt/exploration/models/config_tree/nodes/integer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.neighbors import KernelDensity
3 |
4 | from pbt.exploration.models.config_tree.nodes import Node
5 |
6 |
7 | class Integer(Node):
8 | def __init__(self, name, low, high, width=20):
9 | self.low = low
10 | self.high = high
11 | self.kde = KernelDensity((high - low) / width)
12 |
13 | self.kde.fit(np.array([high-low])[:, None])
14 |
15 | super().__init__(name)
16 |
17 | def sample(self, result):
18 | value = float('inf')
19 | while value < self.low or value > self.high:
20 | value = int(self.kde.sample())
21 | result[self.name] = value
22 |
23 | def uniform_sample(self, result):
24 | result[self.name] = int(np.round(
25 | np.random.uniform(low=self.low, high=self.high)))
26 |
27 | def evaluate(self, data, scores):
28 | single_scores = self.kde.score_samples(
29 | np.array([point[self.name] for point in data])[:, None])
30 | for i, single_score in enumerate(single_scores):
31 | scores[i] += single_score
32 |
33 | def fit(self, values):
34 | self.kde.fit(np.array(values)[:, None])
35 |
36 | def get_children(self):
37 | return []
38 |
39 | def structural_copy(self):
40 | return Integer(self.name, self.low, self.high)
41 |
--------------------------------------------------------------------------------
/pbt/exploration/models/config_tree/nodes/log_float.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pbt.exploration.models.config_tree.nodes import Float
4 |
5 |
6 | class LogFloat(Float):
7 | def __init__(self, name, low, high, width=20):
8 | if low <= 0:
9 | raise ValueError('"low" has to be greater than 0!')
10 | super().__init__(name, low, high, width)
11 |
12 | def sample(self, result):
13 | value = float('inf')
14 | while value < self.low or value > self.high:
15 | value = float(10**self.kde.sample())
16 | result[self.name] = value
17 |
18 | def uniform_sample(self, result):
19 | result[self.name] = float(10**np.random.uniform(
20 | low=np.log10(self.low), high=np.log10(self.high)))
21 | return result[self.name]
22 |
23 | def evaluate(self, data, scores):
24 | single_scores = self.kde.score_samples(
25 | np.array([np.log10(point[self.name]) for point in data])[:, None])
26 | for i, single_score in enumerate(single_scores):
27 | scores[i] += single_score
28 |
29 | def fit(self, values):
30 | self.kde.fit(np.array([np.log10(value) for value in values])[:, None])
31 |
32 | def get_children(self):
33 | return []
34 |
35 | def structural_copy(self):
36 | return LogFloat(self.name, self.low, self.high)
37 |
--------------------------------------------------------------------------------
/pbt/exploration/models/config_tree/nodes/node.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Node:
5 | def __init__(self, name):
6 | self.name = name
7 | self.index = None
8 |
9 | @abc.abstractmethod
10 | def sample(self, result):
11 | raise NotImplementedError
12 |
13 | @abc.abstractmethod
14 | def uniform_sample(self, result):
15 | raise NotImplementedError
16 |
17 | @abc.abstractmethod
18 | def evaluate(self, data, scores):
19 | raise NotImplementedError
20 |
21 | @abc.abstractmethod
22 | def fit(self, values):
23 | raise NotImplementedError
24 |
25 | @abc.abstractmethod
26 | def get_children(self):
27 | raise NotImplementedError
28 |
29 | @abc.abstractmethod
30 | def structural_copy(self):
31 | raise NotImplementedError
32 |
--------------------------------------------------------------------------------
/pbt/exploration/models/model.py:
--------------------------------------------------------------------------------
1 | class Model:
2 | def sample(self):
3 | raise NotImplementedError
4 |
--------------------------------------------------------------------------------
/pbt/exploration/models/tree_parzen_estimator.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 |
6 | from pbt.exploration.models import Model
7 | from pbt.tqdm_logger import TqdmLoggingHandler
8 |
9 | class TreeParzenEstimator(Model):
10 | def __init__(
11 | self, config_tree, best_percent=0.2, uniform_percent=0.25,
12 | sample_size=10, new_samples_until_update=10, window_size=20,
13 | mode='improvement', split='time_step'):
14 | self.l_tree = config_tree
15 | self.g_tree = config_tree.structural_copy()
16 |
17 | self._best_percent = best_percent
18 | self._uniform_percent = uniform_percent
19 | self._sample_size = sample_size
20 | self._window_size = window_size
21 |
22 | self._new_samples_until_update = new_samples_until_update
23 | self._counter = 0
24 |
25 | self.data = defaultdict(list)
26 |
27 | self.logger = logging.getLogger('pbt')
28 | self.logger.setLevel(logging.DEBUG)
29 | self.logger.addHandler(TqdmLoggingHandler())
30 | if mode == 'improvement':
31 | self._mode = 2
32 | else:
33 | self._mode = 0
34 |
35 | if split != 'time_step':
36 | self._split_function = self._split_data
37 | else:
38 | self._split_function = self._split_by_time_step
39 |
40 | def sample(self):
41 | if np.random.random() < self._uniform_percent:
42 | self.logger.debug('Using random sampling.')
43 | return self.l_tree.uniform_sample()
44 | self.logger.debug('TPE sampling:')
45 | samples = [self.l_tree.sample() for _ in range(self._sample_size)]
46 | self.logger.debug(f'Samples: {samples}')
47 | l_scores = self.l_tree.evaluate(samples)
48 | g_scores = self.g_tree.evaluate(samples)
49 | scores = [l/g for l, g in zip(l_scores, g_scores)]
50 | self.logger.debug(f'Scores: {scores}')
51 | return samples[np.argmin(scores)]
52 |
53 | def update(self, trial):
54 | self.data[trial[1]].append(trial)
55 |
56 | self._counter += 1
57 | if self._counter >= self._new_samples_until_update:
58 | self._fit_data()
59 | self._counter = 0
60 |
61 | def _fit_data(self):
62 | indices = range(
63 | max(0, len(self.data) - self._window_size),
64 | len(self.data))
65 | sliding_window = [self.data[x] for x in indices]
66 | good, bad = self._split_function(sliding_window)
67 | self.l_tree.fit(good)
68 | self.g_tree.fit(bad)
69 |
70 | def _split_by_time_step(self, data):
71 | good, bad = [], []
72 | for time_step in data:
73 | new_good, new_bad = self._split_data([time_step])
74 | good += new_good
75 | bad += new_bad
76 | return good, bad
77 |
78 | def _split_data(self, data):
79 | scores = sorted([
80 | trial[self._mode]
81 | for time_step in data for trial in time_step], reverse=True)
82 | pivot_score = scores[int(len(scores) * self._best_percent)]
83 | good, bad = [], []
84 | for time_step in data:
85 | for point in time_step:
86 | if point[self._mode] >= pivot_score:
87 | good.append(point)
88 | if point[self._mode] <= pivot_score:
89 | bad.append(point)
90 | return good, bad
91 |
92 |
93 |
--------------------------------------------------------------------------------
/pbt/exploration/perturb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Perturb:
4 | """
5 | A simple perturb mechanism as specified in (Jaderberg et al., 2017).
6 | """
7 | def __init__(self, cs_space=None, boundaries={}):
8 | self.boundaries = boundaries
9 | self.cs_space = cs_space
10 |
11 | def __call__(self, hyperparameters: dict) -> dict:
12 | """
13 | Perturb the nodes in the input.
14 | :param hyperparameters: A dict with nodes.
15 | :return: The perturbed nodes.
16 | """
17 | result = hyperparameters.copy()
18 |
19 | for key in hyperparameters:
20 | temp_value = self.cs_space[key]._inverse_transform(result[key])
21 | temp_value += np.random.choice([-1, 1]) * 0.2 * temp_value
22 | result[key] = self.cs_space[key]._transform(temp_value)
23 | self.ensure_boundaries(result)
24 | return result
25 |
26 | def ensure_boundaries(self, result):
27 | for key in result:
28 | if key not in self.boundaries:
29 | continue
30 | if result[key] < self.boundaries[key][0]:
31 | result[key] = self.boundaries[key][0]
32 | elif result[key] > self.boundaries[key][1]:
33 | result[key] = self.boundaries[key][1]
34 |
--------------------------------------------------------------------------------
/pbt/exploration/perturb_and_resample.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from pbt.exploration import ExplorationStrategy, Perturb, Resample
4 |
5 |
6 | class PerturbAndResample(ExplorationStrategy):
7 | def __init__(
8 | self, mutations: dict, cs_space: dict, resample_probability: float = 0.25,
9 | boundaries={}):
10 | """
11 | A strategy to do both perturb and resample.
12 | :param mutations: A dictionary with hyperparameter names and mutations
13 | :param resample_probability: The probability to resample for each call
14 | """
15 | self.resample_probability = resample_probability
16 |
17 | self.perturb = Perturb(cs_space=cs_space, boundaries=boundaries)
18 | self.resample = Resample(mutations=mutations)
19 |
20 | def __call__(self, hyperparameters: dict) -> dict:
21 | """
22 | Perturb all nodes specified by mutations and then resample
23 | each hyperparameter depending on the resample_probability.
24 | :param hyperparameters: The nodes to perturb
25 | :return: Perturbed and probably resampled nodes
26 | """
27 | result = self.perturb(hyperparameters)
28 |
29 | if random.random() < self.resample_probability:
30 | result = self.resample(result)
31 |
32 | return result
33 |
--------------------------------------------------------------------------------
/pbt/exploration/resample.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 |
4 | class Resample:
5 | def __init__(self, mutations: dict):
6 | """
7 | A simple resample mechanism as specified in (Jaderberg et al., 2017).
8 | :param mutations: A dict of all nodes and its mutations.
9 | """
10 | self.mutations = mutations
11 |
12 | def __call__(self, hyperparameters: dict) -> dict:
13 | """
14 | Resample nodes given by the specified mutations.
15 | :param hyperparameters: All nodes
16 | :return: All nodes with specified nodes resampled
17 | """
18 | result = hyperparameters.copy()
19 |
20 | for key, value in self.mutations.items():
21 | result[key] = value() if callable(value) else random.choice(value)
22 |
23 | return result
24 |
--------------------------------------------------------------------------------
/pbt/garbage_collector.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class GarbageCollector:
6 | def __init__(self, data_path):
7 | self.data_path = data_path
8 |
9 | def collect(self, member_id, min_timestep, elites):
10 | if min_timestep - 1 < 0:
11 | return
12 | else:
13 | for i in range(min_timestep-1):
14 | if member_id in [elite.member_id for elite in elites] and i in [elite.time_step for elite in elites]:
15 | continue
16 | member_path = os.path.join(self.data_path, str(member_id), str(i))
17 | for file_name in ["state_dict.npz", "traj_acs.json", "traj_obs.json", "traj_rews.json"]:
18 | file_path = os.path.join(member_path, file_name)
19 | if os.path.exists(file_path):
20 | os.remove(file_path)
21 | #else:
22 | # print("Can not delete the file in path %s as it doesn't exists" %file_path)
23 |
24 | def _get_integer_dirs(self, path):
25 | return [int(i) for i in os.listdir(path) if i.isdigit()]
26 |
--------------------------------------------------------------------------------
/pbt/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .controller_adapter import ControllerAdapter
2 | from .worker_adapter import WorkerAdapter
3 |
4 | from .daemon import Daemon
5 | from .controller_daemon import ControllerDaemon, CONTROLLER_URI_FILENAME
6 | from .worker_daemon import WorkerDaemon
7 |
8 | __all__ = [
9 | 'ControllerAdapter', 'WorkerAdapter',
10 | 'Daemon', 'ControllerDaemon', 'WorkerDaemon', 'CONTROLLER_URI_FILENAME']
11 |
--------------------------------------------------------------------------------
/pbt/network/controller_adapter.py:
--------------------------------------------------------------------------------
1 | from Pyro4.errors import ConnectionClosedError, CommunicationError
2 |
3 | from pbt.population import NoTrial
4 |
5 |
6 | class ControllerAdapter:
7 | def __init__(self, controller):
8 | self._controller = controller
9 |
10 | def register_worker_by_uri(self, uri):
11 | try:
12 | self._controller.register_worker_by_uri(uri)
13 | return True
14 | except CommunicationError:
15 | # URI from file is not valid -> kill worker
16 | return False
17 |
18 | def request_trial(self):
19 | try:
20 | return self._controller.request_trial()
21 | except ConnectionClosedError:
22 | return NoTrial().to_tuple()
23 |
24 | def send_evaluation(self, member_id, score):
25 | try:
26 | self._controller.send_evaluation(member_id, float(score))
27 | except ConnectionClosedError:
28 | return
29 |
--------------------------------------------------------------------------------
/pbt/network/controller_daemon.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import Pyro4
4 |
5 | from pbt.network import Daemon, WorkerAdapter
6 | from pbt.tqdm_logger import TqdmLoggingHandler
7 | CONTROLLER_URI_FILENAME = 'controller_uri.txt'
8 |
9 |
10 | @Pyro4.expose
11 | class ControllerDaemon(Daemon):
12 | def __init__(self, controller):
13 | self.logger = logging.getLogger('pbt')
14 | self.logger.setLevel(logging.DEBUG)
15 | # self.logger.addHandler(TqdmLoggingHandler())
16 | self.controller = controller
17 | self.pyro_daemon = None
18 |
19 | def start(self):
20 | Pyro4.config.SERVERTYPE = 'multiplex'
21 | self.pyro_daemon = Pyro4.Daemon(host=self._get_hostname())
22 | uri = self.pyro_daemon.register(self)
23 | self._save_pyro_uri(uri)
24 | self.pyro_daemon.requestLoop()
25 |
26 | def register_worker_by_uri(self, uri):
27 | self.controller.register_worker(WorkerAdapter(Pyro4.Proxy(uri)))
28 |
29 | def request_trial(self):
30 | return self.controller.request_trial()
31 |
32 | def send_evaluation(self, member_id, score):
33 | self.controller.send_evaluation(member_id, score)
34 |
35 | def shut_down(self):
36 | self.pyro_daemon.shutdown()
37 |
38 | def _save_pyro_uri(self, uri):
39 | save_path = os.path.join(
40 | self.controller.data_path, CONTROLLER_URI_FILENAME)
41 | if not os.path.isdir(self.controller.data_path):
42 | os.makedirs(self.controller.data_path)
43 | with open(save_path, 'w') as f:
44 | f.write(str(uri))
45 | self.logger.info(f'Saved pyro uri at {save_path}.')
46 |
--------------------------------------------------------------------------------
/pbt/network/daemon.py:
--------------------------------------------------------------------------------
1 | import socket
2 |
3 |
4 | class Daemon:
5 | def _get_hostname(self):
6 | my_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
7 | my_socket.connect(('10.255.255.255', 1))
8 | ip = my_socket.getsockname()[0]
9 | my_socket.close()
10 | return ip
11 |
--------------------------------------------------------------------------------
/pbt/network/worker_adapter.py:
--------------------------------------------------------------------------------
1 | from Pyro4.errors import CommunicationError
2 |
3 |
4 | class WorkerAdapter:
5 | def __init__(self, worker):
6 | self._worker = worker
7 |
8 | @property
9 | def worker_id(self):
10 | return self._worker.worker_id
11 |
12 | def ping(self):
13 | try:
14 | return self._worker.ping()
15 | except CommunicationError:
16 | return False
17 |
18 | def stop(self):
19 | try:
20 | self._worker.stop()
21 | except CommunicationError:
22 | pass
23 |
--------------------------------------------------------------------------------
/pbt/network/worker_daemon.py:
--------------------------------------------------------------------------------
1 | from threading import Thread
2 |
3 | import Pyro4
4 |
5 | from pbt.network import Daemon
6 |
7 |
8 | @Pyro4.expose
9 | class WorkerDaemon(Daemon):
10 | def __init__(self, worker):
11 | self.worker = worker
12 | self.pyro_daemon = None
13 |
14 | @property
15 | def worker_id(self):
16 | return self.worker.worker_id
17 |
18 | def start(self):
19 | self.pyro_daemon = Pyro4.Daemon(host=self._get_hostname())
20 | uri = self.pyro_daemon.register(self)
21 | thread = Thread(target=self.pyro_daemon.requestLoop)
22 | thread.start()
23 | return uri
24 |
25 | def ping(self):
26 | return True
27 |
28 | def stop(self):
29 | self.worker.stop()
30 | self.pyro_daemon.shutdown()
31 |
--------------------------------------------------------------------------------
/pbt/population/__init__.py:
--------------------------------------------------------------------------------
1 | from .trial import Trial, NoTrial
2 | from .member import Member
3 | from .population import Population
4 |
5 | __all__ = ['Trial', 'NoTrial', 'Member', 'Population']
6 |
--------------------------------------------------------------------------------
/pbt/population/member.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 |
5 | class Member:
6 | def __init__(self, member_id, stop):
7 | self.member_id = member_id
8 | self.stop = stop
9 |
10 | self.trials = {}
11 | self.time_step = 0
12 | self.max_time_step = None
13 | self.last_score = 0
14 | self._actual_time_step = 0
15 |
16 | self.is_free = True
17 | self.is_done = False
18 |
19 | def __repr__(self):
20 | return f''
21 |
22 | def assign_trial(self, trial, last_score):
23 | self.is_free = False
24 | self.trials[trial.time_step] = trial
25 | self.last_score = last_score
26 |
27 | def reached_time_step(self, time_step):
28 | if time_step not in self.trials:
29 | return False
30 | return self.trials[time_step].score is not None
31 |
32 | def get_last_score(self):
33 | if self.time_step == 0:
34 | return -float('inf')
35 | return self.trials[self.time_step - 1].score
36 |
37 | def get_score_by_time_step(self, time_step):
38 | if self.time_step == 0:
39 | return -float('inf')
40 | if time_step not in self.trials:
41 | return None
42 | return self.trials[time_step].score
43 |
44 | def get_last_trial(self):
45 | return self.trials[self.time_step - 1]
46 |
47 | def get_all_scores(self):
48 | return [trial.score for trial in self.trials]
49 |
50 | def get_best_trial(self):
51 | valid_trials = [trial for trial in self.trials.values() if trial.score is not None]
52 | if len(valid_trials) == 0:
53 | return None
54 | else:
55 | return max(valid_trials, key=lambda trial: trial.score)
56 |
57 | # def get_best_trial_by_delta_t(self, delta_t):
58 | # return sorted(self.trials[::delta_t], key=lambda trial: trial.score)[-1]
59 |
60 | def get_hyperparameters_by_time_step(self, time_step):
61 | if time_step not in self.trials:
62 | return None
63 | return self.trials[time_step].hyperparameters
64 |
65 | def get_hyperparameters(self):
66 | return self.trials[self.time_step - 1].hyperparameters
67 |
68 | def save_score(self, score):
69 | trial = self.trials[self.time_step]
70 | trial.score = score
71 | trial.improvement = score - self.last_score
72 | self.time_step += 1
73 | self._actual_time_step += 1
74 |
75 | # TODO: Move this to own function
76 | if self.max_time_step:
77 | if self._actual_time_step >= self.max_time_step:
78 | self.is_done = True
79 | else:
80 | self.is_free = True
81 | else:
82 | if self.stop(self._actual_time_step, score):
83 | self.is_done = True
84 | else:
85 | self.is_free = True
86 |
87 | return trial
88 |
89 | def set_actual_time_step(self, time_step):
90 | self._actual_time_step = time_step
91 |
92 | def get_actual_time_step(self):
93 | return self._actual_time_step
94 |
95 |
96 | def log_last_result(self, results_path):
97 | last_result = self.trials[self.time_step - 1]
98 | member_path = os.path.join(results_path, str(self.member_id))
99 | step_path = os.path.join(member_path, str(last_result.time_step))
100 | if not os.path.isdir(step_path):
101 | os.makedirs(step_path)
102 | with open(os.path.join(member_path, 'scores.txt'), 'a+') as f:
103 | f.write(f'{last_result.score}\n')
104 | with open(os.path.join(step_path, 'nodes.json'), 'w') as f:
105 | json.dump(last_result.hyperparameters, f)
106 | with open(os.path.join(step_path, 'add.json'), 'w') as f:
107 | json.dump(
108 | {'copied from': self.trials[self.time_step - 1].model_id,
109 | 'time step:': self.trials[self.time_step - 1].model_time_step}, f)
110 |
111 | def _create_hyperparameters(self):
112 | if self.time_step == 0:
113 | return self.exploration.get_start_hyperparameters()
114 | return self.exploration()
115 |
--------------------------------------------------------------------------------
/pbt/population/population.py:
--------------------------------------------------------------------------------
1 | from pbt.population import Member
2 |
3 |
4 | class Population:
5 | def __init__(self, size, stop, results_path, mode='asynchronous', backtracking=False, elite_ratio=0.1):
6 | self.members = [Member(i, stop) for i in range(size)]
7 | self.mode = mode
8 | self.results_path = results_path
9 | self.backtracking = backtracking
10 | self.elite_size = max(round(size * elite_ratio), 1)
11 |
12 | def get_next_member(self):
13 | time_steps = [member.time_step for member in self.members]
14 | for time_step in range(min(time_steps), max(time_steps) + 1):
15 | for member in self.members:
16 | if member.is_free and member.time_step == time_step:
17 | return member
18 | else:
19 | if self.mode == 'synchronous':
20 | return None
21 |
22 | def get_scores(self):
23 | return {
24 | member.member_id: member.get_last_score()
25 | for member in self.members}
26 |
27 | def get_scores_by_time_step(self, time_step):
28 | scores = {}
29 | for member in self.members:
30 | member_score = member.get_score_by_time_step(time_step)
31 | if member_score:
32 | scores[member.member_id] = member_score
33 | return scores
34 |
35 | def get_hyperparameters(self, member_id):
36 | return self.members[member_id].get_hyperparameters()
37 |
38 | def get_hyperparameters_by_time_step(self, member_id, time_step):
39 | return self.members[member_id].get_hyperparameters_by_time_step(time_step)
40 |
41 | def get_latest_time_step(self, model_id):
42 | return self.members[model_id].time_step
43 |
44 | def get_min_time_step(self):
45 | return min([member.time_step for member in self.members])
46 |
47 | def save_trial(self, trial):
48 | self.members[trial.member_id].assign_trial(
49 | trial, self._get_last_score(trial))
50 |
51 | def _get_last_score(self, trial):
52 | member_of_model = self.members[trial.model_id]
53 | if member_of_model.time_step == 0:
54 | return 0.0
55 | else:
56 | return member_of_model.get_last_score()
57 |
58 | def update(self, member_id, score):
59 | current_member = self.members[member_id]
60 | trial = current_member.save_score(score)
61 | current_member.log_last_result(self.results_path)
62 | if current_member.is_done and not current_member.max_time_step:
63 | self._set_max_time_step(current_member.time_step)
64 | return trial
65 |
66 | def is_done(self):
67 | return all(member.is_done for member in self.members)
68 |
69 | def _set_max_time_step(self, max_time_step):
70 | for member in self.members:
71 | member.max_time_step = max_time_step
72 |
73 | def get_elites(self):
74 | # For PBT-BT only, return empty list if it's not in the PBT-BT mode
75 | if not self.backtracking:
76 | return []
77 | else:
78 | best_trials = self.get_best_trials()
79 | return best_trials[:self.elite_size]
80 | # return sorted(best_trials, key=lambda trial: trial.score, reverse=True)[:self.elite_size]
81 |
82 | # def get_elites_by_delta_t(self, delta_t):
83 | # # For PBT-BT only, return empty list if it's not in the PBT-BT mode
84 | # if not self.backtracking:
85 | # return []
86 | # else:
87 | # best_trials = self.get_best_trials_by_delta_t(delta_t)
88 | # return sorted(best_trials, key=lambda trial: trial.score, reverse=True)[:self.elite_size]
89 |
90 | def get_best_trials(self):
91 | return [member.get_best_trial() for member in self.members if member.get_best_trial() is not None]
92 |
93 | # def get_best_trials_by_delta_t(self, delta_t):
94 | # return [member.get_best_trial_by_delta_t(delta_t) for member in self.members]
--------------------------------------------------------------------------------
/pbt/population/trial.py:
--------------------------------------------------------------------------------
1 | class Trial:
2 | def __init__(
3 | self, member_id, model_id, time_step, model_time_step,
4 | hyperparameters):
5 | self.member_id = member_id
6 | self.model_id = model_id
7 | self.time_step = time_step
8 | self.model_time_step = model_time_step
9 | self.hyperparameters = self._clean_hyperparameters(hyperparameters)
10 | self.score = None
11 | self.improvement = None
12 |
13 | def __repr__(self):
14 | return f''
22 |
23 | def _clean_hyperparameters(self, hyperparameters):
24 | result = hyperparameters.copy()
25 | for key, value in result.items():
26 | if callable(value):
27 | result[key] = value()
28 | return result
29 |
30 | def is_valid(self):
31 | return True
32 |
33 | def to_tuple(self):
34 | """
35 | Return this trial as tuple (easier to send over network).
36 | :return: (member_id, model_id, time_step, nodes, score)
37 | """
38 | return \
39 | self.member_id, self.model_id, self.time_step, \
40 | self.model_time_step, self.hyperparameters
41 |
42 | def to_array(self):
43 | result = [self.score, self.time_step, self.improvement]
44 | result += [
45 | self.hyperparameters[key]
46 | for key in sorted(self.hyperparameters)]
47 | return result
48 |
49 | @staticmethod
50 | def from_tuple(
51 | member_id, model_id, time_step, model_time_step, hyperparameters):
52 | return Trial(
53 | member_id, model_id, time_step, model_time_step, hyperparameters) \
54 | if member_id is not -1 \
55 | else NoTrial()
56 |
57 | def copy(self):
58 | return Trial(
59 | self.member_id, self.model_id, self.time_step, self.model_time_step,
60 | self.hyperparameters.copy())
61 |
62 |
63 | class NoTrial(Trial):
64 | def __init__(self):
65 | super().__init__(-1, -1, -1, -1, {})
66 |
67 | def is_valid(self):
68 | return False
69 |
--------------------------------------------------------------------------------
/pbt/scheduler.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from pbt.population import NoTrial, Trial
4 | from pbt.tqdm_logger import TqdmLoggingHandler
5 |
6 | class Scheduler:
7 | def __init__(
8 | self, population, start_hyperparameters, exploitation, exploration):
9 | self.logger = logging.getLogger('pbt')
10 | self.logger.setLevel(logging.DEBUG)
11 | self.logger.addHandler(TqdmLoggingHandler())
12 | self.population = population
13 | self.exploitation = exploitation
14 | self.exploration = exploration
15 | self.start_hyperparameters = start_hyperparameters
16 |
17 | def get_trial(self):
18 | self.logger.debug('Trial requested.')
19 | member = self.population.get_next_member()
20 | if not member:
21 | self.logger.debug('No trial ready.')
22 | return NoTrial()
23 |
24 | if member.time_step == 0:
25 | start_hyperparameters = {cfg_name : self.start_hyperparameters[cfg_name]()
26 | for cfg_name in self.start_hyperparameters.keys()}
27 | trial = Trial(
28 | member.member_id, -1, 0, -1, start_hyperparameters)
29 | self.population.save_trial(trial)
30 | self.logger.debug(f'Returning first trial {trial}.')
31 | return trial
32 |
33 | self.logger.debug(f'Generating trial for member {member.member_id}.')
34 | scores = self.population.get_scores_by_time_step(member.time_step - 1)
35 | # model_id indicates the model that we want to copy
36 | model_id = self.exploitation(member.member_id, scores)
37 | # model_time_step = self.population.get_latest_time_step(model_id) - 1
38 | model_time_step = member.time_step - 1
39 | hyperparameters = self.population.get_hyperparameters_by_time_step(model_id, model_time_step)
40 | if model_id != member.member_id:
41 | self.logger.debug(f'Copying model {model_id}.')
42 | hyperparameters = self.exploration(hyperparameters)
43 | self.logger.debug(f'Using exploration. New: {hyperparameters}')
44 | else:
45 | self.logger.debug(f'Staying with current model {model_id}.')
46 | trial = Trial(
47 | member.member_id, model_id, member.time_step, model_time_step,
48 | hyperparameters)
49 | self.population.save_trial(trial)
50 | return trial
51 |
52 | def update_exploration(self, trial):
53 | self.exploration.update(trial)
54 |
--------------------------------------------------------------------------------
/pbt/tqdm_logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import tqdm
3 |
4 | class TqdmLoggingHandler(logging.Handler):
5 | def __init__(self, level=logging.NOTSET):
6 | super().__init__(level)
7 |
8 | def emit(self, record):
9 | try:
10 | msg = self.format(record)
11 | tqdm.tqdm.write(msg)
12 | self.flush()
13 | except (KeyboardInterrupt, SystemExit):
14 | raise
15 | except:
16 | self.handleError(record)
--------------------------------------------------------------------------------
/pbt/worker.py:
--------------------------------------------------------------------------------
1 | from config_space.config_space import DEFAULT_CONFIGSPACE
2 | import numpy as np
3 |
4 | import logging
5 | import os
6 | from time import sleep
7 |
8 | import Pyro4
9 | import json
10 |
11 | from pbt.network import WorkerDaemon, ControllerAdapter, CONTROLLER_URI_FILENAME
12 | from pbt.population import Trial
13 | from pbt.tqdm_logger import TqdmLoggingHandler
14 | from tqdm import tqdm
15 | from scipy.io import savemat
16 |
17 |
18 | class Criterion:
19 | def __init__(self, criterion_mode, **kwargs):
20 | self.criterion_mode = criterion_mode
21 | self.kwargs = kwargs
22 |
23 | def __call__(self, traj_rets, traj_eval_rets=None, info=None):
24 | # Concatenate train and eval returns
25 | if traj_eval_rets is None:
26 | all_rets = traj_rets
27 | else:
28 | all_rets = np.concatenate([traj_rets, traj_eval_rets], axis=1)
29 | if self.criterion_mode == "max":
30 | # Take the maximal score over the past
31 | score = np.mean(all_rets, axis=1).max().item()
32 | elif self.criterion_mode == 'mean':
33 | # Take the average score over the past
34 | score = np.mean(all_rets, axis=1).mean().item()
35 | elif self.criterion_mode == 'lastk':
36 | last_k = self.kwargs.get('last_k', 1)
37 | score = np.mean(all_rets, axis=1)[-last_k:].mean().item()
38 | #elif self.criterion_mode == 'weighted_return':
39 | #
40 | else:
41 | raise NotImplementedError("%s is an invalid criterion mode" %self.criterion_mode)
42 | return score
43 |
44 | class PBTWorker:
45 | # TODO: Doc, worker that sequentially calls step and evaluate
46 | def __init__(self, worker_id, agent, policy_constructor, train_func, criterion_mode='mean', data_path=os.getcwd(),
47 | wait_time=5, initial_step=4, step=1, not_copy_data=False, **kwargs):
48 | self.logger = logging.getLogger('pbt')
49 | self.logger.setLevel(logging.DEBUG)
50 | self.logger.addHandler(TqdmLoggingHandler())
51 | self.worker_id = worker_id
52 | self.agent = agent
53 | self.policy_constructor = policy_constructor
54 | self.train_func = train_func
55 | self.data_path = data_path
56 | self.wait_time = wait_time
57 | self.initial_step = initial_step
58 | self.step = step
59 | # TODO: _is_done
60 | self.is_done = False
61 | self._controller = None
62 | self.not_copy_data = not_copy_data
63 | # Ensure safety of criterion mode lastk
64 | if self.not_copy_data:
65 | self.logger.info('NOTICE: We are not copying data in this experiment.')
66 | if kwargs.get('last_k', 0) > initial_step + 1:
67 | raise ValueError("criterion of last k must have at least same as initial_step + 1")
68 | self.criterion = Criterion(criterion_mode, **kwargs)
69 |
70 | def register(self, controller=None):
71 | if controller:
72 | self.logger.info('Registered controller directly.')
73 | self._controller = controller
74 | self._controller.register_worker(self)
75 | else:
76 | self.logger.info('Registered controller over network.')
77 | self._controller = ControllerAdapter(self._discover_controller())
78 | self._run_daemon()
79 |
80 | def _run_daemon(self):
81 | self.logger.info('Starting worker daemon.')
82 | daemon = WorkerDaemon(self)
83 | uri = daemon.start()
84 | success = self._controller.register_worker_by_uri(uri)
85 | if not success:
86 | daemon.stop()
87 | raise Exception(f'The read controller URI "{uri}" is not valid!')
88 |
89 | def run(self):
90 | if not self._controller:
91 | self.register()
92 |
93 | while not self.is_done:
94 | self._run_iteration()
95 |
96 | def stop(self):
97 | self.logger.info('Shutting down worker.')
98 | self.is_done = True
99 |
100 | def _run_iteration(self):
101 | trial = self._load_trial()
102 | if not trial:
103 | return
104 |
105 | hyperparameters = trial.hyperparameters
106 | policy = self.policy_constructor(hyperparameters, DEFAULT_CONFIGSPACE)
107 | if trial.time_step == 0:
108 | traj_obs, traj_acs, traj_rets, traj_rews, info = self.train_func(self.agent, policy, step=self.initial_step)
109 | else:
110 | path = self._get_last_model_path(trial)
111 | policy.load_model(path)
112 | traj_obs, traj_acs, traj_rets, traj_rews = self.load_training_data(path)
113 | traj_obs, traj_acs, traj_rets, traj_rews, info = self.train_func(self.agent, policy,
114 | traj_obs, traj_acs, traj_rets, traj_rews, step=self.step)
115 | score = self.criterion(traj_rets, info=info)
116 |
117 | save_path = self._create_model_path(trial)
118 | policy.save_model(save_path)
119 | self.save_training_data(save_path, traj_obs, traj_acs, traj_rets, traj_rews)
120 | self.save_infomation(save_path, info)
121 | self._send_evaluation(trial.member_id, score)
122 |
123 |
124 | def load_training_data(self, path):
125 | with open(os.path.join(path, "traj_obs.json"), 'r') as f:
126 | traj_obs = json.load(f)
127 | with open(os.path.join(path, "traj_acs.json"), 'r') as f:
128 | traj_acs = json.load(f)
129 | with open(os.path.join(path, "traj_rets.json"), 'r') as f:
130 | traj_rets = json.load(f)
131 | with open(os.path.join(path, "traj_rews.json"), 'r') as f:
132 | traj_rews = json.load(f)
133 |
134 | return traj_obs, traj_acs, traj_rets, traj_rews
135 |
136 | def save_training_data(self, path, traj_obs, traj_acs, traj_rets, traj_rews):
137 | with open(os.path.join(path, "traj_obs.json"), 'w') as f:
138 | json.dump(traj_obs, f)
139 | with open(os.path.join(path, "traj_acs.json"), 'w') as f:
140 | json.dump(traj_acs, f)
141 | with open(os.path.join(path, "traj_rets.json"), 'w') as f:
142 | json.dump(traj_rets, f)
143 | with open(os.path.join(path, "traj_rews.json"), 'w') as f:
144 | json.dump(traj_rews, f)
145 |
146 | def save_infomation(self, path, info):
147 | savemat(
148 | os.path.join(path, "infos.mat"),
149 | {
150 | key : info[key] for key in info.keys()
151 | },
152 | long_field_names=True
153 | )
154 |
155 | def _load_trial(self):
156 | trial = Trial.from_tuple(*self._get_trial())
157 |
158 | if not trial.is_valid():
159 | self.logger.info(
160 | f'No trial ready. Waiting for {self.wait_time} seconds.')
161 | sleep(self.wait_time)
162 | else:
163 | self.logger.info(f'Got valid trial {trial} from controller.')
164 | return trial
165 |
166 | def _discover_controller(self):
167 | self.logger.debug('Discovering controller.')
168 | file_path = os.path.join(self.data_path, CONTROLLER_URI_FILENAME)
169 | tqdm.write(file_path)
170 | from time import sleep
171 | sleep(5)
172 | for number_of_try in range(5):
173 | try:
174 | with open(file_path, 'r') as f:
175 | uri = f.readline().strip()
176 | break
177 | except FileNotFoundError:
178 | self.logger.info('Can\'t reach controller. Waiting ...')
179 | sleep(5)
180 | if number_of_try < 4:
181 | continue
182 | raise Exception('Can\'t reach controller!')
183 | return Pyro4.Proxy(uri)
184 |
185 | def _get_trial(self):
186 | # TODO: Intercept connection issues
187 | return self._controller.request_trial()
188 |
189 | def _get_last_model_path(self, trial):
190 | # TODO: Remove ambiguity with model_id <-> member_id
191 | if self.not_copy_data:
192 | path = os.path.join(
193 | self.data_path, str(trial.member_id), str(trial.model_time_step))
194 | else:
195 | path = os.path.join(
196 | self.data_path, str(trial.model_id), str(trial.model_time_step))
197 | return path
198 |
199 | def _create_model_path(self, trial):
200 | path = os.path.join(
201 | self.data_path, str(trial.member_id), str(trial.time_step))
202 | if not os.path.isdir(path):
203 | os.makedirs(path)
204 | return path
205 |
206 | def _create_extra_data_path(self, trial):
207 | path = os.path.join(
208 | self.data_path, 'extra', str(trial.member_id), str(trial.time_step))
209 | if not os.path.isdir(path):
210 | os.makedirs(path)
211 | return path
212 |
213 | def _send_evaluation(self, member_id, score):
214 | # TODO: Handle connection issues
215 | self.logger.info(f'Sending evaluation. Score: {score}')
216 | self._controller.send_evaluation(member_id, float(score))
217 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.9.0
2 | astor==0.8.1
3 | astunparse==1.6.3
4 | cachetools==4.0.0
5 | chardet==3.0.4
6 | cloudpickle==1.2.1
7 | configspace==0.4.10
8 | cycler==0.10.0
9 | cython==0.29.13
10 | dataclasses==0.6
11 | decorator==4.4.0
12 | dm-env==1.2
13 | dm-tree==0.1.5
14 | dotmap==1.2.20
15 | future==0.16.0
16 | gast==0.2.2
17 | glfw==1.8.3
18 | google-auth==1.11.2
19 | google-auth-oauthlib==0.4.1
20 | google-pasta==0.2.0
21 | googledrivedownloader==0.4
22 | gpflow==1.1.0
23 | grpcio==1.27.2
24 | gym==0.14.0
25 | h5py==2.10.0
26 | hpbandster==0.7.4
27 | idna==2.8
28 | imageio==2.5.0
29 | isodate==0.6.0
30 | joblib==0.13.2
31 | keras-applications==1.0.8
32 | keras-preprocessing==1.1.2
33 | kiwisolver==1.1.0
34 | labmaze==1.0.3
35 | lockfile==0.12.2
36 | lxml==4.5.2
37 | markdown==3.2.1
38 | matplotlib==3.1.1
39 | more-itertools==8.3.0
40 | mujoco-py==2.0.2.5
41 | multipledispatch==0.6.0
42 | netifaces==0.10.9
43 | networkx==2.3
44 | numpy==1.17.4
45 | oauthlib==3.1.0
46 | omegaconf==2.0.0
47 | opencv-python==4.4.0.42
48 | opt-einsum==3.2.1
49 | packaging==20.4
50 | pandas==0.25.1
51 | patsy==0.5.1
52 | pluggy==0.13.1
53 | plyfile==0.7
54 | protobuf==3.10.0
55 | py==1.8.1
56 | pyasn1==0.4.8
57 | pyasn1-modules==0.2.8
58 | pyglet==1.3.2
59 | pyopengl==3.1.5
60 | pyparsing==2.4.2
61 | pyro4==4.80
62 | pytest==5.4.2
63 | python-dateutil==2.8.0
64 | pytz==2019.2
65 | pyyaml==5.3.1
66 | rdflib==4.2.2
67 | requests==2.22.0
68 | requests-oauthlib==1.3.0
69 | rsa==4.0
70 | scikit-learn==0.21.3
71 | scipy==1.3.1
72 | seaborn==0.11.0
73 | serpent==1.30.2
74 | statsmodels==0.11.1
75 | tensorboard-plugin-wit==1.6.0.post3
76 | termcolor==1.1.0
77 | tqdm==4.19.4
78 | typing==3.7.4.1
79 | typing-extensions==3.7.4.2
80 | urllib3==1.25.3
81 | werkzeug==1.0.0
82 | wrapt==1.12.1
--------------------------------------------------------------------------------
/scripts/hyperband.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Example command
4 |
5 | # activate your conda env
6 | source ~/.bashrc
7 | source activate mbrl
8 |
9 | INTERFACE=`ip route get 8.8.8.8 | cut -d' ' -f5 | head -1`
10 | echo Interface read:$INTERFACE
11 |
12 | # ENV=halfcheetah_v3
13 | # TASK_HORIZON=1000
14 | ENV=reacher
15 | TASK_HORIZON=10
16 |
17 | RUN_ID=$SLURM_ARRAY_JOB_ID
18 | # Replace with your directory
19 | DIR=log/$ENV\_$RUN_ID
20 | NINIT_ROLLOUTS=1
21 | MIN_BUDGET=79
22 | MAX_BUDGET=80
23 | NOPT_ITER=5
24 | LAST_K=3
25 | ETA=2
26 | NEVAL=1
27 |
28 | SEED=0
29 | OPT_TYPE=hyperband # or [random, bohb]
30 | PROP_TYPE=TSinf
31 |
32 | OPT_TYPE_PETS=CEM
33 | # CONFIG_NAMES=plan_hor\ num_cem_iters\ cem_popsize\ cem_elites_ratio\ cem_alpha
34 | CONFIG_NAMES=model_weight_decay\ model_learning_rate\ model_train_epoch
35 |
36 | cd ..
37 |
38 | if [ $SLURM_ARRAY_TASK_ID -eq 1 ]
39 | then
40 | python -u bo-mbexp.py -config_names $CONFIG_NAMES \
41 | -opt_type $OPT_TYPE \
42 | -run_id $RUN_ID \
43 | -env $ENV \
44 | -logdir $DIR \
45 | -worker_id $SLURM_ARRAY_TASK_ID \
46 | -seed $SEED \
47 | -interface $INTERFACE \
48 | -o exp_cfg.log_cfg.neval $NEVAL \
49 | -o exp_cfg.exp_cfg.ninit_rollouts $NINIT_ROLLOUTS \
50 | -o exp_cfg.bo_cfg.min_budget $MIN_BUDGET \
51 | -o exp_cfg.bo_cfg.max_budget $MAX_BUDGET \
52 | -o exp_cfg.bo_cfg.nopt_iter $NOPT_ITER \
53 | -o exp_cfg.bo_cfg.eta $ETA \
54 | -o exp_cfg.bo_cfg.last_k $LAST_K \
55 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \
56 | -ca prop-type $PROP_TYPE \
57 | -ca opt-type $OPT_TYPE_PETS
58 | else
59 | python -u bo-mbexp.py -config_names $CONFIG_NAMES \
60 | -opt_type $OPT_TYPE \
61 | -run_id $RUN_ID \
62 | -worker \
63 | -env $ENV \
64 | -logdir $DIR \
65 | -worker_id $SLURM_ARRAY_TASK_ID \
66 | -seed $SEED \
67 | -interface $INTERFACE \
68 | -o exp_cfg.log_cfg.neval $NEVAL \
69 | -o exp_cfg.exp_cfg.ninit_rollouts $NINIT_ROLLOUTS \
70 | -o exp_cfg.bo_cfg.min_budget $MIN_BUDGET \
71 | -o exp_cfg.bo_cfg.max_budget $MAX_BUDGET \
72 | -o exp_cfg.bo_cfg.nopt_iter $NOPT_ITER \
73 | -o exp_cfg.bo_cfg.eta $ETA \
74 | -o exp_cfg.bo_cfg.last_k $LAST_K \
75 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \
76 | -ca prop-type $PROP_TYPE \
77 | -ca opt-type $OPT_TYPE_PETS
78 | fi
79 |
80 |
--------------------------------------------------------------------------------
/scripts/pbt-bt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Example command
3 |
4 | # activate your conda env
5 | source ~/.bashrc
6 | source activate mbrl
7 |
8 | SAMPLE_FROM_PERCENT=0.2
9 | RESAMPLE_IF_NOT_IN_PERCENT=0.8
10 | RESAMPLE_PROBABILITY=0.25
11 |
12 | MAX_STEPS=2400
13 | DELTA_T=6
14 | TOLERANCE=0.2
15 | ELITE_RATIO=0.2
16 |
17 | POPULATION_SIZE=5
18 | CRITERION_MODE=lastk
19 | TASK_HORIZON=10
20 | LAST_K=3
21 | INITIAL_STEP=6
22 | BUDGET=3
23 | STEP=5
24 |
25 | # PROP_TYPE=E
26 | PROP_TYPE=TSinf
27 |
28 | OPT_TYPE_PETS=CEM
29 | # TOTAL STEPS = (BUDGET - 1) * STEP + INITIAL_STEP
30 |
31 | #ENV=pusher
32 | ENV=halfcheetah_v3
33 | # ENV=hopper
34 | # ENV=cartpole
35 |
36 | # Replace with your directory
37 | DIR=log/$ENV\_$SLURM_ARRAY_JOB_ID
38 | WORKER_ID=$SLURM_ARRAY_JOB_ID
39 |
40 | # CONFIG_NAMES=plan_hor\ num_cem_iters\ cem_popsize\ cem_elites_ratio\ cem_alpha
41 | CONFIG_NAMES=model_weight_decay\ model_learning_rate\ model_train_epoch
42 | NEVAL=1
43 |
44 | cd ..
45 |
46 |
47 | if [ $SLURM_ARRAY_TASK_ID -eq 1 ]
48 | then
49 | python -u pbt-bt-mbexp.py -config_names $CONFIG_NAMES \
50 | -seed 0 \
51 | -env $ENV \
52 | -logdir $DIR \
53 | -worker_id $SLURM_ARRAY_TASK_ID \
54 | -sample_from_percent $SAMPLE_FROM_PERCENT \
55 | -resample_if_not_in_percent $RESAMPLE_IF_NOT_IN_PERCENT \
56 | -resample_probability $RESAMPLE_PROBABILITY \
57 | -max_steps $MAX_STEPS \
58 | -delta_t $DELTA_T \
59 | -tolerance $TOLERANCE \
60 | -elite_ratio $ELITE_RATIO \
61 | -o exp_cfg.log_cfg.neval $NEVAL \
62 | -o exp_cfg.pbt_cfg.pop_size $POPULATION_SIZE \
63 | -o exp_cfg.pbt_cfg.budget $BUDGET \
64 | -o exp_cfg.pbt_cfg.criterion_mode $CRITERION_MODE \
65 | -o exp_cfg.pbt_cfg.last_k $LAST_K \
66 | -o exp_cfg.pbt_cfg.initial_step $INITIAL_STEP \
67 | -o exp_cfg.pbt_cfg.step $STEP \
68 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \
69 | -ca prop-type $PROP_TYPE \
70 | -ca opt-type $OPT_TYPE_PETS
71 | else
72 | python -u pbt-bt-mbexp.py -config_names $CONFIG_NAMES \
73 | -seed 0 \
74 | -worker \
75 | -env $ENV \
76 | -logdir $DIR \
77 | -worker_id $WORKER_ID \
78 | -max_steps $MAX_STEPS \
79 | -delta_t $DELTA_T \
80 | -tolerance $TOLERANCE \
81 | -elite_ratio $ELITE_RATIO \
82 | -o exp_cfg.log_cfg.neval $NEVAL \
83 | -o exp_cfg.pbt_cfg.pop_size $POPULATION_SIZE \
84 | -o exp_cfg.pbt_cfg.budget $BUDGET \
85 | -o exp_cfg.pbt_cfg.criterion_mode $CRITERION_MODE \
86 | -o exp_cfg.pbt_cfg.last_k $LAST_K \
87 | -o exp_cfg.pbt_cfg.initial_step $INITIAL_STEP \
88 | -o exp_cfg.pbt_cfg.step $STEP \
89 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \
90 | -ca prop-type $PROP_TYPE \
91 | -ca opt-type $OPT_TYPE_PETS
92 | fi
--------------------------------------------------------------------------------
/scripts/pbt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # activate your conda env
4 | source ~/.bashrc
5 | source activate mbrl
6 |
7 | SAMPLE_FROM_PERCENT=0.2
8 | RESAMPLE_IF_NOT_IN_PERCENT=0.8
9 | RESAMPLE_PROBABILITY=0.25
10 |
11 |
12 | POPULATION_SIZE=40
13 | CRITERION_MODE=lastk
14 | TASK_HORIZON=10
15 | LAST_K=3
16 | INITIAL_STEP=6
17 | BUDGET=60
18 | STEP=5
19 |
20 | PROP_TYPE=E
21 | OPT_TYPE_PETS=CEM
22 | # TOTAL STEPS = (BUDGET - 1) * STEP + INITIAL_STEP
23 |
24 | ENV=halfcheetah_v3
25 |
26 | # Replace with your directory
27 | DIR=log/$ENV\_$SLURM_ARRAY_JOB_ID
28 | WORKER_ID=$SLURM_ARRAY_JOB_ID
29 |
30 | # CONFIG_NAMES=plan_hor\ num_cem_iters\ cem_popsize\ cem_elites_ratio\ cem_alpha
31 | CONFIG_NAMES=model_weight_decay\ model_learning_rate\ model_train_epoch
32 | NEVAL=1
33 |
34 | cd ..
35 |
36 | if [ $SLURM_ARRAY_TASK_ID -eq 1 ]
37 | then
38 | python -u pbt-mbexp.py -config_names $CONFIG_NAMES \
39 | -seed 0 \
40 | -env $ENV \
41 | -logdir $DIR \
42 | -worker_id $SLURM_ARRAY_TASK_ID \
43 | -sample_from_percent $SAMPLE_FROM_PERCENT \
44 | -resample_if_not_in_percent $RESAMPLE_IF_NOT_IN_PERCENT \
45 | -resample_probability $RESAMPLE_PROBABILITY \
46 | -o exp_cfg.log_cfg.neval $NEVAL \
47 | -o exp_cfg.pbt_cfg.pop_size $POPULATION_SIZE \
48 | -o exp_cfg.pbt_cfg.budget $BUDGET \
49 | -o exp_cfg.pbt_cfg.criterion_mode $CRITERION_MODE \
50 | -o exp_cfg.pbt_cfg.last_k $LAST_K \
51 | -o exp_cfg.pbt_cfg.initial_step $INITIAL_STEP \
52 | -o exp_cfg.pbt_cfg.step $STEP \
53 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \
54 | -ca prop-type $PROP_TYPE \
55 | -ca opt-type $OPT_TYPE_PETS
56 | else
57 | python -u pbt-mbexp.py -config_names $CONFIG_NAMES \
58 | -seed 0 \
59 | -worker \
60 | -env $ENV \
61 | -logdir $DIR \
62 | -worker_id $WORKER_ID \
63 | -o exp_cfg.log_cfg.neval $NEVAL \
64 | -o exp_cfg.pbt_cfg.pop_size $POPULATION_SIZE \
65 | -o exp_cfg.pbt_cfg.budget $BUDGET \
66 | -o exp_cfg.pbt_cfg.criterion_mode $CRITERION_MODE \
67 | -o exp_cfg.pbt_cfg.last_k $LAST_K \
68 | -o exp_cfg.pbt_cfg.initial_step $INITIAL_STEP \
69 | -o exp_cfg.pbt_cfg.step $STEP \
70 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \
71 | -ca prop-type $PROP_TYPE \
72 | -ca opt-type $OPT_TYPE_PETS
73 | fi
--------------------------------------------------------------------------------