├── .gitignore
├── README.md
├── __init__.py
├── img
├── 1m_mbbl_result_table.jpeg
├── mbbl_front.png
├── mbbl_result_table.jpeg
└── mbbl_stats.png
├── main
├── __init__.py
├── deepmimic_main.py
├── gail_mf_main.py
├── gps_main.py
├── ilqr_main.py
├── inverse_dynamics_IM.py
├── mbmf_main.py
├── metrpo_gnn_main.py
├── metrpo_main.py
├── mf_main.py
├── pets_main.py
├── random_main.py
├── rs_main.py
└── test.py
├── mbbl
├── __init__.py
├── config
│ ├── __init__.py
│ ├── base_config.py
│ ├── cem_config.py
│ ├── ggnn_config.py
│ ├── gps_config.py
│ ├── il_config.py
│ ├── ilqr_config.py
│ ├── init_path.py
│ ├── mbmf_config.py
│ ├── metrpo_config.py
│ ├── mf_config.py
│ └── rs_config.py
├── env
│ ├── __init__.py
│ ├── base_env_wrapper.py
│ ├── bullet_env
│ │ ├── __init__.py
│ │ ├── bullet_humanoid.py
│ │ ├── depth_roboschool.py
│ │ ├── humanoid.py
│ │ └── motion_capture_data.py
│ ├── dm_env
│ │ ├── __init__.py
│ │ ├── assets
│ │ │ ├── cheetah_pos.xml
│ │ │ ├── common
│ │ │ │ ├── __init__.py
│ │ │ │ ├── materials.xml
│ │ │ │ ├── skybox.xml
│ │ │ │ └── visual.xml
│ │ │ ├── humanoid_CMU.xml
│ │ │ └── reference
│ │ │ │ ├── acrobot.xml
│ │ │ │ ├── ball_in_cup.xml
│ │ │ │ ├── cartpole.xml
│ │ │ │ ├── cheetah.xml
│ │ │ │ ├── finger.xml
│ │ │ │ ├── fish.xml
│ │ │ │ ├── hopper.xml
│ │ │ │ ├── humanoid.xml
│ │ │ │ ├── humanoid_CMU.xml
│ │ │ │ ├── lqr.xml
│ │ │ │ ├── manipulator.xml
│ │ │ │ ├── pendulum.xml
│ │ │ │ ├── point_mass.xml
│ │ │ │ ├── reacher.xml
│ │ │ │ ├── stacker.xml
│ │ │ │ ├── swimmer.xml
│ │ │ │ └── walker.xml
│ │ ├── dm_env.py
│ │ ├── humanoid_env.py
│ │ └── pos_dm_env.py
│ ├── env_register.py
│ ├── env_util.py
│ ├── fake_env.py
│ ├── gym_env
│ │ ├── __init__.py
│ │ ├── acrobot.py
│ │ ├── box2d
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── walker.py
│ │ │ └── wrappers.py
│ │ ├── box2d_lunar_lander.py
│ │ ├── box2d_racer.py
│ │ ├── box2d_walker.py
│ │ ├── cartpole.py
│ │ ├── delayed_walker.py
│ │ ├── fix_swimmer
│ │ │ ├── __init__.py
│ │ │ ├── assets
│ │ │ │ └── fixed_swimmer.xml
│ │ │ └── fixed_swimmer.py
│ │ ├── fixed_swimmer.py
│ │ ├── fixed_walker.py
│ │ ├── humanoid.py
│ │ ├── invertedPendulum.py
│ │ ├── mountain_car.py
│ │ ├── noise_gym_cartpole.py
│ │ ├── noise_gym_cheetah.py
│ │ ├── noise_gym_pendulum.py
│ │ ├── pendulum.py
│ │ ├── pets.py
│ │ ├── pets_env
│ │ │ ├── __init__.py
│ │ │ ├── assets
│ │ │ │ ├── cartpole.xml
│ │ │ │ ├── half_cheetah.xml
│ │ │ │ ├── pusher.xml
│ │ │ │ └── reacher3d.xml
│ │ │ ├── cartpole.py
│ │ │ ├── half_cheetah.py
│ │ │ ├── half_cheetah_config.py
│ │ │ ├── pusher.py
│ │ │ ├── pusher_config.py
│ │ │ ├── reacher.py
│ │ │ └── reacher_config.py
│ │ ├── point.py
│ │ ├── point_env.py
│ │ ├── reacher.py
│ │ └── walker.py
│ ├── render.py
│ └── render_wrapper.py
├── network
│ ├── __init__.py
│ ├── dynamics
│ │ ├── __init__.py
│ │ ├── base_dynamics.py
│ │ ├── bayesian_forward_dynamics.py
│ │ ├── deterministic_forward_dynamics.py
│ │ ├── deterministic_forward_ggnn_dynamics.py
│ │ ├── groundtruth_forward_dynamics.py
│ │ ├── linear_stochastic_forward_dynamics_gmm_prior.py
│ │ └── stochastic_forward_dynamics.py
│ ├── policy
│ │ ├── __init__.py
│ │ ├── base_policy.py
│ │ ├── cem_policy.py
│ │ ├── gps_policy_gmm_refit.py
│ │ ├── mbmf_policy.py
│ │ ├── ppo_cnn_policy.py
│ │ ├── ppo_policy.py
│ │ ├── random_policy.py
│ │ └── trpo_policy.py
│ └── reward
│ │ ├── GAN_reward.py
│ │ ├── __init__.py
│ │ ├── base_reward.py
│ │ ├── deepmimic_reward.py
│ │ └── groundtruth_reward.py
├── sampler
│ ├── __init__.py
│ ├── base_sampler.py
│ ├── mbmf_sampler.py
│ ├── readme.md
│ ├── singletask_ilqr_sampler.py
│ ├── singletask_metrpo_sampler.py
│ ├── singletask_pets_sampler.py
│ ├── singletask_random_sampler.py
│ └── singletask_sampler.py
├── trainer
│ ├── __init__.py
│ ├── base_trainer.py
│ ├── gail_trainer.py
│ ├── gps_trainer.py
│ ├── mbmf_trainer.py
│ ├── metrpo_trainer.py
│ ├── readme.md
│ └── shooting_trainer.py
├── util
│ ├── __init__.py
│ ├── base_main.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── fpdb.py
│ │ ├── ggnn_utils.py
│ │ ├── logger.py
│ │ ├── misc_utils.py
│ │ ├── model_saver.py
│ │ ├── parallel_util.py
│ │ ├── replay_buffer.py
│ │ ├── summary_handler.py
│ │ ├── tf_ggnn_networks.py
│ │ ├── tf_networks.py
│ │ ├── tf_norm.py
│ │ ├── tf_utils.py
│ │ ├── vis_debug.py
│ │ └── whitening_util.py
│ ├── gps
│ │ ├── __init__.py
│ │ └── gps_utils.py
│ ├── il
│ │ ├── __init__.py
│ │ ├── camera_model.py
│ │ ├── camera_pose_ID_solver.py
│ │ ├── expert_data_util.py
│ │ ├── il_util.py
│ │ ├── pose_visualization.py
│ │ └── test_inverse_dynamics.py
│ ├── ilqr
│ │ ├── __init__.py
│ │ ├── ilqr_data_wrapper.py
│ │ ├── ilqr_utils.py
│ │ └── stochastic_ilqr_data_wrapper.py
│ └── kfac
│ │ ├── estimator.py
│ │ ├── fisher_blocks.py
│ │ ├── fisher_factors.py
│ │ ├── layer_collection.py
│ │ ├── loss_functions.py
│ │ ├── optimizer.py
│ │ └── utils.py
└── worker
│ ├── __init__.py
│ ├── base_worker.py
│ ├── cem_worker.py
│ ├── mbmf_worker.py
│ ├── metrpo_worker.py
│ ├── mf_worker.py
│ ├── model_worker.py
│ ├── readme.md
│ └── rs_worker.py
├── scripts
├── exp_1_performance_curve
│ ├── ilqr.sh
│ ├── ilqr_depth.sh
│ ├── mbmf.sh
│ ├── mbmf_1m.sh
│ ├── pets_gt.sh
│ ├── pets_gt_checkdone.sh
│ ├── ppo.sh
│ ├── random_policy.sh
│ ├── rs.sh
│ ├── rs_1.sh
│ ├── rs_2.sh
│ ├── rs_gt.sh
│ ├── rs_gt_checkdone.sh
│ └── trpo.sh
├── exp_2_dilemma
│ └── gt_computation.sh
├── exp_2_dynamics_schemes
│ ├── rs_act_norm.sh
│ └── rs_lr_batch.sh
├── exp_7_traj
│ └── ilqr_numtraj.sh
├── exp_9_planning_depth
│ └── ilqr_depth.sh
├── ilqr
│ ├── ilqr.sh
│ ├── ilqr_depth.sh
│ ├── ilqr_iter.sh
│ └── ilqr_numtraj.sh
├── performance_curve
│ └── mbmf.sh
└── test
│ ├── gym_humanoid.sh
│ ├── humanoid.sh
│ ├── humanoid_new.sh
│ └── mbmf.sh
├── setup.py
└── tests
├── test_utils
├── test_gt_dynamics.py
└── test_replay_buffer.py
└── tf-test-guide.md
/.gitignore:
--------------------------------------------------------------------------------
1 | MUJUCO_LOG.txt
2 | imitation
3 | main/MUJUCO_LOG.txt
4 | *.pdf
5 | .nfs*
6 | output/
7 | *.p
8 | *.csv
9 | *.npy
10 | paper/
11 | *.mp4
12 | *.log
13 | *.swp
14 | *.swo
15 | *.swn
16 | summary/
17 | log/
18 | *.jpg
19 | tool/batch_summary_*
20 | video/
21 | *.zip
22 | *cpython*
23 | data/
24 |
25 | # IDE files
26 | .idea/
27 |
28 | # pytest cache files
29 | .pytest_cache
30 | # Byte-compiled / optimized / DLL files
31 | __pycache__/
32 | *.py[cod]
33 | *$py.class
34 |
35 | # C extensions
36 | *.so
37 |
38 | # Distribution / packaging
39 | .Python
40 | build/
41 | develop-eggs/
42 | dist/
43 | downloads/
44 | eggs/
45 | .eggs/
46 | lib/
47 | lib64/
48 | parts/
49 | sdist/
50 | var/
51 | wheels/
52 | *.egg-info/
53 | .installed.cfg
54 | *.egg
55 | MANIFEST
56 |
57 | # PyInstaller
58 | # Usually these files are written by a python script from a template
59 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
60 | *.manifest
61 | *.spec
62 |
63 | # Installer logs
64 | pip-log.txt
65 | pip-delete-this-directory.txt
66 |
67 | # Unit test / coverage reports
68 | htmlcov/
69 | .tox/
70 | .nox/
71 | .coverage
72 | .coverage.*
73 | .cache
74 | nosetests.xml
75 | coverage.xml
76 | *.cover
77 | .hypothesis/
78 | .pytest_cache/
79 |
80 | # Translations
81 | *.mo
82 | *.pot
83 |
84 | # Django stuff:
85 | *.log
86 | local_settings.py
87 | db.sqlite3
88 |
89 | # Flask stuff:
90 | instance/
91 | .webassets-cache
92 |
93 | # Scrapy stuff:
94 | .scrapy
95 |
96 | # Sphinx documentation
97 | docs/_build/
98 |
99 | # PyBuilder
100 | target/
101 |
102 | # Jupyter Notebook
103 | .ipynb_checkpoints
104 |
105 | # IPython
106 | profile_default/
107 | ipython_config.py
108 |
109 | # pyenv
110 | .python-version
111 |
112 | # celery beat schedule file
113 | celerybeat-schedule
114 |
115 | # SageMath parsed files
116 | *.sage.py
117 |
118 | # Environments
119 | .env
120 | .venv
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/__init__.py
--------------------------------------------------------------------------------
/img/1m_mbbl_result_table.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/img/1m_mbbl_result_table.jpeg
--------------------------------------------------------------------------------
/img/mbbl_front.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/img/mbbl_front.png
--------------------------------------------------------------------------------
/img/mbbl_result_table.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/img/mbbl_result_table.jpeg
--------------------------------------------------------------------------------
/img/mbbl_stats.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/img/mbbl_stats.png
--------------------------------------------------------------------------------
/main/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/main/__init__.py
--------------------------------------------------------------------------------
/main/deepmimic_main.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: main fucntion
3 | # -----------------------------------------------------------------------------
4 |
5 | from mbbl.config import base_config
6 | from mbbl.config import init_path
7 | from mbbl.config import mf_config
8 | from mbbl.config import il_config
9 | from mbbl.util.common import logger
10 | from mf_main import train
11 |
12 |
13 | def main():
14 | parser = base_config.get_base_config()
15 | parser = mf_config.get_mf_config(parser)
16 | parser = il_config.get_il_config(parser)
17 | args = base_config.make_parser(parser)
18 |
19 | if args.write_log:
20 | logger.set_file_handler(path=args.output_dir,
21 | prefix='deepmimic-mf-' + args.task,
22 | time_str=args.exp_id)
23 |
24 | # no random policy for model-free rl
25 | assert args.random_timesteps == 0
26 |
27 | print('Training starts at {}'.format(init_path.get_abs_base_dir()))
28 | from mbbl.trainer import gail_trainer
29 | from mbbl.sampler import singletask_sampler
30 | from mbbl.worker import mf_worker
31 | import mbbl.network.policy.trpo_policy
32 | import mbbl.network.policy.ppo_policy
33 |
34 | policy_network = {
35 | 'ppo': mbbl.network.policy.ppo_policy.policy_network,
36 | 'trpo': mbbl.network.policy.trpo_policy.policy_network
37 | }[args.trust_region_method]
38 |
39 | # here the dynamics and reward are simply placeholders, which cannot be
40 | # called to pred next state or reward
41 | from mbbl.network.dynamics.base_dynamics import base_dynamics_network
42 | from mbbl.network.reward.deepmimic_reward import reward_network
43 |
44 | train(gail_trainer, singletask_sampler, mf_worker,
45 | base_dynamics_network, policy_network, reward_network, args)
46 |
47 |
48 | if __name__ == '__main__':
49 | main()
50 |
--------------------------------------------------------------------------------
/main/gail_mf_main.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: main fucntion
3 | # -----------------------------------------------------------------------------
4 |
5 | from mbbl.config import base_config
6 | from mbbl.config import init_path
7 | from mbbl.config import mf_config
8 | from mbbl.config import il_config
9 | from mbbl.util.common import logger
10 | from mf_main import train
11 |
12 |
13 | def main():
14 | parser = base_config.get_base_config()
15 | parser = mf_config.get_mf_config(parser)
16 | parser = il_config.get_il_config(parser)
17 | args = base_config.make_parser(parser)
18 |
19 | if args.write_log:
20 | logger.set_file_handler(path=args.output_dir,
21 | prefix='gail-mfrl-mf-' + args.task,
22 | time_str=args.exp_id)
23 |
24 | # no random policy for model-free rl
25 | assert args.random_timesteps == 0
26 |
27 | print('Training starts at {}'.format(init_path.get_abs_base_dir()))
28 | from mbbl.trainer import gail_trainer
29 | from mbbl.sampler import singletask_sampler
30 | from mbbl.worker import mf_worker
31 | import mbbl.network.policy.trpo_policy
32 | import mbbl.network.policy.ppo_policy
33 |
34 | policy_network = {
35 | 'ppo': mbbl.network.policy.ppo_policy.policy_network,
36 | 'trpo': mbbl.network.policy.trpo_policy.policy_network
37 | }[args.trust_region_method]
38 |
39 | # here the dynamics and reward are simply placeholders, which cannot be
40 | # called to pred next state or reward
41 | from mbbl.network.dynamics.base_dynamics import base_dynamics_network
42 | from mbbl.network.reward.GAN_reward import reward_network
43 |
44 | train(gail_trainer, singletask_sampler, mf_worker,
45 | base_dynamics_network, policy_network, reward_network, args)
46 |
47 |
48 | if __name__ == '__main__':
49 | main()
50 |
--------------------------------------------------------------------------------
/main/gps_main.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: main fucntion
3 | # -----------------------------------------------------------------------------
4 |
5 | from mf_main import train # gps use similar trainer as mf trainer
6 | from mbbl.config import base_config
7 | from mbbl.config import gps_config
8 | from mbbl.config import ilqr_config
9 | from mbbl.config import init_path
10 | from mbbl.util.common import logger
11 |
12 |
13 | def main():
14 | parser = base_config.get_base_config()
15 | parser = ilqr_config.get_ilqr_config(parser)
16 | parser = gps_config.get_gps_config(parser)
17 | args = base_config.make_parser(parser)
18 |
19 | if args.write_log:
20 | logger.set_file_handler(path=args.output_dir,
21 | prefix='mbrl-gps-' + args.task,
22 | time_str=args.exp_id)
23 |
24 | print('Training starts at {}'.format(init_path.get_abs_base_dir()))
25 | from mbbl.trainer import gps_trainer
26 | from mbbl.sampler import singletask_sampler
27 | from mbbl.worker import mf_worker
28 | from mbbl.network.policy.gps_policy_gmm_refit import policy_network
29 |
30 | assert not args.gt_dynamics and args.gt_reward
31 | from mbbl.network.dynamics.linear_stochastic_forward_dynamics_gmm_prior \
32 | import dynamics_network
33 | from mbbl.network.reward.groundtruth_reward import reward_network
34 |
35 | train(gps_trainer, singletask_sampler, mf_worker,
36 | dynamics_network, policy_network, reward_network, args)
37 |
38 |
39 | if __name__ == '__main__':
40 | main()
41 |
--------------------------------------------------------------------------------
/main/ilqr_main.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: main fucntion
3 | # -----------------------------------------------------------------------------
4 |
5 | from mbbl.config import base_config
6 | from mbbl.config import ilqr_config
7 | from mbbl.config import init_path
8 | from mbbl.util.base_main import train
9 | from mbbl.util.common import logger
10 |
11 |
12 | def main():
13 | parser = base_config.get_base_config()
14 | parser = ilqr_config.get_ilqr_config(parser)
15 | args = base_config.make_parser(parser)
16 |
17 | if args.write_log:
18 | logger.set_file_handler(path=args.output_dir,
19 | prefix='mbrl-ilqr' + args.task,
20 | time_str=args.exp_id)
21 |
22 | print('Training starts at {}'.format(init_path.get_abs_base_dir()))
23 | from mbbl.trainer import shooting_trainer
24 | from mbbl.sampler import singletask_ilqr_sampler
25 | from mbbl.worker import model_worker
26 | from mbbl.network.policy.random_policy import policy_network
27 |
28 | if args.gt_dynamics:
29 | from mbbl.network.dynamics.groundtruth_forward_dynamics import \
30 | dynamics_network
31 | else:
32 | from mbbl.network.dynamics.deterministic_forward_dynamics import \
33 | dynamics_network
34 |
35 | if args.gt_reward:
36 | from mbbl.network.reward.groundtruth_reward import reward_network
37 | else:
38 | from mbbl.network.reward.deterministic_reward import reward_network
39 |
40 | if (not args.gt_reward) or not (args.gt_dynamics):
41 | raise NotImplementedError('Havent finished! Oooooops')
42 |
43 | train(shooting_trainer, singletask_ilqr_sampler, model_worker,
44 | dynamics_network, policy_network, reward_network, args)
45 |
46 |
47 | if __name__ == '__main__':
48 | main()
49 |
--------------------------------------------------------------------------------
/main/inverse_dynamics_IM.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: main fucntion
3 | # -----------------------------------------------------------------------------
4 |
5 | from mbbl.config import base_config
6 | from mbbl.config import init_path
7 | from mbbl.config import rs_config
8 | from mbbl.config import il_config
9 | from mbbl.util.common import logger
10 | from mbbl.util.il import camera_pose_ID_solver
11 |
12 |
13 | def main():
14 | parser = base_config.get_base_config()
15 | parser = rs_config.get_rs_config(parser)
16 | parser = il_config.get_il_config(parser)
17 | args = base_config.make_parser(parser)
18 | args = il_config.post_process_config(args)
19 |
20 | if args.write_log:
21 | args.log_path = logger.set_file_handler(
22 | path=args.output_dir, prefix='inverse_dynamics' + args.task,
23 | time_str=args.exp_id
24 | )
25 |
26 | print('Training starts at {}'.format(init_path.get_abs_base_dir()))
27 |
28 | train(args)
29 |
30 |
31 | def train(args):
32 | training_args = {
33 | 'physics_loss_lambda': args.physics_loss_lambda,
34 | 'lbfgs_opt_iteration': args.lbfgs_opt_iteration
35 | }
36 | solver = camera_pose_ID_solver.solver(
37 | args.expert_data_name, args.sol_qpos_freq, args.opt_var_list,
38 | args.camera_info, args.camera_id, training_args,
39 | args.imitation_length,
40 | args.gt_camera_info, args.log_path
41 | )
42 |
43 | solver.solve()
44 |
45 |
46 | if __name__ == '__main__':
47 | main()
48 |
--------------------------------------------------------------------------------
/main/mf_main.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: main fucntion
3 | # -----------------------------------------------------------------------------
4 |
5 | import os
6 | os.environ['MUJOCO_GL'] = "osmesa"
7 | import time
8 | from collections import OrderedDict
9 |
10 | from mbbl.config import base_config
11 | from mbbl.config import init_path
12 | from mbbl.config import mf_config
13 | from mbbl.util.base_main import make_sampler, make_trainer, log_results
14 | from mbbl.util.common import logger
15 | from mbbl.util.common import parallel_util
16 |
17 |
18 | def train(trainer, sampler, worker, dynamics, policy, reward, args=None):
19 | logger.info('Training starts at {}'.format(init_path.get_abs_base_dir()))
20 | network_type = {'policy': policy, 'dynamics': dynamics, 'reward': reward}
21 |
22 | # make the trainer and sampler
23 | sampler_agent = make_sampler(sampler, worker, network_type, args)
24 | trainer_tasks, trainer_results, trainer_agent, init_weights = \
25 | make_trainer(trainer, network_type, args)
26 | sampler_agent.set_weights(init_weights)
27 |
28 | timer_dict = OrderedDict()
29 | timer_dict['Program Start'] = time.time()
30 | current_iteration = 0
31 |
32 | while True:
33 | timer_dict['** Program Total Time **'] = time.time()
34 |
35 | # step 1: collect rollout data
36 | rollout_data = \
37 | sampler_agent.rollouts_using_worker_playing(use_true_env=True)
38 |
39 | timer_dict['Generate Rollout'] = time.time()
40 |
41 | # step 2: train the weights for dynamics and policy network
42 | training_info = {'network_to_train': ['dynamics', 'reward', 'policy']}
43 | trainer_tasks.put(
44 | (parallel_util.TRAIN_SIGNAL,
45 | {'data': rollout_data['data'], 'training_info': training_info})
46 | )
47 | trainer_tasks.join()
48 | training_return = trainer_results.get()
49 | timer_dict['Train Weights'] = time.time()
50 |
51 | # step 4: update the weights
52 | sampler_agent.set_weights(training_return['network_weights'])
53 | timer_dict['Assign Weights'] = time.time()
54 |
55 | # log and print the results
56 | log_results(training_return, timer_dict)
57 |
58 | # if totalsteps > args.max_timesteps:
59 | if training_return['totalsteps'] > args.max_timesteps:
60 | break
61 | else:
62 | current_iteration += 1
63 |
64 | # end of training
65 | sampler_agent.end()
66 | trainer_tasks.put((parallel_util.END_SIGNAL, None))
67 |
68 |
69 | def main():
70 | parser = base_config.get_base_config()
71 | parser = mf_config.get_mf_config(parser)
72 | args = base_config.make_parser(parser)
73 |
74 | if args.write_log:
75 | logger.set_file_handler(path=args.output_dir,
76 | prefix='mfrl-mf' + args.task,
77 | time_str=args.exp_id)
78 |
79 | # no random policy for model-free rl
80 | assert args.random_timesteps == 0
81 |
82 | print('Training starts at {}'.format(init_path.get_abs_base_dir()))
83 | from mbbl.trainer import shooting_trainer
84 | from mbbl.sampler import singletask_sampler
85 | from mbbl.worker import mf_worker
86 | import mbbl.network.policy.trpo_policy
87 | import mbbl.network.policy.ppo_policy
88 |
89 | policy_network = {
90 | 'ppo': mbbl.network.policy.ppo_policy.policy_network,
91 | 'trpo': mbbl.network.policy.trpo_policy.policy_network
92 | }[args.trust_region_method]
93 |
94 | # here the dynamics and reward are simply placeholders, which cannot be
95 | # called to pred next state or reward
96 | from mbbl.network.dynamics.base_dynamics import base_dynamics_network
97 | from mbbl.network.reward.groundtruth_reward import base_reward_network
98 |
99 | train(shooting_trainer, singletask_sampler, mf_worker,
100 | base_dynamics_network, policy_network, base_reward_network, args)
101 |
102 |
103 | if __name__ == '__main__':
104 | main()
105 |
--------------------------------------------------------------------------------
/main/pets_main.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: main fucntion
3 | # -----------------------------------------------------------------------------
4 |
5 | from mbbl.config import base_config
6 | from mbbl.config import cem_config
7 | from mbbl.config import init_path
8 | from mbbl.util.base_main import train
9 | from mbbl.util.common import logger
10 |
11 |
12 | def main():
13 | parser = base_config.get_base_config()
14 | parser = cem_config.get_cem_config(parser)
15 | args = base_config.make_parser(parser)
16 |
17 | if args.write_log:
18 | logger.set_file_handler(path=args.output_dir,
19 | prefix='mbrl-cem' + args.task,
20 | time_str=args.exp_id)
21 |
22 | print('Training starts at {}'.format(init_path.get_abs_base_dir()))
23 | from mbbl.trainer import shooting_trainer
24 | from mbbl.sampler import singletask_pets_sampler
25 | from mbbl.worker import cem_worker
26 | from mbbl.network.policy.random_policy import policy_network
27 |
28 | if args.gt_dynamics:
29 | from mbbl.network.dynamics.groundtruth_forward_dynamics import \
30 | dynamics_network
31 | else:
32 | from mbbl.network.dynamics.deterministic_forward_dynamics import \
33 | dynamics_network
34 |
35 | if args.gt_reward:
36 | from mbbl.network.reward.groundtruth_reward import reward_network
37 | else:
38 | from mbbl.network.reward.deterministic_reward import reward_network
39 |
40 | train(shooting_trainer, singletask_pets_sampler, cem_worker,
41 | dynamics_network, policy_network, reward_network, args)
42 |
43 |
44 | if __name__ == '__main__':
45 | main()
46 |
--------------------------------------------------------------------------------
/main/random_main.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: main fucntion
3 | # -----------------------------------------------------------------------------
4 |
5 | from mbbl.config import base_config
6 | from mbbl.config import ilqr_config
7 | from mbbl.config import init_path
8 | from mbbl.util.base_main import train
9 | from mbbl.util.common import logger
10 |
11 |
12 | def main():
13 | parser = base_config.get_base_config()
14 | parser = ilqr_config.get_ilqr_config(parser)
15 | args = base_config.make_parser(parser)
16 |
17 | if args.write_log:
18 | logger.set_file_handler(path=args.output_dir,
19 | prefix='mbrl-ilqr' + args.task,
20 | time_str=args.exp_id)
21 |
22 | print('Training starts at {}'.format(init_path.get_abs_base_dir()))
23 | from mbbl.trainer import shooting_trainer
24 | from mbbl.sampler import singletask_random_sampler
25 | from mbbl.worker import model_worker
26 | from mbbl.network.policy.random_policy import policy_network
27 |
28 | from mbbl.network.dynamics.groundtruth_forward_dynamics import \
29 | dynamics_network
30 |
31 | if args.gt_reward:
32 | from mbbl.network.reward.groundtruth_reward import reward_network
33 | else:
34 | from mbbl.network.reward.deterministic_reward import reward_network
35 |
36 | train(shooting_trainer, singletask_random_sampler, model_worker,
37 | dynamics_network, policy_network, reward_network, args)
38 |
39 |
40 | if __name__ == '__main__':
41 | main()
42 |
--------------------------------------------------------------------------------
/main/rs_main.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: main fucntion
3 | # -----------------------------------------------------------------------------
4 |
5 | from mbbl.config import base_config
6 | from mbbl.config import init_path
7 | from mbbl.config import rs_config
8 | from mbbl.util.base_main import train
9 | from mbbl.util.common import logger
10 |
11 |
12 | def main():
13 | parser = base_config.get_base_config()
14 | parser = rs_config.get_rs_config(parser)
15 | args = base_config.make_parser(parser)
16 |
17 | if args.write_log:
18 | logger.set_file_handler(path=args.output_dir,
19 | prefix='mbrl-rs' + args.task,
20 | time_str=args.exp_id)
21 |
22 | print('Training starts at {}'.format(init_path.get_abs_base_dir()))
23 | from mbbl.trainer import shooting_trainer
24 | from mbbl.sampler import singletask_sampler
25 | from mbbl.worker import rs_worker
26 | from mbbl.network.policy.random_policy import policy_network
27 |
28 | if args.gt_dynamics:
29 | from mbbl.network.dynamics.groundtruth_forward_dynamics import \
30 | dynamics_network
31 | else:
32 | from mbbl.network.dynamics.deterministic_forward_dynamics import \
33 | dynamics_network
34 |
35 | if args.gt_reward:
36 | from mbbl.network.reward.groundtruth_reward import reward_network
37 | else:
38 | from mbbl.network.reward.deterministic_reward import reward_network
39 |
40 | train(shooting_trainer, singletask_sampler, rs_worker,
41 | dynamics_network, policy_network, reward_network, args)
42 |
43 |
44 | if __name__ == '__main__':
45 | main()
46 |
--------------------------------------------------------------------------------
/main/test.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: main fucntion
3 | # -----------------------------------------------------------------------------
4 |
5 | import multiprocessing
6 | import time
7 | import os
8 | from collections import OrderedDict
9 |
10 | from mbbl.config import init_path
11 | from mbbl.util.common import parallel_util
12 | from mbbl.util.common import logger
13 |
14 | if __name__ == '__main__':
15 | print('a')
16 |
--------------------------------------------------------------------------------
/mbbl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/__init__.py
--------------------------------------------------------------------------------
/mbbl/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/config/__init__.py
--------------------------------------------------------------------------------
/mbbl/config/cem_config.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # record the parameters here
4 | # ------------------------------------------------------------------------------
5 |
6 |
7 | def get_cem_config(parser):
8 | parser.add_argument("--dynamics_val_percentage", type=float, default=0.33)
9 | parser.add_argument("--dynamics_val_max_size", type=int, default=5000)
10 | parser.add_argument("--num_planning_traj", type=int, default=10)
11 | parser.add_argument("--planning_depth", type=int, default=10)
12 | parser.add_argument("--cem_learning_rate", type=float, default=0.1)
13 | parser.add_argument("--cem_num_iters", type=int, default=5)
14 | parser.add_argument("--cem_elites_fraction", type=float, default=0.1)
15 |
16 | parser.add_argument("--check_done", type=int, default=0)
17 |
18 | return parser
19 |
--------------------------------------------------------------------------------
/mbbl/config/ggnn_config.py:
--------------------------------------------------------------------------------
1 |
2 | def get_gnn_config(parser):
3 |
4 | parser.add_argument('--ggnn_keep_prob', type=float, default=0.8)
5 |
6 | parser.add_argument('--t_step', type=int, default=3)
7 |
8 | parser.add_argument('--embed_layer', type=int, default=1)
9 | parser.add_argument('--embed_neuron', type=int, default=256)
10 | parser.add_argument('--prop_layer', type=int, default=1)
11 | parser.add_argument('--prop_neuron', type=int, default=1024)
12 | parser.add_argument('--output_layer', type=int, default=1)
13 | parser.add_argument('--output_neuron', type=int, default=1024)
14 |
15 | parser.add_argument('--prop_normalize', action='store_true')
16 |
17 | parser.add_argument('--d_output', action='store_true')
18 | parser.add_argument('--d_bins', type=int, default=51)
19 |
20 | return parser
21 |
--------------------------------------------------------------------------------
/mbbl/config/gps_config.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # ------------------------------------------------------------------------------
4 |
5 |
6 | def get_gps_config(parser):
7 |
8 | # the linear gaussian dynamics with gmm prior
9 | parser.add_argument("--gmm_num_cluster", type=int, default=30)
10 | parser.add_argument("--gmm_max_iteration", type=int, default=100)
11 | parser.add_argument("--gmm_batch_size", type=int, default=2000)
12 | parser.add_argument("--gmm_prior_strength", type=float, default=1.0)
13 | parser.add_argument("--gps_dynamics_cov_reg", type=float, default=1e-6)
14 | parser.add_argument("--gps_policy_cov_reg", type=float, default=1e-8)
15 | # parser.add_argument("--gps_nn_policy_cov_reg", type=float, default=1e-6)
16 |
17 | parser.add_argument("--gps_max_backward_pass_trials", type=int,
18 | default=20)
19 |
20 | # the constraints on the kl between policy and traj
21 | parser.add_argument("--gps_init_traj_kl_eta", type=float, default=1.0)
22 | parser.add_argument("--gps_min_eta", type=float, default=1e-8)
23 | parser.add_argument("--gps_max_eta", type=float, default=1e16)
24 | parser.add_argument("--gps_eta_multiplier", type=float, default=1e-4)
25 |
26 | parser.add_argument("--gps_min_kl_step_mult", type=float, default=1e-2)
27 | parser.add_argument("--gps_max_kl_step_mult", type=float, default=1e2)
28 | parser.add_argument("--gps_init_kl_step_mult", type=float, default=1.0)
29 |
30 | parser.add_argument("--gps_base_kl_step", type=float, default=1.0)
31 |
32 | # the penalty on the entropy of the traj
33 | parser.add_argument("--gps_traj_ent_epsilon", type=float, default=0.001)
34 |
35 | parser.add_argument("--gps_policy_cov_damping", type=float, default=0.0001)
36 | parser.add_argument("--gps_init_state_replay_size", type=int, default=1000)
37 |
38 | # single traj update
39 | parser.add_argument("--gps_single_condition", type=int, default=1)
40 | return parser
41 |
--------------------------------------------------------------------------------
/mbbl/config/il_config.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief:
3 | # record the parameters here
4 | # @author:
5 | # Tingwu Wang, 2017, June, 12th
6 | # ------------------------------------------------------------------------------
7 |
8 | from mbbl.env.env_register import _ENV_INFO
9 | from mbbl.util.il.camera_model import xyaxis2quaternion
10 | import copy
11 |
12 |
13 | def get_il_config(parser):
14 | # get the parameters
15 | parser.add_argument("--GAN_reward_clip_value", type=float, default=10.0)
16 | parser.add_argument("--GAN_ent_coeff", type=float, default=0.0)
17 |
18 | parser.add_argument("--expert_data_name", type=str, default='')
19 | parser.add_argument("--traj_episode_num", type=int, default=-1)
20 |
21 | parser.add_argument("--gan_timesteps_per_epoch", type=int, default=2000)
22 | parser.add_argument("--positive_negative_ratio", type=float, default=1)
23 |
24 | # the config for the inverse dynamics planner
25 | parser.add_argument("--sol_qpos_freq", type=int, default=1)
26 | parser.add_argument("--camera_id", type=int, default=0)
27 | parser.add_argument("--imitation_length", type=int, default=10,
28 | help="debug using 10, 100 sounds more like real running")
29 |
30 | parser.add_argument("--opt_var_list", type=str,
31 | # default='qpos',
32 | default='quaternion',
33 | # default='qpos-xyz_pos-quaternion-fov-image_size',
34 | help='use the "-" to divide variables')
35 |
36 | parser.add_argument("--physics_loss_lambda", type=float, default=1.0,
37 | help='lambda weights the inverse or physics loss')
38 | parser.add_argument("--lbfgs_opt_iteration", type=int, default=1,
39 | help='number of iterations for the inner lbfgs loops')
40 |
41 | return parser
42 |
43 |
44 | def post_process_config(args):
45 | """ @brief:
46 | 1. Parse the opt_var_list string into list
47 | 2. parse the camera info
48 | """
49 | # hack to parse the opt_var_list
50 | args.opt_var_list = args.opt_var_list.split('-')
51 | for key in args.opt_var_list:
52 | assert key in ['qpos', 'xyz_pos', 'quaternion', 'fov', 'image_size']
53 |
54 | # the camera_info
55 | assert 'camera_info' in _ENV_INFO[args.task_name]
56 | args.camera_info = _ENV_INFO[args.task_name]['camera_info']
57 |
58 | # generate the quaternion data
59 | for cam in args.camera_info:
60 | if 'quaternion' not in cam:
61 | cam['quaternion'] = xyaxis2quaternion(cam['xyaxis'])
62 |
63 | args.gt_camera_info = \
64 | copy.deepcopy(_ENV_INFO[args.task_name]['camera_info'])
65 |
66 | # the mask the groundtruth that is not in given?
67 | for key in ['qpos', 'xyz_pos', 'quaternion', 'fov', 'image_size']:
68 | if key in args.opt_var_list:
69 | args.camera_info[args.camera_id][key] = None
70 |
71 | return args
72 |
--------------------------------------------------------------------------------
/mbbl/config/ilqr_config.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # ------------------------------------------------------------------------------
4 |
5 |
6 | def get_ilqr_config(parser):
7 | # get the parameters
8 | parser.add_argument("--finite_difference_eps", type=float, default=1e-6)
9 | parser.add_argument("--num_ilqr_traj", type=int, default=1,
10 | help='number of different initializations of ilqr')
11 | parser.add_argument("--ilqr_depth", type=int, default=20)
12 | parser.add_argument("--ilqr_linesearch_accept_ratio", type=float,
13 | default=0.01)
14 | parser.add_argument("--ilqr_linesearch_decay_factor", type=float,
15 | default=5.0)
16 | parser.add_argument("--ilqr_iteration", type=int, default=10)
17 | parser.add_argument("--max_ilqr_linesearch_backtrack", type=int,
18 | default=30)
19 |
20 | parser.add_argument("--LM_damping_type", type=str, default='V',
21 | help="['V', 'Q'], whether to put damping on V or Q")
22 | parser.add_argument("--init_LM_damping", type=float, default=0.1,
23 | help="initial value of Levenberg-Marquardt_damping")
24 | parser.add_argument("--min_LM_damping", type=float, default=1e-3)
25 | parser.add_argument("--max_LM_damping", type=float, default=1e10)
26 | parser.add_argument("--init_LM_damping_multiplier", type=float, default=1.0)
27 | parser.add_argument("--LM_damping_factor", type=float, default=1.6)
28 |
29 | return parser
30 |
--------------------------------------------------------------------------------
/mbbl/config/init_path.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # In this file we init the path
4 | # @author:
5 | # Written by Tingwu Wang, 2016/Sept/22nd
6 | # ------------------------------------------------------------------------------
7 |
8 |
9 | import os.path as osp
10 | import datetime
11 |
12 | real_file_path = osp.realpath(__file__.replace('pyc', 'py'))
13 | _this_dir = osp.dirname(real_file_path)
14 |
15 | running_start_time = datetime.datetime.now()
16 | time = str(running_start_time.strftime("%Y_%m_%d-%X"))
17 |
18 | _base_dir = osp.join(_this_dir, '..', '..')
19 |
20 |
21 | def bypass_frost_warning():
22 | return 0
23 |
24 |
25 | def get_base_dir():
26 | return _base_dir
27 |
28 |
29 | def get_time():
30 | return time
31 |
32 |
33 | def get_abs_base_dir():
34 | return osp.abspath(_base_dir)
35 |
--------------------------------------------------------------------------------
/mbbl/config/mbmf_config.py:
--------------------------------------------------------------------------------
1 |
2 | def get_mbmf_config(parser):
3 | # get the parameters
4 | # MF config
5 | parser.add_argument("--value_lr", type=float, default=3e-4)
6 | parser.add_argument("--value_epochs", type=int, default=20)
7 | parser.add_argument("--value_network_shape", type=str, default='64,64')
8 | # parser.add_argument("--value_batch_size", type=int, default=64)
9 | parser.add_argument("--value_activation_type", type=str, default='tanh')
10 | parser.add_argument("--value_normalizer_type", type=str, default='none')
11 |
12 | parser.add_argument("--trust_region_method", type=str, default='ppo',
13 | help='["ppo", "trpo"]')
14 | parser.add_argument("--gae_lam", type=float, default=0.95)
15 | parser.add_argument("--fisher_cg_damping", type=float, default=0.1)
16 | parser.add_argument("--target_kl", type=float, default=0.01)
17 | parser.add_argument("--cg_iterations", type=int, default=10)
18 |
19 | # RS config
20 | parser.add_argument("--dynamics_val_percentage", type=float, default=0.33)
21 | parser.add_argument("--dynamics_val_max_size", type=int, default=5000)
22 | parser.add_argument("--num_planning_traj", type=int, default=10)
23 | parser.add_argument("--planning_depth", type=int, default=10)
24 | parser.add_argument("--mb_timesteps", type=int, default=2e4)
25 |
26 | # Imitation learning config
27 | parser.add_argument("--initial_policy_lr", type=float, default=1e-4)
28 | parser.add_argument("--initial_policy_bs", type=int, default=500)
29 | parser.add_argument("--dagger_iter", type=int, default=3)
30 | parser.add_argument("--dagger_epoch", type=int, default=70)
31 | parser.add_argument("--dagger_timesteps_per_iter", type=int, default=1750)
32 |
33 | parser.add_argument("--ppo_clip", type=float, default=0.1)
34 | parser.add_argument("--num_minibatches", type=int, default=10)
35 | parser.add_argument("--target_kl_high", type=float, default=2)
36 | parser.add_argument("--target_kl_low", type=float, default=0.5)
37 | parser.add_argument("--use_weight_decay", type=int, default=0)
38 | parser.add_argument("--weight_decay_coeff", type=float, default=1e-5)
39 |
40 | parser.add_argument("--use_kl_penalty", type=int, default=0)
41 | parser.add_argument("--kl_alpha", type=float, default=1.5)
42 | parser.add_argument("--kl_eta", type=float, default=50)
43 |
44 | parser.add_argument("--policy_lr_schedule", type=str, default='linear',
45 | help='["linear", "constant", "adaptive"]')
46 | parser.add_argument("--policy_lr_alpha", type=int, default=2)
47 | return parser
48 |
--------------------------------------------------------------------------------
/mbbl/config/metrpo_config.py:
--------------------------------------------------------------------------------
1 |
2 | def get_metrpo_config(parser):
3 | # get the parameters
4 | parser.add_argument("--value_lr", type=float, default=3e-4)
5 | parser.add_argument("--value_epochs", type=int, default=20)
6 | parser.add_argument("--value_network_shape", type=str, default='64,64')
7 | # parser.add_argument("--value_batch_size", type=int, default=64)
8 | parser.add_argument("--value_activation_type", type=str, default='tanh')
9 | parser.add_argument("--value_normalizer_type", type=str, default='none')
10 |
11 | parser.add_argument("--gae_lam", type=float, default=0.95)
12 | parser.add_argument("--fisher_cg_damping", type=float, default=0.1)
13 | parser.add_argument("--target_kl", type=float, default=0.01)
14 | parser.add_argument("--cg_iterations", type=int, default=10)
15 |
16 | parser.add_argument("--dynamics_val_percentage", type=float, default=0.33)
17 | parser.add_argument("--dynamics_val_max_size", type=int, default=10000)
18 | parser.add_argument("--max_fake_timesteps", type=int, default=1e5)
19 |
20 | return parser
21 |
--------------------------------------------------------------------------------
/mbbl/config/mf_config.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @author:
3 | # Tingwu Wang
4 | # ------------------------------------------------------------------------------
5 |
6 |
7 | def get_mf_config(parser):
8 | # get the parameters
9 | parser.add_argument("--value_lr", type=float, default=3e-4)
10 | parser.add_argument("--value_epochs", type=int, default=20)
11 | parser.add_argument("--value_network_shape", type=str, default='64,64')
12 | # parser.add_argument("--value_batch_size", type=int, default=64)
13 | parser.add_argument("--value_activation_type", type=str, default='tanh')
14 | parser.add_argument("--value_normalizer_type", type=str, default='none')
15 |
16 | parser.add_argument("--trust_region_method", type=str, default='trpo',
17 | help='["ppo", "trpo"]')
18 | parser.add_argument("--gae_lam", type=float, default=0.95)
19 | parser.add_argument("--fisher_cg_damping", type=float, default=0.1)
20 | parser.add_argument("--target_kl", type=float, default=0.01)
21 | parser.add_argument("--cg_iterations", type=int, default=10)
22 |
23 | parser.add_argument("--ppo_clip", type=float, default=0.1)
24 | parser.add_argument("--num_minibatches", type=int, default=10)
25 | parser.add_argument("--target_kl_high", type=float, default=2)
26 | parser.add_argument("--target_kl_low", type=float, default=0.5)
27 | parser.add_argument("--use_weight_decay", type=int, default=0)
28 | parser.add_argument("--weight_decay_coeff", type=float, default=1e-5)
29 |
30 | parser.add_argument("--use_kl_penalty", type=int, default=0)
31 | parser.add_argument("--kl_alpha", type=float, default=1.5)
32 | parser.add_argument("--kl_eta", type=float, default=50)
33 |
34 | parser.add_argument("--policy_lr_schedule", type=str, default='linear',
35 | help='["linear", "constant", "adaptive"]')
36 | parser.add_argument("--policy_lr_alpha", type=int, default=2)
37 | return parser
38 |
--------------------------------------------------------------------------------
/mbbl/config/rs_config.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # record the parameters here
4 | # @author:
5 | # Tingwu Wang, 2017, June, 12th
6 | # ------------------------------------------------------------------------------
7 |
8 |
9 | def get_rs_config(parser):
10 | # get the parameters
11 | parser.add_argument("--dynamics_val_percentage", type=float, default=0.33)
12 | parser.add_argument("--dynamics_val_max_size", type=int, default=5000)
13 | parser.add_argument("--num_planning_traj", type=int, default=10)
14 | parser.add_argument("--planning_depth", type=int, default=10)
15 | parser.add_argument("--check_done", type=int, default=0)
16 |
17 | return parser
18 |
--------------------------------------------------------------------------------
/mbbl/env/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/env/__init__.py
--------------------------------------------------------------------------------
/mbbl/env/base_env_wrapper.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @author:
3 | # Tingwu Wang
4 | # @brief:
5 | # The environment wrapper
6 | # -----------------------------------------------------------------------------
7 | import numpy as np
8 |
9 |
10 | class base_env(object):
11 |
12 | def __init__(self, env_name, rand_seed, misc_info={}):
13 | self._env_name = env_name
14 | self._seed = rand_seed
15 | self._npr = np.random.RandomState(self._seed)
16 | self._misc_info = misc_info
17 |
18 | # build the environment
19 | self._build_env()
20 | self._set_groundtruth_api()
21 |
22 | def step(self, action):
23 | raise NotImplementedError
24 |
25 | def reset(self, control_info={}):
26 | raise NotImplementedError
27 |
28 | def _build_env(self):
29 | raise NotImplementedError
30 |
31 | def _set_groundtruth_api(self):
32 | """ @brief:
33 | In this function, we could provide the ground-truth dynamics
34 | and rewards APIs for the agent to call.
35 | For the new environments, if we don't set their ground-truth
36 | apis, then we cannot test the algorithm using ground-truth
37 | dynamics or reward
38 | """
39 |
40 | def fdynamics(data_dict):
41 | raise NotImplementedError
42 | self.fdynamics = fdynamics
43 |
44 | def reward(data_dict):
45 | raise NotImplementedError
46 | self.reward = reward
47 |
48 | def reward_derivative(data_dict, target):
49 | raise NotImplementedError
50 | self.reward_derivative = reward_derivative
51 |
--------------------------------------------------------------------------------
/mbbl/env/bullet_env/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/env/bullet_env/__init__.py
--------------------------------------------------------------------------------
/mbbl/env/bullet_env/motion_capture_data.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | class MotionCaptureData(object):
4 | def __init__(self):
5 | self.Reset()
6 |
7 | def Reset(self):
8 | self._motion_data = []
9 |
10 | def Load(self, path):
11 | with open(path, 'r') as f:
12 | self._motion_data = json.load(f)
13 |
14 | def NumFrames(self):
15 | return len(self._motion_data['Frames'])
16 |
17 | def KeyFrameDuraction(self):
18 | return self._motion_data['Frames'][0][0]
19 |
20 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/env/dm_env/__init__.py
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/cheetah_pos.xml:
--------------------------------------------------------------------------------
1 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/common/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Functions to manage the common assets for domains."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | from dm_control.utils import io as resources
24 |
25 | _SUITE_DIR = os.path.dirname(os.path.dirname(__file__))
26 | _FILENAMES = [
27 | "./common/materials.xml",
28 | "./common/skybox.xml",
29 | "./common/visual.xml",
30 | ]
31 |
32 | ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename))
33 | for filename in _FILENAMES}
34 |
35 |
36 | def read_model(model_filename):
37 | """Reads a model XML file and returns its contents as a string."""
38 | return resources.GetResource(os.path.join(_SUITE_DIR, model_filename))
39 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/common/materials.xml:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/common/skybox.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/common/visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/acrobot.xml:
--------------------------------------------------------------------------------
1 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/ball_in_cup.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/cartpole.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/cheetah.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/finger.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/fish.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/lqr.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/point_mass.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/reacher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/mbbl/env/dm_env/assets/reference/walker.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/mbbl/env/fake_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class fake_env(object):
5 |
6 | def __init__(self, env, model):
7 | self._env = env
8 | self._model = model
9 | self._state = None
10 | self._current_step = 0
11 | self._max_length = self._env._env_info['max_length']
12 | self._obs_bounds = (-1e5, 1e5)
13 |
14 | def step(self, action):
15 | self._current_step += 1
16 | next_state, reward = self._model(self._state, action)
17 | next_state = np.clip(next_state, *self._obs_bounds)
18 | self._state = next_state
19 |
20 | if self._current_step > self._max_length:
21 | done = True
22 | else:
23 | done = False
24 | return next_state, reward, done, None
25 |
26 | def reset(self):
27 | self._current_step = 0
28 | ob, reward, done, info = self._env.reset()
29 | self._state = ob
30 |
31 | return ob, reward, done, info
32 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/env/gym_env/__init__.py
--------------------------------------------------------------------------------
/mbbl/env/gym_env/box2d/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Wed Aug 1 12:59:34 2018
5 |
6 | @author: matthewszhang
7 | """
8 |
9 | from env.gym_env.box2d.core import box2d_make
10 | from env.gym_env.box2d.wrappers import LunarLanderWrapper, \
11 | WalkerWrapper, RacerWrapper, Box2D_Wrapper
--------------------------------------------------------------------------------
/mbbl/env/gym_env/box2d/core.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Wed Aug 1 12:13:58 2018
5 |
6 | @author: matthewszhang
7 | """
8 | from mbbl.env.gym_env.box2d.wrappers import LunarLanderWrapper, WalkerWrapper, RacerWrapper
9 |
10 | _WRAPPER_DICT = {'LunarLanderContinuous':LunarLanderWrapper,
11 | 'LunarLander':LunarLanderWrapper,
12 | 'BipedalWalker':WalkerWrapper,
13 | 'BipedalWalkerHardcore':WalkerWrapper,
14 | 'CarRacing':RacerWrapper}
15 |
16 | def get_wrapper(gym_id):
17 | try:
18 | return _WRAPPER_DICT[gym_id]
19 | except:
20 | raise KeyError("Non-existing Box2D env")
21 |
22 | def box2d_make(gym_id): # naive build of environment, leave safeties for gym.make
23 | import re
24 | remove_version = re.compile(r'-v(\d+)$') # version safety
25 | gym_id_base = remove_version.sub('', gym_id)
26 |
27 | wrapper = get_wrapper(gym_id_base)
28 | return wrapper(gym_id)
29 |
30 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/box2d_racer.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @author:
3 | # Matthew Zhang
4 | # @brief:
5 | # Alternate code-up of the box2d lunar lander environment from gym
6 | # -----------------------------------------------------------------------------
7 | from mbbl.env import base_env_wrapper as bew
8 | import mbbl.env.gym_env.box2d.wrappers
9 |
10 | Racer = mbbl.env.gym_env.box2d.wrappers.RacerWrapper
11 |
12 | class env(bew.base_env):
13 | pass
14 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/fix_swimmer/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 | register(
3 | id='Fixswimmer-v1',
4 | entry_point='mbbl.env.gym_env.fix_swimmer.fixed_swimmer:fixedSwimmerEnv'
5 | )
6 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/fix_swimmer/assets/fixed_swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/fix_swimmer/fixed_swimmer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 | from gym.envs.mujoco import mujoco_env
4 | import os
5 |
6 |
7 | class fixedSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
8 |
9 | def __init__(self):
10 | dir_path = os.path.dirname(os.path.realpath(__file__))
11 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/fixed_swimmer.xml' % dir_path, 4)
12 | utils.EzPickle.__init__(self)
13 |
14 | def _step(self, a):
15 | ctrl_cost_coeff = 0.0001
16 |
17 | """
18 | xposbefore = self.model.data.qpos[0, 0]
19 | self.do_simulation(a, self.frame_skip)
20 | xposafter = self.model.data.qpos[0, 0]
21 | """
22 |
23 | self.xposbefore = self.model.data.site_xpos[0][0] / self.dt
24 | self.do_simulation(a, self.frame_skip)
25 | self.xposafter = self.model.data.site_xpos[0][0] / self.dt
26 | self.pos_diff = self.xposafter - self.xposbefore
27 |
28 | reward_fwd = self.xposafter - self.xposbefore
29 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum()
30 | reward = reward_fwd + reward_ctrl
31 | ob = self._get_obs()
32 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
33 |
34 | def _get_obs(self):
35 | qpos = self.model.data.qpos
36 | qvel = self.model.data.qvel
37 | return np.concatenate([qpos.flat[2:], qvel.flat, self.pos_diff.flat])
38 |
39 | def reset_model(self):
40 | self.set_state(
41 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
42 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv)
43 | )
44 | return self._get_obs()
45 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/pets_env/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 |
3 |
4 | register(
5 | id='MBRLCartpole-v0',
6 | entry_point='mbbl.env.gym_env.pets_env.cartpole:CartpoleEnv'
7 | )
8 |
9 |
10 | register(
11 | id='MBRLReacher3D-v0',
12 | entry_point='mbbl.env.gym_env.pets_env.reacher:Reacher3DEnv'
13 | )
14 |
15 |
16 | register(
17 | id='MBRLPusher-v0',
18 | entry_point='mbbl.env.gym_env.pets_env.pusher:PusherEnv'
19 | )
20 |
21 |
22 | register(
23 | id='MBRLHalfCheetah-v0',
24 | entry_point='mbbl.env.gym_env.pets_env.half_cheetah:HalfCheetahEnv'
25 | )
26 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/pets_env/assets/cartpole.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/pets_env/cartpole.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 |
7 | import numpy as np
8 | from gym import utils
9 | from gym.envs.mujoco import mujoco_env
10 |
11 |
12 | class CartpoleEnv(mujoco_env.MujocoEnv, utils.EzPickle):
13 | PENDULUM_LENGTH = 0.6
14 |
15 | def __init__(self):
16 | utils.EzPickle.__init__(self)
17 | dir_path = os.path.dirname(os.path.realpath(__file__))
18 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/cartpole.xml' % dir_path, 2)
19 |
20 | def _step(self, a):
21 | self.do_simulation(a, self.frame_skip)
22 | ob = self._get_obs()
23 |
24 | cost_lscale = CartpoleEnv.PENDULUM_LENGTH
25 | reward = np.exp(
26 | -np.sum(np.square(self._get_ee_pos(ob) - np.array([0.0, CartpoleEnv.PENDULUM_LENGTH]))) / (cost_lscale ** 2)
27 | )
28 | reward -= 0.01 * np.sum(np.square(a))
29 |
30 | done = False
31 | return ob, reward, done, {}
32 |
33 | def reset_model(self):
34 | qpos = self.init_qpos + np.random.normal(0, 0.1, np.shape(self.init_qpos))
35 | qvel = self.init_qvel + np.random.normal(0, 0.1, np.shape(self.init_qvel))
36 | self.set_state(qpos, qvel)
37 | return self._get_obs()
38 |
39 | def _get_obs(self):
40 | return np.concatenate([self.model.data.qpos, self.model.data.qvel]).ravel()
41 |
42 | @staticmethod
43 | def _get_ee_pos(x):
44 | x0, theta = x[0], x[1]
45 | return np.array([
46 | x0 - CartpoleEnv.PENDULUM_LENGTH * np.sin(theta),
47 | -CartpoleEnv.PENDULUM_LENGTH * np.cos(theta)
48 | ])
49 |
50 | def viewer_setup(self):
51 | v = self.viewer
52 | v.cam.trackbodyid = 0
53 | v.cam.distance = v.model.stat.extent
54 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/pets_env/half_cheetah.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 |
7 | import numpy as np
8 | from gym import utils
9 | from gym.envs.mujoco import mujoco_env
10 |
11 |
12 | class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
13 |
14 | def __init__(self):
15 | self.prev_qpos = None
16 | dir_path = os.path.dirname(os.path.realpath(__file__))
17 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/half_cheetah.xml' % dir_path, 5)
18 | utils.EzPickle.__init__(self)
19 |
20 | def _step(self, action):
21 | self.prev_qpos = np.copy(self.model.data.qpos.flat)
22 | self.do_simulation(action, self.frame_skip)
23 | ob = self._get_obs()
24 |
25 | reward_ctrl = -0.1 * np.square(action).sum()
26 | reward_run = ob[0] - 0.0 * np.square(ob[2])
27 | reward = reward_run + reward_ctrl
28 |
29 | done = False
30 | return ob, reward, done, {}
31 |
32 | def _get_obs(self):
33 | return np.concatenate([
34 | (self.model.data.qpos.flat[:1] - self.prev_qpos[:1]) / self.dt,
35 | self.model.data.qpos.flat[1:],
36 | self.model.data.qvel.flat,
37 | ])
38 |
39 | def reset_model(self):
40 | qpos = self.init_qpos + np.random.normal(loc=0, scale=0.001, size=self.model.nq)
41 | qvel = self.init_qvel + np.random.normal(loc=0, scale=0.001, size=self.model.nv)
42 | self.set_state(qpos, qvel)
43 | self.prev_qpos = np.copy(self.model.data.qpos.flat)
44 | return self._get_obs()
45 |
46 | def viewer_setup(self):
47 | self.viewer.cam.distance = self.model.stat.extent * 0.25
48 | self.viewer.cam.elevation = -55
49 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/pets_env/half_cheetah_config.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | import gym
7 |
8 |
9 | class HalfCheetahConfigModule:
10 | ENV_NAME = "MBRLHalfCheetah-v0"
11 | TASK_HORIZON = 1000
12 | NTRAIN_ITERS = 300
13 | NROLLOUTS_PER_ITER = 1
14 | PLAN_HOR = 30
15 | INIT_VAR = 0.25
16 | MODEL_IN, MODEL_OUT = 24, 18 # obs - > 18, action 6
17 | GP_NINDUCING_POINTS = 300
18 |
19 | def __init__(self):
20 | self.ENV = gym.make(self.ENV_NAME)
21 | self.NN_TRAIN_CFG = {"epochs": 5}
22 | self.OPT_CFG = {
23 | "Random": {
24 | "popsize": 2500
25 | },
26 | "GBPRandom": {
27 | "popsize": 2500
28 | },
29 | "GBPCEM": {
30 | "popsize": 500,
31 | "num_elites": 50,
32 | "max_iters": 5,
33 | "alpha": 0.1
34 | },
35 | "CEM": {
36 | "popsize": 500,
37 | "num_elites": 50,
38 | "max_iters": 5,
39 | "alpha": 0.1
40 | },
41 | "PWCEM": {
42 | "popsize": 500,
43 | "num_elites": 50,
44 | "max_iters": 5,
45 | "alpha": 0.1
46 | },
47 | "POCEM": {
48 | "popsize": 500,
49 | "num_elites": 50,
50 | "max_iters": 5,
51 | "alpha": 0.1
52 | }
53 | }
54 |
55 | @staticmethod
56 | def obs_preproc(obs):
57 | return np.concatenate([obs[:, 1:2], np.sin(obs[:, 2:3]), np.cos(obs[:, 2:3]), obs[:, 3:]], axis=1)
58 |
59 | @staticmethod
60 | def obs_postproc(obs, pred):
61 | return np.concatenate([pred[:, :1], obs[:, 1:] + pred[:, 1:]], axis=1)
62 |
63 | @staticmethod
64 | def targ_proc(obs, next_obs):
65 | return np.concatenate([next_obs[:, :1], next_obs[:, 1:] - obs[:, 1:]], axis=1)
66 |
67 | @staticmethod
68 | def obs_cost_fn(obs):
69 | return -obs[:, 0]
70 |
71 | @staticmethod
72 | def ac_cost_fn(acs):
73 | return 0.1 * np.sum(np.square(acs), axis=1)
74 |
75 | def nn_constructor(self, model_init_cfg, misc=None):
76 | pass
77 |
78 | def gp_constructor(self, model_init_cfg):
79 | pass
80 |
81 |
82 | CONFIG_MODULE = HalfCheetahConfigModule
83 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/pets_env/pusher.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 |
7 | import numpy as np
8 | from gym import utils
9 | from gym.envs.mujoco import mujoco_env
10 |
11 |
12 | class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
13 |
14 | def __init__(self):
15 | dir_path = os.path.dirname(os.path.realpath(__file__))
16 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/pusher.xml' % dir_path, 4)
17 | utils.EzPickle.__init__(self)
18 | self.reset_model()
19 |
20 | def _step(self, a):
21 | obj_pos = self.get_body_com("object"),
22 | vec_1 = obj_pos - self.get_body_com("tips_arm")
23 | vec_2 = obj_pos - self.get_body_com("goal")
24 |
25 | reward_near = -np.sum(np.abs(vec_1))
26 | reward_dist = -np.sum(np.abs(vec_2))
27 | reward_ctrl = -np.square(a).sum()
28 | reward = 1.25 * reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near
29 |
30 | self.do_simulation(a, self.frame_skip)
31 | ob = self._get_obs()
32 | done = False
33 | return ob, reward, done, {}
34 |
35 | def viewer_setup(self):
36 | self.viewer.cam.trackbodyid = -1
37 | self.viewer.cam.distance = 4.0
38 |
39 | def reset_model(self):
40 | qpos = self.init_qpos
41 |
42 | self.goal_pos = np.asarray([0, 0])
43 | self.cylinder_pos = np.array([-0.25, 0.15]) + np.random.normal(0, 0.025, [2])
44 |
45 | qpos[-4:-2] = self.cylinder_pos
46 | qpos[-2:] = self.goal_pos
47 | qvel = self.init_qvel + \
48 | self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv)
49 | qvel[-4:] = 0
50 | self.set_state(qpos, qvel)
51 | self.ac_goal_pos = self.get_body_com("goal")
52 |
53 | return self._get_obs()
54 |
55 | def _get_obs(self):
56 | return np.concatenate([
57 | self.model.data.qpos.flat[:7],
58 | self.model.data.qvel.flat[:7],
59 | self.get_body_com("tips_arm"),
60 | self.get_body_com("object"),
61 | self.get_body_com("goal"),
62 | ])
63 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/pets_env/pusher_config.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | import gym
7 |
8 |
9 | class PusherConfigModule:
10 | ENV_NAME = "MBRLPusher-v0"
11 | TASK_HORIZON = 150
12 | NTRAIN_ITERS = 100
13 | NROLLOUTS_PER_ITER = 1
14 | PLAN_HOR = 25
15 | INIT_VAR = 0.25
16 | MODEL_IN, MODEL_OUT = 27, 20
17 | GP_NINDUCING_POINTS = 200
18 |
19 | def __init__(self):
20 | self.ENV = gym.make(self.ENV_NAME)
21 | self.NN_TRAIN_CFG = {"epochs": 5}
22 | self.OPT_CFG = {
23 | "Random": {
24 | "popsize": 2500
25 | },
26 | "CEM": {
27 | "popsize": 500,
28 | "num_elites": 50,
29 | "max_iters": 5,
30 | "alpha": 0.1
31 | },
32 | "GBPRandom": {
33 | "popsize": 2500
34 | },
35 | "GBPCEM": {
36 | "popsize": 500,
37 | "num_elites": 50,
38 | "max_iters": 5,
39 | "alpha": 0.1
40 | },
41 | "PWCEM": {
42 | "popsize": 500,
43 | "num_elites": 50,
44 | "max_iters": 5,
45 | "alpha": 0.1
46 | },
47 | "POCEM": {
48 | "popsize": 500,
49 | "num_elites": 50,
50 | "max_iters": 5,
51 | "alpha": 0.1
52 | }
53 | }
54 |
55 | @staticmethod
56 | def obs_postproc(obs, pred):
57 | return obs + pred
58 |
59 | @staticmethod
60 | def targ_proc(obs, next_obs):
61 | return next_obs - obs
62 |
63 | def obs_cost_fn(self, obs):
64 | to_w, og_w = 0.5, 1.25
65 | tip_pos, obj_pos, goal_pos = obs[:, 14:17], obs[:, 17:20], obs[:, -3:]
66 |
67 | tip_obj_dist = np.sum(np.abs(tip_pos - obj_pos), axis=1)
68 | obj_goal_dist = np.sum(np.abs(goal_pos - obj_pos), axis=1)
69 | return to_w * tip_obj_dist + og_w * obj_goal_dist
70 |
71 | @staticmethod
72 | def ac_cost_fn(acs):
73 | return 0.1 * np.sum(np.square(acs), axis=1)
74 |
75 | def nn_constructor(self, model_init_cfg, misc):
76 | pass
77 |
78 | def gp_constructor(self, model_init_cfg):
79 | pass
80 |
81 |
82 | CONFIG_MODULE = PusherConfigModule
83 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/pets_env/reacher.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 |
7 | import numpy as np
8 | from gym import utils
9 | from gym.envs.mujoco import mujoco_env
10 |
11 |
12 | class Reacher3DEnv(mujoco_env.MujocoEnv, utils.EzPickle):
13 |
14 | def __init__(self):
15 | self.viewer = None
16 | utils.EzPickle.__init__(self)
17 | dir_path = os.path.dirname(os.path.realpath(__file__))
18 | self.goal = np.zeros(3)
19 | mujoco_env.MujocoEnv.__init__(self, os.path.join(dir_path, 'assets/reacher3d.xml'), 2)
20 |
21 | def _step(self, a):
22 | self.do_simulation(a, self.frame_skip)
23 | ob = self._get_obs()
24 | reward = -np.sum(np.square(self.get_EE_pos(ob[None]) - self.goal))
25 | reward -= 0.01 * np.square(a).sum()
26 | done = False
27 | return ob, reward, done, dict(reward_dist=0, reward_ctrl=0)
28 |
29 | def viewer_setup(self):
30 | self.viewer.cam.trackbodyid = 1
31 | self.viewer.cam.distance = 2.5
32 | self.viewer.cam.elevation = -30
33 | self.viewer.cam.azimuth = 270
34 |
35 | def reset_model(self):
36 | qpos, qvel = np.copy(self.init_qpos), np.copy(self.init_qvel)
37 | qpos[-3:] += np.random.normal(loc=0, scale=0.1, size=[3])
38 | qvel[-3:] = 0
39 | self.goal = qpos[-3:]
40 | self.set_state(qpos, qvel)
41 | return self._get_obs()
42 |
43 | def _get_obs(self):
44 | raw_obs = np.concatenate([
45 | self.model.data.qpos.flat, self.model.data.qvel.flat[:-3],
46 | ])
47 |
48 | EE_pos = np.reshape(self.get_EE_pos(raw_obs[None]), [-1])
49 |
50 | return np.concatenate([raw_obs, EE_pos])
51 |
52 | def get_EE_pos(self, states):
53 | theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \
54 | states[:, :1], states[:, 1:2], states[:, 2:3], states[:, 3:4], states[:, 4:5], states[:, 5:6], states[:, 6:]
55 |
56 | rot_axis = np.concatenate([np.cos(theta2) * np.cos(theta1), np.cos(theta2) * np.sin(theta1), -np.sin(theta2)],
57 | axis=1)
58 | rot_perp_axis = np.concatenate([-np.sin(theta1), np.cos(theta1), np.zeros(theta1.shape)], axis=1)
59 | cur_end = np.concatenate([
60 | 0.1 * np.cos(theta1) + 0.4 * np.cos(theta1) * np.cos(theta2),
61 | 0.1 * np.sin(theta1) + 0.4 * np.sin(theta1) * np.cos(theta2) - 0.188,
62 | -0.4 * np.sin(theta2)
63 | ], axis=1)
64 |
65 | for length, hinge, roll in [(0.321, theta4, theta3), (0.16828, theta6, theta5)]:
66 | perp_all_axis = np.cross(rot_axis, rot_perp_axis)
67 | x = np.cos(hinge) * rot_axis
68 | y = np.sin(hinge) * np.sin(roll) * rot_perp_axis
69 | z = -np.sin(hinge) * np.cos(roll) * perp_all_axis
70 | new_rot_axis = x + y + z
71 | new_rot_perp_axis = np.cross(new_rot_axis, rot_axis)
72 | new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = \
73 | rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30]
74 | new_rot_perp_axis /= np.linalg.norm(new_rot_perp_axis, axis=1, keepdims=True)
75 | rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis
76 |
77 | return cur_end
78 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/pets_env/reacher_config.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import numpy as np
6 | import gym
7 |
8 |
9 | class ReacherConfigModule:
10 | ENV_NAME = "MBRLReacher3D-v0"
11 | TASK_HORIZON = 150
12 | NTRAIN_ITERS = 100
13 | NROLLOUTS_PER_ITER = 1
14 | PLAN_HOR = 25
15 | INIT_VAR = 0.25
16 | MODEL_IN, MODEL_OUT = 24, 17
17 | GP_NINDUCING_POINTS = 200
18 |
19 | def __init__(self):
20 | self.ENV = gym.make(self.ENV_NAME)
21 | self.ENV.reset()
22 | self.NN_TRAIN_CFG = {"epochs": 5}
23 | self.OPT_CFG = {
24 | "Random": {
25 | "popsize": 2000
26 | },
27 | "CEM": {
28 | "popsize": 400,
29 | "num_elites": 40,
30 | "max_iters": 5,
31 | "alpha": 0.1
32 | },
33 | "GBPRandom": {
34 | "popsize": 2000
35 | },
36 | "GBPCEM": {
37 | "popsize": 400,
38 | "num_elites": 40,
39 | "max_iters": 5,
40 | "alpha": 0.1
41 | },
42 | "PWCEM": {
43 | "popsize": 400,
44 | "num_elites": 40,
45 | "max_iters": 5,
46 | "alpha": 0.1
47 | },
48 | "POCEM": {
49 | "popsize": 400,
50 | "num_elites": 40,
51 | "max_iters": 5,
52 | "alpha": 0.1
53 | }
54 | }
55 | self.UPDATE_FNS = [self.update_goal]
56 |
57 | @staticmethod
58 | def obs_postproc(obs, pred):
59 | return obs + pred
60 |
61 | @staticmethod
62 | def targ_proc(obs, next_obs):
63 | return next_obs - obs
64 |
65 | def update_goal(self, sess=None):
66 | if sess is not None:
67 | self.goal.load(self.ENV.goal, sess)
68 |
69 | def obs_cost_fn(self, obs):
70 | self.ENV.goal = obs[:, 7: 10]
71 | ee_pos = obs[:, -3:]
72 | return np.sum(np.square(ee_pos - self.ENV.goal), axis=1)
73 |
74 | @staticmethod
75 | def ac_cost_fn(acs):
76 | return 0.01 * np.sum(np.square(acs), axis=1)
77 |
78 | def nn_constructor(self, model_init_cfg, misc=None):
79 | pass
80 |
81 | def gp_constructor(self, model_init_cfg):
82 | pass
83 |
84 | @staticmethod
85 | def get_ee_pos(states, are_tensors=False):
86 | theta1, theta2, theta3, theta4, theta5, theta6, theta7 = \
87 | states[:, :1], states[:, 1:2], states[:, 2:3], states[:, 3:4], states[:, 4:5], states[:, 5:6], states[:, 6:]
88 | rot_axis = np.concatenate([np.cos(theta2) * np.cos(theta1), np.cos(theta2) * np.sin(theta1), -np.sin(theta2)],
89 | axis=1)
90 | rot_perp_axis = np.concatenate([-np.sin(theta1), np.cos(theta1), np.zeros(theta1.shape)], axis=1)
91 | cur_end = np.concatenate([
92 | 0.1 * np.cos(theta1) + 0.4 * np.cos(theta1) * np.cos(theta2),
93 | 0.1 * np.sin(theta1) + 0.4 * np.sin(theta1) * np.cos(theta2) - 0.188,
94 | -0.4 * np.sin(theta2)
95 | ], axis=1)
96 |
97 | for length, hinge, roll in [(0.321, theta4, theta3), (0.16828, theta6, theta5)]:
98 | perp_all_axis = np.cross(rot_axis, rot_perp_axis)
99 | x = np.cos(hinge) * rot_axis
100 | y = np.sin(hinge) * np.sin(roll) * rot_perp_axis
101 | z = -np.sin(hinge) * np.cos(roll) * perp_all_axis
102 | new_rot_axis = x + y + z
103 | new_rot_perp_axis = np.cross(new_rot_axis, rot_axis)
104 | new_rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30] = \
105 | rot_perp_axis[np.linalg.norm(new_rot_perp_axis, axis=1) < 1e-30]
106 | new_rot_perp_axis /= np.linalg.norm(new_rot_perp_axis, axis=1, keepdims=True)
107 | rot_axis, rot_perp_axis, cur_end = new_rot_axis, new_rot_perp_axis, cur_end + length * new_rot_axis
108 |
109 | return cur_end
110 |
111 |
112 | CONFIG_MODULE = ReacherConfigModule
113 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/point.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | import gym
4 | import numpy as np
5 |
6 | from mbbl.config import init_path
7 | from mbbl.env import base_env_wrapper as bew
8 | from mbbl.env import env_register
9 | from mbbl.env.gym_env import point_env
10 |
11 |
12 | class env(bew.base_env):
13 |
14 | POINT = ['gym_point']
15 |
16 | def __init__(self, env_name, rand_seed, misc_info):
17 |
18 | super(env, self).__init__(env_name, rand_seed, misc_info)
19 | self._base_path = init_path.get_abs_base_dir()
20 |
21 |
22 | def step(self, action):
23 |
24 | ob, reward, _, info = self._env.step(action)
25 |
26 | self._current_step += 1
27 | if self._current_step > self._env_info['max_length']:
28 | done = True
29 | else:
30 | done = False # will raise warnings -> set logger flag to ignore
31 | self._old_ob = np.array(ob)
32 | return ob, reward, done, info
33 |
34 | def reset(self):
35 | self._current_step = 0
36 | self._old_ob = self._env.reset()
37 | return self._old_ob, 0.0, False, {}
38 |
39 | def _build_env(self):
40 | _env_name = {
41 | 'gym_point': 'Point-v0',
42 | }
43 |
44 | # make the environments
45 | self._env = gym.make(_env_name[self._env_name])
46 | self._env_info = env_register.get_env_info(self._env_name)
47 |
48 | def _set_dynamics_api(self):
49 |
50 | def fdynamics(data_dict):
51 | self._env.env.state = data_dict['start_state']
52 | action = data_dict['action']
53 | return self._env.step(action)[0]
54 | self.fdynamics = fdynamics
55 |
56 | def _set_reward_api(self):
57 |
58 | def reward(data_dict):
59 | action = np.clip(data_dict['action'], -0.025, 0.025)
60 | state = np.clip(data_dict['start_state'] + action, -1, 1)
61 | return -np.linalg.norm(state)
62 |
63 | self.reward = reward
64 |
65 | def _set_groundtruth_api(self):
66 | self._set_dynamics_api()
67 | self._set_reward_api()
68 |
69 |
--------------------------------------------------------------------------------
/mbbl/env/gym_env/point_env.py:
--------------------------------------------------------------------------------
1 | """
2 | This environment is adapted from the Point environment developed by Rocky Duan, Peter Chen, Pieter Abbeel for the Berkeley Deep RL Bootcamp, August 2017. Bootcamp website with slides and lecture videos: https://sites.google.com/view/deep-rl-bootcamp/.
3 | """
4 |
5 | from gym import Env
6 | from gym import spaces
7 | from gym.envs.registration import register
8 | from gym.utils import seeding
9 | import numpy as np
10 |
11 |
12 | class PointEnv(Env):
13 | metadata = {
14 | 'render.modes': ['human', 'rgb_array'],
15 | 'video.frames_per_second': 50
16 | }
17 |
18 | def __init__(self):
19 | self.action_space = spaces.Box(low=-1, high=1, shape=(2,))
20 | self.observation_space = spaces.Box(low=-1, high=1, shape=(2,))
21 |
22 | self._seed()
23 | self.viewer = None
24 | self.state = None
25 |
26 | def _seed(self, seed=None):
27 | self.np_random, seed = seeding.np_random(seed)
28 | return [seed]
29 |
30 | def _step(self, action):
31 | action = np.clip(action, -0.025, 0.025)
32 | self.state = np.clip(self.state + action, -1, 1)
33 | return np.array(self.state), -np.linalg.norm(self.state), False, {}
34 |
35 | def _reset(self):
36 | while True:
37 | self.state = self.np_random.uniform(low=-1, high=1, size=(2,))
38 | # Sample states that are far away
39 | if np.linalg.norm(self.state) > 0.9:
40 | break
41 | return np.array(self.state)
42 |
43 | def _render(self, mode='human', close=False):
44 | if close:
45 | if self.viewer is not None:
46 | self.viewer.close()
47 | self.viewer = None
48 | return
49 |
50 | screen_width = 800
51 | screen_height = 800
52 |
53 | if self.viewer is None:
54 | from gym.envs.classic_control import rendering
55 | self.viewer = rendering.Viewer(screen_width, screen_height)
56 |
57 | agent = rendering.make_circle(
58 | min(screen_height, screen_width) * 0.03)
59 | origin = rendering.make_circle(
60 | min(screen_height, screen_width) * 0.03)
61 | trans = rendering.Transform(translation=(0, 0))
62 | agent.add_attr(trans)
63 | self.trans = trans
64 | agent.set_color(1, 0, 0)
65 | origin.set_color(0, 0, 0)
66 | origin.add_attr(rendering.Transform(
67 | translation=(screen_width // 2, screen_height // 2)))
68 | self.viewer.add_geom(agent)
69 | self.viewer.add_geom(origin)
70 |
71 | # self.trans.set_translation(0, 0)
72 | self.trans.set_translation(
73 | (self.state[0] + 1) / 2 * screen_width,
74 | (self.state[1] + 1) / 2 * screen_height,
75 | )
76 |
77 | return self.viewer.render(return_rgb_array=mode == 'rgb_array')
78 |
79 | register(
80 | 'Point-v0',
81 | entry_point='env.gym_env.point_env:PointEnv',
82 | timestep_limit=40,
83 | )
84 |
--------------------------------------------------------------------------------
/mbbl/env/render.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Wed Aug 8 14:55:24 2018
5 |
6 | @author: matthewszhang
7 | """
8 | import pickle
9 | import time
10 |
11 | import numpy as np
12 |
13 | from mbbl.env.env_register import make_env
14 |
15 | class rendering(object):
16 | def __init__(self, env_name):
17 | self.env, _ = make_env(env_name, 1234)
18 | self.env.reset()
19 |
20 | def render(self, transition):
21 | self.env.fdynamics(transition)
22 | self.env._env.render()
23 | time.sleep(1/30)
24 |
25 | def get_rendering_config():
26 | import argparse
27 | parser = argparse.ArgumentParser(description='Get rendering from states')
28 |
29 | parser.add_argument("--task", type=str, default='gym_cheetah',
30 | help='the mujoco environment to test')
31 | parser.add_argument("--render_file", type=str, default='ep-0',
32 | help='pickle outputs to render')
33 | return parser
34 |
35 | def main():
36 | parser = get_rendering_config()
37 | args = parser.parse_args()
38 |
39 | render_env = rendering(args.task)
40 | with open(args.render_file, 'rb') as pickle_load:
41 | transition_data = pickle.load(pickle_load)
42 |
43 | for transition in transition_data:
44 | render_transition = {
45 | 'start_state':np.asarray(transition['start_state']),
46 | 'action':np.asarray(transition['action']),
47 | }
48 | render_env.render(render_transition)
49 |
50 | if __name__ == '__main__':
51 | main()
52 |
--------------------------------------------------------------------------------
/mbbl/env/render_wrapper.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Wed Aug 8 12:03:11 2018
5 |
6 | @author: matthewszhang
7 | """
8 | import os.path as osp
9 | import pickle
10 | import re
11 |
12 | from mbbl.util.common import logger
13 |
14 | RENDER_EPISODE = 100
15 |
16 | class render_wrapper(object):
17 | def __init__(self, env_name, *args, **kwargs):
18 | remove_render = re.compile(r'__render$')
19 |
20 | self.env_name = remove_render.sub('', env_name)
21 | from mbbl.env import env_register
22 | self.env, _ = env_register.make_env(self.env_name, *args, **kwargs)
23 | self.episode_number = 0
24 |
25 | # Getting path from logger
26 | self.path = logger._get_path()
27 | self.obs_buffer = []
28 |
29 | def step(self, action, *args, **kwargs):
30 | if (self.episode_number - 1) % RENDER_EPISODE == 0:
31 | self.obs_buffer.append({
32 | 'start_state':self.env._old_ob.tolist(),
33 | 'action':action.tolist()
34 | })
35 | return self.env.step(action, *args, **kwargs)
36 |
37 | def reset(self, *args, **kwargs):
38 | self.episode_number += 1
39 | if self.obs_buffer:
40 | file_name = osp.join(self.path, 'ep_{}.p'.format(self.episode_number))
41 | with open(file_name, 'wb') as pickle_file:
42 | pickle.dump(self.obs_buffer, pickle_file, protocol=pickle.HIGHEST_PROTOCOL)
43 | self.obs_buffer = []
44 |
45 | return self.env.reset(*args, **kwargs)
46 |
47 | def fdynamics(self, *args, **kwargs):
48 | return self.env.fdynamics(*args, **kwargs)
49 |
50 | def reward(self, *args, **kwargs):
51 | return self.env.reward(*args, **kwargs)
52 |
53 | def reward_derivative(self, *args, **kwargs):
54 | return self.env.reward_derivative(*args, **kwargs)
55 |
--------------------------------------------------------------------------------
/mbbl/network/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/network/__init__.py
--------------------------------------------------------------------------------
/mbbl/network/dynamics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/network/dynamics/__init__.py
--------------------------------------------------------------------------------
/mbbl/network/dynamics/stochastic_forward_dynamics.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @author:
3 | # Tingwu Wang
4 | # @brief:
5 | # -----------------------------------------------------------------------------
6 | from .base_dynamics import base_dynamics_network
7 | from mbbl.config import init_path
8 | from mbbl.util.common import logger
9 | from mbbl.util.common import tf_networks
10 |
11 |
12 | class dynamics_network(base_dynamics_network):
13 | '''
14 | @brief:
15 | '''
16 |
17 | def __init__(self, args, session, name_scope,
18 | observation_size, action_size):
19 | '''
20 | @input:
21 | @ob_placeholder:
22 | if this placeholder is not given, we will make one in this
23 | class.
24 |
25 | @trainable:
26 | If it is set to true, then the policy weights will be
27 | trained. It is useful when the class is a subnet which
28 | is not trainable
29 | '''
30 | raise NotImplementedError
31 | super(dynamics_network, self).__init__(
32 | args, session, name_scope, observation_size, action_size
33 | )
34 | self._base_dir = init_path.get_abs_base_dir()
35 | self._debug_it = 0
36 |
--------------------------------------------------------------------------------
/mbbl/network/policy/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/network/policy/__init__.py
--------------------------------------------------------------------------------
/mbbl/network/policy/base_policy.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @author:
3 | # Tingwu Wang
4 | # -----------------------------------------------------------------------------
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from mbbl.config import init_path
9 | from mbbl.util.common import whitening_util
10 | from mbbl.util.common import tf_utils
11 |
12 |
13 | class base_policy_network(object):
14 | '''
15 | @brief:
16 | In this object class, we define the network structure, the restore
17 | function and save function.
18 | It will only be called in the agent/agent.py
19 | '''
20 |
21 | def __init__(self, args, session, name_scope,
22 | observation_size, action_size):
23 | self.args = args
24 |
25 | self._session = session
26 | self._name_scope = name_scope
27 |
28 | self._observation_size = observation_size
29 | self._action_size = action_size
30 |
31 | self._task_name = args.task_name
32 | self._network_shape = args.policy_network_shape
33 |
34 | self._npr = np.random.RandomState(args.seed)
35 |
36 | self._whitening_operator = {}
37 | self._whitening_variable = []
38 | self._base_dir = init_path.get_abs_base_dir()
39 |
40 | def build_network(self):
41 | raise NotImplementedError
42 |
43 | def build_loss(self):
44 | raise NotImplementedError
45 |
46 | def _build_ph(self):
47 |
48 | # initialize the running mean and std (whitening)
49 | whitening_util.add_whitening_operator(
50 | self._whitening_operator, self._whitening_variable,
51 | 'state', self._observation_size
52 | )
53 |
54 | # initialize the input placeholder
55 | self._input_ph = {
56 | 'start_state': tf.placeholder(
57 | tf.float32, [None, self._observation_size], name='start_state'
58 | )
59 | }
60 |
61 | def get_input_placeholder(self):
62 | return self._input_ph
63 |
64 | def get_weights(self):
65 | return None
66 |
67 | def set_weights(self, weights_dict):
68 | pass
69 |
70 | def _set_var_list(self):
71 | # collect the tf variable and the trainable tf variable
72 | self._trainable_var_list = [var for var in tf.trainable_variables()
73 | if self._name_scope in var.name]
74 |
75 | self._all_var_list = [var for var in tf.global_variables()
76 | if self._name_scope in var.name]
77 |
78 | # the weights that actually matter
79 | self._network_var_list = \
80 | self._trainable_var_list + self._whitening_variable
81 |
82 | self._set_network_weights = tf_utils.set_network_weights(
83 | self._session, self._network_var_list, self._name_scope
84 | )
85 |
86 | self._get_network_weights = tf_utils.get_network_weights(
87 | self._session, self._network_var_list, self._name_scope
88 | )
89 |
90 | def load_checkpoint(self, ckpt_path):
91 | pass
92 |
93 | def save_checkpoint(self, ckpt_path):
94 | pass
95 |
96 | def get_whitening_operator(self):
97 | return self._whitening_operator
98 |
99 | def _set_whitening_var(self, whitening_stats):
100 | whitening_util.set_whitening_var(
101 | self._session, self._whitening_operator, whitening_stats, ['state']
102 | )
103 |
104 | def train(self, data_dict, replay_buffer, training_info={}):
105 | raise NotImplementedError
106 |
107 | def eval(self, data_dict):
108 | raise NotImplementedError
109 |
110 | def act(self, data_dict):
111 | raise NotImplementedError
112 |
--------------------------------------------------------------------------------
/mbbl/network/policy/cem_policy.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @author:
3 | # Tingwu Wang
4 | # -----------------------------------------------------------------------------
5 | from .base_policy import base_policy_network
6 | from mbbl.config import init_path
7 |
8 |
9 | class policy_network(base_policy_network):
10 | '''
11 | @brief:
12 | In this object class, we define the network structure, the restore
13 | function and save function.
14 | It will only be called in the agent/agent.py
15 | '''
16 |
17 | def __init__(self, args, session, name_scope,
18 | observation_size, action_size):
19 |
20 | super(policy_network, self).__init__(
21 | args, session, name_scope, observation_size, action_size
22 | )
23 | self._base_dir = init_path.get_abs_base_dir()
24 |
25 | def build_network(self):
26 | pass
27 |
28 | def build_loss(self):
29 | pass
30 |
31 | def train(self, data_dict, replay_buffer, training_info={}):
32 | pass
33 |
34 | def act(self, data_dict):
35 | raise NotImplementedError
36 |
--------------------------------------------------------------------------------
/mbbl/network/policy/random_policy.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @author:
3 | # Tingwu Wang
4 | # -----------------------------------------------------------------------------
5 | from .base_policy import base_policy_network
6 | from mbbl.config import init_path
7 |
8 |
9 | class policy_network(base_policy_network):
10 | '''
11 | @brief:
12 | In this object class, we define the network structure, the restore
13 | function and save function.
14 | It will only be called in the agent/agent.py
15 | '''
16 |
17 | def __init__(self, args, session, name_scope,
18 | observation_size, action_size):
19 |
20 | super(policy_network, self).__init__(
21 | args, session, name_scope, observation_size, action_size
22 | )
23 | self._base_dir = init_path.get_abs_base_dir()
24 |
25 | def build_network(self):
26 | pass
27 |
28 | def build_loss(self):
29 | pass
30 |
31 | def train(self, data_dict, replay_buffer, training_info={}):
32 | pass
33 |
34 | def act(self, data_dict):
35 | # action range from -1 to 1
36 | action = self._npr.uniform(
37 | -1, 1, [len(data_dict['start_state']), self._action_size]
38 | )
39 | return action, -1, -1
40 |
--------------------------------------------------------------------------------
/mbbl/network/reward/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/network/reward/__init__.py
--------------------------------------------------------------------------------
/mbbl/network/reward/base_reward.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @author:
3 | # Tingwu Wang
4 | # -----------------------------------------------------------------------------
5 | import tensorflow as tf
6 | import numpy as np
7 |
8 | from mbbl.config import init_path
9 | from mbbl.util.common import whitening_util
10 |
11 |
12 | class base_reward_network(object):
13 | '''
14 | @brief:
15 | '''
16 |
17 | def __init__(self, args, session, name_scope,
18 | observation_size, action_size):
19 | self.args = args
20 |
21 | self._session = session
22 | self._name_scope = name_scope
23 |
24 | self._observation_size = observation_size
25 | self._action_size = action_size
26 | self._output_size = 1
27 |
28 | self._task_name = args.task_name
29 | self._network_shape = args.reward_network_shape
30 |
31 | self._npr = np.random.RandomState(args.seed)
32 |
33 | self._whitening_operator = {}
34 | self._whitening_variable = []
35 | self._base_dir = init_path.get_abs_base_dir()
36 |
37 | def build_network(self):
38 | pass
39 |
40 | def build_loss(self):
41 | pass
42 |
43 | def get_weights(self):
44 | return None
45 |
46 | def set_weights(self, weights_dict):
47 | pass
48 |
49 | def _build_ph(self):
50 |
51 | # initialize the running mean and std (whitening)
52 | whitening_util.add_whitening_operator(
53 | self._whitening_operator, self._whitening_variable,
54 | 'state', self._observation_size
55 | )
56 |
57 | # initialize the input placeholder
58 | self._input_ph = {
59 | 'start_state': tf.placeholder(
60 | tf.float32, [None, self._observation_size], name='start_state'
61 | )
62 | }
63 |
64 | def get_input_placeholder(self):
65 | return self._input_ph
66 |
67 | def _set_var_list(self):
68 | # collect the tf variable and the trainable tf variable
69 | self._trainable_var_list = [var for var in tf.trainable_variables()
70 | if self._name_scope in var.name]
71 |
72 | self._all_var_list = [var for var in tf.global_variables()
73 | if self._name_scope in var.name]
74 |
75 | # the weights that actually matter
76 | self._network_var_list = \
77 | self._trainable_var_list + self._whitening_variable
78 |
79 | def load_checkpoint(self, ckpt_path):
80 | pass
81 |
82 | def save_checkpoint(self, ckpt_path):
83 | pass
84 |
85 | def get_whitening_operator(self):
86 | return self._whitening_operator
87 |
88 | def _set_whitening_var(self, whitening_stats):
89 | whitening_util.set_whitening_var(
90 | self._session, self._whitening_operator, whitening_stats, ['state']
91 | )
92 |
93 | def train(self, data_dict, replay_buffer, training_info={}):
94 | pass
95 |
96 | def eval(self, data_dict):
97 | raise NotImplementedError
98 |
99 | def pred(self, data_dict):
100 | raise NotImplementedError
101 |
102 | def use_groundtruth_network(self):
103 | return False
104 |
--------------------------------------------------------------------------------
/mbbl/network/reward/deepmimic_reward.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @author:
3 | # Tingwu Wang
4 | # @brief:
5 | # -----------------------------------------------------------------------------
6 |
7 | from .base_reward import base_reward_network
8 | from mbbl.config import init_path
9 | from mbbl.util.il import expert_data_util
10 | # import numpy as np
11 | # from mbbl.util.common import tf_networks
12 | # from mbbl.util.common import tf_utils
13 | # from mbbl.util.common import logger
14 | # import tensorflow as tf
15 |
16 |
17 | class reward_network(base_reward_network):
18 | '''
19 | @brief:
20 | '''
21 |
22 | def __init__(self, args, session, name_scope,
23 | observation_size, action_size):
24 | '''
25 | @input:
26 | @ob_placeholder:
27 | if this placeholder is not given, we will make one in this
28 | class.
29 |
30 | @trainable:
31 | If it is set to true, then the policy weights will be
32 | trained. It is useful when the class is a subnet which
33 | is not trainable
34 | '''
35 | super(reward_network, self).__init__(
36 | args, session, name_scope, observation_size, action_size
37 | )
38 | self._base_dir = init_path.get_abs_base_dir()
39 |
40 | # load the expert data
41 | self._expert_trajectory_obs = expert_data_util.load_expert_trajectory(
42 | self.args.expert_data_name, self.args.traj_episode_num
43 | )
44 |
45 | def build_network(self):
46 | """ @Brief:
47 | in deepmimic, we don't need a neural network to produce reward
48 | """
49 | pass
50 |
51 | def build_loss(self):
52 | """ @Brief:
53 | Similarly, in deepmimic, we don't need a neural network
54 | """
55 | pass
56 |
57 | def train(self, data_dict, replay_buffer, training_info={}):
58 | """ @brief:
59 | """
60 | return {}
61 |
62 | def eval(self, data_dict):
63 | pass
64 |
65 | def use_groundtruth_network(self):
66 | return False
67 |
68 | def generate_rewards(self, rollout_data):
69 | """@brief:
70 | This function should be called before _preprocess_data
71 | """
72 | for path in rollout_data:
73 | # the predicted value function (baseline function)
74 | path["raw_rewards"] = path['rewards'] # preserve the raw reward
75 |
76 | # TODO: generate the r = r_task + r_imitation, see the paper,
77 | # use self._expert_trajectory_obs
78 | path["rewards"] = 0.0 * path['raw_rewards']
79 | return rollout_data
80 |
--------------------------------------------------------------------------------
/mbbl/network/reward/groundtruth_reward.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @author:
3 | # Tingwu Wang
4 | # @brief:
5 | # -----------------------------------------------------------------------------
6 | import numpy as np
7 |
8 | from .base_reward import base_reward_network
9 | from mbbl.config import init_path
10 | from mbbl.env import env_register
11 |
12 |
13 | class reward_network(base_reward_network):
14 | '''
15 | @brief:
16 | '''
17 |
18 | def __init__(self, args, session, name_scope,
19 | observation_size, action_size):
20 | '''
21 | @input:
22 | @ob_placeholder:
23 | if this placeholder is not given, we will make one in this
24 | class.
25 |
26 | @trainable:
27 | If it is set to true, then the policy weights will be
28 | trained. It is useful when the class is a subnet which
29 | is not trainable
30 | '''
31 | super(reward_network, self).__init__(
32 | args, session, name_scope, observation_size, action_size
33 | )
34 | self._base_dir = init_path.get_abs_base_dir()
35 |
36 | def build_network(self):
37 | self._env, self._env_info = env_register.make_env(
38 | self.args.task_name, self._npr.randint(0, 9999)
39 | )
40 |
41 | def build_loss(self):
42 | pass
43 |
44 | def train(self, data_dict, replay_buffer, training_info={}):
45 | pass
46 |
47 | def eval(self, data_dict):
48 | pass
49 |
50 | def pred(self, data_dict):
51 | reward = []
52 | for i_data in range(len(data_dict['action'])):
53 | key_list = ['start_state', 'action', 'next_state'] \
54 | if 'next_state' in data_dict else ['start_state', 'action']
55 |
56 | i_reward = self._env.reward(
57 | {key: data_dict[key][i_data] for key in key_list}
58 | )
59 | reward.append(i_reward)
60 | return np.stack(reward), -1, -1
61 |
62 | def use_groundtruth_network(self):
63 | return True
64 |
65 | def get_derivative(self, data_dict, target):
66 | derivative_data = {}
67 | for derivative_key in target:
68 | derivative_data[derivative_key] = \
69 | self._env.reward_derivative(data_dict, derivative_key)
70 | return derivative_data
71 |
--------------------------------------------------------------------------------
/mbbl/sampler/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/sampler/__init__.py
--------------------------------------------------------------------------------
/mbbl/sampler/base_sampler.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # Tingwu Wang
3 | # -----------------------------------------------------------------------------
4 | import multiprocessing
5 |
6 | import numpy as np
7 |
8 | from mbbl.config import init_path
9 | from mbbl.env import env_register
10 | from mbbl.util.common import parallel_util
11 |
12 |
13 | class base_sampler(object):
14 |
15 | def __init__(self, args, worker_type, network_type):
16 | '''
17 | @brief:
18 | the master agent has several actors (or samplers) to do the
19 | sampling for it.
20 | '''
21 | self.args = args
22 | self._npr = np.random.RandomState(args.seed + 23333)
23 | self._observation_size, self._action_size, _ = \
24 | env_register.io_information(self.args.task)
25 | self._worker_type = worker_type
26 | self._network_type = network_type
27 |
28 | # init the multiprocess actors
29 | self._task_queue = multiprocessing.JoinableQueue()
30 | self._result_queue = multiprocessing.Queue()
31 | self._init_workers()
32 | self._build_env()
33 | self._base_path = init_path.get_abs_base_dir()
34 |
35 | self._current_iteration = 0
36 |
37 | def set_weights(self, weights):
38 | for i_agent in range(self.args.num_workers):
39 | self._task_queue.put((parallel_util.AGENT_SET_WEIGHTS,
40 | weights))
41 | self._task_queue.join()
42 |
43 | def end(self):
44 | for i_agent in range(self.args.num_workers):
45 | self._task_queue.put((parallel_util.END_ROLLOUT_SIGNAL, None))
46 |
47 | def rollouts_using_worker_planning(self, num_timesteps=None,
48 | use_random_action=False):
49 | """ @brief:
50 | Workers are only used to do the planning.
51 | The sampler will choose the control signals and interact with
52 | the env.
53 |
54 | Run the experiments until a total of @timesteps_per_batch
55 | timesteps are collected.
56 | @return:
57 | {'data': None}
58 | """
59 | raise NotImplementedError
60 |
61 | def rollouts_using_worker_playing(self, num_timesteps=None,
62 | use_random_action=False,
63 | using_true_env=False):
64 | """ @brief:
65 | Workers are used to do the planning, choose the control signals
66 | and interact with the env. The sampler will choose the control
67 | signals and interact with the env.
68 |
69 | Run the experiments until a total of @timesteps_per_batch
70 | timesteps are collected.
71 | @input:
72 | If @using_true_env is set to True, the worker will interact with
73 | the environment. Otherwise it will interact with the env it
74 | models (the trainable dynamics and reward)
75 | @return:
76 | {'data': None}
77 | """
78 | raise NotImplementedError
79 |
80 | def _init_workers(self):
81 | '''
82 | @brief: init the actors and start the multiprocessing
83 | '''
84 | self._actors = []
85 |
86 | # the sub actor that only do the sampling
87 | for i in range(self.args.num_workers):
88 | self._actors.append(
89 | self._worker_type.worker(
90 | self.args, self._observation_size, self._action_size,
91 | self._network_type, self._task_queue, self._result_queue, i,
92 | )
93 | )
94 |
95 | # todo: start
96 | for i_actor in self._actors:
97 | i_actor.start()
98 |
99 | def _build_env(self):
100 | self._env, self._env_info = env_register.make_env(
101 | self.args.task, self._npr.randint(0, 999999),
102 | {'allow_monitor': self.args.monitor}
103 | )
104 |
--------------------------------------------------------------------------------
/mbbl/sampler/readme.md:
--------------------------------------------------------------------------------
1 | rs: random shooting method
2 |
--------------------------------------------------------------------------------
/mbbl/sampler/singletask_metrpo_sampler.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # Tingwu Wang
3 | # -----------------------------------------------------------------------------
4 | import numpy as np
5 |
6 | from .base_sampler import base_sampler
7 | from mbbl.config import init_path
8 | from mbbl.util.common import logger
9 | from mbbl.util.common import parallel_util
10 |
11 |
12 | class sampler(base_sampler):
13 |
14 | def __init__(self, args, worker_type, network_type):
15 | '''
16 | @brief:
17 | the master agent has several actors (or samplers) to do the
18 | sampling for it.
19 | '''
20 | super(sampler, self).__init__(args, worker_type, network_type)
21 | self._base_path = init_path.get_abs_base_dir()
22 | self._avg_episode_len = self._env_info['max_length']
23 |
24 | def rollouts_using_worker_playing(self, num_timesteps=None,
25 | use_random_action=False,
26 | use_true_env=False):
27 | """ @brief:
28 | In this case, the sampler will call workers to generate data
29 | """
30 | self._current_iteration += 1
31 | num_timesteps_received = 0
32 | timesteps_needed = self.args.timesteps_per_batch \
33 | if num_timesteps is None else num_timesteps
34 | rollout_data = []
35 |
36 | while True:
37 | # how many episodes are expected to complete the current dataset?
38 | num_estimiated_episode = int(
39 | np.ceil(timesteps_needed / self._avg_episode_len)
40 | )
41 |
42 | # send out the task for each worker to play
43 | for _ in range(num_estimiated_episode):
44 | self._task_queue.put((parallel_util.WORKER_PLAYING,
45 | {'use_true_env': use_true_env,
46 | 'use_random_action': use_random_action}))
47 | self._task_queue.join()
48 |
49 | # collect the data
50 | for _ in range(num_estimiated_episode):
51 | traj_episode = self._result_queue.get()
52 | rollout_data.append(traj_episode)
53 | num_timesteps_received += len(traj_episode['rewards'])
54 |
55 | # update average timesteps per episode and timestep remains
56 | self._avg_episode_len = \
57 | float(num_timesteps_received) / len(rollout_data)
58 | timesteps_needed = self.args.timesteps_per_batch - \
59 | num_timesteps_received
60 |
61 | logger.info('Finished {}th episode'.format(len(rollout_data)))
62 | if timesteps_needed <= 0 or self.args.test:
63 | break
64 |
65 | logger.info('{} timesteps from {} episodes collected'.format(
66 | num_timesteps_received, len(rollout_data))
67 | )
68 |
69 | return {'data': rollout_data}
70 |
--------------------------------------------------------------------------------
/mbbl/sampler/singletask_random_sampler.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # Tingwu Wang
3 | # -----------------------------------------------------------------------------
4 |
5 | from mbbl.config import init_path
6 | from mbbl.sampler import singletask_sampler
7 | from mbbl.util.common import logger
8 |
9 |
10 | class sampler(singletask_sampler.sampler):
11 |
12 | def __init__(self, args, worker_type, network_type):
13 | '''
14 | @brief:
15 | the master agent has several actors (or samplers) to do the
16 | sampling for it.
17 | '''
18 | super(sampler, self).__init__(args, worker_type, network_type)
19 | self._base_path = init_path.get_abs_base_dir()
20 |
21 | def rollouts_using_worker_planning(self, num_timesteps=None,
22 | use_random_action=False):
23 | ''' @brief:
24 | Run the experiments until a total of @timesteps_per_batch
25 | timesteps are collected.
26 | '''
27 | self._current_iteration += 1
28 | num_timesteps_received = 0
29 | timesteps_needed = self.args.timesteps_per_batch \
30 | if num_timesteps is None else num_timesteps
31 | rollout_data = []
32 |
33 | while True:
34 | # init the data
35 | traj_episode = self._play(use_random_action)
36 | logger.info('done with episode')
37 | rollout_data.append(traj_episode)
38 | num_timesteps_received += len(traj_episode['rewards'])
39 |
40 | # update average timesteps per episode
41 | timesteps_needed = self.args.timesteps_per_batch - \
42 | num_timesteps_received
43 |
44 | if timesteps_needed <= 0 or self.args.test:
45 | break
46 |
47 | logger.info('{} timesteps from {} episodes collected'.format(
48 | num_timesteps_received, len(rollout_data))
49 | )
50 |
51 | return {'data': rollout_data}
52 |
53 | def _act(self, state, control_info={'use_random_action': False}):
54 | # use random policy
55 | action = self._npr.uniform(-1, 1, [self._action_size])
56 | return action, [-1], [-1]
57 |
--------------------------------------------------------------------------------
/mbbl/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/trainer/__init__.py
--------------------------------------------------------------------------------
/mbbl/trainer/gail_trainer.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # The optimization agent is responsible for doing the updates.
4 | # @author:
5 | # ------------------------------------------------------------------------------
6 | from .base_trainer import base_trainer
7 | from mbbl.util.common import logger
8 | import numpy as np
9 | from collections import OrderedDict
10 |
11 |
12 | class trainer(base_trainer):
13 |
14 | def __init__(self, args, network_type, task_queue, result_queue,
15 | name_scope='trainer'):
16 | # the base agent
17 | super(trainer, self).__init__(
18 | args=args, network_type=network_type,
19 | task_queue=task_queue, result_queue=result_queue,
20 | name_scope=name_scope
21 | )
22 |
23 | def _update_parameters(self, rollout_data, training_info):
24 | # get the observation list
25 | self._update_whitening_stats(rollout_data)
26 |
27 | # generate the reward from discriminator
28 | rollout_data = self._network['reward'][0].generate_rewards(rollout_data)
29 |
30 | training_data = self._preprocess_data(rollout_data)
31 | training_stats = OrderedDict()
32 | training_stats['avg_reward'] = training_data['avg_reward']
33 | training_stats['avg_reward_std'] = training_data['avg_reward_std']
34 |
35 | assert 'reward' in training_info['network_to_train']
36 | # train the policy
37 | for key in training_info['network_to_train']:
38 | for i_network in range(self._num_model_ensemble[key]):
39 | i_stats = self._network[key][i_network].train(
40 | training_data, self._replay_buffer, training_info={}
41 | )
42 | if i_stats is not None:
43 | training_stats.update(i_stats)
44 | self._replay_buffer.add_data(training_data)
45 |
46 | # record the actual reward (not from the discriminator)
47 | self._get_groundtruth_reward(rollout_data, training_stats)
48 | return training_stats
49 |
50 | def _get_groundtruth_reward(self, rollout_data, training_stats):
51 |
52 | for i_episode in rollout_data:
53 | i_episode['raw_episodic_reward'] = sum(i_episode['raw_rewards'])
54 | avg_reward = np.mean([i_episode['raw_episodic_reward']
55 | for i_episode in rollout_data])
56 | logger.info('Raw reward: {}'.format(avg_reward))
57 | training_stats['RAW_reward'] = avg_reward
58 |
--------------------------------------------------------------------------------
/mbbl/trainer/metrpo_trainer.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # The optimization agent is responsible for doing the updates.
4 | # @author:
5 | # ------------------------------------------------------------------------------
6 | from .base_trainer import base_trainer
7 |
8 |
9 | class trainer(base_trainer):
10 |
11 | def __init__(self, args, network_type, task_queue, result_queue,
12 | name_scope='trainer'):
13 | # the base agent
14 | super(trainer, self).__init__(
15 | args=args,
16 | network_type=network_type,
17 | task_queue=task_queue,
18 | result_queue=result_queue,
19 | name_scope=name_scope
20 | )
21 | # self._base_path = init_path.get_abs_base_dir()
22 |
23 | def _update_parameters(self, rollout_data, training_info):
24 | # get the observation list
25 | self._update_whitening_stats(rollout_data)
26 | training_data = self._preprocess_data(rollout_data)
27 | training_stats = {'avg_reward': training_data['avg_reward']}
28 |
29 | # TODO(GD): add for loop to train policy?
30 | # train the policy
31 | for key in training_info['network_to_train']:
32 | for i_network in range(self._num_model_ensemble[key]):
33 | i_stats = self._network[key][i_network].train(
34 | training_data, self._replay_buffer, training_info={}
35 | )
36 | if i_stats is not None:
37 | training_stats.update(i_stats)
38 | self._replay_buffer.add_data(training_data)
39 | return training_stats
40 |
--------------------------------------------------------------------------------
/mbbl/trainer/readme.md:
--------------------------------------------------------------------------------
1 | rs: random shooting method
2 |
--------------------------------------------------------------------------------
/mbbl/trainer/shooting_trainer.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # The optimization agent is responsible for doing the updates.
4 | # @author:
5 | # ------------------------------------------------------------------------------
6 | from .base_trainer import base_trainer
7 |
8 |
9 | class trainer(base_trainer):
10 |
11 | def __init__(self, args, network_type, task_queue, result_queue,
12 | name_scope='trainer'):
13 | # the base agent
14 | super(trainer, self).__init__(
15 | args=args, network_type=network_type,
16 | task_queue=task_queue, result_queue=result_queue,
17 | name_scope=name_scope
18 | )
19 | # self._base_path = init_path.get_abs_base_dir()
20 |
21 | def _update_parameters(self, rollout_data, training_info):
22 | # get the observation list
23 | self._update_whitening_stats(rollout_data)
24 | training_data = self._preprocess_data(rollout_data)
25 | training_stats = {'avg_reward': training_data['avg_reward'],
26 | 'avg_reward_std': training_data['avg_reward_std']}
27 |
28 | # train the policy
29 | for key in training_info['network_to_train']:
30 | for i_network in range(self._num_model_ensemble[key]):
31 | i_stats = self._network[key][i_network].train(
32 | training_data, self._replay_buffer, training_info={}
33 | )
34 | if i_stats is not None:
35 | training_stats.update(i_stats)
36 | self._replay_buffer.add_data(training_data)
37 | return training_stats
38 |
--------------------------------------------------------------------------------
/mbbl/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/util/__init__.py
--------------------------------------------------------------------------------
/mbbl/util/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/util/common/__init__.py
--------------------------------------------------------------------------------
/mbbl/util/common/fpdb.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief:
3 | # some helper functions about pdb.
4 | # -----------------------------------------------------------------------------
5 |
6 | import sys
7 | import pdb
8 |
9 |
10 | class fpdb(pdb.Pdb):
11 | '''
12 | @brief:
13 | a Pdb subclass that may be used from a import forked multiprocessing
14 | child
15 | '''
16 |
17 | def interaction(self, *args, **kwargs):
18 | _stdin = sys.stdin
19 | try:
20 | sys.stdin = open('/dev/stdin')
21 | pdb.Pdb.interaction(self, *args, **kwargs)
22 | finally:
23 | sys.stdin = _stdin
24 |
25 |
26 | if __name__ == '__main__':
27 | # how to use it somewhere outside in a multi-process project.
28 | # from util import fpdb
29 | fpdb = fpdb.fpdb()
30 | fpdb.set_trace()
31 |
--------------------------------------------------------------------------------
/mbbl/util/common/ggnn_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | GGNN utils
3 | '''
4 | import numpy as np
5 |
6 |
7 | def compact2sparse_representation(mat, total_edge_type):
8 | '''
9 | '''
10 | N, _ = mat.shape
11 |
12 | sparse_mat = np.zeros((N, N * total_edge_type * 2))
13 |
14 | for i in range(N):
15 | for j in range(N):
16 | if mat[i, j] == 0: continue
17 |
18 | edge_type = mat[i, j]
19 | _from = i
20 | _to = j
21 |
22 | in_x = j
23 | in_y = i + N * (edge_type - 1)
24 | sparse_mat[int(in_x), int(in_y)] = 1
25 |
26 | # fill out
27 | out_x = i
28 | out_y = total_edge_type + j + N * (edge_type - 1)
29 | sparse_mat[int(out_x), int(out_y)] = 1
30 |
31 | return sparse_mat.astype('int')
32 |
33 | def manual_parser(env):
34 |
35 | def _hopper():
36 | graph = np.array([
37 | [0, 1, 0, 0],
38 | [1, 0, 1, 0],
39 | [0, 1, 0, 1],
40 | [0, 0, 1, 0]
41 | ])
42 |
43 | geom_meta_info = np.array([
44 | [0.9, 0, 0, 1.45, 0, 0, 1.05, 0.05],
45 | [0.9, 0, 0, 1.05, 0, 0, 0.6, 0.05],
46 | [0.9, 0, 0, 0.6, 0, 0, 0.1, 0.04],
47 | [2.0, -0.14, 0, 0.1, 0.26, 0, 0.1, 0.06]
48 | ])
49 |
50 | joint_meta_info = np.array([
51 | [0, 0, 0, 0, 0, 0],
52 | [0, 0, 1.05, -150, 0, 200],
53 | [0, 0, 0.6, -150, 0, 200],
54 | [0, 0, 0.1, -45, 45, 200]
55 | ])
56 |
57 | meta_info = np.hstack( (geom_meta_info, joint_meta_info) )
58 |
59 | ob_assign = np.array([
60 | 0, 0, 1, 2, 3,
61 | 0, 0, 0, 1, 2, 3
62 | ])
63 |
64 | ac_assign = np.array([
65 | 1, 2, 3
66 | ])
67 |
68 | ggnn_info = {}
69 | ggnn_info['n_node'] = graph.shape[0]
70 | ggnn_info['n_node_type'] = 1
71 |
72 | ggnn_info['node_anno_dim'] = meta_info.shape[1]
73 | ggnn_info['node_state_dim'] = 6 # observation space
74 | ggnn_info['node_embed_dim'] = 512 # ggnn hidden states dim - node_state_dim
75 |
76 | ggnn_info['n_edge_type'] = graph.max()
77 | ggnn_info['output_dim'] = 5 # the number of occurrences in ob_assign
78 |
79 | return graph, ob_assign, ac_assign, meta_info, ggnn_info
80 |
81 | def _half_cheetah():
82 | raise NotImplementedError
83 |
84 | def _walker():
85 | raise NotImplementedError
86 |
87 | def _ant():
88 | raise NotImplementedError
89 |
90 | if env == 'gym_hopper': return _hopper()
91 | else:
92 | raise NotImplementedError
93 |
94 |
--------------------------------------------------------------------------------
/mbbl/util/common/logger.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief:
3 | # The logger here will be called all across the project. It is inspired
4 | # by Yuxin Wu (ppwwyyxx@gmail.com)
5 | #
6 | # @author:
7 | # Tingwu Wang, 2017, Feb, 20th
8 | # -----------------------------------------------------------------------------
9 |
10 | import logging
11 | import sys
12 | import os
13 | import datetime
14 | from termcolor import colored
15 |
16 | __all__ = ['set_file_handler'] # the actual worker is the '_logger'
17 |
18 |
19 | class _MyFormatter(logging.Formatter):
20 | '''
21 | @brief:
22 | a class to make sure the format could be used
23 | '''
24 |
25 | def format(self, record):
26 | date = colored('[%(asctime)s @%(filename)s:%(lineno)d]', 'green')
27 | msg = '%(message)s'
28 |
29 | if record.levelno == logging.WARNING:
30 | fmt = date + ' ' + \
31 | colored('WRN', 'red', attrs=[]) + ' ' + msg
32 | elif record.levelno == logging.ERROR or \
33 | record.levelno == logging.CRITICAL:
34 | fmt = date + ' ' + \
35 | colored('ERR', 'red', attrs=['underline']) + ' ' + msg
36 | else:
37 | fmt = date + ' ' + msg
38 |
39 | if hasattr(self, '_style'):
40 | # Python3 compatibilty
41 | self._style._fmt = fmt
42 | self._fmt = fmt
43 |
44 | return super(self.__class__, self).format(record)
45 |
46 |
47 | _logger = logging.getLogger('joint_embedding')
48 | _logger.propagate = False
49 | _logger.setLevel(logging.INFO)
50 |
51 | # set the console output handler
52 | con_handler = logging.StreamHandler(sys.stdout)
53 | con_handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))
54 | _logger.addHandler(con_handler)
55 |
56 |
57 | class GLOBAL_PATH(object):
58 |
59 | def __init__(self, path=None):
60 | if path is None:
61 | path = os.getcwd()
62 | self.path = path
63 |
64 | def _set_path(self, path):
65 | self.path = path
66 |
67 | def _get_path(self):
68 | return self.path
69 |
70 |
71 | PATH = GLOBAL_PATH()
72 |
73 |
74 | def set_file_handler(path=None, prefix='', time_str=''):
75 | # set the file output handler
76 | if time_str == '':
77 | file_name = prefix + \
78 | datetime.datetime.now().strftime("%A_%d_%B_%Y_%I:%M%p") + '.log'
79 | else:
80 | file_name = prefix + time_str + '.log'
81 |
82 | if path is None:
83 | mod = sys.modules['__main__']
84 | path = os.path.join(os.path.abspath(mod.__file__), '..', '..', 'log')
85 | else:
86 | path = os.path.join(path, 'log')
87 | path = os.path.abspath(path)
88 |
89 | path = os.path.join(path, file_name)
90 | if not os.path.exists(path):
91 | os.makedirs(path)
92 |
93 | PATH._set_path(path)
94 | path = os.path.join(path, file_name)
95 | from tensorboard_logger import configure
96 | configure(path)
97 |
98 | file_handler = logging.FileHandler(
99 | filename=os.path.join(path, 'logger'), encoding='utf-8', mode='w')
100 | file_handler.setFormatter(_MyFormatter(datefmt='%m%d %H:%M:%S'))
101 | _logger.addHandler(file_handler)
102 |
103 | _logger.info('Log file set to {}'.format(path))
104 | return path
105 |
106 |
107 | def _get_path():
108 | return PATH._get_path()
109 |
110 |
111 | _LOGGING_METHOD = ['info', 'warning', 'error', 'critical',
112 | 'warn', 'exception', 'debug']
113 |
114 | # export logger functions
115 | for func in _LOGGING_METHOD:
116 | locals()[func] = getattr(_logger, func)
117 |
--------------------------------------------------------------------------------
/mbbl/util/common/parallel_util.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief:
3 | # Define some signals used during parallel
4 | # @author:
5 | # Tingwu Wang
6 | # -----------------------------------------------------------------------------
7 |
8 | TRAIN_SIGNAL = 1
9 | SAVE_SIGNAL = 2
10 |
11 | # it makes the main trpo agent push its weights into the tunnel
12 | START_SIGNAL = 3
13 |
14 | # it ends the training
15 | END_SIGNAL = 4
16 |
17 | # ends the rollout
18 | END_ROLLOUT_SIGNAL = 5
19 |
20 | # ask the rollout agents to collect the ob normalizer's info
21 | AGENT_COLLECT_FILTER_INFO = 6
22 |
23 | # ask the rollout agents to synchronize the ob normalizer's info
24 | AGENT_SYNCHRONIZE_FILTER = 7
25 |
26 | # ask the agents to set their parameters of network
27 | AGENT_SET_WEIGHTS = 8
28 |
29 | # reset
30 | RESET_SIGNAL = 9
31 |
32 | # Initial training for mbmf policy netwrok.
33 | MBMF_INITIAL = 666
34 |
35 | # ask for policy network.
36 | GET_POLICY_NETWORK = 6666
37 |
38 | # ask and set for policy network weight.
39 | GET_POLICY_WEIGHT = 66
40 | SET_POLICY_WEIGHT = 66666
41 |
42 |
43 | WORKER_PLANNING = 10
44 | WORKER_PLAYING = 11
45 | WORKER_GET_MODEL = 12
46 | WORKER_RATE_ACTIONS = 13
47 |
48 | # make sure that no signals are using the same number
49 | var_dict = locals()
50 | var_list = [var_dict[var] for var in dir() if
51 | (not var.startswith('_') and type(var_dict[var]) == int)]
52 |
53 | assert len(var_list) == len(set(var_list))
54 |
--------------------------------------------------------------------------------
/mbbl/util/common/replay_buffer.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief: save the true datapoints into a buffer
3 | # @author: Tingwu Wang
4 | # -----------------------------------------------------------------------------
5 | import numpy as np
6 |
7 |
8 | class replay_buffer(object):
9 |
10 | def __init__(self, use_buffer, buffer_size, rand_seed,
11 | observation_size, action_size, save_reward=False):
12 |
13 | self._use_buffer = use_buffer
14 | self._buffer_size = buffer_size
15 | self._npr = np.random.RandomState(rand_seed)
16 |
17 | if not self._use_buffer:
18 | self._buffer_size = 0
19 |
20 | self._observation_size = observation_size
21 | self._action_size = action_size
22 |
23 | reward_data_size = self._buffer_size if save_reward else 0
24 | self._data = {
25 | 'start_state': np.zeros(
26 | [self._buffer_size, self._observation_size],
27 | dtype=np.float16
28 | ),
29 |
30 | 'end_state': np.zeros(
31 | [self._buffer_size, self._observation_size],
32 | dtype=np.float16
33 | ),
34 |
35 | 'action': np.zeros(
36 | [self._buffer_size, self._action_size],
37 | dtype=np.float16
38 | ),
39 |
40 | 'reward': np.zeros([reward_data_size], dtype=np.float16)
41 | }
42 | self._data_key = [key for key in self._data if len(self._data[key]) > 0]
43 |
44 | self._current_id = 0
45 | self._occupied_size = 0
46 |
47 | def add_data(self, new_data):
48 | if self._buffer_size == 0:
49 | return
50 |
51 | num_new_data = len(new_data['start_state'])
52 |
53 | if num_new_data + self._current_id > self._buffer_size:
54 | num_after_full = num_new_data + self._current_id - self._buffer_size
55 | for key in self._data_key:
56 | # filling the tail part
57 | self._data[key][self._current_id: self._buffer_size] = \
58 | new_data[key][0: self._buffer_size - self._current_id]
59 |
60 | # filling the head part
61 | self._data[key][0: num_after_full] = \
62 | new_data[key][self._buffer_size - self._current_id:]
63 |
64 | else:
65 |
66 | for key in self._data_key:
67 | self._data[key][self._current_id:
68 | self._current_id + num_new_data] = \
69 | new_data[key]
70 |
71 | self._current_id = \
72 | (self._current_id + num_new_data) % self._buffer_size
73 | self._occupied_size = \
74 | min(self._buffer_size, self._occupied_size + num_new_data)
75 |
76 | def get_data(self, batch_size):
77 |
78 | # the data from old data
79 | sample_id = self._npr.randint(0, self._occupied_size, batch_size)
80 | return {key: self._data[key][sample_id] for key in self._data_key}
81 |
82 | def get_current_size(self):
83 | return self._occupied_size
84 |
85 | def get_all_data(self):
86 | return {key: self._data[key][:self._occupied_size]
87 | for key in self._data_key}
88 |
--------------------------------------------------------------------------------
/mbbl/util/common/summary_handler.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief:
3 | # define the how to record the summary
4 | # @author:
5 | # Tingwu Wang
6 | # -----------------------------------------------------------------------------
7 | import os
8 |
9 | import tensorflow as tf
10 |
11 | from mbbl.config import init_path
12 | from mbbl.util.common import logger
13 |
14 |
15 | class summary_handler(object):
16 | '''
17 | @brief:
18 | Tell the handler where to record all the information.
19 | Normally, we want to record the prediction of value loss, and the
20 | average reward (maybe learning rate)
21 | '''
22 |
23 | def __init__(self, sess, summary_name, enable=True, summary_dir=None):
24 | # the interface we need
25 | self.summary = None
26 | self.sess = sess
27 | self.enable = enable
28 | if not self.enable: # the summary handler is disabled
29 | return
30 | if summary_dir is None:
31 | self.path = os.path.join(
32 | init_path.get_base_dir(), 'summary'
33 | )
34 | else:
35 | self.path = os.path.join(summary_dir, 'summary')
36 | self.path = os.path.abspath(self.path)
37 |
38 | if not os.path.exists(self.path):
39 | os.makedirs(self.path)
40 | self.path = os.path.join(self.path, summary_name)
41 |
42 | self.train_writer = tf.summary.FileWriter(self.path, self.sess.graph)
43 |
44 | logger.info(
45 | 'summary write initialized, writing to {}'.format(self.path))
46 |
47 | def get_tf_summary(self):
48 | assert self.summary is not None, logger.error(
49 | 'tf summary not defined, call the summary object separately')
50 | return self.summary
51 |
52 |
53 | class gym_summary_handler(summary_handler):
54 | '''
55 | @brief:
56 | For the gym environment, we pass the stuff we want to record
57 | '''
58 |
59 | def __init__(self, sess, summary_name, enable=True,
60 | scalar_var_list=dict(), summary_dir=None):
61 | super(self.__class__, self).__init__(sess, summary_name, enable=enable,
62 | summary_dir=summary_dir)
63 | if not self.enable:
64 | return
65 | assert type(scalar_var_list) == dict, logger.error(
66 | 'We only take the dict where the name is given as the key')
67 |
68 | if len(scalar_var_list) > 0:
69 | self.summary_list = []
70 | for name, var in scalar_var_list.items():
71 | self.summary_list.append(tf.summary.scalar(name, var))
72 | self.summary = tf.summary.merge(self.summary_list)
73 |
74 | def manually_add_scalar_summary(self, summary_name, summary_value, x_axis):
75 | '''
76 | @brief:
77 | might be useful to record the average game_length, and average
78 | reward
79 | @input:
80 | x_axis could either be the episode number of step number
81 | '''
82 | if not self.enable: # it happens when we are just debugging
83 | return
84 |
85 | if 'expert_traj' in summary_name:
86 | return
87 |
88 | summary = tf.Summary(
89 | value=[tf.Summary.Value(
90 | tag=summary_name, simple_value=summary_value
91 | ), ]
92 | )
93 | self.train_writer.add_summary(summary, x_axis)
94 |
--------------------------------------------------------------------------------
/mbbl/util/common/tf_norm.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief: define the batchnorm and layernorm in this function
3 | # ------------------------------------------------------------------------------
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def layer_norm(x, name_scope, epsilon=1e-5, use_bias=True,
9 | use_scale=True, gamma_init=None, data_format='NHWC'):
10 | """
11 | @Brief: code modified from ppwwyyxx github.com/ppwwyyxx/tensorpack/,
12 | under layer_norm.py.
13 | Layer Normalization layer, as described in the paper:
14 | https://arxiv.org/abs/1607.06450.
15 | @input:
16 | x (tf.Tensor): a 4D or 2D tensor. When 4D, the layout should
17 | match data_format.
18 | """
19 | with tf.variable_scope(name_scope):
20 | shape = x.get_shape().as_list()
21 | ndims = len(shape)
22 | assert ndims in [2, 4]
23 |
24 | mean, var = tf.nn.moments(x, list(range(1, len(shape))), keep_dims=True)
25 |
26 | if data_format == 'NCHW':
27 | chan = shape[1]
28 | new_shape = [1, chan, 1, 1]
29 | else:
30 | chan = shape[-1]
31 | new_shape = [1, 1, 1, chan]
32 | if ndims == 2:
33 | new_shape = [1, chan]
34 |
35 | if use_bias:
36 | beta = tf.get_variable(
37 | 'beta', [chan], initializer=tf.constant_initializer()
38 | )
39 | beta = tf.reshape(beta, new_shape)
40 | else:
41 | beta = tf.zeros([1] * ndims, name='beta')
42 | if use_scale:
43 | if gamma_init is None:
44 | gamma_init = tf.constant_initializer(1.0)
45 | gamma = tf.get_variable('gamma', [chan], initializer=gamma_init)
46 | gamma = tf.reshape(gamma, new_shape)
47 | else:
48 | gamma = tf.ones([1] * ndims, name='gamma')
49 |
50 | ret = tf.nn.batch_normalization(
51 | x, mean, var, beta, gamma, epsilon, name='output'
52 | )
53 | return ret
54 |
55 |
56 | def batch_norm_with_train(x, name_scope, epsilon=1e-5, momentum=0.9):
57 | ret = tf.contrib.layers.batch_norm(
58 | x, decay=momentum, updates_collections=None, epsilon=epsilon,
59 | scale=True, is_training=True, scope=name_scope
60 | )
61 | return ret
62 |
63 |
64 | def batch_norm_without_train(x, name_scope, epsilon=1e-5, momentum=0.9):
65 | ret = tf.contrib.layers.batch_norm(
66 | x, decay=momentum, updates_collections=None, epsilon=epsilon,
67 | scale=True, is_training=False, scope=name_scope
68 | )
69 | return ret
70 |
--------------------------------------------------------------------------------
/mbbl/util/gps/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/util/gps/__init__.py
--------------------------------------------------------------------------------
/mbbl/util/il/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/util/il/__init__.py
--------------------------------------------------------------------------------
/mbbl/util/il/test_inverse_dynamics.py:
--------------------------------------------------------------------------------
1 | """
2 | @brief: understand some statistics about the inverse dynamics
3 | @author: Tingwu Wang
4 |
5 | @Date: Jan 12, 2019
6 | """
7 | # import matplotlib.pyplot as plt
8 | # from mbbl.env.dm_env.pos_dm_env import POS_CONNECTION
9 | from mbbl.env.env_register import make_env
10 | # from mbbl.util.common import logger
11 | # from mbbl.config import init_path
12 | # import cv2
13 | # import os
14 | # from skimage import draw
15 | import numpy as np
16 | # from mbbl.util.il.expert_data_util import load_pose_data
17 | # import argparse
18 |
19 |
20 | if __name__ == '__main__':
21 |
22 | env, env_info = make_env("cheetah-run-pos", 1234)
23 | control_info = env.get_controller_info()
24 | dynamics_env, _ = make_env("cheetah-run-pos", 1234)
25 |
26 | # generate the data
27 | env.reset()
28 | for i in range(1000):
29 | action = np.random.randn(env_info['action_size'])
30 | qpos = np.array(env._env.physics.data.qpos, copy=True)
31 | old_qpos = np.array(env._env.physics.data.qpos, copy=True)
32 | old_qvel = np.array(env._env.physics.data.qvel, copy=True)
33 | old_qacc = np.array(env._env.physics.data.qacc, copy=True)
34 | old_qfrc_inverse = np.array(env._env.physics.data.qfrc_inverse, copy=True)
35 | _, _, _, _ = env.step(action)
36 | ctrl = np.array(env._env.physics.data.ctrl, copy=True)
37 | qvel = np.array(env._env.physics.data.qvel, copy=True)
38 | qacc = np.array(env._env.physics.data.qacc, copy=True)
39 |
40 | # see the inverse
41 | qfrc_inverse = dynamics_env._env.physics.get_inverse_output(qpos, qvel, qacc)
42 | qfrc_action = qfrc_inverse[None, control_info['actuated_id']]
43 | action = ctrl * control_info['gear']
44 | print("predicted action: {}\n".format(qfrc_action))
45 | print("groundtruth action: {}\n".format(action))
46 |
--------------------------------------------------------------------------------
/mbbl/util/ilqr/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/util/ilqr/__init__.py
--------------------------------------------------------------------------------
/mbbl/util/ilqr/ilqr_utils.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------------
2 | # @brief:
3 | # -----------------------------------------------------------------------------
4 | from mbbl.config import init_path
5 |
6 | _BASE_DIR = init_path.get_abs_base_dir()
7 |
8 |
9 | def update_damping_lambda(traj_data, increase, damping_args):
10 | if increase:
11 | traj_data['lambda_multiplier'] = max(
12 | traj_data['lambda_multiplier'] * damping_args['factor'],
13 | damping_args['factor']
14 | )
15 | traj_data['damping_lambda'] = max(
16 | traj_data['damping_lambda'] * traj_data['lambda_multiplier'],
17 | damping_args['min_damping']
18 | )
19 | else: # decrease
20 | traj_data['lambda_multiplier'] = min(
21 | traj_data['lambda_multiplier'] / damping_args['factor'],
22 | 1.0 / damping_args['factor']
23 | )
24 | traj_data['damping_lambda'] = \
25 | traj_data['damping_lambda'] * traj_data['lambda_multiplier'] * \
26 | (traj_data['damping_lambda'] > damping_args['min_damping'])
27 |
28 | traj_data['damping_lambda'] = \
29 | max(traj_data['damping_lambda'], damping_args['min_damping'])
30 | if traj_data['damping_lambda'] > damping_args['max_damping']:
31 | traj_data['active'] = False
32 |
--------------------------------------------------------------------------------
/mbbl/worker/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WilsonWangTHU/mbbl/bb88a016de2fcd8ea0ed9c4d5c539817d2b476e7/mbbl/worker/__init__.py
--------------------------------------------------------------------------------
/mbbl/worker/cem_worker.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # ------------------------------------------------------------------------------
4 | import numpy as np
5 |
6 | from .base_worker import base_worker
7 | from mbbl.config import init_path
8 | # from mbbl.util.common import parallel_util
9 |
10 |
11 | def detect_done(ob, env_name, check_done):
12 | if env_name == 'gym_fhopper':
13 | height, ang = ob[:, 0], ob[:, 1]
14 | done = np.logical_or(height <= 0.7, abs(ang) >= 0.2)
15 |
16 | elif env_name == 'gym_fwalker2d':
17 | height, ang = ob[:, 0], ob[:, 1]
18 | done = np.logical_or(
19 | height >= 2.0,
20 | np.logical_or(height <= 0.8, abs(ang) >= 1.0)
21 | )
22 |
23 | elif env_name == 'gym_fant':
24 | height = ob[:, 0]
25 | done = np.logical_or(height > 1.0, height < 0.2)
26 |
27 | elif env_name in ['gym_fant2', 'gym_fant5', 'gym_fant10',
28 | 'gym_fant20', 'gym_fant30']:
29 | height = ob[:, 0]
30 | done = np.logical_or(height > 1.0, height < 0.2)
31 | else:
32 | done = np.zeros([ob.shape[0]])
33 |
34 | if not check_done:
35 | done[:] = False
36 |
37 | return done
38 |
39 |
40 | class worker(base_worker):
41 | EPSILON = 0.001
42 |
43 | def __init__(self, args, observation_size, action_size,
44 | network_type, task_queue, result_queue, worker_id,
45 | name_scope='planning_worker'):
46 |
47 | # the base agent
48 | super(worker, self).__init__(args, observation_size, action_size,
49 | network_type, task_queue, result_queue,
50 | worker_id, name_scope)
51 | self._base_dir = init_path.get_base_dir()
52 | self._alpha = args.cem_learning_rate
53 | self._num_iters = args.cem_num_iters
54 | self._elites_fraction = args.cem_elites_fraction
55 |
56 | def _plan(self, planning_data):
57 | num_traj = planning_data['state'].shape[0]
58 | sample_action = np.reshape(
59 | planning_data['samples'],
60 | [num_traj, planning_data['depth'], planning_data['action_size']]
61 | )
62 | current_state = planning_data['state']
63 | total_reward = 0
64 | done = np.zeros([num_traj])
65 |
66 | for i_depth in range(planning_data['depth']):
67 | action = sample_action[:, i_depth, :]
68 |
69 | # pred next state
70 | next_state, _, _ = self._network['dynamics'][0].pred(
71 | {'start_state': current_state, 'action': action}
72 | )
73 |
74 | # pred the reward
75 | reward, _, _ = self._network['reward'][0].pred(
76 | {'start_state': current_state, 'action': action}
77 | )
78 |
79 | total_reward += reward * (1 - done)
80 | current_state = next_state
81 |
82 | # mark the done marker
83 | this_done = detect_done(next_state, self.args.task, self.args.check_done)
84 | done = np.logical_or(this_done, done)
85 | # if np.any(done):
86 | # from mbbl.util.common.fpdb import fpdb; fpdb().set_trace()
87 |
88 | # Return the cost
89 | return_dict = {'costs': -total_reward,
90 | 'sample_id': planning_data['id']}
91 | return return_dict
92 |
--------------------------------------------------------------------------------
/mbbl/worker/metrpo_worker.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # ------------------------------------------------------------------------------
4 | import numpy as np
5 |
6 | from .base_worker import base_worker
7 | from mbbl.config import init_path
8 | from mbbl.env import env_register
9 | from mbbl.env.env_util import play_episode_with_env
10 | from mbbl.env.fake_env import fake_env
11 |
12 |
13 | class worker(base_worker):
14 |
15 | def __init__(self, args, observation_size, action_size,
16 | network_type, task_queue, result_queue, worker_id,
17 | name_scope='planning_worker'):
18 |
19 | # the base agent
20 | super(worker, self).__init__(args, observation_size, action_size,
21 | network_type, task_queue, result_queue,
22 | worker_id, name_scope)
23 | self._base_dir = init_path.get_base_dir()
24 |
25 | # build the environments
26 | self._build_env()
27 |
28 | def _build_env(self):
29 | self._env, self._env_info = env_register.make_env(
30 | self.args.task, self._npr.randint(0, 9999),
31 | {'allow_monitor': self.args.monitor and self._worker_id == 0}
32 | )
33 | self._fake_env = fake_env(self._env, self._step)
34 |
35 | def _plan(self, planning_data):
36 | raise NotImplementedError
37 |
38 | def _play(self, planning_data):
39 | '''
40 | # TODO NOTE:
41 | var_list = self._network['policy'][0]._trainable_var_list
42 | print('')
43 | for var in var_list:
44 | print(var.name)
45 | # print(var.name, self._session.run(var)[-1])
46 | '''
47 | if planning_data['use_true_env']:
48 | traj_episode = play_episode_with_env(
49 | self._env, self._act,
50 | {'use_random_action': planning_data['use_random_action']}
51 | )
52 | else:
53 | traj_episode = play_episode_with_env(
54 | self._fake_env, self._act,
55 | {'use_random_action': planning_data['use_random_action']}
56 | )
57 | return traj_episode
58 |
59 | def _act(self, state,
60 | control_info={'use_random_action': False}):
61 |
62 | if 'use_random_action' in control_info and \
63 | control_info['use_random_action']:
64 | # use random policy
65 | action = self._npr.uniform(-1, 1, [self._action_size])
66 | return action, [-1], [-1]
67 |
68 | else:
69 | # call the policy network
70 | return self._network['policy'][0].act({'start_state': state})
71 |
72 | def _step(self, state, action):
73 | state = np.reshape(state, [-1, self._observation_size])
74 | action = np.reshape(action, [-1, self._action_size])
75 | # pred next state
76 | next_state, _, _ = self._network['dynamics'][0].pred(
77 | {'start_state': state, 'action': action}
78 | )
79 |
80 | # pred the reward
81 | reward, _, _ = self._network['reward'][0].pred(
82 | {'start_state': state, 'action': action}
83 | )
84 | # from util.common.fpdb import fpdb; fpdb().set_trace()
85 | return next_state[0], reward[0]
86 |
--------------------------------------------------------------------------------
/mbbl/worker/mf_worker.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # ------------------------------------------------------------------------------
4 | from .base_worker import base_worker
5 | from mbbl.config import init_path
6 | from mbbl.env.env_util import play_episode_with_env
7 | from mbbl.util.common import logger
8 | import numpy as np
9 |
10 |
11 | class worker(base_worker):
12 |
13 | def __init__(self, args, observation_size, action_size,
14 | network_type, task_queue, result_queue, worker_id,
15 | name_scope='planning_worker'):
16 |
17 | # the base agent
18 | super(worker, self).__init__(args, observation_size, action_size,
19 | network_type, task_queue, result_queue,
20 | worker_id, name_scope)
21 | self._base_dir = init_path.get_base_dir()
22 | self._previous_reward = -np.inf
23 |
24 | # build the environments
25 | self._build_env()
26 |
27 | def _plan(self, planning_data):
28 | raise NotImplementedError
29 |
30 | def _play(self, planning_data):
31 | if self.args.num_expert_episode_to_save > 0 and \
32 | self._previous_reward > self._env_solved_reward and \
33 | self._worker_id == 0:
34 | start_save_episode = True
35 | logger.info('Last episodic reward: %.4f' % self._previous_reward)
36 | logger.info('Minimum reward of %.4f is needed to start saving'
37 | % self._env_solved_reward)
38 | logger.info('[SAVING] Worker %d will record its episode data'
39 | % self._worker_id)
40 | else:
41 | start_save_episode = False
42 | if self.args.num_expert_episode_to_save > 0 \
43 | and self._worker_id == 0:
44 | logger.info('Last episodic reward: %.4f' %
45 | self._previous_reward)
46 | logger.info('Minimum reward of %.4f is needed to start saving'
47 | % self._env_solved_reward)
48 |
49 | traj_episode = play_episode_with_env(
50 | self._env, self._act,
51 | {'use_random_action': planning_data['use_random_action'],
52 | 'record_flag': start_save_episode,
53 | 'num_episode': self.args.num_expert_episode_to_save,
54 | 'data_name': self.args.task + '_' + self.args.exp_id}
55 | )
56 | self._previous_reward = np.sum(traj_episode['rewards'])
57 | return traj_episode
58 |
59 | def _act(self, state,
60 | control_info={'use_random_action': False}):
61 |
62 | if 'use_random_action' in control_info and \
63 | control_info['use_random_action']:
64 | # use random policy
65 | action = self._npr.uniform(-1, 1, [self._action_size])
66 | return action, [-1], [-1]
67 |
68 | else:
69 | # call the policy network
70 | return self._network['policy'][0].act({'start_state': state})
71 |
--------------------------------------------------------------------------------
/mbbl/worker/model_worker.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # @brief:
3 | # ------------------------------------------------------------------------------
4 | from .base_worker import base_worker
5 | from mbbl.config import init_path
6 |
7 |
8 | class worker(base_worker):
9 |
10 | def __init__(self, args, observation_size, action_size,
11 | network_type, task_queue, result_queue, worker_id,
12 | name_scope='derivative_worker'):
13 |
14 | # the base agent
15 | super(worker, self).__init__(args, observation_size, action_size,
16 | network_type, task_queue, result_queue,
17 | worker_id, name_scope)
18 | self._base_dir = init_path.get_base_dir()
19 |
20 | # build the environments
21 | self._build_env()
22 |
23 | def _dynamics_derivative(self, data_dict,
24 | target=['state', 'action', 'state-action']):
25 |
26 | assert len(self._network['dynamics']) == 1
27 | return self._network['dynamics'][0].get_derivative(data_dict, target)
28 |
29 | def _reward_derivative(self, data_dict,
30 | target=['state', 'action', 'state-state']):
31 |
32 | assert len(self._network['reward']) == 1
33 | return self._network['reward'][0].get_derivative(data_dict, target)
34 |
35 | def _dynamics(self, data_dict):
36 | assert len(self._network['dynamics']) == 1
37 | return {'end_state': self._network['dynamics'][0].pred(data_dict)[0]}
38 |
39 | def _reward(self, data_dict):
40 | assert len(self._network['reward']) == 1
41 | return {'reward': self._network['reward'][0].pred(data_dict)[0]}
42 |
--------------------------------------------------------------------------------
/mbbl/worker/readme.md:
--------------------------------------------------------------------------------
1 | rs: random shooting method
2 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/ilqr.sh:
--------------------------------------------------------------------------------
1 | # the first batch:
2 | # gym_reacher, gym_cheetah, gym_walker2d, gym_hopper, gym_swimmer, gym_ant
3 |
4 | # the second batch:
5 | # gym_pendulum, gym_invertedPendulum, gym_acrobot, gym_mountain, gym_cartpole
6 | # 200 , 100 , 200 , 200 , 200
7 |
8 | # bash ilqr.sh gym_pendulum 4000 200
9 | # bash ilqr.sh gym_invertedPendulum 2000 100
10 | # bash ilqr.sh gym_acrobot 4000 200
11 | # bash ilqr.sh gym_mountain 4000 200
12 | # bash ilqr.sh gym_cartpole 4000 200
13 |
14 | # $1 is the environment
15 | # ilqr_depth: [10, 20, 50, 100]
16 | for ilqr_depth in 10 20 30 50 ; do
17 | python main/ilqr_main.py --max_timesteps $2 --task $1 --timesteps_per_batch $3 --ilqr_iteration 10 --ilqr_depth $ilqr_depth --max_ilqr_linesearch_backtrack 10 --exp_id $1-depth-$ilqr_depth --num_workers 2 --gt_dynamics 1
18 | done
19 |
20 | for num_ilqr_traj in 2 5 10; do
21 | python main/ilqr_main.py --max_timesteps $2 --task $1 --timesteps_per_batch $3 --ilqr_iteration 10 --ilqr_depth 30 --max_ilqr_linesearch_backtrack 10 --num_ilqr_traj $num_ilqr_traj --exp_id $1-traj-$num_ilqr_traj --num_workers 2 --gt_dynamics 1
22 | done
23 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/ilqr_depth.sh:
--------------------------------------------------------------------------------
1 |
2 | # $1 is the environment
3 | # ilqr_depth: [10, 20, 50, 100]
4 | for ilqr_depth in 10 20 30 50 ; do
5 | python main/ilqr_main.py --max_timesteps 2000 --task $1 --timesteps_per_batch 50 --ilqr_iteration 10 --ilqr_depth $ilqr_depth --max_ilqr_linesearch_backtrack 10 --exp_id $1-depth-$ilqr_depth --num_workers 2 --gt_dynamics 1
6 | done
7 |
8 | for num_ilqr_traj in 2 5 10; do
9 | python main/ilqr_main.py --max_timesteps 2000 --task $1 --timesteps_per_batch 50 --ilqr_iteration 10 --ilqr_depth 30 --max_ilqr_linesearch_backtrack 10 --num_ilqr_traj $num_ilqr_traj --exp_id $1-traj-$num_ilqr_traj --num_workers 2 --gt_dynamics 1
10 | done
11 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/mbmf.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 | # gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant
5 | # gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole
6 |
7 | for trust_region_method in ppo; do
8 | for batch_size in 1000; do
9 | for env_type in $1; do
10 | for seed in 1234 2341 3412 4123; do
11 | python main/mbmf_main.py --exp_id mbmf_${env_type}_${trust_region_method}_seed_${seed}\
12 | --task $env_type \
13 | --trust_region_method ${trust_region_method} \
14 | --num_planning_traj 5000 --planning_depth 20 --random_timesteps 1000 \
15 | --timesteps_per_batch $batch_size --dynamics_epochs 30 \
16 | --num_workers 20 --mb_timesteps 7000 --dagger_epoch 300 \
17 | --dagger_timesteps_per_iter 1750 --max_timesteps 200000 \
18 | --seed $seed --dynamics_batch_size 500
19 | done
20 | done
21 | done
22 | done
23 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/mbmf_1m.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 | # gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant
5 |
6 | for batch_size in 2000; do
7 | for env_type in $1; do
8 | for seed in 1234 2341 3412 4123; do
9 | python main/mbmf_main.py --exp_id mbmf_${env_type}_seed_${seed}_1m \
10 | --task $env_type \
11 | --num_planning_traj 5000 --planning_depth 20 --random_timesteps 10000 \
12 | --timesteps_per_batch $batch_size --dynamics_epochs 30 --num_workers 10 \
13 | --mb_timesteps 70000 --dagger_epoch 300 --dagger_timesteps_per_iter 1750 \
14 | --trust_region_method ppo \
15 | --max_timesteps 1000000 --seed $seed --dynamics_batch_size 500
16 | done
17 | done
18 | done
19 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/pets_gt.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 | # bash gym_reacher 2000
6 | # bash gym_cheetah 20000
7 | # bash gym_walker2d 20000
8 | # bash gym_hopper 20000
9 | # bash gym_swimmer 20000
10 | # bash gym_ant 20000
11 | # bash gym_pendulum 5000
12 | # bash gym_invertedPendulum 2500
13 | # bash gym_acrobot 5000
14 | # bash gym_mountain 5000
15 | # bash gym_cartpole 5000
16 |
17 |
18 | # bash gym_reacher 2000 30
19 | # bash gym_cheetah 20000 100
20 | # bash gym_walker2d 20000 50
21 | # bash gym_hopper 20000 100
22 | # bash gym_swimmer 20000 50
23 | # bash gym_ant 20000 30
24 | # bash gym_pendulum 5000 30
25 | # bash gym_invertedPendulum 2500 30
26 | # bash gym_acrobot 5000 30
27 | # bash gym_mountain 5000 30
28 | # bash gym_cartpole 5000 30
29 |
30 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
31 | # for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do
32 | for env_type in $1; do
33 | python main/pets_main.py --exp_id rs_${env_type}\
34 | --task $env_type \
35 | --num_planning_traj 500 --planning_depth $2 --random_timesteps 0 \
36 | --timesteps_per_batch 1 --num_workers 10 --max_timesteps 20000 \
37 | --gt_dynamics 1
38 | done
39 |
40 | # python2 main/pets_main.py --timesteps_per_batch 2000 --task gym_cheetah --num_workers 5 --planning_depth 15 --num_planning_traj 500
41 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/pets_gt_checkdone.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 | # bash gym_reacher 2000
6 | # bash gym_cheetah 20000
7 | # bash gym_walker2d 20000
8 | # bash gym_hopper 20000
9 | # bash gym_swimmer 20000
10 | # bash gym_ant 20000
11 | # bash gym_pendulum 5000
12 | # bash gym_invertedPendulum 2500
13 | # bash gym_acrobot 5000
14 | # bash gym_mountain 5000
15 | # bash gym_cartpole 5000
16 |
17 |
18 | # bash gym_reacher 2000 30
19 | # bash gym_cheetah 20000 100
20 | # bash gym_walker2d 20000 50
21 | # bash gym_hopper 20000 100
22 | # bash gym_swimmer 20000 50
23 | # bash gym_ant 20000 30
24 | # bash gym_pendulum 5000 30
25 | # bash gym_invertedPendulum 2500 30
26 | # bash gym_acrobot 5000 30
27 | # bash gym_mountain 5000 30
28 | # bash gym_cartpole 5000 30
29 |
30 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
31 | # for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do
32 | for env_type in $1; do
33 | python main/pets_main.py --exp_id rs_${env_type}\
34 | --task $env_type \
35 | --num_planning_traj 500 --planning_depth $2 --random_timesteps 0 \
36 | --timesteps_per_batch 1 --num_workers 10 --max_timesteps 20000 \
37 | --gt_dynamics 1 --check_done 1
38 | done
39 |
40 | # python2 main/pets_main.py --timesteps_per_batch 2000 --task gym_cheetah --num_workers 5 --planning_depth 15 --num_planning_traj 500
41 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/ppo.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
5 |
6 |
7 | for batch_size in 2000 5000; do
8 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
9 | for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do
10 | for seed in 1234 2341 3412 4123; do
11 | python main/mf_main.py --exp_id ppo_${env_type}_batch_${batch_size}_seed_${seed} \
12 | --timesteps_per_batch $batch_size --task $env_type \
13 | --num_workers 5 --trust_region_method ppo --max_timesteps 1000000
14 | done
15 | done
16 | done
17 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/random_policy.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
5 |
6 |
7 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
8 | for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do
9 | python main/random_main.py --exp_id random_${env_type} \
10 | --timesteps_per_batch 1 --task $env_type \
11 | --num_workers 1 --max_timesteps 40000
12 | done
13 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/rs.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 |
6 | for seed in 1234 2341 3412 4123; do
7 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
8 | for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do
9 | python main/rs_main.py --exp_id rs_${env_type}_seed_${seed}\
10 | --task $env_type \
11 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 10000 \
12 | --timesteps_per_batch 3000 --num_workers 20 --max_timesteps 200000 --seed $seed
13 | done
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/rs_1.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 |
6 | for seed in 1234 2341 3412 4123; do
7 | for env_type in gym_reacher gym_cheetah gym_walker2d; do
8 | python main/rs_main.py --exp_id rs_${env_type}_seed_${seed}\
9 | --task $env_type \
10 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 10000 \
11 | --timesteps_per_batch 3000 --num_workers 20 --max_timesteps 300000 --seed $seed
12 | done
13 | done
14 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/rs_2.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 |
6 | for seed in 1234 2341 3412 4123; do
7 | for env_type in gym_hopper gym_swimmer gym_ant; do
8 | python main/rs_main.py --exp_id rs_${env_type}_seed_${seed}\
9 | --task $env_type \
10 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 10000 \
11 | --timesteps_per_batch 3000 --num_workers 20 --max_timesteps 300000 --seed $seed
12 | done
13 | done
14 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/rs_gt.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 | # bash gym_reacher 2000
6 | # bash gym_cheetah 20000
7 | # bash gym_walker2d 20000
8 | # bash gym_hopper 20000
9 | # bash gym_swimmer 20000
10 | # bash gym_ant 20000
11 | # bash gym_pendulum 5000
12 | # bash gym_invertedPendulum 2500
13 | # bash gym_acrobot 5000
14 | # bash gym_mountain 5000
15 | # bash gym_cartpole 5000
16 |
17 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
18 | # for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do
19 | for env_type in $1; do
20 | python main/rs_main.py --exp_id rs_${env_type}\
21 | --task $env_type \
22 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 0 \
23 | --timesteps_per_batch 1 --num_workers 20 --max_timesteps 20000 \
24 | --gt_dynamics 1
25 | done
26 |
27 | # python2 main/pets_main.py --timesteps_per_batch 2000 --task gym_cheetah --num_workers 5 --planning_depth 15 --num_planning_traj 500
28 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/rs_gt_checkdone.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 | # bash gym_reacher 2000
6 | # bash gym_cheetah 20000
7 | # bash gym_walker2d 20000
8 | # bash gym_hopper 20000
9 | # bash gym_swimmer 20000
10 | # bash gym_ant 20000
11 | # bash gym_pendulum 5000
12 | # bash gym_invertedPendulum 2500
13 | # bash gym_acrobot 5000
14 | # bash gym_mountain 5000
15 | # bash gym_cartpole 5000
16 |
17 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
18 | # for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do
19 | for env_type in $1; do
20 | python main/rs_main.py --exp_id rs_${env_type}\
21 | --task $env_type \
22 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 0 \
23 | --timesteps_per_batch 1 --num_workers 20 --max_timesteps 20000 \
24 | --gt_dynamics 1 --check_done 1
25 | done
26 |
27 | # python2 main/pets_main.py --timesteps_per_batch 2000 --task gym_cheetah --num_workers 5 --planning_depth 15 --num_planning_traj 500
28 |
--------------------------------------------------------------------------------
/scripts/exp_1_performance_curve/trpo.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 |
6 | for batch_size in 2000 5000; do
7 | # for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
8 | for env_type in gym_pendulum gym_invertedPendulum gym_acrobot gym_mountain gym_cartpole; do
9 | for seed in 1234 2341 3412 4123; do
10 | python main/mf_main.py --exp_id trpo_${env_type}_batch_${batch_size}_seed_${seed} \
11 | --timesteps_per_batch $batch_size --task $env_type \
12 | --num_workers 5 --trust_region_method trpo --max_timesteps 1000000
13 | done
14 | done
15 | done
16 |
--------------------------------------------------------------------------------
/scripts/exp_2_dilemma/gt_computation.sh:
--------------------------------------------------------------------------------
1 |
2 | for env_type in gym_cheetah gym_ant; do
3 | python main/rs_main.py --exp_id rs_${env_type}_depth_$1 \
4 | --task $env_type \
5 | --num_planning_traj 1000 --planning_depth $1 --random_timesteps 0 \
6 | --timesteps_per_batch 1 --num_workers 20 --max_timesteps 10000 \
7 | --gt_dynamics 1 --check_done 1
8 |
9 | python main/pets_main.py --exp_id rs_${env_type}_depth_$1 \
10 | --task $env_type \
11 | --num_planning_traj 500 --planning_depth $1 --random_timesteps 0 \
12 | --timesteps_per_batch 1 --num_workers 20 --max_timesteps 10000 \
13 | --gt_dynamics 1 --check_done 1
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/exp_2_dynamics_schemes/rs_act_norm.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 | # Tanh + None (done)
6 | # Relu + None
7 | # Leaky-Relu + None
8 | # Tanh + layer-norm
9 | # Tanh + batch-norm
10 | # Relu + layer-norm
11 | # Relu + batch-norm
12 | # Leaky-relu + layernorm
13 | # Leaky-relu + batch-norm
14 | # Tanh + None (done)
15 | # Relu + None
16 | # Leaky-Relu + None
17 | # Tanh + layer-norm
18 | # Tanh + batch-norm
19 | # Relu + layer-norm
20 | # Relu + batch-norm
21 | # Leaky-relu + layernorm
22 | # Leaky-relu + batch-norm
23 |
24 |
25 | for seed in 1234 2341 3412 4123; do
26 | for env_type in gym_reacher gym_cheetah gym_ant; do
27 | for dynamics_activation_type in 'leaky_relu' 'tanh' 'relu' 'swish' 'none'; do
28 | for dynamics_normalizer_type in 'layer_norm' 'batch_norm' 'none'; do
29 |
30 | python main/rs_main.py \
31 | --exp_id rs_${env_type}_act_${dynamics_activation_type}_norm_${dynamics_normalizer_type}_seed_${seed} \
32 | --task $env_type \
33 | --dynamics_batch_size 512 \
34 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 10000 \
35 | --timesteps_per_batch 3000 --num_workers 20 --max_timesteps 200000 \
36 | --seed $seed \
37 | --dynamics_activation_type $dynamics_activation_type \
38 | --dynamics_normalizer_type $dynamics_normalizer_type
39 |
40 | done
41 | done
42 | done
43 | done
44 |
--------------------------------------------------------------------------------
/scripts/exp_2_dynamics_schemes/rs_lr_batch.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 | # Tanh + None (done)
6 | # Relu + None
7 | # Leaky-Relu + None
8 | # Tanh + layer-norm
9 | # Tanh + batch-norm
10 | # Relu + layer-norm
11 | # Relu + batch-norm
12 | # Leaky-relu + layernorm
13 | # Leaky-relu + batch-norm
14 | # Tanh + None (done)
15 | # Relu + None
16 | # Leaky-Relu + None
17 | # Tanh + layer-norm
18 | # Tanh + batch-norm
19 | # Relu + layer-norm
20 | # Relu + batch-norm
21 | # Leaky-relu + layernorm
22 | # Leaky-relu + batch-norm
23 |
24 |
25 | for seed in 1234 2341 3412 4123; do
26 | for env_type in gym_cheetah; do
27 | for dynamics_lr in 0.0003 0.001 0.003; do
28 | for dynamics_batch_size in 256 512 1024; do
29 | for dynamics_epochs in 10 30 50; do
30 |
31 | python main/rs_main.py \
32 | --exp_id rs_${env_type}_lr_${dynamics_lr}_batchsize_${dynamics_batch_size}_epochs_${dynamics_epochs}_seed_${seed} \
33 | --task $env_type \
34 | --dynamics_batch_size 512 \
35 | --num_planning_traj 1000 --planning_depth 10 --random_timesteps 10000 \
36 | --timesteps_per_batch 3000 --num_workers 20 --max_timesteps 200000 \
37 | --seed $seed \
38 | --dynamics_lr $dynamics_lr \
39 | --dynamics_batch_size $dynamics_batch_size \
40 | --dynamics_epochs $dynamics_epochs
41 | done
42 |
43 | done
44 | done
45 | done
46 | done
47 |
--------------------------------------------------------------------------------
/scripts/exp_7_traj/ilqr_numtraj.sh:
--------------------------------------------------------------------------------
1 | # num_ilqr_traj: [2, 5, 10]
2 | for num_ilqr_traj in 1 3 5 7 9 2 4 6 8 10; do
3 | python main/ilqr_main.py --max_timesteps 10000 --task $1-1000 --timesteps_per_batch 1000 \
4 | --ilqr_iteration 10 --ilqr_depth 10 --max_ilqr_linesearch_backtrack 10 --num_ilqr_traj $num_ilqr_traj --exp_id $1-traj-$num_ilqr_traj --num_workers 2 --gt_dynamics 1
5 | done
6 |
--------------------------------------------------------------------------------
/scripts/exp_9_planning_depth/ilqr_depth.sh:
--------------------------------------------------------------------------------
1 |
2 | # $1 is the environment
3 | # ilqr_depth: [10, 20, 50, 100]
4 | for ilqr_depth in 100 80 60 40 20 10; do
5 | python main/ilqr_main.py --max_timesteps 10000 --task $1-1000 --timesteps_per_batch 1000 \
6 | --ilqr_iteration 5 --ilqr_depth $ilqr_depth --max_ilqr_linesearch_backtrack 10 \
7 | --exp_id $1-depth-$ilqr_depth --num_workers 2 --gt_dynamics 1
8 | done
9 |
--------------------------------------------------------------------------------
/scripts/ilqr/ilqr.sh:
--------------------------------------------------------------------------------
1 |
2 | # $1 is the environment
3 | # ilqr_depth: [10, 20, 50, 100]
4 | for ilqr_depth in 10 20 50 100; do
5 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 1000 --ilqr_iteration 10 --ilqr_depth $ilqr_depth --max_ilqr_linesearch_backtrack 10 --exp_id cheetah_depth_$ilqr_depth --num_workers 2
6 | done
7 |
8 | # ilqr_iteration: [5, 10, 15, 20]
9 | for ilqr_iteration in 5 10 15 20; do
10 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 20000 --ilqr_iteration $ilqr_iteration --ilqr_depth 10 --max_ilqr_linesearch_backtrack 10 --exp_id cheetah_ilqr_iteration_$ilqr_iteration --num_workers 2
11 | done
12 |
13 | # num_ilqr_traj: [2, 5, 10]
14 | for num_ilqr_traj in 2 5 10; do
15 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 1000 --ilqr_iteration 10 --ilqr_depth 10 --max_ilqr_linesearch_backtrack 10 --num_ilqr_traj $num_ilqr_traj --exp_id cheetah_traj_$num_ilqr_traj --num_workers 2
16 |
17 | done
18 |
--------------------------------------------------------------------------------
/scripts/ilqr/ilqr_depth.sh:
--------------------------------------------------------------------------------
1 |
2 | # $1 is the environment
3 | # ilqr_depth: [10, 20, 50, 100]
4 | for ilqr_depth in 10 20 50 100; do
5 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 1000 --ilqr_iteration 10 --ilqr_depth $ilqr_depth --max_ilqr_linesearch_backtrack 10 --exp_id $1-depth-$ilqr_depth --num_workers 2
6 | done
7 |
--------------------------------------------------------------------------------
/scripts/ilqr/ilqr_iter.sh:
--------------------------------------------------------------------------------
1 | # ilqr_iteration: [5, 10, 15, 20]
2 | for ilqr_iteration in 5 10 15 20; do
3 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 20000 --ilqr_iteration $ilqr_iteration --ilqr_depth 10 --max_ilqr_linesearch_backtrack 10 --exp_id $1-iteration_$ilqr_iteration --num_workers 2
4 | done
5 |
--------------------------------------------------------------------------------
/scripts/ilqr/ilqr_numtraj.sh:
--------------------------------------------------------------------------------
1 | # num_ilqr_traj: [2, 5, 10]
2 | for num_ilqr_traj in 2 5 10; do
3 | python main/ilqr_main.py --max_timesteps 20000 --task $1-1000 --timesteps_per_batch 1000 --ilqr_iteration 10 --ilqr_depth 10 --max_ilqr_linesearch_backtrack 10 --num_ilqr_traj $num_ilqr_traj --exp_id $1-traj-$num_ilqr_traj --num_workers 2
4 |
5 | done
6 |
--------------------------------------------------------------------------------
/scripts/performance_curve/mbmf.sh:
--------------------------------------------------------------------------------
1 | # see how the ppo works for the environments in terms of the performance:
2 | # batch size
3 | # max_timesteps 1e8
4 |
5 | for seed in 1234 2341 3412 4123; do
6 | for env_type in gym_reacher gym_cheetah gym_walker2d gym_hopper gym_swimmer gym_ant; do
7 | python main/mbmf_main.py --exp_id mbmf_${env_type}_seed_${seed}\
8 | --task $env_type \
9 | --num_planning_traj 5000 --planning_depth 20 --random_timesteps 10000 --timesteps_per_batch 3000 --dynamics_epochs 30 --num_workers 24 --mb_timesteps 70000 --dagger_epoch 300 --dagger_timesteps_per_iter 1750 --max_timesteps 10000000 --seed $seed
10 | done
11 | done
12 |
--------------------------------------------------------------------------------
/scripts/test/gym_humanoid.sh:
--------------------------------------------------------------------------------
1 | # see how the humanoid works in terms of the performance:
2 | # trpo / ppo
3 | # batch size
4 | # max_timesteps 1e8
5 |
6 |
7 | for batch_size in 5000 20000 50000 2000; do
8 | for tr_method in ppo trpo; do
9 | for env_type in gym_humanoid; do
10 | python main/mf_main.py --exp_id ${env_type}_${tr_method}_${batch_size} \
11 | --timesteps_per_batch $batch_size --task $env_type \
12 | --num_workers 5 --trust_region_method $tr_method --max_timesteps 50000000
13 | done
14 | done
15 | done
16 |
--------------------------------------------------------------------------------
/scripts/test/humanoid.sh:
--------------------------------------------------------------------------------
1 | # see how the humanoid works in terms of the performance:
2 | # trpo / ppo
3 | # batch size
4 | # max_timesteps 1e8
5 |
6 |
7 | for batch_size in 5000 20000 50000; do
8 | for tr_method in ppo trpo; do
9 | for env_type in dm-humanoid dm-humanoid-noise; do
10 | python main/mf_main.py --exp_id ${env_type}_${tr_method}_${batch_size} \
11 | --timesteps_per_batch $batch_size --task $env_type \
12 | --num_workers 5 --trust_region_method $tr_method --max_timesteps 50000000 \
13 | --num_expert_episode_to_save 5
14 | done
15 | done
16 | done
17 |
--------------------------------------------------------------------------------
/scripts/test/humanoid_new.sh:
--------------------------------------------------------------------------------
1 | # see how the humanoid works in terms of the performance:
2 | # trpo / ppo
3 | # batch size
4 | # max_timesteps 1e8
5 |
6 |
7 | export PYTHONPATH="$PYTHONPATH:$PWD"
8 |
9 | for env_type in dm-humanoid dm-humanoid-noise; do
10 | for batch_size in 5000 20000 50000; do
11 | for tr_method in ppo trpo; do
12 | python main/mf_main.py --exp_id ${env_type}_${tr_method}_${batch_size} \
13 | --timesteps_per_batch $batch_size --task $env_type \
14 | --num_workers 5 --trust_region_method $tr_method --max_timesteps 50000000
15 | done
16 | done
17 | done
18 |
--------------------------------------------------------------------------------
/scripts/test/mbmf.sh:
--------------------------------------------------------------------------------
1 | seeds=($(shuf -i 0-1000 -n 3))
2 | for seed in "${seeds[@]}"; do
3 | echo "$seed"
4 | # Swimmer
5 | python main/mbmf_main.py --task gym_swimmer --num_planning_traj 5000 --planning_depth 20 --random_timesteps 10000 --timesteps_per_batch 3000 --dynamics_epochs 30 --output_dir mbmf --num_workers 24 --mb_timesteps 70000 --dagger_epoch 300 --dagger_timesteps_per_iter 1750 --max_timesteps 10000000 --seed $seed
6 | # Half cheetah
7 | python main/mbmf_main.py --task gym_cheetah --num_planning_traj 1000 --planning_depth 20 --random_timesteps 10000 --timesteps_per_batch 9000 --dynamics_epochs 60 --output_dir mbmf --num_workers 24 --mb_timesteps 80000 --dagger_epoch 300 --dagger_timesteps_per_iter 2000 --max_timesteps 100000000 --seed $seed
8 | # Hopper
9 | python main/mbmf_main.py --task gym_hopper --num_planning_traj 1000 --planning_depth 40 --random_timesteps 4000 --timesteps_per_batch 10000 --dynamics_epochs 40 --output_dir mbmf --num_workers 24 --mb_timesteps 50000 --dagger_epoch 200 --dagger_timesteps_per_iter 5000 --max_timesteps 100000000 --seed $seed --dagger_saved_rollout 60 --dagger_iter 5
10 | #python main/mf_main.py --task gym_hopper --timesteps_per_batch 50000 --policy_batch_size 50000 --output_dir mbmf --num_workers 24 --seed $seed --max_timestep 100000000
11 | done
12 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Package setup script."""
2 | import setuptools
3 |
4 | setuptools.setup(
5 | name="mbbl",
6 | version="0.1",
7 | description="Model-based RL Baselines",
8 | author="Tingwu Wang",
9 | author_email="tingwuwang@cs.toronto.edu",
10 | packages=setuptools.find_packages(),
11 | package_data={'': ['./env/dm_env/assets/*.xml',
12 | './env/dm_env/assets/common/*.xml',
13 | './env/gym_env/fix_swimmer/assets/*.xml',
14 | './env/gym_env/pets_env/assets/*.xml']},
15 | include_package_data=True,
16 | install_requires=[
17 | ],
18 | )
19 |
20 | """
21 | "pyquaternion",
22 | "beautifulsoup4",
23 | "Box2D>=2.3.2",
24 | "num2words",
25 | "six",
26 | "tensorboard_logger",
27 | "tensorflow==1.12.0",
28 | "termcolor",
29 | "gym[mujoco]==0.7.4",
30 | "mujoco-py==0.5.7",
31 | """
32 |
--------------------------------------------------------------------------------
/tests/tf-test-guide.md:
--------------------------------------------------------------------------------
1 | # Tensorflow Code Testing Guide
2 | This guide provides inspirations on testing machine learning code written in tensorflow.
3 |
4 | ## Tensorflow Testing Framework
5 | Import the tensorflow built-in testing framework as follows:
6 | ```python
7 | from tensorflow.python.platform import test
8 | ```
9 | Or simply `import tensorflow as tf` and use `tf.test`.
10 | For the rest of this document, we will refer to the framework as `tf.test`.
11 |
12 | ### Introduction to `tf.test`
13 | 1. Recommended structure for the test code.
14 |
15 | Similar to unittest, write your code as classes inherited from `tf.test.TestCase`.
16 | Then in the `main` scope, run `tf.test.main()`.
17 |
18 | Example:
19 | ```python
20 | import tensorflow as tf
21 |
22 | class SeriousTest(tf.test.TestCase):
23 | def test_method_1(self):
24 | actual_value = expected_value = 1
25 | self.assertEqual(actual_value, expected_value)
26 | def test_method_2(self):
27 | actual_value = 1
28 | expected_value = 2
29 | self.assertNotEqual(actual_value, expected_value)
30 |
31 | if __name__ == "__main__":
32 | tf.test.main()
33 | ```
34 |
35 |
36 | 2. Tensorflow unit test class [`tf.test.TestCase`](https://www.tensorflow.org/api_docs/python/tf/test/TestCase#test_session).
37 |
38 | The `tf.test.TestCase` is built upon the standard python `unittest` with many additional methods.
39 |
40 | Checkout the test code in [TF Slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim/python/slim) for good examples.
41 |
42 |
43 |
44 | ### Useful Tips
45 | #### 1. Use `test_session()`
46 | When using `tf.test.TestCase`, you should almost always use
47 | the [`self.test_session()`](https://www.tensorflow.org/api_docs/python/tf/test/TestCase#test_session) method.
48 | It returns a TensorFlow Session for use in executing tests.
49 |
50 | See example in the [AlexNet tests](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/alexnet_test.py)
51 | as part of the TF Slim code base.
52 |
53 | #### 2. Test that a graph is built correctly
54 |
55 | The method [`tf.test.assert_equal_graph_def`](https://www.tensorflow.org/api_docs/python/tf/test/assert_equal_graph_def)
56 | asserts two `GraphDef` objects are the same, ignoring versions and ordering of nodes, attrs, and control inputs.
57 |
58 | Note: you can turn a `Graph` object to `GraphDef` object using `graph.as_graph_def()`.
59 |
60 | Example:
61 | ```python
62 | import tensorflow as tf
63 |
64 | graph_actual = tf.Graph()
65 | with graph_actual.as_default():
66 | x = tf.placeholder(tf.float32, shape=[None, 3])
67 | variable_name = tf.Variable(initial_value=tf.random_normal(shape=[3, 5]))
68 |
69 | graph_expected = tf.Graph()
70 | with graph_expected.as_default():
71 | x = tf.placeholder(tf.float32, shape=[None, 3])
72 | variable_name = tf.Variable(initial_value=tf.random_normal(shape=[3, 5]))
73 |
74 | tf.test.assert_equal_graph_def(graph_actual.as_graph_def(), graph_expected.as_graph_def())
75 | # test will pass.
76 | ```
--------------------------------------------------------------------------------