├── .gitignore ├── LICENCE ├── README.md ├── bo-mbexp.py ├── bohb ├── HPWorker.py └── __init__.py ├── config_space └── config_space.py ├── dmbrl ├── config │ ├── __init__.py │ ├── cartpole.py │ ├── default.py │ ├── halfcheetah_v3.py │ ├── hopper.py │ ├── pusher.py │ ├── reacher.py │ └── template.py ├── controllers │ ├── Controller.py │ ├── MPC.py │ └── __init__.py ├── env │ ├── __init__.py │ ├── assets │ │ ├── cartpole.xml │ │ ├── half_cheetah.xml │ │ ├── hopper.xml │ │ ├── pusher.xml │ │ └── reacher3d.xml │ ├── cartpole.py │ ├── half_cheetah_v3.py │ ├── hopper.py │ ├── pusher.py │ └── reacher.py ├── misc │ ├── Agent.py │ ├── DotmapUtils.py │ ├── MBExp.py │ ├── MBwBOExp.py │ ├── MBwPBTBTExp.py │ ├── MBwPBTExp.py │ ├── __init__.py │ ├── optimizers │ │ ├── __init__.py │ │ ├── cem.py │ │ ├── optimizer.py │ │ └── random.py │ └── render.py └── modeling │ ├── __init__.py │ ├── layers │ ├── FC.py │ └── __init__.py │ ├── models │ ├── BNN.py │ ├── NN.py │ ├── TFGP.py │ └── __init__.py │ └── utils │ ├── HPO.py │ ├── TensorStandardScaler.py │ └── __init__.py ├── environment.yml ├── eval_schedules.py ├── learned_schedule ├── daisy_hb_model_train.json ├── daisy_hb_planning.json ├── daisy_pbt_model_train_incumbent.json ├── daisy_pbt_model_train_top_avg.json ├── daisy_pbt_planning_incumbent.json ├── daisy_pbt_planning_top_avg.json ├── daisy_rs_model_train.json ├── daisy_rs_planning.json ├── halfcheetah_default.json ├── halfcheetah_hb_model_train.json ├── halfcheetah_hb_planning.json ├── halfcheetah_pbt_model_train_incumbent.json ├── halfcheetah_pbt_model_train_top_avg.json ├── halfcheetah_pbt_planning_incumbent.json ├── halfcheetah_pbt_planning_top_avg.json ├── halfcheetah_pbtbt_model_train.json ├── halfcheetah_pbtbt_planning.json ├── halfcheetah_rs_model_train.json ├── halfcheetah_rs_planning.json ├── hopper_hb_model_train.json ├── hopper_hb_planning.json ├── hopper_pbt_model_train_incumbent.json ├── hopper_pbt_model_train_top_avg.json ├── hopper_pbt_planning_incumbent.json ├── hopper_pbt_planning_top_avg.json ├── hopper_pbtbt_model_train.json ├── hopper_pbtbt_planning.json ├── hopper_rs_model_train.json ├── hopper_rs_planning.json ├── pusher_default.json ├── pusher_hb_model_train.json ├── pusher_hb_planning.json ├── quadruped_hb_model_train.json ├── quadruped_hb_planning.json ├── quadruped_pbt_model_train_incumbent.json ├── quadruped_pbt_model_train_top_avg.json ├── quadruped_pbt_planning_incumbent.json ├── quadruped_pbt_planning_top_avg.json ├── quadruped_rs_model_train.json ├── quadruped_rs_planning.json ├── reacher_hb_model_train.json ├── reacher_hb_planning.json ├── reacher_pbt_model_train_incumbent.json ├── reacher_pbt_model_train_top_avg.json ├── reacher_pbt_planning_incumbent.json ├── reacher_pbt_planning_top_avg.json ├── reacher_rs_model_train.json └── reacher_rs_planning.json ├── mbexp.py ├── pbt-bt-mbexp.py ├── pbt-mbexp.py ├── pbt ├── __init__.py ├── backtrack_controller.py ├── backtrack_scheduler.py ├── controller.py ├── exploitation │ ├── __init__.py │ ├── constant.py │ ├── exploitation_strategy.py │ └── truncation.py ├── exploration │ ├── __init__.py │ ├── constant.py │ ├── exploration_bt.py │ ├── exploration_strategy.py │ ├── model_based.py │ ├── models │ │ ├── __init__.py │ │ ├── config_tree │ │ │ ├── __init__.py │ │ │ ├── config_tree.py │ │ │ └── nodes │ │ │ │ ├── __init__.py │ │ │ │ ├── categorical.py │ │ │ │ ├── float.py │ │ │ │ ├── integer.py │ │ │ │ ├── log_float.py │ │ │ │ └── node.py │ │ ├── model.py │ │ └── tree_parzen_estimator.py │ ├── perturb.py │ ├── perturb_and_resample.py │ └── resample.py ├── garbage_collector.py ├── network │ ├── __init__.py │ ├── controller_adapter.py │ ├── controller_daemon.py │ ├── daemon.py │ ├── worker_adapter.py │ └── worker_daemon.py ├── population │ ├── __init__.py │ ├── member.py │ ├── population.py │ └── trial.py ├── scheduler.py ├── tqdm_logger.py └── worker.py ├── requirements.txt └── scripts ├── hyperband.sh ├── pbt-bt.sh └── pbt.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | tags 3 | *.pyc 4 | log/ 5 | *.out 6 | .ipynb_checkpoints/* 7 | .vscode/ 8 | *.mat 9 | *.pdf 10 | *.png 11 | visualization/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Author: 2 | Baohe Zhang 3 | 4 | # Code Instruction 5 | 6 | ### Overview: 7 | This code is mainly based on Kurtland Chua's implementation of PETS paper. See [code](https://github.com/kchua/handful-of-trials). 8 | 9 | Besides that, to inject the bayesian optimization and Populaion-based Training optimization methods, several components are added. 10 | 11 | config_space (folder): 12 | **config_space.py** : Define the configuration space of hyperparameters for each environment. It is also the key component. 13 | 14 | pbt (folder): 15 | This folder contains the codes for Population-based training methods and Population-based training with back-tracking. The experiment object can be found pbt-mbexp.py and pbt-bt-mbexp.py. 16 | 17 | bohb (folder): 18 | This folder contains the bayesian optimization method [BOHB](https://github.com/automl/HpBandSter/blob/master/hpbandster/optimizers/bohb.py). This also provides Hyperband and Random search functionalities. 19 | 20 | ### Environment Setting 21 | You can use anaconda environment yml file if you want which contains all the necessary libraries used to run this code. 22 | 23 | ### Scripts for Running codes 24 | In order to run the experiments given a fixed hyperparameters, you can refer the following scripts. In the scripts folder, there are some examples when using slurm cluster. 25 | 26 | Example commend to start the code: 27 | ``` 28 | # xxx refers the cluster partation 29 | # Hyperband 30 | cd scripts 31 | sbatch -p xxx -a 1-10 hyperband.sh 32 | # For PBT 33 | 34 | # PBT 35 | cd scripts 36 | sbatch -p xxx -a 1-10 pbt.sh 37 | 38 | # PBT-BT 39 | cd scripts 40 | sbatch -p xxx -a 1-10 pbt-bt.sh 41 | 42 | ``` 43 | 44 | ### Learned Schedules 45 | In the learned_schedule folder, we have uploaded few hyperparameter schedules that are learned by running HPO methods on PETS. You should be able to reproduce the results reported in our [paper](https://arxiv.org/abs/2102.13651). The evaluation function is in eval_schedules.py file. 46 | 47 | ### Disclaimer 48 | This code is still not yet perfectly cleaned up and might contain issues. Feel free to contact me if there are anything unclear to you. 49 | 50 | ### Known Issue 51 | When running PBT on Slurm-based cluster, the controller will occupy one gpu without really using it since the controller will not train any model. One possible way to alleviate this temporarly is by starting the controller first and then start the workers. But this would need the user taking care that the workers know the directory which stores the server connect file. -------------------------------------------------------------------------------- /bo-mbexp.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import argparse 7 | import pprint 8 | 9 | from dotmap import DotMap 10 | 11 | from dmbrl.misc.MBwBOExp import MBWithBOExperiment 12 | from dmbrl.controllers.MPC import MPC 13 | from dmbrl.config import create_config 14 | 15 | import tensorflow as tf 16 | import numpy as np 17 | import random 18 | 19 | def create_cfg_creator(env, ctrl_type, ctrl_args, base_overrides, logdir): 20 | ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args}) 21 | 22 | def cfg_creator(additional_overrides=None): 23 | if additional_overrides is not None: 24 | return create_config(env, ctrl_type, ctrl_args, base_overrides + additional_overrides, logdir) 25 | return create_config(env, ctrl_type, ctrl_args, base_overrides, logdir) 26 | 27 | return cfg_creator 28 | 29 | 30 | def main(args): 31 | cfg_creator = create_cfg_creator(args.env, args.ctrl_type, args.ctrl_arg, args.override, args.logdir) 32 | cfg = cfg_creator()[0] 33 | cfg.pprint() 34 | 35 | if args.ctrl_type == "MPC": 36 | cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg) 37 | 38 | exp = MBWithBOExperiment(cfg.exp_cfg, cfg_creator, args) 39 | os.makedirs(exp.logdir, exist_ok=True) 40 | 41 | exp.run_experiment() 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('-env', type=str, required=True, 47 | help='Environment name: select from [cartpole, reacher, pusher, halfcheetah]') 48 | parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[], 49 | help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments') 50 | parser.add_argument('-o', '--override', action='append', nargs=2, default=[], 51 | help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides') 52 | parser.add_argument('-logdir', type=str, default='log', 53 | help='Directory to which results will be logged (default: ./log)') 54 | parser.add_argument('-ctrl_type', type=str, default='MPC', 55 | help='Control type will be applied (default: MPC)') 56 | # Parser for running BOHB on cluster 57 | parser.add_argument('-worker', action='store_true', 58 | help='Flag to turn this into a worker process') 59 | parser.add_argument('-interface', type=str, default='lo', 60 | help='Interface name to use for creating host') 61 | parser.add_argument('-run_id', type=str, default='111', 62 | help='A unique run id for this optimization run. An easy option is to use' 63 | ' the job id of the clusters scheduler.') 64 | parser.add_argument('-worker_id', type=int, default=0, 65 | help='The ID of the worker') 66 | parser.add_argument('-config_names', type=str, default="model_learning_rate", nargs="+", 67 | help='Specify which hyperparameters to optimize)') 68 | parser.add_argument('-opt_type', type=str, default="bohb", 69 | help='Specify which optimizer to use') 70 | parser.add_argument('-seed', type=int, default=0, 71 | help='Specify the random seed to use') 72 | args = parser.parse_args() 73 | tf.set_random_seed(args.seed) 74 | np.random.seed(args.seed) 75 | random.seed(args.seed) 76 | 77 | print(args) 78 | main(args) 79 | -------------------------------------------------------------------------------- /bohb/HPWorker.py: -------------------------------------------------------------------------------- 1 | import ConfigSpace as CS 2 | from hpbandster.core.worker import Worker 3 | import numpy as np 4 | from config_space.config_space import DEFAULT_CONFIGSPACE 5 | 6 | class HPWorker(Worker): 7 | """"Worker for BOHB/Hyperband/Random Search 8 | We use the implementation based on https://github.com/automl/HpBandSter 9 | """ 10 | def __init__(self, 11 | train_func, 12 | env, 13 | config_names, 14 | last_k, 15 | **kwargs): 16 | """ 17 | Initialization 18 | """ 19 | super().__init__(**kwargs) 20 | self.train_func = train_func 21 | if len(config_names) == 0: 22 | raise ValueError("config_names can not be an empty list") 23 | self.config_names = config_names 24 | self.env = env 25 | self.last_k = last_k 26 | 27 | @staticmethod 28 | def get_configspace(env, config_names): 29 | """ 30 | Construct the configuration space 31 | Param: 32 | env: (str) Name of the environment that we will used 33 | 34 | Return: 35 | cs: (ConfigurationSpace) A configuration space which defines the search space 36 | """ 37 | cs = CS.ConfigurationSpace() 38 | config_lst = [] 39 | for config_name in config_names: 40 | config_lst.append(DEFAULT_CONFIGSPACE[env][config_name]) 41 | cs.add_hyperparameters(config_lst) 42 | return cs 43 | 44 | def compute(self, config_id, config, budget, **kwargs): 45 | """ 46 | Evaluates the configuration on the defined budget and returns the validation performance. 47 | Args: 48 | config_id: (int) configuration id 49 | config: dictionary containing the sampled configurations by the optimizer 50 | budget: (float) amount of time/epochs/etc. the model can use to train 51 | Returns: 52 | dictionary with mandatory fields: 53 | 'loss' (scalar) 54 | 'info' (dict) 55 | """ 56 | train_func = self.train_func 57 | # Construct the configurations 58 | cfg = dict() 59 | for config_name in DEFAULT_CONFIGSPACE[self.env]: 60 | if config_name in self.config_names: 61 | cfg[config_name] = config[config_name] 62 | else: 63 | cfg[config_name] = DEFAULT_CONFIGSPACE[self.env][config_name].default_value 64 | 65 | # Run one configuration 66 | traj_rets, traj_eval_rets = train_func(config_id, cfg, budget) 67 | all_rets = np.concatenate([traj_rets, traj_eval_rets], axis=1) 68 | loss = - np.mean(all_rets, axis=1)[-self.last_k:].mean().item() 69 | return ({ 70 | 'loss': loss, 71 | 'info': {} 72 | }) 73 | 74 | -------------------------------------------------------------------------------- /bohb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/HPO_for_RL/d82c7ddd6fe19834c088137570530f11761d9390/bohb/__init__.py -------------------------------------------------------------------------------- /config_space/config_space.py: -------------------------------------------------------------------------------- 1 | import ConfigSpace as CS 2 | import ConfigSpace.hyperparameters as CSH 3 | 4 | """ 5 | This defines the search space of each environment. The default value is chosen based on Kurtland et al. 2018. 6 | See https://github.com/kchua/handful-of-trials for more information 7 | """ 8 | DEFAULT_CONFIGSPACE = { 9 | "cartpole": { 10 | # Model Architecture 11 | "num_hidden_layers" : CSH.UniformIntegerHyperparameter("num_hidden_layers", lower=2, upper=8, default_value=3, log=False), 12 | "hidden_layer_width" : CSH.UniformIntegerHyperparameter("hidden_layer_width", lower=100, upper=600, default_value=500, log=True), 13 | "act_idx" : CSH.CategoricalHyperparameter(name='activation', choices=['relu', 'tanh', 'sigmoid', 'swish'], default_value='swish'), 14 | # Model Optimizer 15 | "model_learning_rate" : CSH.UniformFloatHyperparameter('model_learning_rate', lower=1e-5, upper=4e-2, default_value=1e-3, log=True), 16 | "model_weight_decay" : CSH.UniformFloatHyperparameter('model_weight_decay', lower=1e-7, upper=1e-1, default_value=0.00025, log=True), 17 | "model_opt_idx" : CSH.CategoricalHyperparameter(name="model_opt_idx", choices=['adam', 'adadelta', 'adagrad', 'sgd', 'rms'], default_value='adam'), 18 | "model_train_epoch" : CSH.UniformIntegerHyperparameter("model_train_epoch", lower=3, upper=15, default_value=5, log=False), 19 | # Planner 20 | "num_cem_iters" : CSH.UniformIntegerHyperparameter("num_cem_iters", lower=4, upper=6, default_value=5, log=False), 21 | "cem_popsize" : CSH.UniformIntegerHyperparameter("cem_popsize", lower=200, upper=700, default_value=400, log=True), 22 | "cem_alpha" : CSH.UniformFloatHyperparameter("cem_alpha", lower=0.05, upper=0.2, default_value=0.1, log=False), 23 | "cem_elites_ratio" : CSH.UniformFloatHyperparameter("cem_elites_ratio", lower=0.04, upper=0.5, default_value=0.1, log=True), 24 | "plan_hor" : CSH.UniformIntegerHyperparameter("plan_hor", lower=5, upper=40, default_value=25, log=False), 25 | }, 26 | "pusher": { 27 | # Model Architecture 28 | "num_hidden_layers" : CSH.UniformIntegerHyperparameter("num_hidden_layers", lower=3, upper=8, default_value=3, log=False), 29 | "hidden_layer_width" : CSH.UniformIntegerHyperparameter("hidden_layer_width", lower=100, upper=600, default_value=200, log=True), 30 | "act_idx" : CSH.CategoricalHyperparameter(name='activation', choices=['relu', 'tanh', 'sigmoid', 'swish'], default_value='swish'), 31 | # Model Optimizer 32 | "model_learning_rate" : CSH.UniformFloatHyperparameter('model_learning_rate', lower=3e-5, upper=3e-3, default_value=0.001, log=True), 33 | "model_weight_decay" : CSH.UniformFloatHyperparameter('model_weight_decay', lower=1e-7, upper=1e-1, default_value=5e-4, log=True), 34 | "model_opt_idx" : CSH.CategoricalHyperparameter(name="model_opt_idx", choices=['adam', 'adadelta', 'adagrad', 'sgd', 'rms'], default_value='adam'), 35 | "model_train_epoch" : CSH.UniformIntegerHyperparameter("model_train_epoch", lower=3, upper=20, default_value=5, log=False), 36 | # Planners 37 | "num_cem_iters" : CSH.UniformIntegerHyperparameter("num_cem_iters", lower=3, upper=10, default_value=5, log=False), 38 | "cem_popsize" : CSH.UniformIntegerHyperparameter("cem_popsize", lower=100, upper=700, default_value=500, log=True), 39 | "cem_alpha" : CSH.UniformFloatHyperparameter("cem_alpha", lower=0.01, upper=0.5, default_value=0.1, log=False), 40 | "cem_elites_ratio" : CSH.UniformFloatHyperparameter("cem_elites_ratio", lower=0.04, upper=0.5, default_value=0.1, log=True), 41 | "plan_hor" : CSH.UniformIntegerHyperparameter("plan_hor", lower=5, upper=40, default_value=25, log=False), 42 | }, 43 | "reacher": { 44 | # Model Architecture 45 | "num_hidden_layers" : CSH.UniformIntegerHyperparameter("num_hidden_layers", lower=3, upper=8, default_value=4, log=False), 46 | "hidden_layer_width" : CSH.UniformIntegerHyperparameter("hidden_layer_width", lower=100, upper=600, default_value=200, log=True), 47 | "act_idx" : CSH.CategoricalHyperparameter(name='activation', choices=['relu', 'tanh', 'sigmoid', 'swish'], default_value='swish'), 48 | # Model Optimizer 49 | "model_learning_rate" : CSH.UniformFloatHyperparameter('model_learning_rate', lower=1e-5, upper=4e-2, default_value=0.00075, log=True), 50 | "model_weight_decay" : CSH.UniformFloatHyperparameter('model_weight_decay', lower=1e-7, upper=1e-1, default_value=0.0005, log=True), 51 | "model_opt_idx" : CSH.CategoricalHyperparameter(name="model_opt_idx", choices=['adam', 'adadelta', 'adagrad', 'sgd', 'rms'], default_value='adam'), 52 | "model_train_epoch" : CSH.UniformIntegerHyperparameter("model_train_epoch", lower=3, upper=20, default_value=5, log=False), 53 | # Planner 54 | "num_cem_iters" : CSH.UniformIntegerHyperparameter("num_cem_iters", lower=4, upper=6, default_value=5, log=False), 55 | "cem_popsize" : CSH.UniformIntegerHyperparameter("cem_popsize", lower=200, upper=700, default_value=400, log=True), 56 | "cem_alpha" : CSH.UniformFloatHyperparameter("cem_alpha", lower=0.05, upper=0.4, default_value=0.1, log=False), 57 | "cem_elites_ratio" : CSH.UniformFloatHyperparameter("cem_elites_ratio", lower=0.04, upper=0.5, default_value=0.1, log=True), 58 | "plan_hor" : CSH.UniformIntegerHyperparameter("plan_hor", lower=5, upper=40, default_value=25, log=False), 59 | }, 60 | "halfcheetah_v3": { 61 | # Model Architecture 62 | "num_hidden_layers" : CSH.UniformIntegerHyperparameter("num_hidden_layers", lower=2, upper=8, default_value=4, log=False), 63 | "hidden_layer_width" : CSH.UniformIntegerHyperparameter("hidden_layer_width", lower=100, upper=600, default_value=200, log=False), 64 | "act_idx" : CSH.CategoricalHyperparameter(name='activation', choices=['relu', 'tanh', 'sigmoid', 'swish'], default_value='swish'), 65 | # Model Optimizer 66 | "model_learning_rate" : CSH.UniformFloatHyperparameter('model_learning_rate', lower=1e-5, upper=4e-2, default_value=0.001, log=True), 67 | "model_weight_decay" : CSH.UniformFloatHyperparameter('model_weight_decay', lower=1e-7, upper=1e-1, default_value=0.000075, log=True), 68 | "model_opt_idx" : CSH.CategoricalHyperparameter(name="model_opt_idx", choices=['adam', 'adadelta', 'adagrad', 'sgd', 'rms'], default_value='adam'), 69 | "model_train_epoch" : CSH.UniformIntegerHyperparameter("model_train_epoch", lower=3, upper=20, default_value=5, log=False), 70 | # Planner 71 | "num_cem_iters" : CSH.UniformIntegerHyperparameter("num_cem_iters", lower=3, upper=8, default_value=5, log=False), 72 | "cem_popsize" : CSH.UniformIntegerHyperparameter("cem_popsize", lower=200, upper=700, default_value=500, log=True), 73 | "cem_alpha" : CSH.UniformFloatHyperparameter("cem_alpha", lower=0.05, upper=0.2, default_value=0.1, log=False), 74 | "cem_elites_ratio" : CSH.UniformFloatHyperparameter("cem_elites_ratio", lower=0.04, upper=0.5, default_value=0.1, log=True), 75 | "plan_hor" : CSH.UniformIntegerHyperparameter("plan_hor", lower=5, upper=60, default_value=30, log=False), 76 | }, 77 | "hopper": { 78 | # Model Architecture 79 | "num_hidden_layers" : CSH.UniformIntegerHyperparameter("num_hidden_layers", lower=2, upper=8, default_value=4, log=False), 80 | "hidden_layer_width" : CSH.UniformIntegerHyperparameter("hidden_layer_width", lower=100, upper=600, default_value=200, log=False), 81 | "act_idx" : CSH.CategoricalHyperparameter(name='activation', choices=['relu', 'tanh', 'sigmoid', 'swish'], default_value='swish'), 82 | # Model Optimizer 83 | "model_learning_rate" : CSH.UniformFloatHyperparameter('model_learning_rate', lower=1e-5, upper=4e-2, default_value=0.001, log=True), 84 | "model_weight_decay" : CSH.UniformFloatHyperparameter('model_weight_decay', lower=1e-7, upper=1e-1, default_value=0.000075, log=True), 85 | "model_opt_idx" : CSH.CategoricalHyperparameter(name="model_opt_idx", choices=['adam', 'adadelta', 'adagrad', 'sgd', 'rms'], default_value='adam'), 86 | "model_train_epoch" : CSH.UniformIntegerHyperparameter("model_train_epoch", lower=3, upper=20, default_value=5, log=False), 87 | # Planner 88 | "num_cem_iters" : CSH.UniformIntegerHyperparameter("num_cem_iters", lower=3, upper=8, default_value=5, log=False), 89 | "cem_popsize" : CSH.UniformIntegerHyperparameter("cem_popsize", lower=200, upper=700, default_value=500, log=True), 90 | "cem_alpha" : CSH.UniformFloatHyperparameter("cem_alpha", lower=0.05, upper=0.2, default_value=0.1, log=False), 91 | "cem_elites_ratio" : CSH.UniformFloatHyperparameter("cem_elites_ratio", lower=0.04, upper=0.5, default_value=0.1, log=True), 92 | "plan_hor" : CSH.UniformIntegerHyperparameter("plan_hor", lower=5, upper=60, default_value=30, log=False), 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /dmbrl/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .default import create_config -------------------------------------------------------------------------------- /dmbrl/config/cartpole.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from dotmap import DotMap 8 | import gym 9 | 10 | from dmbrl.misc.DotmapUtils import get_required_argument 11 | from dmbrl.modeling.layers import FC 12 | import dmbrl.env 13 | 14 | 15 | class CartpoleConfigModule: 16 | ENV_NAME = "MBRLCartpole-v0" 17 | TASK_HORIZON = 200 18 | NTRAIN_ITERS = 50 19 | NROLLOUTS_PER_ITER = 1 20 | PLAN_HOR = 25 21 | MODEL_IN, MODEL_OUT = 6, 4 22 | GP_NINDUCING_POINTS = 200 23 | 24 | def __init__(self): 25 | self.ENV = gym.make(self.ENV_NAME) 26 | cfg = tf.ConfigProto() 27 | cfg.gpu_options.allow_growth = True 28 | self.SESS = tf.Session(config=cfg) 29 | self.NN_TRAIN_CFG = {"epochs": 5} 30 | self.OPT_CFG = { 31 | "Random": { 32 | "popsize": 2000 33 | }, 34 | "CEM": { 35 | "popsize": 400, 36 | "num_elites": 40, 37 | "max_iters": 5, 38 | "alpha": 0.1 39 | } 40 | } 41 | 42 | @staticmethod 43 | def obs_preproc(obs): 44 | if isinstance(obs, np.ndarray): 45 | return np.concatenate([np.sin(obs[:, 1:2]), np.cos(obs[:, 1:2]), obs[:, :1], obs[:, 2:]], axis=1) 46 | else: 47 | return tf.concat([tf.sin(obs[:, 1:2]), tf.cos(obs[:, 1:2]), obs[:, :1], obs[:, 2:]], axis=1) 48 | 49 | @staticmethod 50 | def obs_postproc(obs, pred): 51 | return obs + pred 52 | 53 | @staticmethod 54 | def targ_proc(obs, next_obs): 55 | return next_obs - obs 56 | 57 | @staticmethod 58 | def obs_cost_fn(obs): 59 | if isinstance(obs, np.ndarray): 60 | return -np.exp(-np.sum( 61 | np.square(CartpoleConfigModule._get_ee_pos(obs, are_tensors=False) - np.array([0.0, 0.6])), axis=1 62 | ) / (0.6 ** 2)) 63 | else: 64 | return -tf.exp(-tf.reduce_sum( 65 | tf.square(CartpoleConfigModule._get_ee_pos(obs, are_tensors=True) - np.array([0.0, 0.6])), axis=1 66 | ) / (0.6 ** 2)) 67 | 68 | @staticmethod 69 | def ac_cost_fn(acs): 70 | if isinstance(acs, np.ndarray): 71 | return 0.01 * np.sum(np.square(acs), axis=1) 72 | else: 73 | return 0.01 * tf.reduce_sum(tf.square(acs), axis=1) 74 | 75 | def nn_constructor(self, model_init_cfg): 76 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 77 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"), 78 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False), 79 | model_dir=model_init_cfg.get("model_dir", None) 80 | )) 81 | if not model_init_cfg.get("load_model", False): 82 | model.add(FC(500, input_dim=self.MODEL_IN, activation='swish', weight_decay=0.0001)) 83 | model.add(FC(500, activation='swish', weight_decay=0.00025)) 84 | model.add(FC(500, activation='swish', weight_decay=0.00025)) 85 | model.add(FC(self.MODEL_OUT, weight_decay=0.0005)) 86 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001}) 87 | return model 88 | 89 | def gp_constructor(self, model_init_cfg): 90 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 91 | name="model", 92 | kernel_class=get_required_argument(model_init_cfg, "kernel_class", "Must provide kernel class"), 93 | kernel_args=model_init_cfg.get("kernel_args", {}), 94 | num_inducing_points=get_required_argument( 95 | model_init_cfg, "num_inducing_points", "Must provide number of inducing points." 96 | ), 97 | sess=self.SESS 98 | )) 99 | return model 100 | 101 | @staticmethod 102 | def _get_ee_pos(obs, are_tensors=False): 103 | x0, theta = obs[:, :1], obs[:, 1:2] 104 | if are_tensors: 105 | return tf.concat([x0 - 0.6 * tf.sin(theta), -0.6 * tf.cos(theta)], axis=1) 106 | else: 107 | return np.concatenate([x0 - 0.6 * np.sin(theta), -0.6 * np.cos(theta)], axis=1) 108 | 109 | 110 | CONFIG_MODULE = CartpoleConfigModule 111 | -------------------------------------------------------------------------------- /dmbrl/config/halfcheetah_v3.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from dotmap import DotMap 8 | import gym 9 | 10 | from dmbrl.misc.DotmapUtils import get_required_argument 11 | from dmbrl.modeling.layers import FC 12 | import dmbrl.env 13 | 14 | 15 | class HalfCheetahConfigModule: 16 | ENV_NAME = "MBRLHalfCheetah-v3" 17 | TASK_HORIZON = 1000 # Modified from 1000 -> 300 18 | NTRAIN_ITERS = 300 19 | NROLLOUTS_PER_ITER = 1 20 | PLAN_HOR = 15 # Modified from 30 -> 15 21 | MODEL_IN, MODEL_OUT = 24, 18 22 | GP_NINDUCING_POINTS = 300 23 | 24 | def __init__(self): 25 | self.ENV = gym.make(self.ENV_NAME) 26 | cfg = tf.ConfigProto() 27 | cfg.gpu_options.allow_growth = True 28 | self.SESS = tf.Session(config=cfg) 29 | self.NN_TRAIN_CFG = {"epochs": 5} 30 | self.OPT_CFG = { 31 | "Random": { 32 | "popsize": 2500 33 | }, 34 | "CEM": { 35 | "popsize": 500, 36 | "num_elites": 50, 37 | "max_iters": 5, 38 | "alpha": 0.1 39 | } 40 | } 41 | 42 | @staticmethod 43 | def obs_preproc(obs): 44 | if isinstance(obs, np.ndarray): 45 | return np.concatenate([obs[:, 1:2], np.sin(obs[:, 2:3]), np.cos(obs[:, 2:3]), obs[:, 3:]], axis=1) 46 | else: 47 | return tf.concat([obs[:, 1:2], tf.sin(obs[:, 2:3]), tf.cos(obs[:, 2:3]), obs[:, 3:]], axis=1) 48 | 49 | @staticmethod 50 | def obs_postproc(obs, pred): 51 | if isinstance(obs, np.ndarray): 52 | return np.concatenate([pred[:, :1], obs[:, 1:] + pred[:, 1:]], axis=1) 53 | else: 54 | return tf.concat([pred[:, :1], obs[:, 1:] + pred[:, 1:]], axis=1) 55 | 56 | @staticmethod 57 | def targ_proc(obs, next_obs): 58 | return np.concatenate([next_obs[:, :1], next_obs[:, 1:] - obs[:, 1:]], axis=1) 59 | 60 | @staticmethod 61 | def obs_cost_fn(obs): 62 | return -obs[:, 0] 63 | 64 | @staticmethod 65 | def ac_cost_fn(acs): 66 | if isinstance(acs, np.ndarray): 67 | return 0.1 * np.sum(np.square(acs), axis=1) 68 | else: 69 | return 0.1 * tf.reduce_sum(tf.square(acs), axis=1) 70 | 71 | def nn_constructor(self, model_init_cfg): 72 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 73 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"), 74 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False), 75 | model_dir=model_init_cfg.get("model_dir", None) 76 | )) 77 | if not model_init_cfg.get("load_model", False): 78 | model.add(FC(200, input_dim=self.MODEL_IN, activation="swish", weight_decay=0.000025)) 79 | model.add(FC(200, activation="swish", weight_decay=0.00005)) 80 | model.add(FC(200, activation="swish", weight_decay=0.000075)) 81 | model.add(FC(200, activation="swish", weight_decay=0.000075)) 82 | model.add(FC(self.MODEL_OUT, weight_decay=0.0001)) 83 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001}) 84 | return model 85 | 86 | def gp_constructor(self, model_init_cfg): 87 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 88 | name="model", 89 | kernel_class=get_required_argument(model_init_cfg, "kernel_class", "Must provide kernel class"), 90 | kernel_args=model_init_cfg.get("kernel_args", {}), 91 | num_inducing_points=get_required_argument( 92 | model_init_cfg, "num_inducing_points", "Must provide number of inducing points." 93 | ), 94 | sess=self.SESS 95 | )) 96 | return model 97 | 98 | 99 | CONFIG_MODULE = HalfCheetahConfigModule 100 | -------------------------------------------------------------------------------- /dmbrl/config/hopper.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from dotmap import DotMap 8 | import gym 9 | 10 | from dmbrl.misc.DotmapUtils import get_required_argument 11 | from dmbrl.modeling.layers import FC 12 | import dmbrl.env 13 | 14 | 15 | class HopperConfigModule: 16 | ENV_NAME = "MBRLHopper-v3" 17 | TASK_HORIZON = 1000 # Modified from 1000 -> 300 18 | NTRAIN_ITERS = 300 19 | NROLLOUTS_PER_ITER = 1 20 | PLAN_HOR = 15 # Modified from 30 -> 15 21 | MODEL_IN, MODEL_OUT = 14, 12 22 | GP_NINDUCING_POINTS = 300 23 | 24 | def __init__(self): 25 | self.ENV = gym.make(self.ENV_NAME) 26 | cfg = tf.ConfigProto() 27 | cfg.gpu_options.allow_growth = True 28 | self.SESS = tf.Session(config=cfg) 29 | self.NN_TRAIN_CFG = {"epochs": 5} 30 | self.OPT_CFG = { 31 | "Random": { 32 | "popsize": 2500 33 | }, 34 | "CEM": { 35 | "popsize": 500, 36 | "num_elites": 50, 37 | "max_iters": 5, 38 | "alpha": 0.1 39 | } 40 | } 41 | 42 | @staticmethod 43 | def obs_preproc(obs): 44 | return obs[:, 1:] 45 | 46 | @staticmethod 47 | def obs_postproc(obs, pred): 48 | if isinstance(obs, np.ndarray): 49 | return np.concatenate([pred[:, :1], obs[:, 1:] + pred[:, 1:]], axis=1) 50 | else: 51 | return tf.concat([pred[:, :1], obs[:, 1:] + pred[:, 1:]], axis=1) 52 | 53 | @staticmethod 54 | def targ_proc(obs, next_obs): 55 | return np.concatenate([next_obs[:, :1], next_obs[:, 1:] - obs[:, 1:]], axis=1) 56 | 57 | @staticmethod 58 | def obs_cost_fn(obs): 59 | obs_cost = - obs[:, 0] 60 | return obs_cost 61 | # height, ang = obs[1:3] 62 | # if isinstance(obs, np.ndarray): 63 | # alive = (np.isfinite(obs).all() and (np.abs(obs[2:]) < 100).all() and 64 | # (height > .7) and (np.abs(ang) < .2)) 65 | # else: 66 | # alive = (tf.reduce_all(tf.math.is_inf(obs)) and (tf.reduce_all(tf.math.abs(obs[2:]) < 100)) and 67 | # (height > .7) and (tf.math.abs(ang) < .2)) 68 | # alive_bonus = 1.0 if alive else 0.0 69 | 70 | # return obs_cost - alive_bonus 71 | 72 | @staticmethod 73 | def ac_cost_fn(acs): 74 | if isinstance(acs, np.ndarray): 75 | return 1e-3 * np.sum(np.square(acs), axis=1) 76 | else: 77 | return 1e-3 * tf.reduce_sum(tf.square(acs), axis=1) 78 | 79 | def nn_constructor(self, model_init_cfg): 80 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 81 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"), 82 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False), 83 | model_dir=model_init_cfg.get("model_dir", None) 84 | )) 85 | if not model_init_cfg.get("load_model", False): 86 | model.add(FC(200, input_dim=self.MODEL_IN, activation="swish", weight_decay=0.000025)) 87 | model.add(FC(200, activation="swish", weight_decay=0.00005)) 88 | model.add(FC(200, activation="swish", weight_decay=0.000075)) 89 | model.add(FC(200, activation="swish", weight_decay=0.000075)) 90 | model.add(FC(self.MODEL_OUT, weight_decay=0.0001)) 91 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001}) 92 | return model 93 | 94 | def gp_constructor(self, model_init_cfg): 95 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 96 | name="model", 97 | kernel_class=get_required_argument(model_init_cfg, "kernel_class", "Must provide kernel class"), 98 | kernel_args=model_init_cfg.get("kernel_args", {}), 99 | num_inducing_points=get_required_argument( 100 | model_init_cfg, "num_inducing_points", "Must provide number of inducing points." 101 | ), 102 | sess=self.SESS 103 | )) 104 | return model 105 | 106 | 107 | CONFIG_MODULE = HopperConfigModule 108 | -------------------------------------------------------------------------------- /dmbrl/config/pusher.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from dotmap import DotMap 8 | import gym 9 | 10 | from dmbrl.misc.DotmapUtils import get_required_argument 11 | from dmbrl.modeling.layers import FC 12 | import dmbrl.env 13 | 14 | 15 | class PusherConfigModule: 16 | ENV_NAME = "MBRLPusher-v0" 17 | TASK_HORIZON = 150 18 | NTRAIN_ITERS = 100 19 | NROLLOUTS_PER_ITER = 1 20 | PLAN_HOR = 25 # Modified from 25 to 20 21 | MODEL_IN, MODEL_OUT = 27, 20 22 | GP_NINDUCING_POINTS = 200 23 | 24 | def __init__(self): 25 | self.ENV = gym.make(self.ENV_NAME) 26 | cfg = tf.ConfigProto() 27 | cfg.gpu_options.allow_growth = True 28 | self.SESS = tf.Session(config=cfg) 29 | self.NN_TRAIN_CFG = {"epochs": 5} 30 | self.OPT_CFG = { 31 | "Random": { 32 | "popsize": 2500 33 | }, 34 | "CEM": { 35 | "popsize": 500, 36 | "num_elites": 50, 37 | "max_iters": 5, 38 | "alpha": 0.1 39 | } 40 | } 41 | 42 | @staticmethod 43 | def obs_postproc(obs, pred): 44 | return obs + pred 45 | 46 | @staticmethod 47 | def targ_proc(obs, next_obs): 48 | return next_obs - obs 49 | 50 | def obs_cost_fn(self, obs): 51 | to_w, og_w = 0.5, 1.25 52 | tip_pos, obj_pos, goal_pos = obs[:, 14:17], obs[:, 17:20], self.ENV.ac_goal_pos 53 | 54 | if isinstance(obs, np.ndarray): 55 | tip_obj_dist = np.sum(np.abs(tip_pos - obj_pos), axis=1) 56 | obj_goal_dist = np.sum(np.abs(goal_pos - obj_pos), axis=1) 57 | return to_w * tip_obj_dist + og_w * obj_goal_dist 58 | else: 59 | tip_obj_dist = tf.reduce_sum(tf.abs(tip_pos - obj_pos), axis=1) 60 | obj_goal_dist = tf.reduce_sum(tf.abs(goal_pos - obj_pos), axis=1) 61 | return to_w * tip_obj_dist + og_w * obj_goal_dist 62 | 63 | @staticmethod 64 | def ac_cost_fn(acs): 65 | if isinstance(acs, np.ndarray): 66 | return 0.1 * np.sum(np.square(acs), axis=1) 67 | else: 68 | return 0.1 * tf.reduce_sum(tf.square(acs), axis=1) 69 | 70 | def nn_constructor(self, model_init_cfg): 71 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 72 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"), 73 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False), 74 | model_dir=model_init_cfg.get("model_dir", None) 75 | )) 76 | if not model_init_cfg.get("load_model", False): 77 | model.add(FC(200, input_dim=self.MODEL_IN, activation="swish", weight_decay=0.00025)) 78 | model.add(FC(200, activation="swish", weight_decay=0.0005)) 79 | model.add(FC(200, activation="swish", weight_decay=0.0005)) 80 | model.add(FC(self.MODEL_OUT, weight_decay=0.00075)) 81 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001}) 82 | return model 83 | 84 | def gp_constructor(self, model_init_cfg): 85 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 86 | name="model", 87 | kernel_class=get_required_argument(model_init_cfg, "kernel_class", "Must provide kernel class"), 88 | kernel_args=model_init_cfg.get("kernel_args", {}), 89 | num_inducing_points=get_required_argument( 90 | model_init_cfg, "num_inducing_points", "Must provide number of inducing points." 91 | ), 92 | sess=self.SESS 93 | )) 94 | return model 95 | 96 | def value_net_constructor(self, model_init_cfg): 97 | model = NN(DotMap( 98 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"), 99 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False), 100 | model_dir=model_init_cfg.get("model_dir", None) 101 | )) 102 | if not model_init_cfg.get("load_model", False): 103 | model.add(FC(200, input_dim=self.MODEL_IN, activation="swish", weight_decay=0.00025)) 104 | model.add(FC(200, activation="swish", weight_decay=0.0005)) 105 | model.add(FC(200, activation="swish", weight_decay=0.0005)) 106 | model.add(FC(self.MODEL_OUT, weight_decay=0.00075)) 107 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001}) 108 | return model 109 | 110 | 111 | CONFIG_MODULE = PusherConfigModule 112 | -------------------------------------------------------------------------------- /dmbrl/config/reacher.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from dotmap import DotMap 8 | import gym 9 | 10 | from dmbrl.misc.DotmapUtils import get_required_argument 11 | from dmbrl.modeling.layers import FC 12 | import dmbrl.env 13 | 14 | 15 | class ReacherConfigModule: 16 | ENV_NAME = "MBRLReacher3D-v0" 17 | TASK_HORIZON = 150 18 | NTRAIN_ITERS = 100 19 | NROLLOUTS_PER_ITER = 1 20 | PLAN_HOR = 25 21 | MODEL_IN, MODEL_OUT = 24, 17 22 | GP_NINDUCING_POINTS = 200 23 | 24 | def __init__(self): 25 | self.ENV = gym.make(self.ENV_NAME) 26 | self.ENV.reset() 27 | cfg = tf.ConfigProto() 28 | cfg.gpu_options.allow_growth = True 29 | self.SESS = tf.Session(config=cfg) 30 | self.NN_TRAIN_CFG = {"epochs": 5} 31 | self.OPT_CFG = { 32 | "Random": { 33 | "popsize": 2000 34 | }, 35 | "CEM": { 36 | "popsize": 400, 37 | "num_elites": 40, 38 | "max_iters": 5, 39 | "alpha": 0.1 40 | } 41 | } 42 | self.UPDATE_FNS = [self.update_goal] 43 | 44 | self.goal = tf.Variable(self.ENV.goal, dtype=tf.float32) 45 | self.SESS.run(self.goal.initializer) 46 | 47 | @staticmethod 48 | def obs_postproc(obs, pred): 49 | return obs + pred 50 | 51 | @staticmethod 52 | def targ_proc(obs, next_obs): 53 | return next_obs - obs 54 | 55 | def update_goal(self, sess=None): 56 | if sess is not None: 57 | self.goal.load(self.ENV.goal, sess) 58 | 59 | def obs_cost_fn(self, obs): 60 | if isinstance(obs, np.ndarray): 61 | return np.sum(np.square(ReacherConfigModule.get_ee_pos(obs, are_tensors=False) - self.ENV.goal), axis=1) 62 | else: 63 | return tf.reduce_sum(tf.square(ReacherConfigModule.get_ee_pos(obs, are_tensors=True) - self.goal), axis=1) 64 | 65 | @staticmethod 66 | def ac_cost_fn(acs): 67 | if isinstance(acs, np.ndarray): 68 | return 0.01 * np.sum(np.square(acs), axis=1) 69 | else: 70 | return 0.01 * tf.reduce_sum(tf.square(acs), axis=1) 71 | 72 | def nn_constructor(self, model_init_cfg): 73 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 74 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"), 75 | sess=self.SESS, load_model=model_init_cfg.get("load_model", False), 76 | model_dir=model_init_cfg.get("model_dir", None) 77 | )) 78 | if not model_init_cfg.get("load_model", False): 79 | model.add(FC(200, input_dim=self.MODEL_IN, activation="swish", weight_decay=0.00025)) 80 | model.add(FC(200, activation="swish", weight_decay=0.0005)) 81 | model.add(FC(200, activation="swish", weight_decay=0.0005)) 82 | model.add(FC(200, activation="swish", weight_decay=0.0005)) 83 | model.add(FC(self.MODEL_OUT, weight_decay=0.00075)) 84 | model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.00075}) 85 | return model 86 | 87 | def gp_constructor(self, model_init_cfg): 88 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 89 | name="model", 90 | kernel_class=get_required_argument(model_init_cfg, "kernel_class", "Must provide kernel class"), 91 | kernel_args=model_init_cfg.get("kernel_args", {}), 92 | num_inducing_points=get_required_argument( 93 | model_init_cfg, "num_inducing_points", "Must provide number of inducing points." 94 | ), 95 | sess=self.SESS 96 | )) 97 | return model 98 | 99 | @staticmethod 100 | def get_ee_pos(states, are_tensors=False): 101 | theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \ 102 | states[:, :1], states[:, 1:2], states[:, 2:3], states[:, 3:4], states[:, 4:5], states[:, 5:6], states[:, 6:] 103 | if are_tensors: 104 | rot_axis = tf.concat([tf.cos(theta2) * tf.cos(theta1), tf.cos(theta2) * tf.sin(theta1), -tf.sin(theta2)], 105 | axis=1) 106 | rot_perp_axis = tf.concat([-tf.sin(theta1), tf.cos(theta1), tf.zeros(tf.shape(theta1))], axis=1) 107 | cur_end = tf.concat([ 108 | 0.1 * tf.cos(theta1) + 0.4 * tf.cos(theta1) * tf.cos(theta2), 109 | 0.1 * tf.sin(theta1) + 0.4 * tf.sin(theta1) * tf.cos(theta2) - 0.188, 110 | -0.4 * tf.sin(theta2) 111 | ], axis=1) 112 | 113 | for length, hinge, roll in [(0.321, theta4, theta3), (0.16828, theta6, theta5)]: 114 | perp_all_axis = tf.cross(rot_axis, rot_perp_axis) 115 | x = tf.cos(hinge) * rot_axis 116 | y = tf.sin(hinge) * tf.sin(roll) * rot_perp_axis 117 | z = -tf.sin(hinge) * tf.cos(roll) * perp_all_axis 118 | new_rot_axis = x + y + z 119 | new_rot_perp_axis = tf.cross(new_rot_axis, rot_axis) 120 | new_rot_perp_axis = tf.where(tf.less(tf.norm(new_rot_perp_axis, axis=1), 1e-30), 121 | rot_perp_axis, new_rot_perp_axis) 122 | new_rot_perp_axis /= tf.norm(new_rot_perp_axis, axis=1, keepdims=True) 123 | rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis 124 | else: 125 | rot_axis = np.concatenate([np.cos(theta2) * np.cos(theta1), np.cos(theta2) * np.sin(theta1), -np.sin(theta2)], 126 | axis=1) 127 | rot_perp_axis = np.concatenate([-np.sin(theta1), np.cos(theta1), np.zeros(theta1.shape)], axis=1) 128 | cur_end = np.concatenate([ 129 | 0.1 * np.cos(theta1) + 0.4 * np.cos(theta1) * np.cos(theta2), 130 | 0.1 * np.sin(theta1) + 0.4 * np.sin(theta1) * np.cos(theta2) - 0.188, 131 | -0.4 * np.sin(theta2) 132 | ], axis=1) 133 | 134 | for length, hinge, roll in [(0.321, theta4, theta3), (0.16828, theta6, theta5)]: 135 | perp_all_axis = np.cross(rot_axis, rot_perp_axis) 136 | x = np.cos(hinge) * rot_axis 137 | y = np.sin(hinge) * np.sin(roll) * rot_perp_axis 138 | z = -np.sin(hinge) * np.cos(roll) * perp_all_axis 139 | new_rot_axis = x + y + z 140 | new_rot_perp_axis = np.cross(new_rot_axis, rot_axis) 141 | new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = \ 142 | rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] 143 | new_rot_perp_axis /= np.linalg.norm(new_rot_perp_axis, axis=1, keepdims=True) 144 | rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis 145 | 146 | return cur_end 147 | 148 | 149 | CONFIG_MODULE = ReacherConfigModule 150 | -------------------------------------------------------------------------------- /dmbrl/config/template.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from dotmap import DotMap 8 | import gym 9 | 10 | from dmbrl.misc.DotmapUtils import get_required_argument 11 | from dmbrl.modeling.layers import FC 12 | 13 | 14 | class EnvConfigModule: 15 | ENV_NAME = None 16 | TASK_HORIZON = None 17 | NTRAIN_ITERS = None 18 | NROLLOUTS_PER_ITER = None 19 | PLAN_HOR = None 20 | 21 | def __init__(self): 22 | self.ENV = gym.make(self.ENV_NAME) 23 | cfg = tf.ConfigProto() 24 | cfg.gpu_options.allow_growth = True 25 | self.SESS = tf.Session(config=cfg) 26 | self.NN_TRAIN_CFG = {"epochs": None} 27 | self.OPT_CFG = { 28 | "Random": { 29 | "popsize": None 30 | }, 31 | "CEM": { 32 | "popsize": None, 33 | "num_elites": None, 34 | "max_iters": None, 35 | "alpha": None 36 | } 37 | } 38 | self.UPDATE_FNS = [] 39 | 40 | # Fill in other things to be done here. 41 | 42 | @staticmethod 43 | def obs_preproc(obs): 44 | # Note: Must be able to process both NumPy and Tensorflow arrays. 45 | if isinstance(obs, np.ndarray): 46 | raise NotImplementedError() 47 | else: 48 | raise NotImplementedError 49 | 50 | @staticmethod 51 | def obs_postproc(obs, pred): 52 | # Note: Must be able to process both NumPy and Tensorflow arrays. 53 | if isinstance(obs, np.ndarray): 54 | raise NotImplementedError() 55 | else: 56 | raise NotImplementedError() 57 | 58 | @staticmethod 59 | def targ_proc(obs, next_obs): 60 | # Note: Only needs to process NumPy arrays. 61 | raise NotImplementedError() 62 | 63 | @staticmethod 64 | def obs_cost_fn(obs): 65 | # Note: Must be able to process both NumPy and Tensorflow arrays. 66 | if isinstance(obs, np.ndarray): 67 | raise NotImplementedError() 68 | else: 69 | raise NotImplementedError() 70 | 71 | @staticmethod 72 | def ac_cost_fn(acs): 73 | # Note: Must be able to process both NumPy and Tensorflow arrays. 74 | if isinstance(acs, np.ndarray): 75 | raise NotImplementedError() 76 | else: 77 | raise NotImplementedError() 78 | 79 | def nn_constructor(self, model_init_cfg): 80 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 81 | name="model", num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"), 82 | sess=self.SESS 83 | )) 84 | # Construct model below. For example: 85 | # model.add(FC(*args)) 86 | # ... 87 | # model.finalize(tf.train.AdamOptimizer, {"learning_rate": 0.001}) 88 | return model 89 | 90 | 91 | CONFIG_MODULE = EnvConfigModule 92 | 93 | -------------------------------------------------------------------------------- /dmbrl/controllers/Controller.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | 6 | class Controller: 7 | def __init__(self, *args, **kwargs): 8 | """Creates class instance. 9 | """ 10 | pass 11 | 12 | def train(self, obs_trajs, acs_trajs, rews_trajs): 13 | """Trains this controller using lists of trajectories. 14 | """ 15 | raise NotImplementedError("Must be implemented in subclass.") 16 | 17 | def reset(self): 18 | """Resets this controller. 19 | """ 20 | raise NotImplementedError("Must be implemented in subclass.") 21 | 22 | def act(self, obs, t, get_pred_cost=False): 23 | """Performs an action. 24 | """ 25 | raise NotImplementedError("Must be implemented in subclass.") 26 | 27 | def dump_logs(self, primary_logdir, iter_logdir): 28 | """Dumps logs into primary log directory and per-train iteration log directory. 29 | """ 30 | raise NotImplementedError("Must be implemented in subclass.") 31 | -------------------------------------------------------------------------------- /dmbrl/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | from .MPC import MPC 2 | -------------------------------------------------------------------------------- /dmbrl/env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | 4 | register( 5 | id='MBRLCartpole-v0', 6 | entry_point='dmbrl.env.cartpole:CartpoleEnv' 7 | ) 8 | 9 | 10 | register( 11 | id='MBRLReacher3D-v0', 12 | entry_point='dmbrl.env.reacher:Reacher3DEnv' 13 | ) 14 | 15 | 16 | register( 17 | id='MBRLPusher-v0', 18 | entry_point='dmbrl.env.pusher:PusherEnv' 19 | ) 20 | 21 | register( 22 | id='MBRLHalfCheetah-v3', 23 | entry_point='dmbrl.env.half_cheetah_v3:HalfCheetahEnv' 24 | ) 25 | 26 | register( 27 | id='MBRLHopper-v3', 28 | entry_point='dmbrl.env.hopper:HopperEnv' 29 | ) 30 | -------------------------------------------------------------------------------- /dmbrl/env/assets/cartpole.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 35 | 36 | -------------------------------------------------------------------------------- /dmbrl/env/assets/half_cheetah.xml: -------------------------------------------------------------------------------- 1 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /dmbrl/env/assets/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /dmbrl/env/assets/pusher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 102 | -------------------------------------------------------------------------------- /dmbrl/env/cartpole.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | 7 | import numpy as np 8 | from gym import utils 9 | from gym.envs.mujoco import mujoco_env 10 | 11 | 12 | class CartpoleEnv(mujoco_env.MujocoEnv, utils.EzPickle): 13 | PENDULUM_LENGTH = 0.6 14 | 15 | def __init__(self): 16 | utils.EzPickle.__init__(self) 17 | dir_path = os.path.dirname(os.path.realpath(__file__)) 18 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/cartpole.xml' % dir_path, 2) 19 | 20 | def step(self, a): 21 | self.do_simulation(a, self.frame_skip) 22 | ob = self._get_state() 23 | 24 | cost_lscale = CartpoleEnv.PENDULUM_LENGTH 25 | reward = np.exp( 26 | -np.sum(np.square(self._get_ee_pos(ob) - np.array([0.0, CartpoleEnv.PENDULUM_LENGTH]))) / (cost_lscale ** 2) 27 | ) 28 | reward -= 0.01 * np.sum(np.square(a)) 29 | 30 | done = False 31 | return ob, reward, done, {} 32 | 33 | def reset_model(self): 34 | qpos = self.init_qpos + np.random.normal(0, 0.1, np.shape(self.init_qpos)) 35 | qvel = self.init_qvel + np.random.normal(0, 0.1, np.shape(self.init_qvel)) 36 | self.set_state(qpos, qvel) 37 | return self._get_state() 38 | 39 | def _get_state(self): 40 | return np.concatenate([self.data.qpos, self.data.qvel]).ravel() 41 | 42 | @staticmethod 43 | def _get_ee_pos(x): 44 | x0, theta = x[0], x[1] 45 | return np.array([ 46 | x0 - CartpoleEnv.PENDULUM_LENGTH * np.sin(theta), 47 | -CartpoleEnv.PENDULUM_LENGTH * np.cos(theta) 48 | ]) 49 | 50 | def viewer_setup(self): 51 | v = self.viewer 52 | v.cam.trackbodyid = 0 53 | v.cam.distance = v.model.stat.extent 54 | -------------------------------------------------------------------------------- /dmbrl/env/half_cheetah_v3.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | import os 5 | import numpy as np 6 | from gym import utils 7 | from gym.envs.mujoco import mujoco_env 8 | 9 | DEFAULT_CAMERA_CONFIG = { 10 | # 'trackbodyid': 1, 11 | # 'lookat': np.array((0.0, 0.0, 2.0)), 12 | 'distance': 4.0, 13 | } 14 | 15 | 16 | class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): 17 | def __init__(self, 18 | xml_file='half_cheetah.xml', 19 | forward_reward_weight=1.0, 20 | ctrl_cost_weight=0.1, 21 | reset_noise_scale=0.1): 22 | utils.EzPickle.__init__(**locals()) 23 | dir_path = os.path.dirname(os.path.realpath(__file__)) 24 | 25 | self._forward_reward_weight = forward_reward_weight 26 | 27 | self._ctrl_cost_weight = ctrl_cost_weight 28 | 29 | self._reset_noise_scale = reset_noise_scale 30 | self.prev_qpos = None 31 | 32 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/half_cheetah.xml' % dir_path, 5) 33 | 34 | 35 | def control_cost(self, action): 36 | control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) 37 | return control_cost 38 | 39 | def step(self, action): 40 | x_position_before = self.sim.data.qpos[0] 41 | self.prev_qpos = np.copy(self.sim.data.qpos.flat) 42 | 43 | self.do_simulation(action, self.frame_skip) 44 | x_position_after = self.sim.data.qpos[0] 45 | x_velocity = ((x_position_after - x_position_before) 46 | / self.dt) 47 | 48 | ctrl_cost = self.control_cost(action) 49 | 50 | forward_reward = self._forward_reward_weight * x_velocity 51 | 52 | observation = self._get_obs() 53 | reward = forward_reward - ctrl_cost 54 | done = False 55 | info = {} 56 | 57 | return observation, reward, done, info 58 | 59 | 60 | def _get_obs(self): 61 | position = self.sim.data.qpos.flat.copy() 62 | velocity = self.sim.data.qvel.flat.copy() 63 | 64 | return np.concatenate([ 65 | (position[:1] - self.prev_qpos[:1]) / self.dt, 66 | position[1:], 67 | velocity, 68 | ]).ravel() 69 | 70 | def _get_state(self): 71 | return np.concatenate([ 72 | self.sim.data.qpos.flat, 73 | self.sim.data.qvel.flat, 74 | ]) 75 | 76 | def reset_model(self): 77 | noise_low = -self._reset_noise_scale 78 | noise_high = self._reset_noise_scale 79 | 80 | qpos = self.init_qpos + self.np_random.uniform( 81 | low=noise_low, high=noise_high, size=self.model.nq) 82 | qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn( 83 | self.model.nv) 84 | 85 | self.set_state(qpos, qvel) 86 | 87 | observation = self._get_obs() 88 | return observation 89 | 90 | def viewer_setup(self): 91 | for key, value in DEFAULT_CAMERA_CONFIG.items(): 92 | if isinstance(value, np.ndarray): 93 | getattr(self.viewer.cam, key)[:] = value 94 | else: 95 | setattr(self.viewer.cam, key, value) 96 | -------------------------------------------------------------------------------- /dmbrl/env/hopper.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | import os 5 | import numpy as np 6 | from gym import utils 7 | from gym.envs.mujoco import mujoco_env 8 | 9 | 10 | class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): 11 | def __init__(self, 12 | xml_file='hopper.xml'): 13 | 14 | utils.EzPickle.__init__(**locals()) 15 | dir_path = os.path.dirname(os.path.realpath(__file__)) 16 | self.prev_qpos = None 17 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/hopper.xml' % dir_path, 4) 18 | 19 | 20 | def step(self, a): 21 | posbefore = self.sim.data.qpos[0] 22 | self.prev_qpos = np.copy(self.sim.data.qpos.flat) 23 | self.do_simulation(a, self.frame_skip) 24 | posafter, height, ang = self.sim.data.qpos[0:3] 25 | 26 | s = self.state_vector() 27 | 28 | alive = (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and 29 | (height > .7) and (abs(ang) < .2)) 30 | 31 | alive_bonus = 1.0 if alive else 0.0 32 | 33 | forward_reward = (posafter - posbefore) / self.dt 34 | control_cost = 1e-3 * np.square(a).sum() 35 | reward = forward_reward - control_cost + alive_bonus 36 | 37 | done = False 38 | ob = self._get_obs() 39 | return ob, reward, done, {"forward_reward" : forward_reward} 40 | 41 | def _get_obs(self): 42 | 43 | position = self.sim.data.qpos.flat.copy() 44 | velocity = self.sim.data.qvel.flat.copy() 45 | 46 | return np.concatenate([ 47 | (position[:1] - self.prev_qpos[:1]) / self.dt, 48 | position[1:], 49 | np.clip(velocity, -10, 10) 50 | ]) 51 | 52 | def _get_state(self): 53 | return np.concatenate([ 54 | self.sim.data.qpos.flat, 55 | self.sim.data.qvel.flat, 56 | ]) 57 | 58 | def reset_model(self): 59 | qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq) 60 | qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 61 | self.set_state(qpos, qvel) 62 | return self._get_obs() 63 | 64 | def viewer_setup(self): 65 | self.viewer.cam.trackbodyid = 2 66 | self.viewer.cam.distance = self.model.stat.extent * 0.75 67 | self.viewer.cam.lookat[2] = 1.15 68 | self.viewer.cam.elevation = -20 -------------------------------------------------------------------------------- /dmbrl/env/pusher.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | 7 | import numpy as np 8 | from gym import utils 9 | from gym.envs.mujoco import mujoco_env 10 | 11 | 12 | class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle): 13 | def __init__(self): 14 | dir_path = os.path.dirname(os.path.realpath(__file__)) 15 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/pusher.xml' % dir_path, 4) 16 | utils.EzPickle.__init__(self) 17 | self.reset_model() 18 | 19 | def _step(self, a): 20 | obj_pos = self.get_body_com("object"), 21 | vec_1 = obj_pos - self.get_body_com("tips_arm") 22 | vec_2 = obj_pos - self.get_body_com("goal") 23 | 24 | reward_near = -np.sum(np.abs(vec_1)) 25 | reward_dist = -np.sum(np.abs(vec_2)) 26 | reward_ctrl = -np.square(a).sum() 27 | reward = 1.25 * reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near 28 | 29 | self.do_simulation(a, self.frame_skip) 30 | ob = self._get_obs() 31 | done = False 32 | return ob, reward, done, {} 33 | 34 | def viewer_setup(self): 35 | self.viewer.cam.trackbodyid = -1 36 | self.viewer.cam.distance = 4.0 37 | 38 | def reset_model(self): 39 | qpos = self.init_qpos 40 | 41 | self.goal_pos = np.asarray([0, 0]) 42 | self.cylinder_pos = np.array([-0.25, 0.15]) + np.random.normal(0, 0.025, [2]) 43 | 44 | qpos[-4:-2] = self.cylinder_pos 45 | qpos[-2:] = self.goal_pos 46 | qvel = self.init_qvel + self.np_random.uniform(low=-0.005, 47 | high=0.005, size=self.model.nv) 48 | qvel[-4:] = 0 49 | self.set_state(qpos, qvel) 50 | self.ac_goal_pos = self.get_body_com("goal") 51 | 52 | return self._get_obs() 53 | 54 | def _get_obs(self): 55 | return np.concatenate([ 56 | self.model.data.qpos.flat[:7], 57 | self.model.data.qvel.flat[:7], 58 | self.get_body_com("tips_arm"), 59 | self.get_body_com("object"), 60 | ]) 61 | 62 | def _get_state(self): 63 | return np.concatenate([ 64 | self.model.data.qpos.flat, 65 | self.model.data.qvel.flat, 66 | ]) 67 | -------------------------------------------------------------------------------- /dmbrl/env/reacher.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | 7 | import numpy as np 8 | from gym import utils 9 | from gym.envs.mujoco import mujoco_env 10 | 11 | 12 | class Reacher3DEnv(mujoco_env.MujocoEnv, utils.EzPickle): 13 | def __init__(self): 14 | self.viewer = None 15 | utils.EzPickle.__init__(self) 16 | dir_path = os.path.dirname(os.path.realpath(__file__)) 17 | self.goal = np.zeros(3) 18 | mujoco_env.MujocoEnv.__init__(self, os.path.join(dir_path, 'assets/reacher3d.xml'), 2) 19 | 20 | def step(self, a): 21 | self.do_simulation(a, self.frame_skip) 22 | ob = self._get_obs() 23 | reward = -np.sum(np.square(self.get_EE_pos(ob[None]) - self.goal)) 24 | reward -= 0.01 * np.square(a).sum() 25 | done = False 26 | return ob, reward, done, dict(reward_dist=0, reward_ctrl=0) 27 | 28 | def viewer_setup(self): 29 | self.viewer.cam.trackbodyid = 1 30 | self.viewer.cam.distance = 2.5 31 | self.viewer.cam.elevation = -30 32 | self.viewer.cam.azimuth = 270 33 | 34 | def reset_model(self): 35 | qpos, qvel = np.copy(self.init_qpos), np.copy(self.init_qvel) 36 | qpos[-3:] += np.random.normal(loc=0, scale=0.1, size=[3]) 37 | qvel[-3:] = 0 38 | self.goal = qpos[-3:] 39 | self.set_state(qpos, qvel) 40 | return self._get_obs() 41 | 42 | def _get_obs(self): 43 | return np.concatenate([ 44 | self.sim.data.qpos.flat, 45 | self.sim.data.qvel.flat[:-3], 46 | ]) 47 | 48 | def _get_state(self): 49 | return np.concatenate([ 50 | self.sim.data.qpos.flat, 51 | self.sim.data.qvel.flat, 52 | ]) 53 | 54 | def get_EE_pos(self, states): 55 | theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \ 56 | states[:, :1], states[:, 1:2], states[:, 2:3], states[:, 3:4], states[:, 4:5], states[:, 5:6], states[:, 6:] 57 | 58 | rot_axis = np.concatenate([np.cos(theta2) * np.cos(theta1), np.cos(theta2) * np.sin(theta1), -np.sin(theta2)], 59 | axis=1) 60 | rot_perp_axis = np.concatenate([-np.sin(theta1), np.cos(theta1), np.zeros(theta1.shape)], axis=1) 61 | cur_end = np.concatenate([ 62 | 0.1 * np.cos(theta1) + 0.4 * np.cos(theta1) * np.cos(theta2), 63 | 0.1 * np.sin(theta1) + 0.4 * np.sin(theta1) * np.cos(theta2) - 0.188, 64 | -0.4 * np.sin(theta2) 65 | ], axis=1) 66 | 67 | for length, hinge, roll in [(0.321, theta4, theta3), (0.16828, theta6, theta5)]: 68 | perp_all_axis = np.cross(rot_axis, rot_perp_axis) 69 | x = np.cos(hinge) * rot_axis 70 | y = np.sin(hinge) * np.sin(roll) * rot_perp_axis 71 | z = -np.sin(hinge) * np.cos(roll) * perp_all_axis 72 | new_rot_axis = x + y + z 73 | new_rot_perp_axis = np.cross(new_rot_axis, rot_axis) 74 | new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = \ 75 | rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] 76 | new_rot_perp_axis /= np.linalg.norm(new_rot_perp_axis, axis=1, keepdims=True) 77 | rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis 78 | 79 | return cur_end 80 | 81 | -------------------------------------------------------------------------------- /dmbrl/misc/Agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | # from gym.monitoring import VideoRecorder 7 | from dotmap import DotMap 8 | import mujoco_py 9 | import time 10 | 11 | 12 | class Agent: 13 | """An general class for RL agents. 14 | """ 15 | def __init__(self, params): 16 | """Initializes an agent. 17 | 18 | Arguments: 19 | params: (DotMap) A DotMap of agent parameters. 20 | .env: (OpenAI gym environment) The environment for this agent. 21 | .noisy_actions: (bool) Indicates whether random Gaussian noise will 22 | be added to the actions of this agent. 23 | .noise_stddev: (float) The standard deviation to be used for the 24 | action noise if params.noisy_actions is True. 25 | """ 26 | self.env = params.env 27 | self.noise_stddev = params.noise_stddev if params.get("noisy_actions", False) else None 28 | 29 | if isinstance(self.env, DotMap): 30 | raise ValueError("Environment must be provided to the agent at initialization.") 31 | if (not isinstance(self.noise_stddev, float)) and params.get("noisy_actions", False): 32 | raise ValueError("Must provide standard deviation for noise for noisy actions.") 33 | 34 | if self.noise_stddev is not None: 35 | self.dU = self.env.action_space.shape[0] 36 | 37 | def save(self, path): 38 | """ 39 | Save the environment to the given path 40 | """ 41 | pass 42 | 43 | def load(self, path): 44 | """ 45 | Load the environment from the given poth 46 | """ 47 | pass 48 | 49 | def sample(self, horizon, policy, record_fname=None, catch_error=False): 50 | """Samples a rollout from the agent. 51 | 52 | Arguments: 53 | horizon: (int) The length of the rollout to generate from the agent. 54 | policy: (policy) The policy that the agent will use for actions. 55 | record_fname: (str/None) The name of the file to which a recording of the rollout 56 | will be saved. If None, the rollout will not be recorded. 57 | 58 | Returns: (dict) A dictionary containing data from the rollout. 59 | The keys of the dictionary are 'obs', 'ac', and 'reward_sum'. 60 | """ 61 | video_record = record_fname is not None 62 | recorder = None 63 | 64 | times, rewards = [], [] 65 | 66 | O, A, reward_sum, done = [self.env.reset().tolist()], [], 0, False 67 | init_state = self.env._get_state().tolist() 68 | policy.reset() 69 | 70 | if catch_error: 71 | error_state = False 72 | for t in range(horizon): 73 | start = time.time() 74 | A.append(policy.act(O[t], t).tolist()) 75 | times.append(time.time() - start) 76 | try: 77 | if self.noise_stddev is None: 78 | obs, reward, done, info = self.env.step(np.array(A[t])) 79 | else: 80 | action = A[t] + np.random.normal(loc=0, scale=self.noise_stddev, size=[self.dU]) 81 | action = np.minimum(np.maximum(action, self.env.action_space.low), self.env.action_space.high) 82 | obs, reward, done, info = self.env.step(action) 83 | O.append(obs.tolist()) 84 | reward_sum += reward 85 | rewards.append(reward) 86 | 87 | if done: 88 | break 89 | except mujoco_py.builder.MujocoException: 90 | error_state = True 91 | break 92 | 93 | print("Average action selection time: ", np.mean(times)) 94 | print("Rollout length: ", len(A)) 95 | 96 | return { 97 | "obs": O, 98 | "ac": A, 99 | "reward_sum": reward_sum, 100 | "rewards": rewards, 101 | "init_state": init_state, 102 | "error_state": error_state 103 | } 104 | else: 105 | for t in range(horizon): 106 | start = time.time() 107 | A.append(policy.act(O[t], t).tolist()) 108 | times.append(time.time() - start) 109 | 110 | if self.noise_stddev is None: 111 | obs, reward, done, info = self.env.step(np.array(A[t])) 112 | else: 113 | action = A[t] + np.random.normal(loc=0, scale=self.noise_stddev, size=[self.dU]) 114 | action = np.minimum(np.maximum(action, self.env.action_space.low), self.env.action_space.high) 115 | obs, reward, done, info = self.env.step(action) 116 | O.append(obs.tolist()) 117 | reward_sum += reward 118 | rewards.append(reward) 119 | if done: 120 | break 121 | 122 | print("Average action selection time: ", np.mean(times)) 123 | print("Rollout length: ", len(A)) 124 | 125 | return { 126 | "obs": O, 127 | "ac": A, 128 | "reward_sum": reward_sum, 129 | "rewards": rewards, 130 | "init_state": init_state 131 | } 132 | -------------------------------------------------------------------------------- /dmbrl/misc/DotmapUtils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | 6 | def get_required_argument(dotmap, key, message, default=None): 7 | val = dotmap.get(key, default) 8 | if val is default: 9 | raise ValueError(message) 10 | return val 11 | -------------------------------------------------------------------------------- /dmbrl/misc/MBExp.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | from time import time, localtime, strftime 7 | 8 | import numpy as np 9 | from scipy.io import savemat 10 | from dotmap import DotMap 11 | 12 | from dmbrl.misc.DotmapUtils import get_required_argument 13 | from dmbrl.misc.Agent import Agent 14 | 15 | 16 | class MBExperiment: 17 | def __init__(self, params): 18 | """Initializes class instance. 19 | 20 | Argument: 21 | params (DotMap): A DotMap containing the following: 22 | .sim_cfg: 23 | .env (gym.env): Environment for this experiment 24 | .task_hor (int): Task horizon 25 | .stochastic (bool): (optional) If True, agent adds noise to its actions. 26 | Must provide noise_std (see below). Defaults to False. 27 | .noise_std (float): for stochastic agents, noise of the form N(0, noise_std^2I) 28 | will be added. 29 | 30 | .exp_cfg: 31 | .ntrain_iters (int): Number of training iterations to be performed. 32 | .nrollouts_per_iter (int): (optional) Number of rollouts done between training 33 | iterations. Defaults to 1. 34 | .ninit_rollouts (int): (optional) Number of initial rollouts. Defaults to 1. 35 | .policy (controller): Policy that will be trained. 36 | 37 | .log_cfg: 38 | .logdir (str): Parent of directory path where experiment data will be saved. 39 | Experiment will be saved in logdir/ 40 | .nrecord (int): (optional) Number of rollouts to record for every iteration. 41 | Defaults to 0. 42 | .neval (int): (optional) Number of rollouts for performance evaluation. 43 | Defaults to 1. 44 | """ 45 | self.env = get_required_argument(params.sim_cfg, "env", "Must provide environment.") 46 | self.task_hor = get_required_argument(params.sim_cfg, "task_hor", "Must provide task horizon.") 47 | if params.sim_cfg.get("stochastic", False): 48 | self.agent = Agent(DotMap( 49 | env=self.env, noisy_actions=True, 50 | noise_stddev=get_required_argument( 51 | params.sim_cfg, 52 | "noise_std", 53 | "Must provide noise standard deviation in the case of a stochastic environment." 54 | ) 55 | )) 56 | else: 57 | self.agent = Agent(DotMap(env=self.env, noisy_actions=False)) 58 | 59 | self.ntrain_iters = get_required_argument( 60 | params.exp_cfg, "ntrain_iters", "Must provide number of training iterations." 61 | ) 62 | self.nrollouts_per_iter = params.exp_cfg.get("nrollouts_per_iter", 1) 63 | self.ninit_rollouts = params.exp_cfg.get("ninit_rollouts", 1) 64 | self.policy = get_required_argument(params.exp_cfg, "policy", "Must provide a policy.") 65 | 66 | self.logdir = os.path.join( 67 | get_required_argument(params.log_cfg, "logdir", "Must provide log parent directory."), 68 | strftime("%Y-%m-%d--%H:%M:%S", localtime()) 69 | ) 70 | self.nrecord = params.log_cfg.get("nrecord", 0) 71 | self.neval = params.log_cfg.get("neval", 1) 72 | 73 | def run_experiment(self): 74 | """Perform experiment. 75 | """ 76 | os.makedirs(self.logdir, exist_ok=True) 77 | 78 | traj_obs, traj_acs, traj_rets, traj_rews = [], [], [], [] 79 | 80 | # Perform initial rollouts 81 | samples = [] 82 | for i in range(self.ninit_rollouts): 83 | samples.append( 84 | self.agent.sample( 85 | self.task_hor, self.policy 86 | ) 87 | ) 88 | traj_obs.append(samples[-1]["obs"]) 89 | traj_acs.append(samples[-1]["ac"]) 90 | traj_rews.append(samples[-1]["rewards"]) 91 | 92 | if self.ninit_rollouts > 0: 93 | self.policy.train( 94 | [sample["obs"] for sample in samples], 95 | [sample["ac"] for sample in samples], 96 | [sample["rewards"] for sample in samples] 97 | ) 98 | 99 | # Training loop 100 | for i in range(self.ntrain_iters): 101 | print("####################################################################") 102 | print("Starting training iteration %d." % (i + 1)) 103 | 104 | iter_dir = os.path.join(self.logdir, "train_iter%d" % (i + 1)) 105 | os.makedirs(iter_dir, exist_ok=True) 106 | 107 | samples = [] 108 | for j in range(self.nrecord): 109 | samples.append( 110 | self.agent.sample( 111 | self.task_hor, self.policy, 112 | os.path.join(iter_dir, "rollout%d.mp4" % j) 113 | ) 114 | ) 115 | if self.nrecord > 0: 116 | for item in filter(lambda f: f.endswith(".json"), os.listdir(iter_dir)): 117 | os.remove(os.path.join(iter_dir, item)) 118 | for j in range(max(self.neval, self.nrollouts_per_iter) - self.nrecord): 119 | samples.append( 120 | self.agent.sample( 121 | self.task_hor, self.policy 122 | ) 123 | ) 124 | print("Rewards obtained:", [sample["reward_sum"] for sample in samples[:self.neval]]) 125 | traj_obs.extend([sample["obs"] for sample in samples[:self.nrollouts_per_iter]]) 126 | traj_acs.extend([sample["ac"] for sample in samples[:self.nrollouts_per_iter]]) 127 | traj_rets.extend([sample["reward_sum"] for sample in samples[:self.neval]]) 128 | traj_rews.extend([sample["rewards"] for sample in samples[:self.nrollouts_per_iter]]) 129 | samples = samples[:self.nrollouts_per_iter] 130 | 131 | self.policy.dump_logs(self.logdir, iter_dir) 132 | savemat( 133 | os.path.join(self.logdir, "logs.mat"), 134 | { 135 | "observations": traj_obs, 136 | "actions": traj_acs, 137 | "returns": traj_rets, 138 | "rewards": traj_rews 139 | } 140 | ) 141 | # Delete iteration directory if not used 142 | if len(os.listdir(iter_dir)) == 0: 143 | os.rmdir(iter_dir) 144 | 145 | if i < self.ntrain_iters - 1: 146 | self.policy.train( 147 | [sample["obs"] for sample in samples], 148 | [sample["ac"] for sample in samples], 149 | [sample["rewards"] for sample in samples] 150 | ) 151 | -------------------------------------------------------------------------------- /dmbrl/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/HPO_for_RL/d82c7ddd6fe19834c088137570530f11761d9390/dmbrl/misc/__init__.py -------------------------------------------------------------------------------- /dmbrl/misc/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .cem import CEMOptimizer 2 | from .random import RandomOptimizer -------------------------------------------------------------------------------- /dmbrl/misc/optimizers/cem.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import scipy.stats as stats 8 | 9 | from .optimizer import Optimizer 10 | 11 | 12 | class CEMOptimizer(Optimizer): 13 | """A Tensorflow-compatible CEM optimizer. 14 | """ 15 | def __init__(self, sol_dim, max_iters, popsize, num_elites, tf_session=None, 16 | upper_bound=None, lower_bound=None, epsilon=0.001, alpha=0.25): 17 | """Creates an instance of this class. 18 | 19 | Arguments: 20 | sol_dim (int): The dimensionality of the problem space 21 | max_iters (int): The maximum number of iterations to perform during optimization 22 | popsize (int): The number of candidate solutions to be sampled at every iteration 23 | num_elites (int): The number of top solutions that will be used to obtain the distribution 24 | at the next iteration. 25 | tf_session (tf.Session): (optional) Session to be used for this optimizer. Defaults to None, 26 | in which case any functions passed in cannot be tf.Tensor-valued. 27 | upper_bound (np.array): An array of upper bounds 28 | lower_bound (np.array): An array of lower bounds 29 | epsilon (float): A minimum variance. If the maximum variance drops below epsilon, optimization is 30 | stopped. 31 | alpha (float): Controls how much of the previous mean and variance is used for the next iteration. 32 | next_mean = alpha * old_mean + (1 - alpha) * elite_mean, and similarly for variance. 33 | """ 34 | super().__init__() 35 | self.sol_dim, self.max_iters, self.popsize, self.num_elites = sol_dim, max_iters, popsize, num_elites 36 | self.ub, self.lb = upper_bound, lower_bound 37 | self.epsilon, self.alpha = epsilon, alpha 38 | self.tf_sess = tf_session 39 | 40 | if num_elites > popsize: 41 | raise ValueError("Number of elites must be at most the population size.") 42 | 43 | if self.tf_sess is not None: 44 | with self.tf_sess.graph.as_default(): 45 | with tf.variable_scope("CEMSolver") as scope: 46 | self.init_mean = tf.placeholder(dtype=tf.float32, shape=[sol_dim]) 47 | self.init_var = tf.placeholder(dtype=tf.float32, shape=[sol_dim]) 48 | 49 | self.num_opt_iters, self.mean, self.var = None, None, None 50 | self.tf_compatible, self.cost_function = None, None 51 | 52 | def setup(self, cost_function, tf_compatible): 53 | """Sets up this optimizer using a given cost function. 54 | 55 | Arguments: 56 | cost_function (func): A function for computing costs over a batch of candidate solutions. 57 | tf_compatible (bool): True if the cost function provided is tf.Tensor-valued. 58 | 59 | Returns: None 60 | """ 61 | if tf_compatible and self.tf_sess is None: 62 | raise RuntimeError("Cannot pass in a tf.Tensor-valued cost function without passing in a TensorFlow " 63 | "session into the constructor") 64 | 65 | self.tf_compatible = tf_compatible 66 | 67 | if not tf_compatible: 68 | self.cost_function = cost_function 69 | else: 70 | def continue_optimization(t, mean, var, best_val, best_sol): 71 | return tf.logical_and(tf.less(t, self.max_iters), tf.reduce_max(var) > self.epsilon) 72 | 73 | def iteration(t, mean, var, best_val, best_sol): 74 | lb_dist, ub_dist = mean - self.lb, self.ub - mean 75 | constrained_var = tf.minimum(tf.minimum(tf.square(lb_dist / 2), tf.square(ub_dist / 2)), var) 76 | samples = tf.truncated_normal([self.popsize, self.sol_dim], mean, tf.sqrt(constrained_var)) 77 | 78 | costs = cost_function(samples) 79 | values, indices = tf.nn.top_k(-costs, k=self.num_elites, sorted=True) 80 | 81 | best_val, best_sol = tf.cond( 82 | tf.less(-values[0], best_val), 83 | lambda: (-values[0], samples[indices[0]]), 84 | lambda: (best_val, best_sol) 85 | ) 86 | 87 | elites = tf.gather(samples, indices) 88 | new_mean = tf.reduce_mean(elites, axis=0) 89 | new_var = tf.reduce_mean(tf.square(elites - new_mean), axis=0) 90 | 91 | mean = self.alpha * mean + (1 - self.alpha) * new_mean 92 | var = self.alpha * var + (1 - self.alpha) * new_var 93 | 94 | return t + 1, mean, var, best_val, best_sol 95 | 96 | with self.tf_sess.graph.as_default(): 97 | self.num_opt_iters, self.mean, self.var, self.best_val, self.best_sol = tf.while_loop( 98 | cond=continue_optimization, body=iteration, 99 | loop_vars=[0, self.init_mean, self.init_var, float("inf"), self.init_mean] 100 | ) 101 | 102 | def reset(self): 103 | pass 104 | 105 | def obtain_solution(self, init_mean, init_var): 106 | """Optimizes the cost function using the provided initial candidate distribution 107 | 108 | Arguments: 109 | init_mean (np.ndarray): The mean of the initial candidate distribution. 110 | init_var (np.ndarray): The variance of the initial candidate distribution. 111 | """ 112 | if self.tf_compatible: 113 | sol, solvar = self.tf_sess.run( 114 | [self.mean, self.var], 115 | feed_dict={self.init_mean: init_mean, self.init_var: init_var} 116 | ) 117 | else: 118 | mean, var, t = init_mean, init_var, 0 119 | X = stats.truncnorm(-2, 2, loc=np.zeros_like(mean), scale=np.ones_like(mean)) 120 | 121 | while (t < self.max_iters) and np.max(var) > self.epsilon: 122 | lb_dist, ub_dist = mean - self.lb, self.ub - mean 123 | constrained_var = np.minimum(np.minimum(np.square(lb_dist / 2), np.square(ub_dist / 2)), var) 124 | 125 | samples = X.rvs(size=[self.popsize, self.sol_dim]) * np.sqrt(constrained_var) + mean 126 | costs = self.cost_function(samples) 127 | elites = samples[np.argsort(costs)][:self.num_elites] 128 | 129 | new_mean = np.mean(elites, axis=0) 130 | new_var = np.var(elites, axis=0) 131 | 132 | mean = self.alpha * mean + (1 - self.alpha) * new_mean 133 | var = self.alpha * var + (1 - self.alpha) * new_var 134 | 135 | t += 1 136 | sol, solvar = mean, var 137 | return sol 138 | 139 | -------------------------------------------------------------------------------- /dmbrl/misc/optimizers/optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | 6 | class Optimizer: 7 | def __init__(self, *args, **kwargs): 8 | pass 9 | 10 | def setup(self, cost_function, tf_compatible): 11 | raise NotImplementedError("Must be implemented in subclass.") 12 | 13 | def reset(self): 14 | raise NotImplementedError("Must be implemented in subclass.") 15 | 16 | def obtain_solution(self, *args, **kwargs): 17 | raise NotImplementedError("Must be implemented in subclass.") 18 | -------------------------------------------------------------------------------- /dmbrl/misc/optimizers/random.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from .optimizer import Optimizer 9 | 10 | 11 | class RandomOptimizer(Optimizer): 12 | def __init__(self, sol_dim, popsize, tf_session, 13 | upper_bound=None, lower_bound=None): 14 | """Creates an instance of this class. 15 | 16 | Arguments: 17 | sol_dim (int): The dimensionality of the problem space 18 | popsize (int): The number of candidate solutions to be sampled at every iteration 19 | num_elites (int): The number of top solutions that will be used to obtain the distribution 20 | at the next iteration. 21 | tf_session (tf.Session): (optional) Session to be used for this optimizer. Defaults to None, 22 | in which case any functions passed in cannot be tf.Tensor-valued. 23 | upper_bound (np.array): An array of upper bounds 24 | lower_bound (np.array): An array of lower bounds 25 | """ 26 | super().__init__() 27 | self.sol_dim = sol_dim 28 | self.popsize = popsize 29 | self.ub, self.lb = upper_bound, lower_bound 30 | self.tf_sess = tf_session 31 | self.solution = None 32 | self.tf_compatible, self.cost_function = None, None 33 | 34 | def setup(self, cost_function, tf_compatible): 35 | """Sets up this optimizer using a given cost function. 36 | 37 | Arguments: 38 | cost_function (func): A function for computing costs over a batch of candidate solutions. 39 | tf_compatible (bool): True if the cost function provided is tf.Tensor-valued. 40 | 41 | Returns: None 42 | """ 43 | if tf_compatible and self.tf_sess is None: 44 | raise RuntimeError("Cannot pass in a tf.Tensor-valued cost function without passing in a TensorFlow " 45 | "session into the constructor") 46 | 47 | if not tf_compatible: 48 | self.tf_compatible = False 49 | self.cost_function = cost_function 50 | else: 51 | with self.tf_sess.graph.as_default(): 52 | self.tf_compatible = True 53 | solutions = tf.random_uniform([self.popsize, self.sol_dim], self.ub, self.lb) 54 | costs = cost_function(solutions) 55 | self.solution = solutions[tf.cast(tf.argmin(costs), tf.int32)] 56 | 57 | def reset(self): 58 | pass 59 | 60 | def obtain_solution(self, *args, **kwargs): 61 | """Optimizes the cost function provided in setup(). 62 | 63 | Arguments: 64 | init_mean (np.ndarray): The mean of the initial candidate distribution. 65 | init_var (np.ndarray): The variance of the initial candidate distribution. 66 | """ 67 | if self.tf_compatible: 68 | return self.tf_sess.run(self.solution) 69 | else: 70 | solutions = np.random.uniform(self.lb, self.ub, [self.popsize, self.sol_dim]) 71 | costs = self.cost_function(solutions) 72 | return solutions[np.argmin(costs)] 73 | -------------------------------------------------------------------------------- /dmbrl/misc/render.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import argparse 7 | import pprint 8 | 9 | from dotmap import DotMap 10 | 11 | from dmbrl.misc.MBExp import MBExperiment 12 | from dmbrl.controllers.MPC import MPC 13 | from dmbrl.config import create_config 14 | 15 | 16 | def main(env, ctrl_type, ctrl_args, overrides, model_dir, logdir): 17 | ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args}) 18 | 19 | overrides.append(["ctrl_cfg.prop_cfg.model_init_cfg.model_dir", model_dir]) 20 | overrides.append(["ctrl_cfg.prop_cfg.model_init_cfg.load_model", "True"]) 21 | overrides.append(["ctrl_cfg.prop_cfg.model_pretrained", "True"]) 22 | overrides.append(["exp_cfg.exp_cfg.ninit_rollouts", "0"]) 23 | overrides.append(["exp_cfg.exp_cfg.ntrain_iters", "1"]) 24 | overrides.append(["exp_cfg.log_cfg.nrecord", "1"]) 25 | 26 | cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir) 27 | cfg.pprint() 28 | 29 | if ctrl_type == "MPC": 30 | cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg) 31 | exp = MBExperiment(cfg.exp_cfg) 32 | 33 | os.makedirs(exp.logdir) 34 | with open(os.path.join(exp.logdir, "config.txt"), "w") as f: 35 | f.write(pprint.pformat(cfg.toDict())) 36 | 37 | exp.run_experiment() 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('-env', type=str, required=True) 43 | parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[]) 44 | parser.add_argument('-o', '--override', action='append', nargs=2, default=[]) 45 | parser.add_argument('-model-dir', type=str, required=True) 46 | parser.add_argument('-logdir', type=str, required=True) 47 | args = parser.parse_args() 48 | 49 | main(args.env, "MPC", args.ctrl_arg, args.override, args.model_dir, args.logdir) 50 | -------------------------------------------------------------------------------- /dmbrl/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/HPO_for_RL/d82c7ddd6fe19834c088137570530f11761d9390/dmbrl/modeling/__init__.py -------------------------------------------------------------------------------- /dmbrl/modeling/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .FC import FC -------------------------------------------------------------------------------- /dmbrl/modeling/models/TFGP.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | import gpflow 9 | 10 | from dmbrl.misc.DotmapUtils import get_required_argument 11 | 12 | 13 | class TFGP: 14 | def __init__(self, params): 15 | """Initializes class instance. 16 | 17 | Arguments: 18 | params 19 | .name (str): Model name 20 | .kernel_class (class): Kernel class 21 | .kernel_args (args): Kernel args 22 | .num_inducing_points (int): Number of inducing points 23 | .sess (tf.Session): Tensorflow session 24 | """ 25 | self.name = params.get("name", "GP") 26 | self.kernel_class = get_required_argument(params, "kernel_class", "Must provide kernel class.") 27 | self.kernel_args = params.get("kernel_args", {}) 28 | self.num_inducing_points = get_required_argument( 29 | params, "num_inducing_points", "Must provide number of inducing points." 30 | ) 31 | 32 | if params.get("sess", None) is None: 33 | config = tf.ConfigProto() 34 | config.gpu_options.allow_growth = True 35 | self._sess = tf.Session(config=config) 36 | else: 37 | self._sess = params.get("sess") 38 | 39 | with self._sess.as_default(): 40 | with tf.variable_scope(self.name): 41 | output_dim = self.kernel_args["output_dim"] 42 | del self.kernel_args["output_dim"] 43 | self.model = gpflow.models.SGPR( 44 | np.zeros([1, self.kernel_args["input_dim"]]), 45 | np.zeros([1, output_dim]), 46 | kern=self.kernel_class(**self.kernel_args), 47 | Z=np.zeros([self.num_inducing_points, self.kernel_args["input_dim"]]) 48 | ) 49 | self.model.initialize() 50 | 51 | @property 52 | def is_probabilistic(self): 53 | return True 54 | 55 | @property 56 | def sess(self): 57 | return self._sess 58 | 59 | @property 60 | def is_tf_model(self): 61 | return True 62 | 63 | def train(self, inputs, targets, 64 | *args, **kwargs): 65 | """Optimizes the parameters of the internal GP model. 66 | 67 | Arguments: 68 | inputs: (np.ndarray) An array of inputs. 69 | targets: (np.ndarray) An array of targets. 70 | num_restarts: (int) The number of times that the optimization of 71 | the GP will be restarted to obtain a good set of parameters. 72 | 73 | Returns: None. 74 | """ 75 | perm = np.random.permutation(inputs.shape[0]) 76 | inputs, targets = inputs[perm], targets[perm] 77 | Z = np.copy(inputs[:self.num_inducing_points]) 78 | if Z.shape[0] < self.num_inducing_points: 79 | Z = np.concatenate([Z, np.zeros([self.num_inducing_points - Z.shape[0], Z.shape[1]])]) 80 | self.model.X = inputs 81 | self.model.Y = targets 82 | self.model.feature.Z = Z 83 | with self.sess.as_default(): 84 | self.model.compile() 85 | print("Optimizing model... ", end="") 86 | gpflow.train.ScipyOptimizer().minimize(self.model) 87 | print("Done.") 88 | 89 | def predict(self, inputs, *args, **kwargs): 90 | """Returns the predictions of this model on inputs. 91 | 92 | Arguments: 93 | inputs: (np.ndarray) The inputs on which predictions will be returned. 94 | ign_var: (bool) If True, only returns the mean prediction 95 | 96 | Returns: (np.ndarrays) The mean and variance of the model on the new points. 97 | """ 98 | if self.model is None: 99 | raise RuntimeError("Cannot make predictions without initial batch of data.") 100 | 101 | with self.sess.as_default(): 102 | mean, var = self.model.predict_y(inputs) 103 | return mean, var 104 | 105 | def create_prediction_tensors(self, inputs, *args, **kwargs): 106 | "" 107 | if self.model is None: 108 | raise RuntimeError("Cannot make predictions without initial batch of data.") 109 | 110 | inputs = tf.cast(inputs, tf.float64) 111 | mean, var = self.model._build_predict(inputs, full_cov=False) 112 | return tf.cast(mean, dtype=tf.float32), tf.cast(var, tf.float32) 113 | 114 | def save(self, *args, **kwargs): 115 | pass 116 | -------------------------------------------------------------------------------- /dmbrl/modeling/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .BNN import BNN 2 | from .NN import NN 3 | from .TFGP import TFGP 4 | -------------------------------------------------------------------------------- /dmbrl/modeling/utils/HPO.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from dmbrl.controllers import MPC 3 | from dmbrl.misc.DotmapUtils import get_required_argument 4 | from dmbrl.modeling.layers import FC 5 | from dotmap import DotMap 6 | 7 | 8 | def build_policy_constructor(exp): 9 | """Returns a function which will translates a parameter vector into a configuration 10 | that can be used to construct a policy. 11 | 12 | Args: 13 | exp: experiment object 14 | 15 | Returns: 16 | translate_into_model: A function which will construct a controller (e.g. MPC) given the hyperparameters 17 | """ 18 | def translate_into_model(param_dict, config_space, reset=True): 19 | """Translate a set of hyperparameters into a policy 20 | 21 | Args: 22 | param_dict: dictonary which contains the hyperparameters of the policy 23 | config_space: Configuration space which will be used to determine the default hyperparameter 24 | reset (bool, optional): If True, reset the default graph of Tensorflow. Defaults to True. 25 | 26 | Returns: 27 | policy: MPC controller 28 | """ 29 | # Clear the graph to avoid conflicts with previous graphs. 30 | if reset: 31 | tf.reset_default_graph() 32 | 33 | # Give parameters names to avoid bugs 34 | config_dict = dict() 35 | config_space = config_space[exp.env_name] 36 | for config_name in config_space.keys(): 37 | if config_name in exp.config_names: 38 | config_dict[config_name] = param_dict[config_name] 39 | else: 40 | config_dict[config_name] = config_space[config_name].default_value 41 | 42 | num_hidden_layers = int(config_dict["num_hidden_layers"]) 43 | hidden_layer_width = int(config_dict["hidden_layer_width"]) 44 | act_idx = config_dict["act_idx"] 45 | model_learning_rate = config_dict["model_learning_rate"] 46 | model_weight_decay = config_dict["model_weight_decay"] 47 | model_opt_idx = config_dict["model_opt_idx"] 48 | num_cem_iters = int(config_dict["num_cem_iters"]) 49 | cem_popsize = int(config_dict["cem_popsize"]) 50 | cem_alpha = config_dict["cem_alpha"] 51 | num_cem_elites = int(config_dict["cem_elites_ratio"]*cem_popsize) 52 | model_train_epoch = int(config_dict["model_train_epoch"]) 53 | plan_hor = int(config_dict["plan_hor"]) 54 | 55 | 56 | # Instantiation of parameter vector 57 | 58 | if exp.opt_type == 'CEM': 59 | overrides = [ 60 | ("ctrl_cfg.opt_cfg.cfg.max_iters", num_cem_iters), 61 | ("ctrl_cfg.opt_cfg.cfg.popsize", cem_popsize), 62 | ("ctrl_cfg.opt_cfg.cfg.num_elites", num_cem_elites), 63 | ("ctrl_cfg.opt_cfg.cfg.alpha", cem_alpha), 64 | ("ctrl_cfg.opt_cfg.plan_hor", plan_hor), 65 | ("ctrl_cfg.prop_cfg.model_train_cfg.epochs", model_train_epoch) 66 | ] 67 | elif exp.opt_type == 'Random': 68 | overrides = [ 69 | ("ctrl_cfg.opt_cfg.cfg.popsize", cem_popsize), 70 | ("ctrl_cfg.opt_cfg.plan_hor", plan_hor), 71 | ("ctrl_cfg.prop_cfg.model_train_cfg.epochs", model_train_epoch) 72 | ] 73 | else: 74 | print(exp.opt_type) 75 | raise(NotImplementedError("Given control type %s is unknown" %exp.opt_type)) 76 | 77 | cfg, cfg_module = exp.cfg_creator(overrides) 78 | 79 | def nn_constructor(model_init_cfg): 80 | model = get_required_argument(model_init_cfg, "model_class", "Must provide model class")(DotMap( 81 | name="model", 82 | num_networks=get_required_argument(model_init_cfg, "num_nets", "Must provide ensemble size"), 83 | sess=cfg_module.SESS, load_model=model_init_cfg.get("load_model", False), 84 | model_dir=model_init_cfg.get("model_dir", None) 85 | )) 86 | if not model_init_cfg.get("load_model", False): 87 | model.add(FC( 88 | hidden_layer_width, input_dim=cfg_module.MODEL_IN, 89 | activation=exp.activation_fns[act_idx], weight_decay=model_weight_decay 90 | )) 91 | for _ in range(num_hidden_layers - 1): 92 | model.add(FC( 93 | hidden_layer_width, activation=exp.activation_fns[act_idx], weight_decay=model_weight_decay 94 | )) 95 | model.add(FC(cfg_module.MODEL_OUT, weight_decay=model_weight_decay)) 96 | model.finalize(exp.model_optimizers[model_opt_idx], {"learning_rate": model_learning_rate}) 97 | return model 98 | 99 | cfg.ctrl_cfg.prop_cfg.model_init_cfg.model_constructor = nn_constructor 100 | 101 | # Build up model 102 | policy = MPC(cfg.ctrl_cfg) 103 | 104 | return policy 105 | 106 | return translate_into_model -------------------------------------------------------------------------------- /dmbrl/modeling/utils/TensorStandardScaler.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | 9 | class TensorStandardScaler: 10 | """Helper class for automatically normalizing inputs into the network. 11 | """ 12 | def __init__(self, x_dim): 13 | """Initializes a scaler. 14 | 15 | Arguments: 16 | x_dim (int): The dimensionality of the inputs into the scaler. 17 | 18 | Returns: None. 19 | """ 20 | self.fitted = False 21 | with tf.variable_scope("Scaler"): 22 | self.mu = tf.get_variable( 23 | name="scaler_mu", shape=[1, x_dim], initializer=tf.constant_initializer(0.0), 24 | trainable=False 25 | ) 26 | self.sigma = tf.get_variable( 27 | name="scaler_std", shape=[1, x_dim], initializer=tf.constant_initializer(1.0), 28 | trainable=False 29 | ) 30 | 31 | self.cached_mu, self.cached_sigma = np.zeros([0, x_dim]), np.ones([1, x_dim]) 32 | 33 | def fit(self, data): 34 | """Runs two ops, one for assigning the mean of the data to the internal mean, and 35 | another for assigning the standard deviation of the data to the internal standard deviation. 36 | This function must be called within a 'with .as_default()' block. 37 | 38 | Arguments: 39 | data (np.ndarray): A numpy array containing the input 40 | 41 | Returns: None. 42 | """ 43 | mu = np.mean(data, axis=0, keepdims=True) 44 | sigma = np.std(data, axis=0, keepdims=True) 45 | sigma[sigma < 1e-12] = 1.0 46 | 47 | self.mu.load(mu) 48 | self.sigma.load(sigma) 49 | self.fitted = True 50 | self.cache() 51 | 52 | def transform(self, data): 53 | """Transforms the input matrix data using the parameters of this scaler. 54 | 55 | Arguments: 56 | data (np.array): A numpy array containing the points to be transformed. 57 | 58 | Returns: (np.array) The transformed dataset. 59 | """ 60 | return (data - self.mu) / self.sigma 61 | 62 | def inverse_transform(self, data): 63 | """Undoes the transformation performed by this scaler. 64 | 65 | Arguments: 66 | data (np.array): A numpy array containing the points to be transformed. 67 | 68 | Returns: (np.array) The transformed dataset. 69 | """ 70 | return self.sigma * data + self.mu 71 | 72 | def get_vars(self): 73 | """Returns a list of variables managed by this object. 74 | 75 | Returns: (list) The list of variables. 76 | """ 77 | return [self.mu, self.sigma] 78 | 79 | def set_vars(self, mu, sigma): 80 | """Set the mu and sigma according to the given value 81 | """ 82 | self.mu.load(mu) 83 | self.sigma.load(sigma) 84 | self.cache() 85 | 86 | def cache(self): 87 | """Caches current values of this scaler. 88 | 89 | Returns: None. 90 | """ 91 | self.cached_mu = self.mu.eval() 92 | self.cached_sigma = self.sigma.eval() 93 | 94 | def load_cache(self): 95 | """Loads values from the cache 96 | 97 | Returns: None. 98 | """ 99 | self.mu.load(self.cached_mu) 100 | self.sigma.load(self.cached_sigma) 101 | -------------------------------------------------------------------------------- /dmbrl/modeling/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .TensorStandardScaler import TensorStandardScaler -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mbrl 2 | channels: 3 | - menpo 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - attrs=19.3.0=py_0 10 | - backcall=0.1.0=py36_0 11 | - blas=1.0=mkl 12 | - bleach=3.1.0=py36_0 13 | - ca-certificates=2021.5.25=h06a4308_1 14 | - certifi=2021.5.30=py36h06a4308_0 15 | - cffi=1.12.3=py36h2e261b9_0 16 | - cudatoolkit=10.0.130=0 17 | - dbus=1.13.12=h746ee38_0 18 | - defusedxml=0.6.0=py_0 19 | - entrypoints=0.3=py36_0 20 | - expat=2.2.6=he6710b0_0 21 | - fontconfig=2.13.0=h9420a91_0 22 | - freetype=2.9.1=h8a8886c_1 23 | - gettext=0.19.8.1=hc5be6a0_1002 24 | - glfw3=3.2.1=0 25 | - glib=2.63.1=h5a9c865_0 26 | - gmp=6.1.2=h6c8ec71_1 27 | - gst-plugins-base=1.14.0=hbbd80ab_1 28 | - gstreamer=1.14.0=hb453b48_1 29 | - icu=58.2=h9c2bf20_1 30 | - importlib_metadata=1.5.0=py36_0 31 | - intel-openmp=2019.4=243 32 | - ipykernel=5.1.4=py36h39e3cac_0 33 | - ipython=7.12.0=py36h5ca1d4c_0 34 | - ipython_genutils=0.2.0=py36_0 35 | - ipywidgets=7.5.1=py_0 36 | - jedi=0.16.0=py36_0 37 | - jinja2=2.11.1=py_0 38 | - jpeg=9b=h024ee3a_2 39 | - jsonschema=3.2.0=py36_0 40 | - jupyter=1.0.0=py36_7 41 | - jupyter_client=5.3.4=py36_0 42 | - jupyter_console=6.1.0=py_0 43 | - jupyter_core=4.6.1=py36_0 44 | - libedit=3.1.20181209=hc058e9b_0 45 | - libffi=3.2.1=hd88cf55_4 46 | - libgcc-ng=9.1.0=hdf63c60_0 47 | - libgfortran-ng=7.3.0=hdf63c60_0 48 | - libgpg-error=1.36=he1b5a44_0 49 | - libpng=1.6.37=hbc83047_0 50 | - libsodium=1.0.16=h1bed415_0 51 | - libstdcxx-ng=9.1.0=hdf63c60_0 52 | - libtiff=4.0.10=h2733197_2 53 | - libuuid=1.0.3=h1bed415_2 54 | - libxcb=1.13=h1bed415_1 55 | - libxml2=2.9.9=hea5a465_1 56 | - llvm=3.3=0 57 | - markupsafe=1.1.1=py36h7b6447c_0 58 | - mesa=10.5.4=0 59 | - mistune=0.8.4=py36h7b6447c_0 60 | - mkl=2019.4=243 61 | - mkl-service=2.0.2=py36h7b6447c_0 62 | - mkl_fft=1.0.14=py36ha843d7b_0 63 | - mkl_random=1.0.2=py36hd81dba3_0 64 | - nbconvert=5.6.1=py36_0 65 | - nbformat=5.0.4=py_0 66 | - ncurses=6.1=he6710b0_1 67 | - ninja=1.9.0=py36hfd86e86_0 68 | - notebook=6.0.3=py36_0 69 | - numpy-base=1.19.1=py36hfa32c7d_0 70 | - olefile=0.46=py36_0 71 | - openssl=1.1.1k=h27cfd23_0 72 | - osmesa=12.2.2.dev=0 73 | - pandoc=2.2.3.2=0 74 | - pandocfilters=1.4.2=py36_1 75 | - parso=0.6.1=py_0 76 | - patchelf=0.9=he6710b0_3 77 | - pcre=8.43=he6710b0_0 78 | - pexpect=4.8.0=py36_0 79 | - pickleshare=0.7.5=py36_0 80 | - pillow=6.1.0=py36h34e0f95_0 81 | - pip=19.2.2=py36_0 82 | - prometheus_client=0.7.1=py_0 83 | - prompt_toolkit=3.0.3=py_0 84 | - ptyprocess=0.6.0=py36_0 85 | - pycparser=2.19=py36_0 86 | - pygments=2.5.2=py_0 87 | - pyqt=5.9.2=py36h05f1152_2 88 | - pyrsistent=0.15.7=py36h7b6447c_0 89 | - python=3.6.9=h265db76_0 90 | - pyzmq=18.1.1=py36he6710b0_0 91 | - qt=5.9.7=h5867ecd_1 92 | - qtconsole=4.6.0=py_1 93 | - readline=7.0=h7b6447c_5 94 | - send2trash=1.5.0=py36_0 95 | - setuptools=41.0.1=py36_0 96 | - sip=4.19.8=py36hf484d3e_0 97 | - six=1.12.0=py36_0 98 | - sqlite=3.29.0=h7b6447c_0 99 | - system=5.8=2 100 | - terminado=0.8.3=py36_0 101 | - testpath=0.4.4=py_0 102 | - tk=8.6.8=hbc83047_0 103 | - tornado=6.0.3=py36h7b6447c_3 104 | - traitlets=4.3.3=py36_0 105 | - wcwidth=0.1.8=py_0 106 | - webencodings=0.5.1=py36_1 107 | - wheel=0.33.4=py36_0 108 | - widgetsnbextension=3.5.1=py36_0 109 | - xz=5.2.4=h14c3975_4 110 | - zeromq=4.3.1=he6710b0_3 111 | - zipp=2.2.0=py_0 112 | - zlib=1.2.11=h7b6447c_3 113 | - zstd=1.3.7=h0b5b093_0 114 | - pip: 115 | - absl-py==0.9.0 116 | - astor==0.8.1 117 | - astunparse==1.6.3 118 | - cachetools==4.0.0 119 | - chardet==3.0.4 120 | - cloudpickle==1.2.1 121 | - configspace==0.4.10 122 | - cycler==0.10.0 123 | - cython==0.29.13 124 | - dataclasses==0.7 125 | - decorator==4.4.0 126 | - dm-env==1.2 127 | - dm-tree==0.1.5 128 | - dotmap==1.2.20 129 | - future==0.16.0 130 | - gast==0.2.2 131 | - glfw==1.8.3 132 | - google-auth==1.11.2 133 | - google-auth-oauthlib==0.4.1 134 | - google-pasta==0.2.0 135 | - googledrivedownloader==0.4 136 | - gpflow==1.1.0 137 | - grpcio==1.27.2 138 | - gym==0.14.0 139 | - h5py==2.10.0 140 | - hpbandster==0.7.4 141 | - idna==2.8 142 | - imageio==2.5.0 143 | - isodate==0.6.0 144 | - joblib==0.13.2 145 | - keras-applications==1.0.8 146 | - keras-preprocessing==1.1.2 147 | - kiwisolver==1.1.0 148 | - labmaze==1.0.3 149 | - lockfile==0.12.2 150 | - lxml==4.5.2 151 | - markdown==3.2.1 152 | - matplotlib==3.1.1 153 | - more-itertools==8.3.0 154 | - mujoco-py==2.0.2.5 155 | - multipledispatch==0.6.0 156 | - netifaces==0.10.9 157 | - networkx==2.3 158 | - numpy==1.17.4 159 | - oauthlib==3.1.0 160 | - omegaconf==2.0.0 161 | - opencv-python==4.4.0.42 162 | - opt-einsum==3.2.1 163 | - packaging==20.4 164 | - pandas==0.25.1 165 | - patsy==0.5.1 166 | - pluggy==0.13.1 167 | - plyfile==0.7 168 | - protobuf==3.10.0 169 | - py==1.8.1 170 | - pyasn1==0.4.8 171 | - pyasn1-modules==0.2.8 172 | - pyglet==1.3.2 173 | - pyopengl==3.1.5 174 | - pyparsing==2.4.2 175 | - pyro4==4.80 176 | - pytest==5.4.2 177 | - python-dateutil==2.8.0 178 | - pytz==2019.2 179 | - pyyaml==5.3.1 180 | - rdflib==4.2.2 181 | - requests==2.22.0 182 | - requests-oauthlib==1.3.0 183 | - rsa==4.0 184 | - scikit-learn==0.21.3 185 | - scipy==1.3.1 186 | - seaborn==0.11.0 187 | - serpent==1.30.2 188 | - statsmodels==0.11.1 189 | - tensorboard==1.15.0 190 | - tensorboard-plugin-wit==1.6.0.post3 191 | - tensorflow==1.15.0 192 | - tensorflow-estimator==1.15.1 193 | - termcolor==1.1.0 194 | - tqdm==4.19.4 195 | - typing==3.7.4.1 196 | - typing-extensions==3.7.4.2 197 | - urllib3==1.25.3 198 | - werkzeug==1.0.0 199 | - wrapt==1.12.1 200 | prefix: /home/zhangb/anaconda3/envs/mbrl 201 | -------------------------------------------------------------------------------- /learned_schedule/reacher_rs_model_train.json: -------------------------------------------------------------------------------- 1 | [{"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}, {"model_learning_rate": 0.00011291116329538502, "model_train_epoch": 4, "model_weight_decay": 4.611330762141144e-05}] -------------------------------------------------------------------------------- /mbexp.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import argparse 7 | import pprint 8 | 9 | from dotmap import DotMap 10 | 11 | from dmbrl.misc.MBExp import MBExperiment 12 | from dmbrl.controllers.MPC import MPC 13 | from dmbrl.config import create_config 14 | 15 | 16 | def main(env, ctrl_type, ctrl_args, overrides, logdir): 17 | ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args}) 18 | cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir)[0] 19 | cfg.pprint() 20 | 21 | if ctrl_type == "MPC": 22 | cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg) 23 | exp = MBExperiment(cfg.exp_cfg) 24 | 25 | os.makedirs(exp.logdir) 26 | with open(os.path.join(exp.logdir, "config.txt"), "w") as f: 27 | f.write(pprint.pformat(cfg.toDict())) 28 | 29 | exp.run_experiment() 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('-env', type=str, required=True, 35 | help='Environment name: select from [cartpole, reacher, pusher, halfcheetah]') 36 | parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[], 37 | help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments') 38 | parser.add_argument('-o', '--override', action='append', nargs=2, default=[], 39 | help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides') 40 | parser.add_argument('-logdir', type=str, default='log', 41 | help='Directory to which results will be logged (default: ./log)') 42 | args = parser.parse_args() 43 | 44 | main(args.env, "MPC", args.ctrl_arg, args.override, args.logdir) 45 | -------------------------------------------------------------------------------- /pbt-bt-mbexp.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import argparse 6 | 7 | from dotmap import DotMap 8 | 9 | from dmbrl.misc.MBwPBTBTExp import MBWithPBTBTExperiment 10 | from dmbrl.controllers.MPC import MPC 11 | from dmbrl.config import create_config 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | 16 | def create_cfg_creator(env, ctrl_type, ctrl_args, base_overrides, logdir): 17 | ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args}) 18 | 19 | def cfg_creator(additional_overrides=None): 20 | if additional_overrides is not None: 21 | return create_config(env, ctrl_type, ctrl_args, base_overrides + additional_overrides, logdir) 22 | return create_config(env, ctrl_type, ctrl_args, base_overrides, logdir) 23 | 24 | return cfg_creator 25 | 26 | 27 | def main(args): 28 | cfg_creator = create_cfg_creator(args.env, args.ctrl_type, args.ctrl_arg, args.override, args.logdir) 29 | cfg = cfg_creator()[0] 30 | cfg.pprint() 31 | 32 | if args.ctrl_type == "MPC": 33 | cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg) 34 | 35 | exp = MBWithPBTBTExperiment(cfg.exp_cfg, cfg_creator, args) 36 | exp.run_experiment() 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('-env', type=str, required=True, 42 | help='Environment name: select from [cartpole, reacher, pusher, halfcheetah, halfcheetah_v3]') 43 | parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[], 44 | help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments') 45 | parser.add_argument('-o', '--override', action='append', nargs=2, default=[], 46 | help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides') 47 | parser.add_argument('-logdir', type=str, default='log', 48 | help='Directory to which results will be logged (default: ./log)') 49 | parser.add_argument('-ctrl_type', type=str, default='MPC', 50 | help='Control type will be applied (default: MPC)') 51 | # Parser for running dynamic scheduler on cluster 52 | parser.add_argument('-config_names', type=str, default="model_learning_rate", nargs="+", 53 | help='Specify which hyperparameters to optimize') 54 | parser.add_argument('-seed', type=int, default=0, 55 | help='Specify the random seed of the experiment') 56 | parser.add_argument('-worker_id', type=int, default=0, 57 | help='The worker id, e.g. using SLURM ARRAY JOB ID') 58 | parser.add_argument('-worker', action='store_true', 59 | help='Flag to turn this into a worker process otherwise this will start a new controller') 60 | parser.add_argument('-sample_from_percent', type=float, default=0.2, 61 | help='Sample from the top ratio N') 62 | parser.add_argument('-resample_if_not_in_percent', type=float, default=0.8, 63 | help='Resample if the configuration is not in the top ratio N') 64 | parser.add_argument('-resample_probability', type=float, default=0.25, 65 | help='Probability of an exploited member resampling configurations randomly') 66 | parser.add_argument('-resample_prob_decay', type=float, default=1, 67 | help='decay factor of resample if not in percent') 68 | parser.add_argument('-max_steps', type=int, default=60*40, 69 | help='Maximum amount of steps to take') 70 | parser.add_argument('-delta_t', type=int, default=30, 71 | help='every X steps we do once backtracking') 72 | parser.add_argument('-tolerance', type=float, default=0.2, 73 | help='the tolerance of performance drop') 74 | parser.add_argument('-elite_ratio', type=float, default=0.125, 75 | help='elite ratio of the population') 76 | 77 | args = parser.parse_args() 78 | print(args) 79 | # Set the random seeds of the experiment 80 | tf.set_random_seed(args.seed) 81 | np.random.seed(args.seed) 82 | main(args) 83 | -------------------------------------------------------------------------------- /pbt-mbexp.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import argparse 6 | 7 | from dotmap import DotMap 8 | 9 | from dmbrl.misc.MBwPBTExp import MBWithPBTExperiment 10 | from dmbrl.controllers.MPC import MPC 11 | from dmbrl.config import create_config 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | 16 | def create_cfg_creator(env, ctrl_type, ctrl_args, base_overrides, logdir): 17 | ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args}) 18 | 19 | def cfg_creator(additional_overrides=None): 20 | if additional_overrides is not None: 21 | return create_config(env, ctrl_type, ctrl_args, base_overrides + additional_overrides, logdir) 22 | return create_config(env, ctrl_type, ctrl_args, base_overrides, logdir) 23 | 24 | return cfg_creator 25 | 26 | 27 | def main(args): 28 | cfg_creator = create_cfg_creator(args.env, args.ctrl_type, args.ctrl_arg, args.override, args.logdir) 29 | cfg = cfg_creator()[0] 30 | cfg.pprint() 31 | 32 | if args.ctrl_type == "MPC": 33 | cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg) 34 | 35 | exp = MBWithPBTExperiment(cfg.exp_cfg, cfg_creator, args) 36 | exp.run_experiment() 37 | 38 | 39 | if __name__ == "__main__": 40 | print("Exp start!") 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('-env', type=str, required=True, 43 | help='Environment name: select from [cartpole, reacher, pusher, halfcheetah, halfcheetah_v3]') 44 | parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[], 45 | help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments') 46 | parser.add_argument('-o', '--override', action='append', nargs=2, default=[], 47 | help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides') 48 | parser.add_argument('-logdir', type=str, default='log', 49 | help='Directory to which results will be logged (default: ./log)') 50 | parser.add_argument('-ctrl_type', type=str, default='MPC', 51 | help='Control type will be applied (default: MPC)') 52 | # Parser for running dynamic scheduler on cluster 53 | parser.add_argument('-config_names', type=str, default="model_learning_rate", nargs="+", 54 | help='Specify which hyperparameters to optimize') 55 | parser.add_argument('-seed', type=int, default=0, 56 | help='Specify the random seed of the experiment') 57 | parser.add_argument('-worker_id', type=int, default=0, 58 | help='The worker id, e.g. using SLURM ARRAY JOB ID') 59 | parser.add_argument('-worker', action='store_true', 60 | help='Flag to turn this into a worker process otherwise this will start a new controller') 61 | parser.add_argument('-sample_from_percent', type=float, default=0.2, 62 | help='Sample from the top ratio N') 63 | parser.add_argument('-resample_if_not_in_percent', type=float, default=0.8, 64 | help='Resample if the configuration is not in the top ratio N') 65 | parser.add_argument('-resample_probability', type=float, default=0.25, 66 | help='Probability of an exploited member resampling configurations randomly') 67 | parser.add_argument('-resample_prob_decay', type=float, default=1, 68 | help='decay factor of resample if not in percent') 69 | parser.add_argument('-not_copy_data', type=bool, default=False, 70 | help='Set to True if not copy the data to new trials') 71 | 72 | args = parser.parse_args() 73 | print(args) 74 | # Set the random seeds of the experiment 75 | tf.set_random_seed(args.seed) 76 | np.random.seed(args.seed) 77 | main(args) 78 | -------------------------------------------------------------------------------- /pbt/__init__.py: -------------------------------------------------------------------------------- 1 | from pbt.controller import PBTController 2 | from pbt.worker import PBTWorker 3 | 4 | __all__ = ['Controller', 'Worker'] 5 | -------------------------------------------------------------------------------- /pbt/backtrack_controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import random 4 | import logging 5 | import os 6 | import sys 7 | from time import sleep 8 | 9 | from pbt.exploitation import Truncation 10 | from pbt.exploration import Perturb 11 | from pbt.garbage_collector import GarbageCollector 12 | from pbt.network import ControllerDaemon 13 | from pbt.population import Population 14 | from pbt.backtrack_scheduler import BacktrackScheduler 15 | from pbt.tqdm_logger import TqdmLoggingHandler 16 | 17 | class PBTwBT_Controller: 18 | def __init__( 19 | self, pop_size, start_hyperparameters, exploitation=Truncation(), 20 | exploration=Perturb(), ready=lambda: True, 21 | exploration_bt=lambda x: x, 22 | stop=lambda iterations, _: iterations >= 100.0, 23 | data_path=os.getcwd(), results_path=os.getcwd(), 24 | max_steps=sys.maxsize, delta_t=30, tolerance=0.2, elite_ratio=0.125): 25 | 26 | self.logger = logging.getLogger('pbt') 27 | self.logger.setLevel(logging.DEBUG) 28 | self.logger.addHandler(TqdmLoggingHandler()) 29 | self.daemon = None 30 | self.population = Population(pop_size, stop, results_path, backtracking=True, elite_ratio=elite_ratio) 31 | self.data_path = data_path 32 | self.scheduler = BacktrackScheduler( 33 | self.population, start_hyperparameters, exploitation, exploration, 34 | exploration_bt, delta_t, tolerance) 35 | self.garbage_collector = GarbageCollector(data_path) 36 | self.ready = ready 37 | self.max_steps = max_steps 38 | self.total_steps = 0 39 | 40 | self.workers = {} 41 | self._done = False 42 | 43 | self.logger.info( 44 | f'Started controller with parameters: ' + 45 | f'population size: {pop_size}, ' 46 | f'data_path: {data_path}, ' 47 | f'results_path: {results_path}') 48 | 49 | def start_daemon(self): 50 | self.logger.info('Starting daemon.') 51 | self.daemon = ControllerDaemon(self) 52 | self.daemon.start() 53 | 54 | def register_worker(self, worker): 55 | self.logger.debug(f'Worker {worker.worker_id} registered.') 56 | self.workers[worker.worker_id] = worker 57 | 58 | def request_trial(self): 59 | trial = self.scheduler.get_trial() 60 | return trial.to_tuple() 61 | 62 | def send_evaluation(self, member_id, score): 63 | self.logger.debug( 64 | f'Receiving evaluation for member {member_id}: {score}') 65 | # TODO: Some sort of ready function? 66 | min_timestep = self.population.get_min_time_step() 67 | # exclude all elites members (PBT-BT only) 68 | elites = self.population.get_elites() 69 | # self.garbage_collector.collect(member_id, min_timestep, elites) 70 | trial = self.population.update(member_id, score) 71 | self.scheduler.update_exploration(trial) 72 | self.total_steps += 1 73 | self.logger.debug( 74 | f'Current total step is : {self.total_steps}') 75 | if self.population.is_done(): 76 | self.logger.info('Nothing more to do. Shutting down.') 77 | self._shut_down_workers() 78 | sleep(10) 79 | if self.daemon: 80 | self.daemon.shut_down() 81 | 82 | if self.total_steps >= self.max_steps: 83 | self.logger.info('Reach the maximal steps. Shutting down.') 84 | self._shut_down_workers() 85 | sleep(10) 86 | if self.daemon: 87 | self.daemon.shut_down() 88 | 89 | def _shut_down_workers(self): 90 | for worker in self.workers.values(): 91 | worker.stop() 92 | -------------------------------------------------------------------------------- /pbt/backtrack_scheduler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pbt.population import NoTrial, Trial 4 | from pbt.tqdm_logger import TqdmLoggingHandler 5 | import random 6 | from itertools import product 7 | 8 | class BacktrackScheduler: 9 | def __init__( 10 | self, population, start_hyperparameters, exploitation, exploration, exploration_bt, delta_t=30, tolerance=0.2): 11 | """Initialize a scheduler with backtracking mechanism 12 | """ 13 | self.logger = logging.getLogger('pbt') 14 | self.logger.setLevel(logging.DEBUG) 15 | self.population = population 16 | self.exploitation = exploitation 17 | self.exploration = exploration 18 | self.exploration_bt = exploration_bt 19 | 20 | self.start_hyperparameters = start_hyperparameters 21 | self.delta_t = delta_t 22 | self.tolerance = tolerance 23 | self.record = {} # For tabu search 24 | 25 | def get_trial(self): 26 | self.logger.debug('Trial requested.') 27 | member = self.population.get_next_member() 28 | if not member: 29 | self.logger.debug('No trial ready.') 30 | return NoTrial() 31 | 32 | # intial run with starting hyperparameter 33 | if member.time_step == 0: 34 | start_hyperparameters = {cfg_name : self.start_hyperparameters[cfg_name]() 35 | for cfg_name in self.start_hyperparameters.keys()} 36 | trial = Trial( 37 | member.member_id, -1, 0, -1, start_hyperparameters) 38 | self.population.save_trial(trial) 39 | self.logger.debug(f'Returning first trial {trial}.') 40 | return trial 41 | 42 | if member.time_step % self.delta_t == 0: 43 | # start to check if it drops by X percentage 44 | self.logger.debug(f'Generating trial for member {member.member_id} with times step {member.time_step} with BT.') 45 | self.logger.debug(f'member {member.member_id} actual time step is {member._actual_time_step}.') 46 | elites = self.population.get_elites() 47 | model_id, model_time_step = self.backtracking_exploitation(member, elites) 48 | hyperparameters = self.population.get_hyperparameters_by_time_step(model_id, model_time_step) 49 | if model_id == member.member_id and model_time_step == member.time_step - 1: 50 | self.logger.debug(f'Staying with current model {model_id}.') 51 | else: 52 | self.logger.debug(f'Backtracking to model {model_id}, time step {model_time_step}.') 53 | member.set_actual_time_step(model_time_step) 54 | self.logger.debug(f'Set model {model_id} actual time step to {model_time_step}.') 55 | # TODO: Replace it with backtracking exploration 56 | hyperparameters = self.exploration_bt(hyperparameters, model_id, model_time_step) 57 | self.logger.debug(f'Using exploration. New: {hyperparameters}') 58 | trial = Trial( 59 | member.member_id, model_id, member.time_step, model_time_step, 60 | hyperparameters) 61 | self.population.save_trial(trial) 62 | return trial 63 | 64 | # Jointly do standard PBT with elites 65 | self.logger.debug(f'Generating trial for member {member.member_id} with times step {member.time_step}.') 66 | self.logger.debug(f'member {member.member_id} actual time step is {member._actual_time_step}.') 67 | scores = self.population.get_scores() 68 | # self.logger.debug(f'Collected all scoires.') 69 | # model_id indicates the model that we want to copy 70 | model_id = self.exploitation(member.member_id, scores) 71 | # model_time_step shows the timestep of the model to be copied 72 | model_time_step = self.population.get_latest_time_step(model_id) - 1 73 | # actual_time_step = self.population.get_actual_time_step_by_member_id(model_id) 74 | # self.logger.debug(f'Safely get timestep') 75 | hyperparameters = self.population.get_hyperparameters_by_time_step(model_id, model_time_step) 76 | # self.logger.debug(f'Safely get hyperparameters') 77 | if model_id != member.member_id: 78 | self.logger.debug(f'Copying model {model_id} at time step {model_time_step}.') 79 | member.set_actual_time_step(model_time_step) 80 | self.logger.debug(f'Set model {model_id} actual time step to {model_time_step}.') 81 | hyperparameters = self.exploration(hyperparameters) 82 | self.logger.debug(f'Using exploration. New: {hyperparameters}') 83 | else: 84 | self.logger.debug(f'Staying with current model {model_id}.') 85 | trial = Trial( 86 | member.member_id, model_id, member.time_step, model_time_step, 87 | hyperparameters) 88 | self.population.save_trial(trial) 89 | # self.logger.debug(f'Safely saved trial') 90 | return trial 91 | 92 | def update_exploration(self, trial): 93 | self.exploration.update(trial) 94 | 95 | # For PBT-BT only 96 | def backtracking_exploitation(self, member, elites): 97 | # TODO: handle zero division 98 | percentage_change = (member.get_last_score() - elites[-1].score) / abs(elites[-1].score) 99 | self.logger.debug(f'Performance changed {percentage_change}.') 100 | if percentage_change < - self.tolerance: 101 | trial = random.choice(elites) 102 | else: 103 | trial = member.get_last_trial() 104 | 105 | return trial.member_id, trial.time_step 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /pbt/controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import random 4 | import logging 5 | import os 6 | import sys 7 | from time import sleep 8 | 9 | from pbt.exploitation import Truncation 10 | from pbt.exploration import Perturb 11 | from pbt.garbage_collector import GarbageCollector 12 | from pbt.network import ControllerDaemon 13 | from pbt.population import Population 14 | from pbt.scheduler import Scheduler 15 | from pbt.tqdm_logger import TqdmLoggingHandler 16 | 17 | class PBTController: 18 | def __init__( 19 | self, pop_size, start_hyperparameters, exploitation=Truncation(), 20 | exploration=Perturb(), ready=lambda: True, 21 | stop=lambda iterations, _: iterations >= 100.0, 22 | data_path=os.getcwd(), results_path=os.getcwd(), 23 | max_steps=sys.maxsize): 24 | 25 | self.logger = logging.getLogger('pbt') 26 | self.logger.setLevel(logging.DEBUG) 27 | self.logger.addHandler(TqdmLoggingHandler()) 28 | self.daemon = None 29 | self.population = Population(pop_size, stop, results_path) 30 | self.data_path = data_path 31 | self.scheduler = Scheduler( 32 | self.population, start_hyperparameters, exploitation, exploration) 33 | self.garbage_collector = GarbageCollector(data_path) 34 | self.ready = ready 35 | self.max_steps = max_steps 36 | self.total_steps = 0 37 | 38 | self.workers = {} 39 | self._done = False 40 | 41 | self.logger.info( 42 | f'Started controller with parameters: ' + 43 | f'population size: {pop_size}, ' 44 | f'data_path: {data_path}, ' 45 | f'results_path: {results_path}') 46 | 47 | def start_daemon(self): 48 | self.logger.info('Starting daemon.') 49 | self.daemon = ControllerDaemon(self) 50 | self.daemon.start() 51 | 52 | def register_worker(self, worker): 53 | self.logger.debug(f'Worker {worker.worker_id} registered.') 54 | self.workers[worker.worker_id] = worker 55 | 56 | def request_trial(self): 57 | trial = self.scheduler.get_trial() 58 | return trial.to_tuple() 59 | 60 | def send_evaluation(self, member_id, score): 61 | self.logger.debug( 62 | f'Receiving evaluation for member {member_id}: {score}') 63 | # TODO: Some sort of ready function? 64 | min_timestep = self.population.get_min_time_step() 65 | # exclude all elites members (PBT-BT only) 66 | elites = [] 67 | self.logger.info(f'min time step is {min_timestep}') 68 | self.garbage_collector.collect(member_id, min_timestep, elites) 69 | trial = self.population.update(member_id, score) 70 | self.scheduler.update_exploration(trial) 71 | self.total_steps += 1 72 | if self.population.is_done(): 73 | self.logger.info('Nothing more to do. Shutting down.') 74 | self._shut_down_workers() 75 | sleep(10) 76 | if self.daemon: 77 | self.daemon.shut_down() 78 | 79 | if self.total_steps >= self.max_steps: 80 | self.logger.info('Reach the maximal steps. Shutting down.') 81 | self._shut_down_workers() 82 | sleep(10) 83 | if self.daemon: 84 | self.daemon.shut_down() 85 | 86 | def _shut_down_workers(self): 87 | for worker in self.workers.values(): 88 | worker.stop() 89 | -------------------------------------------------------------------------------- /pbt/exploitation/__init__.py: -------------------------------------------------------------------------------- 1 | from .exploitation_strategy import ExploitationStrategy 2 | from .truncation import Truncation 3 | 4 | __all__ = ['ExploitationStrategy', 'Truncation'] 5 | -------------------------------------------------------------------------------- /pbt/exploitation/constant.py: -------------------------------------------------------------------------------- 1 | class Constant: 2 | def __init__(self): 3 | """ 4 | Constant mechanism which doesn't exploit any member 5 | """ 6 | pass 7 | 8 | def __call__(self, own_name: str, scores: dict) -> str: 9 | """ 10 | Return the name of the given member 11 | :param own_name: The agent of the current agent 12 | :param scores: A dict with names and scores of all agents 13 | :return: The name of the chosen better agent 14 | """ 15 | return own_name -------------------------------------------------------------------------------- /pbt/exploitation/exploitation_strategy.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class ExploitationStrategy: 5 | @abc.abstractmethod 6 | def __call__(self, own_name: str, scores: dict) -> str: 7 | """ 8 | This method should implement the exploitation behaviour. 9 | :param own_name: The name of the current agent 10 | :param scores: The names and scores of all agents 11 | :return: The name of the agent to copy 12 | """ 13 | raise NotImplementedError('This method has to be overwritten!') 14 | -------------------------------------------------------------------------------- /pbt/exploitation/truncation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import operator 3 | import random 4 | 5 | 6 | class Truncation: 7 | def __init__(self, sample_from_percent=0.2, resample_if_not_in_percent=0.8): 8 | """ 9 | A simple truncation mechanism as specified in (Jaderberg et al., 2017) 10 | :param sample_from_percent: Percent of best part of the population 11 | :param resample_if_not_in_percent: Percent of untouched agents 12 | """ 13 | self.sample_from_percent = sample_from_percent 14 | self.resample_if_not_in_percent = resample_if_not_in_percent 15 | 16 | def __call__(self, own_name: str, scores: dict) -> str: 17 | """ 18 | Find a better agent or return own_name, if the agent is good enough. 19 | :param own_name: The agent of the current agent 20 | :param scores: A dict with names and scores of all agents 21 | :return: The name of the chosen better agent 22 | """ 23 | if len(scores) == 1: 24 | return own_name 25 | if own_name in self._get_best(self.resample_if_not_in_percent, scores): 26 | return own_name 27 | else: 28 | return random.choice(list( 29 | self._get_best(self.sample_from_percent, scores).keys())) 30 | 31 | def _get_best(self, percent, scores): 32 | sorted_scores = sorted( 33 | scores.items(), reverse=True, key=operator.itemgetter(1)) 34 | last_index = math.ceil((len(scores) - 1) * percent) 35 | return {name: score for name, score in sorted_scores[:last_index]} 36 | -------------------------------------------------------------------------------- /pbt/exploration/__init__.py: -------------------------------------------------------------------------------- 1 | from .exploration_strategy import ExplorationStrategy 2 | from .perturb import Perturb 3 | from .resample import Resample 4 | from .perturb_and_resample import PerturbAndResample 5 | from pbt.exploration.models.tree_parzen_estimator import TreeParzenEstimator 6 | from .exploration_bt import Exploration_BT 7 | __all__ = [ 8 | 'ExplorationStrategy', 'Perturb', 'Resample', 'PerturbAndResample', 9 | 'TreeParzenEstimator', 'Exploration_BT'] 10 | -------------------------------------------------------------------------------- /pbt/exploration/constant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Constant: 4 | """ 5 | A simple constant mechanism which returns the same configuration as given. 6 | """ 7 | def __init__(self): 8 | pass 9 | 10 | def __call__(self, hyperparameters: dict) -> dict: 11 | """ 12 | Perturb the nodes in the input. 13 | :param hyperparameters: A dict with nodes. 14 | :return: The perturbed nodes. 15 | """ 16 | result = hyperparameters.copy() 17 | return result 18 | 19 | -------------------------------------------------------------------------------- /pbt/exploration/exploration_bt.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from pbt.exploration import Resample 4 | import itertools 5 | 6 | class Exploration_BT(): 7 | def __init__( 8 | self, mutations: dict, cs_space: dict, resample_probability: float = 0.25, 9 | boundaries={}): 10 | """ 11 | A strategy to do both perturb and resample. 12 | :param mutations: A dictionary with hyperparameter names and mutations 13 | :param resample_probability: The probability to resample for each call 14 | """ 15 | self.resample_probability = resample_probability 16 | self.resample = Resample(mutations=mutations) 17 | self.records = {} 18 | self.cs_space = cs_space 19 | self.boundaries = boundaries 20 | 21 | def __call__(self, hyperparameters: dict, model_id: int, model_time_step: int) -> dict: 22 | """ 23 | Tabu search of the all possible perturbation. If all possibilities are tried, then sample randomly 24 | from the configuration space 25 | :param hyperparameters: The nodes to perturb 26 | :return: Perturbed and probably resampled nodes 27 | """ 28 | result = hyperparameters.copy() 29 | num_hyperparameters = len(result) 30 | if (model_id, model_time_step) not in self.records: 31 | self.records[(model_id, model_time_step)] = list(itertools.product([-1,1], repeat=num_hyperparameters)) 32 | random.shuffle(self.records[(model_id, model_time_step)]) 33 | 34 | if random.random() < self.resample_probability or len(self.records[(model_id, model_time_step)]) == 0: 35 | result = self.resample(result) 36 | else: 37 | result = self.perturb(result, model_id, model_time_step) 38 | 39 | return result 40 | 41 | def perturb(self, hyperparameters: dict, model_id: int, model_time_step: int) -> dict: 42 | directions = self.records[(model_id, model_time_step)].pop() 43 | for i, key in enumerate(sorted(hyperparameters)): 44 | temp_value = self.cs_space[key]._inverse_transform(hyperparameters[key]) 45 | temp_value += directions[i] * 0.2 * temp_value 46 | hyperparameters[key] = self.cs_space[key]._transform(temp_value) 47 | self.ensure_boundaries(hyperparameters) 48 | return hyperparameters 49 | 50 | def ensure_boundaries(self, result): 51 | for key in result: 52 | if key not in self.boundaries: 53 | continue 54 | if result[key] < self.boundaries[key][0]: 55 | result[key] = self.boundaries[key][0] 56 | elif result[key] > self.boundaries[key][1]: 57 | result[key] = self.boundaries[key][1] 58 | -------------------------------------------------------------------------------- /pbt/exploration/exploration_strategy.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class ExplorationStrategy: 5 | @abstractmethod 6 | def __call__(self, hyperparameters): 7 | """ 8 | This method should implement the exploration behaviour. 9 | :param hyperparameters: The nodes to explore 10 | :return: The changed nodes 11 | """ 12 | raise NotImplementedError('This method has to be overwritten!') 13 | 14 | def update(self, trial): 15 | pass 16 | -------------------------------------------------------------------------------- /pbt/exploration/model_based.py: -------------------------------------------------------------------------------- 1 | from pbt.exploration import ExplorationStrategy 2 | 3 | 4 | class ModelBased(ExplorationStrategy): 5 | def __init__(self, model): 6 | self.model = model 7 | 8 | def __call__(self, hyperparameters): 9 | return self.model.sample() 10 | 11 | def update(self, trial): 12 | self.model.update(trial.to_array()) 13 | -------------------------------------------------------------------------------- /pbt/exploration/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model 2 | from .tree_parzen_estimator import TreeParzenEstimator 3 | 4 | __all__ = ['Model', 'TreeParzenEstimator'] 5 | -------------------------------------------------------------------------------- /pbt/exploration/models/config_tree/__init__.py: -------------------------------------------------------------------------------- 1 | from .config_tree import ConfigTree 2 | 3 | __all__ = ['ConfigTree'] 4 | -------------------------------------------------------------------------------- /pbt/exploration/models/config_tree/config_tree.py: -------------------------------------------------------------------------------- 1 | class ConfigTree: 2 | def __init__(self, root): 3 | self.root = root 4 | self.all_nodes = self._get_all_nodes() 5 | self.hyperparameter_names = self._get_hyperparameter_names() 6 | self._distribute_indices() 7 | 8 | def sample(self): 9 | result = self._sample_from_root_nodes() 10 | for hyperparameter in self.hyperparameter_names: 11 | if hyperparameter not in result: 12 | result[hyperparameter] = None 13 | return result 14 | 15 | def uniform_sample(self): 16 | result = {} 17 | for node in self.all_nodes: 18 | node.uniform_sample(result) 19 | return result 20 | 21 | def evaluate(self, data): 22 | scores = [0.0 for _ in data] 23 | for node in self.root: 24 | node.evaluate(data, scores) 25 | return scores 26 | 27 | def fit(self, data): 28 | for node in self.all_nodes: 29 | node.fit([point[node.index] for point in data]) 30 | 31 | def structural_copy(self): 32 | return ConfigTree([node.structural_copy() for node in self.root]) 33 | 34 | def _distribute_indices(self): 35 | for i, node in enumerate(self.all_nodes): 36 | node.index = i + 3 # Place 0, 1, and 2 are score and time_step ... 37 | 38 | def _sample_from_root_nodes(self): 39 | result = {} 40 | for node in self.root: 41 | node.sample(result) 42 | return result 43 | 44 | def _get_all_nodes(self): 45 | all_nodes = [] 46 | for node in self.root: 47 | all_nodes.append(node) 48 | all_nodes += node.get_children() 49 | return sorted(all_nodes, key=lambda n: n.name) 50 | 51 | def _get_hyperparameter_names(self): 52 | return [node.name for node in self.all_nodes] 53 | -------------------------------------------------------------------------------- /pbt/exploration/models/config_tree/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | from .node import Node 2 | from .categorical import Categorical 3 | from .float import Float 4 | from .log_float import LogFloat 5 | from .integer import Integer 6 | 7 | __all__ = ['Node', 'Categorical', 'Float', 'LogFloat', 'Integer'] 8 | -------------------------------------------------------------------------------- /pbt/exploration/models/config_tree/nodes/categorical.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import numpy as np 4 | 5 | from pbt.exploration.models.config_tree.nodes import Node 6 | 7 | 8 | class Categorical(Node): 9 | def __init__(self, name, values): 10 | self.name = name 11 | self.keys = list(values.keys()) 12 | self.values = values 13 | self.probabilities = [1.0/len(self.keys) for _ in self.keys] 14 | 15 | def sample(self, result): 16 | value = str(np.random.choice(self.keys, p=self.probabilities)) 17 | result[self.name] = value 18 | if self.values[value] is not None: 19 | self.values[value].sample(result) 20 | 21 | def uniform_sample(self, result): 22 | result[self.name] = str(np.random.choice(self.keys)) 23 | 24 | def evaluate(self, data, scores): 25 | single_scores = [ 26 | self._get_log_density(point[self.name]) for point in data] 27 | for i, single_score in enumerate(single_scores): 28 | scores += single_score 29 | 30 | def _get_log_density(self, value): 31 | return np.log(self.probabilities[self.keys.index(value)]) 32 | 33 | def fit(self, values): 34 | counter = Counter(values) 35 | self.probabilities = [ 36 | counter[key]/len(values) for key in self.keys] 37 | self._add_random_exploration() 38 | 39 | def _add_random_exploration(self, random_probability=0.1): 40 | factor = 1.0 - random_probability 41 | addend = random_probability / len(self.probabilities) 42 | 43 | for i, _ in enumerate(self.probabilities): 44 | self.probabilities[i] *= factor 45 | self.probabilities[i] += addend 46 | 47 | def get_children(self): 48 | result = [] 49 | for child in self.values.values(): 50 | if child is not None: 51 | result.append(child) 52 | result += child.get_children() 53 | return result 54 | 55 | def structural_copy(self): 56 | values = { 57 | name: node.structural_copy() if node is not None else None 58 | for name, node in self.values.items()} 59 | return Categorical(self.name, values) 60 | -------------------------------------------------------------------------------- /pbt/exploration/models/config_tree/nodes/float.py: -------------------------------------------------------------------------------- 1 | from bokeh.io import show 2 | from bokeh.plotting import Figure 3 | 4 | import numpy as np 5 | from sklearn.neighbors import KernelDensity 6 | 7 | from pbt.exploration.models.config_tree.nodes import Node 8 | 9 | 10 | class Float(Node): 11 | def __init__(self, name, low, high, width=20): 12 | self.low = low 13 | self.high = high 14 | 15 | self.kde = KernelDensity((high - low) / width) 16 | 17 | self.kde.fit(np.array([high-low])[:, None]) 18 | 19 | super().__init__(name) 20 | 21 | def sample(self, result): 22 | value = float('inf') 23 | while value < self.low or value > self.high: 24 | value = float(self.kde.sample()) 25 | result[self.name] = value 26 | 27 | def uniform_sample(self, result): 28 | result[self.name] = float(np.random.uniform( 29 | low=self.low, high=self.high)) 30 | 31 | def evaluate(self, data, scores): 32 | single_scores = self.kde.score_samples( 33 | np.array([point[self.name] for point in data])[:, None]) 34 | for i, single_score in enumerate(single_scores): 35 | scores[i] += single_score 36 | 37 | def fit(self, values): 38 | self.kde.fit(np.array(values)[:, None]) 39 | 40 | def get_children(self): 41 | return [] 42 | 43 | def structural_copy(self): 44 | return Float(self.name, self.low, self.high) 45 | 46 | 47 | if __name__ == '__main__': 48 | node = Float('lr', low=1e-5, high=1e-3) 49 | node.fit([0.0005, 1e-5, 1e-5]) 50 | results = [] 51 | for i in range(100000): 52 | r = {} 53 | node.sample(r) 54 | results.append(r['lr']) 55 | 56 | hist, edges = np.histogram(results, bins=100) 57 | 58 | plot = Figure() 59 | plot.quad(top=hist, bottom=0, left=edges[:-1], right=edges[1:]) 60 | show(plot) 61 | -------------------------------------------------------------------------------- /pbt/exploration/models/config_tree/nodes/integer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.neighbors import KernelDensity 3 | 4 | from pbt.exploration.models.config_tree.nodes import Node 5 | 6 | 7 | class Integer(Node): 8 | def __init__(self, name, low, high, width=20): 9 | self.low = low 10 | self.high = high 11 | self.kde = KernelDensity((high - low) / width) 12 | 13 | self.kde.fit(np.array([high-low])[:, None]) 14 | 15 | super().__init__(name) 16 | 17 | def sample(self, result): 18 | value = float('inf') 19 | while value < self.low or value > self.high: 20 | value = int(self.kde.sample()) 21 | result[self.name] = value 22 | 23 | def uniform_sample(self, result): 24 | result[self.name] = int(np.round( 25 | np.random.uniform(low=self.low, high=self.high))) 26 | 27 | def evaluate(self, data, scores): 28 | single_scores = self.kde.score_samples( 29 | np.array([point[self.name] for point in data])[:, None]) 30 | for i, single_score in enumerate(single_scores): 31 | scores[i] += single_score 32 | 33 | def fit(self, values): 34 | self.kde.fit(np.array(values)[:, None]) 35 | 36 | def get_children(self): 37 | return [] 38 | 39 | def structural_copy(self): 40 | return Integer(self.name, self.low, self.high) 41 | -------------------------------------------------------------------------------- /pbt/exploration/models/config_tree/nodes/log_float.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pbt.exploration.models.config_tree.nodes import Float 4 | 5 | 6 | class LogFloat(Float): 7 | def __init__(self, name, low, high, width=20): 8 | if low <= 0: 9 | raise ValueError('"low" has to be greater than 0!') 10 | super().__init__(name, low, high, width) 11 | 12 | def sample(self, result): 13 | value = float('inf') 14 | while value < self.low or value > self.high: 15 | value = float(10**self.kde.sample()) 16 | result[self.name] = value 17 | 18 | def uniform_sample(self, result): 19 | result[self.name] = float(10**np.random.uniform( 20 | low=np.log10(self.low), high=np.log10(self.high))) 21 | return result[self.name] 22 | 23 | def evaluate(self, data, scores): 24 | single_scores = self.kde.score_samples( 25 | np.array([np.log10(point[self.name]) for point in data])[:, None]) 26 | for i, single_score in enumerate(single_scores): 27 | scores[i] += single_score 28 | 29 | def fit(self, values): 30 | self.kde.fit(np.array([np.log10(value) for value in values])[:, None]) 31 | 32 | def get_children(self): 33 | return [] 34 | 35 | def structural_copy(self): 36 | return LogFloat(self.name, self.low, self.high) 37 | -------------------------------------------------------------------------------- /pbt/exploration/models/config_tree/nodes/node.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Node: 5 | def __init__(self, name): 6 | self.name = name 7 | self.index = None 8 | 9 | @abc.abstractmethod 10 | def sample(self, result): 11 | raise NotImplementedError 12 | 13 | @abc.abstractmethod 14 | def uniform_sample(self, result): 15 | raise NotImplementedError 16 | 17 | @abc.abstractmethod 18 | def evaluate(self, data, scores): 19 | raise NotImplementedError 20 | 21 | @abc.abstractmethod 22 | def fit(self, values): 23 | raise NotImplementedError 24 | 25 | @abc.abstractmethod 26 | def get_children(self): 27 | raise NotImplementedError 28 | 29 | @abc.abstractmethod 30 | def structural_copy(self): 31 | raise NotImplementedError 32 | -------------------------------------------------------------------------------- /pbt/exploration/models/model.py: -------------------------------------------------------------------------------- 1 | class Model: 2 | def sample(self): 3 | raise NotImplementedError 4 | -------------------------------------------------------------------------------- /pbt/exploration/models/tree_parzen_estimator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | 6 | from pbt.exploration.models import Model 7 | from pbt.tqdm_logger import TqdmLoggingHandler 8 | 9 | class TreeParzenEstimator(Model): 10 | def __init__( 11 | self, config_tree, best_percent=0.2, uniform_percent=0.25, 12 | sample_size=10, new_samples_until_update=10, window_size=20, 13 | mode='improvement', split='time_step'): 14 | self.l_tree = config_tree 15 | self.g_tree = config_tree.structural_copy() 16 | 17 | self._best_percent = best_percent 18 | self._uniform_percent = uniform_percent 19 | self._sample_size = sample_size 20 | self._window_size = window_size 21 | 22 | self._new_samples_until_update = new_samples_until_update 23 | self._counter = 0 24 | 25 | self.data = defaultdict(list) 26 | 27 | self.logger = logging.getLogger('pbt') 28 | self.logger.setLevel(logging.DEBUG) 29 | self.logger.addHandler(TqdmLoggingHandler()) 30 | if mode == 'improvement': 31 | self._mode = 2 32 | else: 33 | self._mode = 0 34 | 35 | if split != 'time_step': 36 | self._split_function = self._split_data 37 | else: 38 | self._split_function = self._split_by_time_step 39 | 40 | def sample(self): 41 | if np.random.random() < self._uniform_percent: 42 | self.logger.debug('Using random sampling.') 43 | return self.l_tree.uniform_sample() 44 | self.logger.debug('TPE sampling:') 45 | samples = [self.l_tree.sample() for _ in range(self._sample_size)] 46 | self.logger.debug(f'Samples: {samples}') 47 | l_scores = self.l_tree.evaluate(samples) 48 | g_scores = self.g_tree.evaluate(samples) 49 | scores = [l/g for l, g in zip(l_scores, g_scores)] 50 | self.logger.debug(f'Scores: {scores}') 51 | return samples[np.argmin(scores)] 52 | 53 | def update(self, trial): 54 | self.data[trial[1]].append(trial) 55 | 56 | self._counter += 1 57 | if self._counter >= self._new_samples_until_update: 58 | self._fit_data() 59 | self._counter = 0 60 | 61 | def _fit_data(self): 62 | indices = range( 63 | max(0, len(self.data) - self._window_size), 64 | len(self.data)) 65 | sliding_window = [self.data[x] for x in indices] 66 | good, bad = self._split_function(sliding_window) 67 | self.l_tree.fit(good) 68 | self.g_tree.fit(bad) 69 | 70 | def _split_by_time_step(self, data): 71 | good, bad = [], [] 72 | for time_step in data: 73 | new_good, new_bad = self._split_data([time_step]) 74 | good += new_good 75 | bad += new_bad 76 | return good, bad 77 | 78 | def _split_data(self, data): 79 | scores = sorted([ 80 | trial[self._mode] 81 | for time_step in data for trial in time_step], reverse=True) 82 | pivot_score = scores[int(len(scores) * self._best_percent)] 83 | good, bad = [], [] 84 | for time_step in data: 85 | for point in time_step: 86 | if point[self._mode] >= pivot_score: 87 | good.append(point) 88 | if point[self._mode] <= pivot_score: 89 | bad.append(point) 90 | return good, bad 91 | 92 | 93 | -------------------------------------------------------------------------------- /pbt/exploration/perturb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Perturb: 4 | """ 5 | A simple perturb mechanism as specified in (Jaderberg et al., 2017). 6 | """ 7 | def __init__(self, cs_space=None, boundaries={}): 8 | self.boundaries = boundaries 9 | self.cs_space = cs_space 10 | 11 | def __call__(self, hyperparameters: dict) -> dict: 12 | """ 13 | Perturb the nodes in the input. 14 | :param hyperparameters: A dict with nodes. 15 | :return: The perturbed nodes. 16 | """ 17 | result = hyperparameters.copy() 18 | 19 | for key in hyperparameters: 20 | temp_value = self.cs_space[key]._inverse_transform(result[key]) 21 | temp_value += np.random.choice([-1, 1]) * 0.2 * temp_value 22 | result[key] = self.cs_space[key]._transform(temp_value) 23 | self.ensure_boundaries(result) 24 | return result 25 | 26 | def ensure_boundaries(self, result): 27 | for key in result: 28 | if key not in self.boundaries: 29 | continue 30 | if result[key] < self.boundaries[key][0]: 31 | result[key] = self.boundaries[key][0] 32 | elif result[key] > self.boundaries[key][1]: 33 | result[key] = self.boundaries[key][1] 34 | -------------------------------------------------------------------------------- /pbt/exploration/perturb_and_resample.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from pbt.exploration import ExplorationStrategy, Perturb, Resample 4 | 5 | 6 | class PerturbAndResample(ExplorationStrategy): 7 | def __init__( 8 | self, mutations: dict, cs_space: dict, resample_probability: float = 0.25, 9 | boundaries={}): 10 | """ 11 | A strategy to do both perturb and resample. 12 | :param mutations: A dictionary with hyperparameter names and mutations 13 | :param resample_probability: The probability to resample for each call 14 | """ 15 | self.resample_probability = resample_probability 16 | 17 | self.perturb = Perturb(cs_space=cs_space, boundaries=boundaries) 18 | self.resample = Resample(mutations=mutations) 19 | 20 | def __call__(self, hyperparameters: dict) -> dict: 21 | """ 22 | Perturb all nodes specified by mutations and then resample 23 | each hyperparameter depending on the resample_probability. 24 | :param hyperparameters: The nodes to perturb 25 | :return: Perturbed and probably resampled nodes 26 | """ 27 | result = self.perturb(hyperparameters) 28 | 29 | if random.random() < self.resample_probability: 30 | result = self.resample(result) 31 | 32 | return result 33 | -------------------------------------------------------------------------------- /pbt/exploration/resample.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class Resample: 5 | def __init__(self, mutations: dict): 6 | """ 7 | A simple resample mechanism as specified in (Jaderberg et al., 2017). 8 | :param mutations: A dict of all nodes and its mutations. 9 | """ 10 | self.mutations = mutations 11 | 12 | def __call__(self, hyperparameters: dict) -> dict: 13 | """ 14 | Resample nodes given by the specified mutations. 15 | :param hyperparameters: All nodes 16 | :return: All nodes with specified nodes resampled 17 | """ 18 | result = hyperparameters.copy() 19 | 20 | for key, value in self.mutations.items(): 21 | result[key] = value() if callable(value) else random.choice(value) 22 | 23 | return result 24 | -------------------------------------------------------------------------------- /pbt/garbage_collector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class GarbageCollector: 6 | def __init__(self, data_path): 7 | self.data_path = data_path 8 | 9 | def collect(self, member_id, min_timestep, elites): 10 | if min_timestep - 1 < 0: 11 | return 12 | else: 13 | for i in range(min_timestep-1): 14 | if member_id in [elite.member_id for elite in elites] and i in [elite.time_step for elite in elites]: 15 | continue 16 | member_path = os.path.join(self.data_path, str(member_id), str(i)) 17 | for file_name in ["state_dict.npz", "traj_acs.json", "traj_obs.json", "traj_rews.json"]: 18 | file_path = os.path.join(member_path, file_name) 19 | if os.path.exists(file_path): 20 | os.remove(file_path) 21 | #else: 22 | # print("Can not delete the file in path %s as it doesn't exists" %file_path) 23 | 24 | def _get_integer_dirs(self, path): 25 | return [int(i) for i in os.listdir(path) if i.isdigit()] 26 | -------------------------------------------------------------------------------- /pbt/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .controller_adapter import ControllerAdapter 2 | from .worker_adapter import WorkerAdapter 3 | 4 | from .daemon import Daemon 5 | from .controller_daemon import ControllerDaemon, CONTROLLER_URI_FILENAME 6 | from .worker_daemon import WorkerDaemon 7 | 8 | __all__ = [ 9 | 'ControllerAdapter', 'WorkerAdapter', 10 | 'Daemon', 'ControllerDaemon', 'WorkerDaemon', 'CONTROLLER_URI_FILENAME'] 11 | -------------------------------------------------------------------------------- /pbt/network/controller_adapter.py: -------------------------------------------------------------------------------- 1 | from Pyro4.errors import ConnectionClosedError, CommunicationError 2 | 3 | from pbt.population import NoTrial 4 | 5 | 6 | class ControllerAdapter: 7 | def __init__(self, controller): 8 | self._controller = controller 9 | 10 | def register_worker_by_uri(self, uri): 11 | try: 12 | self._controller.register_worker_by_uri(uri) 13 | return True 14 | except CommunicationError: 15 | # URI from file is not valid -> kill worker 16 | return False 17 | 18 | def request_trial(self): 19 | try: 20 | return self._controller.request_trial() 21 | except ConnectionClosedError: 22 | return NoTrial().to_tuple() 23 | 24 | def send_evaluation(self, member_id, score): 25 | try: 26 | self._controller.send_evaluation(member_id, float(score)) 27 | except ConnectionClosedError: 28 | return 29 | -------------------------------------------------------------------------------- /pbt/network/controller_daemon.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import Pyro4 4 | 5 | from pbt.network import Daemon, WorkerAdapter 6 | from pbt.tqdm_logger import TqdmLoggingHandler 7 | CONTROLLER_URI_FILENAME = 'controller_uri.txt' 8 | 9 | 10 | @Pyro4.expose 11 | class ControllerDaemon(Daemon): 12 | def __init__(self, controller): 13 | self.logger = logging.getLogger('pbt') 14 | self.logger.setLevel(logging.DEBUG) 15 | # self.logger.addHandler(TqdmLoggingHandler()) 16 | self.controller = controller 17 | self.pyro_daemon = None 18 | 19 | def start(self): 20 | Pyro4.config.SERVERTYPE = 'multiplex' 21 | self.pyro_daemon = Pyro4.Daemon(host=self._get_hostname()) 22 | uri = self.pyro_daemon.register(self) 23 | self._save_pyro_uri(uri) 24 | self.pyro_daemon.requestLoop() 25 | 26 | def register_worker_by_uri(self, uri): 27 | self.controller.register_worker(WorkerAdapter(Pyro4.Proxy(uri))) 28 | 29 | def request_trial(self): 30 | return self.controller.request_trial() 31 | 32 | def send_evaluation(self, member_id, score): 33 | self.controller.send_evaluation(member_id, score) 34 | 35 | def shut_down(self): 36 | self.pyro_daemon.shutdown() 37 | 38 | def _save_pyro_uri(self, uri): 39 | save_path = os.path.join( 40 | self.controller.data_path, CONTROLLER_URI_FILENAME) 41 | if not os.path.isdir(self.controller.data_path): 42 | os.makedirs(self.controller.data_path) 43 | with open(save_path, 'w') as f: 44 | f.write(str(uri)) 45 | self.logger.info(f'Saved pyro uri at {save_path}.') 46 | -------------------------------------------------------------------------------- /pbt/network/daemon.py: -------------------------------------------------------------------------------- 1 | import socket 2 | 3 | 4 | class Daemon: 5 | def _get_hostname(self): 6 | my_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 7 | my_socket.connect(('10.255.255.255', 1)) 8 | ip = my_socket.getsockname()[0] 9 | my_socket.close() 10 | return ip 11 | -------------------------------------------------------------------------------- /pbt/network/worker_adapter.py: -------------------------------------------------------------------------------- 1 | from Pyro4.errors import CommunicationError 2 | 3 | 4 | class WorkerAdapter: 5 | def __init__(self, worker): 6 | self._worker = worker 7 | 8 | @property 9 | def worker_id(self): 10 | return self._worker.worker_id 11 | 12 | def ping(self): 13 | try: 14 | return self._worker.ping() 15 | except CommunicationError: 16 | return False 17 | 18 | def stop(self): 19 | try: 20 | self._worker.stop() 21 | except CommunicationError: 22 | pass 23 | -------------------------------------------------------------------------------- /pbt/network/worker_daemon.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | 3 | import Pyro4 4 | 5 | from pbt.network import Daemon 6 | 7 | 8 | @Pyro4.expose 9 | class WorkerDaemon(Daemon): 10 | def __init__(self, worker): 11 | self.worker = worker 12 | self.pyro_daemon = None 13 | 14 | @property 15 | def worker_id(self): 16 | return self.worker.worker_id 17 | 18 | def start(self): 19 | self.pyro_daemon = Pyro4.Daemon(host=self._get_hostname()) 20 | uri = self.pyro_daemon.register(self) 21 | thread = Thread(target=self.pyro_daemon.requestLoop) 22 | thread.start() 23 | return uri 24 | 25 | def ping(self): 26 | return True 27 | 28 | def stop(self): 29 | self.worker.stop() 30 | self.pyro_daemon.shutdown() 31 | -------------------------------------------------------------------------------- /pbt/population/__init__.py: -------------------------------------------------------------------------------- 1 | from .trial import Trial, NoTrial 2 | from .member import Member 3 | from .population import Population 4 | 5 | __all__ = ['Trial', 'NoTrial', 'Member', 'Population'] 6 | -------------------------------------------------------------------------------- /pbt/population/member.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | class Member: 6 | def __init__(self, member_id, stop): 7 | self.member_id = member_id 8 | self.stop = stop 9 | 10 | self.trials = {} 11 | self.time_step = 0 12 | self.max_time_step = None 13 | self.last_score = 0 14 | self._actual_time_step = 0 15 | 16 | self.is_free = True 17 | self.is_done = False 18 | 19 | def __repr__(self): 20 | return f'' 21 | 22 | def assign_trial(self, trial, last_score): 23 | self.is_free = False 24 | self.trials[trial.time_step] = trial 25 | self.last_score = last_score 26 | 27 | def reached_time_step(self, time_step): 28 | if time_step not in self.trials: 29 | return False 30 | return self.trials[time_step].score is not None 31 | 32 | def get_last_score(self): 33 | if self.time_step == 0: 34 | return -float('inf') 35 | return self.trials[self.time_step - 1].score 36 | 37 | def get_score_by_time_step(self, time_step): 38 | if self.time_step == 0: 39 | return -float('inf') 40 | if time_step not in self.trials: 41 | return None 42 | return self.trials[time_step].score 43 | 44 | def get_last_trial(self): 45 | return self.trials[self.time_step - 1] 46 | 47 | def get_all_scores(self): 48 | return [trial.score for trial in self.trials] 49 | 50 | def get_best_trial(self): 51 | valid_trials = [trial for trial in self.trials.values() if trial.score is not None] 52 | if len(valid_trials) == 0: 53 | return None 54 | else: 55 | return max(valid_trials, key=lambda trial: trial.score) 56 | 57 | # def get_best_trial_by_delta_t(self, delta_t): 58 | # return sorted(self.trials[::delta_t], key=lambda trial: trial.score)[-1] 59 | 60 | def get_hyperparameters_by_time_step(self, time_step): 61 | if time_step not in self.trials: 62 | return None 63 | return self.trials[time_step].hyperparameters 64 | 65 | def get_hyperparameters(self): 66 | return self.trials[self.time_step - 1].hyperparameters 67 | 68 | def save_score(self, score): 69 | trial = self.trials[self.time_step] 70 | trial.score = score 71 | trial.improvement = score - self.last_score 72 | self.time_step += 1 73 | self._actual_time_step += 1 74 | 75 | # TODO: Move this to own function 76 | if self.max_time_step: 77 | if self._actual_time_step >= self.max_time_step: 78 | self.is_done = True 79 | else: 80 | self.is_free = True 81 | else: 82 | if self.stop(self._actual_time_step, score): 83 | self.is_done = True 84 | else: 85 | self.is_free = True 86 | 87 | return trial 88 | 89 | def set_actual_time_step(self, time_step): 90 | self._actual_time_step = time_step 91 | 92 | def get_actual_time_step(self): 93 | return self._actual_time_step 94 | 95 | 96 | def log_last_result(self, results_path): 97 | last_result = self.trials[self.time_step - 1] 98 | member_path = os.path.join(results_path, str(self.member_id)) 99 | step_path = os.path.join(member_path, str(last_result.time_step)) 100 | if not os.path.isdir(step_path): 101 | os.makedirs(step_path) 102 | with open(os.path.join(member_path, 'scores.txt'), 'a+') as f: 103 | f.write(f'{last_result.score}\n') 104 | with open(os.path.join(step_path, 'nodes.json'), 'w') as f: 105 | json.dump(last_result.hyperparameters, f) 106 | with open(os.path.join(step_path, 'add.json'), 'w') as f: 107 | json.dump( 108 | {'copied from': self.trials[self.time_step - 1].model_id, 109 | 'time step:': self.trials[self.time_step - 1].model_time_step}, f) 110 | 111 | def _create_hyperparameters(self): 112 | if self.time_step == 0: 113 | return self.exploration.get_start_hyperparameters() 114 | return self.exploration() 115 | -------------------------------------------------------------------------------- /pbt/population/population.py: -------------------------------------------------------------------------------- 1 | from pbt.population import Member 2 | 3 | 4 | class Population: 5 | def __init__(self, size, stop, results_path, mode='asynchronous', backtracking=False, elite_ratio=0.1): 6 | self.members = [Member(i, stop) for i in range(size)] 7 | self.mode = mode 8 | self.results_path = results_path 9 | self.backtracking = backtracking 10 | self.elite_size = max(round(size * elite_ratio), 1) 11 | 12 | def get_next_member(self): 13 | time_steps = [member.time_step for member in self.members] 14 | for time_step in range(min(time_steps), max(time_steps) + 1): 15 | for member in self.members: 16 | if member.is_free and member.time_step == time_step: 17 | return member 18 | else: 19 | if self.mode == 'synchronous': 20 | return None 21 | 22 | def get_scores(self): 23 | return { 24 | member.member_id: member.get_last_score() 25 | for member in self.members} 26 | 27 | def get_scores_by_time_step(self, time_step): 28 | scores = {} 29 | for member in self.members: 30 | member_score = member.get_score_by_time_step(time_step) 31 | if member_score: 32 | scores[member.member_id] = member_score 33 | return scores 34 | 35 | def get_hyperparameters(self, member_id): 36 | return self.members[member_id].get_hyperparameters() 37 | 38 | def get_hyperparameters_by_time_step(self, member_id, time_step): 39 | return self.members[member_id].get_hyperparameters_by_time_step(time_step) 40 | 41 | def get_latest_time_step(self, model_id): 42 | return self.members[model_id].time_step 43 | 44 | def get_min_time_step(self): 45 | return min([member.time_step for member in self.members]) 46 | 47 | def save_trial(self, trial): 48 | self.members[trial.member_id].assign_trial( 49 | trial, self._get_last_score(trial)) 50 | 51 | def _get_last_score(self, trial): 52 | member_of_model = self.members[trial.model_id] 53 | if member_of_model.time_step == 0: 54 | return 0.0 55 | else: 56 | return member_of_model.get_last_score() 57 | 58 | def update(self, member_id, score): 59 | current_member = self.members[member_id] 60 | trial = current_member.save_score(score) 61 | current_member.log_last_result(self.results_path) 62 | if current_member.is_done and not current_member.max_time_step: 63 | self._set_max_time_step(current_member.time_step) 64 | return trial 65 | 66 | def is_done(self): 67 | return all(member.is_done for member in self.members) 68 | 69 | def _set_max_time_step(self, max_time_step): 70 | for member in self.members: 71 | member.max_time_step = max_time_step 72 | 73 | def get_elites(self): 74 | # For PBT-BT only, return empty list if it's not in the PBT-BT mode 75 | if not self.backtracking: 76 | return [] 77 | else: 78 | best_trials = self.get_best_trials() 79 | return best_trials[:self.elite_size] 80 | # return sorted(best_trials, key=lambda trial: trial.score, reverse=True)[:self.elite_size] 81 | 82 | # def get_elites_by_delta_t(self, delta_t): 83 | # # For PBT-BT only, return empty list if it's not in the PBT-BT mode 84 | # if not self.backtracking: 85 | # return [] 86 | # else: 87 | # best_trials = self.get_best_trials_by_delta_t(delta_t) 88 | # return sorted(best_trials, key=lambda trial: trial.score, reverse=True)[:self.elite_size] 89 | 90 | def get_best_trials(self): 91 | return [member.get_best_trial() for member in self.members if member.get_best_trial() is not None] 92 | 93 | # def get_best_trials_by_delta_t(self, delta_t): 94 | # return [member.get_best_trial_by_delta_t(delta_t) for member in self.members] -------------------------------------------------------------------------------- /pbt/population/trial.py: -------------------------------------------------------------------------------- 1 | class Trial: 2 | def __init__( 3 | self, member_id, model_id, time_step, model_time_step, 4 | hyperparameters): 5 | self.member_id = member_id 6 | self.model_id = model_id 7 | self.time_step = time_step 8 | self.model_time_step = model_time_step 9 | self.hyperparameters = self._clean_hyperparameters(hyperparameters) 10 | self.score = None 11 | self.improvement = None 12 | 13 | def __repr__(self): 14 | return f'' 22 | 23 | def _clean_hyperparameters(self, hyperparameters): 24 | result = hyperparameters.copy() 25 | for key, value in result.items(): 26 | if callable(value): 27 | result[key] = value() 28 | return result 29 | 30 | def is_valid(self): 31 | return True 32 | 33 | def to_tuple(self): 34 | """ 35 | Return this trial as tuple (easier to send over network). 36 | :return: (member_id, model_id, time_step, nodes, score) 37 | """ 38 | return \ 39 | self.member_id, self.model_id, self.time_step, \ 40 | self.model_time_step, self.hyperparameters 41 | 42 | def to_array(self): 43 | result = [self.score, self.time_step, self.improvement] 44 | result += [ 45 | self.hyperparameters[key] 46 | for key in sorted(self.hyperparameters)] 47 | return result 48 | 49 | @staticmethod 50 | def from_tuple( 51 | member_id, model_id, time_step, model_time_step, hyperparameters): 52 | return Trial( 53 | member_id, model_id, time_step, model_time_step, hyperparameters) \ 54 | if member_id is not -1 \ 55 | else NoTrial() 56 | 57 | def copy(self): 58 | return Trial( 59 | self.member_id, self.model_id, self.time_step, self.model_time_step, 60 | self.hyperparameters.copy()) 61 | 62 | 63 | class NoTrial(Trial): 64 | def __init__(self): 65 | super().__init__(-1, -1, -1, -1, {}) 66 | 67 | def is_valid(self): 68 | return False 69 | -------------------------------------------------------------------------------- /pbt/scheduler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pbt.population import NoTrial, Trial 4 | from pbt.tqdm_logger import TqdmLoggingHandler 5 | 6 | class Scheduler: 7 | def __init__( 8 | self, population, start_hyperparameters, exploitation, exploration): 9 | self.logger = logging.getLogger('pbt') 10 | self.logger.setLevel(logging.DEBUG) 11 | self.logger.addHandler(TqdmLoggingHandler()) 12 | self.population = population 13 | self.exploitation = exploitation 14 | self.exploration = exploration 15 | self.start_hyperparameters = start_hyperparameters 16 | 17 | def get_trial(self): 18 | self.logger.debug('Trial requested.') 19 | member = self.population.get_next_member() 20 | if not member: 21 | self.logger.debug('No trial ready.') 22 | return NoTrial() 23 | 24 | if member.time_step == 0: 25 | start_hyperparameters = {cfg_name : self.start_hyperparameters[cfg_name]() 26 | for cfg_name in self.start_hyperparameters.keys()} 27 | trial = Trial( 28 | member.member_id, -1, 0, -1, start_hyperparameters) 29 | self.population.save_trial(trial) 30 | self.logger.debug(f'Returning first trial {trial}.') 31 | return trial 32 | 33 | self.logger.debug(f'Generating trial for member {member.member_id}.') 34 | scores = self.population.get_scores_by_time_step(member.time_step - 1) 35 | # model_id indicates the model that we want to copy 36 | model_id = self.exploitation(member.member_id, scores) 37 | # model_time_step = self.population.get_latest_time_step(model_id) - 1 38 | model_time_step = member.time_step - 1 39 | hyperparameters = self.population.get_hyperparameters_by_time_step(model_id, model_time_step) 40 | if model_id != member.member_id: 41 | self.logger.debug(f'Copying model {model_id}.') 42 | hyperparameters = self.exploration(hyperparameters) 43 | self.logger.debug(f'Using exploration. New: {hyperparameters}') 44 | else: 45 | self.logger.debug(f'Staying with current model {model_id}.') 46 | trial = Trial( 47 | member.member_id, model_id, member.time_step, model_time_step, 48 | hyperparameters) 49 | self.population.save_trial(trial) 50 | return trial 51 | 52 | def update_exploration(self, trial): 53 | self.exploration.update(trial) 54 | -------------------------------------------------------------------------------- /pbt/tqdm_logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import tqdm 3 | 4 | class TqdmLoggingHandler(logging.Handler): 5 | def __init__(self, level=logging.NOTSET): 6 | super().__init__(level) 7 | 8 | def emit(self, record): 9 | try: 10 | msg = self.format(record) 11 | tqdm.tqdm.write(msg) 12 | self.flush() 13 | except (KeyboardInterrupt, SystemExit): 14 | raise 15 | except: 16 | self.handleError(record) -------------------------------------------------------------------------------- /pbt/worker.py: -------------------------------------------------------------------------------- 1 | from config_space.config_space import DEFAULT_CONFIGSPACE 2 | import numpy as np 3 | 4 | import logging 5 | import os 6 | from time import sleep 7 | 8 | import Pyro4 9 | import json 10 | 11 | from pbt.network import WorkerDaemon, ControllerAdapter, CONTROLLER_URI_FILENAME 12 | from pbt.population import Trial 13 | from pbt.tqdm_logger import TqdmLoggingHandler 14 | from tqdm import tqdm 15 | from scipy.io import savemat 16 | 17 | 18 | class Criterion: 19 | def __init__(self, criterion_mode, **kwargs): 20 | self.criterion_mode = criterion_mode 21 | self.kwargs = kwargs 22 | 23 | def __call__(self, traj_rets, traj_eval_rets=None, info=None): 24 | # Concatenate train and eval returns 25 | if traj_eval_rets is None: 26 | all_rets = traj_rets 27 | else: 28 | all_rets = np.concatenate([traj_rets, traj_eval_rets], axis=1) 29 | if self.criterion_mode == "max": 30 | # Take the maximal score over the past 31 | score = np.mean(all_rets, axis=1).max().item() 32 | elif self.criterion_mode == 'mean': 33 | # Take the average score over the past 34 | score = np.mean(all_rets, axis=1).mean().item() 35 | elif self.criterion_mode == 'lastk': 36 | last_k = self.kwargs.get('last_k', 1) 37 | score = np.mean(all_rets, axis=1)[-last_k:].mean().item() 38 | #elif self.criterion_mode == 'weighted_return': 39 | # 40 | else: 41 | raise NotImplementedError("%s is an invalid criterion mode" %self.criterion_mode) 42 | return score 43 | 44 | class PBTWorker: 45 | # TODO: Doc, worker that sequentially calls step and evaluate 46 | def __init__(self, worker_id, agent, policy_constructor, train_func, criterion_mode='mean', data_path=os.getcwd(), 47 | wait_time=5, initial_step=4, step=1, not_copy_data=False, **kwargs): 48 | self.logger = logging.getLogger('pbt') 49 | self.logger.setLevel(logging.DEBUG) 50 | self.logger.addHandler(TqdmLoggingHandler()) 51 | self.worker_id = worker_id 52 | self.agent = agent 53 | self.policy_constructor = policy_constructor 54 | self.train_func = train_func 55 | self.data_path = data_path 56 | self.wait_time = wait_time 57 | self.initial_step = initial_step 58 | self.step = step 59 | # TODO: _is_done 60 | self.is_done = False 61 | self._controller = None 62 | self.not_copy_data = not_copy_data 63 | # Ensure safety of criterion mode lastk 64 | if self.not_copy_data: 65 | self.logger.info('NOTICE: We are not copying data in this experiment.') 66 | if kwargs.get('last_k', 0) > initial_step + 1: 67 | raise ValueError("criterion of last k must have at least same as initial_step + 1") 68 | self.criterion = Criterion(criterion_mode, **kwargs) 69 | 70 | def register(self, controller=None): 71 | if controller: 72 | self.logger.info('Registered controller directly.') 73 | self._controller = controller 74 | self._controller.register_worker(self) 75 | else: 76 | self.logger.info('Registered controller over network.') 77 | self._controller = ControllerAdapter(self._discover_controller()) 78 | self._run_daemon() 79 | 80 | def _run_daemon(self): 81 | self.logger.info('Starting worker daemon.') 82 | daemon = WorkerDaemon(self) 83 | uri = daemon.start() 84 | success = self._controller.register_worker_by_uri(uri) 85 | if not success: 86 | daemon.stop() 87 | raise Exception(f'The read controller URI "{uri}" is not valid!') 88 | 89 | def run(self): 90 | if not self._controller: 91 | self.register() 92 | 93 | while not self.is_done: 94 | self._run_iteration() 95 | 96 | def stop(self): 97 | self.logger.info('Shutting down worker.') 98 | self.is_done = True 99 | 100 | def _run_iteration(self): 101 | trial = self._load_trial() 102 | if not trial: 103 | return 104 | 105 | hyperparameters = trial.hyperparameters 106 | policy = self.policy_constructor(hyperparameters, DEFAULT_CONFIGSPACE) 107 | if trial.time_step == 0: 108 | traj_obs, traj_acs, traj_rets, traj_rews, info = self.train_func(self.agent, policy, step=self.initial_step) 109 | else: 110 | path = self._get_last_model_path(trial) 111 | policy.load_model(path) 112 | traj_obs, traj_acs, traj_rets, traj_rews = self.load_training_data(path) 113 | traj_obs, traj_acs, traj_rets, traj_rews, info = self.train_func(self.agent, policy, 114 | traj_obs, traj_acs, traj_rets, traj_rews, step=self.step) 115 | score = self.criterion(traj_rets, info=info) 116 | 117 | save_path = self._create_model_path(trial) 118 | policy.save_model(save_path) 119 | self.save_training_data(save_path, traj_obs, traj_acs, traj_rets, traj_rews) 120 | self.save_infomation(save_path, info) 121 | self._send_evaluation(trial.member_id, score) 122 | 123 | 124 | def load_training_data(self, path): 125 | with open(os.path.join(path, "traj_obs.json"), 'r') as f: 126 | traj_obs = json.load(f) 127 | with open(os.path.join(path, "traj_acs.json"), 'r') as f: 128 | traj_acs = json.load(f) 129 | with open(os.path.join(path, "traj_rets.json"), 'r') as f: 130 | traj_rets = json.load(f) 131 | with open(os.path.join(path, "traj_rews.json"), 'r') as f: 132 | traj_rews = json.load(f) 133 | 134 | return traj_obs, traj_acs, traj_rets, traj_rews 135 | 136 | def save_training_data(self, path, traj_obs, traj_acs, traj_rets, traj_rews): 137 | with open(os.path.join(path, "traj_obs.json"), 'w') as f: 138 | json.dump(traj_obs, f) 139 | with open(os.path.join(path, "traj_acs.json"), 'w') as f: 140 | json.dump(traj_acs, f) 141 | with open(os.path.join(path, "traj_rets.json"), 'w') as f: 142 | json.dump(traj_rets, f) 143 | with open(os.path.join(path, "traj_rews.json"), 'w') as f: 144 | json.dump(traj_rews, f) 145 | 146 | def save_infomation(self, path, info): 147 | savemat( 148 | os.path.join(path, "infos.mat"), 149 | { 150 | key : info[key] for key in info.keys() 151 | }, 152 | long_field_names=True 153 | ) 154 | 155 | def _load_trial(self): 156 | trial = Trial.from_tuple(*self._get_trial()) 157 | 158 | if not trial.is_valid(): 159 | self.logger.info( 160 | f'No trial ready. Waiting for {self.wait_time} seconds.') 161 | sleep(self.wait_time) 162 | else: 163 | self.logger.info(f'Got valid trial {trial} from controller.') 164 | return trial 165 | 166 | def _discover_controller(self): 167 | self.logger.debug('Discovering controller.') 168 | file_path = os.path.join(self.data_path, CONTROLLER_URI_FILENAME) 169 | tqdm.write(file_path) 170 | from time import sleep 171 | sleep(5) 172 | for number_of_try in range(5): 173 | try: 174 | with open(file_path, 'r') as f: 175 | uri = f.readline().strip() 176 | break 177 | except FileNotFoundError: 178 | self.logger.info('Can\'t reach controller. Waiting ...') 179 | sleep(5) 180 | if number_of_try < 4: 181 | continue 182 | raise Exception('Can\'t reach controller!') 183 | return Pyro4.Proxy(uri) 184 | 185 | def _get_trial(self): 186 | # TODO: Intercept connection issues 187 | return self._controller.request_trial() 188 | 189 | def _get_last_model_path(self, trial): 190 | # TODO: Remove ambiguity with model_id <-> member_id 191 | if self.not_copy_data: 192 | path = os.path.join( 193 | self.data_path, str(trial.member_id), str(trial.model_time_step)) 194 | else: 195 | path = os.path.join( 196 | self.data_path, str(trial.model_id), str(trial.model_time_step)) 197 | return path 198 | 199 | def _create_model_path(self, trial): 200 | path = os.path.join( 201 | self.data_path, str(trial.member_id), str(trial.time_step)) 202 | if not os.path.isdir(path): 203 | os.makedirs(path) 204 | return path 205 | 206 | def _create_extra_data_path(self, trial): 207 | path = os.path.join( 208 | self.data_path, 'extra', str(trial.member_id), str(trial.time_step)) 209 | if not os.path.isdir(path): 210 | os.makedirs(path) 211 | return path 212 | 213 | def _send_evaluation(self, member_id, score): 214 | # TODO: Handle connection issues 215 | self.logger.info(f'Sending evaluation. Score: {score}') 216 | self._controller.send_evaluation(member_id, float(score)) 217 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astor==0.8.1 3 | astunparse==1.6.3 4 | cachetools==4.0.0 5 | chardet==3.0.4 6 | cloudpickle==1.2.1 7 | configspace==0.4.10 8 | cycler==0.10.0 9 | cython==0.29.13 10 | dataclasses==0.6 11 | decorator==4.4.0 12 | dm-env==1.2 13 | dm-tree==0.1.5 14 | dotmap==1.2.20 15 | future==0.16.0 16 | gast==0.2.2 17 | glfw==1.8.3 18 | google-auth==1.11.2 19 | google-auth-oauthlib==0.4.1 20 | google-pasta==0.2.0 21 | googledrivedownloader==0.4 22 | gpflow==1.1.0 23 | grpcio==1.27.2 24 | gym==0.14.0 25 | h5py==2.10.0 26 | hpbandster==0.7.4 27 | idna==2.8 28 | imageio==2.5.0 29 | isodate==0.6.0 30 | joblib==0.13.2 31 | keras-applications==1.0.8 32 | keras-preprocessing==1.1.2 33 | kiwisolver==1.1.0 34 | labmaze==1.0.3 35 | lockfile==0.12.2 36 | lxml==4.5.2 37 | markdown==3.2.1 38 | matplotlib==3.1.1 39 | more-itertools==8.3.0 40 | mujoco-py==2.0.2.5 41 | multipledispatch==0.6.0 42 | netifaces==0.10.9 43 | networkx==2.3 44 | numpy==1.17.4 45 | oauthlib==3.1.0 46 | omegaconf==2.0.0 47 | opencv-python==4.4.0.42 48 | opt-einsum==3.2.1 49 | packaging==20.4 50 | pandas==0.25.1 51 | patsy==0.5.1 52 | pluggy==0.13.1 53 | plyfile==0.7 54 | protobuf==3.10.0 55 | py==1.8.1 56 | pyasn1==0.4.8 57 | pyasn1-modules==0.2.8 58 | pyglet==1.3.2 59 | pyopengl==3.1.5 60 | pyparsing==2.4.2 61 | pyro4==4.80 62 | pytest==5.4.2 63 | python-dateutil==2.8.0 64 | pytz==2019.2 65 | pyyaml==5.3.1 66 | rdflib==4.2.2 67 | requests==2.22.0 68 | requests-oauthlib==1.3.0 69 | rsa==4.0 70 | scikit-learn==0.21.3 71 | scipy==1.3.1 72 | seaborn==0.11.0 73 | serpent==1.30.2 74 | statsmodels==0.11.1 75 | tensorboard-plugin-wit==1.6.0.post3 76 | termcolor==1.1.0 77 | tqdm==4.19.4 78 | typing==3.7.4.1 79 | typing-extensions==3.7.4.2 80 | urllib3==1.25.3 81 | werkzeug==1.0.0 82 | wrapt==1.12.1 -------------------------------------------------------------------------------- /scripts/hyperband.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Example command 4 | 5 | # activate your conda env 6 | source ~/.bashrc 7 | source activate mbrl 8 | 9 | INTERFACE=`ip route get 8.8.8.8 | cut -d' ' -f5 | head -1` 10 | echo Interface read:$INTERFACE 11 | 12 | # ENV=halfcheetah_v3 13 | # TASK_HORIZON=1000 14 | ENV=reacher 15 | TASK_HORIZON=10 16 | 17 | RUN_ID=$SLURM_ARRAY_JOB_ID 18 | # Replace with your directory 19 | DIR=log/$ENV\_$RUN_ID 20 | NINIT_ROLLOUTS=1 21 | MIN_BUDGET=79 22 | MAX_BUDGET=80 23 | NOPT_ITER=5 24 | LAST_K=3 25 | ETA=2 26 | NEVAL=1 27 | 28 | SEED=0 29 | OPT_TYPE=hyperband # or [random, bohb] 30 | PROP_TYPE=TSinf 31 | 32 | OPT_TYPE_PETS=CEM 33 | # CONFIG_NAMES=plan_hor\ num_cem_iters\ cem_popsize\ cem_elites_ratio\ cem_alpha 34 | CONFIG_NAMES=model_weight_decay\ model_learning_rate\ model_train_epoch 35 | 36 | cd .. 37 | 38 | if [ $SLURM_ARRAY_TASK_ID -eq 1 ] 39 | then 40 | python -u bo-mbexp.py -config_names $CONFIG_NAMES \ 41 | -opt_type $OPT_TYPE \ 42 | -run_id $RUN_ID \ 43 | -env $ENV \ 44 | -logdir $DIR \ 45 | -worker_id $SLURM_ARRAY_TASK_ID \ 46 | -seed $SEED \ 47 | -interface $INTERFACE \ 48 | -o exp_cfg.log_cfg.neval $NEVAL \ 49 | -o exp_cfg.exp_cfg.ninit_rollouts $NINIT_ROLLOUTS \ 50 | -o exp_cfg.bo_cfg.min_budget $MIN_BUDGET \ 51 | -o exp_cfg.bo_cfg.max_budget $MAX_BUDGET \ 52 | -o exp_cfg.bo_cfg.nopt_iter $NOPT_ITER \ 53 | -o exp_cfg.bo_cfg.eta $ETA \ 54 | -o exp_cfg.bo_cfg.last_k $LAST_K \ 55 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \ 56 | -ca prop-type $PROP_TYPE \ 57 | -ca opt-type $OPT_TYPE_PETS 58 | else 59 | python -u bo-mbexp.py -config_names $CONFIG_NAMES \ 60 | -opt_type $OPT_TYPE \ 61 | -run_id $RUN_ID \ 62 | -worker \ 63 | -env $ENV \ 64 | -logdir $DIR \ 65 | -worker_id $SLURM_ARRAY_TASK_ID \ 66 | -seed $SEED \ 67 | -interface $INTERFACE \ 68 | -o exp_cfg.log_cfg.neval $NEVAL \ 69 | -o exp_cfg.exp_cfg.ninit_rollouts $NINIT_ROLLOUTS \ 70 | -o exp_cfg.bo_cfg.min_budget $MIN_BUDGET \ 71 | -o exp_cfg.bo_cfg.max_budget $MAX_BUDGET \ 72 | -o exp_cfg.bo_cfg.nopt_iter $NOPT_ITER \ 73 | -o exp_cfg.bo_cfg.eta $ETA \ 74 | -o exp_cfg.bo_cfg.last_k $LAST_K \ 75 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \ 76 | -ca prop-type $PROP_TYPE \ 77 | -ca opt-type $OPT_TYPE_PETS 78 | fi 79 | 80 | -------------------------------------------------------------------------------- /scripts/pbt-bt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Example command 3 | 4 | # activate your conda env 5 | source ~/.bashrc 6 | source activate mbrl 7 | 8 | SAMPLE_FROM_PERCENT=0.2 9 | RESAMPLE_IF_NOT_IN_PERCENT=0.8 10 | RESAMPLE_PROBABILITY=0.25 11 | 12 | MAX_STEPS=2400 13 | DELTA_T=6 14 | TOLERANCE=0.2 15 | ELITE_RATIO=0.2 16 | 17 | POPULATION_SIZE=5 18 | CRITERION_MODE=lastk 19 | TASK_HORIZON=10 20 | LAST_K=3 21 | INITIAL_STEP=6 22 | BUDGET=3 23 | STEP=5 24 | 25 | # PROP_TYPE=E 26 | PROP_TYPE=TSinf 27 | 28 | OPT_TYPE_PETS=CEM 29 | # TOTAL STEPS = (BUDGET - 1) * STEP + INITIAL_STEP 30 | 31 | #ENV=pusher 32 | ENV=halfcheetah_v3 33 | # ENV=hopper 34 | # ENV=cartpole 35 | 36 | # Replace with your directory 37 | DIR=log/$ENV\_$SLURM_ARRAY_JOB_ID 38 | WORKER_ID=$SLURM_ARRAY_JOB_ID 39 | 40 | # CONFIG_NAMES=plan_hor\ num_cem_iters\ cem_popsize\ cem_elites_ratio\ cem_alpha 41 | CONFIG_NAMES=model_weight_decay\ model_learning_rate\ model_train_epoch 42 | NEVAL=1 43 | 44 | cd .. 45 | 46 | 47 | if [ $SLURM_ARRAY_TASK_ID -eq 1 ] 48 | then 49 | python -u pbt-bt-mbexp.py -config_names $CONFIG_NAMES \ 50 | -seed 0 \ 51 | -env $ENV \ 52 | -logdir $DIR \ 53 | -worker_id $SLURM_ARRAY_TASK_ID \ 54 | -sample_from_percent $SAMPLE_FROM_PERCENT \ 55 | -resample_if_not_in_percent $RESAMPLE_IF_NOT_IN_PERCENT \ 56 | -resample_probability $RESAMPLE_PROBABILITY \ 57 | -max_steps $MAX_STEPS \ 58 | -delta_t $DELTA_T \ 59 | -tolerance $TOLERANCE \ 60 | -elite_ratio $ELITE_RATIO \ 61 | -o exp_cfg.log_cfg.neval $NEVAL \ 62 | -o exp_cfg.pbt_cfg.pop_size $POPULATION_SIZE \ 63 | -o exp_cfg.pbt_cfg.budget $BUDGET \ 64 | -o exp_cfg.pbt_cfg.criterion_mode $CRITERION_MODE \ 65 | -o exp_cfg.pbt_cfg.last_k $LAST_K \ 66 | -o exp_cfg.pbt_cfg.initial_step $INITIAL_STEP \ 67 | -o exp_cfg.pbt_cfg.step $STEP \ 68 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \ 69 | -ca prop-type $PROP_TYPE \ 70 | -ca opt-type $OPT_TYPE_PETS 71 | else 72 | python -u pbt-bt-mbexp.py -config_names $CONFIG_NAMES \ 73 | -seed 0 \ 74 | -worker \ 75 | -env $ENV \ 76 | -logdir $DIR \ 77 | -worker_id $WORKER_ID \ 78 | -max_steps $MAX_STEPS \ 79 | -delta_t $DELTA_T \ 80 | -tolerance $TOLERANCE \ 81 | -elite_ratio $ELITE_RATIO \ 82 | -o exp_cfg.log_cfg.neval $NEVAL \ 83 | -o exp_cfg.pbt_cfg.pop_size $POPULATION_SIZE \ 84 | -o exp_cfg.pbt_cfg.budget $BUDGET \ 85 | -o exp_cfg.pbt_cfg.criterion_mode $CRITERION_MODE \ 86 | -o exp_cfg.pbt_cfg.last_k $LAST_K \ 87 | -o exp_cfg.pbt_cfg.initial_step $INITIAL_STEP \ 88 | -o exp_cfg.pbt_cfg.step $STEP \ 89 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \ 90 | -ca prop-type $PROP_TYPE \ 91 | -ca opt-type $OPT_TYPE_PETS 92 | fi -------------------------------------------------------------------------------- /scripts/pbt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # activate your conda env 4 | source ~/.bashrc 5 | source activate mbrl 6 | 7 | SAMPLE_FROM_PERCENT=0.2 8 | RESAMPLE_IF_NOT_IN_PERCENT=0.8 9 | RESAMPLE_PROBABILITY=0.25 10 | 11 | 12 | POPULATION_SIZE=40 13 | CRITERION_MODE=lastk 14 | TASK_HORIZON=10 15 | LAST_K=3 16 | INITIAL_STEP=6 17 | BUDGET=60 18 | STEP=5 19 | 20 | PROP_TYPE=E 21 | OPT_TYPE_PETS=CEM 22 | # TOTAL STEPS = (BUDGET - 1) * STEP + INITIAL_STEP 23 | 24 | ENV=halfcheetah_v3 25 | 26 | # Replace with your directory 27 | DIR=log/$ENV\_$SLURM_ARRAY_JOB_ID 28 | WORKER_ID=$SLURM_ARRAY_JOB_ID 29 | 30 | # CONFIG_NAMES=plan_hor\ num_cem_iters\ cem_popsize\ cem_elites_ratio\ cem_alpha 31 | CONFIG_NAMES=model_weight_decay\ model_learning_rate\ model_train_epoch 32 | NEVAL=1 33 | 34 | cd .. 35 | 36 | if [ $SLURM_ARRAY_TASK_ID -eq 1 ] 37 | then 38 | python -u pbt-mbexp.py -config_names $CONFIG_NAMES \ 39 | -seed 0 \ 40 | -env $ENV \ 41 | -logdir $DIR \ 42 | -worker_id $SLURM_ARRAY_TASK_ID \ 43 | -sample_from_percent $SAMPLE_FROM_PERCENT \ 44 | -resample_if_not_in_percent $RESAMPLE_IF_NOT_IN_PERCENT \ 45 | -resample_probability $RESAMPLE_PROBABILITY \ 46 | -o exp_cfg.log_cfg.neval $NEVAL \ 47 | -o exp_cfg.pbt_cfg.pop_size $POPULATION_SIZE \ 48 | -o exp_cfg.pbt_cfg.budget $BUDGET \ 49 | -o exp_cfg.pbt_cfg.criterion_mode $CRITERION_MODE \ 50 | -o exp_cfg.pbt_cfg.last_k $LAST_K \ 51 | -o exp_cfg.pbt_cfg.initial_step $INITIAL_STEP \ 52 | -o exp_cfg.pbt_cfg.step $STEP \ 53 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \ 54 | -ca prop-type $PROP_TYPE \ 55 | -ca opt-type $OPT_TYPE_PETS 56 | else 57 | python -u pbt-mbexp.py -config_names $CONFIG_NAMES \ 58 | -seed 0 \ 59 | -worker \ 60 | -env $ENV \ 61 | -logdir $DIR \ 62 | -worker_id $WORKER_ID \ 63 | -o exp_cfg.log_cfg.neval $NEVAL \ 64 | -o exp_cfg.pbt_cfg.pop_size $POPULATION_SIZE \ 65 | -o exp_cfg.pbt_cfg.budget $BUDGET \ 66 | -o exp_cfg.pbt_cfg.criterion_mode $CRITERION_MODE \ 67 | -o exp_cfg.pbt_cfg.last_k $LAST_K \ 68 | -o exp_cfg.pbt_cfg.initial_step $INITIAL_STEP \ 69 | -o exp_cfg.pbt_cfg.step $STEP \ 70 | -o exp_cfg.sim_cfg.task_hor $TASK_HORIZON \ 71 | -ca prop-type $PROP_TYPE \ 72 | -ca opt-type $OPT_TYPE_PETS 73 | fi --------------------------------------------------------------------------------