├── .gitignore ├── README.md ├── __init__.py ├── img ├── 1m_mbbl_result_table.jpeg ├── mbbl_front.png ├── mbbl_result_table.jpeg └── mbbl_stats.png ├── main ├── __init__.py ├── deepmimic_main.py ├── gail_mf_main.py ├── gps_main.py ├── ilqr_main.py ├── inverse_dynamics_IM.py ├── mbmf_main.py ├── metrpo_gnn_main.py ├── metrpo_main.py ├── mf_main.py ├── pets_main.py ├── random_main.py ├── rs_main.py └── test.py ├── mbbl ├── __init__.py ├── config │ ├── __init__.py │ ├── base_config.py │ ├── cem_config.py │ ├── ggnn_config.py │ ├── gps_config.py │ ├── il_config.py │ ├── ilqr_config.py │ ├── init_path.py │ ├── mbmf_config.py │ ├── metrpo_config.py │ ├── mf_config.py │ └── rs_config.py ├── env │ ├── __init__.py │ ├── base_env_wrapper.py │ ├── bullet_env │ │ ├── __init__.py │ │ ├── bullet_humanoid.py │ │ ├── depth_roboschool.py │ │ ├── humanoid.py │ │ └── motion_capture_data.py │ ├── dm_env │ │ ├── __init__.py │ │ ├── assets │ │ │ ├── cheetah_pos.xml │ │ │ ├── common │ │ │ │ ├── __init__.py │ │ │ │ ├── materials.xml │ │ │ │ ├── skybox.xml │ │ │ │ └── visual.xml │ │ │ ├── humanoid_CMU.xml │ │ │ └── reference │ │ │ │ ├── acrobot.xml │ │ │ │ ├── ball_in_cup.xml │ │ │ │ ├── cartpole.xml │ │ │ │ ├── cheetah.xml │ │ │ │ ├── finger.xml │ │ │ │ ├── fish.xml │ │ │ │ ├── hopper.xml │ │ │ │ ├── humanoid.xml │ │ │ │ ├── humanoid_CMU.xml │ │ │ │ ├── lqr.xml │ │ │ │ ├── manipulator.xml │ │ │ │ ├── pendulum.xml │ │ │ │ ├── point_mass.xml │ │ │ │ ├── reacher.xml │ │ │ │ ├── stacker.xml │ │ │ │ ├── swimmer.xml │ │ │ │ └── walker.xml │ │ ├── dm_env.py │ │ ├── humanoid_env.py │ │ └── pos_dm_env.py │ ├── env_register.py │ ├── env_util.py │ ├── fake_env.py │ ├── gym_env │ │ ├── __init__.py │ │ ├── acrobot.py │ │ ├── box2d │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── walker.py │ │ │ └── wrappers.py │ │ ├── box2d_lunar_lander.py │ │ ├── box2d_racer.py │ │ ├── box2d_walker.py │ │ ├── cartpole.py │ │ ├── delayed_walker.py │ │ ├── fix_swimmer │ │ │ ├── __init__.py │ │ │ ├── assets │ │ │ │ └── fixed_swimmer.xml │ │ │ └── fixed_swimmer.py │ │ ├── fixed_swimmer.py │ │ ├── fixed_walker.py │ │ ├── humanoid.py │ │ ├── invertedPendulum.py │ │ ├── mountain_car.py │ │ ├── noise_gym_cartpole.py │ │ ├── noise_gym_cheetah.py │ │ ├── noise_gym_pendulum.py │ │ ├── pendulum.py │ │ ├── pets.py │ │ ├── pets_env │ │ │ ├── __init__.py │ │ │ ├── assets │ │ │ │ ├── cartpole.xml │ │ │ │ ├── half_cheetah.xml │ │ │ │ ├── pusher.xml │ │ │ │ └── reacher3d.xml │ │ │ ├── cartpole.py │ │ │ ├── half_cheetah.py │ │ │ ├── half_cheetah_config.py │ │ │ ├── pusher.py │ │ │ ├── pusher_config.py │ │ │ ├── reacher.py │ │ │ └── reacher_config.py │ │ ├── point.py │ │ ├── point_env.py │ │ ├── reacher.py │ │ └── walker.py │ ├── render.py │ └── render_wrapper.py ├── network │ ├── __init__.py │ ├── dynamics │ │ ├── __init__.py │ │ ├── base_dynamics.py │ │ ├── bayesian_forward_dynamics.py │ │ ├── deterministic_forward_dynamics.py │ │ ├── deterministic_forward_ggnn_dynamics.py │ │ ├── groundtruth_forward_dynamics.py │ │ ├── linear_stochastic_forward_dynamics_gmm_prior.py │ │ └── stochastic_forward_dynamics.py │ ├── policy │ │ ├── __init__.py │ │ ├── base_policy.py │ │ ├── cem_policy.py │ │ ├── gps_policy_gmm_refit.py │ │ ├── mbmf_policy.py │ │ ├── ppo_cnn_policy.py │ │ ├── ppo_policy.py │ │ ├── random_policy.py │ │ └── trpo_policy.py │ └── reward │ │ ├── GAN_reward.py │ │ ├── __init__.py │ │ ├── base_reward.py │ │ ├── deepmimic_reward.py │ │ └── groundtruth_reward.py ├── sampler │ ├── __init__.py │ ├── base_sampler.py │ ├── mbmf_sampler.py │ ├── readme.md │ ├── singletask_ilqr_sampler.py │ ├── singletask_metrpo_sampler.py │ ├── singletask_pets_sampler.py │ ├── singletask_random_sampler.py │ └── singletask_sampler.py ├── trainer │ ├── __init__.py │ ├── base_trainer.py │ ├── gail_trainer.py │ ├── gps_trainer.py │ ├── mbmf_trainer.py │ ├── metrpo_trainer.py │ ├── readme.md │ └── shooting_trainer.py ├── util │ ├── __init__.py │ ├── base_main.py │ ├── common │ │ ├── __init__.py │ │ ├── fpdb.py │ │ ├── ggnn_utils.py │ │ ├── logger.py │ │ ├── misc_utils.py │ │ ├── model_saver.py │ │ ├── parallel_util.py │ │ ├── replay_buffer.py │ │ ├── summary_handler.py │ │ ├── tf_ggnn_networks.py │ │ ├── tf_networks.py │ │ ├── tf_norm.py │ │ ├── tf_utils.py │ │ ├── vis_debug.py │ │ └── whitening_util.py │ ├── gps │ │ ├── __init__.py │ │ └── gps_utils.py │ ├── il │ │ ├── __init__.py │ │ ├── camera_model.py │ │ ├── camera_pose_ID_solver.py │ │ ├── expert_data_util.py │ │ ├── il_util.py │ │ ├── pose_visualization.py │ │ └── test_inverse_dynamics.py │ ├── ilqr │ │ ├── __init__.py │ │ ├── ilqr_data_wrapper.py │ │ ├── ilqr_utils.py │ │ └── stochastic_ilqr_data_wrapper.py │ └── kfac │ │ ├── estimator.py │ │ ├── fisher_blocks.py │ │ ├── fisher_factors.py │ │ ├── layer_collection.py │ │ ├── loss_functions.py │ │ ├── optimizer.py │ │ └── utils.py └── worker │ ├── __init__.py │ ├── base_worker.py │ ├── cem_worker.py │ ├── mbmf_worker.py │ ├── metrpo_worker.py │ ├── mf_worker.py │ ├── model_worker.py │ ├── readme.md │ └── rs_worker.py ├── scripts ├── exp_1_performance_curve │ ├── ilqr.sh │ ├── ilqr_depth.sh │ ├── mbmf.sh │ ├── mbmf_1m.sh │ ├── pets_gt.sh │ ├── pets_gt_checkdone.sh │ ├── ppo.sh │ ├── random_policy.sh │ ├── rs.sh │ ├── rs_1.sh │ ├── rs_2.sh │ ├── rs_gt.sh │ ├── rs_gt_checkdone.sh │ └── trpo.sh ├── exp_2_dilemma │ └── gt_computation.sh ├── exp_2_dynamics_schemes │ ├── rs_act_norm.sh │ └── rs_lr_batch.sh ├── exp_7_traj │ └── ilqr_numtraj.sh ├── exp_9_planning_depth │ └── ilqr_depth.sh ├── ilqr │ ├── ilqr.sh │ ├── ilqr_depth.sh │ ├── ilqr_iter.sh │ └── ilqr_numtraj.sh ├── performance_curve │ └── mbmf.sh └── test │ ├── gym_humanoid.sh │ ├── humanoid.sh │ ├── humanoid_new.sh │ └── mbmf.sh ├── setup.py └── tests ├── test_utils ├── test_gt_dynamics.py └── test_replay_buffer.py └── tf-test-guide.md /.gitignore: -------------------------------------------------------------------------------- 1 | MUJUCO_LOG.txt 2 | imitation 3 | main/MUJUCO_LOG.txt 4 | *.pdf 5 | .nfs* 6 | output/ 7 | *.p 8 | *.csv 9 | *.npy 10 | paper/ 11 | *.mp4 12 | *.log 13 | *.swp 14 | *.swo 15 | *.swn 16 | summary/ 17 | log/ 18 | *.jpg 19 | tool/batch_summary_* 20 | video/ 21 | *.zip 22 | *cpython* 23 | data/ 24 | 25 | # IDE files 26 | .idea/ 27 | 28 | # pytest cache files 29 | .pytest_cache 30 | # Byte-compiled / optimized / DLL files 31 | __pycache__/ 32 | *.py[cod] 33 | *$py.class 34 | 35 | # C extensions 36 | *.so 37 | 38 | # Distribution / packaging 39 | .Python 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | wheels/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | MANIFEST 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .nox/ 71 | .coverage 72 | .coverage.* 73 | .cache 74 | nosetests.xml 75 | coverage.xml 76 | *.cover 77 | .hypothesis/ 78 | .pytest_cache/ 79 | 80 | # Translations 81 | *.mo 82 | *.pot 83 | 84 | # Django stuff: 85 | *.log 86 | local_settings.py 87 | db.sqlite3 88 | 89 | # Flask stuff: 90 | instance/ 91 | .webassets-cache 92 | 93 | # Scrapy stuff: 94 | .scrapy 95 | 96 | # Sphinx documentation 97 | docs/_build/ 98 | 99 | # PyBuilder 100 | target/ 101 | 102 | # Jupyter Notebook 103 | .ipynb_checkpoints 104 | 105 | # IPython 106 | profile_default/ 107 | ipython_config.py 108 | 109 | # pyenv 110 | .python-version 111 | 112 | # celery beat schedule file 113 | celerybeat-schedule 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/__init__.py -------------------------------------------------------------------------------- /img/1m_mbbl_result_table.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/img/1m_mbbl_result_table.jpeg -------------------------------------------------------------------------------- /img/mbbl_front.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/img/mbbl_front.png -------------------------------------------------------------------------------- /img/mbbl_result_table.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/img/mbbl_result_table.jpeg -------------------------------------------------------------------------------- /img/mbbl_stats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/img/mbbl_stats.png -------------------------------------------------------------------------------- /main/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/main/__init__.py -------------------------------------------------------------------------------- /main/deepmimic_main.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: main fucntion 3 | # ----------------------------------------------------------------------------- 4 | 5 | from mbbl.config import base_config 6 | from mbbl.config import init_path 7 | from mbbl.config import mf_config 8 | from mbbl.config import il_config 9 | from mbbl.util.common import logger 10 | from mf_main import train 11 | 12 | 13 | def main(): 14 | parser = base_config.get_base_config() 15 | parser = mf_config.get_mf_config(parser) 16 | parser = il_config.get_il_config(parser) 17 | args = base_config.make_parser(parser) 18 | 19 | if args.write_log: 20 | logger.set_file_handler(path=args.output_dir, 21 | prefix='deepmimic-mf-' + args.task, 22 | time_str=args.exp_id) 23 | 24 | # no random policy for model-free rl 25 | assert args.random_timesteps == 0 26 | 27 | print('Training starts at {}'.format(init_path.get_abs_base_dir())) 28 | from mbbl.trainer import gail_trainer 29 | from mbbl.sampler import singletask_sampler 30 | from mbbl.worker import mf_worker 31 | import mbbl.network.policy.trpo_policy 32 | import mbbl.network.policy.ppo_policy 33 | 34 | policy_network = { 35 | 'ppo': mbbl.network.policy.ppo_policy.policy_network, 36 | 'trpo': mbbl.network.policy.trpo_policy.policy_network 37 | }[args.trust_region_method] 38 | 39 | # here the dynamics and reward are simply placeholders, which cannot be 40 | # called to pred next state or reward 41 | from mbbl.network.dynamics.base_dynamics import base_dynamics_network 42 | from mbbl.network.reward.deepmimic_reward import reward_network 43 | 44 | train(gail_trainer, singletask_sampler, mf_worker, 45 | base_dynamics_network, policy_network, reward_network, args) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /main/gail_mf_main.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: main fucntion 3 | # ----------------------------------------------------------------------------- 4 | 5 | from mbbl.config import base_config 6 | from mbbl.config import init_path 7 | from mbbl.config import mf_config 8 | from mbbl.config import il_config 9 | from mbbl.util.common import logger 10 | from mf_main import train 11 | 12 | 13 | def main(): 14 | parser = base_config.get_base_config() 15 | parser = mf_config.get_mf_config(parser) 16 | parser = il_config.get_il_config(parser) 17 | args = base_config.make_parser(parser) 18 | 19 | if args.write_log: 20 | logger.set_file_handler(path=args.output_dir, 21 | prefix='gail-mfrl-mf-' + args.task, 22 | time_str=args.exp_id) 23 | 24 | # no random policy for model-free rl 25 | assert args.random_timesteps == 0 26 | 27 | print('Training starts at {}'.format(init_path.get_abs_base_dir())) 28 | from mbbl.trainer import gail_trainer 29 | from mbbl.sampler import singletask_sampler 30 | from mbbl.worker import mf_worker 31 | import mbbl.network.policy.trpo_policy 32 | import mbbl.network.policy.ppo_policy 33 | 34 | policy_network = { 35 | 'ppo': mbbl.network.policy.ppo_policy.policy_network, 36 | 'trpo': mbbl.network.policy.trpo_policy.policy_network 37 | }[args.trust_region_method] 38 | 39 | # here the dynamics and reward are simply placeholders, which cannot be 40 | # called to pred next state or reward 41 | from mbbl.network.dynamics.base_dynamics import base_dynamics_network 42 | from mbbl.network.reward.GAN_reward import reward_network 43 | 44 | train(gail_trainer, singletask_sampler, mf_worker, 45 | base_dynamics_network, policy_network, reward_network, args) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /main/gps_main.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: main fucntion 3 | # ----------------------------------------------------------------------------- 4 | 5 | from mf_main import train # gps use similar trainer as mf trainer 6 | from mbbl.config import base_config 7 | from mbbl.config import gps_config 8 | from mbbl.config import ilqr_config 9 | from mbbl.config import init_path 10 | from mbbl.util.common import logger 11 | 12 | 13 | def main(): 14 | parser = base_config.get_base_config() 15 | parser = ilqr_config.get_ilqr_config(parser) 16 | parser = gps_config.get_gps_config(parser) 17 | args = base_config.make_parser(parser) 18 | 19 | if args.write_log: 20 | logger.set_file_handler(path=args.output_dir, 21 | prefix='mbrl-gps-' + args.task, 22 | time_str=args.exp_id) 23 | 24 | print('Training starts at {}'.format(init_path.get_abs_base_dir())) 25 | from mbbl.trainer import gps_trainer 26 | from mbbl.sampler import singletask_sampler 27 | from mbbl.worker import mf_worker 28 | from mbbl.network.policy.gps_policy_gmm_refit import policy_network 29 | 30 | assert not args.gt_dynamics and args.gt_reward 31 | from mbbl.network.dynamics.linear_stochastic_forward_dynamics_gmm_prior \ 32 | import dynamics_network 33 | from mbbl.network.reward.groundtruth_reward import reward_network 34 | 35 | train(gps_trainer, singletask_sampler, mf_worker, 36 | dynamics_network, policy_network, reward_network, args) 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /main/ilqr_main.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: main fucntion 3 | # ----------------------------------------------------------------------------- 4 | 5 | from mbbl.config import base_config 6 | from mbbl.config import ilqr_config 7 | from mbbl.config import init_path 8 | from mbbl.util.base_main import train 9 | from mbbl.util.common import logger 10 | 11 | 12 | def main(): 13 | parser = base_config.get_base_config() 14 | parser = ilqr_config.get_ilqr_config(parser) 15 | args = base_config.make_parser(parser) 16 | 17 | if args.write_log: 18 | logger.set_file_handler(path=args.output_dir, 19 | prefix='mbrl-ilqr' + args.task, 20 | time_str=args.exp_id) 21 | 22 | print('Training starts at {}'.format(init_path.get_abs_base_dir())) 23 | from mbbl.trainer import shooting_trainer 24 | from mbbl.sampler import singletask_ilqr_sampler 25 | from mbbl.worker import model_worker 26 | from mbbl.network.policy.random_policy import policy_network 27 | 28 | if args.gt_dynamics: 29 | from mbbl.network.dynamics.groundtruth_forward_dynamics import \ 30 | dynamics_network 31 | else: 32 | from mbbl.network.dynamics.deterministic_forward_dynamics import \ 33 | dynamics_network 34 | 35 | if args.gt_reward: 36 | from mbbl.network.reward.groundtruth_reward import reward_network 37 | else: 38 | from mbbl.network.reward.deterministic_reward import reward_network 39 | 40 | if (not args.gt_reward) or not (args.gt_dynamics): 41 | raise NotImplementedError('Havent finished! Oooooops') 42 | 43 | train(shooting_trainer, singletask_ilqr_sampler, model_worker, 44 | dynamics_network, policy_network, reward_network, args) 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /main/inverse_dynamics_IM.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: main fucntion 3 | # ----------------------------------------------------------------------------- 4 | 5 | from mbbl.config import base_config 6 | from mbbl.config import init_path 7 | from mbbl.config import rs_config 8 | from mbbl.config import il_config 9 | from mbbl.util.common import logger 10 | from mbbl.util.il import camera_pose_ID_solver 11 | 12 | 13 | def main(): 14 | parser = base_config.get_base_config() 15 | parser = rs_config.get_rs_config(parser) 16 | parser = il_config.get_il_config(parser) 17 | args = base_config.make_parser(parser) 18 | args = il_config.post_process_config(args) 19 | 20 | if args.write_log: 21 | args.log_path = logger.set_file_handler( 22 | path=args.output_dir, prefix='inverse_dynamics' + args.task, 23 | time_str=args.exp_id 24 | ) 25 | 26 | print('Training starts at {}'.format(init_path.get_abs_base_dir())) 27 | 28 | train(args) 29 | 30 | 31 | def train(args): 32 | training_args = { 33 | 'physics_loss_lambda': args.physics_loss_lambda, 34 | 'lbfgs_opt_iteration': args.lbfgs_opt_iteration 35 | } 36 | solver = camera_pose_ID_solver.solver( 37 | args.expert_data_name, args.sol_qpos_freq, args.opt_var_list, 38 | args.camera_info, args.camera_id, training_args, 39 | args.imitation_length, 40 | args.gt_camera_info, args.log_path 41 | ) 42 | 43 | solver.solve() 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /main/mf_main.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: main fucntion 3 | # ----------------------------------------------------------------------------- 4 | 5 | import os 6 | os.environ['MUJOCO_GL'] = "osmesa" 7 | import time 8 | from collections import OrderedDict 9 | 10 | from mbbl.config import base_config 11 | from mbbl.config import init_path 12 | from mbbl.config import mf_config 13 | from mbbl.util.base_main import make_sampler, make_trainer, log_results 14 | from mbbl.util.common import logger 15 | from mbbl.util.common import parallel_util 16 | 17 | 18 | def train(trainer, sampler, worker, dynamics, policy, reward, args=None): 19 | logger.info('Training starts at {}'.format(init_path.get_abs_base_dir())) 20 | network_type = {'policy': policy, 'dynamics': dynamics, 'reward': reward} 21 | 22 | # make the trainer and sampler 23 | sampler_agent = make_sampler(sampler, worker, network_type, args) 24 | trainer_tasks, trainer_results, trainer_agent, init_weights = \ 25 | make_trainer(trainer, network_type, args) 26 | sampler_agent.set_weights(init_weights) 27 | 28 | timer_dict = OrderedDict() 29 | timer_dict['Program Start'] = time.time() 30 | current_iteration = 0 31 | 32 | while True: 33 | timer_dict['** Program Total Time **'] = time.time() 34 | 35 | # step 1: collect rollout data 36 | rollout_data = \ 37 | sampler_agent.rollouts_using_worker_playing(use_true_env=True) 38 | 39 | timer_dict['Generate Rollout'] = time.time() 40 | 41 | # step 2: train the weights for dynamics and policy network 42 | training_info = {'network_to_train': ['dynamics', 'reward', 'policy']} 43 | trainer_tasks.put( 44 | (parallel_util.TRAIN_SIGNAL, 45 | {'data': rollout_data['data'], 'training_info': training_info}) 46 | ) 47 | trainer_tasks.join() 48 | training_return = trainer_results.get() 49 | timer_dict['Train Weights'] = time.time() 50 | 51 | # step 4: update the weights 52 | sampler_agent.set_weights(training_return['network_weights']) 53 | timer_dict['Assign Weights'] = time.time() 54 | 55 | # log and print the results 56 | log_results(training_return, timer_dict) 57 | 58 | # if totalsteps > args.max_timesteps: 59 | if training_return['totalsteps'] > args.max_timesteps: 60 | break 61 | else: 62 | current_iteration += 1 63 | 64 | # end of training 65 | sampler_agent.end() 66 | trainer_tasks.put((parallel_util.END_SIGNAL, None)) 67 | 68 | 69 | def main(): 70 | parser = base_config.get_base_config() 71 | parser = mf_config.get_mf_config(parser) 72 | args = base_config.make_parser(parser) 73 | 74 | if args.write_log: 75 | logger.set_file_handler(path=args.output_dir, 76 | prefix='mfrl-mf' + args.task, 77 | time_str=args.exp_id) 78 | 79 | # no random policy for model-free rl 80 | assert args.random_timesteps == 0 81 | 82 | print('Training starts at {}'.format(init_path.get_abs_base_dir())) 83 | from mbbl.trainer import shooting_trainer 84 | from mbbl.sampler import singletask_sampler 85 | from mbbl.worker import mf_worker 86 | import mbbl.network.policy.trpo_policy 87 | import mbbl.network.policy.ppo_policy 88 | 89 | policy_network = { 90 | 'ppo': mbbl.network.policy.ppo_policy.policy_network, 91 | 'trpo': mbbl.network.policy.trpo_policy.policy_network 92 | }[args.trust_region_method] 93 | 94 | # here the dynamics and reward are simply placeholders, which cannot be 95 | # called to pred next state or reward 96 | from mbbl.network.dynamics.base_dynamics import base_dynamics_network 97 | from mbbl.network.reward.groundtruth_reward import base_reward_network 98 | 99 | train(shooting_trainer, singletask_sampler, mf_worker, 100 | base_dynamics_network, policy_network, base_reward_network, args) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /main/pets_main.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: main fucntion 3 | # ----------------------------------------------------------------------------- 4 | 5 | from mbbl.config import base_config 6 | from mbbl.config import cem_config 7 | from mbbl.config import init_path 8 | from mbbl.util.base_main import train 9 | from mbbl.util.common import logger 10 | 11 | 12 | def main(): 13 | parser = base_config.get_base_config() 14 | parser = cem_config.get_cem_config(parser) 15 | args = base_config.make_parser(parser) 16 | 17 | if args.write_log: 18 | logger.set_file_handler(path=args.output_dir, 19 | prefix='mbrl-cem' + args.task, 20 | time_str=args.exp_id) 21 | 22 | print('Training starts at {}'.format(init_path.get_abs_base_dir())) 23 | from mbbl.trainer import shooting_trainer 24 | from mbbl.sampler import singletask_pets_sampler 25 | from mbbl.worker import cem_worker 26 | from mbbl.network.policy.random_policy import policy_network 27 | 28 | if args.gt_dynamics: 29 | from mbbl.network.dynamics.groundtruth_forward_dynamics import \ 30 | dynamics_network 31 | else: 32 | from mbbl.network.dynamics.deterministic_forward_dynamics import \ 33 | dynamics_network 34 | 35 | if args.gt_reward: 36 | from mbbl.network.reward.groundtruth_reward import reward_network 37 | else: 38 | from mbbl.network.reward.deterministic_reward import reward_network 39 | 40 | train(shooting_trainer, singletask_pets_sampler, cem_worker, 41 | dynamics_network, policy_network, reward_network, args) 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | -------------------------------------------------------------------------------- /main/random_main.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: main fucntion 3 | # ----------------------------------------------------------------------------- 4 | 5 | from mbbl.config import base_config 6 | from mbbl.config import ilqr_config 7 | from mbbl.config import init_path 8 | from mbbl.util.base_main import train 9 | from mbbl.util.common import logger 10 | 11 | 12 | def main(): 13 | parser = base_config.get_base_config() 14 | parser = ilqr_config.get_ilqr_config(parser) 15 | args = base_config.make_parser(parser) 16 | 17 | if args.write_log: 18 | logger.set_file_handler(path=args.output_dir, 19 | prefix='mbrl-ilqr' + args.task, 20 | time_str=args.exp_id) 21 | 22 | print('Training starts at {}'.format(init_path.get_abs_base_dir())) 23 | from mbbl.trainer import shooting_trainer 24 | from mbbl.sampler import singletask_random_sampler 25 | from mbbl.worker import model_worker 26 | from mbbl.network.policy.random_policy import policy_network 27 | 28 | from mbbl.network.dynamics.groundtruth_forward_dynamics import \ 29 | dynamics_network 30 | 31 | if args.gt_reward: 32 | from mbbl.network.reward.groundtruth_reward import reward_network 33 | else: 34 | from mbbl.network.reward.deterministic_reward import reward_network 35 | 36 | train(shooting_trainer, singletask_random_sampler, model_worker, 37 | dynamics_network, policy_network, reward_network, args) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /main/rs_main.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: main fucntion 3 | # ----------------------------------------------------------------------------- 4 | 5 | from mbbl.config import base_config 6 | from mbbl.config import init_path 7 | from mbbl.config import rs_config 8 | from mbbl.util.base_main import train 9 | from mbbl.util.common import logger 10 | 11 | 12 | def main(): 13 | parser = base_config.get_base_config() 14 | parser = rs_config.get_rs_config(parser) 15 | args = base_config.make_parser(parser) 16 | 17 | if args.write_log: 18 | logger.set_file_handler(path=args.output_dir, 19 | prefix='mbrl-rs' + args.task, 20 | time_str=args.exp_id) 21 | 22 | print('Training starts at {}'.format(init_path.get_abs_base_dir())) 23 | from mbbl.trainer import shooting_trainer 24 | from mbbl.sampler import singletask_sampler 25 | from mbbl.worker import rs_worker 26 | from mbbl.network.policy.random_policy import policy_network 27 | 28 | if args.gt_dynamics: 29 | from mbbl.network.dynamics.groundtruth_forward_dynamics import \ 30 | dynamics_network 31 | else: 32 | from mbbl.network.dynamics.deterministic_forward_dynamics import \ 33 | dynamics_network 34 | 35 | if args.gt_reward: 36 | from mbbl.network.reward.groundtruth_reward import reward_network 37 | else: 38 | from mbbl.network.reward.deterministic_reward import reward_network 39 | 40 | train(shooting_trainer, singletask_sampler, rs_worker, 41 | dynamics_network, policy_network, reward_network, args) 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | -------------------------------------------------------------------------------- /main/test.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: main fucntion 3 | # ----------------------------------------------------------------------------- 4 | 5 | import multiprocessing 6 | import time 7 | import os 8 | from collections import OrderedDict 9 | 10 | from mbbl.config import init_path 11 | from mbbl.util.common import parallel_util 12 | from mbbl.util.common import logger 13 | 14 | if __name__ == '__main__': 15 | print('a') 16 | -------------------------------------------------------------------------------- /mbbl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/__init__.py -------------------------------------------------------------------------------- /mbbl/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/config/__init__.py -------------------------------------------------------------------------------- /mbbl/config/cem_config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # record the parameters here 4 | # ------------------------------------------------------------------------------ 5 | 6 | 7 | def get_cem_config(parser): 8 | parser.add_argument("--dynamics_val_percentage", type=float, default=0.33) 9 | parser.add_argument("--dynamics_val_max_size", type=int, default=5000) 10 | parser.add_argument("--num_planning_traj", type=int, default=10) 11 | parser.add_argument("--planning_depth", type=int, default=10) 12 | parser.add_argument("--cem_learning_rate", type=float, default=0.1) 13 | parser.add_argument("--cem_num_iters", type=int, default=5) 14 | parser.add_argument("--cem_elites_fraction", type=float, default=0.1) 15 | 16 | parser.add_argument("--check_done", type=int, default=0) 17 | 18 | return parser 19 | -------------------------------------------------------------------------------- /mbbl/config/ggnn_config.py: -------------------------------------------------------------------------------- 1 | 2 | def get_gnn_config(parser): 3 | 4 | parser.add_argument('--ggnn_keep_prob', type=float, default=0.8) 5 | 6 | parser.add_argument('--t_step', type=int, default=3) 7 | 8 | parser.add_argument('--embed_layer', type=int, default=1) 9 | parser.add_argument('--embed_neuron', type=int, default=256) 10 | parser.add_argument('--prop_layer', type=int, default=1) 11 | parser.add_argument('--prop_neuron', type=int, default=1024) 12 | parser.add_argument('--output_layer', type=int, default=1) 13 | parser.add_argument('--output_neuron', type=int, default=1024) 14 | 15 | parser.add_argument('--prop_normalize', action='store_true') 16 | 17 | parser.add_argument('--d_output', action='store_true') 18 | parser.add_argument('--d_bins', type=int, default=51) 19 | 20 | return parser 21 | -------------------------------------------------------------------------------- /mbbl/config/gps_config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # ------------------------------------------------------------------------------ 4 | 5 | 6 | def get_gps_config(parser): 7 | 8 | # the linear gaussian dynamics with gmm prior 9 | parser.add_argument("--gmm_num_cluster", type=int, default=30) 10 | parser.add_argument("--gmm_max_iteration", type=int, default=100) 11 | parser.add_argument("--gmm_batch_size", type=int, default=2000) 12 | parser.add_argument("--gmm_prior_strength", type=float, default=1.0) 13 | parser.add_argument("--gps_dynamics_cov_reg", type=float, default=1e-6) 14 | parser.add_argument("--gps_policy_cov_reg", type=float, default=1e-8) 15 | # parser.add_argument("--gps_nn_policy_cov_reg", type=float, default=1e-6) 16 | 17 | parser.add_argument("--gps_max_backward_pass_trials", type=int, 18 | default=20) 19 | 20 | # the constraints on the kl between policy and traj 21 | parser.add_argument("--gps_init_traj_kl_eta", type=float, default=1.0) 22 | parser.add_argument("--gps_min_eta", type=float, default=1e-8) 23 | parser.add_argument("--gps_max_eta", type=float, default=1e16) 24 | parser.add_argument("--gps_eta_multiplier", type=float, default=1e-4) 25 | 26 | parser.add_argument("--gps_min_kl_step_mult", type=float, default=1e-2) 27 | parser.add_argument("--gps_max_kl_step_mult", type=float, default=1e2) 28 | parser.add_argument("--gps_init_kl_step_mult", type=float, default=1.0) 29 | 30 | parser.add_argument("--gps_base_kl_step", type=float, default=1.0) 31 | 32 | # the penalty on the entropy of the traj 33 | parser.add_argument("--gps_traj_ent_epsilon", type=float, default=0.001) 34 | 35 | parser.add_argument("--gps_policy_cov_damping", type=float, default=0.0001) 36 | parser.add_argument("--gps_init_state_replay_size", type=int, default=1000) 37 | 38 | # single traj update 39 | parser.add_argument("--gps_single_condition", type=int, default=1) 40 | return parser 41 | -------------------------------------------------------------------------------- /mbbl/config/il_config.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: 3 | # record the parameters here 4 | # @author: 5 | # Tingwu Wang, 2017, June, 12th 6 | # ------------------------------------------------------------------------------ 7 | 8 | from mbbl.env.env_register import _ENV_INFO 9 | from mbbl.util.il.camera_model import xyaxis2quaternion 10 | import copy 11 | 12 | 13 | def get_il_config(parser): 14 | # get the parameters 15 | parser.add_argument("--GAN_reward_clip_value", type=float, default=10.0) 16 | parser.add_argument("--GAN_ent_coeff", type=float, default=0.0) 17 | 18 | parser.add_argument("--expert_data_name", type=str, default='') 19 | parser.add_argument("--traj_episode_num", type=int, default=-1) 20 | 21 | parser.add_argument("--gan_timesteps_per_epoch", type=int, default=2000) 22 | parser.add_argument("--positive_negative_ratio", type=float, default=1) 23 | 24 | # the config for the inverse dynamics planner 25 | parser.add_argument("--sol_qpos_freq", type=int, default=1) 26 | parser.add_argument("--camera_id", type=int, default=0) 27 | parser.add_argument("--imitation_length", type=int, default=10, 28 | help="debug using 10, 100 sounds more like real running") 29 | 30 | parser.add_argument("--opt_var_list", type=str, 31 | # default='qpos', 32 | default='quaternion', 33 | # default='qpos-xyz_pos-quaternion-fov-image_size', 34 | help='use the "-" to divide variables') 35 | 36 | parser.add_argument("--physics_loss_lambda", type=float, default=1.0, 37 | help='lambda weights the inverse or physics loss') 38 | parser.add_argument("--lbfgs_opt_iteration", type=int, default=1, 39 | help='number of iterations for the inner lbfgs loops') 40 | 41 | return parser 42 | 43 | 44 | def post_process_config(args): 45 | """ @brief: 46 | 1. Parse the opt_var_list string into list 47 | 2. parse the camera info 48 | """ 49 | # hack to parse the opt_var_list 50 | args.opt_var_list = args.opt_var_list.split('-') 51 | for key in args.opt_var_list: 52 | assert key in ['qpos', 'xyz_pos', 'quaternion', 'fov', 'image_size'] 53 | 54 | # the camera_info 55 | assert 'camera_info' in _ENV_INFO[args.task_name] 56 | args.camera_info = _ENV_INFO[args.task_name]['camera_info'] 57 | 58 | # generate the quaternion data 59 | for cam in args.camera_info: 60 | if 'quaternion' not in cam: 61 | cam['quaternion'] = xyaxis2quaternion(cam['xyaxis']) 62 | 63 | args.gt_camera_info = \ 64 | copy.deepcopy(_ENV_INFO[args.task_name]['camera_info']) 65 | 66 | # the mask the groundtruth that is not in given? 67 | for key in ['qpos', 'xyz_pos', 'quaternion', 'fov', 'image_size']: 68 | if key in args.opt_var_list: 69 | args.camera_info[args.camera_id][key] = None 70 | 71 | return args 72 | -------------------------------------------------------------------------------- /mbbl/config/ilqr_config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # ------------------------------------------------------------------------------ 4 | 5 | 6 | def get_ilqr_config(parser): 7 | # get the parameters 8 | parser.add_argument("--finite_difference_eps", type=float, default=1e-6) 9 | parser.add_argument("--num_ilqr_traj", type=int, default=1, 10 | help='number of different initializations of ilqr') 11 | parser.add_argument("--ilqr_depth", type=int, default=20) 12 | parser.add_argument("--ilqr_linesearch_accept_ratio", type=float, 13 | default=0.01) 14 | parser.add_argument("--ilqr_linesearch_decay_factor", type=float, 15 | default=5.0) 16 | parser.add_argument("--ilqr_iteration", type=int, default=10) 17 | parser.add_argument("--max_ilqr_linesearch_backtrack", type=int, 18 | default=30) 19 | 20 | parser.add_argument("--LM_damping_type", type=str, default='V', 21 | help="['V', 'Q'], whether to put damping on V or Q") 22 | parser.add_argument("--init_LM_damping", type=float, default=0.1, 23 | help="initial value of Levenberg-Marquardt_damping") 24 | parser.add_argument("--min_LM_damping", type=float, default=1e-3) 25 | parser.add_argument("--max_LM_damping", type=float, default=1e10) 26 | parser.add_argument("--init_LM_damping_multiplier", type=float, default=1.0) 27 | parser.add_argument("--LM_damping_factor", type=float, default=1.6) 28 | 29 | return parser 30 | -------------------------------------------------------------------------------- /mbbl/config/init_path.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # In this file we init the path 4 | # @author: 5 | # Written by Tingwu Wang, 2016/Sept/22nd 6 | # ------------------------------------------------------------------------------ 7 | 8 | 9 | import os.path as osp 10 | import datetime 11 | 12 | real_file_path = osp.realpath(__file__.replace('pyc', 'py')) 13 | _this_dir = osp.dirname(real_file_path) 14 | 15 | running_start_time = datetime.datetime.now() 16 | time = str(running_start_time.strftime("%Y_%m_%d-%X")) 17 | 18 | _base_dir = osp.join(_this_dir, '..', '..') 19 | 20 | 21 | def bypass_frost_warning(): 22 | return 0 23 | 24 | 25 | def get_base_dir(): 26 | return _base_dir 27 | 28 | 29 | def get_time(): 30 | return time 31 | 32 | 33 | def get_abs_base_dir(): 34 | return osp.abspath(_base_dir) 35 | -------------------------------------------------------------------------------- /mbbl/config/mbmf_config.py: -------------------------------------------------------------------------------- 1 | 2 | def get_mbmf_config(parser): 3 | # get the parameters 4 | # MF config 5 | parser.add_argument("--value_lr", type=float, default=3e-4) 6 | parser.add_argument("--value_epochs", type=int, default=20) 7 | parser.add_argument("--value_network_shape", type=str, default='64,64') 8 | # parser.add_argument("--value_batch_size", type=int, default=64) 9 | parser.add_argument("--value_activation_type", type=str, default='tanh') 10 | parser.add_argument("--value_normalizer_type", type=str, default='none') 11 | 12 | parser.add_argument("--trust_region_method", type=str, default='ppo', 13 | help='["ppo", "trpo"]') 14 | parser.add_argument("--gae_lam", type=float, default=0.95) 15 | parser.add_argument("--fisher_cg_damping", type=float, default=0.1) 16 | parser.add_argument("--target_kl", type=float, default=0.01) 17 | parser.add_argument("--cg_iterations", type=int, default=10) 18 | 19 | # RS config 20 | parser.add_argument("--dynamics_val_percentage", type=float, default=0.33) 21 | parser.add_argument("--dynamics_val_max_size", type=int, default=5000) 22 | parser.add_argument("--num_planning_traj", type=int, default=10) 23 | parser.add_argument("--planning_depth", type=int, default=10) 24 | parser.add_argument("--mb_timesteps", type=int, default=2e4) 25 | 26 | # Imitation learning config 27 | parser.add_argument("--initial_policy_lr", type=float, default=1e-4) 28 | parser.add_argument("--initial_policy_bs", type=int, default=500) 29 | parser.add_argument("--dagger_iter", type=int, default=3) 30 | parser.add_argument("--dagger_epoch", type=int, default=70) 31 | parser.add_argument("--dagger_timesteps_per_iter", type=int, default=1750) 32 | 33 | parser.add_argument("--ppo_clip", type=float, default=0.1) 34 | parser.add_argument("--num_minibatches", type=int, default=10) 35 | parser.add_argument("--target_kl_high", type=float, default=2) 36 | parser.add_argument("--target_kl_low", type=float, default=0.5) 37 | parser.add_argument("--use_weight_decay", type=int, default=0) 38 | parser.add_argument("--weight_decay_coeff", type=float, default=1e-5) 39 | 40 | parser.add_argument("--use_kl_penalty", type=int, default=0) 41 | parser.add_argument("--kl_alpha", type=float, default=1.5) 42 | parser.add_argument("--kl_eta", type=float, default=50) 43 | 44 | parser.add_argument("--policy_lr_schedule", type=str, default='linear', 45 | help='["linear", "constant", "adaptive"]') 46 | parser.add_argument("--policy_lr_alpha", type=int, default=2) 47 | return parser 48 | -------------------------------------------------------------------------------- /mbbl/config/metrpo_config.py: -------------------------------------------------------------------------------- 1 | 2 | def get_metrpo_config(parser): 3 | # get the parameters 4 | parser.add_argument("--value_lr", type=float, default=3e-4) 5 | parser.add_argument("--value_epochs", type=int, default=20) 6 | parser.add_argument("--value_network_shape", type=str, default='64,64') 7 | # parser.add_argument("--value_batch_size", type=int, default=64) 8 | parser.add_argument("--value_activation_type", type=str, default='tanh') 9 | parser.add_argument("--value_normalizer_type", type=str, default='none') 10 | 11 | parser.add_argument("--gae_lam", type=float, default=0.95) 12 | parser.add_argument("--fisher_cg_damping", type=float, default=0.1) 13 | parser.add_argument("--target_kl", type=float, default=0.01) 14 | parser.add_argument("--cg_iterations", type=int, default=10) 15 | 16 | parser.add_argument("--dynamics_val_percentage", type=float, default=0.33) 17 | parser.add_argument("--dynamics_val_max_size", type=int, default=10000) 18 | parser.add_argument("--max_fake_timesteps", type=int, default=1e5) 19 | 20 | return parser 21 | -------------------------------------------------------------------------------- /mbbl/config/mf_config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @author: 3 | # Tingwu Wang 4 | # ------------------------------------------------------------------------------ 5 | 6 | 7 | def get_mf_config(parser): 8 | # get the parameters 9 | parser.add_argument("--value_lr", type=float, default=3e-4) 10 | parser.add_argument("--value_epochs", type=int, default=20) 11 | parser.add_argument("--value_network_shape", type=str, default='64,64') 12 | # parser.add_argument("--value_batch_size", type=int, default=64) 13 | parser.add_argument("--value_activation_type", type=str, default='tanh') 14 | parser.add_argument("--value_normalizer_type", type=str, default='none') 15 | 16 | parser.add_argument("--trust_region_method", type=str, default='trpo', 17 | help='["ppo", "trpo"]') 18 | parser.add_argument("--gae_lam", type=float, default=0.95) 19 | parser.add_argument("--fisher_cg_damping", type=float, default=0.1) 20 | parser.add_argument("--target_kl", type=float, default=0.01) 21 | parser.add_argument("--cg_iterations", type=int, default=10) 22 | 23 | parser.add_argument("--ppo_clip", type=float, default=0.1) 24 | parser.add_argument("--num_minibatches", type=int, default=10) 25 | parser.add_argument("--target_kl_high", type=float, default=2) 26 | parser.add_argument("--target_kl_low", type=float, default=0.5) 27 | parser.add_argument("--use_weight_decay", type=int, default=0) 28 | parser.add_argument("--weight_decay_coeff", type=float, default=1e-5) 29 | 30 | parser.add_argument("--use_kl_penalty", type=int, default=0) 31 | parser.add_argument("--kl_alpha", type=float, default=1.5) 32 | parser.add_argument("--kl_eta", type=float, default=50) 33 | 34 | parser.add_argument("--policy_lr_schedule", type=str, default='linear', 35 | help='["linear", "constant", "adaptive"]') 36 | parser.add_argument("--policy_lr_alpha", type=int, default=2) 37 | return parser 38 | -------------------------------------------------------------------------------- /mbbl/config/rs_config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # record the parameters here 4 | # @author: 5 | # Tingwu Wang, 2017, June, 12th 6 | # ------------------------------------------------------------------------------ 7 | 8 | 9 | def get_rs_config(parser): 10 | # get the parameters 11 | parser.add_argument("--dynamics_val_percentage", type=float, default=0.33) 12 | parser.add_argument("--dynamics_val_max_size", type=int, default=5000) 13 | parser.add_argument("--num_planning_traj", type=int, default=10) 14 | parser.add_argument("--planning_depth", type=int, default=10) 15 | parser.add_argument("--check_done", type=int, default=0) 16 | 17 | return parser 18 | -------------------------------------------------------------------------------- /mbbl/env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/env/__init__.py -------------------------------------------------------------------------------- /mbbl/env/base_env_wrapper.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @author: 3 | # Tingwu Wang 4 | # @brief: 5 | # The environment wrapper 6 | # ----------------------------------------------------------------------------- 7 | import numpy as np 8 | 9 | 10 | class base_env(object): 11 | 12 | def __init__(self, env_name, rand_seed, misc_info={}): 13 | self._env_name = env_name 14 | self._seed = rand_seed 15 | self._npr = np.random.RandomState(self._seed) 16 | self._misc_info = misc_info 17 | 18 | # build the environment 19 | self._build_env() 20 | self._set_groundtruth_api() 21 | 22 | def step(self, action): 23 | raise NotImplementedError 24 | 25 | def reset(self, control_info={}): 26 | raise NotImplementedError 27 | 28 | def _build_env(self): 29 | raise NotImplementedError 30 | 31 | def _set_groundtruth_api(self): 32 | """ @brief: 33 | In this function, we could provide the ground-truth dynamics 34 | and rewards APIs for the agent to call. 35 | For the new environments, if we don't set their ground-truth 36 | apis, then we cannot test the algorithm using ground-truth 37 | dynamics or reward 38 | """ 39 | 40 | def fdynamics(data_dict): 41 | raise NotImplementedError 42 | self.fdynamics = fdynamics 43 | 44 | def reward(data_dict): 45 | raise NotImplementedError 46 | self.reward = reward 47 | 48 | def reward_derivative(data_dict, target): 49 | raise NotImplementedError 50 | self.reward_derivative = reward_derivative 51 | -------------------------------------------------------------------------------- /mbbl/env/bullet_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/env/bullet_env/__init__.py -------------------------------------------------------------------------------- /mbbl/env/bullet_env/motion_capture_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class MotionCaptureData(object): 4 | def __init__(self): 5 | self.Reset() 6 | 7 | def Reset(self): 8 | self._motion_data = [] 9 | 10 | def Load(self, path): 11 | with open(path, 'r') as f: 12 | self._motion_data = json.load(f) 13 | 14 | def NumFrames(self): 15 | return len(self._motion_data['Frames']) 16 | 17 | def KeyFrameDuraction(self): 18 | return self._motion_data['Frames'][0][0] 19 | 20 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/env/dm_env/__init__.py -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/cheetah_pos.xml: -------------------------------------------------------------------------------- 1 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 94 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Functions to manage the common assets for domains.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | from dm_control.utils import io as resources 24 | 25 | _SUITE_DIR = os.path.dirname(os.path.dirname(__file__)) 26 | _FILENAMES = [ 27 | "./common/materials.xml", 28 | "./common/skybox.xml", 29 | "./common/visual.xml", 30 | ] 31 | 32 | ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename)) 33 | for filename in _FILENAMES} 34 | 35 | 36 | def read_model(model_filename): 37 | """Reads a model XML file and returns its contents as a string.""" 38 | return resources.GetResource(os.path.join(_SUITE_DIR, model_filename)) 39 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/common/materials.xml: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/common/skybox.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/common/visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/acrobot.xml: -------------------------------------------------------------------------------- 1 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/ball_in_cup.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/cartpole.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/cheetah.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 74 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/finger.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/fish.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 67 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/lqr.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/point_mass.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/reacher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /mbbl/env/dm_env/assets/reference/walker.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 71 | -------------------------------------------------------------------------------- /mbbl/env/fake_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class fake_env(object): 5 | 6 | def __init__(self, env, model): 7 | self._env = env 8 | self._model = model 9 | self._state = None 10 | self._current_step = 0 11 | self._max_length = self._env._env_info['max_length'] 12 | self._obs_bounds = (-1e5, 1e5) 13 | 14 | def step(self, action): 15 | self._current_step += 1 16 | next_state, reward = self._model(self._state, action) 17 | next_state = np.clip(next_state, *self._obs_bounds) 18 | self._state = next_state 19 | 20 | if self._current_step > self._max_length: 21 | done = True 22 | else: 23 | done = False 24 | return next_state, reward, done, None 25 | 26 | def reset(self): 27 | self._current_step = 0 28 | ob, reward, done, info = self._env.reset() 29 | self._state = ob 30 | 31 | return ob, reward, done, info 32 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/env/gym_env/__init__.py -------------------------------------------------------------------------------- /mbbl/env/gym_env/box2d/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Aug 1 12:59:34 2018 5 | 6 | @author: matthewszhang 7 | """ 8 | 9 | from env.gym_env.box2d.core import box2d_make 10 | from env.gym_env.box2d.wrappers import LunarLanderWrapper, \ 11 | WalkerWrapper, RacerWrapper, Box2D_Wrapper -------------------------------------------------------------------------------- /mbbl/env/gym_env/box2d/core.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Aug 1 12:13:58 2018 5 | 6 | @author: matthewszhang 7 | """ 8 | from mbbl.env.gym_env.box2d.wrappers import LunarLanderWrapper, WalkerWrapper, RacerWrapper 9 | 10 | _WRAPPER_DICT = {'LunarLanderContinuous':LunarLanderWrapper, 11 | 'LunarLander':LunarLanderWrapper, 12 | 'BipedalWalker':WalkerWrapper, 13 | 'BipedalWalkerHardcore':WalkerWrapper, 14 | 'CarRacing':RacerWrapper} 15 | 16 | def get_wrapper(gym_id): 17 | try: 18 | return _WRAPPER_DICT[gym_id] 19 | except: 20 | raise KeyError("Non-existing Box2D env") 21 | 22 | def box2d_make(gym_id): # naive build of environment, leave safeties for gym.make 23 | import re 24 | remove_version = re.compile(r'-v(\d+)$') # version safety 25 | gym_id_base = remove_version.sub('', gym_id) 26 | 27 | wrapper = get_wrapper(gym_id_base) 28 | return wrapper(gym_id) 29 | 30 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/box2d_racer.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @author: 3 | # Matthew Zhang 4 | # @brief: 5 | # Alternate code-up of the box2d lunar lander environment from gym 6 | # ----------------------------------------------------------------------------- 7 | from mbbl.env import base_env_wrapper as bew 8 | import mbbl.env.gym_env.box2d.wrappers 9 | 10 | Racer = mbbl.env.gym_env.box2d.wrappers.RacerWrapper 11 | 12 | class env(bew.base_env): 13 | pass 14 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/fix_swimmer/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | register( 3 | id='Fixswimmer-v1', 4 | entry_point='mbbl.env.gym_env.fix_swimmer.fixed_swimmer:fixedSwimmerEnv' 5 | ) 6 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/fix_swimmer/assets/fixed_swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 44 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/fix_swimmer/fixed_swimmer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | import os 5 | 6 | 7 | class fixedSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): 8 | 9 | def __init__(self): 10 | dir_path = os.path.dirname(os.path.realpath(__file__)) 11 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/fixed_swimmer.xml' % dir_path, 4) 12 | utils.EzPickle.__init__(self) 13 | 14 | def _step(self, a): 15 | ctrl_cost_coeff = 0.0001 16 | 17 | """ 18 | xposbefore = self.model.data.qpos[0, 0] 19 | self.do_simulation(a, self.frame_skip) 20 | xposafter = self.model.data.qpos[0, 0] 21 | """ 22 | 23 | self.xposbefore = self.model.data.site_xpos[0][0] / self.dt 24 | self.do_simulation(a, self.frame_skip) 25 | self.xposafter = self.model.data.site_xpos[0][0] / self.dt 26 | self.pos_diff = self.xposafter - self.xposbefore 27 | 28 | reward_fwd = self.xposafter - self.xposbefore 29 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum() 30 | reward = reward_fwd + reward_ctrl 31 | ob = self._get_obs() 32 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl) 33 | 34 | def _get_obs(self): 35 | qpos = self.model.data.qpos 36 | qvel = self.model.data.qvel 37 | return np.concatenate([qpos.flat[2:], qvel.flat, self.pos_diff.flat]) 38 | 39 | def reset_model(self): 40 | self.set_state( 41 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq), 42 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv) 43 | ) 44 | return self._get_obs() 45 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/pets_env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | 4 | register( 5 | id='MBRLCartpole-v0', 6 | entry_point='mbbl.env.gym_env.pets_env.cartpole:CartpoleEnv' 7 | ) 8 | 9 | 10 | register( 11 | id='MBRLReacher3D-v0', 12 | entry_point='mbbl.env.gym_env.pets_env.reacher:Reacher3DEnv' 13 | ) 14 | 15 | 16 | register( 17 | id='MBRLPusher-v0', 18 | entry_point='mbbl.env.gym_env.pets_env.pusher:PusherEnv' 19 | ) 20 | 21 | 22 | register( 23 | id='MBRLHalfCheetah-v0', 24 | entry_point='mbbl.env.gym_env.pets_env.half_cheetah:HalfCheetahEnv' 25 | ) 26 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/pets_env/assets/cartpole.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 35 | 36 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/pets_env/cartpole.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | 7 | import numpy as np 8 | from gym import utils 9 | from gym.envs.mujoco import mujoco_env 10 | 11 | 12 | class CartpoleEnv(mujoco_env.MujocoEnv, utils.EzPickle): 13 | PENDULUM_LENGTH = 0.6 14 | 15 | def __init__(self): 16 | utils.EzPickle.__init__(self) 17 | dir_path = os.path.dirname(os.path.realpath(__file__)) 18 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/cartpole.xml' % dir_path, 2) 19 | 20 | def _step(self, a): 21 | self.do_simulation(a, self.frame_skip) 22 | ob = self._get_obs() 23 | 24 | cost_lscale = CartpoleEnv.PENDULUM_LENGTH 25 | reward = np.exp( 26 | -np.sum(np.square(self._get_ee_pos(ob) - np.array([0.0, CartpoleEnv.PENDULUM_LENGTH]))) / (cost_lscale ** 2) 27 | ) 28 | reward -= 0.01 * np.sum(np.square(a)) 29 | 30 | done = False 31 | return ob, reward, done, {} 32 | 33 | def reset_model(self): 34 | qpos = self.init_qpos + np.random.normal(0, 0.1, np.shape(self.init_qpos)) 35 | qvel = self.init_qvel + np.random.normal(0, 0.1, np.shape(self.init_qvel)) 36 | self.set_state(qpos, qvel) 37 | return self._get_obs() 38 | 39 | def _get_obs(self): 40 | return np.concatenate([self.model.data.qpos, self.model.data.qvel]).ravel() 41 | 42 | @staticmethod 43 | def _get_ee_pos(x): 44 | x0, theta = x[0], x[1] 45 | return np.array([ 46 | x0 - CartpoleEnv.PENDULUM_LENGTH * np.sin(theta), 47 | -CartpoleEnv.PENDULUM_LENGTH * np.cos(theta) 48 | ]) 49 | 50 | def viewer_setup(self): 51 | v = self.viewer 52 | v.cam.trackbodyid = 0 53 | v.cam.distance = v.model.stat.extent 54 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/pets_env/half_cheetah.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | 7 | import numpy as np 8 | from gym import utils 9 | from gym.envs.mujoco import mujoco_env 10 | 11 | 12 | class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): 13 | 14 | def __init__(self): 15 | self.prev_qpos = None 16 | dir_path = os.path.dirname(os.path.realpath(__file__)) 17 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/half_cheetah.xml' % dir_path, 5) 18 | utils.EzPickle.__init__(self) 19 | 20 | def _step(self, action): 21 | self.prev_qpos = np.copy(self.model.data.qpos.flat) 22 | self.do_simulation(action, self.frame_skip) 23 | ob = self._get_obs() 24 | 25 | reward_ctrl = -0.1 * np.square(action).sum() 26 | reward_run = ob[0] - 0.0 * np.square(ob[2]) 27 | reward = reward_run + reward_ctrl 28 | 29 | done = False 30 | return ob, reward, done, {} 31 | 32 | def _get_obs(self): 33 | return np.concatenate([ 34 | (self.model.data.qpos.flat[:1] - self.prev_qpos[:1]) / self.dt, 35 | self.model.data.qpos.flat[1:], 36 | self.model.data.qvel.flat, 37 | ]) 38 | 39 | def reset_model(self): 40 | qpos = self.init_qpos + np.random.normal(loc=0, scale=0.001, size=self.model.nq) 41 | qvel = self.init_qvel + np.random.normal(loc=0, scale=0.001, size=self.model.nv) 42 | self.set_state(qpos, qvel) 43 | self.prev_qpos = np.copy(self.model.data.qpos.flat) 44 | return self._get_obs() 45 | 46 | def viewer_setup(self): 47 | self.viewer.cam.distance = self.model.stat.extent * 0.25 48 | self.viewer.cam.elevation = -55 49 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/pets_env/half_cheetah_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import gym 7 | 8 | 9 | class HalfCheetahConfigModule: 10 | ENV_NAME = "MBRLHalfCheetah-v0" 11 | TASK_HORIZON = 1000 12 | NTRAIN_ITERS = 300 13 | NROLLOUTS_PER_ITER = 1 14 | PLAN_HOR = 30 15 | INIT_VAR = 0.25 16 | MODEL_IN, MODEL_OUT = 24, 18 # obs - > 18, action 6 17 | GP_NINDUCING_POINTS = 300 18 | 19 | def __init__(self): 20 | self.ENV = gym.make(self.ENV_NAME) 21 | self.NN_TRAIN_CFG = {"epochs": 5} 22 | self.OPT_CFG = { 23 | "Random": { 24 | "popsize": 2500 25 | }, 26 | "GBPRandom": { 27 | "popsize": 2500 28 | }, 29 | "GBPCEM": { 30 | "popsize": 500, 31 | "num_elites": 50, 32 | "max_iters": 5, 33 | "alpha": 0.1 34 | }, 35 | "CEM": { 36 | "popsize": 500, 37 | "num_elites": 50, 38 | "max_iters": 5, 39 | "alpha": 0.1 40 | }, 41 | "PWCEM": { 42 | "popsize": 500, 43 | "num_elites": 50, 44 | "max_iters": 5, 45 | "alpha": 0.1 46 | }, 47 | "POCEM": { 48 | "popsize": 500, 49 | "num_elites": 50, 50 | "max_iters": 5, 51 | "alpha": 0.1 52 | } 53 | } 54 | 55 | @staticmethod 56 | def obs_preproc(obs): 57 | return np.concatenate([obs[:, 1:2], np.sin(obs[:, 2:3]), np.cos(obs[:, 2:3]), obs[:, 3:]], axis=1) 58 | 59 | @staticmethod 60 | def obs_postproc(obs, pred): 61 | return np.concatenate([pred[:, :1], obs[:, 1:] + pred[:, 1:]], axis=1) 62 | 63 | @staticmethod 64 | def targ_proc(obs, next_obs): 65 | return np.concatenate([next_obs[:, :1], next_obs[:, 1:] - obs[:, 1:]], axis=1) 66 | 67 | @staticmethod 68 | def obs_cost_fn(obs): 69 | return -obs[:, 0] 70 | 71 | @staticmethod 72 | def ac_cost_fn(acs): 73 | return 0.1 * np.sum(np.square(acs), axis=1) 74 | 75 | def nn_constructor(self, model_init_cfg, misc=None): 76 | pass 77 | 78 | def gp_constructor(self, model_init_cfg): 79 | pass 80 | 81 | 82 | CONFIG_MODULE = HalfCheetahConfigModule 83 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/pets_env/pusher.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | 7 | import numpy as np 8 | from gym import utils 9 | from gym.envs.mujoco import mujoco_env 10 | 11 | 12 | class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle): 13 | 14 | def __init__(self): 15 | dir_path = os.path.dirname(os.path.realpath(__file__)) 16 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/pusher.xml' % dir_path, 4) 17 | utils.EzPickle.__init__(self) 18 | self.reset_model() 19 | 20 | def _step(self, a): 21 | obj_pos = self.get_body_com("object"), 22 | vec_1 = obj_pos - self.get_body_com("tips_arm") 23 | vec_2 = obj_pos - self.get_body_com("goal") 24 | 25 | reward_near = -np.sum(np.abs(vec_1)) 26 | reward_dist = -np.sum(np.abs(vec_2)) 27 | reward_ctrl = -np.square(a).sum() 28 | reward = 1.25 * reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near 29 | 30 | self.do_simulation(a, self.frame_skip) 31 | ob = self._get_obs() 32 | done = False 33 | return ob, reward, done, {} 34 | 35 | def viewer_setup(self): 36 | self.viewer.cam.trackbodyid = -1 37 | self.viewer.cam.distance = 4.0 38 | 39 | def reset_model(self): 40 | qpos = self.init_qpos 41 | 42 | self.goal_pos = np.asarray([0, 0]) 43 | self.cylinder_pos = np.array([-0.25, 0.15]) + np.random.normal(0, 0.025, [2]) 44 | 45 | qpos[-4:-2] = self.cylinder_pos 46 | qpos[-2:] = self.goal_pos 47 | qvel = self.init_qvel + \ 48 | self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv) 49 | qvel[-4:] = 0 50 | self.set_state(qpos, qvel) 51 | self.ac_goal_pos = self.get_body_com("goal") 52 | 53 | return self._get_obs() 54 | 55 | def _get_obs(self): 56 | return np.concatenate([ 57 | self.model.data.qpos.flat[:7], 58 | self.model.data.qvel.flat[:7], 59 | self.get_body_com("tips_arm"), 60 | self.get_body_com("object"), 61 | self.get_body_com("goal"), 62 | ]) 63 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/pets_env/pusher_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import gym 7 | 8 | 9 | class PusherConfigModule: 10 | ENV_NAME = "MBRLPusher-v0" 11 | TASK_HORIZON = 150 12 | NTRAIN_ITERS = 100 13 | NROLLOUTS_PER_ITER = 1 14 | PLAN_HOR = 25 15 | INIT_VAR = 0.25 16 | MODEL_IN, MODEL_OUT = 27, 20 17 | GP_NINDUCING_POINTS = 200 18 | 19 | def __init__(self): 20 | self.ENV = gym.make(self.ENV_NAME) 21 | self.NN_TRAIN_CFG = {"epochs": 5} 22 | self.OPT_CFG = { 23 | "Random": { 24 | "popsize": 2500 25 | }, 26 | "CEM": { 27 | "popsize": 500, 28 | "num_elites": 50, 29 | "max_iters": 5, 30 | "alpha": 0.1 31 | }, 32 | "GBPRandom": { 33 | "popsize": 2500 34 | }, 35 | "GBPCEM": { 36 | "popsize": 500, 37 | "num_elites": 50, 38 | "max_iters": 5, 39 | "alpha": 0.1 40 | }, 41 | "PWCEM": { 42 | "popsize": 500, 43 | "num_elites": 50, 44 | "max_iters": 5, 45 | "alpha": 0.1 46 | }, 47 | "POCEM": { 48 | "popsize": 500, 49 | "num_elites": 50, 50 | "max_iters": 5, 51 | "alpha": 0.1 52 | } 53 | } 54 | 55 | @staticmethod 56 | def obs_postproc(obs, pred): 57 | return obs + pred 58 | 59 | @staticmethod 60 | def targ_proc(obs, next_obs): 61 | return next_obs - obs 62 | 63 | def obs_cost_fn(self, obs): 64 | to_w, og_w = 0.5, 1.25 65 | tip_pos, obj_pos, goal_pos = obs[:, 14:17], obs[:, 17:20], obs[:, -3:] 66 | 67 | tip_obj_dist = np.sum(np.abs(tip_pos - obj_pos), axis=1) 68 | obj_goal_dist = np.sum(np.abs(goal_pos - obj_pos), axis=1) 69 | return to_w * tip_obj_dist + og_w * obj_goal_dist 70 | 71 | @staticmethod 72 | def ac_cost_fn(acs): 73 | return 0.1 * np.sum(np.square(acs), axis=1) 74 | 75 | def nn_constructor(self, model_init_cfg, misc): 76 | pass 77 | 78 | def gp_constructor(self, model_init_cfg): 79 | pass 80 | 81 | 82 | CONFIG_MODULE = PusherConfigModule 83 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/pets_env/reacher.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | 7 | import numpy as np 8 | from gym import utils 9 | from gym.envs.mujoco import mujoco_env 10 | 11 | 12 | class Reacher3DEnv(mujoco_env.MujocoEnv, utils.EzPickle): 13 | 14 | def __init__(self): 15 | self.viewer = None 16 | utils.EzPickle.__init__(self) 17 | dir_path = os.path.dirname(os.path.realpath(__file__)) 18 | self.goal = np.zeros(3) 19 | mujoco_env.MujocoEnv.__init__(self, os.path.join(dir_path, 'assets/reacher3d.xml'), 2) 20 | 21 | def _step(self, a): 22 | self.do_simulation(a, self.frame_skip) 23 | ob = self._get_obs() 24 | reward = -np.sum(np.square(self.get_EE_pos(ob[None]) - self.goal)) 25 | reward -= 0.01 * np.square(a).sum() 26 | done = False 27 | return ob, reward, done, dict(reward_dist=0, reward_ctrl=0) 28 | 29 | def viewer_setup(self): 30 | self.viewer.cam.trackbodyid = 1 31 | self.viewer.cam.distance = 2.5 32 | self.viewer.cam.elevation = -30 33 | self.viewer.cam.azimuth = 270 34 | 35 | def reset_model(self): 36 | qpos, qvel = np.copy(self.init_qpos), np.copy(self.init_qvel) 37 | qpos[-3:] += np.random.normal(loc=0, scale=0.1, size=[3]) 38 | qvel[-3:] = 0 39 | self.goal = qpos[-3:] 40 | self.set_state(qpos, qvel) 41 | return self._get_obs() 42 | 43 | def _get_obs(self): 44 | raw_obs = np.concatenate([ 45 | self.model.data.qpos.flat, self.model.data.qvel.flat[:-3], 46 | ]) 47 | 48 | EE_pos = np.reshape(self.get_EE_pos(raw_obs[None]), [-1]) 49 | 50 | return np.concatenate([raw_obs, EE_pos]) 51 | 52 | def get_EE_pos(self, states): 53 | theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \ 54 | states[:, :1], states[:, 1:2], states[:, 2:3], states[:, 3:4], states[:, 4:5], states[:, 5:6], states[:, 6:] 55 | 56 | rot_axis = np.concatenate([np.cos(theta2) * np.cos(theta1), np.cos(theta2) * np.sin(theta1), -np.sin(theta2)], 57 | axis=1) 58 | rot_perp_axis = np.concatenate([-np.sin(theta1), np.cos(theta1), np.zeros(theta1.shape)], axis=1) 59 | cur_end = np.concatenate([ 60 | 0.1 * np.cos(theta1) + 0.4 * np.cos(theta1) * np.cos(theta2), 61 | 0.1 * np.sin(theta1) + 0.4 * np.sin(theta1) * np.cos(theta2) - 0.188, 62 | -0.4 * np.sin(theta2) 63 | ], axis=1) 64 | 65 | for length, hinge, roll in [(0.321, theta4, theta3), (0.16828, theta6, theta5)]: 66 | perp_all_axis = np.cross(rot_axis, rot_perp_axis) 67 | x = np.cos(hinge) * rot_axis 68 | y = np.sin(hinge) * np.sin(roll) * rot_perp_axis 69 | z = -np.sin(hinge) * np.cos(roll) * perp_all_axis 70 | new_rot_axis = x + y + z 71 | new_rot_perp_axis = np.cross(new_rot_axis, rot_axis) 72 | new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = \ 73 | rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] 74 | new_rot_perp_axis /= np.linalg.norm(new_rot_perp_axis, axis=1, keepdims=True) 75 | rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis 76 | 77 | return cur_end 78 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/pets_env/reacher_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import gym 7 | 8 | 9 | class ReacherConfigModule: 10 | ENV_NAME = "MBRLReacher3D-v0" 11 | TASK_HORIZON = 150 12 | NTRAIN_ITERS = 100 13 | NROLLOUTS_PER_ITER = 1 14 | PLAN_HOR = 25 15 | INIT_VAR = 0.25 16 | MODEL_IN, MODEL_OUT = 24, 17 17 | GP_NINDUCING_POINTS = 200 18 | 19 | def __init__(self): 20 | self.ENV = gym.make(self.ENV_NAME) 21 | self.ENV.reset() 22 | self.NN_TRAIN_CFG = {"epochs": 5} 23 | self.OPT_CFG = { 24 | "Random": { 25 | "popsize": 2000 26 | }, 27 | "CEM": { 28 | "popsize": 400, 29 | "num_elites": 40, 30 | "max_iters": 5, 31 | "alpha": 0.1 32 | }, 33 | "GBPRandom": { 34 | "popsize": 2000 35 | }, 36 | "GBPCEM": { 37 | "popsize": 400, 38 | "num_elites": 40, 39 | "max_iters": 5, 40 | "alpha": 0.1 41 | }, 42 | "PWCEM": { 43 | "popsize": 400, 44 | "num_elites": 40, 45 | "max_iters": 5, 46 | "alpha": 0.1 47 | }, 48 | "POCEM": { 49 | "popsize": 400, 50 | "num_elites": 40, 51 | "max_iters": 5, 52 | "alpha": 0.1 53 | } 54 | } 55 | self.UPDATE_FNS = [self.update_goal] 56 | 57 | @staticmethod 58 | def obs_postproc(obs, pred): 59 | return obs + pred 60 | 61 | @staticmethod 62 | def targ_proc(obs, next_obs): 63 | return next_obs - obs 64 | 65 | def update_goal(self, sess=None): 66 | if sess is not None: 67 | self.goal.load(self.ENV.goal, sess) 68 | 69 | def obs_cost_fn(self, obs): 70 | self.ENV.goal = obs[:, 7: 10] 71 | ee_pos = obs[:, -3:] 72 | return np.sum(np.square(ee_pos - self.ENV.goal), axis=1) 73 | 74 | @staticmethod 75 | def ac_cost_fn(acs): 76 | return 0.01 * np.sum(np.square(acs), axis=1) 77 | 78 | def nn_constructor(self, model_init_cfg, misc=None): 79 | pass 80 | 81 | def gp_constructor(self, model_init_cfg): 82 | pass 83 | 84 | @staticmethod 85 | def get_ee_pos(states, are_tensors=False): 86 | theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \ 87 | states[:, :1], states[:, 1:2], states[:, 2:3], states[:, 3:4], states[:, 4:5], states[:, 5:6], states[:, 6:] 88 | rot_axis = np.concatenate([np.cos(theta2) * np.cos(theta1), np.cos(theta2) * np.sin(theta1), -np.sin(theta2)], 89 | axis=1) 90 | rot_perp_axis = np.concatenate([-np.sin(theta1), np.cos(theta1), np.zeros(theta1.shape)], axis=1) 91 | cur_end = np.concatenate([ 92 | 0.1 * np.cos(theta1) + 0.4 * np.cos(theta1) * np.cos(theta2), 93 | 0.1 * np.sin(theta1) + 0.4 * np.sin(theta1) * np.cos(theta2) - 0.188, 94 | -0.4 * np.sin(theta2) 95 | ], axis=1) 96 | 97 | for length, hinge, roll in [(0.321, theta4, theta3), (0.16828, theta6, theta5)]: 98 | perp_all_axis = np.cross(rot_axis, rot_perp_axis) 99 | x = np.cos(hinge) * rot_axis 100 | y = np.sin(hinge) * np.sin(roll) * rot_perp_axis 101 | z = -np.sin(hinge) * np.cos(roll) * perp_all_axis 102 | new_rot_axis = x + y + z 103 | new_rot_perp_axis = np.cross(new_rot_axis, rot_axis) 104 | new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = \ 105 | rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] 106 | new_rot_perp_axis /= np.linalg.norm(new_rot_perp_axis, axis=1, keepdims=True) 107 | rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis 108 | 109 | return cur_end 110 | 111 | 112 | CONFIG_MODULE = ReacherConfigModule 113 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/point.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import gym 4 | import numpy as np 5 | 6 | from mbbl.config import init_path 7 | from mbbl.env import base_env_wrapper as bew 8 | from mbbl.env import env_register 9 | from mbbl.env.gym_env import point_env 10 | 11 | 12 | class env(bew.base_env): 13 | 14 | POINT = ['gym_point'] 15 | 16 | def __init__(self, env_name, rand_seed, misc_info): 17 | 18 | super(env, self).__init__(env_name, rand_seed, misc_info) 19 | self._base_path = init_path.get_abs_base_dir() 20 | 21 | 22 | def step(self, action): 23 | 24 | ob, reward, _, info = self._env.step(action) 25 | 26 | self._current_step += 1 27 | if self._current_step > self._env_info['max_length']: 28 | done = True 29 | else: 30 | done = False # will raise warnings -> set logger flag to ignore 31 | self._old_ob = np.array(ob) 32 | return ob, reward, done, info 33 | 34 | def reset(self): 35 | self._current_step = 0 36 | self._old_ob = self._env.reset() 37 | return self._old_ob, 0.0, False, {} 38 | 39 | def _build_env(self): 40 | _env_name = { 41 | 'gym_point': 'Point-v0', 42 | } 43 | 44 | # make the environments 45 | self._env = gym.make(_env_name[self._env_name]) 46 | self._env_info = env_register.get_env_info(self._env_name) 47 | 48 | def _set_dynamics_api(self): 49 | 50 | def fdynamics(data_dict): 51 | self._env.env.state = data_dict['start_state'] 52 | action = data_dict['action'] 53 | return self._env.step(action)[0] 54 | self.fdynamics = fdynamics 55 | 56 | def _set_reward_api(self): 57 | 58 | def reward(data_dict): 59 | action = np.clip(data_dict['action'], -0.025, 0.025) 60 | state = np.clip(data_dict['start_state'] + action, -1, 1) 61 | return -np.linalg.norm(state) 62 | 63 | self.reward = reward 64 | 65 | def _set_groundtruth_api(self): 66 | self._set_dynamics_api() 67 | self._set_reward_api() 68 | 69 | -------------------------------------------------------------------------------- /mbbl/env/gym_env/point_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | This environment is adapted from the Point environment developed by Rocky Duan, Peter Chen, Pieter Abbeel for the Berkeley Deep RL Bootcamp, August 2017. Bootcamp website with slides and lecture videos: https://sites.google.com/view/deep-rl-bootcamp/. 3 | """ 4 | 5 | from gym import Env 6 | from gym import spaces 7 | from gym.envs.registration import register 8 | from gym.utils import seeding 9 | import numpy as np 10 | 11 | 12 | class PointEnv(Env): 13 | metadata = { 14 | 'render.modes': ['human', 'rgb_array'], 15 | 'video.frames_per_second': 50 16 | } 17 | 18 | def __init__(self): 19 | self.action_space = spaces.Box(low=-1, high=1, shape=(2,)) 20 | self.observation_space = spaces.Box(low=-1, high=1, shape=(2,)) 21 | 22 | self._seed() 23 | self.viewer = None 24 | self.state = None 25 | 26 | def _seed(self, seed=None): 27 | self.np_random, seed = seeding.np_random(seed) 28 | return [seed] 29 | 30 | def _step(self, action): 31 | action = np.clip(action, -0.025, 0.025) 32 | self.state = np.clip(self.state + action, -1, 1) 33 | return np.array(self.state), -np.linalg.norm(self.state), False, {} 34 | 35 | def _reset(self): 36 | while True: 37 | self.state = self.np_random.uniform(low=-1, high=1, size=(2,)) 38 | # Sample states that are far away 39 | if np.linalg.norm(self.state) > 0.9: 40 | break 41 | return np.array(self.state) 42 | 43 | def _render(self, mode='human', close=False): 44 | if close: 45 | if self.viewer is not None: 46 | self.viewer.close() 47 | self.viewer = None 48 | return 49 | 50 | screen_width = 800 51 | screen_height = 800 52 | 53 | if self.viewer is None: 54 | from gym.envs.classic_control import rendering 55 | self.viewer = rendering.Viewer(screen_width, screen_height) 56 | 57 | agent = rendering.make_circle( 58 | min(screen_height, screen_width) * 0.03) 59 | origin = rendering.make_circle( 60 | min(screen_height, screen_width) * 0.03) 61 | trans = rendering.Transform(translation=(0, 0)) 62 | agent.add_attr(trans) 63 | self.trans = trans 64 | agent.set_color(1, 0, 0) 65 | origin.set_color(0, 0, 0) 66 | origin.add_attr(rendering.Transform( 67 | translation=(screen_width // 2, screen_height // 2))) 68 | self.viewer.add_geom(agent) 69 | self.viewer.add_geom(origin) 70 | 71 | # self.trans.set_translation(0, 0) 72 | self.trans.set_translation( 73 | (self.state[0] + 1) / 2 * screen_width, 74 | (self.state[1] + 1) / 2 * screen_height, 75 | ) 76 | 77 | return self.viewer.render(return_rgb_array=mode == 'rgb_array') 78 | 79 | register( 80 | 'Point-v0', 81 | entry_point='env.gym_env.point_env:PointEnv', 82 | timestep_limit=40, 83 | ) 84 | -------------------------------------------------------------------------------- /mbbl/env/render.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Aug 8 14:55:24 2018 5 | 6 | @author: matthewszhang 7 | """ 8 | import pickle 9 | import time 10 | 11 | import numpy as np 12 | 13 | from mbbl.env.env_register import make_env 14 | 15 | class rendering(object): 16 | def __init__(self, env_name): 17 | self.env, _ = make_env(env_name, 1234) 18 | self.env.reset() 19 | 20 | def render(self, transition): 21 | self.env.fdynamics(transition) 22 | self.env._env.render() 23 | time.sleep(1/30) 24 | 25 | def get_rendering_config(): 26 | import argparse 27 | parser = argparse.ArgumentParser(description='Get rendering from states') 28 | 29 | parser.add_argument("--task", type=str, default='gym_cheetah', 30 | help='the mujoco environment to test') 31 | parser.add_argument("--render_file", type=str, default='ep-0', 32 | help='pickle outputs to render') 33 | return parser 34 | 35 | def main(): 36 | parser = get_rendering_config() 37 | args = parser.parse_args() 38 | 39 | render_env = rendering(args.task) 40 | with open(args.render_file, 'rb') as pickle_load: 41 | transition_data = pickle.load(pickle_load) 42 | 43 | for transition in transition_data: 44 | render_transition = { 45 | 'start_state':np.asarray(transition['start_state']), 46 | 'action':np.asarray(transition['action']), 47 | } 48 | render_env.render(render_transition) 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /mbbl/env/render_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Aug 8 12:03:11 2018 5 | 6 | @author: matthewszhang 7 | """ 8 | import os.path as osp 9 | import pickle 10 | import re 11 | 12 | from mbbl.util.common import logger 13 | 14 | RENDER_EPISODE = 100 15 | 16 | class render_wrapper(object): 17 | def __init__(self, env_name, *args, **kwargs): 18 | remove_render = re.compile(r'__render$') 19 | 20 | self.env_name = remove_render.sub('', env_name) 21 | from mbbl.env import env_register 22 | self.env, _ = env_register.make_env(self.env_name, *args, **kwargs) 23 | self.episode_number = 0 24 | 25 | # Getting path from logger 26 | self.path = logger._get_path() 27 | self.obs_buffer = [] 28 | 29 | def step(self, action, *args, **kwargs): 30 | if (self.episode_number - 1) % RENDER_EPISODE == 0: 31 | self.obs_buffer.append({ 32 | 'start_state':self.env._old_ob.tolist(), 33 | 'action':action.tolist() 34 | }) 35 | return self.env.step(action, *args, **kwargs) 36 | 37 | def reset(self, *args, **kwargs): 38 | self.episode_number += 1 39 | if self.obs_buffer: 40 | file_name = osp.join(self.path, 'ep_{}.p'.format(self.episode_number)) 41 | with open(file_name, 'wb') as pickle_file: 42 | pickle.dump(self.obs_buffer, pickle_file, protocol=pickle.HIGHEST_PROTOCOL) 43 | self.obs_buffer = [] 44 | 45 | return self.env.reset(*args, **kwargs) 46 | 47 | def fdynamics(self, *args, **kwargs): 48 | return self.env.fdynamics(*args, **kwargs) 49 | 50 | def reward(self, *args, **kwargs): 51 | return self.env.reward(*args, **kwargs) 52 | 53 | def reward_derivative(self, *args, **kwargs): 54 | return self.env.reward_derivative(*args, **kwargs) 55 | -------------------------------------------------------------------------------- /mbbl/network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/network/__init__.py -------------------------------------------------------------------------------- /mbbl/network/dynamics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/network/dynamics/__init__.py -------------------------------------------------------------------------------- /mbbl/network/dynamics/stochastic_forward_dynamics.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @author: 3 | # Tingwu Wang 4 | # @brief: 5 | # ----------------------------------------------------------------------------- 6 | from .base_dynamics import base_dynamics_network 7 | from mbbl.config import init_path 8 | from mbbl.util.common import logger 9 | from mbbl.util.common import tf_networks 10 | 11 | 12 | class dynamics_network(base_dynamics_network): 13 | ''' 14 | @brief: 15 | ''' 16 | 17 | def __init__(self, args, session, name_scope, 18 | observation_size, action_size): 19 | ''' 20 | @input: 21 | @ob_placeholder: 22 | if this placeholder is not given, we will make one in this 23 | class. 24 | 25 | @trainable: 26 | If it is set to true, then the policy weights will be 27 | trained. It is useful when the class is a subnet which 28 | is not trainable 29 | ''' 30 | raise NotImplementedError 31 | super(dynamics_network, self).__init__( 32 | args, session, name_scope, observation_size, action_size 33 | ) 34 | self._base_dir = init_path.get_abs_base_dir() 35 | self._debug_it = 0 36 | -------------------------------------------------------------------------------- /mbbl/network/policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/network/policy/__init__.py -------------------------------------------------------------------------------- /mbbl/network/policy/base_policy.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @author: 3 | # Tingwu Wang 4 | # ----------------------------------------------------------------------------- 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from mbbl.config import init_path 9 | from mbbl.util.common import whitening_util 10 | from mbbl.util.common import tf_utils 11 | 12 | 13 | class base_policy_network(object): 14 | ''' 15 | @brief: 16 | In this object class, we define the network structure, the restore 17 | function and save function. 18 | It will only be called in the agent/agent.py 19 | ''' 20 | 21 | def __init__(self, args, session, name_scope, 22 | observation_size, action_size): 23 | self.args = args 24 | 25 | self._session = session 26 | self._name_scope = name_scope 27 | 28 | self._observation_size = observation_size 29 | self._action_size = action_size 30 | 31 | self._task_name = args.task_name 32 | self._network_shape = args.policy_network_shape 33 | 34 | self._npr = np.random.RandomState(args.seed) 35 | 36 | self._whitening_operator = {} 37 | self._whitening_variable = [] 38 | self._base_dir = init_path.get_abs_base_dir() 39 | 40 | def build_network(self): 41 | raise NotImplementedError 42 | 43 | def build_loss(self): 44 | raise NotImplementedError 45 | 46 | def _build_ph(self): 47 | 48 | # initialize the running mean and std (whitening) 49 | whitening_util.add_whitening_operator( 50 | self._whitening_operator, self._whitening_variable, 51 | 'state', self._observation_size 52 | ) 53 | 54 | # initialize the input placeholder 55 | self._input_ph = { 56 | 'start_state': tf.placeholder( 57 | tf.float32, [None, self._observation_size], name='start_state' 58 | ) 59 | } 60 | 61 | def get_input_placeholder(self): 62 | return self._input_ph 63 | 64 | def get_weights(self): 65 | return None 66 | 67 | def set_weights(self, weights_dict): 68 | pass 69 | 70 | def _set_var_list(self): 71 | # collect the tf variable and the trainable tf variable 72 | self._trainable_var_list = [var for var in tf.trainable_variables() 73 | if self._name_scope in var.name] 74 | 75 | self._all_var_list = [var for var in tf.global_variables() 76 | if self._name_scope in var.name] 77 | 78 | # the weights that actually matter 79 | self._network_var_list = \ 80 | self._trainable_var_list + self._whitening_variable 81 | 82 | self._set_network_weights = tf_utils.set_network_weights( 83 | self._session, self._network_var_list, self._name_scope 84 | ) 85 | 86 | self._get_network_weights = tf_utils.get_network_weights( 87 | self._session, self._network_var_list, self._name_scope 88 | ) 89 | 90 | def load_checkpoint(self, ckpt_path): 91 | pass 92 | 93 | def save_checkpoint(self, ckpt_path): 94 | pass 95 | 96 | def get_whitening_operator(self): 97 | return self._whitening_operator 98 | 99 | def _set_whitening_var(self, whitening_stats): 100 | whitening_util.set_whitening_var( 101 | self._session, self._whitening_operator, whitening_stats, ['state'] 102 | ) 103 | 104 | def train(self, data_dict, replay_buffer, training_info={}): 105 | raise NotImplementedError 106 | 107 | def eval(self, data_dict): 108 | raise NotImplementedError 109 | 110 | def act(self, data_dict): 111 | raise NotImplementedError 112 | -------------------------------------------------------------------------------- /mbbl/network/policy/cem_policy.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @author: 3 | # Tingwu Wang 4 | # ----------------------------------------------------------------------------- 5 | from .base_policy import base_policy_network 6 | from mbbl.config import init_path 7 | 8 | 9 | class policy_network(base_policy_network): 10 | ''' 11 | @brief: 12 | In this object class, we define the network structure, the restore 13 | function and save function. 14 | It will only be called in the agent/agent.py 15 | ''' 16 | 17 | def __init__(self, args, session, name_scope, 18 | observation_size, action_size): 19 | 20 | super(policy_network, self).__init__( 21 | args, session, name_scope, observation_size, action_size 22 | ) 23 | self._base_dir = init_path.get_abs_base_dir() 24 | 25 | def build_network(self): 26 | pass 27 | 28 | def build_loss(self): 29 | pass 30 | 31 | def train(self, data_dict, replay_buffer, training_info={}): 32 | pass 33 | 34 | def act(self, data_dict): 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /mbbl/network/policy/random_policy.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @author: 3 | # Tingwu Wang 4 | # ----------------------------------------------------------------------------- 5 | from .base_policy import base_policy_network 6 | from mbbl.config import init_path 7 | 8 | 9 | class policy_network(base_policy_network): 10 | ''' 11 | @brief: 12 | In this object class, we define the network structure, the restore 13 | function and save function. 14 | It will only be called in the agent/agent.py 15 | ''' 16 | 17 | def __init__(self, args, session, name_scope, 18 | observation_size, action_size): 19 | 20 | super(policy_network, self).__init__( 21 | args, session, name_scope, observation_size, action_size 22 | ) 23 | self._base_dir = init_path.get_abs_base_dir() 24 | 25 | def build_network(self): 26 | pass 27 | 28 | def build_loss(self): 29 | pass 30 | 31 | def train(self, data_dict, replay_buffer, training_info={}): 32 | pass 33 | 34 | def act(self, data_dict): 35 | # action range from -1 to 1 36 | action = self._npr.uniform( 37 | -1, 1, [len(data_dict['start_state']), self._action_size] 38 | ) 39 | return action, -1, -1 40 | -------------------------------------------------------------------------------- /mbbl/network/reward/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/network/reward/__init__.py -------------------------------------------------------------------------------- /mbbl/network/reward/base_reward.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @author: 3 | # Tingwu Wang 4 | # ----------------------------------------------------------------------------- 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from mbbl.config import init_path 9 | from mbbl.util.common import whitening_util 10 | 11 | 12 | class base_reward_network(object): 13 | ''' 14 | @brief: 15 | ''' 16 | 17 | def __init__(self, args, session, name_scope, 18 | observation_size, action_size): 19 | self.args = args 20 | 21 | self._session = session 22 | self._name_scope = name_scope 23 | 24 | self._observation_size = observation_size 25 | self._action_size = action_size 26 | self._output_size = 1 27 | 28 | self._task_name = args.task_name 29 | self._network_shape = args.reward_network_shape 30 | 31 | self._npr = np.random.RandomState(args.seed) 32 | 33 | self._whitening_operator = {} 34 | self._whitening_variable = [] 35 | self._base_dir = init_path.get_abs_base_dir() 36 | 37 | def build_network(self): 38 | pass 39 | 40 | def build_loss(self): 41 | pass 42 | 43 | def get_weights(self): 44 | return None 45 | 46 | def set_weights(self, weights_dict): 47 | pass 48 | 49 | def _build_ph(self): 50 | 51 | # initialize the running mean and std (whitening) 52 | whitening_util.add_whitening_operator( 53 | self._whitening_operator, self._whitening_variable, 54 | 'state', self._observation_size 55 | ) 56 | 57 | # initialize the input placeholder 58 | self._input_ph = { 59 | 'start_state': tf.placeholder( 60 | tf.float32, [None, self._observation_size], name='start_state' 61 | ) 62 | } 63 | 64 | def get_input_placeholder(self): 65 | return self._input_ph 66 | 67 | def _set_var_list(self): 68 | # collect the tf variable and the trainable tf variable 69 | self._trainable_var_list = [var for var in tf.trainable_variables() 70 | if self._name_scope in var.name] 71 | 72 | self._all_var_list = [var for var in tf.global_variables() 73 | if self._name_scope in var.name] 74 | 75 | # the weights that actually matter 76 | self._network_var_list = \ 77 | self._trainable_var_list + self._whitening_variable 78 | 79 | def load_checkpoint(self, ckpt_path): 80 | pass 81 | 82 | def save_checkpoint(self, ckpt_path): 83 | pass 84 | 85 | def get_whitening_operator(self): 86 | return self._whitening_operator 87 | 88 | def _set_whitening_var(self, whitening_stats): 89 | whitening_util.set_whitening_var( 90 | self._session, self._whitening_operator, whitening_stats, ['state'] 91 | ) 92 | 93 | def train(self, data_dict, replay_buffer, training_info={}): 94 | pass 95 | 96 | def eval(self, data_dict): 97 | raise NotImplementedError 98 | 99 | def pred(self, data_dict): 100 | raise NotImplementedError 101 | 102 | def use_groundtruth_network(self): 103 | return False 104 | -------------------------------------------------------------------------------- /mbbl/network/reward/deepmimic_reward.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @author: 3 | # Tingwu Wang 4 | # @brief: 5 | # ----------------------------------------------------------------------------- 6 | 7 | from .base_reward import base_reward_network 8 | from mbbl.config import init_path 9 | from mbbl.util.il import expert_data_util 10 | # import numpy as np 11 | # from mbbl.util.common import tf_networks 12 | # from mbbl.util.common import tf_utils 13 | # from mbbl.util.common import logger 14 | # import tensorflow as tf 15 | 16 | 17 | class reward_network(base_reward_network): 18 | ''' 19 | @brief: 20 | ''' 21 | 22 | def __init__(self, args, session, name_scope, 23 | observation_size, action_size): 24 | ''' 25 | @input: 26 | @ob_placeholder: 27 | if this placeholder is not given, we will make one in this 28 | class. 29 | 30 | @trainable: 31 | If it is set to true, then the policy weights will be 32 | trained. It is useful when the class is a subnet which 33 | is not trainable 34 | ''' 35 | super(reward_network, self).__init__( 36 | args, session, name_scope, observation_size, action_size 37 | ) 38 | self._base_dir = init_path.get_abs_base_dir() 39 | 40 | # load the expert data 41 | self._expert_trajectory_obs = expert_data_util.load_expert_trajectory( 42 | self.args.expert_data_name, self.args.traj_episode_num 43 | ) 44 | 45 | def build_network(self): 46 | """ @Brief: 47 | in deepmimic, we don't need a neural network to produce reward 48 | """ 49 | pass 50 | 51 | def build_loss(self): 52 | """ @Brief: 53 | Similarly, in deepmimic, we don't need a neural network 54 | """ 55 | pass 56 | 57 | def train(self, data_dict, replay_buffer, training_info={}): 58 | """ @brief: 59 | """ 60 | return {} 61 | 62 | def eval(self, data_dict): 63 | pass 64 | 65 | def use_groundtruth_network(self): 66 | return False 67 | 68 | def generate_rewards(self, rollout_data): 69 | """@brief: 70 | This function should be called before _preprocess_data 71 | """ 72 | for path in rollout_data: 73 | # the predicted value function (baseline function) 74 | path["raw_rewards"] = path['rewards'] # preserve the raw reward 75 | 76 | # TODO: generate the r = r_task + r_imitation, see the paper, 77 | # use self._expert_trajectory_obs 78 | path["rewards"] = 0.0 * path['raw_rewards'] 79 | return rollout_data 80 | -------------------------------------------------------------------------------- /mbbl/network/reward/groundtruth_reward.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @author: 3 | # Tingwu Wang 4 | # @brief: 5 | # ----------------------------------------------------------------------------- 6 | import numpy as np 7 | 8 | from .base_reward import base_reward_network 9 | from mbbl.config import init_path 10 | from mbbl.env import env_register 11 | 12 | 13 | class reward_network(base_reward_network): 14 | ''' 15 | @brief: 16 | ''' 17 | 18 | def __init__(self, args, session, name_scope, 19 | observation_size, action_size): 20 | ''' 21 | @input: 22 | @ob_placeholder: 23 | if this placeholder is not given, we will make one in this 24 | class. 25 | 26 | @trainable: 27 | If it is set to true, then the policy weights will be 28 | trained. It is useful when the class is a subnet which 29 | is not trainable 30 | ''' 31 | super(reward_network, self).__init__( 32 | args, session, name_scope, observation_size, action_size 33 | ) 34 | self._base_dir = init_path.get_abs_base_dir() 35 | 36 | def build_network(self): 37 | self._env, self._env_info = env_register.make_env( 38 | self.args.task_name, self._npr.randint(0, 9999) 39 | ) 40 | 41 | def build_loss(self): 42 | pass 43 | 44 | def train(self, data_dict, replay_buffer, training_info={}): 45 | pass 46 | 47 | def eval(self, data_dict): 48 | pass 49 | 50 | def pred(self, data_dict): 51 | reward = [] 52 | for i_data in range(len(data_dict['action'])): 53 | key_list = ['start_state', 'action', 'next_state'] \ 54 | if 'next_state' in data_dict else ['start_state', 'action'] 55 | 56 | i_reward = self._env.reward( 57 | {key: data_dict[key][i_data] for key in key_list} 58 | ) 59 | reward.append(i_reward) 60 | return np.stack(reward), -1, -1 61 | 62 | def use_groundtruth_network(self): 63 | return True 64 | 65 | def get_derivative(self, data_dict, target): 66 | derivative_data = {} 67 | for derivative_key in target: 68 | derivative_data[derivative_key] = \ 69 | self._env.reward_derivative(data_dict, derivative_key) 70 | return derivative_data 71 | -------------------------------------------------------------------------------- /mbbl/sampler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/sampler/__init__.py -------------------------------------------------------------------------------- /mbbl/sampler/base_sampler.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Tingwu Wang 3 | # ----------------------------------------------------------------------------- 4 | import multiprocessing 5 | 6 | import numpy as np 7 | 8 | from mbbl.config import init_path 9 | from mbbl.env import env_register 10 | from mbbl.util.common import parallel_util 11 | 12 | 13 | class base_sampler(object): 14 | 15 | def __init__(self, args, worker_type, network_type): 16 | ''' 17 | @brief: 18 | the master agent has several actors (or samplers) to do the 19 | sampling for it. 20 | ''' 21 | self.args = args 22 | self._npr = np.random.RandomState(args.seed + 23333) 23 | self._observation_size, self._action_size, _ = \ 24 | env_register.io_information(self.args.task) 25 | self._worker_type = worker_type 26 | self._network_type = network_type 27 | 28 | # init the multiprocess actors 29 | self._task_queue = multiprocessing.JoinableQueue() 30 | self._result_queue = multiprocessing.Queue() 31 | self._init_workers() 32 | self._build_env() 33 | self._base_path = init_path.get_abs_base_dir() 34 | 35 | self._current_iteration = 0 36 | 37 | def set_weights(self, weights): 38 | for i_agent in range(self.args.num_workers): 39 | self._task_queue.put((parallel_util.AGENT_SET_WEIGHTS, 40 | weights)) 41 | self._task_queue.join() 42 | 43 | def end(self): 44 | for i_agent in range(self.args.num_workers): 45 | self._task_queue.put((parallel_util.END_ROLLOUT_SIGNAL, None)) 46 | 47 | def rollouts_using_worker_planning(self, num_timesteps=None, 48 | use_random_action=False): 49 | """ @brief: 50 | Workers are only used to do the planning. 51 | The sampler will choose the control signals and interact with 52 | the env. 53 | 54 | Run the experiments until a total of @timesteps_per_batch 55 | timesteps are collected. 56 | @return: 57 | {'data': None} 58 | """ 59 | raise NotImplementedError 60 | 61 | def rollouts_using_worker_playing(self, num_timesteps=None, 62 | use_random_action=False, 63 | using_true_env=False): 64 | """ @brief: 65 | Workers are used to do the planning, choose the control signals 66 | and interact with the env. The sampler will choose the control 67 | signals and interact with the env. 68 | 69 | Run the experiments until a total of @timesteps_per_batch 70 | timesteps are collected. 71 | @input: 72 | If @using_true_env is set to True, the worker will interact with 73 | the environment. Otherwise it will interact with the env it 74 | models (the trainable dynamics and reward) 75 | @return: 76 | {'data': None} 77 | """ 78 | raise NotImplementedError 79 | 80 | def _init_workers(self): 81 | ''' 82 | @brief: init the actors and start the multiprocessing 83 | ''' 84 | self._actors = [] 85 | 86 | # the sub actor that only do the sampling 87 | for i in range(self.args.num_workers): 88 | self._actors.append( 89 | self._worker_type.worker( 90 | self.args, self._observation_size, self._action_size, 91 | self._network_type, self._task_queue, self._result_queue, i, 92 | ) 93 | ) 94 | 95 | # todo: start 96 | for i_actor in self._actors: 97 | i_actor.start() 98 | 99 | def _build_env(self): 100 | self._env, self._env_info = env_register.make_env( 101 | self.args.task, self._npr.randint(0, 999999), 102 | {'allow_monitor': self.args.monitor} 103 | ) 104 | -------------------------------------------------------------------------------- /mbbl/sampler/readme.md: -------------------------------------------------------------------------------- 1 | rs: random shooting method 2 | -------------------------------------------------------------------------------- /mbbl/sampler/singletask_metrpo_sampler.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Tingwu Wang 3 | # ----------------------------------------------------------------------------- 4 | import numpy as np 5 | 6 | from .base_sampler import base_sampler 7 | from mbbl.config import init_path 8 | from mbbl.util.common import logger 9 | from mbbl.util.common import parallel_util 10 | 11 | 12 | class sampler(base_sampler): 13 | 14 | def __init__(self, args, worker_type, network_type): 15 | ''' 16 | @brief: 17 | the master agent has several actors (or samplers) to do the 18 | sampling for it. 19 | ''' 20 | super(sampler, self).__init__(args, worker_type, network_type) 21 | self._base_path = init_path.get_abs_base_dir() 22 | self._avg_episode_len = self._env_info['max_length'] 23 | 24 | def rollouts_using_worker_playing(self, num_timesteps=None, 25 | use_random_action=False, 26 | use_true_env=False): 27 | """ @brief: 28 | In this case, the sampler will call workers to generate data 29 | """ 30 | self._current_iteration += 1 31 | num_timesteps_received = 0 32 | timesteps_needed = self.args.timesteps_per_batch \ 33 | if num_timesteps is None else num_timesteps 34 | rollout_data = [] 35 | 36 | while True: 37 | # how many episodes are expected to complete the current dataset? 38 | num_estimiated_episode = int( 39 | np.ceil(timesteps_needed / self._avg_episode_len) 40 | ) 41 | 42 | # send out the task for each worker to play 43 | for _ in range(num_estimiated_episode): 44 | self._task_queue.put((parallel_util.WORKER_PLAYING, 45 | {'use_true_env': use_true_env, 46 | 'use_random_action': use_random_action})) 47 | self._task_queue.join() 48 | 49 | # collect the data 50 | for _ in range(num_estimiated_episode): 51 | traj_episode = self._result_queue.get() 52 | rollout_data.append(traj_episode) 53 | num_timesteps_received += len(traj_episode['rewards']) 54 | 55 | # update average timesteps per episode and timestep remains 56 | self._avg_episode_len = \ 57 | float(num_timesteps_received) / len(rollout_data) 58 | timesteps_needed = self.args.timesteps_per_batch - \ 59 | num_timesteps_received 60 | 61 | logger.info('Finished {}th episode'.format(len(rollout_data))) 62 | if timesteps_needed <= 0 or self.args.test: 63 | break 64 | 65 | logger.info('{} timesteps from {} episodes collected'.format( 66 | num_timesteps_received, len(rollout_data)) 67 | ) 68 | 69 | return {'data': rollout_data} 70 | -------------------------------------------------------------------------------- /mbbl/sampler/singletask_random_sampler.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Tingwu Wang 3 | # ----------------------------------------------------------------------------- 4 | 5 | from mbbl.config import init_path 6 | from mbbl.sampler import singletask_sampler 7 | from mbbl.util.common import logger 8 | 9 | 10 | class sampler(singletask_sampler.sampler): 11 | 12 | def __init__(self, args, worker_type, network_type): 13 | ''' 14 | @brief: 15 | the master agent has several actors (or samplers) to do the 16 | sampling for it. 17 | ''' 18 | super(sampler, self).__init__(args, worker_type, network_type) 19 | self._base_path = init_path.get_abs_base_dir() 20 | 21 | def rollouts_using_worker_planning(self, num_timesteps=None, 22 | use_random_action=False): 23 | ''' @brief: 24 | Run the experiments until a total of @timesteps_per_batch 25 | timesteps are collected. 26 | ''' 27 | self._current_iteration += 1 28 | num_timesteps_received = 0 29 | timesteps_needed = self.args.timesteps_per_batch \ 30 | if num_timesteps is None else num_timesteps 31 | rollout_data = [] 32 | 33 | while True: 34 | # init the data 35 | traj_episode = self._play(use_random_action) 36 | logger.info('done with episode') 37 | rollout_data.append(traj_episode) 38 | num_timesteps_received += len(traj_episode['rewards']) 39 | 40 | # update average timesteps per episode 41 | timesteps_needed = self.args.timesteps_per_batch - \ 42 | num_timesteps_received 43 | 44 | if timesteps_needed <= 0 or self.args.test: 45 | break 46 | 47 | logger.info('{} timesteps from {} episodes collected'.format( 48 | num_timesteps_received, len(rollout_data)) 49 | ) 50 | 51 | return {'data': rollout_data} 52 | 53 | def _act(self, state, control_info={'use_random_action': False}): 54 | # use random policy 55 | action = self._npr.uniform(-1, 1, [self._action_size]) 56 | return action, [-1], [-1] 57 | -------------------------------------------------------------------------------- /mbbl/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/trainer/__init__.py -------------------------------------------------------------------------------- /mbbl/trainer/gail_trainer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # The optimization agent is responsible for doing the updates. 4 | # @author: 5 | # ------------------------------------------------------------------------------ 6 | from .base_trainer import base_trainer 7 | from mbbl.util.common import logger 8 | import numpy as np 9 | from collections import OrderedDict 10 | 11 | 12 | class trainer(base_trainer): 13 | 14 | def __init__(self, args, network_type, task_queue, result_queue, 15 | name_scope='trainer'): 16 | # the base agent 17 | super(trainer, self).__init__( 18 | args=args, network_type=network_type, 19 | task_queue=task_queue, result_queue=result_queue, 20 | name_scope=name_scope 21 | ) 22 | 23 | def _update_parameters(self, rollout_data, training_info): 24 | # get the observation list 25 | self._update_whitening_stats(rollout_data) 26 | 27 | # generate the reward from discriminator 28 | rollout_data = self._network['reward'][0].generate_rewards(rollout_data) 29 | 30 | training_data = self._preprocess_data(rollout_data) 31 | training_stats = OrderedDict() 32 | training_stats['avg_reward'] = training_data['avg_reward'] 33 | training_stats['avg_reward_std'] = training_data['avg_reward_std'] 34 | 35 | assert 'reward' in training_info['network_to_train'] 36 | # train the policy 37 | for key in training_info['network_to_train']: 38 | for i_network in range(self._num_model_ensemble[key]): 39 | i_stats = self._network[key][i_network].train( 40 | training_data, self._replay_buffer, training_info={} 41 | ) 42 | if i_stats is not None: 43 | training_stats.update(i_stats) 44 | self._replay_buffer.add_data(training_data) 45 | 46 | # record the actual reward (not from the discriminator) 47 | self._get_groundtruth_reward(rollout_data, training_stats) 48 | return training_stats 49 | 50 | def _get_groundtruth_reward(self, rollout_data, training_stats): 51 | 52 | for i_episode in rollout_data: 53 | i_episode['raw_episodic_reward'] = sum(i_episode['raw_rewards']) 54 | avg_reward = np.mean([i_episode['raw_episodic_reward'] 55 | for i_episode in rollout_data]) 56 | logger.info('Raw reward: {}'.format(avg_reward)) 57 | training_stats['RAW_reward'] = avg_reward 58 | -------------------------------------------------------------------------------- /mbbl/trainer/metrpo_trainer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # The optimization agent is responsible for doing the updates. 4 | # @author: 5 | # ------------------------------------------------------------------------------ 6 | from .base_trainer import base_trainer 7 | 8 | 9 | class trainer(base_trainer): 10 | 11 | def __init__(self, args, network_type, task_queue, result_queue, 12 | name_scope='trainer'): 13 | # the base agent 14 | super(trainer, self).__init__( 15 | args=args, 16 | network_type=network_type, 17 | task_queue=task_queue, 18 | result_queue=result_queue, 19 | name_scope=name_scope 20 | ) 21 | # self._base_path = init_path.get_abs_base_dir() 22 | 23 | def _update_parameters(self, rollout_data, training_info): 24 | # get the observation list 25 | self._update_whitening_stats(rollout_data) 26 | training_data = self._preprocess_data(rollout_data) 27 | training_stats = {'avg_reward': training_data['avg_reward']} 28 | 29 | # TODO(GD): add for loop to train policy? 30 | # train the policy 31 | for key in training_info['network_to_train']: 32 | for i_network in range(self._num_model_ensemble[key]): 33 | i_stats = self._network[key][i_network].train( 34 | training_data, self._replay_buffer, training_info={} 35 | ) 36 | if i_stats is not None: 37 | training_stats.update(i_stats) 38 | self._replay_buffer.add_data(training_data) 39 | return training_stats 40 | -------------------------------------------------------------------------------- /mbbl/trainer/readme.md: -------------------------------------------------------------------------------- 1 | rs: random shooting method 2 | -------------------------------------------------------------------------------- /mbbl/trainer/shooting_trainer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # The optimization agent is responsible for doing the updates. 4 | # @author: 5 | # ------------------------------------------------------------------------------ 6 | from .base_trainer import base_trainer 7 | 8 | 9 | class trainer(base_trainer): 10 | 11 | def __init__(self, args, network_type, task_queue, result_queue, 12 | name_scope='trainer'): 13 | # the base agent 14 | super(trainer, self).__init__( 15 | args=args, network_type=network_type, 16 | task_queue=task_queue, result_queue=result_queue, 17 | name_scope=name_scope 18 | ) 19 | # self._base_path = init_path.get_abs_base_dir() 20 | 21 | def _update_parameters(self, rollout_data, training_info): 22 | # get the observation list 23 | self._update_whitening_stats(rollout_data) 24 | training_data = self._preprocess_data(rollout_data) 25 | training_stats = {'avg_reward': training_data['avg_reward'], 26 | 'avg_reward_std': training_data['avg_reward_std']} 27 | 28 | # train the policy 29 | for key in training_info['network_to_train']: 30 | for i_network in range(self._num_model_ensemble[key]): 31 | i_stats = self._network[key][i_network].train( 32 | training_data, self._replay_buffer, training_info={} 33 | ) 34 | if i_stats is not None: 35 | training_stats.update(i_stats) 36 | self._replay_buffer.add_data(training_data) 37 | return training_stats 38 | -------------------------------------------------------------------------------- /mbbl/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/util/__init__.py -------------------------------------------------------------------------------- /mbbl/util/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/util/common/__init__.py -------------------------------------------------------------------------------- /mbbl/util/common/fpdb.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: 3 | # some helper functions about pdb. 4 | # ----------------------------------------------------------------------------- 5 | 6 | import sys 7 | import pdb 8 | 9 | 10 | class fpdb(pdb.Pdb): 11 | ''' 12 | @brief: 13 | a Pdb subclass that may be used from a import forked multiprocessing 14 | child 15 | ''' 16 | 17 | def interaction(self, *args, **kwargs): 18 | _stdin = sys.stdin 19 | try: 20 | sys.stdin = open('/dev/stdin') 21 | pdb.Pdb.interaction(self, *args, **kwargs) 22 | finally: 23 | sys.stdin = _stdin 24 | 25 | 26 | if __name__ == '__main__': 27 | # how to use it somewhere outside in a multi-process project. 28 | # from util import fpdb 29 | fpdb = fpdb.fpdb() 30 | fpdb.set_trace() 31 | -------------------------------------------------------------------------------- /mbbl/util/common/ggnn_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | GGNN utils 3 | ''' 4 | import numpy as np 5 | 6 | 7 | def compact2sparse_representation(mat, total_edge_type): 8 | ''' 9 | ''' 10 | N, _ = mat.shape 11 | 12 | sparse_mat = np.zeros((N, N * total_edge_type * 2)) 13 | 14 | for i in range(N): 15 | for j in range(N): 16 | if mat[i, j] == 0: continue 17 | 18 | edge_type = mat[i, j] 19 | _from = i 20 | _to = j 21 | 22 | in_x = j 23 | in_y = i + N * (edge_type - 1) 24 | sparse_mat[int(in_x), int(in_y)] = 1 25 | 26 | # fill out 27 | out_x = i 28 | out_y = total_edge_type + j + N * (edge_type - 1) 29 | sparse_mat[int(out_x), int(out_y)] = 1 30 | 31 | return sparse_mat.astype('int') 32 | 33 | def manual_parser(env): 34 | 35 | def _hopper(): 36 | graph = np.array([ 37 | [0, 1, 0, 0], 38 | [1, 0, 1, 0], 39 | [0, 1, 0, 1], 40 | [0, 0, 1, 0] 41 | ]) 42 | 43 | geom_meta_info = np.array([ 44 | [0.9, 0, 0, 1.45, 0, 0, 1.05, 0.05], 45 | [0.9, 0, 0, 1.05, 0, 0, 0.6, 0.05], 46 | [0.9, 0, 0, 0.6, 0, 0, 0.1, 0.04], 47 | [2.0, -0.14, 0, 0.1, 0.26, 0, 0.1, 0.06] 48 | ]) 49 | 50 | joint_meta_info = np.array([ 51 | [0, 0, 0, 0, 0, 0], 52 | [0, 0, 1.05, -150, 0, 200], 53 | [0, 0, 0.6, -150, 0, 200], 54 | [0, 0, 0.1, -45, 45, 200] 55 | ]) 56 | 57 | meta_info = np.hstack( (geom_meta_info, joint_meta_info) ) 58 | 59 | ob_assign = np.array([ 60 | 0, 0, 1, 2, 3, 61 | 0, 0, 0, 1, 2, 3 62 | ]) 63 | 64 | ac_assign = np.array([ 65 | 1, 2, 3 66 | ]) 67 | 68 | ggnn_info = {} 69 | ggnn_info['n_node'] = graph.shape[0] 70 | ggnn_info['n_node_type'] = 1 71 | 72 | ggnn_info['node_anno_dim'] = meta_info.shape[1] 73 | ggnn_info['node_state_dim'] = 6 # observation space 74 | ggnn_info['node_embed_dim'] = 512 # ggnn hidden states dim - node_state_dim 75 | 76 | ggnn_info['n_edge_type'] = graph.max() 77 | ggnn_info['output_dim'] = 5 # the number of occurrences in ob_assign 78 | 79 | return graph, ob_assign, ac_assign, meta_info, ggnn_info 80 | 81 | def _half_cheetah(): 82 | raise NotImplementedError 83 | 84 | def _walker(): 85 | raise NotImplementedError 86 | 87 | def _ant(): 88 | raise NotImplementedError 89 | 90 | if env == 'gym_hopper': return _hopper() 91 | else: 92 | raise NotImplementedError 93 | 94 | -------------------------------------------------------------------------------- /mbbl/util/common/logger.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: 3 | # The logger here will be called all across the project. It is inspired 4 | # by Yuxin Wu (ppwwyyxx@gmail.com) 5 | # 6 | # @author: 7 | # Tingwu Wang, 2017, Feb, 20th 8 | # ----------------------------------------------------------------------------- 9 | 10 | import logging 11 | import sys 12 | import os 13 | import datetime 14 | from termcolor import colored 15 | 16 | __all__ = ['set_file_handler'] # the actual worker is the '_logger' 17 | 18 | 19 | class _MyFormatter(logging.Formatter): 20 | ''' 21 | @brief: 22 | a class to make sure the format could be used 23 | ''' 24 | 25 | def format(self, record): 26 | date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green') 27 | msg = '%(message)s' 28 | 29 | if record.levelno == logging.WARNING: 30 | fmt = date + ' ' + \ 31 | colored('WRN', 'red', attrs=[]) + ' ' + msg 32 | elif record.levelno == logging.ERROR or \ 33 | record.levelno == logging.CRITICAL: 34 | fmt = date + ' ' + \ 35 | colored('ERR', 'red', attrs=['underline']) + ' ' + msg 36 | else: 37 | fmt = date + ' ' + msg 38 | 39 | if hasattr(self, '_style'): 40 | # Python3 compatibilty 41 | self._style._fmt = fmt 42 | self._fmt = fmt 43 | 44 | return super(self.__class__, self).format(record) 45 | 46 | 47 | _logger = logging.getLogger('joint_embedding') 48 | _logger.propagate = False 49 | _logger.setLevel(logging.INFO) 50 | 51 | # set the console output handler 52 | con_handler = logging.StreamHandler(sys.stdout) 53 | con_handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S')) 54 | _logger.addHandler(con_handler) 55 | 56 | 57 | class GLOBAL_PATH(object): 58 | 59 | def __init__(self, path=None): 60 | if path is None: 61 | path = os.getcwd() 62 | self.path = path 63 | 64 | def _set_path(self, path): 65 | self.path = path 66 | 67 | def _get_path(self): 68 | return self.path 69 | 70 | 71 | PATH = GLOBAL_PATH() 72 | 73 | 74 | def set_file_handler(path=None, prefix='', time_str=''): 75 | # set the file output handler 76 | if time_str == '': 77 | file_name = prefix + \ 78 | datetime.datetime.now().strftime("%A_%d_%B_%Y_%I:%M%p") + '.log' 79 | else: 80 | file_name = prefix + time_str + '.log' 81 | 82 | if path is None: 83 | mod = sys.modules['__main__'] 84 | path = os.path.join(os.path.abspath(mod.__file__), '..', '..', 'log') 85 | else: 86 | path = os.path.join(path, 'log') 87 | path = os.path.abspath(path) 88 | 89 | path = os.path.join(path, file_name) 90 | if not os.path.exists(path): 91 | os.makedirs(path) 92 | 93 | PATH._set_path(path) 94 | path = os.path.join(path, file_name) 95 | from tensorboard_logger import configure 96 | configure(path) 97 | 98 | file_handler = logging.FileHandler( 99 | filename=os.path.join(path, 'logger'), encoding='utf-8', mode='w') 100 | file_handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S')) 101 | _logger.addHandler(file_handler) 102 | 103 | _logger.info('Log file set to {}'.format(path)) 104 | return path 105 | 106 | 107 | def _get_path(): 108 | return PATH._get_path() 109 | 110 | 111 | _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 112 | 'warn', 'exception', 'debug'] 113 | 114 | # export logger functions 115 | for func in _LOGGING_METHOD: 116 | locals()[func] = getattr(_logger, func) 117 | -------------------------------------------------------------------------------- /mbbl/util/common/parallel_util.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: 3 | # Define some signals used during parallel 4 | # @author: 5 | # Tingwu Wang 6 | # ----------------------------------------------------------------------------- 7 | 8 | TRAIN_SIGNAL = 1 9 | SAVE_SIGNAL = 2 10 | 11 | # it makes the main trpo agent push its weights into the tunnel 12 | START_SIGNAL = 3 13 | 14 | # it ends the training 15 | END_SIGNAL = 4 16 | 17 | # ends the rollout 18 | END_ROLLOUT_SIGNAL = 5 19 | 20 | # ask the rollout agents to collect the ob normalizer's info 21 | AGENT_COLLECT_FILTER_INFO = 6 22 | 23 | # ask the rollout agents to synchronize the ob normalizer's info 24 | AGENT_SYNCHRONIZE_FILTER = 7 25 | 26 | # ask the agents to set their parameters of network 27 | AGENT_SET_WEIGHTS = 8 28 | 29 | # reset 30 | RESET_SIGNAL = 9 31 | 32 | # Initial training for mbmf policy netwrok. 33 | MBMF_INITIAL = 666 34 | 35 | # ask for policy network. 36 | GET_POLICY_NETWORK = 6666 37 | 38 | # ask and set for policy network weight. 39 | GET_POLICY_WEIGHT = 66 40 | SET_POLICY_WEIGHT = 66666 41 | 42 | 43 | WORKER_PLANNING = 10 44 | WORKER_PLAYING = 11 45 | WORKER_GET_MODEL = 12 46 | WORKER_RATE_ACTIONS = 13 47 | 48 | # make sure that no signals are using the same number 49 | var_dict = locals() 50 | var_list = [var_dict[var] for var in dir() if 51 | (not var.startswith('_') and type(var_dict[var]) == int)] 52 | 53 | assert len(var_list) == len(set(var_list)) 54 | -------------------------------------------------------------------------------- /mbbl/util/common/replay_buffer.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: save the true datapoints into a buffer 3 | # @author: Tingwu Wang 4 | # ----------------------------------------------------------------------------- 5 | import numpy as np 6 | 7 | 8 | class replay_buffer(object): 9 | 10 | def __init__(self, use_buffer, buffer_size, rand_seed, 11 | observation_size, action_size, save_reward=False): 12 | 13 | self._use_buffer = use_buffer 14 | self._buffer_size = buffer_size 15 | self._npr = np.random.RandomState(rand_seed) 16 | 17 | if not self._use_buffer: 18 | self._buffer_size = 0 19 | 20 | self._observation_size = observation_size 21 | self._action_size = action_size 22 | 23 | reward_data_size = self._buffer_size if save_reward else 0 24 | self._data = { 25 | 'start_state': np.zeros( 26 | [self._buffer_size, self._observation_size], 27 | dtype=np.float16 28 | ), 29 | 30 | 'end_state': np.zeros( 31 | [self._buffer_size, self._observation_size], 32 | dtype=np.float16 33 | ), 34 | 35 | 'action': np.zeros( 36 | [self._buffer_size, self._action_size], 37 | dtype=np.float16 38 | ), 39 | 40 | 'reward': np.zeros([reward_data_size], dtype=np.float16) 41 | } 42 | self._data_key = [key for key in self._data if len(self._data[key]) > 0] 43 | 44 | self._current_id = 0 45 | self._occupied_size = 0 46 | 47 | def add_data(self, new_data): 48 | if self._buffer_size == 0: 49 | return 50 | 51 | num_new_data = len(new_data['start_state']) 52 | 53 | if num_new_data + self._current_id > self._buffer_size: 54 | num_after_full = num_new_data + self._current_id - self._buffer_size 55 | for key in self._data_key: 56 | # filling the tail part 57 | self._data[key][self._current_id: self._buffer_size] = \ 58 | new_data[key][0: self._buffer_size - self._current_id] 59 | 60 | # filling the head part 61 | self._data[key][0: num_after_full] = \ 62 | new_data[key][self._buffer_size - self._current_id:] 63 | 64 | else: 65 | 66 | for key in self._data_key: 67 | self._data[key][self._current_id: 68 | self._current_id + num_new_data] = \ 69 | new_data[key] 70 | 71 | self._current_id = \ 72 | (self._current_id + num_new_data) % self._buffer_size 73 | self._occupied_size = \ 74 | min(self._buffer_size, self._occupied_size + num_new_data) 75 | 76 | def get_data(self, batch_size): 77 | 78 | # the data from old data 79 | sample_id = self._npr.randint(0, self._occupied_size, batch_size) 80 | return {key: self._data[key][sample_id] for key in self._data_key} 81 | 82 | def get_current_size(self): 83 | return self._occupied_size 84 | 85 | def get_all_data(self): 86 | return {key: self._data[key][:self._occupied_size] 87 | for key in self._data_key} 88 | -------------------------------------------------------------------------------- /mbbl/util/common/summary_handler.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: 3 | # define the how to record the summary 4 | # @author: 5 | # Tingwu Wang 6 | # ----------------------------------------------------------------------------- 7 | import os 8 | 9 | import tensorflow as tf 10 | 11 | from mbbl.config import init_path 12 | from mbbl.util.common import logger 13 | 14 | 15 | class summary_handler(object): 16 | ''' 17 | @brief: 18 | Tell the handler where to record all the information. 19 | Normally, we want to record the prediction of value loss, and the 20 | average reward (maybe learning rate) 21 | ''' 22 | 23 | def __init__(self, sess, summary_name, enable=True, summary_dir=None): 24 | # the interface we need 25 | self.summary = None 26 | self.sess = sess 27 | self.enable = enable 28 | if not self.enable: # the summary handler is disabled 29 | return 30 | if summary_dir is None: 31 | self.path = os.path.join( 32 | init_path.get_base_dir(), 'summary' 33 | ) 34 | else: 35 | self.path = os.path.join(summary_dir, 'summary') 36 | self.path = os.path.abspath(self.path) 37 | 38 | if not os.path.exists(self.path): 39 | os.makedirs(self.path) 40 | self.path = os.path.join(self.path, summary_name) 41 | 42 | self.train_writer = tf.summary.FileWriter(self.path, self.sess.graph) 43 | 44 | logger.info( 45 | 'summary write initialized, writing to {}'.format(self.path)) 46 | 47 | def get_tf_summary(self): 48 | assert self.summary is not None, logger.error( 49 | 'tf summary not defined, call the summary object separately') 50 | return self.summary 51 | 52 | 53 | class gym_summary_handler(summary_handler): 54 | ''' 55 | @brief: 56 | For the gym environment, we pass the stuff we want to record 57 | ''' 58 | 59 | def __init__(self, sess, summary_name, enable=True, 60 | scalar_var_list=dict(), summary_dir=None): 61 | super(self.__class__, self).__init__(sess, summary_name, enable=enable, 62 | summary_dir=summary_dir) 63 | if not self.enable: 64 | return 65 | assert type(scalar_var_list) == dict, logger.error( 66 | 'We only take the dict where the name is given as the key') 67 | 68 | if len(scalar_var_list) > 0: 69 | self.summary_list = [] 70 | for name, var in scalar_var_list.items(): 71 | self.summary_list.append(tf.summary.scalar(name, var)) 72 | self.summary = tf.summary.merge(self.summary_list) 73 | 74 | def manually_add_scalar_summary(self, summary_name, summary_value, x_axis): 75 | ''' 76 | @brief: 77 | might be useful to record the average game_length, and average 78 | reward 79 | @input: 80 | x_axis could either be the episode number of step number 81 | ''' 82 | if not self.enable: # it happens when we are just debugging 83 | return 84 | 85 | if 'expert_traj' in summary_name: 86 | return 87 | 88 | summary = tf.Summary( 89 | value=[tf.Summary.Value( 90 | tag=summary_name, simple_value=summary_value 91 | ), ] 92 | ) 93 | self.train_writer.add_summary(summary, x_axis) 94 | -------------------------------------------------------------------------------- /mbbl/util/common/tf_norm.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: define the batchnorm and layernorm in this function 3 | # ------------------------------------------------------------------------------ 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def layer_norm(x, name_scope, epsilon=1e-5, use_bias=True, 9 | use_scale=True, gamma_init=None, data_format='NHWC'): 10 | """ 11 | @Brief: code modified from ppwwyyxx github.com/ppwwyyxx/tensorpack/, 12 | under layer_norm.py. 13 | Layer Normalization layer, as described in the paper: 14 | https://arxiv.org/abs/1607.06450. 15 | @input: 16 | x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should 17 | match data_format. 18 | """ 19 | with tf.variable_scope(name_scope): 20 | shape = x.get_shape().as_list() 21 | ndims = len(shape) 22 | assert ndims in [2, 4] 23 | 24 | mean, var = tf.nn.moments(x, list(range(1, len(shape))), keep_dims=True) 25 | 26 | if data_format == 'NCHW': 27 | chan = shape[1] 28 | new_shape = [1, chan, 1, 1] 29 | else: 30 | chan = shape[-1] 31 | new_shape = [1, 1, 1, chan] 32 | if ndims == 2: 33 | new_shape = [1, chan] 34 | 35 | if use_bias: 36 | beta = tf.get_variable( 37 | 'beta', [chan], initializer=tf.constant_initializer() 38 | ) 39 | beta = tf.reshape(beta, new_shape) 40 | else: 41 | beta = tf.zeros([1] * ndims, name='beta') 42 | if use_scale: 43 | if gamma_init is None: 44 | gamma_init = tf.constant_initializer(1.0) 45 | gamma = tf.get_variable('gamma', [chan], initializer=gamma_init) 46 | gamma = tf.reshape(gamma, new_shape) 47 | else: 48 | gamma = tf.ones([1] * ndims, name='gamma') 49 | 50 | ret = tf.nn.batch_normalization( 51 | x, mean, var, beta, gamma, epsilon, name='output' 52 | ) 53 | return ret 54 | 55 | 56 | def batch_norm_with_train(x, name_scope, epsilon=1e-5, momentum=0.9): 57 | ret = tf.contrib.layers.batch_norm( 58 | x, decay=momentum, updates_collections=None, epsilon=epsilon, 59 | scale=True, is_training=True, scope=name_scope 60 | ) 61 | return ret 62 | 63 | 64 | def batch_norm_without_train(x, name_scope, epsilon=1e-5, momentum=0.9): 65 | ret = tf.contrib.layers.batch_norm( 66 | x, decay=momentum, updates_collections=None, epsilon=epsilon, 67 | scale=True, is_training=False, scope=name_scope 68 | ) 69 | return ret 70 | -------------------------------------------------------------------------------- /mbbl/util/gps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/util/gps/__init__.py -------------------------------------------------------------------------------- /mbbl/util/il/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/util/il/__init__.py -------------------------------------------------------------------------------- /mbbl/util/il/test_inverse_dynamics.py: -------------------------------------------------------------------------------- 1 | """ 2 | @brief: understand some statistics about the inverse dynamics 3 | @author: Tingwu Wang 4 | 5 | @Date: Jan 12, 2019 6 | """ 7 | # import matplotlib.pyplot as plt 8 | # from mbbl.env.dm_env.pos_dm_env import POS_CONNECTION 9 | from mbbl.env.env_register import make_env 10 | # from mbbl.util.common import logger 11 | # from mbbl.config import init_path 12 | # import cv2 13 | # import os 14 | # from skimage import draw 15 | import numpy as np 16 | # from mbbl.util.il.expert_data_util import load_pose_data 17 | # import argparse 18 | 19 | 20 | if __name__ == '__main__': 21 | 22 | env, env_info = make_env("cheetah-run-pos", 1234) 23 | control_info = env.get_controller_info() 24 | dynamics_env, _ = make_env("cheetah-run-pos", 1234) 25 | 26 | # generate the data 27 | env.reset() 28 | for i in range(1000): 29 | action = np.random.randn(env_info['action_size']) 30 | qpos = np.array(env._env.physics.data.qpos, copy=True) 31 | old_qpos = np.array(env._env.physics.data.qpos, copy=True) 32 | old_qvel = np.array(env._env.physics.data.qvel, copy=True) 33 | old_qacc = np.array(env._env.physics.data.qacc, copy=True) 34 | old_qfrc_inverse = np.array(env._env.physics.data.qfrc_inverse, copy=True) 35 | _, _, _, _ = env.step(action) 36 | ctrl = np.array(env._env.physics.data.ctrl, copy=True) 37 | qvel = np.array(env._env.physics.data.qvel, copy=True) 38 | qacc = np.array(env._env.physics.data.qacc, copy=True) 39 | 40 | # see the inverse 41 | qfrc_inverse = dynamics_env._env.physics.get_inverse_output(qpos, qvel, qacc) 42 | qfrc_action = qfrc_inverse[None, control_info['actuated_id']] 43 | action = ctrl * control_info['gear'] 44 | print("predicted action: {}\n".format(qfrc_action)) 45 | print("groundtruth action: {}\n".format(action)) 46 | -------------------------------------------------------------------------------- /mbbl/util/ilqr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/util/ilqr/__init__.py -------------------------------------------------------------------------------- /mbbl/util/ilqr/ilqr_utils.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # @brief: 3 | # ----------------------------------------------------------------------------- 4 | from mbbl.config import init_path 5 | 6 | _BASE_DIR = init_path.get_abs_base_dir() 7 | 8 | 9 | def update_damping_lambda(traj_data, increase, damping_args): 10 | if increase: 11 | traj_data['lambda_multiplier'] = max( 12 | traj_data['lambda_multiplier'] * damping_args['factor'], 13 | damping_args['factor'] 14 | ) 15 | traj_data['damping_lambda'] = max( 16 | traj_data['damping_lambda'] * traj_data['lambda_multiplier'], 17 | damping_args['min_damping'] 18 | ) 19 | else: # decrease 20 | traj_data['lambda_multiplier'] = min( 21 | traj_data['lambda_multiplier'] / damping_args['factor'], 22 | 1.0 / damping_args['factor'] 23 | ) 24 | traj_data['damping_lambda'] = \ 25 | traj_data['damping_lambda'] * traj_data['lambda_multiplier'] * \ 26 | (traj_data['damping_lambda'] > damping_args['min_damping']) 27 | 28 | traj_data['damping_lambda'] = \ 29 | max(traj_data['damping_lambda'], damping_args['min_damping']) 30 | if traj_data['damping_lambda'] > damping_args['max_damping']: 31 | traj_data['active'] = False 32 | -------------------------------------------------------------------------------- /mbbl/worker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/worker/__init__.py -------------------------------------------------------------------------------- /mbbl/worker/cem_worker.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # ------------------------------------------------------------------------------ 4 | import numpy as np 5 | 6 | from .base_worker import base_worker 7 | from mbbl.config import init_path 8 | # from mbbl.util.common import parallel_util 9 | 10 | 11 | def detect_done(ob, env_name, check_done): 12 | if env_name == 'gym_fhopper': 13 | height, ang = ob[:, 0], ob[:, 1] 14 | done = np.logical_or(height <= 0.7, abs(ang) >= 0.2) 15 | 16 | elif env_name == 'gym_fwalker2d': 17 | height, ang = ob[:, 0], ob[:, 1] 18 | done = np.logical_or( 19 | height >= 2.0, 20 | np.logical_or(height <= 0.8, abs(ang) >= 1.0) 21 | ) 22 | 23 | elif env_name == 'gym_fant': 24 | height = ob[:, 0] 25 | done = np.logical_or(height > 1.0, height < 0.2) 26 | 27 | elif env_name in ['gym_fant2', 'gym_fant5', 'gym_fant10', 28 | 'gym_fant20', 'gym_fant30']: 29 | height = ob[:, 0] 30 | done = np.logical_or(height > 1.0, height < 0.2) 31 | else: 32 | done = np.zeros([ob.shape[0]]) 33 | 34 | if not check_done: 35 | done[:] = False 36 | 37 | return done 38 | 39 | 40 | class worker(base_worker): 41 | EPSILON = 0.001 42 | 43 | def __init__(self, args, observation_size, action_size, 44 | network_type, task_queue, result_queue, worker_id, 45 | name_scope='planning_worker'): 46 | 47 | # the base agent 48 | super(worker, self).__init__(args, observation_size, action_size, 49 | network_type, task_queue, result_queue, 50 | worker_id, name_scope) 51 | self._base_dir = init_path.get_base_dir() 52 | self._alpha = args.cem_learning_rate 53 | self._num_iters = args.cem_num_iters 54 | self._elites_fraction = args.cem_elites_fraction 55 | 56 | def _plan(self, planning_data): 57 | num_traj = planning_data['state'].shape[0] 58 | sample_action = np.reshape( 59 | planning_data['samples'], 60 | [num_traj, planning_data['depth'], planning_data['action_size']] 61 | ) 62 | current_state = planning_data['state'] 63 | total_reward = 0 64 | done = np.zeros([num_traj]) 65 | 66 | for i_depth in range(planning_data['depth']): 67 | action = sample_action[:, i_depth, :] 68 | 69 | # pred next state 70 | next_state, _, _ = self._network['dynamics'][0].pred( 71 | {'start_state': current_state, 'action': action} 72 | ) 73 | 74 | # pred the reward 75 | reward, _, _ = self._network['reward'][0].pred( 76 | {'start_state': current_state, 'action': action} 77 | ) 78 | 79 | total_reward += reward * (1 - done) 80 | current_state = next_state 81 | 82 | # mark the done marker 83 | this_done = detect_done(next_state, self.args.task, self.args.check_done) 84 | done = np.logical_or(this_done, done) 85 | # if np.any(done): 86 | # from mbbl.util.common.fpdb import fpdb; fpdb().set_trace() 87 | 88 | # Return the cost 89 | return_dict = {'costs': -total_reward, 90 | 'sample_id': planning_data['id']} 91 | return return_dict 92 | -------------------------------------------------------------------------------- /mbbl/worker/metrpo_worker.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # ------------------------------------------------------------------------------ 4 | import numpy as np 5 | 6 | from .base_worker import base_worker 7 | from mbbl.config import init_path 8 | from mbbl.env import env_register 9 | from mbbl.env.env_util import play_episode_with_env 10 | from mbbl.env.fake_env import fake_env 11 | 12 | 13 | class worker(base_worker): 14 | 15 | def __init__(self, args, observation_size, action_size, 16 | network_type, task_queue, result_queue, worker_id, 17 | name_scope='planning_worker'): 18 | 19 | # the base agent 20 | super(worker, self).__init__(args, observation_size, action_size, 21 | network_type, task_queue, result_queue, 22 | worker_id, name_scope) 23 | self._base_dir = init_path.get_base_dir() 24 | 25 | # build the environments 26 | self._build_env() 27 | 28 | def _build_env(self): 29 | self._env, self._env_info = env_register.make_env( 30 | self.args.task, self._npr.randint(0, 9999), 31 | {'allow_monitor': self.args.monitor and self._worker_id == 0} 32 | ) 33 | self._fake_env = fake_env(self._env, self._step) 34 | 35 | def _plan(self, planning_data): 36 | raise NotImplementedError 37 | 38 | def _play(self, planning_data): 39 | ''' 40 | # TODO NOTE: 41 | var_list = self._network['policy'][0]._trainable_var_list 42 | print('') 43 | for var in var_list: 44 | print(var.name) 45 | # print(var.name, self._session.run(var)[-1]) 46 | ''' 47 | if planning_data['use_true_env']: 48 | traj_episode = play_episode_with_env( 49 | self._env, self._act, 50 | {'use_random_action': planning_data['use_random_action']} 51 | ) 52 | else: 53 | traj_episode = play_episode_with_env( 54 | self._fake_env, self._act, 55 | {'use_random_action': planning_data['use_random_action']} 56 | ) 57 | return traj_episode 58 | 59 | def _act(self, state, 60 | control_info={'use_random_action': False}): 61 | 62 | if 'use_random_action' in control_info and \ 63 | control_info['use_random_action']: 64 | # use random policy 65 | action = self._npr.uniform(-1, 1, [self._action_size]) 66 | return action, [-1], [-1] 67 | 68 | else: 69 | # call the policy network 70 | return self._network['policy'][0].act({'start_state': state}) 71 | 72 | def _step(self, state, action): 73 | state = np.reshape(state, [-1, self._observation_size]) 74 | action = np.reshape(action, [-1, self._action_size]) 75 | # pred next state 76 | next_state, _, _ = self._network['dynamics'][0].pred( 77 | {'start_state': state, 'action': action} 78 | ) 79 | 80 | # pred the reward 81 | reward, _, _ = self._network['reward'][0].pred( 82 | {'start_state': state, 'action': action} 83 | ) 84 | # from util.common.fpdb import fpdb; fpdb().set_trace() 85 | return next_state[0], reward[0] 86 | -------------------------------------------------------------------------------- /mbbl/worker/mf_worker.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # ------------------------------------------------------------------------------ 4 | from .base_worker import base_worker 5 | from mbbl.config import init_path 6 | from mbbl.env.env_util import play_episode_with_env 7 | from mbbl.util.common import logger 8 | import numpy as np 9 | 10 | 11 | class worker(base_worker): 12 | 13 | def __init__(self, args, observation_size, action_size, 14 | network_type, task_queue, result_queue, worker_id, 15 | name_scope='planning_worker'): 16 | 17 | # the base agent 18 | super(worker, self).__init__(args, observation_size, action_size, 19 | network_type, task_queue, result_queue, 20 | worker_id, name_scope) 21 | self._base_dir = init_path.get_base_dir() 22 | self._previous_reward = -np.inf 23 | 24 | # build the environments 25 | self._build_env() 26 | 27 | def _plan(self, planning_data): 28 | raise NotImplementedError 29 | 30 | def _play(self, planning_data): 31 | if self.args.num_expert_episode_to_save > 0 and \ 32 | self._previous_reward > self._env_solved_reward and \ 33 | self._worker_id == 0: 34 | start_save_episode = True 35 | logger.info('Last episodic reward: %.4f' % self._previous_reward) 36 | logger.info('Minimum reward of %.4f is needed to start saving' 37 | % self._env_solved_reward) 38 | logger.info('[SAVING] Worker %d will record its episode data' 39 | % self._worker_id) 40 | else: 41 | start_save_episode = False 42 | if self.args.num_expert_episode_to_save > 0 \ 43 | and self._worker_id == 0: 44 | logger.info('Last episodic reward: %.4f' % 45 | self._previous_reward) 46 | logger.info('Minimum reward of %.4f is needed to start saving' 47 | % self._env_solved_reward) 48 | 49 | traj_episode = play_episode_with_env( 50 | self._env, self._act, 51 | {'use_random_action': planning_data['use_random_action'], 52 | 'record_flag': start_save_episode, 53 | 'num_episode': self.args.num_expert_episode_to_save, 54 | 'data_name': self.args.task + '_' + self.args.exp_id} 55 | ) 56 | self._previous_reward = np.sum(traj_episode['rewards']) 57 | return traj_episode 58 | 59 | def _act(self, state, 60 | control_info={'use_random_action': False}): 61 | 62 | if 'use_random_action' in control_info and \ 63 | control_info['use_random_action']: 64 | # use random policy 65 | action = self._npr.uniform(-1, 1, [self._action_size]) 66 | return action, [-1], [-1] 67 | 68 | else: 69 | # call the policy network 70 | return self._network['policy'][0].act({'start_state': state}) 71 | -------------------------------------------------------------------------------- /mbbl/worker/model_worker.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # @brief: 3 | # ------------------------------------------------------------------------------ 4 | from .base_worker import base_worker 5 | from mbbl.config import init_path 6 | 7 | 8 | class worker(base_worker): 9 | 10 | def __init__(self, args, observation_size, action_size, 11 | network_type, task_queue, result_queue, worker_id, 12 | name_scope='derivative_worker'): 13 | 14 | # the base agent 15 | super(worker, self).__init__(args, observation_size, action_size, 16 | network_type, task_queue, result_queue, 17 | worker_id, name_scope) 18 | self._base_dir = init_path.get_base_dir() 19 | 20 | # build the environments 21 | self._build_env() 22 | 23 | def _dynamics_derivative(self, data_dict, 24 | target=['state', 'action', 'state-action']): 25 | 26 | assert len(self._network['dynamics']) == 1 27 | return self._network['dynamics'][0].get_derivative(data_dict, target) 28 | 29 | def _reward_derivative(self, data_dict, 30 | target=['state', 'action', 'state-state']): 31 | 32 | assert len(self._network['reward']) == 1 33 | return self._network['reward'][0].get_derivative(data_dict, target) 34 | 35 | def _dynamics(self, data_dict): 36 | assert len(self._network['dynamics']) == 1 37 | return {'end_state': self._network['dynamics'][0].pred(data_dict)[0]} 38 | 39 | def _reward(self, data_dict): 40 | assert len(self._network['reward']) == 1 41 | return {'reward': self._network['reward'][0].pred(data_dict)[0]} 42 | -------------------------------------------------------------------------------- /mbbl/worker/readme.md: -------------------------------------------------------------------------------- 1 | rs: random shooting method 2 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/ilqr.sh: -------------------------------------------------------------------------------- 1 | # the first batch: 2 | # gym_reacher, gym_cheetah, gym_walker2d, gym_hopper, gym_swimmer, gym_ant 3 | 4 | # the second batch: 5 | # gym_pendulum, gym_invertedPendulum, gym_acrobot, gym_mountain, gym_cartpole 6 | # 200 , 100 , 200 , 200 , 200 7 | 8 | # bash ilqr.sh gym_pendulum 4000 200 9 | # bash ilqr.sh gym_invertedPendulum 2000 100 10 | # bash ilqr.sh gym_acrobot 4000 200 11 | # bash ilqr.sh gym_mountain 4000 200 12 | # bash ilqr.sh gym_cartpole 4000 200 13 | 14 | # $1 is the environment 15 | # ilqr_depth: [10, 20, 50, 100] 16 | for ilqr_depth in 10 20 30 50 ; do 17 | python main/ilqr_main.py --max_timesteps $2 --task $1 --timesteps_per_batch $3 --ilqr_iteration 10 --ilqr_depth $ilqr_depth --max_ilqr_linesearch_backtrack 10 --exp_id $1-depth-$ilqr_depth --num_workers 2 --gt_dynamics 1 18 | done 19 | 20 | for num_ilqr_traj in 2 5 10; do 21 | python main/ilqr_main.py --max_timesteps $2 --task $1 --timesteps_per_batch $3 --ilqr_iteration 10 --ilqr_depth 30 --max_ilqr_linesearch_backtrack 10 --num_ilqr_traj $num_ilqr_traj --exp_id $1-traj-$num_ilqr_traj --num_workers 2 --gt_dynamics 1 22 | done 23 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/ilqr_depth.sh: -------------------------------------------------------------------------------- 1 | 2 | # $1 is the environment 3 | # ilqr_depth: [10, 20, 50, 100] 4 | for ilqr_depth in 10 20 30 50 ; do 5 | python main/ilqr_main.py --max_timesteps 2000 --task $1 --timesteps_per_batch 50 --ilqr_iteration 10 --ilqr_depth $ilqr_depth --max_ilqr_linesearch_backtrack 10 --exp_id $1-depth-$ilqr_depth --num_workers 2 --gt_dynamics 1 6 | done 7 | 8 | for num_ilqr_traj in 2 5 10; do 9 | python main/ilqr_main.py --max_timesteps 2000 --task $1 --timesteps_per_batch 50 --ilqr_iteration 10 --ilqr_depth 30 --max_ilqr_linesearch_backtrack 10 --num_ilqr_traj $num_ilqr_traj --exp_id $1-traj-$num_ilqr_traj --num_workers 2 --gt_dynamics 1 10 | done 11 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/mbmf.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | # gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant 5 | # gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole 6 | 7 | for trust_region_method in ppo; do 8 | for batch_size in 1000; do 9 | for env_type in $1; do 10 | for seed in 1234 2341 3412 4123; do 11 | python main/mbmf_main.py --exp_id mbmf_${env_type}_${trust_region_method}_seed_${seed}\ 12 | --task $env_type \ 13 | --trust_region_method ${trust_region_method} \ 14 | --num_planning_traj 5000 --planning_depth 20 --random_timesteps 1000 \ 15 | --timesteps_per_batch $batch_size --dynamics_epochs 30 \ 16 | --num_workers 20 --mb_timesteps 7000 --dagger_epoch 300 \ 17 | --dagger_timesteps_per_iter 1750 --max_timesteps 200000 \ 18 | --seed $seed --dynamics_batch_size 500 19 | done 20 | done 21 | done 22 | done 23 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/mbmf_1m.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | # gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant 5 | 6 | for batch_size in 2000; do 7 | for env_type in $1; do 8 | for seed in 1234 2341 3412 4123; do 9 | python main/mbmf_main.py --exp_id mbmf_${env_type}_seed_${seed}_1m \ 10 | --task $env_type \ 11 | --num_planning_traj 5000 --planning_depth 20 --random_timesteps 10000 \ 12 | --timesteps_per_batch $batch_size --dynamics_epochs 30 --num_workers 10 \ 13 | --mb_timesteps 70000 --dagger_epoch 300 --dagger_timesteps_per_iter 1750 \ 14 | --trust_region_method ppo \ 15 | --max_timesteps 1000000 --seed $seed --dynamics_batch_size 500 16 | done 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/pets_gt.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | # bash gym_reacher 2000 6 | # bash gym_cheetah 20000 7 | # bash gym_walker2d 20000 8 | # bash gym_hopper 20000 9 | # bash gym_swimmer 20000 10 | # bash gym_ant 20000 11 | # bash gym_pendulum 5000 12 | # bash gym_invertedPendulum 2500 13 | # bash gym_acrobot 5000 14 | # bash gym_mountain 5000 15 | # bash gym_cartpole 5000 16 | 17 | 18 | # bash gym_reacher 2000 30 19 | # bash gym_cheetah 20000 100 20 | # bash gym_walker2d 20000 50 21 | # bash gym_hopper 20000 100 22 | # bash gym_swimmer 20000 50 23 | # bash gym_ant 20000 30 24 | # bash gym_pendulum 5000 30 25 | # bash gym_invertedPendulum 2500 30 26 | # bash gym_acrobot 5000 30 27 | # bash gym_mountain 5000 30 28 | # bash gym_cartpole 5000 30 29 | 30 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 31 | # for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do 32 | for env_type in $1; do 33 | python main/pets_main.py --exp_id rs_${env_type}\ 34 | --task $env_type \ 35 | --num_planning_traj 500 --planning_depth $2 --random_timesteps 0 \ 36 | --timesteps_per_batch 1 --num_workers 10 --max_timesteps 20000 \ 37 | --gt_dynamics 1 38 | done 39 | 40 | # python2 main/pets_main.py --timesteps_per_batch 2000 --task gym_cheetah --num_workers 5 --planning_depth 15 --num_planning_traj 500 41 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/pets_gt_checkdone.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | # bash gym_reacher 2000 6 | # bash gym_cheetah 20000 7 | # bash gym_walker2d 20000 8 | # bash gym_hopper 20000 9 | # bash gym_swimmer 20000 10 | # bash gym_ant 20000 11 | # bash gym_pendulum 5000 12 | # bash gym_invertedPendulum 2500 13 | # bash gym_acrobot 5000 14 | # bash gym_mountain 5000 15 | # bash gym_cartpole 5000 16 | 17 | 18 | # bash gym_reacher 2000 30 19 | # bash gym_cheetah 20000 100 20 | # bash gym_walker2d 20000 50 21 | # bash gym_hopper 20000 100 22 | # bash gym_swimmer 20000 50 23 | # bash gym_ant 20000 30 24 | # bash gym_pendulum 5000 30 25 | # bash gym_invertedPendulum 2500 30 26 | # bash gym_acrobot 5000 30 27 | # bash gym_mountain 5000 30 28 | # bash gym_cartpole 5000 30 29 | 30 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 31 | # for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do 32 | for env_type in $1; do 33 | python main/pets_main.py --exp_id rs_${env_type}\ 34 | --task $env_type \ 35 | --num_planning_traj 500 --planning_depth $2 --random_timesteps 0 \ 36 | --timesteps_per_batch 1 --num_workers 10 --max_timesteps 20000 \ 37 | --gt_dynamics 1 --check_done 1 38 | done 39 | 40 | # python2 main/pets_main.py --timesteps_per_batch 2000 --task gym_cheetah --num_workers 5 --planning_depth 15 --num_planning_traj 500 41 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/ppo.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 5 | 6 | 7 | for batch_size in 2000 5000; do 8 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 9 | for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do 10 | for seed in 1234 2341 3412 4123; do 11 | python main/mf_main.py --exp_id ppo_${env_type}_batch_${batch_size}_seed_${seed} \ 12 | --timesteps_per_batch $batch_size --task $env_type \ 13 | --num_workers 5 --trust_region_method ppo --max_timesteps 1000000 14 | done 15 | done 16 | done 17 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/random_policy.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 5 | 6 | 7 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 8 | for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do 9 | python main/random_main.py --exp_id random_${env_type} \ 10 | --timesteps_per_batch 1 --task $env_type \ 11 | --num_workers 1 --max_timesteps 40000 12 | done 13 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/rs.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | 6 | for seed in 1234 2341 3412 4123; do 7 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 8 | for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do 9 | python main/rs_main.py --exp_id rs_${env_type}_seed_${seed}\ 10 | --task $env_type \ 11 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 10000 \ 12 | --timesteps_per_batch 3000 --num_workers 20 --max_timesteps 200000 --seed $seed 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/rs_1.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | 6 | for seed in 1234 2341 3412 4123; do 7 | for env_type in gym_reacher gym_cheetah gym_walker2d; do 8 | python main/rs_main.py --exp_id rs_${env_type}_seed_${seed}\ 9 | --task $env_type \ 10 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 10000 \ 11 | --timesteps_per_batch 3000 --num_workers 20 --max_timesteps 300000 --seed $seed 12 | done 13 | done 14 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/rs_2.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | 6 | for seed in 1234 2341 3412 4123; do 7 | for env_type in gym_hopper gym_swimmer gym_ant; do 8 | python main/rs_main.py --exp_id rs_${env_type}_seed_${seed}\ 9 | --task $env_type \ 10 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 10000 \ 11 | --timesteps_per_batch 3000 --num_workers 20 --max_timesteps 300000 --seed $seed 12 | done 13 | done 14 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/rs_gt.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | # bash gym_reacher 2000 6 | # bash gym_cheetah 20000 7 | # bash gym_walker2d 20000 8 | # bash gym_hopper 20000 9 | # bash gym_swimmer 20000 10 | # bash gym_ant 20000 11 | # bash gym_pendulum 5000 12 | # bash gym_invertedPendulum 2500 13 | # bash gym_acrobot 5000 14 | # bash gym_mountain 5000 15 | # bash gym_cartpole 5000 16 | 17 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 18 | # for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do 19 | for env_type in $1; do 20 | python main/rs_main.py --exp_id rs_${env_type}\ 21 | --task $env_type \ 22 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 0 \ 23 | --timesteps_per_batch 1 --num_workers 20 --max_timesteps 20000 \ 24 | --gt_dynamics 1 25 | done 26 | 27 | # python2 main/pets_main.py --timesteps_per_batch 2000 --task gym_cheetah --num_workers 5 --planning_depth 15 --num_planning_traj 500 28 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/rs_gt_checkdone.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | # bash gym_reacher 2000 6 | # bash gym_cheetah 20000 7 | # bash gym_walker2d 20000 8 | # bash gym_hopper 20000 9 | # bash gym_swimmer 20000 10 | # bash gym_ant 20000 11 | # bash gym_pendulum 5000 12 | # bash gym_invertedPendulum 2500 13 | # bash gym_acrobot 5000 14 | # bash gym_mountain 5000 15 | # bash gym_cartpole 5000 16 | 17 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 18 | # for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do 19 | for env_type in $1; do 20 | python main/rs_main.py --exp_id rs_${env_type}\ 21 | --task $env_type \ 22 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 0 \ 23 | --timesteps_per_batch 1 --num_workers 20 --max_timesteps 20000 \ 24 | --gt_dynamics 1 --check_done 1 25 | done 26 | 27 | # python2 main/pets_main.py --timesteps_per_batch 2000 --task gym_cheetah --num_workers 5 --planning_depth 15 --num_planning_traj 500 28 | -------------------------------------------------------------------------------- /scripts/exp_1_performance_curve/trpo.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | 6 | for batch_size in 2000 5000; do 7 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 8 | for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do 9 | for seed in 1234 2341 3412 4123; do 10 | python main/mf_main.py --exp_id trpo_${env_type}_batch_${batch_size}_seed_${seed} \ 11 | --timesteps_per_batch $batch_size --task $env_type \ 12 | --num_workers 5 --trust_region_method trpo --max_timesteps 1000000 13 | done 14 | done 15 | done 16 | -------------------------------------------------------------------------------- /scripts/exp_2_dilemma/gt_computation.sh: -------------------------------------------------------------------------------- 1 | 2 | for env_type in gym_cheetah gym_ant; do 3 | python main/rs_main.py --exp_id rs_${env_type}_depth_$1 \ 4 | --task $env_type \ 5 | --num_planning_traj 1000 --planning_depth $1 --random_timesteps 0 \ 6 | --timesteps_per_batch 1 --num_workers 20 --max_timesteps 10000 \ 7 | --gt_dynamics 1 --check_done 1 8 | 9 | python main/pets_main.py --exp_id rs_${env_type}_depth_$1 \ 10 | --task $env_type \ 11 | --num_planning_traj 500 --planning_depth $1 --random_timesteps 0 \ 12 | --timesteps_per_batch 1 --num_workers 20 --max_timesteps 10000 \ 13 | --gt_dynamics 1 --check_done 1 14 | done 15 | -------------------------------------------------------------------------------- /scripts/exp_2_dynamics_schemes/rs_act_norm.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | # Tanh + None (done) 6 | # Relu + None 7 | # Leaky-Relu + None 8 | # Tanh + layer-norm 9 | # Tanh + batch-norm 10 | # Relu + layer-norm 11 | # Relu + batch-norm 12 | # Leaky-relu + layernorm 13 | # Leaky-relu + batch-norm 14 | # Tanh + None (done) 15 | # Relu + None 16 | # Leaky-Relu + None 17 | # Tanh + layer-norm 18 | # Tanh + batch-norm 19 | # Relu + layer-norm 20 | # Relu + batch-norm 21 | # Leaky-relu + layernorm 22 | # Leaky-relu + batch-norm 23 | 24 | 25 | for seed in 1234 2341 3412 4123; do 26 | for env_type in gym_reacher gym_cheetah gym_ant; do 27 | for dynamics_activation_type in 'leaky_relu' 'tanh' 'relu' 'swish' 'none'; do 28 | for dynamics_normalizer_type in 'layer_norm' 'batch_norm' 'none'; do 29 | 30 | python main/rs_main.py \ 31 | --exp_id rs_${env_type}_act_${dynamics_activation_type}_norm_${dynamics_normalizer_type}_seed_${seed} \ 32 | --task $env_type \ 33 | --dynamics_batch_size 512 \ 34 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 10000 \ 35 | --timesteps_per_batch 3000 --num_workers 20 --max_timesteps 200000 \ 36 | --seed $seed \ 37 | --dynamics_activation_type $dynamics_activation_type \ 38 | --dynamics_normalizer_type $dynamics_normalizer_type 39 | 40 | done 41 | done 42 | done 43 | done 44 | -------------------------------------------------------------------------------- /scripts/exp_2_dynamics_schemes/rs_lr_batch.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | # Tanh + None (done) 6 | # Relu + None 7 | # Leaky-Relu + None 8 | # Tanh + layer-norm 9 | # Tanh + batch-norm 10 | # Relu + layer-norm 11 | # Relu + batch-norm 12 | # Leaky-relu + layernorm 13 | # Leaky-relu + batch-norm 14 | # Tanh + None (done) 15 | # Relu + None 16 | # Leaky-Relu + None 17 | # Tanh + layer-norm 18 | # Tanh + batch-norm 19 | # Relu + layer-norm 20 | # Relu + batch-norm 21 | # Leaky-relu + layernorm 22 | # Leaky-relu + batch-norm 23 | 24 | 25 | for seed in 1234 2341 3412 4123; do 26 | for env_type in gym_cheetah; do 27 | for dynamics_lr in 0.0003 0.001 0.003; do 28 | for dynamics_batch_size in 256 512 1024; do 29 | for dynamics_epochs in 10 30 50; do 30 | 31 | python main/rs_main.py \ 32 | --exp_id rs_${env_type}_lr_${dynamics_lr}_batchsize_${dynamics_batch_size}_epochs_${dynamics_epochs}_seed_${seed} \ 33 | --task $env_type \ 34 | --dynamics_batch_size 512 \ 35 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 10000 \ 36 | --timesteps_per_batch 3000 --num_workers 20 --max_timesteps 200000 \ 37 | --seed $seed \ 38 | --dynamics_lr $dynamics_lr \ 39 | --dynamics_batch_size $dynamics_batch_size \ 40 | --dynamics_epochs $dynamics_epochs 41 | done 42 | 43 | done 44 | done 45 | done 46 | done 47 | -------------------------------------------------------------------------------- /scripts/exp_7_traj/ilqr_numtraj.sh: -------------------------------------------------------------------------------- 1 | # num_ilqr_traj: [2, 5, 10] 2 | for num_ilqr_traj in 1 3 5 7 9 2 4 6 8 10; do 3 | python main/ilqr_main.py --max_timesteps 10000 --task $1-1000 --timesteps_per_batch 1000 \ 4 | --ilqr_iteration 10 --ilqr_depth 10 --max_ilqr_linesearch_backtrack 10 --num_ilqr_traj $num_ilqr_traj --exp_id $1-traj-$num_ilqr_traj --num_workers 2 --gt_dynamics 1 5 | done 6 | -------------------------------------------------------------------------------- /scripts/exp_9_planning_depth/ilqr_depth.sh: -------------------------------------------------------------------------------- 1 | 2 | # $1 is the environment 3 | # ilqr_depth: [10, 20, 50, 100] 4 | for ilqr_depth in 100 80 60 40 20 10; do 5 | python main/ilqr_main.py --max_timesteps 10000 --task $1-1000 --timesteps_per_batch 1000 \ 6 | --ilqr_iteration 5 --ilqr_depth $ilqr_depth --max_ilqr_linesearch_backtrack 10 \ 7 | --exp_id $1-depth-$ilqr_depth --num_workers 2 --gt_dynamics 1 8 | done 9 | -------------------------------------------------------------------------------- /scripts/ilqr/ilqr.sh: -------------------------------------------------------------------------------- 1 | 2 | # $1 is the environment 3 | # ilqr_depth: [10, 20, 50, 100] 4 | for ilqr_depth in 10 20 50 100; do 5 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 1000 --ilqr_iteration 10 --ilqr_depth $ilqr_depth --max_ilqr_linesearch_backtrack 10 --exp_id cheetah_depth_$ilqr_depth --num_workers 2 6 | done 7 | 8 | # ilqr_iteration: [5, 10, 15, 20] 9 | for ilqr_iteration in 5 10 15 20; do 10 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 20000 --ilqr_iteration $ilqr_iteration --ilqr_depth 10 --max_ilqr_linesearch_backtrack 10 --exp_id cheetah_ilqr_iteration_$ilqr_iteration --num_workers 2 11 | done 12 | 13 | # num_ilqr_traj: [2, 5, 10] 14 | for num_ilqr_traj in 2 5 10; do 15 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 1000 --ilqr_iteration 10 --ilqr_depth 10 --max_ilqr_linesearch_backtrack 10 --num_ilqr_traj $num_ilqr_traj --exp_id cheetah_traj_$num_ilqr_traj --num_workers 2 16 | 17 | done 18 | -------------------------------------------------------------------------------- /scripts/ilqr/ilqr_depth.sh: -------------------------------------------------------------------------------- 1 | 2 | # $1 is the environment 3 | # ilqr_depth: [10, 20, 50, 100] 4 | for ilqr_depth in 10 20 50 100; do 5 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 1000 --ilqr_iteration 10 --ilqr_depth $ilqr_depth --max_ilqr_linesearch_backtrack 10 --exp_id $1-depth-$ilqr_depth --num_workers 2 6 | done 7 | -------------------------------------------------------------------------------- /scripts/ilqr/ilqr_iter.sh: -------------------------------------------------------------------------------- 1 | # ilqr_iteration: [5, 10, 15, 20] 2 | for ilqr_iteration in 5 10 15 20; do 3 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 20000 --ilqr_iteration $ilqr_iteration --ilqr_depth 10 --max_ilqr_linesearch_backtrack 10 --exp_id $1-iteration_$ilqr_iteration --num_workers 2 4 | done 5 | -------------------------------------------------------------------------------- /scripts/ilqr/ilqr_numtraj.sh: -------------------------------------------------------------------------------- 1 | # num_ilqr_traj: [2, 5, 10] 2 | for num_ilqr_traj in 2 5 10; do 3 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 1000 --ilqr_iteration 10 --ilqr_depth 10 --max_ilqr_linesearch_backtrack 10 --num_ilqr_traj $num_ilqr_traj --exp_id $1-traj-$num_ilqr_traj --num_workers 2 4 | 5 | done 6 | -------------------------------------------------------------------------------- /scripts/performance_curve/mbmf.sh: -------------------------------------------------------------------------------- 1 | # see how the ppo works for the environments in terms of the performance: 2 | # batch size 3 | # max_timesteps 1e8 4 | 5 | for seed in 1234 2341 3412 4123; do 6 | for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do 7 | python main/mbmf_main.py --exp_id mbmf_${env_type}_seed_${seed}\ 8 | --task $env_type \ 9 | --num_planning_traj 5000 --planning_depth 20 --random_timesteps 10000 --timesteps_per_batch 3000 --dynamics_epochs 30 --num_workers 24 --mb_timesteps 70000 --dagger_epoch 300 --dagger_timesteps_per_iter 1750 --max_timesteps 10000000 --seed $seed 10 | done 11 | done 12 | -------------------------------------------------------------------------------- /scripts/test/gym_humanoid.sh: -------------------------------------------------------------------------------- 1 | # see how the humanoid works in terms of the performance: 2 | # trpo / ppo 3 | # batch size 4 | # max_timesteps 1e8 5 | 6 | 7 | for batch_size in 5000 20000 50000 2000; do 8 | for tr_method in ppo trpo; do 9 | for env_type in gym_humanoid; do 10 | python main/mf_main.py --exp_id ${env_type}_${tr_method}_${batch_size} \ 11 | --timesteps_per_batch $batch_size --task $env_type \ 12 | --num_workers 5 --trust_region_method $tr_method --max_timesteps 50000000 13 | done 14 | done 15 | done 16 | -------------------------------------------------------------------------------- /scripts/test/humanoid.sh: -------------------------------------------------------------------------------- 1 | # see how the humanoid works in terms of the performance: 2 | # trpo / ppo 3 | # batch size 4 | # max_timesteps 1e8 5 | 6 | 7 | for batch_size in 5000 20000 50000; do 8 | for tr_method in ppo trpo; do 9 | for env_type in dm-humanoid dm-humanoid-noise; do 10 | python main/mf_main.py --exp_id ${env_type}_${tr_method}_${batch_size} \ 11 | --timesteps_per_batch $batch_size --task $env_type \ 12 | --num_workers 5 --trust_region_method $tr_method --max_timesteps 50000000 \ 13 | --num_expert_episode_to_save 5 14 | done 15 | done 16 | done 17 | -------------------------------------------------------------------------------- /scripts/test/humanoid_new.sh: -------------------------------------------------------------------------------- 1 | # see how the humanoid works in terms of the performance: 2 | # trpo / ppo 3 | # batch size 4 | # max_timesteps 1e8 5 | 6 | 7 | export PYTHONPATH="$PYTHONPATH:$PWD" 8 | 9 | for env_type in dm-humanoid dm-humanoid-noise; do 10 | for batch_size in 5000 20000 50000; do 11 | for tr_method in ppo trpo; do 12 | python main/mf_main.py --exp_id ${env_type}_${tr_method}_${batch_size} \ 13 | --timesteps_per_batch $batch_size --task $env_type \ 14 | --num_workers 5 --trust_region_method $tr_method --max_timesteps 50000000 15 | done 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /scripts/test/mbmf.sh: -------------------------------------------------------------------------------- 1 | seeds=($(shuf -i 0-1000 -n 3)) 2 | for seed in "${seeds[@]}"; do 3 | echo "$seed" 4 | # Swimmer 5 | python main/mbmf_main.py --task gym_swimmer --num_planning_traj 5000 --planning_depth 20 --random_timesteps 10000 --timesteps_per_batch 3000 --dynamics_epochs 30 --output_dir mbmf --num_workers 24 --mb_timesteps 70000 --dagger_epoch 300 --dagger_timesteps_per_iter 1750 --max_timesteps 10000000 --seed $seed 6 | # Half cheetah 7 | python main/mbmf_main.py --task gym_cheetah --num_planning_traj 1000 --planning_depth 20 --random_timesteps 10000 --timesteps_per_batch 9000 --dynamics_epochs 60 --output_dir mbmf --num_workers 24 --mb_timesteps 80000 --dagger_epoch 300 --dagger_timesteps_per_iter 2000 --max_timesteps 100000000 --seed $seed 8 | # Hopper 9 | python main/mbmf_main.py --task gym_hopper --num_planning_traj 1000 --planning_depth 40 --random_timesteps 4000 --timesteps_per_batch 10000 --dynamics_epochs 40 --output_dir mbmf --num_workers 24 --mb_timesteps 50000 --dagger_epoch 200 --dagger_timesteps_per_iter 5000 --max_timesteps 100000000 --seed $seed --dagger_saved_rollout 60 --dagger_iter 5 10 | #python main/mf_main.py --task gym_hopper --timesteps_per_batch 50000 --policy_batch_size 50000 --output_dir mbmf --num_workers 24 --seed $seed --max_timestep 100000000 11 | done 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Package setup script.""" 2 | import setuptools 3 | 4 | setuptools.setup( 5 | name="mbbl", 6 | version="0.1", 7 | description="Model-based RL Baselines", 8 | author="Tingwu Wang", 9 | author_email="tingwuwang@cs.toronto.edu", 10 | packages=setuptools.find_packages(), 11 | package_data={'': ['./env/dm_env/assets/*.xml', 12 | './env/dm_env/assets/common/*.xml', 13 | './env/gym_env/fix_swimmer/assets/*.xml', 14 | './env/gym_env/pets_env/assets/*.xml']}, 15 | include_package_data=True, 16 | install_requires=[ 17 | ], 18 | ) 19 | 20 | """ 21 | "pyquaternion", 22 | "beautifulsoup4", 23 | "Box2D>=2.3.2", 24 | "num2words", 25 | "six", 26 | "tensorboard_logger", 27 | "tensorflow==1.12.0", 28 | "termcolor", 29 | "gym[mujoco]==0.7.4", 30 | "mujoco-py==0.5.7", 31 | """ 32 | -------------------------------------------------------------------------------- /tests/tf-test-guide.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Code Testing Guide 2 | This guide provides inspirations on testing machine learning code written in tensorflow. 3 | 4 | ## Tensorflow Testing Framework 5 | Import the tensorflow built-in testing framework as follows: 6 | ```python 7 | from tensorflow.python.platform import test 8 | ``` 9 | Or simply `import tensorflow as tf` and use `tf.test`. 10 | For the rest of this document, we will refer to the framework as `tf.test`. 11 | 12 | ### Introduction to `tf.test` 13 | 1. Recommended structure for the test code. 14 | 15 | Similar to unittest, write your code as classes inherited from `tf.test.TestCase`. 16 | Then in the `main` scope, run `tf.test.main()`. 17 | 18 | Example: 19 | ```python 20 | import tensorflow as tf 21 | 22 | class SeriousTest(tf.test.TestCase): 23 | def test_method_1(self): 24 | actual_value = expected_value = 1 25 | self.assertEqual(actual_value, expected_value) 26 | def test_method_2(self): 27 | actual_value = 1 28 | expected_value = 2 29 | self.assertNotEqual(actual_value, expected_value) 30 | 31 | if __name__ == "__main__": 32 | tf.test.main() 33 | ``` 34 | 35 | 36 | 2. Tensorflow unit test class [`tf.test.TestCase`](https://www.tensorflow.org/api_docs/python/tf/test/TestCase#test_session). 37 | 38 | The `tf.test.TestCase` is built upon the standard python `unittest` with many additional methods. 39 | 40 | Checkout the test code in [TF Slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim/python/slim) for good examples. 41 | 42 | 43 | 44 | ### Useful Tips 45 | #### 1. Use `test_session()` 46 | When using `tf.test.TestCase`, you should almost always use 47 | the [`self.test_session()`](https://www.tensorflow.org/api_docs/python/tf/test/TestCase#test_session) method. 48 | It returns a TensorFlow Session for use in executing tests. 49 | 50 | See example in the [AlexNet tests](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py) 51 | as part of the TF Slim code base. 52 | 53 | #### 2. Test that a graph is built correctly 54 | 55 | The method [`tf.test.assert_equal_graph_def`](https://www.tensorflow.org/api_docs/python/tf/test/assert_equal_graph_def) 56 | asserts two `GraphDef` objects are the same, ignoring versions and ordering of nodes, attrs, and control inputs. 57 | 58 | Note: you can turn a `Graph` object to `GraphDef` object using `graph.as_graph_def()`. 59 | 60 | Example: 61 | ```python 62 | import tensorflow as tf 63 | 64 | graph_actual = tf.Graph() 65 | with graph_actual.as_default(): 66 | x = tf.placeholder(tf.float32, shape=[None, 3]) 67 | variable_name = tf.Variable(initial_value=tf.random_normal(shape=[3, 5])) 68 | 69 | graph_expected = tf.Graph() 70 | with graph_expected.as_default(): 71 | x = tf.placeholder(tf.float32, shape=[None, 3]) 72 | variable_name = tf.Variable(initial_value=tf.random_normal(shape=[3, 5])) 73 | 74 | tf.test.assert_equal_graph_def(graph_actual.as_graph_def(), graph_expected.as_graph_def()) 75 | # test will pass. 76 | ``` --------------------------------------------------------------------------------